mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
Merge pull request #1147 from QwenLM/mingholy/feat/cli-sdk-stage-2
Custom tools support via SDK controlled MCP servers
This commit is contained in:
18
.github/workflows/release-sdk.yml
vendored
18
.github/workflows/release-sdk.yml
vendored
@@ -132,6 +132,24 @@ jobs:
|
||||
OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}'
|
||||
OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}'
|
||||
|
||||
- name: 'Build CLI for Integration Tests'
|
||||
if: |-
|
||||
${{ github.event.inputs.force_skip_tests != 'true' }}
|
||||
run: |
|
||||
npm run build
|
||||
npm run bundle
|
||||
|
||||
- name: 'Run SDK Integration Tests'
|
||||
if: |-
|
||||
${{ github.event.inputs.force_skip_tests != 'true' }}
|
||||
run: |
|
||||
npm run test:integration:sdk:sandbox:none
|
||||
npm run test:integration:sdk:sandbox:docker
|
||||
env:
|
||||
OPENAI_API_KEY: '${{ secrets.OPENAI_API_KEY }}'
|
||||
OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}'
|
||||
OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}'
|
||||
|
||||
- name: 'Configure Git User'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
|
||||
@@ -532,7 +532,6 @@ describe('Configuration Options (E2E)', () => {
|
||||
cwd: testDir,
|
||||
authType: 'openai',
|
||||
debug: true,
|
||||
logLevel: 'debug',
|
||||
stderr: (msg: string) => {
|
||||
stderrMessages.push(msg);
|
||||
},
|
||||
|
||||
@@ -555,6 +555,15 @@ describe('Permission Control (E2E)', () => {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
permissionMode: 'default',
|
||||
timeout: {
|
||||
/**
|
||||
* We use a short control request timeout and
|
||||
* wait till the time exceeded to test if
|
||||
* an immediate close() will raise an query close
|
||||
* error and no other uncaught timeout error
|
||||
*/
|
||||
controlRequest: 5000,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -563,7 +572,9 @@ describe('Permission Control (E2E)', () => {
|
||||
await expect(q.setPermissionMode('yolo')).rejects.toThrow(
|
||||
'Query is closed',
|
||||
);
|
||||
});
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 8000));
|
||||
}, 10_000);
|
||||
});
|
||||
|
||||
describe('canUseTool and setPermissionMode integration', () => {
|
||||
|
||||
465
integration-tests/sdk-typescript/sdk-mcp-server.test.ts
Normal file
465
integration-tests/sdk-typescript/sdk-mcp-server.test.ts
Normal file
@@ -0,0 +1,465 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* E2E tests for SDK-embedded MCP servers
|
||||
*
|
||||
* Tests that the SDK can create and manage MCP servers running in the SDK process
|
||||
* using the tool() and createSdkMcpServer() APIs.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest';
|
||||
import { z } from 'zod';
|
||||
import {
|
||||
query,
|
||||
tool,
|
||||
createSdkMcpServer,
|
||||
isSDKAssistantMessage,
|
||||
isSDKResultMessage,
|
||||
isSDKSystemMessage,
|
||||
type SDKMessage,
|
||||
type SDKSystemMessage,
|
||||
} from '@qwen-code/sdk-typescript';
|
||||
import {
|
||||
SDKTestHelper,
|
||||
extractText,
|
||||
findToolUseBlocks,
|
||||
createSharedTestOptions,
|
||||
} from './test-helper.js';
|
||||
|
||||
const SHARED_TEST_OPTIONS = {
|
||||
...createSharedTestOptions(),
|
||||
permissionMode: 'yolo' as const,
|
||||
};
|
||||
|
||||
describe('SDK MCP Server Integration (E2E)', () => {
|
||||
let helper: SDKTestHelper;
|
||||
let testDir: string;
|
||||
|
||||
beforeEach(async () => {
|
||||
helper = new SDKTestHelper();
|
||||
testDir = await helper.setup('sdk-mcp-server-integration');
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await helper.cleanup();
|
||||
});
|
||||
|
||||
describe('Basic SDK MCP Tool Usage', () => {
|
||||
it('should use SDK MCP tool to perform a simple calculation', async () => {
|
||||
// Define a simple calculator tool using the tool() API with Zod schema
|
||||
console.log(
|
||||
z.object({
|
||||
a: z.number().describe('First number'),
|
||||
b: z.number().describe('Second number'),
|
||||
}),
|
||||
);
|
||||
const calculatorTool = tool(
|
||||
'calculate_sum',
|
||||
'Calculate the sum of two numbers',
|
||||
z.object({
|
||||
a: z.number().describe('First number'),
|
||||
b: z.number().describe('Second number'),
|
||||
}).shape,
|
||||
async (args) => ({
|
||||
content: [{ type: 'text', text: String(args.a + args.b) }],
|
||||
}),
|
||||
);
|
||||
|
||||
// Create SDK MCP server with the tool
|
||||
const serverConfig = createSdkMcpServer({
|
||||
name: 'sdk-calculator',
|
||||
version: '1.0.0',
|
||||
tools: [calculatorTool],
|
||||
});
|
||||
|
||||
const q = query({
|
||||
prompt:
|
||||
'Use the calculate_sum tool to add 25 and 17. Output the result of tool only.',
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
stderr: (message) => console.error(message),
|
||||
mcpServers: {
|
||||
'sdk-calculator': serverConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
let assistantText = '';
|
||||
let foundToolUse = false;
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
console.log(JSON.stringify(message, null, 2));
|
||||
|
||||
if (isSDKAssistantMessage(message)) {
|
||||
const toolUseBlocks = findToolUseBlocks(message, 'calculate_sum');
|
||||
if (toolUseBlocks.length > 0) {
|
||||
foundToolUse = true;
|
||||
}
|
||||
assistantText += extractText(message.message.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate tool was called
|
||||
expect(foundToolUse).toBe(true);
|
||||
|
||||
// Validate result contains expected answer: 25 + 17 = 42
|
||||
expect(assistantText).toMatch(/42/);
|
||||
|
||||
// Validate successful completion
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
expect(isSDKResultMessage(lastMessage)).toBe(true);
|
||||
if (isSDKResultMessage(lastMessage)) {
|
||||
expect(lastMessage.subtype).toBe('success');
|
||||
}
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
|
||||
it('should use SDK MCP tool with string operations', async () => {
|
||||
// Define a string manipulation tool with Zod schema
|
||||
const stringTool = tool(
|
||||
'reverse_string',
|
||||
'Reverse a string',
|
||||
{
|
||||
text: z.string().describe('The text to reverse'),
|
||||
},
|
||||
async (args) => ({
|
||||
content: [
|
||||
{ type: 'text', text: args.text.split('').reverse().join('') },
|
||||
],
|
||||
}),
|
||||
);
|
||||
|
||||
const serverConfig = createSdkMcpServer({
|
||||
name: 'sdk-string-utils',
|
||||
version: '1.0.0',
|
||||
tools: [stringTool],
|
||||
});
|
||||
|
||||
const q = query({
|
||||
prompt: `Use the 'reverse_string' tool to process the word "hello world". Output the tool result only.`,
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
mcpServers: {
|
||||
'sdk-string-utils': serverConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
let assistantText = '';
|
||||
let foundToolUse = false;
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
|
||||
if (isSDKAssistantMessage(message)) {
|
||||
const toolUseBlocks = findToolUseBlocks(message, 'reverse_string');
|
||||
if (toolUseBlocks.length > 0) {
|
||||
foundToolUse = true;
|
||||
}
|
||||
assistantText += extractText(message.message.content);
|
||||
}
|
||||
}
|
||||
console.log(JSON.stringify(messages, null, 2));
|
||||
|
||||
// Validate tool was called
|
||||
expect(foundToolUse).toBe(true);
|
||||
|
||||
// Validate result contains reversed string: "olleh"
|
||||
expect(assistantText.toLowerCase()).toMatch(/olleh/);
|
||||
|
||||
// Validate successful completion
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
expect(isSDKResultMessage(lastMessage)).toBe(true);
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Multiple SDK MCP Tools', () => {
|
||||
it('should use multiple tools from the same SDK MCP server', async () => {
|
||||
// Define the Zod schema shape for two numbers
|
||||
const twoNumbersSchema = {
|
||||
a: z.number().describe('First number'),
|
||||
b: z.number().describe('Second number'),
|
||||
};
|
||||
|
||||
// Define multiple tools
|
||||
const addTool = tool(
|
||||
'sdk_add',
|
||||
'Add two numbers',
|
||||
twoNumbersSchema,
|
||||
async (args) => ({
|
||||
content: [{ type: 'text', text: String(args.a + args.b) }],
|
||||
}),
|
||||
);
|
||||
|
||||
const multiplyTool = tool(
|
||||
'sdk_multiply',
|
||||
'Multiply two numbers',
|
||||
twoNumbersSchema,
|
||||
async (args) => ({
|
||||
content: [{ type: 'text', text: String(args.a * args.b) }],
|
||||
}),
|
||||
);
|
||||
|
||||
const serverConfig = createSdkMcpServer({
|
||||
name: 'sdk-math',
|
||||
version: '1.0.0',
|
||||
tools: [addTool, multiplyTool],
|
||||
});
|
||||
|
||||
const q = query({
|
||||
prompt:
|
||||
'First use sdk_add to calculate 10 + 5, then use sdk_multiply to multiply the result by 3. Give me the final answer.',
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
debug: false,
|
||||
mcpServers: {
|
||||
'sdk-math': serverConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
let assistantText = '';
|
||||
const toolCalls: string[] = [];
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
|
||||
if (isSDKAssistantMessage(message)) {
|
||||
const toolUseBlocks = findToolUseBlocks(message);
|
||||
toolUseBlocks.forEach((block) => {
|
||||
toolCalls.push(block.name);
|
||||
});
|
||||
assistantText += extractText(message.message.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate both tools were called
|
||||
expect(toolCalls).toContain('sdk_add');
|
||||
expect(toolCalls).toContain('sdk_multiply');
|
||||
|
||||
// Validate result: (10 + 5) * 3 = 45
|
||||
expect(assistantText).toMatch(/45/);
|
||||
|
||||
// Validate successful completion
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
expect(isSDKResultMessage(lastMessage)).toBe(true);
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('SDK MCP Server Discovery', () => {
|
||||
it('should list SDK MCP servers in system init message', async () => {
|
||||
// Define echo tool with Zod schema
|
||||
const echoTool = tool(
|
||||
'echo',
|
||||
'Echo a message',
|
||||
{
|
||||
message: z.string().describe('Message to echo'),
|
||||
},
|
||||
async (args) => ({
|
||||
content: [{ type: 'text', text: args.message }],
|
||||
}),
|
||||
);
|
||||
|
||||
const serverConfig = createSdkMcpServer({
|
||||
name: 'sdk-echo',
|
||||
version: '1.0.0',
|
||||
tools: [echoTool],
|
||||
});
|
||||
|
||||
const q = query({
|
||||
prompt: 'Hello',
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
debug: false,
|
||||
mcpServers: {
|
||||
'sdk-echo': serverConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
let systemMessage: SDKSystemMessage | null = null;
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
if (isSDKSystemMessage(message) && message.subtype === 'init') {
|
||||
systemMessage = message;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate MCP server is listed
|
||||
expect(systemMessage).not.toBeNull();
|
||||
expect(systemMessage!.mcp_servers).toBeDefined();
|
||||
expect(Array.isArray(systemMessage!.mcp_servers)).toBe(true);
|
||||
|
||||
// Find our SDK MCP server
|
||||
const sdkServer = systemMessage!.mcp_servers?.find(
|
||||
(server) => server.name === 'sdk-echo',
|
||||
);
|
||||
expect(sdkServer).toBeDefined();
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('SDK MCP Tool Error Handling', () => {
|
||||
it('should handle tool errors gracefully', async () => {
|
||||
// Define a tool that throws an error with Zod schema
|
||||
const errorTool = tool(
|
||||
'maybe_fail',
|
||||
'A tool that may fail based on input',
|
||||
{
|
||||
shouldFail: z.boolean().describe('If true, the tool will fail'),
|
||||
},
|
||||
async (args) => {
|
||||
if (args.shouldFail) {
|
||||
throw new Error('Tool intentionally failed');
|
||||
}
|
||||
return { content: [{ type: 'text', text: 'Success!' }] };
|
||||
},
|
||||
);
|
||||
|
||||
const serverConfig = createSdkMcpServer({
|
||||
name: 'sdk-error-test',
|
||||
version: '1.0.0',
|
||||
tools: [errorTool],
|
||||
});
|
||||
|
||||
const q = query({
|
||||
prompt:
|
||||
'Use the maybe_fail tool with shouldFail set to true. Tell me what happens.',
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
debug: false,
|
||||
mcpServers: {
|
||||
'sdk-error-test': serverConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
let foundToolUse = false;
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
|
||||
if (isSDKAssistantMessage(message)) {
|
||||
const toolUseBlocks = findToolUseBlocks(message, 'maybe_fail');
|
||||
if (toolUseBlocks.length > 0) {
|
||||
foundToolUse = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool should be called
|
||||
expect(foundToolUse).toBe(true);
|
||||
|
||||
// Query should complete (even with tool error)
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
expect(isSDKResultMessage(lastMessage)).toBe(true);
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe('Async Tool Handlers', () => {
|
||||
it('should handle async tool handlers with delays', async () => {
|
||||
// Define a tool with async delay using Zod schema
|
||||
const delayedTool = tool(
|
||||
'delayed_response',
|
||||
'Returns a value after a delay',
|
||||
{
|
||||
delay: z.number().describe('Delay in milliseconds (max 100)'),
|
||||
value: z.string().describe('Value to return'),
|
||||
},
|
||||
async (args) => {
|
||||
// Cap delay at 100ms for test performance
|
||||
const actualDelay = Math.min(args.delay, 100);
|
||||
await new Promise((resolve) => setTimeout(resolve, actualDelay));
|
||||
return {
|
||||
content: [{ type: 'text', text: `Delayed result: ${args.value}` }],
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
const serverConfig = createSdkMcpServer({
|
||||
name: 'sdk-async',
|
||||
version: '1.0.0',
|
||||
tools: [delayedTool],
|
||||
});
|
||||
|
||||
const q = query({
|
||||
prompt:
|
||||
'Use the delayed_response tool with delay=50 and value="test_async". Tell me the result.',
|
||||
options: {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
debug: false,
|
||||
mcpServers: {
|
||||
'sdk-async': serverConfig,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const messages: SDKMessage[] = [];
|
||||
let assistantText = '';
|
||||
let foundToolUse = false;
|
||||
|
||||
try {
|
||||
for await (const message of q) {
|
||||
messages.push(message);
|
||||
|
||||
if (isSDKAssistantMessage(message)) {
|
||||
const toolUseBlocks = findToolUseBlocks(
|
||||
message,
|
||||
'delayed_response',
|
||||
);
|
||||
if (toolUseBlocks.length > 0) {
|
||||
foundToolUse = true;
|
||||
}
|
||||
assistantText += extractText(message.message.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Validate tool was called
|
||||
expect(foundToolUse).toBe(true);
|
||||
|
||||
// Validate result contains the delayed response
|
||||
expect(assistantText.toLowerCase()).toMatch(/test_async/i);
|
||||
|
||||
// Validate successful completion
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
expect(isSDKResultMessage(lastMessage)).toBe(true);
|
||||
} finally {
|
||||
await q.close();
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -44,7 +44,6 @@ describe('Single-Turn Query (E2E)', () => {
|
||||
...SHARED_TEST_OPTIONS,
|
||||
cwd: testDir,
|
||||
debug: true,
|
||||
logLevel: 'debug',
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -37,6 +37,10 @@
|
||||
"test:integration:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests",
|
||||
"test:integration:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests",
|
||||
"test:integration:sandbox:podman": "cross-env GEMINI_SANDBOX=podman vitest run --root ./integration-tests",
|
||||
"test:integration:sdk:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests --dir sdk-typescript",
|
||||
"test:integration:sdk:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests --dir sdk-typescript",
|
||||
"test:integration:cli:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests --exclude '**/sdk-typescript/**'",
|
||||
"test:integration:cli:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests --exclude '**/sdk-typescript/**'",
|
||||
"test:terminal-bench": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests",
|
||||
"test:terminal-bench:oracle": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'oracle'",
|
||||
"test:terminal-bench:qwen": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'qwen'",
|
||||
|
||||
@@ -276,8 +276,11 @@ export async function main() {
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
// For stream-json mode, don't read stdin here - it should be forwarded to the sandbox
|
||||
// and consumed by StreamJsonInputReader inside the container
|
||||
const inputFormat = argv.inputFormat as string | undefined;
|
||||
let stdinData = '';
|
||||
if (!process.stdin.isTTY) {
|
||||
if (!process.stdin.isTTY && inputFormat !== 'stream-json') {
|
||||
stdinData = await readStdin();
|
||||
}
|
||||
|
||||
|
||||
@@ -16,9 +16,12 @@
|
||||
* Controllers:
|
||||
* - SystemController: initialize, interrupt, set_model, supported_commands
|
||||
* - PermissionController: can_use_tool, set_permission_mode
|
||||
* - MCPController: mcp_message, mcp_server_status
|
||||
* - SdkMcpController: mcp_server_status (mcp_message handled via callback)
|
||||
* - HookController: hook_callback
|
||||
*
|
||||
* Note: mcp_message requests are NOT routed through the dispatcher. CLI MCP
|
||||
* clients send messages via SdkMcpController.createSendSdkMcpMessage() callback.
|
||||
*
|
||||
* Note: Control request types are centrally defined in the ControlRequestType
|
||||
* enum in packages/sdk/typescript/src/types/controlRequests.ts
|
||||
*/
|
||||
@@ -27,7 +30,7 @@ import type { IControlContext } from './ControlContext.js';
|
||||
import type { IPendingRequestRegistry } from './controllers/baseController.js';
|
||||
import { SystemController } from './controllers/systemController.js';
|
||||
import { PermissionController } from './controllers/permissionController.js';
|
||||
// import { MCPController } from './controllers/mcpController.js';
|
||||
import { SdkMcpController } from './controllers/sdkMcpController.js';
|
||||
// import { HookController } from './controllers/hookController.js';
|
||||
import type {
|
||||
CLIControlRequest,
|
||||
@@ -65,7 +68,7 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
// Make controllers publicly accessible
|
||||
readonly systemController: SystemController;
|
||||
readonly permissionController: PermissionController;
|
||||
// readonly mcpController: MCPController;
|
||||
readonly sdkMcpController: SdkMcpController;
|
||||
// readonly hookController: HookController;
|
||||
|
||||
// Central pending request registries
|
||||
@@ -88,7 +91,11 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
this,
|
||||
'PermissionController',
|
||||
);
|
||||
// this.mcpController = new MCPController(context, this, 'MCPController');
|
||||
this.sdkMcpController = new SdkMcpController(
|
||||
context,
|
||||
this,
|
||||
'SdkMcpController',
|
||||
);
|
||||
// this.hookController = new HookController(context, this, 'HookController');
|
||||
|
||||
// Listen for main abort signal
|
||||
@@ -228,10 +235,10 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
}
|
||||
this.pendingOutgoingRequests.clear();
|
||||
|
||||
// Cleanup controllers (MCP controller will close all clients)
|
||||
// Cleanup controllers
|
||||
this.systemController.cleanup();
|
||||
this.permissionController.cleanup();
|
||||
// this.mcpController.cleanup();
|
||||
this.sdkMcpController.cleanup();
|
||||
// this.hookController.cleanup();
|
||||
}
|
||||
|
||||
@@ -291,6 +298,47 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get count of pending incoming requests (for debugging)
|
||||
*/
|
||||
getPendingIncomingRequestCount(): number {
|
||||
return this.pendingIncomingRequests.size;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for all incoming request handlers to complete.
|
||||
*
|
||||
* Uses polling since we don't have direct Promise references to handlers.
|
||||
* The pendingIncomingRequests map is managed by BaseController:
|
||||
* - Registered when handler starts (in handleRequest)
|
||||
* - Deregistered when handler completes (success or error)
|
||||
*
|
||||
* @param pollIntervalMs - How often to check (default 50ms)
|
||||
* @param timeoutMs - Maximum wait time (default 30s)
|
||||
*/
|
||||
async waitForPendingIncomingRequests(
|
||||
pollIntervalMs: number = 50,
|
||||
timeoutMs: number = 30000,
|
||||
): Promise<void> {
|
||||
const startTime = Date.now();
|
||||
|
||||
while (this.pendingIncomingRequests.size > 0) {
|
||||
if (Date.now() - startTime > timeoutMs) {
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[ControlDispatcher] Timeout waiting for ${this.pendingIncomingRequests.size} pending incoming requests`,
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, pollIntervalMs));
|
||||
}
|
||||
|
||||
if (this.context.debugMode && this.pendingIncomingRequests.size === 0) {
|
||||
console.error('[ControlDispatcher] All incoming requests completed');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the controller that handles the given request subtype
|
||||
*/
|
||||
@@ -306,9 +354,8 @@ export class ControlDispatcher implements IPendingRequestRegistry {
|
||||
case 'set_permission_mode':
|
||||
return this.permissionController;
|
||||
|
||||
// case 'mcp_message':
|
||||
// case 'mcp_server_status':
|
||||
// return this.mcpController;
|
||||
case 'mcp_server_status':
|
||||
return this.sdkMcpController;
|
||||
|
||||
// case 'hook_callback':
|
||||
// return this.hookController;
|
||||
|
||||
@@ -117,16 +117,41 @@ export abstract class BaseController {
|
||||
* Send an outgoing control request to SDK
|
||||
*
|
||||
* Manages lifecycle: register -> send -> wait for response -> deregister
|
||||
* Respects the provided AbortSignal for cancellation.
|
||||
*/
|
||||
async sendControlRequest(
|
||||
payload: ControlRequestPayload,
|
||||
timeoutMs: number = DEFAULT_REQUEST_TIMEOUT_MS,
|
||||
signal?: AbortSignal,
|
||||
): Promise<ControlResponse> {
|
||||
// Check if already aborted
|
||||
if (signal?.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
const requestId = randomUUID();
|
||||
|
||||
return new Promise<ControlResponse>((resolve, reject) => {
|
||||
// Setup abort handler
|
||||
const abortHandler = () => {
|
||||
this.registry.deregisterOutgoingRequest(requestId);
|
||||
reject(new Error('Request aborted'));
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[${this.controllerName}] Outgoing request aborted: ${requestId}`,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if (signal) {
|
||||
signal.addEventListener('abort', abortHandler, { once: true });
|
||||
}
|
||||
|
||||
// Setup timeout
|
||||
const timeoutId = setTimeout(() => {
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
this.registry.deregisterOutgoingRequest(requestId);
|
||||
reject(new Error('Control request timeout'));
|
||||
if (this.context.debugMode) {
|
||||
@@ -136,12 +161,27 @@ export abstract class BaseController {
|
||||
}
|
||||
}, timeoutMs);
|
||||
|
||||
// Wrap resolve/reject to clean up abort listener
|
||||
const wrappedResolve = (response: ControlResponse) => {
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
resolve(response);
|
||||
};
|
||||
|
||||
const wrappedReject = (error: Error) => {
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
reject(error);
|
||||
};
|
||||
|
||||
// Register with central registry
|
||||
this.registry.registerOutgoingRequest(
|
||||
requestId,
|
||||
this.controllerName,
|
||||
resolve,
|
||||
reject,
|
||||
wrappedResolve,
|
||||
wrappedReject,
|
||||
timeoutId,
|
||||
);
|
||||
|
||||
@@ -155,6 +195,9 @@ export abstract class BaseController {
|
||||
try {
|
||||
this.context.streamJson.send(request);
|
||||
} catch (error) {
|
||||
if (signal) {
|
||||
signal.removeEventListener('abort', abortHandler);
|
||||
}
|
||||
this.registry.deregisterOutgoingRequest(requestId);
|
||||
reject(error);
|
||||
}
|
||||
|
||||
@@ -1,287 +0,0 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* MCP Controller
|
||||
*
|
||||
* Handles MCP-related control requests:
|
||||
* - mcp_message: Route MCP messages
|
||||
* - mcp_server_status: Return MCP server status
|
||||
*/
|
||||
|
||||
import { BaseController } from './baseController.js';
|
||||
import type { Client } from '@modelcontextprotocol/sdk/client/index.js';
|
||||
import { ResultSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type {
|
||||
ControlRequestPayload,
|
||||
CLIControlMcpMessageRequest,
|
||||
} from '../../types.js';
|
||||
import type {
|
||||
MCPServerConfig,
|
||||
WorkspaceContext,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
connectToMcpServer,
|
||||
MCP_DEFAULT_TIMEOUT_MSEC,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
|
||||
export class MCPController extends BaseController {
|
||||
/**
|
||||
* Handle MCP control requests
|
||||
*/
|
||||
protected async handleRequestPayload(
|
||||
payload: ControlRequestPayload,
|
||||
_signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
switch (payload.subtype) {
|
||||
case 'mcp_message':
|
||||
return this.handleMcpMessage(payload as CLIControlMcpMessageRequest);
|
||||
|
||||
case 'mcp_server_status':
|
||||
return this.handleMcpStatus();
|
||||
|
||||
default:
|
||||
throw new Error(`Unsupported request subtype in MCPController`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle mcp_message request
|
||||
*
|
||||
* Routes JSON-RPC messages to MCP servers
|
||||
*/
|
||||
private async handleMcpMessage(
|
||||
payload: CLIControlMcpMessageRequest,
|
||||
): Promise<Record<string, unknown>> {
|
||||
const serverNameRaw = payload.server_name;
|
||||
if (
|
||||
typeof serverNameRaw !== 'string' ||
|
||||
serverNameRaw.trim().length === 0
|
||||
) {
|
||||
throw new Error('Missing server_name in mcp_message request');
|
||||
}
|
||||
|
||||
const message = payload.message;
|
||||
if (!message || typeof message !== 'object') {
|
||||
throw new Error(
|
||||
'Missing or invalid message payload for mcp_message request',
|
||||
);
|
||||
}
|
||||
|
||||
// Get or create MCP client
|
||||
let clientEntry: { client: Client; config: MCPServerConfig };
|
||||
try {
|
||||
clientEntry = await this.getOrCreateMcpClient(serverNameRaw.trim());
|
||||
} catch (error) {
|
||||
throw new Error(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'Failed to connect to MCP server',
|
||||
);
|
||||
}
|
||||
|
||||
const method = message.method;
|
||||
if (typeof method !== 'string' || method.trim().length === 0) {
|
||||
throw new Error('Invalid MCP message: missing method');
|
||||
}
|
||||
|
||||
const jsonrpcVersion =
|
||||
typeof message.jsonrpc === 'string' ? message.jsonrpc : '2.0';
|
||||
const messageId = message.id;
|
||||
const params = message.params;
|
||||
const timeout =
|
||||
typeof clientEntry.config.timeout === 'number'
|
||||
? clientEntry.config.timeout
|
||||
: MCP_DEFAULT_TIMEOUT_MSEC;
|
||||
|
||||
try {
|
||||
// Handle notification (no id)
|
||||
if (messageId === undefined) {
|
||||
await clientEntry.client.notification({
|
||||
method,
|
||||
params,
|
||||
});
|
||||
return {
|
||||
subtype: 'mcp_message',
|
||||
mcp_response: {
|
||||
jsonrpc: jsonrpcVersion,
|
||||
id: null,
|
||||
result: { success: true, acknowledged: true },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Handle request (with id)
|
||||
const result = await clientEntry.client.request(
|
||||
{
|
||||
method,
|
||||
params,
|
||||
},
|
||||
ResultSchema,
|
||||
{ timeout },
|
||||
);
|
||||
|
||||
return {
|
||||
subtype: 'mcp_message',
|
||||
mcp_response: {
|
||||
jsonrpc: jsonrpcVersion,
|
||||
id: messageId,
|
||||
result,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
// If connection closed, remove from cache
|
||||
if (error instanceof Error && /closed/i.test(error.message)) {
|
||||
this.context.mcpClients.delete(serverNameRaw.trim());
|
||||
}
|
||||
|
||||
const errorCode =
|
||||
typeof (error as { code?: unknown })?.code === 'number'
|
||||
? ((error as { code: number }).code as number)
|
||||
: -32603;
|
||||
const errorMessage =
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'Failed to execute MCP request';
|
||||
const errorData = (error as { data?: unknown })?.data;
|
||||
|
||||
const errorBody: Record<string, unknown> = {
|
||||
code: errorCode,
|
||||
message: errorMessage,
|
||||
};
|
||||
if (errorData !== undefined) {
|
||||
errorBody['data'] = errorData;
|
||||
}
|
||||
|
||||
return {
|
||||
subtype: 'mcp_message',
|
||||
mcp_response: {
|
||||
jsonrpc: jsonrpcVersion,
|
||||
id: messageId ?? null,
|
||||
error: errorBody,
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle mcp_server_status request
|
||||
*
|
||||
* Returns status of registered MCP servers
|
||||
*/
|
||||
private async handleMcpStatus(): Promise<Record<string, unknown>> {
|
||||
const status: Record<string, string> = {};
|
||||
|
||||
// Include SDK MCP servers
|
||||
for (const serverName of this.context.sdkMcpServers) {
|
||||
status[serverName] = 'connected';
|
||||
}
|
||||
|
||||
// Include CLI-managed MCP clients
|
||||
for (const serverName of this.context.mcpClients.keys()) {
|
||||
status[serverName] = 'connected';
|
||||
}
|
||||
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[MCPController] MCP status: ${Object.keys(status).length} servers`,
|
||||
);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get or create MCP client for a server
|
||||
*
|
||||
* Implements lazy connection and caching
|
||||
*/
|
||||
private async getOrCreateMcpClient(
|
||||
serverName: string,
|
||||
): Promise<{ client: Client; config: MCPServerConfig }> {
|
||||
// Check cache first
|
||||
const cached = this.context.mcpClients.get(serverName);
|
||||
if (cached) {
|
||||
return cached;
|
||||
}
|
||||
|
||||
// Get server configuration
|
||||
const provider = this.context.config as unknown as {
|
||||
getMcpServers?: () => Record<string, MCPServerConfig> | undefined;
|
||||
getDebugMode?: () => boolean;
|
||||
getWorkspaceContext?: () => unknown;
|
||||
};
|
||||
|
||||
if (typeof provider.getMcpServers !== 'function') {
|
||||
throw new Error(`MCP server "${serverName}" is not configured`);
|
||||
}
|
||||
|
||||
const servers = provider.getMcpServers() ?? {};
|
||||
const serverConfig = servers[serverName];
|
||||
if (!serverConfig) {
|
||||
throw new Error(`MCP server "${serverName}" is not configured`);
|
||||
}
|
||||
|
||||
const debugMode =
|
||||
typeof provider.getDebugMode === 'function'
|
||||
? provider.getDebugMode()
|
||||
: false;
|
||||
|
||||
const workspaceContext =
|
||||
typeof provider.getWorkspaceContext === 'function'
|
||||
? provider.getWorkspaceContext()
|
||||
: undefined;
|
||||
|
||||
if (!workspaceContext) {
|
||||
throw new Error('Workspace context is not available for MCP connection');
|
||||
}
|
||||
|
||||
// Connect to MCP server
|
||||
const client = await connectToMcpServer(
|
||||
serverName,
|
||||
serverConfig,
|
||||
debugMode,
|
||||
workspaceContext as WorkspaceContext,
|
||||
);
|
||||
|
||||
// Cache the client
|
||||
const entry = { client, config: serverConfig };
|
||||
this.context.mcpClients.set(serverName, entry);
|
||||
|
||||
if (this.context.debugMode) {
|
||||
console.error(`[MCPController] Connected to MCP server: ${serverName}`);
|
||||
}
|
||||
|
||||
return entry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup MCP clients
|
||||
*/
|
||||
override cleanup(): void {
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[MCPController] Cleaning up ${this.context.mcpClients.size} MCP clients`,
|
||||
);
|
||||
}
|
||||
|
||||
// Close all MCP clients
|
||||
for (const [serverName, { client }] of this.context.mcpClients.entries()) {
|
||||
try {
|
||||
client.close();
|
||||
} catch (error) {
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[MCPController] Failed to close MCP client ${serverName}:`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.context.mcpClients.clear();
|
||||
}
|
||||
}
|
||||
@@ -44,15 +44,23 @@ export class PermissionController extends BaseController {
|
||||
*/
|
||||
protected async handleRequestPayload(
|
||||
payload: ControlRequestPayload,
|
||||
_signal: AbortSignal,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
switch (payload.subtype) {
|
||||
case 'can_use_tool':
|
||||
return this.handleCanUseTool(payload as CLIControlPermissionRequest);
|
||||
return this.handleCanUseTool(
|
||||
payload as CLIControlPermissionRequest,
|
||||
signal,
|
||||
);
|
||||
|
||||
case 'set_permission_mode':
|
||||
return this.handleSetPermissionMode(
|
||||
payload as CLIControlSetPermissionModeRequest,
|
||||
signal,
|
||||
);
|
||||
|
||||
default:
|
||||
@@ -70,7 +78,12 @@ export class PermissionController extends BaseController {
|
||||
*/
|
||||
private async handleCanUseTool(
|
||||
payload: CLIControlPermissionRequest,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
const toolName = payload.tool_name;
|
||||
if (
|
||||
!toolName ||
|
||||
@@ -192,7 +205,12 @@ export class PermissionController extends BaseController {
|
||||
*/
|
||||
private async handleSetPermissionMode(
|
||||
payload: CLIControlSetPermissionModeRequest,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
const mode = payload.mode;
|
||||
const validModes: PermissionMode[] = [
|
||||
'default',
|
||||
@@ -373,6 +391,14 @@ export class PermissionController extends BaseController {
|
||||
toolCall: WaitingToolCall,
|
||||
): Promise<void> {
|
||||
try {
|
||||
// Check if already aborted
|
||||
if (this.context.abortSignal?.aborted) {
|
||||
await toolCall.confirmationDetails.onConfirm(
|
||||
ToolConfirmationOutcome.Cancel,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const inputFormat = this.context.config.getInputFormat?.();
|
||||
const isStreamJsonMode = inputFormat === InputFormat.STREAM_JSON;
|
||||
|
||||
@@ -392,14 +418,18 @@ export class PermissionController extends BaseController {
|
||||
toolCall.confirmationDetails,
|
||||
);
|
||||
|
||||
const response = await this.sendControlRequest({
|
||||
const response = await this.sendControlRequest(
|
||||
{
|
||||
subtype: 'can_use_tool',
|
||||
tool_name: toolCall.request.name,
|
||||
tool_use_id: toolCall.request.callId,
|
||||
input: toolCall.request.args,
|
||||
permission_suggestions: permissionSuggestions,
|
||||
blocked_path: null,
|
||||
} as CLIControlPermissionRequest);
|
||||
} as CLIControlPermissionRequest,
|
||||
undefined, // use default timeout
|
||||
this.context.abortSignal,
|
||||
);
|
||||
|
||||
if (response.subtype !== 'success') {
|
||||
await toolCall.confirmationDetails.onConfirm(
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* SDK MCP Controller
|
||||
*
|
||||
* Handles MCP communication between CLI MCP clients and SDK MCP servers:
|
||||
* - Provides sendSdkMcpMessage callback for CLI → SDK MCP message routing
|
||||
* - mcp_server_status: Returns status of SDK MCP servers
|
||||
*
|
||||
* Message Flow (CLI MCP Client → SDK MCP Server):
|
||||
* CLI MCP Client → SdkControlClientTransport.send() →
|
||||
* sendSdkMcpMessage callback → control_request (mcp_message) → SDK →
|
||||
* SDK MCP Server processes → control_response → CLI MCP Client
|
||||
*/
|
||||
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import { BaseController } from './baseController.js';
|
||||
import type {
|
||||
ControlRequestPayload,
|
||||
CLIControlMcpMessageRequest,
|
||||
} from '../../types.js';
|
||||
|
||||
const MCP_REQUEST_TIMEOUT = 30_000; // 30 seconds
|
||||
|
||||
export class SdkMcpController extends BaseController {
|
||||
/**
|
||||
* Handle SDK MCP control requests from ControlDispatcher
|
||||
*
|
||||
* Note: mcp_message requests are NOT handled here. CLI MCP clients
|
||||
* send messages via the sendSdkMcpMessage callback directly, not
|
||||
* through the control dispatcher.
|
||||
*/
|
||||
protected async handleRequestPayload(
|
||||
payload: ControlRequestPayload,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
switch (payload.subtype) {
|
||||
case 'mcp_server_status':
|
||||
return this.handleMcpStatus();
|
||||
|
||||
default:
|
||||
throw new Error(`Unsupported request subtype in SdkMcpController`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle mcp_server_status request
|
||||
*
|
||||
* Returns status of all registered SDK MCP servers.
|
||||
* SDK servers are considered "connected" if they are registered.
|
||||
*/
|
||||
private async handleMcpStatus(): Promise<Record<string, unknown>> {
|
||||
const status: Record<string, string> = {};
|
||||
|
||||
for (const serverName of this.context.sdkMcpServers) {
|
||||
// SDK MCP servers are "connected" once registered since they run in SDK process
|
||||
status[serverName] = 'connected';
|
||||
}
|
||||
|
||||
return {
|
||||
subtype: 'mcp_server_status',
|
||||
status,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Send MCP message to SDK server via control plane
|
||||
*
|
||||
* @param serverName - Name of the SDK MCP server
|
||||
* @param message - MCP JSON-RPC message to send
|
||||
* @returns MCP JSON-RPC response from SDK server
|
||||
*/
|
||||
private async sendMcpMessageToSdk(
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
): Promise<JSONRPCMessage> {
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[SdkMcpController] Sending MCP message to SDK server '${serverName}':`,
|
||||
JSON.stringify(message),
|
||||
);
|
||||
}
|
||||
|
||||
// Send control request to SDK with the MCP message
|
||||
const response = await this.sendControlRequest(
|
||||
{
|
||||
subtype: 'mcp_message',
|
||||
server_name: serverName,
|
||||
message: message as CLIControlMcpMessageRequest['message'],
|
||||
},
|
||||
MCP_REQUEST_TIMEOUT,
|
||||
this.context.abortSignal,
|
||||
);
|
||||
|
||||
// Extract MCP response from control response
|
||||
const responsePayload = response.response as Record<string, unknown>;
|
||||
const mcpResponse = responsePayload?.['mcp_response'] as JSONRPCMessage;
|
||||
|
||||
if (!mcpResponse) {
|
||||
throw new Error(
|
||||
`Invalid MCP response from SDK for server '${serverName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[SdkMcpController] Received MCP response from SDK server '${serverName}':`,
|
||||
JSON.stringify(mcpResponse),
|
||||
);
|
||||
}
|
||||
|
||||
return mcpResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a callback function for sending MCP messages to SDK servers.
|
||||
*
|
||||
* This callback is used by McpClientManager/SdkControlClientTransport to send
|
||||
* MCP messages from CLI MCP clients to SDK MCP servers via the control plane.
|
||||
*
|
||||
* @returns A function that sends MCP messages to SDK and returns the response
|
||||
*/
|
||||
createSendSdkMcpMessage(): (
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
) => Promise<JSONRPCMessage> {
|
||||
return (serverName: string, message: JSONRPCMessage) =>
|
||||
this.sendMcpMessageToSdk(serverName, message);
|
||||
}
|
||||
}
|
||||
@@ -18,9 +18,15 @@ import type {
|
||||
ControlRequestPayload,
|
||||
CLIControlInitializeRequest,
|
||||
CLIControlSetModelRequest,
|
||||
CLIMcpServerConfig,
|
||||
} from '../../types.js';
|
||||
import { CommandService } from '../../../services/CommandService.js';
|
||||
import { BuiltinCommandLoader } from '../../../services/BuiltinCommandLoader.js';
|
||||
import {
|
||||
MCPServerConfig,
|
||||
AuthProviderType,
|
||||
type MCPOAuthConfig,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
|
||||
export class SystemController extends BaseController {
|
||||
/**
|
||||
@@ -28,20 +34,30 @@ export class SystemController extends BaseController {
|
||||
*/
|
||||
protected async handleRequestPayload(
|
||||
payload: ControlRequestPayload,
|
||||
_signal: AbortSignal,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
switch (payload.subtype) {
|
||||
case 'initialize':
|
||||
return this.handleInitialize(payload as CLIControlInitializeRequest);
|
||||
return this.handleInitialize(
|
||||
payload as CLIControlInitializeRequest,
|
||||
signal,
|
||||
);
|
||||
|
||||
case 'interrupt':
|
||||
return this.handleInterrupt();
|
||||
|
||||
case 'set_model':
|
||||
return this.handleSetModel(payload as CLIControlSetModelRequest);
|
||||
return this.handleSetModel(
|
||||
payload as CLIControlSetModelRequest,
|
||||
signal,
|
||||
);
|
||||
|
||||
case 'supported_commands':
|
||||
return this.handleSupportedCommands();
|
||||
return this.handleSupportedCommands(signal);
|
||||
|
||||
default:
|
||||
throw new Error(`Unsupported request subtype in SystemController`);
|
||||
@@ -51,23 +67,65 @@ export class SystemController extends BaseController {
|
||||
/**
|
||||
* Handle initialize request
|
||||
*
|
||||
* Registers SDK MCP servers and returns capabilities
|
||||
* Processes SDK MCP servers config.
|
||||
* SDK servers are registered in context.sdkMcpServers
|
||||
* and added to config.mcpServers with the sdk type flag.
|
||||
* External MCP servers are configured separately in settings.
|
||||
*/
|
||||
private async handleInitialize(
|
||||
payload: CLIControlInitializeRequest,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
this.context.config.setSdkMode(true);
|
||||
|
||||
if (payload.sdkMcpServers && typeof payload.sdkMcpServers === 'object') {
|
||||
for (const serverName of Object.keys(payload.sdkMcpServers)) {
|
||||
this.context.sdkMcpServers.add(serverName);
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
this.context.config.setSdkMode(true);
|
||||
|
||||
// Process SDK MCP servers
|
||||
if (
|
||||
payload.sdkMcpServers &&
|
||||
typeof payload.sdkMcpServers === 'object' &&
|
||||
payload.sdkMcpServers !== null
|
||||
) {
|
||||
const sdkServers: Record<string, MCPServerConfig> = {};
|
||||
for (const [key, wireConfig] of Object.entries(payload.sdkMcpServers)) {
|
||||
const name =
|
||||
typeof wireConfig?.name === 'string' && wireConfig.name.trim().length
|
||||
? wireConfig.name
|
||||
: key;
|
||||
|
||||
this.context.sdkMcpServers.add(name);
|
||||
sdkServers[name] = new MCPServerConfig(
|
||||
undefined, // command
|
||||
undefined, // args
|
||||
undefined, // env
|
||||
undefined, // cwd
|
||||
undefined, // url
|
||||
undefined, // httpUrl
|
||||
undefined, // headers
|
||||
undefined, // tcp
|
||||
undefined, // timeout
|
||||
true, // trust - SDK servers are trusted
|
||||
undefined, // description
|
||||
undefined, // includeTools
|
||||
undefined, // excludeTools
|
||||
undefined, // extensionName
|
||||
undefined, // oauth
|
||||
undefined, // authProviderType
|
||||
undefined, // targetAudience
|
||||
undefined, // targetServiceAccount
|
||||
'sdk', // type
|
||||
);
|
||||
}
|
||||
|
||||
const sdkServerCount = Object.keys(sdkServers).length;
|
||||
if (sdkServerCount > 0) {
|
||||
try {
|
||||
this.context.config.addMcpServers(payload.sdkMcpServers);
|
||||
this.context.config.addMcpServers(sdkServers);
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[SystemController] Added ${Object.keys(payload.sdkMcpServers).length} SDK MCP servers to config`,
|
||||
`[SystemController] Added ${sdkServerCount} SDK MCP servers to config`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
@@ -79,18 +137,40 @@ export class SystemController extends BaseController {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (payload.mcpServers && typeof payload.mcpServers === 'object') {
|
||||
if (
|
||||
payload.mcpServers &&
|
||||
typeof payload.mcpServers === 'object' &&
|
||||
payload.mcpServers !== null
|
||||
) {
|
||||
const externalServers: Record<string, MCPServerConfig> = {};
|
||||
for (const [name, serverConfig] of Object.entries(payload.mcpServers)) {
|
||||
const normalized = this.normalizeMcpServerConfig(
|
||||
name,
|
||||
serverConfig as CLIMcpServerConfig | undefined,
|
||||
);
|
||||
if (normalized) {
|
||||
externalServers[name] = normalized;
|
||||
}
|
||||
}
|
||||
|
||||
const externalCount = Object.keys(externalServers).length;
|
||||
if (externalCount > 0) {
|
||||
try {
|
||||
this.context.config.addMcpServers(payload.mcpServers);
|
||||
this.context.config.addMcpServers(externalServers);
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[SystemController] Added ${Object.keys(payload.mcpServers).length} MCP servers to config`,
|
||||
`[SystemController] Added ${externalCount} external MCP servers to config`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.context.debugMode) {
|
||||
console.error('[SystemController] Failed to add MCP servers:', error);
|
||||
console.error(
|
||||
'[SystemController] Failed to add external MCP servers:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -143,13 +223,96 @@ export class SystemController extends BaseController {
|
||||
can_set_permission_mode:
|
||||
typeof this.context.config.setApprovalMode === 'function',
|
||||
can_set_model: typeof this.context.config.setModel === 'function',
|
||||
/* TODO: sdkMcpServers support */
|
||||
can_handle_mcp_message: false,
|
||||
// SDK MCP servers are supported - messages routed through control plane
|
||||
can_handle_mcp_message: true,
|
||||
};
|
||||
|
||||
return capabilities;
|
||||
}
|
||||
|
||||
private normalizeMcpServerConfig(
|
||||
serverName: string,
|
||||
config?: CLIMcpServerConfig,
|
||||
): MCPServerConfig | null {
|
||||
if (!config || typeof config !== 'object') {
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[SystemController] Ignoring invalid MCP server config for '${serverName}'`,
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
const authProvider = this.normalizeAuthProviderType(
|
||||
config.authProviderType,
|
||||
);
|
||||
const oauthConfig = this.normalizeOAuthConfig(config.oauth);
|
||||
|
||||
return new MCPServerConfig(
|
||||
config.command,
|
||||
config.args,
|
||||
config.env,
|
||||
config.cwd,
|
||||
config.url,
|
||||
config.httpUrl,
|
||||
config.headers,
|
||||
config.tcp,
|
||||
config.timeout,
|
||||
config.trust,
|
||||
config.description,
|
||||
config.includeTools,
|
||||
config.excludeTools,
|
||||
config.extensionName,
|
||||
oauthConfig,
|
||||
authProvider,
|
||||
config.targetAudience,
|
||||
config.targetServiceAccount,
|
||||
);
|
||||
}
|
||||
|
||||
private normalizeAuthProviderType(
|
||||
value?: string,
|
||||
): AuthProviderType | undefined {
|
||||
if (!value) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
switch (value) {
|
||||
case AuthProviderType.DYNAMIC_DISCOVERY:
|
||||
case AuthProviderType.GOOGLE_CREDENTIALS:
|
||||
case AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION:
|
||||
return value;
|
||||
default:
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
`[SystemController] Unsupported authProviderType '${value}', skipping`,
|
||||
);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private normalizeOAuthConfig(
|
||||
oauth?: CLIMcpServerConfig['oauth'],
|
||||
): MCPOAuthConfig | undefined {
|
||||
if (!oauth) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
enabled: oauth.enabled,
|
||||
clientId: oauth.clientId,
|
||||
clientSecret: oauth.clientSecret,
|
||||
authorizationUrl: oauth.authorizationUrl,
|
||||
tokenUrl: oauth.tokenUrl,
|
||||
scopes: oauth.scopes,
|
||||
audiences: oauth.audiences,
|
||||
redirectUri: oauth.redirectUri,
|
||||
tokenParamName: oauth.tokenParamName,
|
||||
registrationUrl: oauth.registrationUrl,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle interrupt request
|
||||
*
|
||||
@@ -183,7 +346,12 @@ export class SystemController extends BaseController {
|
||||
*/
|
||||
private async handleSetModel(
|
||||
payload: CLIControlSetModelRequest,
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
const model = payload.model;
|
||||
|
||||
// Validate model parameter
|
||||
@@ -223,8 +391,14 @@ export class SystemController extends BaseController {
|
||||
*
|
||||
* Returns list of supported slash commands loaded dynamically
|
||||
*/
|
||||
private async handleSupportedCommands(): Promise<Record<string, unknown>> {
|
||||
const slashCommands = await this.loadSlashCommandNames();
|
||||
private async handleSupportedCommands(
|
||||
signal: AbortSignal,
|
||||
): Promise<Record<string, unknown>> {
|
||||
if (signal.aborted) {
|
||||
throw new Error('Request aborted');
|
||||
}
|
||||
|
||||
const slashCommands = await this.loadSlashCommandNames(signal);
|
||||
|
||||
return {
|
||||
subtype: 'supported_commands',
|
||||
@@ -235,15 +409,24 @@ export class SystemController extends BaseController {
|
||||
/**
|
||||
* Load slash command names using CommandService
|
||||
*
|
||||
* @param signal - AbortSignal to respect for cancellation
|
||||
* @returns Promise resolving to array of slash command names
|
||||
*/
|
||||
private async loadSlashCommandNames(): Promise<string[]> {
|
||||
const controller = new AbortController();
|
||||
private async loadSlashCommandNames(signal: AbortSignal): Promise<string[]> {
|
||||
if (signal.aborted) {
|
||||
return [];
|
||||
}
|
||||
|
||||
try {
|
||||
const service = await CommandService.create(
|
||||
[new BuiltinCommandLoader(this.context.config)],
|
||||
controller.signal,
|
||||
signal,
|
||||
);
|
||||
|
||||
if (signal.aborted) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const names = new Set<string>();
|
||||
const commands = service.getCommands();
|
||||
for (const command of commands) {
|
||||
@@ -251,6 +434,11 @@ export class SystemController extends BaseController {
|
||||
}
|
||||
return Array.from(names).sort();
|
||||
} catch (error) {
|
||||
// Check if the error is due to abort
|
||||
if (signal.aborted) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (this.context.debugMode) {
|
||||
console.error(
|
||||
'[SystemController] Failed to load slash commands:',
|
||||
@@ -258,8 +446,6 @@ export class SystemController extends BaseController {
|
||||
);
|
||||
}
|
||||
return [];
|
||||
} finally {
|
||||
controller.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -153,6 +153,11 @@ describe('runNonInteractiveStreamJson', () => {
|
||||
handleControlResponse: ReturnType<typeof vi.fn>;
|
||||
handleCancel: ReturnType<typeof vi.fn>;
|
||||
shutdown: ReturnType<typeof vi.fn>;
|
||||
getPendingIncomingRequestCount: ReturnType<typeof vi.fn>;
|
||||
waitForPendingIncomingRequests: ReturnType<typeof vi.fn>;
|
||||
sdkMcpController: {
|
||||
createSendSdkMcpMessage: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
};
|
||||
let mockConsolePatcher: {
|
||||
patch: ReturnType<typeof vi.fn>;
|
||||
@@ -187,6 +192,11 @@ describe('runNonInteractiveStreamJson', () => {
|
||||
handleControlResponse: vi.fn(),
|
||||
handleCancel: vi.fn(),
|
||||
shutdown: vi.fn(),
|
||||
getPendingIncomingRequestCount: vi.fn().mockReturnValue(0),
|
||||
waitForPendingIncomingRequests: vi.fn().mockResolvedValue(undefined),
|
||||
sdkMcpController: {
|
||||
createSendSdkMcpMessage: vi.fn().mockReturnValue(vi.fn()),
|
||||
},
|
||||
};
|
||||
(
|
||||
ControlDispatcher as unknown as ReturnType<typeof vi.fn>
|
||||
|
||||
@@ -4,7 +4,10 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '@qwen-code/qwen-code-core';
|
||||
import type {
|
||||
Config,
|
||||
ConfigInitializeOptions,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { StreamJsonInputReader } from './io/StreamJsonInputReader.js';
|
||||
import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js';
|
||||
import { ControlContext } from './control/ControlContext.js';
|
||||
@@ -50,6 +53,12 @@ class Session {
|
||||
private isShuttingDown: boolean = false;
|
||||
private configInitialized: boolean = false;
|
||||
|
||||
// Single initialization promise that resolves when session is ready for user messages.
|
||||
// Created lazily once initialization actually starts.
|
||||
private initializationPromise: Promise<void> | null = null;
|
||||
private initializationResolve: (() => void) | null = null;
|
||||
private initializationReject: ((error: Error) => void) | null = null;
|
||||
|
||||
constructor(config: Config, initialPrompt?: CLIUserMessage) {
|
||||
this.config = config;
|
||||
this.sessionId = config.getSessionId();
|
||||
@@ -66,12 +75,32 @@ class Session {
|
||||
this.setupSignalHandlers();
|
||||
}
|
||||
|
||||
private ensureInitializationPromise(): void {
|
||||
if (this.initializationPromise) {
|
||||
return;
|
||||
}
|
||||
this.initializationPromise = new Promise<void>((resolve, reject) => {
|
||||
this.initializationResolve = () => {
|
||||
resolve();
|
||||
this.initializationResolve = null;
|
||||
this.initializationReject = null;
|
||||
};
|
||||
this.initializationReject = (error: Error) => {
|
||||
reject(error);
|
||||
this.initializationResolve = null;
|
||||
this.initializationReject = null;
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
private getNextPromptId(): string {
|
||||
this.promptIdCounter++;
|
||||
return `${this.sessionId}########${this.promptIdCounter}`;
|
||||
}
|
||||
|
||||
private async ensureConfigInitialized(): Promise<void> {
|
||||
private async ensureConfigInitialized(
|
||||
options?: ConfigInitializeOptions,
|
||||
): Promise<void> {
|
||||
if (this.configInitialized) {
|
||||
return;
|
||||
}
|
||||
@@ -81,7 +110,7 @@ class Session {
|
||||
}
|
||||
|
||||
try {
|
||||
await this.config.initialize();
|
||||
await this.config.initialize(options);
|
||||
this.configInitialized = true;
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
@@ -91,6 +120,44 @@ class Session {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark initialization as complete
|
||||
*/
|
||||
private completeInitialization(): void {
|
||||
if (this.initializationResolve) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Initialization complete');
|
||||
}
|
||||
this.initializationResolve();
|
||||
this.initializationResolve = null;
|
||||
this.initializationReject = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark initialization as failed
|
||||
*/
|
||||
private failInitialization(error: Error): void {
|
||||
if (this.initializationReject) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Initialization failed:', error);
|
||||
}
|
||||
this.initializationReject(error);
|
||||
this.initializationResolve = null;
|
||||
this.initializationReject = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for session to be ready for user messages
|
||||
*/
|
||||
private async waitForInitialization(): Promise<void> {
|
||||
if (!this.initializationPromise) {
|
||||
return;
|
||||
}
|
||||
await this.initializationPromise;
|
||||
}
|
||||
|
||||
private ensureControlSystem(): void {
|
||||
if (this.controlContext && this.dispatcher && this.controlService) {
|
||||
return;
|
||||
@@ -120,49 +187,114 @@ class Session {
|
||||
return this.dispatcher;
|
||||
}
|
||||
|
||||
private async handleFirstMessage(
|
||||
/**
|
||||
* Handle the first message to determine session mode (SDK vs direct).
|
||||
* This is synchronous from the message loop's perspective - it starts
|
||||
* async work but does not return a promise that the loop awaits.
|
||||
*
|
||||
* The initialization completes asynchronously and resolves initializationPromise
|
||||
* when ready for user messages.
|
||||
*/
|
||||
private handleFirstMessage(
|
||||
message:
|
||||
| CLIMessage
|
||||
| CLIControlRequest
|
||||
| CLIControlResponse
|
||||
| ControlCancelRequest,
|
||||
): Promise<boolean> {
|
||||
): void {
|
||||
if (isControlRequest(message)) {
|
||||
const request = message as CLIControlRequest;
|
||||
this.controlSystemEnabled = true;
|
||||
this.ensureControlSystem();
|
||||
if (request.request.subtype === 'initialize') {
|
||||
// Dispatch the initialize request first
|
||||
await this.dispatcher?.dispatch(request);
|
||||
|
||||
// After handling initialize control request, initialize the config
|
||||
// This is the SDK mode where config initialization is deferred
|
||||
await this.ensureConfigInitialized();
|
||||
return true;
|
||||
if (request.request.subtype === 'initialize') {
|
||||
// Start SDK mode initialization (fire-and-forget from loop perspective)
|
||||
void this.initializeSdkMode(request);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
'[Session] Ignoring non-initialize control request during initialization',
|
||||
);
|
||||
}
|
||||
return true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (isCLIUserMessage(message)) {
|
||||
this.controlSystemEnabled = false;
|
||||
// For non-SDK mode (direct user message), initialize config if not already done
|
||||
await this.ensureConfigInitialized();
|
||||
this.enqueueUserMessage(message as CLIUserMessage);
|
||||
return true;
|
||||
// Start direct mode initialization (fire-and-forget from loop perspective)
|
||||
void this.initializeDirectMode(message as CLIUserMessage);
|
||||
return;
|
||||
}
|
||||
|
||||
this.controlSystemEnabled = false;
|
||||
return false;
|
||||
}
|
||||
|
||||
private async handleControlRequest(
|
||||
request: CLIControlRequest,
|
||||
/**
|
||||
* SDK mode initialization flow
|
||||
* Dispatches initialize request and initializes config with MCP support
|
||||
*/
|
||||
private async initializeSdkMode(request: CLIControlRequest): Promise<void> {
|
||||
this.ensureInitializationPromise();
|
||||
try {
|
||||
// Dispatch the initialize request first
|
||||
// This registers SDK MCP servers in the control context
|
||||
await this.dispatcher?.dispatch(request);
|
||||
|
||||
// Get sendSdkMcpMessage callback from SdkMcpController
|
||||
// This callback is used by McpClientManager to send MCP messages
|
||||
// from CLI MCP clients to SDK MCP servers via the control plane
|
||||
const sendSdkMcpMessage =
|
||||
this.dispatcher?.sdkMcpController.createSendSdkMcpMessage();
|
||||
|
||||
// Initialize config with SDK MCP message support
|
||||
await this.ensureConfigInitialized({ sendSdkMcpMessage });
|
||||
|
||||
// Initialization complete!
|
||||
this.completeInitialization();
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] SDK mode initialization failed:', error);
|
||||
}
|
||||
this.failInitialization(
|
||||
error instanceof Error ? error : new Error(String(error)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Direct mode initialization flow
|
||||
* Initializes config and enqueues the first user message
|
||||
*/
|
||||
private async initializeDirectMode(
|
||||
userMessage: CLIUserMessage,
|
||||
): Promise<void> {
|
||||
this.ensureInitializationPromise();
|
||||
try {
|
||||
// Initialize config
|
||||
await this.ensureConfigInitialized();
|
||||
|
||||
// Initialization complete!
|
||||
this.completeInitialization();
|
||||
|
||||
// Enqueue the first user message for processing
|
||||
this.enqueueUserMessage(userMessage);
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Direct mode initialization failed:', error);
|
||||
}
|
||||
this.failInitialization(
|
||||
error instanceof Error ? error : new Error(String(error)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle control request asynchronously (fire-and-forget from main loop).
|
||||
* Errors are handled internally and responses sent by dispatcher.
|
||||
*/
|
||||
private handleControlRequestAsync(request: CLIControlRequest): void {
|
||||
const dispatcher = this.getDispatcher();
|
||||
if (!dispatcher) {
|
||||
if (this.debugMode) {
|
||||
@@ -171,9 +303,20 @@ class Session {
|
||||
return;
|
||||
}
|
||||
|
||||
await dispatcher.dispatch(request);
|
||||
// Fire-and-forget: dispatch runs concurrently
|
||||
// The dispatcher's pendingIncomingRequests tracks completion
|
||||
void dispatcher.dispatch(request).catch((error) => {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Control request dispatch error:', error);
|
||||
}
|
||||
// Error response is already sent by dispatcher.dispatch()
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle control response - MUST be synchronous
|
||||
* This resolves pending outgoing requests, breaking the deadlock cycle.
|
||||
*/
|
||||
private handleControlResponse(response: CLIControlResponse): void {
|
||||
const dispatcher = this.getDispatcher();
|
||||
if (!dispatcher) {
|
||||
@@ -201,8 +344,8 @@ class Session {
|
||||
return;
|
||||
}
|
||||
|
||||
// Ensure config is initialized before processing user messages
|
||||
await this.ensureConfigInitialized();
|
||||
// Wait for initialization to complete before processing user messages
|
||||
await this.waitForInitialization();
|
||||
|
||||
const promptId = this.getNextPromptId();
|
||||
|
||||
@@ -307,6 +450,45 @@ class Session {
|
||||
process.on('SIGTERM', this.shutdownHandler);
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for all pending work to complete before shutdown
|
||||
*/
|
||||
private async waitForAllPendingWork(): Promise<void> {
|
||||
// 1. Wait for initialization to complete (or fail)
|
||||
try {
|
||||
await this.waitForInitialization();
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Initialization error during shutdown:', error);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Wait for all control request handlers using dispatcher's tracking
|
||||
if (this.dispatcher) {
|
||||
const pendingCount = this.dispatcher.getPendingIncomingRequestCount();
|
||||
if (pendingCount > 0 && this.debugMode) {
|
||||
console.error(
|
||||
`[Session] Waiting for ${pendingCount} pending control request handlers`,
|
||||
);
|
||||
}
|
||||
await this.dispatcher.waitForPendingIncomingRequests();
|
||||
}
|
||||
|
||||
// 3. Wait for user message processing queue
|
||||
while (this.processingPromise) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Waiting for user message processing');
|
||||
}
|
||||
try {
|
||||
await this.processingPromise;
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Error in user message processing:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async shutdown(): Promise<void> {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Shutting down');
|
||||
@@ -314,18 +496,8 @@ class Session {
|
||||
|
||||
this.isShuttingDown = true;
|
||||
|
||||
if (this.processingPromise) {
|
||||
try {
|
||||
await this.processingPromise;
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
'[Session] Error waiting for processing to complete:',
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Wait for all pending work
|
||||
await this.waitForAllPendingWork();
|
||||
|
||||
this.dispatcher?.shutdown();
|
||||
this.cleanupSignalHandlers();
|
||||
@@ -339,18 +511,30 @@ class Session {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Main message processing loop
|
||||
*
|
||||
* CRITICAL: This loop must NEVER await handlers that might need to
|
||||
* send control requests and wait for responses. Such handlers must
|
||||
* be started in fire-and-forget mode, allowing the loop to continue
|
||||
* reading responses that resolve pending requests.
|
||||
*
|
||||
* Message handling order:
|
||||
* 1. control_response - FIRST, synchronously resolves pending requests
|
||||
* 2. First message - determines mode, starts async initialization
|
||||
* 3. control_request - fire-and-forget, tracked by dispatcher
|
||||
* 4. control_cancel - synchronous
|
||||
* 5. user_message - enqueued for processing
|
||||
*/
|
||||
async run(): Promise<void> {
|
||||
try {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Starting session', this.sessionId);
|
||||
}
|
||||
|
||||
// Handle initial prompt if provided (fire-and-forget)
|
||||
if (this.initialPrompt !== null) {
|
||||
const handled = await this.handleFirstMessage(this.initialPrompt);
|
||||
if (handled && this.isShuttingDown) {
|
||||
await this.shutdown();
|
||||
return;
|
||||
}
|
||||
this.handleFirstMessage(this.initialPrompt);
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -359,23 +543,33 @@ class Session {
|
||||
break;
|
||||
}
|
||||
|
||||
if (this.controlSystemEnabled === null) {
|
||||
const handled = await this.handleFirstMessage(message);
|
||||
if (handled) {
|
||||
if (this.isShuttingDown) {
|
||||
break;
|
||||
}
|
||||
// ============================================================
|
||||
// CRITICAL: Handle control_response FIRST and SYNCHRONOUSLY
|
||||
// This resolves pending outgoing requests, breaking deadlock.
|
||||
// ============================================================
|
||||
if (isControlResponse(message)) {
|
||||
this.handleControlResponse(message as CLIControlResponse);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle first message to determine session mode
|
||||
if (this.controlSystemEnabled === null) {
|
||||
this.handleFirstMessage(message);
|
||||
continue;
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// CRITICAL: Handle control_request in FIRE-AND-FORGET mode
|
||||
// DON'T await - let handler run concurrently while loop continues
|
||||
// Dispatcher's pendingIncomingRequests tracks completion
|
||||
// ============================================================
|
||||
if (isControlRequest(message)) {
|
||||
await this.handleControlRequest(message as CLIControlRequest);
|
||||
} else if (isControlResponse(message)) {
|
||||
this.handleControlResponse(message as CLIControlResponse);
|
||||
this.handleControlRequestAsync(message as CLIControlRequest);
|
||||
} else if (isControlCancel(message)) {
|
||||
// Cancel is synchronous - OK to handle inline
|
||||
this.handleControlCancel(message as ControlCancelRequest);
|
||||
} else if (isCLIUserMessage(message)) {
|
||||
// User messages are enqueued, processing runs separately
|
||||
this.enqueueUserMessage(message as CLIUserMessage);
|
||||
} else if (this.debugMode) {
|
||||
if (
|
||||
@@ -402,19 +596,8 @@ class Session {
|
||||
throw streamError;
|
||||
}
|
||||
|
||||
while (this.processingPromise) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Waiting for final processing to complete');
|
||||
}
|
||||
try {
|
||||
await this.processingPromise;
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error('[Session] Error in final processing:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream ended - wait for all pending work before shutdown
|
||||
await this.waitForAllPendingWork();
|
||||
await this.shutdown();
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import type {
|
||||
MCPServerConfig,
|
||||
SubagentConfig,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import type { SubagentConfig } from '@qwen-code/qwen-code-core';
|
||||
|
||||
/**
|
||||
* Annotation for attaching metadata to content blocks
|
||||
@@ -298,11 +295,68 @@ export interface CLIControlPermissionRequest {
|
||||
blocked_path: string | null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wire format for SDK MCP server config in initialization request.
|
||||
* The actual Server instance stays in the SDK process.
|
||||
*/
|
||||
export interface SDKMcpServerConfig {
|
||||
type: 'sdk';
|
||||
name: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wire format for external MCP server config in initialization request.
|
||||
* Represents stdio/SSE/HTTP/TCP transports that must run in the CLI process.
|
||||
*/
|
||||
export interface CLIMcpServerConfig {
|
||||
command?: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
cwd?: string;
|
||||
url?: string;
|
||||
httpUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
tcp?: string;
|
||||
timeout?: number;
|
||||
trust?: boolean;
|
||||
description?: string;
|
||||
includeTools?: string[];
|
||||
excludeTools?: string[];
|
||||
extensionName?: string;
|
||||
oauth?: {
|
||||
enabled?: boolean;
|
||||
clientId?: string;
|
||||
clientSecret?: string;
|
||||
authorizationUrl?: string;
|
||||
tokenUrl?: string;
|
||||
scopes?: string[];
|
||||
audiences?: string[];
|
||||
redirectUri?: string;
|
||||
tokenParamName?: string;
|
||||
registrationUrl?: string;
|
||||
};
|
||||
authProviderType?:
|
||||
| 'dynamic_discovery'
|
||||
| 'google_credentials'
|
||||
| 'service_account_impersonation';
|
||||
targetAudience?: string;
|
||||
targetServiceAccount?: string;
|
||||
}
|
||||
|
||||
export interface CLIControlInitializeRequest {
|
||||
subtype: 'initialize';
|
||||
hooks?: HookRegistration[] | null;
|
||||
sdkMcpServers?: Record<string, MCPServerConfig>;
|
||||
mcpServers?: Record<string, MCPServerConfig>;
|
||||
/**
|
||||
* SDK MCP servers config
|
||||
* These are MCP servers running in the SDK process, connected via control plane.
|
||||
* External MCP servers are configured separately in settings, not via initialization.
|
||||
*/
|
||||
sdkMcpServers?: Record<string, Omit<SDKMcpServerConfig, 'instance'>>;
|
||||
/**
|
||||
* External MCP servers that the SDK wants the CLI to manage.
|
||||
* These run outside the SDK process and require CLI-side transport setup.
|
||||
*/
|
||||
mcpServers?: Record<string, CLIMcpServerConfig>;
|
||||
agents?: SubagentConfig[];
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +63,7 @@ vi.mock('../tools/tool-registry', () => {
|
||||
ToolRegistryMock.prototype.registerTool = vi.fn();
|
||||
ToolRegistryMock.prototype.discoverAllTools = vi.fn();
|
||||
ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed
|
||||
ToolRegistryMock.prototype.getAllToolNames = vi.fn(() => []);
|
||||
ToolRegistryMock.prototype.getTool = vi.fn();
|
||||
ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []);
|
||||
return { ToolRegistry: ToolRegistryMock };
|
||||
|
||||
@@ -46,6 +46,7 @@ import { ExitPlanModeTool } from '../tools/exitPlanMode.js';
|
||||
import { GlobTool } from '../tools/glob.js';
|
||||
import { GrepTool } from '../tools/grep.js';
|
||||
import { LSTool } from '../tools/ls.js';
|
||||
import type { SendSdkMcpMessage } from '../tools/mcp-client.js';
|
||||
import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js';
|
||||
import { ReadFileTool } from '../tools/read-file.js';
|
||||
import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
||||
@@ -239,9 +240,18 @@ export class MCPServerConfig {
|
||||
readonly targetAudience?: string,
|
||||
/* targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
|
||||
readonly targetServiceAccount?: string,
|
||||
// SDK MCP server type - 'sdk' indicates server runs in SDK process
|
||||
readonly type?: 'sdk',
|
||||
) {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an MCP server config represents an SDK server
|
||||
*/
|
||||
export function isSdkMcpServerConfig(config: MCPServerConfig): boolean {
|
||||
return config.type === 'sdk';
|
||||
}
|
||||
|
||||
export enum AuthProviderType {
|
||||
DYNAMIC_DISCOVERY = 'dynamic_discovery',
|
||||
GOOGLE_CREDENTIALS = 'google_credentials',
|
||||
@@ -360,6 +370,17 @@ function normalizeConfigOutputFormat(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for Config.initialize()
|
||||
*/
|
||||
export interface ConfigInitializeOptions {
|
||||
/**
|
||||
* Callback for sending MCP messages to SDK servers via control plane.
|
||||
* Required for SDK MCP server support in SDK mode.
|
||||
*/
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage;
|
||||
}
|
||||
|
||||
export class Config {
|
||||
private sessionId: string;
|
||||
private sessionData?: ResumedSessionData;
|
||||
@@ -599,8 +620,9 @@ export class Config {
|
||||
|
||||
/**
|
||||
* Must only be called once, throws if called again.
|
||||
* @param options Optional initialization options including sendSdkMcpMessage callback
|
||||
*/
|
||||
async initialize(): Promise<void> {
|
||||
async initialize(options?: ConfigInitializeOptions): Promise<void> {
|
||||
if (this.initialized) {
|
||||
throw Error('Config was already initialized');
|
||||
}
|
||||
@@ -619,7 +641,9 @@ export class Config {
|
||||
this.subagentManager.loadSessionSubagents(this.sessionSubagents);
|
||||
}
|
||||
|
||||
this.toolRegistry = await this.createToolRegistry();
|
||||
this.toolRegistry = await this.createToolRegistry(
|
||||
options?.sendSdkMcpMessage,
|
||||
);
|
||||
|
||||
await this.geminiClient.initialize();
|
||||
|
||||
@@ -1261,8 +1285,14 @@ export class Config {
|
||||
return this.subagentManager;
|
||||
}
|
||||
|
||||
async createToolRegistry(): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(this, this.eventEmitter);
|
||||
async createToolRegistry(
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
): Promise<ToolRegistry> {
|
||||
const registry = new ToolRegistry(
|
||||
this,
|
||||
this.eventEmitter,
|
||||
sendSdkMcpMessage,
|
||||
);
|
||||
|
||||
const coreToolsConfig = this.getCoreTools();
|
||||
const excludeToolsConfig = this.getExcludeTools();
|
||||
@@ -1347,6 +1377,7 @@ export class Config {
|
||||
}
|
||||
|
||||
await registry.discoverAllTools();
|
||||
console.debug('ToolRegistry created', registry.getAllToolNames());
|
||||
return registry;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +102,9 @@ export * from './tools/shell.js';
|
||||
export * from './tools/web-search/index.js';
|
||||
export * from './tools/read-many-files.js';
|
||||
export * from './tools/mcp-client.js';
|
||||
export * from './tools/mcp-client-manager.js';
|
||||
export * from './tools/mcp-tool.js';
|
||||
export * from './tools/sdk-control-client-transport.js';
|
||||
export * from './tools/task.js';
|
||||
export * from './tools/todoWrite.js';
|
||||
export * from './tools/exitPlanMode.js';
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { isSdkMcpServerConfig } from '../config/config.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
import type { PromptRegistry } from '../prompts/prompt-registry.js';
|
||||
import {
|
||||
@@ -12,6 +13,7 @@ import {
|
||||
MCPDiscoveryState,
|
||||
populateMcpServerCommand,
|
||||
} from './mcp-client.js';
|
||||
import type { SendSdkMcpMessage } from './mcp-client.js';
|
||||
import { getErrorMessage } from '../utils/errors.js';
|
||||
import type { EventEmitter } from 'node:events';
|
||||
import type { WorkspaceContext } from '../utils/workspaceContext.js';
|
||||
@@ -31,6 +33,7 @@ export class McpClientManager {
|
||||
private readonly workspaceContext: WorkspaceContext;
|
||||
private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED;
|
||||
private readonly eventEmitter?: EventEmitter;
|
||||
private readonly sendSdkMcpMessage?: SendSdkMcpMessage;
|
||||
|
||||
constructor(
|
||||
mcpServers: Record<string, MCPServerConfig>,
|
||||
@@ -40,6 +43,7 @@ export class McpClientManager {
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
eventEmitter?: EventEmitter,
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
) {
|
||||
this.mcpServers = mcpServers;
|
||||
this.mcpServerCommand = mcpServerCommand;
|
||||
@@ -48,6 +52,7 @@ export class McpClientManager {
|
||||
this.debugMode = debugMode;
|
||||
this.workspaceContext = workspaceContext;
|
||||
this.eventEmitter = eventEmitter;
|
||||
this.sendSdkMcpMessage = sendSdkMcpMessage;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -71,6 +76,11 @@ export class McpClientManager {
|
||||
this.eventEmitter?.emit('mcp-client-update', this.clients);
|
||||
const discoveryPromises = Object.entries(servers).map(
|
||||
async ([name, config]) => {
|
||||
// For SDK MCP servers, pass the sendSdkMcpMessage callback
|
||||
const sdkCallback = isSdkMcpServerConfig(config)
|
||||
? this.sendSdkMcpMessage
|
||||
: undefined;
|
||||
|
||||
const client = new McpClient(
|
||||
name,
|
||||
config,
|
||||
@@ -78,6 +88,7 @@ export class McpClientManager {
|
||||
this.promptRegistry,
|
||||
this.workspaceContext,
|
||||
this.debugMode,
|
||||
sdkCallback,
|
||||
);
|
||||
this.clients.set(name, client);
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/
|
||||
import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js';
|
||||
import type {
|
||||
GetPromptResult,
|
||||
JSONRPCMessage,
|
||||
Prompt,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import {
|
||||
@@ -22,10 +23,11 @@ import {
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import { parse } from 'shell-quote';
|
||||
import type { Config, MCPServerConfig } from '../config/config.js';
|
||||
import { AuthProviderType } from '../config/config.js';
|
||||
import { AuthProviderType, isSdkMcpServerConfig } from '../config/config.js';
|
||||
import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js';
|
||||
import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { SdkControlClientTransport } from './sdk-control-client-transport.js';
|
||||
|
||||
import type { FunctionDeclaration } from '@google/genai';
|
||||
import { mcpToTool } from '@google/genai';
|
||||
@@ -42,6 +44,14 @@ import type {
|
||||
} from '../utils/workspaceContext.js';
|
||||
import type { ToolRegistry } from './tool-registry.js';
|
||||
|
||||
/**
|
||||
* Callback type for sending MCP messages to SDK servers via control plane
|
||||
*/
|
||||
export type SendSdkMcpMessage = (
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
) => Promise<JSONRPCMessage>;
|
||||
|
||||
export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes
|
||||
|
||||
export type DiscoveredMCPPrompt = Prompt & {
|
||||
@@ -92,6 +102,7 @@ export class McpClient {
|
||||
private readonly promptRegistry: PromptRegistry,
|
||||
private readonly workspaceContext: WorkspaceContext,
|
||||
private readonly debugMode: boolean,
|
||||
private readonly sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
) {
|
||||
this.client = new Client({
|
||||
name: `qwen-cli-mcp-client-${this.serverName}`,
|
||||
@@ -189,7 +200,12 @@ export class McpClient {
|
||||
}
|
||||
|
||||
private async createTransport(): Promise<Transport> {
|
||||
return createTransport(this.serverName, this.serverConfig, this.debugMode);
|
||||
return createTransport(
|
||||
this.serverName,
|
||||
this.serverConfig,
|
||||
this.debugMode,
|
||||
this.sendSdkMcpMessage,
|
||||
);
|
||||
}
|
||||
|
||||
private async discoverTools(cliConfig: Config): Promise<DiscoveredMCPTool[]> {
|
||||
@@ -501,6 +517,7 @@ export function populateMcpServerCommand(
|
||||
* @param mcpServerName The name identifier for this MCP server
|
||||
* @param mcpServerConfig Configuration object containing connection details
|
||||
* @param toolRegistry The registry to register discovered tools with
|
||||
* @param sendSdkMcpMessage Optional callback for SDK MCP servers to route messages via control plane.
|
||||
* @returns Promise that resolves when discovery is complete
|
||||
*/
|
||||
export async function connectAndDiscover(
|
||||
@@ -511,6 +528,7 @@ export async function connectAndDiscover(
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
cliConfig: Config,
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
): Promise<void> {
|
||||
updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING);
|
||||
|
||||
@@ -521,6 +539,7 @@ export async function connectAndDiscover(
|
||||
mcpServerConfig,
|
||||
debugMode,
|
||||
workspaceContext,
|
||||
sendSdkMcpMessage,
|
||||
);
|
||||
|
||||
mcpClient.onerror = (error) => {
|
||||
@@ -744,6 +763,7 @@ export function hasNetworkTransport(config: MCPServerConfig): boolean {
|
||||
*
|
||||
* @param mcpServerName The name of the MCP server, used for logging and identification.
|
||||
* @param mcpServerConfig The configuration specifying how to connect to the server.
|
||||
* @param sendSdkMcpMessage Optional callback for SDK MCP servers to route messages via control plane.
|
||||
* @returns A promise that resolves to a connected MCP `Client` instance.
|
||||
* @throws An error if the connection fails or the configuration is invalid.
|
||||
*/
|
||||
@@ -752,6 +772,7 @@ export async function connectToMcpServer(
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
workspaceContext: WorkspaceContext,
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
): Promise<Client> {
|
||||
const mcpClient = new Client({
|
||||
name: 'qwen-code-mcp-client',
|
||||
@@ -808,6 +829,7 @@ export async function connectToMcpServer(
|
||||
mcpServerName,
|
||||
mcpServerConfig,
|
||||
debugMode,
|
||||
sendSdkMcpMessage,
|
||||
);
|
||||
try {
|
||||
await mcpClient.connect(transport, {
|
||||
@@ -1172,7 +1194,21 @@ export async function createTransport(
|
||||
mcpServerName: string,
|
||||
mcpServerConfig: MCPServerConfig,
|
||||
debugMode: boolean,
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
): Promise<Transport> {
|
||||
if (isSdkMcpServerConfig(mcpServerConfig)) {
|
||||
if (!sendSdkMcpMessage) {
|
||||
throw new Error(
|
||||
`SDK MCP server '${mcpServerName}' requires sendSdkMcpMessage callback`,
|
||||
);
|
||||
}
|
||||
return new SdkControlClientTransport({
|
||||
serverName: mcpServerName,
|
||||
sendMcpMessage: sendSdkMcpMessage,
|
||||
debugMode,
|
||||
});
|
||||
}
|
||||
|
||||
if (
|
||||
mcpServerConfig.authProviderType ===
|
||||
AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION
|
||||
|
||||
163
packages/core/src/tools/sdk-control-client-transport.ts
Normal file
163
packages/core/src/tools/sdk-control-client-transport.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* SdkControlClientTransport - MCP Client transport for SDK MCP servers
|
||||
*
|
||||
* This transport enables CLI's MCP client to connect to SDK MCP servers
|
||||
* through the control plane. Messages are routed:
|
||||
*
|
||||
* CLI MCP Client → SdkControlClientTransport → sendMcpMessage() →
|
||||
* control_request (mcp_message) → SDK → control_response → onmessage → CLI
|
||||
*
|
||||
* Unlike StdioClientTransport which spawns a subprocess, this transport
|
||||
* communicates with SDK MCP servers running in the SDK process.
|
||||
*/
|
||||
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
|
||||
/**
|
||||
* Callback to send MCP messages to SDK via control plane
|
||||
* Returns the MCP response from the SDK
|
||||
*/
|
||||
export type SendMcpMessageCallback = (
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
) => Promise<JSONRPCMessage>;
|
||||
|
||||
export interface SdkControlClientTransportOptions {
|
||||
serverName: string;
|
||||
sendMcpMessage: SendMcpMessageCallback;
|
||||
debugMode?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* MCP Client Transport for SDK MCP servers
|
||||
*
|
||||
* Implements the @modelcontextprotocol/sdk Transport interface to enable
|
||||
* CLI's MCP client to connect to SDK MCP servers via the control plane.
|
||||
*/
|
||||
export class SdkControlClientTransport {
|
||||
private serverName: string;
|
||||
private sendMcpMessage: SendMcpMessageCallback;
|
||||
private debugMode: boolean;
|
||||
private started = false;
|
||||
|
||||
// Transport interface callbacks
|
||||
onmessage?: (message: JSONRPCMessage) => void;
|
||||
onerror?: (error: Error) => void;
|
||||
onclose?: () => void;
|
||||
|
||||
constructor(options: SdkControlClientTransportOptions) {
|
||||
this.serverName = options.serverName;
|
||||
this.sendMcpMessage = options.sendMcpMessage;
|
||||
this.debugMode = options.debugMode ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the transport
|
||||
* For SDK transport, this just marks it as ready - no subprocess to spawn
|
||||
*/
|
||||
async start(): Promise<void> {
|
||||
if (this.started) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.started = true;
|
||||
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
`[SdkControlClientTransport] Started for server '${this.serverName}'`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send a message to the SDK MCP server via control plane
|
||||
*
|
||||
* Routes the message through the control plane and delivers
|
||||
* the response via onmessage callback.
|
||||
*/
|
||||
async send(message: JSONRPCMessage): Promise<void> {
|
||||
if (!this.started) {
|
||||
throw new Error(
|
||||
`SdkControlClientTransport (${this.serverName}) not started. Call start() first.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
`[SdkControlClientTransport] Sending message to '${this.serverName}':`,
|
||||
JSON.stringify(message),
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
// Send message to SDK and wait for response
|
||||
const response = await this.sendMcpMessage(this.serverName, message);
|
||||
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
`[SdkControlClientTransport] Received response from '${this.serverName}':`,
|
||||
JSON.stringify(response),
|
||||
);
|
||||
}
|
||||
|
||||
// Deliver response via onmessage callback
|
||||
if (this.onmessage) {
|
||||
this.onmessage(response);
|
||||
}
|
||||
} catch (error) {
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
`[SdkControlClientTransport] Error sending to '${this.serverName}':`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
if (this.onerror) {
|
||||
this.onerror(error instanceof Error ? error : new Error(String(error)));
|
||||
}
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Close the transport
|
||||
*/
|
||||
async close(): Promise<void> {
|
||||
if (!this.started) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.started = false;
|
||||
|
||||
if (this.debugMode) {
|
||||
console.error(
|
||||
`[SdkControlClientTransport] Closed for server '${this.serverName}'`,
|
||||
);
|
||||
}
|
||||
|
||||
if (this.onclose) {
|
||||
this.onclose();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if transport is started
|
||||
*/
|
||||
isStarted(): boolean {
|
||||
return this.started;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get server name
|
||||
*/
|
||||
getServerName(): string {
|
||||
return this.serverName;
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import type { Config } from '../config/config.js';
|
||||
import { spawn } from 'node:child_process';
|
||||
import { StringDecoder } from 'node:string_decoder';
|
||||
import { connectAndDiscover } from './mcp-client.js';
|
||||
import type { SendSdkMcpMessage } from './mcp-client.js';
|
||||
import { McpClientManager } from './mcp-client-manager.js';
|
||||
import { DiscoveredMCPTool } from './mcp-tool.js';
|
||||
import { parse } from 'shell-quote';
|
||||
@@ -173,7 +174,11 @@ export class ToolRegistry {
|
||||
private config: Config;
|
||||
private mcpClientManager: McpClientManager;
|
||||
|
||||
constructor(config: Config, eventEmitter?: EventEmitter) {
|
||||
constructor(
|
||||
config: Config,
|
||||
eventEmitter?: EventEmitter,
|
||||
sendSdkMcpMessage?: SendSdkMcpMessage,
|
||||
) {
|
||||
this.config = config;
|
||||
this.mcpClientManager = new McpClientManager(
|
||||
this.config.getMcpServers() ?? {},
|
||||
@@ -183,6 +188,7 @@ export class ToolRegistry {
|
||||
this.config.getDebugMode(),
|
||||
this.config.getWorkspaceContext(),
|
||||
eventEmitter,
|
||||
sendSdkMcpMessage,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ npm install @qwen-code/sdk-typescript
|
||||
## Requirements
|
||||
|
||||
- Node.js >= 20.0.0
|
||||
- [Qwen Code](https://github.com/QwenLM/qwen-code) installed and accessible in PATH
|
||||
- [Qwen Code](https://github.com/QwenLM/qwen-code) >= 0.4.0 (stable) installed and accessible in PATH
|
||||
|
||||
> **Note for nvm users**: If you use nvm to manage Node.js versions, the SDK may not be able to auto-detect the Qwen Code executable. You should explicitly set the `pathToQwenExecutable` option to the full path of the `qwen` binary.
|
||||
|
||||
@@ -61,7 +61,7 @@ Creates a new query session with the Qwen Code.
|
||||
| `permissionMode` | `'default' \| 'plan' \| 'auto-edit' \| 'yolo'` | `'default'` | Permission mode controlling tool execution approval. See [Permission Modes](#permission-modes) for details. |
|
||||
| `canUseTool` | `CanUseTool` | - | Custom permission handler for tool execution approval. Invoked when a tool requires confirmation. Must respond within 30 seconds or the request will be auto-denied. See [Custom Permission Handler](#custom-permission-handler). |
|
||||
| `env` | `Record<string, string>` | - | Environment variables to pass to the Qwen Code process. Merged with the current process environment. |
|
||||
| `mcpServers` | `Record<string, ExternalMcpServerConfig>` | - | External MCP (Model Context Protocol) servers to connect. Each server is identified by a unique name and configured with `command`, `args`, and `env`. |
|
||||
| `mcpServers` | `Record<string, McpServerConfig>` | - | MCP (Model Context Protocol) servers to connect. Supports external servers (stdio/SSE/HTTP) and SDK-embedded servers. External servers are configured with transport options like `command`, `args`, `url`, `httpUrl`, etc. SDK servers use `{ type: 'sdk', name: string, instance: Server }`. |
|
||||
| `abortController` | `AbortController` | - | Controller to cancel the query session. Call `abortController.abort()` to terminate the session and cleanup resources. |
|
||||
| `debug` | `boolean` | `false` | Enable debug mode for verbose logging from the CLI process. |
|
||||
| `maxSessionTurns` | `number` | `-1` (unlimited) | Maximum number of conversation turns before the session automatically terminates. A turn consists of a user message and an assistant response. |
|
||||
@@ -74,12 +74,27 @@ Creates a new query session with the Qwen Code.
|
||||
|
||||
### Timeouts
|
||||
|
||||
The SDK enforces the following timeouts:
|
||||
The SDK enforces the following default timeouts:
|
||||
|
||||
| Timeout | Duration | Description |
|
||||
| ------------------- | ---------- | ---------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Permission Callback | 30 seconds | Maximum time for `canUseTool` callback to respond. If exceeded, the tool request is auto-denied. |
|
||||
| Control Request | 30 seconds | Maximum time for control operations like `initialize()`, `setModel()`, `setPermissionMode()`, and `interrupt()` to complete. |
|
||||
| Timeout | Default | Description |
|
||||
| ---------------- | ---------- | ---------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `canUseTool` | 30 seconds | Maximum time for `canUseTool` callback to respond. If exceeded, the tool request is auto-denied. |
|
||||
| `mcpRequest` | 1 minute | Maximum time for SDK MCP tool calls to complete. |
|
||||
| `controlRequest` | 30 seconds | Maximum time for control operations like `initialize()`, `setModel()`, `setPermissionMode()`, and `interrupt()` to complete. |
|
||||
| `streamClose` | 1 minute | Maximum time to wait for initialization to complete before closing CLI stdin in multi-turn mode with SDK MCP servers. |
|
||||
|
||||
You can customize these timeouts via the `timeout` option:
|
||||
|
||||
```typescript
|
||||
const query = qwen.query('Your prompt', {
|
||||
timeout: {
|
||||
canUseTool: 60000, // 60 seconds for permission callback
|
||||
mcpRequest: 600000, // 10 minutes for MCP tool calls
|
||||
controlRequest: 60000, // 60 seconds for control requests
|
||||
streamClose: 15000, // 15 seconds for stream close wait
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### Message Types
|
||||
|
||||
@@ -212,7 +227,7 @@ const result = query({
|
||||
});
|
||||
```
|
||||
|
||||
### With MCP Servers
|
||||
### With External MCP Servers
|
||||
|
||||
```typescript
|
||||
import { query } from '@qwen-code/sdk-typescript';
|
||||
@@ -231,6 +246,84 @@ const result = query({
|
||||
});
|
||||
```
|
||||
|
||||
### With SDK-Embedded MCP Servers
|
||||
|
||||
The SDK provides `tool` and `createSdkMcpServer` to create MCP servers that run in the same process as your SDK application. This is useful when you want to expose custom tools to the AI without running a separate server process.
|
||||
|
||||
#### `tool(name, description, inputSchema, handler)`
|
||||
|
||||
Creates a tool definition with Zod schema type inference.
|
||||
|
||||
| Parameter | Type | Description |
|
||||
| ------------- | ---------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `name` | `string` | Tool name (1-64 chars, starts with letter, alphanumeric and underscores) |
|
||||
| `description` | `string` | Human-readable description of what the tool does |
|
||||
| `inputSchema` | `ZodRawShape` | Zod schema object defining the tool's input parameters |
|
||||
| `handler` | `(args, extra) => Promise<Result>` | Async function that executes the tool and returns MCP content blocks |
|
||||
|
||||
The handler must return a `CallToolResult` object with the following structure:
|
||||
|
||||
```typescript
|
||||
{
|
||||
content: Array<
|
||||
| { type: 'text'; text: string }
|
||||
| { type: 'image'; data: string; mimeType: string }
|
||||
| { type: 'resource'; uri: string; mimeType?: string; text?: string }
|
||||
>;
|
||||
isError?: boolean;
|
||||
}
|
||||
```
|
||||
|
||||
#### `createSdkMcpServer(options)`
|
||||
|
||||
Creates an SDK-embedded MCP server instance.
|
||||
|
||||
| Option | Type | Default | Description |
|
||||
| --------- | ------------------------ | --------- | ------------------------------------ |
|
||||
| `name` | `string` | Required | Unique name for the MCP server |
|
||||
| `version` | `string` | `'1.0.0'` | Server version |
|
||||
| `tools` | `SdkMcpToolDefinition[]` | - | Array of tools created with `tool()` |
|
||||
|
||||
Returns a `McpSdkServerConfigWithInstance` object that can be passed directly to the `mcpServers` option.
|
||||
|
||||
#### Example
|
||||
|
||||
```typescript
|
||||
import { z } from 'zod';
|
||||
import { query, tool, createSdkMcpServer } from '@qwen-code/sdk-typescript';
|
||||
|
||||
// Define a tool with Zod schema
|
||||
const calculatorTool = tool(
|
||||
'calculate_sum',
|
||||
'Add two numbers',
|
||||
{ a: z.number(), b: z.number() },
|
||||
async (args) => ({
|
||||
content: [{ type: 'text', text: String(args.a + args.b) }],
|
||||
}),
|
||||
);
|
||||
|
||||
// Create the MCP server
|
||||
const server = createSdkMcpServer({
|
||||
name: 'calculator',
|
||||
tools: [calculatorTool],
|
||||
});
|
||||
|
||||
// Use the server in a query
|
||||
const result = query({
|
||||
prompt: 'What is 42 + 17?',
|
||||
options: {
|
||||
permissionMode: 'yolo',
|
||||
mcpServers: {
|
||||
calculator: server,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
for await (const message of result) {
|
||||
console.log(message);
|
||||
}
|
||||
```
|
||||
|
||||
### Abort a Query
|
||||
|
||||
```typescript
|
||||
|
||||
@@ -3,6 +3,17 @@ export { AbortError, isAbortError } from './types/errors.js';
|
||||
export { Query } from './query/Query.js';
|
||||
export { SdkLogger } from './utils/logger.js';
|
||||
|
||||
// SDK MCP Server exports
|
||||
export { tool } from './mcp/tool.js';
|
||||
export { createSdkMcpServer } from './mcp/createSdkMcpServer.js';
|
||||
|
||||
export type { SdkMcpToolDefinition } from './mcp/tool.js';
|
||||
|
||||
export type {
|
||||
CreateSdkMcpServerOptions,
|
||||
McpSdkServerConfigWithInstance,
|
||||
} from './mcp/createSdkMcpServer.js';
|
||||
|
||||
export type { QueryOptions } from './query/createQuery.js';
|
||||
export type { LogLevel, LoggerConfig, ScopedLogger } from './utils/logger.js';
|
||||
|
||||
@@ -18,6 +29,7 @@ export type {
|
||||
SDKResultMessage,
|
||||
SDKPartialAssistantMessage,
|
||||
SDKMessage,
|
||||
SDKMcpServerConfig,
|
||||
ControlMessage,
|
||||
CLIControlRequest,
|
||||
CLIControlResponse,
|
||||
@@ -43,6 +55,10 @@ export type {
|
||||
PermissionMode,
|
||||
CanUseTool,
|
||||
PermissionResult,
|
||||
ExternalMcpServerConfig,
|
||||
SdkMcpServerConfig,
|
||||
CLIMcpServerConfig,
|
||||
McpServerConfig,
|
||||
McpOAuthConfig,
|
||||
McpAuthProviderType,
|
||||
} from './types/types.js';
|
||||
|
||||
export { isSdkMcpServerConfig } from './types/types.js';
|
||||
|
||||
@@ -103,9 +103,3 @@ export class SdkControlServerTransport {
|
||||
return this.serverName;
|
||||
}
|
||||
}
|
||||
|
||||
export function createSdkControlServerTransport(
|
||||
options: SdkControlServerTransportOptions,
|
||||
): SdkControlServerTransport {
|
||||
return new SdkControlServerTransport(options);
|
||||
}
|
||||
|
||||
@@ -1,29 +1,63 @@
|
||||
/**
|
||||
* Factory function to create SDK-embedded MCP servers
|
||||
*
|
||||
* Creates MCP Server instances that run in the user's Node.js process
|
||||
* and are proxied to the CLI via the control plane.
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js';
|
||||
import {
|
||||
ListToolsRequestSchema,
|
||||
CallToolRequestSchema,
|
||||
type CallToolResultSchema,
|
||||
} from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { ToolDefinition } from '../types/types.js';
|
||||
import { formatToolResult, formatToolError } from './formatters.js';
|
||||
/**
|
||||
* Factory function to create SDK-embedded MCP servers
|
||||
*/
|
||||
|
||||
import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import type { SdkMcpToolDefinition } from './tool.js';
|
||||
import { validateToolName } from './tool.js';
|
||||
import type { z } from 'zod';
|
||||
|
||||
type CallToolResult = z.infer<typeof CallToolResultSchema>;
|
||||
/**
|
||||
* Options for creating an SDK MCP server
|
||||
*/
|
||||
export type CreateSdkMcpServerOptions = {
|
||||
name: string;
|
||||
version?: string;
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
tools?: Array<SdkMcpToolDefinition<any>>;
|
||||
};
|
||||
|
||||
/**
|
||||
* SDK MCP Server configuration with instance
|
||||
*/
|
||||
export type McpSdkServerConfigWithInstance = {
|
||||
type: 'sdk';
|
||||
name: string;
|
||||
instance: McpServer;
|
||||
};
|
||||
|
||||
/**
|
||||
* Creates an MCP server instance that can be used with the SDK transport.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { z } from 'zod';
|
||||
* import { tool, createSdkMcpServer } from '@qwen-code/sdk-typescript';
|
||||
*
|
||||
* const calculatorTool = tool(
|
||||
* 'calculate_sum',
|
||||
* 'Add two numbers',
|
||||
* { a: z.number(), b: z.number() },
|
||||
* async (args) => ({ content: [{ type: 'text', text: String(args.a + args.b) }] })
|
||||
* );
|
||||
*
|
||||
* const server = createSdkMcpServer({
|
||||
* name: 'calculator',
|
||||
* version: '1.0.0',
|
||||
* tools: [calculatorTool],
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export function createSdkMcpServer(
|
||||
name: string,
|
||||
version: string,
|
||||
tools: ToolDefinition[],
|
||||
): Server {
|
||||
// Validate server name
|
||||
options: CreateSdkMcpServerOptions,
|
||||
): McpSdkServerConfigWithInstance {
|
||||
const { name, version = '1.0.0', tools } = options;
|
||||
|
||||
if (!name || typeof name !== 'string') {
|
||||
throw new Error('MCP server name must be a non-empty string');
|
||||
}
|
||||
@@ -32,78 +66,42 @@ export function createSdkMcpServer(
|
||||
throw new Error('MCP server version must be a non-empty string');
|
||||
}
|
||||
|
||||
if (!Array.isArray(tools)) {
|
||||
if (tools !== undefined && !Array.isArray(tools)) {
|
||||
throw new Error('Tools must be an array');
|
||||
}
|
||||
|
||||
// Validate tool names are unique
|
||||
const toolNames = new Set<string>();
|
||||
for (const tool of tools) {
|
||||
validateToolName(tool.name);
|
||||
|
||||
if (toolNames.has(tool.name)) {
|
||||
if (tools) {
|
||||
for (const t of tools) {
|
||||
validateToolName(t.name);
|
||||
if (toolNames.has(t.name)) {
|
||||
throw new Error(
|
||||
`Duplicate tool name '${tool.name}' in MCP server '${name}'`,
|
||||
`Duplicate tool name '${t.name}' in MCP server '${name}'`,
|
||||
);
|
||||
}
|
||||
toolNames.add(tool.name);
|
||||
toolNames.add(t.name);
|
||||
}
|
||||
}
|
||||
|
||||
// Create MCP Server instance
|
||||
const server = new Server(
|
||||
{
|
||||
name,
|
||||
version,
|
||||
},
|
||||
const server = new McpServer(
|
||||
{ name, version },
|
||||
{
|
||||
capabilities: {
|
||||
tools: {},
|
||||
tools: tools ? {} : undefined,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
// Create tool map for fast lookup
|
||||
const toolMap = new Map<string, ToolDefinition>();
|
||||
for (const tool of tools) {
|
||||
toolMap.set(tool.name, tool);
|
||||
}
|
||||
|
||||
// Register list_tools handler
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => ({
|
||||
tools: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
inputSchema: tool.inputSchema,
|
||||
})),
|
||||
}));
|
||||
|
||||
// Register call_tool handler
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
const { name: toolName, arguments: toolArgs } = request.params;
|
||||
|
||||
// Find tool
|
||||
const tool = toolMap.get(toolName);
|
||||
if (!tool) {
|
||||
return formatToolError(
|
||||
new Error(`Tool '${toolName}' not found in server '${name}'`),
|
||||
) as CallToolResult;
|
||||
}
|
||||
|
||||
try {
|
||||
// Invoke tool handler
|
||||
const result = await tool.handler(toolArgs);
|
||||
|
||||
// Format result
|
||||
return formatToolResult(result) as CallToolResult;
|
||||
} catch (error) {
|
||||
// Handle tool execution error
|
||||
return formatToolError(
|
||||
error instanceof Error
|
||||
? error
|
||||
: new Error(`Tool '${toolName}' failed: ${String(error)}`),
|
||||
) as CallToolResult;
|
||||
}
|
||||
if (tools) {
|
||||
tools.forEach((toolDef) => {
|
||||
server.tool(
|
||||
toolDef.name,
|
||||
toolDef.description,
|
||||
toolDef.inputSchema,
|
||||
toolDef.handler,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return server;
|
||||
return { type: 'sdk', name, instance: server };
|
||||
}
|
||||
|
||||
@@ -1,39 +1,76 @@
|
||||
/**
|
||||
* Tool definition helper for SDK-embedded MCP servers
|
||||
*
|
||||
* Provides type-safe tool definitions with generic input/output types.
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { ToolDefinition } from '../types/types.js';
|
||||
/**
|
||||
* Tool definition helper for SDK-embedded MCP servers
|
||||
*/
|
||||
|
||||
export function tool<TInput = unknown, TOutput = unknown>(
|
||||
def: ToolDefinition<TInput, TOutput>,
|
||||
): ToolDefinition<TInput, TOutput> {
|
||||
// Validate tool definition
|
||||
if (!def.name || typeof def.name !== 'string') {
|
||||
throw new Error('Tool definition must have a name (string)');
|
||||
import type { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { z, ZodRawShape, ZodObject, ZodTypeAny } from 'zod';
|
||||
|
||||
type CallToolResult = z.infer<typeof CallToolResultSchema>;
|
||||
|
||||
/**
|
||||
* SDK MCP Tool Definition with Zod schema type inference
|
||||
*/
|
||||
export type SdkMcpToolDefinition<Schema extends ZodRawShape = ZodRawShape> = {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema: Schema;
|
||||
handler: (
|
||||
args: z.infer<ZodObject<Schema, 'strip', ZodTypeAny>>,
|
||||
extra: unknown,
|
||||
) => Promise<CallToolResult>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Create an SDK MCP tool definition with Zod schema inference
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* import { z } from 'zod';
|
||||
* import { tool } from '@qwen-code/sdk-typescript';
|
||||
*
|
||||
* const calculatorTool = tool(
|
||||
* 'calculate_sum',
|
||||
* 'Calculate the sum of two numbers',
|
||||
* { a: z.number(), b: z.number() },
|
||||
* async (args) => {
|
||||
* // args is inferred as { a: number, b: number }
|
||||
* return { content: [{ type: 'text', text: String(args.a + args.b) }] };
|
||||
* }
|
||||
* );
|
||||
* ```
|
||||
*/
|
||||
export function tool<Schema extends ZodRawShape>(
|
||||
name: string,
|
||||
description: string,
|
||||
inputSchema: Schema,
|
||||
handler: (
|
||||
args: z.infer<ZodObject<Schema, 'strip', ZodTypeAny>>,
|
||||
extra: unknown,
|
||||
) => Promise<CallToolResult>,
|
||||
): SdkMcpToolDefinition<Schema> {
|
||||
if (!name || typeof name !== 'string') {
|
||||
throw new Error('Tool name must be a non-empty string');
|
||||
}
|
||||
|
||||
if (!def.description || typeof def.description !== 'string') {
|
||||
throw new Error(
|
||||
`Tool definition for '${def.name}' must have a description (string)`,
|
||||
);
|
||||
if (!description || typeof description !== 'string') {
|
||||
throw new Error(`Tool '${name}' must have a description (string)`);
|
||||
}
|
||||
|
||||
if (!def.inputSchema || typeof def.inputSchema !== 'object') {
|
||||
throw new Error(
|
||||
`Tool definition for '${def.name}' must have an inputSchema (object)`,
|
||||
);
|
||||
if (!inputSchema || typeof inputSchema !== 'object') {
|
||||
throw new Error(`Tool '${name}' must have an inputSchema (object)`);
|
||||
}
|
||||
|
||||
if (!def.handler || typeof def.handler !== 'function') {
|
||||
throw new Error(
|
||||
`Tool definition for '${def.name}' must have a handler (function)`,
|
||||
);
|
||||
if (!handler || typeof handler !== 'function') {
|
||||
throw new Error(`Tool '${name}' must have a handler (function)`);
|
||||
}
|
||||
|
||||
// Return definition (pass-through for type safety)
|
||||
return def;
|
||||
return { name, description, inputSchema, handler };
|
||||
}
|
||||
|
||||
export function validateToolName(name: string): void {
|
||||
@@ -53,39 +90,3 @@ export function validateToolName(name: string): void {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function validateInputSchema(schema: unknown): void {
|
||||
if (!schema || typeof schema !== 'object') {
|
||||
throw new Error('Input schema must be an object');
|
||||
}
|
||||
|
||||
const schemaObj = schema as Record<string, unknown>;
|
||||
|
||||
if (!schemaObj.type) {
|
||||
throw new Error('Input schema must have a type field');
|
||||
}
|
||||
|
||||
// For object schemas, validate properties
|
||||
if (schemaObj.type === 'object') {
|
||||
if (schemaObj.properties && typeof schemaObj.properties !== 'object') {
|
||||
throw new Error('Input schema properties must be an object');
|
||||
}
|
||||
|
||||
if (schemaObj.required && !Array.isArray(schemaObj.required)) {
|
||||
throw new Error('Input schema required must be an array');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function createTool<TInput = unknown, TOutput = unknown>(
|
||||
def: ToolDefinition<TInput, TOutput>,
|
||||
): ToolDefinition<TInput, TOutput> {
|
||||
// Validate via tool() function
|
||||
const validated = tool(def);
|
||||
|
||||
// Additional validation
|
||||
validateToolName(validated.name);
|
||||
validateInputSchema(validated.inputSchema);
|
||||
|
||||
return validated;
|
||||
}
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
* Implements AsyncIterator protocol for message consumption.
|
||||
*/
|
||||
|
||||
const PERMISSION_CALLBACK_TIMEOUT = 30000;
|
||||
const MCP_REQUEST_TIMEOUT = 30000;
|
||||
const CONTROL_REQUEST_TIMEOUT = 30000;
|
||||
const STREAM_CLOSE_TIMEOUT = 10000;
|
||||
const DEFAULT_CAN_USE_TOOL_TIMEOUT = 30_000;
|
||||
const DEFAULT_MCP_REQUEST_TIMEOUT = 60_000;
|
||||
const DEFAULT_CONTROL_REQUEST_TIMEOUT = 30_000;
|
||||
const DEFAULT_STREAM_CLOSE_TIMEOUT = 60_000;
|
||||
|
||||
import { randomUUID } from 'node:crypto';
|
||||
import { SdkLogger } from '../utils/logger.js';
|
||||
@@ -19,6 +19,7 @@ import type {
|
||||
CLIControlResponse,
|
||||
ControlCancelRequest,
|
||||
PermissionSuggestion,
|
||||
WireSDKMcpServerConfig,
|
||||
} from '../types/protocol.js';
|
||||
import {
|
||||
isSDKUserMessage,
|
||||
@@ -31,12 +32,17 @@ import {
|
||||
isControlCancel,
|
||||
} from '../types/protocol.js';
|
||||
import type { Transport } from '../transport/Transport.js';
|
||||
import type { QueryOptions } from '../types/types.js';
|
||||
import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
import type { QueryOptions, CLIMcpServerConfig } from '../types/types.js';
|
||||
import { isSdkMcpServerConfig } from '../types/types.js';
|
||||
import { Stream } from '../utils/Stream.js';
|
||||
import { serializeJsonLine } from '../utils/jsonLines.js';
|
||||
import { AbortError } from '../types/errors.js';
|
||||
import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js';
|
||||
import type { SdkControlServerTransport } from '../mcp/SdkControlServerTransport.js';
|
||||
import {
|
||||
SdkControlServerTransport,
|
||||
type SdkControlServerTransportOptions,
|
||||
} from '../mcp/SdkControlServerTransport.js';
|
||||
import { ControlRequestType } from '../types/protocol.js';
|
||||
|
||||
interface PendingControlRequest {
|
||||
@@ -46,6 +52,11 @@ interface PendingControlRequest {
|
||||
abortController: AbortController;
|
||||
}
|
||||
|
||||
interface PendingMcpResponse {
|
||||
resolve: (response: JSONRPCMessage) => void;
|
||||
reject: (error: Error) => void;
|
||||
}
|
||||
|
||||
interface TransportWithEndInput extends Transport {
|
||||
endInput(): void;
|
||||
}
|
||||
@@ -61,7 +72,9 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
private abortController: AbortController;
|
||||
private pendingControlRequests: Map<string, PendingControlRequest> =
|
||||
new Map();
|
||||
private pendingMcpResponses: Map<string, PendingMcpResponse> = new Map();
|
||||
private sdkMcpTransports: Map<string, SdkControlServerTransport> = new Map();
|
||||
private sdkMcpServers: Map<string, McpServer> = new Map();
|
||||
readonly initialized: Promise<void>;
|
||||
private closed = false;
|
||||
private messageRouterStarted = false;
|
||||
@@ -92,6 +105,11 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
*/
|
||||
this.sdkMessages = this.readSdkMessages();
|
||||
|
||||
/**
|
||||
* Promise that resolves when the first SDKResultMessage is received.
|
||||
* Used to coordinate endInput() timing - ensures all initialization
|
||||
* (SDK MCP servers, control responses) is complete before closing CLI stdin.
|
||||
*/
|
||||
this.firstResultReceivedPromise = new Promise((resolve) => {
|
||||
this.firstResultReceivedResolve = resolve;
|
||||
});
|
||||
@@ -121,17 +139,152 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
this.startMessageRouter();
|
||||
}
|
||||
|
||||
private async initializeSdkMcpServers(): Promise<void> {
|
||||
if (!this.options.mcpServers) {
|
||||
return;
|
||||
}
|
||||
|
||||
const connectionPromises: Array<Promise<void>> = [];
|
||||
|
||||
// Extract SDK MCP servers from the unified mcpServers config
|
||||
for (const [key, config] of Object.entries(this.options.mcpServers)) {
|
||||
if (!isSdkMcpServerConfig(config)) {
|
||||
continue; // Skip external MCP servers
|
||||
}
|
||||
|
||||
// Use the name from SDKMcpServerConfig, fallback to key for backwards compatibility
|
||||
const serverName = config.name || key;
|
||||
const server = config.instance;
|
||||
|
||||
// Create transport options with callback to route MCP server responses
|
||||
const transportOptions: SdkControlServerTransportOptions = {
|
||||
sendToQuery: async (message: JSONRPCMessage) => {
|
||||
this.handleMcpServerResponse(serverName, message);
|
||||
},
|
||||
serverName,
|
||||
};
|
||||
|
||||
const sdkTransport = new SdkControlServerTransport(transportOptions);
|
||||
|
||||
// Connect server to transport and only register on success
|
||||
const connectionPromise = server
|
||||
.connect(sdkTransport)
|
||||
.then(() => {
|
||||
// Only add to maps after successful connection
|
||||
this.sdkMcpServers.set(serverName, server);
|
||||
this.sdkMcpTransports.set(serverName, sdkTransport);
|
||||
logger.debug(`SDK MCP server '${serverName}' connected to transport`);
|
||||
})
|
||||
.catch((error) => {
|
||||
logger.error(
|
||||
`Failed to connect SDK MCP server '${serverName}' to transport:`,
|
||||
error,
|
||||
);
|
||||
// Don't throw - one failed server shouldn't prevent others
|
||||
});
|
||||
|
||||
connectionPromises.push(connectionPromise);
|
||||
}
|
||||
|
||||
// Wait for all connection attempts to complete
|
||||
await Promise.all(connectionPromises);
|
||||
|
||||
if (this.sdkMcpServers.size > 0) {
|
||||
logger.info(
|
||||
`Initialized ${this.sdkMcpServers.size} SDK MCP server(s): ${Array.from(this.sdkMcpServers.keys()).join(', ')}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle response messages from SDK MCP servers
|
||||
*
|
||||
* When an MCP server sends a response via transport.send(), this callback
|
||||
* routes it back to the pending request that's waiting for it.
|
||||
*/
|
||||
private handleMcpServerResponse(
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
): void {
|
||||
// Check if this is a response with an id
|
||||
if ('id' in message && message.id !== null && message.id !== undefined) {
|
||||
const key = `${serverName}:${message.id}`;
|
||||
const pending = this.pendingMcpResponses.get(key);
|
||||
if (pending) {
|
||||
logger.debug(
|
||||
`Routing MCP response for server '${serverName}', id: ${message.id}`,
|
||||
);
|
||||
pending.resolve(message);
|
||||
this.pendingMcpResponses.delete(key);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If no pending request found, log a warning (this shouldn't happen normally)
|
||||
logger.warn(
|
||||
`Received MCP server response with no pending request: server='${serverName}'`,
|
||||
message,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get SDK MCP servers config for CLI initialization
|
||||
*
|
||||
* Only SDK servers are sent in the initialize request.
|
||||
*/
|
||||
private getSdkMcpServersForCli(): Record<string, WireSDKMcpServerConfig> {
|
||||
const sdkServers: Record<string, WireSDKMcpServerConfig> = {};
|
||||
|
||||
for (const [name] of this.sdkMcpServers.entries()) {
|
||||
sdkServers[name] = { type: 'sdk', name };
|
||||
}
|
||||
|
||||
return sdkServers;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get external MCP servers (non-SDK) that should be managed by the CLI
|
||||
*/
|
||||
private getMcpServersForCli(): Record<string, CLIMcpServerConfig> {
|
||||
if (!this.options.mcpServers) {
|
||||
return {};
|
||||
}
|
||||
|
||||
const externalServers: Record<string, CLIMcpServerConfig> = {};
|
||||
|
||||
for (const [name, config] of Object.entries(this.options.mcpServers)) {
|
||||
if (isSdkMcpServerConfig(config)) {
|
||||
continue;
|
||||
}
|
||||
externalServers[name] = config as CLIMcpServerConfig;
|
||||
}
|
||||
|
||||
return externalServers;
|
||||
}
|
||||
|
||||
private async initialize(): Promise<void> {
|
||||
try {
|
||||
logger.debug('Initializing Query');
|
||||
|
||||
const sdkMcpServerNames = Array.from(this.sdkMcpTransports.keys());
|
||||
// Initialize SDK MCP servers and wait for connections
|
||||
await this.initializeSdkMcpServers();
|
||||
|
||||
// Get only successfully connected SDK servers for CLI
|
||||
const sdkMcpServersForCli = this.getSdkMcpServersForCli();
|
||||
const mcpServersForCli = this.getMcpServersForCli();
|
||||
logger.debug('SDK MCP servers for CLI:', sdkMcpServersForCli);
|
||||
logger.debug('External MCP servers for CLI:', mcpServersForCli);
|
||||
|
||||
await this.sendControlRequest(ControlRequestType.INITIALIZE, {
|
||||
hooks: null,
|
||||
sdkMcpServers:
|
||||
sdkMcpServerNames.length > 0 ? sdkMcpServerNames : undefined,
|
||||
mcpServers: this.options.mcpServers,
|
||||
Object.keys(sdkMcpServersForCli).length > 0
|
||||
? sdkMcpServersForCli
|
||||
: undefined,
|
||||
mcpServers:
|
||||
Object.keys(mcpServersForCli).length > 0
|
||||
? mcpServersForCli
|
||||
: undefined,
|
||||
agents: this.options.agents,
|
||||
});
|
||||
logger.info('Query initialized successfully');
|
||||
@@ -279,10 +432,12 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
}
|
||||
|
||||
try {
|
||||
const canUseToolTimeout =
|
||||
this.options.timeout?.canUseTool ?? DEFAULT_CAN_USE_TOOL_TIMEOUT;
|
||||
const timeoutPromise = new Promise<never>((_, reject) => {
|
||||
setTimeout(
|
||||
() => reject(new Error('Permission callback timeout')),
|
||||
PERMISSION_CALLBACK_TIMEOUT,
|
||||
canUseToolTimeout,
|
||||
);
|
||||
});
|
||||
|
||||
@@ -361,32 +516,45 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
}
|
||||
|
||||
private handleMcpRequest(
|
||||
_serverName: string,
|
||||
serverName: string,
|
||||
message: JSONRPCMessage,
|
||||
transport: SdkControlServerTransport,
|
||||
): Promise<JSONRPCMessage> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const timeout = setTimeout(() => {
|
||||
reject(new Error('MCP request timeout'));
|
||||
}, MCP_REQUEST_TIMEOUT);
|
||||
|
||||
const messageId = 'id' in message ? message.id : null;
|
||||
const key = `${serverName}:${messageId}`;
|
||||
|
||||
/**
|
||||
* Hook into transport to capture response.
|
||||
* Temporarily replace sendToQuery to intercept the response message
|
||||
* matching this request's ID, then restore the original handler.
|
||||
*/
|
||||
const originalSend = transport.sendToQuery;
|
||||
transport.sendToQuery = async (responseMessage: JSONRPCMessage) => {
|
||||
if ('id' in responseMessage && responseMessage.id === messageId) {
|
||||
return new Promise((resolve, reject) => {
|
||||
const mcpRequestTimeout =
|
||||
this.options.timeout?.mcpRequest ?? DEFAULT_MCP_REQUEST_TIMEOUT;
|
||||
const timeout = setTimeout(() => {
|
||||
this.pendingMcpResponses.delete(key);
|
||||
reject(new Error('MCP request timeout'));
|
||||
}, mcpRequestTimeout);
|
||||
|
||||
const cleanup = () => {
|
||||
clearTimeout(timeout);
|
||||
transport.sendToQuery = originalSend;
|
||||
resolve(responseMessage);
|
||||
}
|
||||
return originalSend(responseMessage);
|
||||
this.pendingMcpResponses.delete(key);
|
||||
};
|
||||
|
||||
const resolveAndCleanup = (response: JSONRPCMessage) => {
|
||||
cleanup();
|
||||
resolve(response);
|
||||
};
|
||||
|
||||
const rejectAndCleanup = (error: Error) => {
|
||||
cleanup();
|
||||
reject(error);
|
||||
};
|
||||
|
||||
// Register pending response handler
|
||||
this.pendingMcpResponses.set(key, {
|
||||
resolve: resolveAndCleanup,
|
||||
reject: rejectAndCleanup,
|
||||
});
|
||||
|
||||
// Deliver message to MCP server via transport.onmessage
|
||||
// The server will process it and call transport.send() with the response,
|
||||
// which triggers handleMcpServerResponse to resolve our pending promise
|
||||
transport.handleMessage(message);
|
||||
});
|
||||
}
|
||||
@@ -452,6 +620,10 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
subtype: string,
|
||||
data: Record<string, unknown> = {},
|
||||
): Promise<Record<string, unknown> | null> {
|
||||
if (this.closed) {
|
||||
return Promise.reject(new Error('Query is closed'));
|
||||
}
|
||||
|
||||
const requestId = randomUUID();
|
||||
|
||||
const request: CLIControlRequest = {
|
||||
@@ -466,10 +638,13 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
const responsePromise = new Promise<Record<string, unknown> | null>(
|
||||
(resolve, reject) => {
|
||||
const abortController = new AbortController();
|
||||
const controlRequestTimeout =
|
||||
this.options.timeout?.controlRequest ??
|
||||
DEFAULT_CONTROL_REQUEST_TIMEOUT;
|
||||
const timeout = setTimeout(() => {
|
||||
this.pendingControlRequests.delete(requestId);
|
||||
reject(new Error(`Control request timeout: ${subtype}`));
|
||||
}, CONTROL_REQUEST_TIMEOUT);
|
||||
}, controlRequestTimeout);
|
||||
|
||||
this.pendingControlRequests.set(requestId, {
|
||||
resolve,
|
||||
@@ -517,9 +692,16 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
for (const pending of this.pendingControlRequests.values()) {
|
||||
pending.abortController.abort();
|
||||
clearTimeout(pending.timeout);
|
||||
pending.reject(new Error('Query is closed'));
|
||||
}
|
||||
this.pendingControlRequests.clear();
|
||||
|
||||
// Clean up pending MCP responses
|
||||
for (const pending of this.pendingMcpResponses.values()) {
|
||||
pending.reject(new Error('Query is closed'));
|
||||
}
|
||||
this.pendingMcpResponses.clear();
|
||||
|
||||
await this.transport.close();
|
||||
|
||||
/**
|
||||
@@ -542,7 +724,7 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
}
|
||||
}
|
||||
this.sdkMcpTransports.clear();
|
||||
logger.info('Query closed');
|
||||
logger.info('Query is closed');
|
||||
}
|
||||
|
||||
private async *readSdkMessages(): AsyncGenerator<SDKMessage> {
|
||||
@@ -588,22 +770,31 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
}
|
||||
|
||||
/**
|
||||
* In multi-turn mode with MCP servers, wait for first result
|
||||
* to ensure MCP servers have time to process before next input.
|
||||
* This prevents race conditions where the next input arrives before
|
||||
* MCP servers have finished processing the current request.
|
||||
* After all user messages are sent (for-await loop ended), determine when to
|
||||
* close the CLI's stdin via endInput().
|
||||
*
|
||||
* - If a result message was already received: All initialization (SDK MCP servers,
|
||||
* control responses, etc.) is complete, safe to close stdin immediately.
|
||||
* - If no result yet: Wait for either the result to arrive, or the timeout to expire.
|
||||
* This gives pending control_responses from SDK MCP servers or other modules
|
||||
* time to complete their initialization before we close the input stream.
|
||||
*
|
||||
* The timeout ensures we don't hang indefinitely - either the turn proceeds
|
||||
* normally, or it fails with a timeout, but Promise.race will always resolve.
|
||||
*/
|
||||
if (
|
||||
!this.isSingleTurn &&
|
||||
this.sdkMcpTransports.size > 0 &&
|
||||
this.firstResultReceivedPromise
|
||||
) {
|
||||
const streamCloseTimeout =
|
||||
this.options.timeout?.streamClose ?? DEFAULT_STREAM_CLOSE_TIMEOUT;
|
||||
await Promise.race([
|
||||
this.firstResultReceivedPromise,
|
||||
new Promise<void>((resolve) => {
|
||||
setTimeout(() => {
|
||||
resolve();
|
||||
}, STREAM_CLOSE_TIMEOUT);
|
||||
}, streamCloseTimeout);
|
||||
}),
|
||||
]);
|
||||
}
|
||||
@@ -635,28 +826,16 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
}
|
||||
|
||||
async interrupt(): Promise<void> {
|
||||
if (this.closed) {
|
||||
throw new Error('Query is closed');
|
||||
}
|
||||
|
||||
await this.sendControlRequest(ControlRequestType.INTERRUPT);
|
||||
}
|
||||
|
||||
async setPermissionMode(mode: string): Promise<void> {
|
||||
if (this.closed) {
|
||||
throw new Error('Query is closed');
|
||||
}
|
||||
|
||||
await this.sendControlRequest(ControlRequestType.SET_PERMISSION_MODE, {
|
||||
mode,
|
||||
});
|
||||
}
|
||||
|
||||
async setModel(model: string): Promise<void> {
|
||||
if (this.closed) {
|
||||
throw new Error('Query is closed');
|
||||
}
|
||||
|
||||
await this.sendControlRequest(ControlRequestType.SET_MODEL, { model });
|
||||
}
|
||||
|
||||
@@ -667,10 +846,6 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
* @throws Error if query is closed
|
||||
*/
|
||||
async supportedCommands(): Promise<Record<string, unknown> | null> {
|
||||
if (this.closed) {
|
||||
throw new Error('Query is closed');
|
||||
}
|
||||
|
||||
return this.sendControlRequest(ControlRequestType.SUPPORTED_COMMANDS);
|
||||
}
|
||||
|
||||
@@ -681,10 +856,6 @@ export class Query implements AsyncIterable<SDKMessage> {
|
||||
* @throws Error if query is closed
|
||||
*/
|
||||
async mcpServerStatus(): Promise<Record<string, unknown> | null> {
|
||||
if (this.closed) {
|
||||
throw new Error('Query is closed');
|
||||
}
|
||||
|
||||
return this.sendControlRequest(ControlRequestType.MCP_SERVER_STATUS);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
|
||||
import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js';
|
||||
export interface Annotation {
|
||||
type: string;
|
||||
value: string;
|
||||
@@ -293,10 +294,44 @@ export interface MCPServerConfig {
|
||||
targetServiceAccount?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* SDK MCP Server configuration
|
||||
*
|
||||
* SDK MCP servers run in the SDK process and are connected via in-memory transport.
|
||||
* Tool calls are routed through the control plane between SDK and CLI.
|
||||
*/
|
||||
export interface SDKMcpServerConfig {
|
||||
/**
|
||||
* Type identifier for SDK MCP servers
|
||||
*/
|
||||
type: 'sdk';
|
||||
/**
|
||||
* Server name for identification and routing
|
||||
*/
|
||||
name: string;
|
||||
/**
|
||||
* The MCP Server instance created by createSdkMcpServer()
|
||||
*/
|
||||
instance: McpServer;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wire format for SDK MCP servers sent to the CLI
|
||||
*/
|
||||
export type WireSDKMcpServerConfig = Omit<SDKMcpServerConfig, 'instance'>;
|
||||
|
||||
export interface CLIControlInitializeRequest {
|
||||
subtype: 'initialize';
|
||||
hooks?: HookRegistration[] | null;
|
||||
sdkMcpServers?: Record<string, MCPServerConfig>;
|
||||
/**
|
||||
* SDK MCP servers config
|
||||
* These are MCP servers running in the SDK process, connected via control plane.
|
||||
* External MCP servers are configured separately in settings, not via initialization.
|
||||
*/
|
||||
sdkMcpServers?: Record<string, WireSDKMcpServerConfig>;
|
||||
/**
|
||||
* External MCP servers that should be managed by the CLI.
|
||||
*/
|
||||
mcpServers?: Record<string, MCPServerConfig>;
|
||||
agents?: SubagentConfig[];
|
||||
}
|
||||
|
||||
@@ -2,19 +2,98 @@ import { z } from 'zod';
|
||||
import type { CanUseTool } from './types.js';
|
||||
import type { SubagentConfig } from './protocol.js';
|
||||
|
||||
export const ExternalMcpServerConfigSchema = z.object({
|
||||
command: z.string().min(1, 'Command must be a non-empty string'),
|
||||
/**
|
||||
* OAuth configuration for MCP servers
|
||||
*/
|
||||
export const McpOAuthConfigSchema = z
|
||||
.object({
|
||||
enabled: z.boolean().optional(),
|
||||
clientId: z
|
||||
.string()
|
||||
.min(1, 'clientId must be a non-empty string')
|
||||
.optional(),
|
||||
clientSecret: z.string().optional(),
|
||||
scopes: z.array(z.string()).optional(),
|
||||
redirectUri: z.string().optional(),
|
||||
authorizationUrl: z.string().optional(),
|
||||
tokenUrl: z.string().optional(),
|
||||
audiences: z.array(z.string()).optional(),
|
||||
tokenParamName: z.string().optional(),
|
||||
registrationUrl: z.string().optional(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
/**
|
||||
* CLI MCP Server configuration schema
|
||||
*
|
||||
* Supports multiple transport types:
|
||||
* - stdio: command, args, env, cwd
|
||||
* - SSE: url
|
||||
* - Streamable HTTP: httpUrl, headers
|
||||
* - WebSocket: tcp
|
||||
*/
|
||||
export const CLIMcpServerConfigSchema = z.object({
|
||||
// For stdio transport
|
||||
command: z.string().optional(),
|
||||
args: z.array(z.string()).optional(),
|
||||
env: z.record(z.string(), z.string()).optional(),
|
||||
cwd: z.string().optional(),
|
||||
// For SSE transport
|
||||
url: z.string().optional(),
|
||||
// For streamable HTTP transport
|
||||
httpUrl: z.string().optional(),
|
||||
headers: z.record(z.string(), z.string()).optional(),
|
||||
// For WebSocket transport
|
||||
tcp: z.string().optional(),
|
||||
// Common
|
||||
timeout: z.number().optional(),
|
||||
trust: z.boolean().optional(),
|
||||
// Metadata
|
||||
description: z.string().optional(),
|
||||
includeTools: z.array(z.string()).optional(),
|
||||
excludeTools: z.array(z.string()).optional(),
|
||||
extensionName: z.string().optional(),
|
||||
// OAuth configuration
|
||||
oauth: McpOAuthConfigSchema.optional(),
|
||||
authProviderType: z
|
||||
.enum([
|
||||
'dynamic_discovery',
|
||||
'google_credentials',
|
||||
'service_account_impersonation',
|
||||
])
|
||||
.optional(),
|
||||
// Service Account Configuration
|
||||
targetAudience: z.string().optional(),
|
||||
targetServiceAccount: z.string().optional(),
|
||||
});
|
||||
|
||||
/**
|
||||
* SDK MCP Server configuration schema
|
||||
*/
|
||||
export const SdkMcpServerConfigSchema = z.object({
|
||||
connect: z.custom<(transport: unknown) => Promise<void>>(
|
||||
(val) => typeof val === 'function',
|
||||
{ message: 'connect must be a function' },
|
||||
type: z.literal('sdk'),
|
||||
name: z.string().min(1, 'name must be a non-empty string'),
|
||||
instance: z.custom<{
|
||||
connect(transport: unknown): Promise<void>;
|
||||
close(): Promise<void>;
|
||||
}>(
|
||||
(val) =>
|
||||
val &&
|
||||
typeof val === 'object' &&
|
||||
'connect' in val &&
|
||||
typeof val.connect === 'function',
|
||||
{ message: 'instance must be an MCP Server with connect method' },
|
||||
),
|
||||
});
|
||||
|
||||
/**
|
||||
* Unified MCP Server configuration schema
|
||||
*/
|
||||
export const McpServerConfigSchema = z.union([
|
||||
CLIMcpServerConfigSchema,
|
||||
SdkMcpServerConfigSchema,
|
||||
]);
|
||||
|
||||
export const ModelConfigSchema = z.object({
|
||||
model: z.string().optional(),
|
||||
temp: z.number().optional(),
|
||||
@@ -37,6 +116,13 @@ export const SubagentConfigSchema = z.object({
|
||||
isBuiltin: z.boolean().optional(),
|
||||
});
|
||||
|
||||
export const TimeoutConfigSchema = z.object({
|
||||
canUseTool: z.number().positive().optional(),
|
||||
mcpRequest: z.number().positive().optional(),
|
||||
controlRequest: z.number().positive().optional(),
|
||||
streamClose: z.number().positive().optional(),
|
||||
});
|
||||
|
||||
export const QueryOptionsSchema = z
|
||||
.object({
|
||||
cwd: z.string().optional(),
|
||||
@@ -49,7 +135,7 @@ export const QueryOptionsSchema = z
|
||||
message: 'canUseTool must be a function',
|
||||
})
|
||||
.optional(),
|
||||
mcpServers: z.record(z.string(), ExternalMcpServerConfigSchema).optional(),
|
||||
mcpServers: z.record(z.string(), McpServerConfigSchema).optional(),
|
||||
abortController: z.instanceof(AbortController).optional(),
|
||||
debug: z.boolean().optional(),
|
||||
stderr: z
|
||||
@@ -78,5 +164,6 @@ export const QueryOptionsSchema = z
|
||||
)
|
||||
.optional(),
|
||||
includePartialMessages: z.boolean().optional(),
|
||||
timeout: TimeoutConfigSchema.optional(),
|
||||
})
|
||||
.strict();
|
||||
|
||||
@@ -2,25 +2,11 @@ import type {
|
||||
PermissionMode,
|
||||
PermissionSuggestion,
|
||||
SubagentConfig,
|
||||
SDKMcpServerConfig,
|
||||
} from './protocol.js';
|
||||
|
||||
export type { PermissionMode };
|
||||
|
||||
type JSONSchema = {
|
||||
type: string;
|
||||
properties?: Record<string, unknown>;
|
||||
required?: string[];
|
||||
description?: string;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
export type ToolDefinition<TInput = unknown, TOutput = unknown> = {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema: JSONSchema;
|
||||
handler: (input: TInput) => Promise<TOutput>;
|
||||
};
|
||||
|
||||
export type TransportOptions = {
|
||||
pathToQwenExecutable: string;
|
||||
cwd?: string;
|
||||
@@ -61,14 +47,115 @@ export type PermissionResult =
|
||||
interrupt?: boolean;
|
||||
};
|
||||
|
||||
export interface ExternalMcpServerConfig {
|
||||
command: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
/**
|
||||
* OAuth configuration for MCP servers
|
||||
*/
|
||||
export interface McpOAuthConfig {
|
||||
enabled?: boolean;
|
||||
clientId?: string;
|
||||
clientSecret?: string;
|
||||
scopes?: string[];
|
||||
redirectUri?: string;
|
||||
authorizationUrl?: string;
|
||||
tokenUrl?: string;
|
||||
audiences?: string[];
|
||||
tokenParamName?: string;
|
||||
registrationUrl?: string;
|
||||
}
|
||||
|
||||
export interface SdkMcpServerConfig {
|
||||
connect: (transport: unknown) => Promise<void>;
|
||||
/**
|
||||
* Auth provider type for MCP servers
|
||||
*/
|
||||
export type McpAuthProviderType =
|
||||
| 'dynamic_discovery'
|
||||
| 'google_credentials'
|
||||
| 'service_account_impersonation';
|
||||
|
||||
/**
|
||||
* CLI MCP Server configuration
|
||||
*
|
||||
* Supports multiple transport types:
|
||||
* - stdio: command, args, env, cwd
|
||||
* - SSE: url
|
||||
* - Streamable HTTP: httpUrl, headers
|
||||
* - WebSocket: tcp
|
||||
*
|
||||
* This interface aligns with MCPServerConfig in @qwen-code/qwen-code-core.
|
||||
*/
|
||||
export interface CLIMcpServerConfig {
|
||||
// For stdio transport
|
||||
command?: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
cwd?: string;
|
||||
// For SSE transport
|
||||
url?: string;
|
||||
// For streamable HTTP transport
|
||||
httpUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
// For WebSocket transport
|
||||
tcp?: string;
|
||||
// Common
|
||||
timeout?: number;
|
||||
trust?: boolean;
|
||||
// Metadata
|
||||
description?: string;
|
||||
includeTools?: string[];
|
||||
excludeTools?: string[];
|
||||
extensionName?: string;
|
||||
// OAuth configuration
|
||||
oauth?: McpOAuthConfig;
|
||||
authProviderType?: McpAuthProviderType;
|
||||
// Service Account Configuration
|
||||
/** targetAudience format: CLIENT_ID.apps.googleusercontent.com */
|
||||
targetAudience?: string;
|
||||
/** targetServiceAccount format: <service-account-name>@<project-num>.iam.gserviceaccount.com */
|
||||
targetServiceAccount?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unified MCP Server configuration
|
||||
*
|
||||
* Supports both external MCP servers (stdio/SSE/HTTP/WebSocket) and SDK-embedded MCP servers.
|
||||
*
|
||||
* @example External MCP server (stdio)
|
||||
* ```typescript
|
||||
* mcpServers: {
|
||||
* 'my-server': { command: 'node', args: ['server.js'] }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @example External MCP server (SSE)
|
||||
* ```typescript
|
||||
* mcpServers: {
|
||||
* 'remote-server': { url: 'http://localhost:3000/sse' }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @example External MCP server (Streamable HTTP)
|
||||
* ```typescript
|
||||
* mcpServers: {
|
||||
* 'http-server': { httpUrl: 'http://localhost:3000/mcp', headers: { 'Authorization': 'Bearer token' } }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @example SDK MCP server
|
||||
* ```typescript
|
||||
* const server = createSdkMcpServer('weather', '1.0.0', [weatherTool]);
|
||||
* mcpServers: {
|
||||
* 'weather': { type: 'sdk', name: 'weather', instance: server }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export type McpServerConfig = CLIMcpServerConfig | SDKMcpServerConfig;
|
||||
|
||||
/**
|
||||
* Type guard to check if a config is an SDK MCP server
|
||||
*/
|
||||
export function isSdkMcpServerConfig(
|
||||
config: McpServerConfig,
|
||||
): config is SDKMcpServerConfig {
|
||||
return 'type' in config && config.type === 'sdk';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -174,11 +261,36 @@ export interface QueryOptions {
|
||||
canUseTool?: CanUseTool;
|
||||
|
||||
/**
|
||||
* External MCP (Model Context Protocol) servers to connect to.
|
||||
* Each server is identified by a unique name and configured with command, args, and environment.
|
||||
* @example { 'my-server': { command: 'node', args: ['server.js'], env: { PORT: '3000' } } }
|
||||
* MCP (Model Context Protocol) servers to connect to.
|
||||
*
|
||||
* Supports both external MCP servers and SDK-embedded MCP servers:
|
||||
*
|
||||
* **External MCP servers** - Run in separate processes, connected via stdio/SSE/HTTP:
|
||||
* ```typescript
|
||||
* mcpServers: {
|
||||
* 'stdio-server': { command: 'node', args: ['server.js'], env: { PORT: '3000' } },
|
||||
* 'sse-server': { url: 'http://localhost:3000/sse' },
|
||||
* 'http-server': { httpUrl: 'http://localhost:3000/mcp' }
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* **SDK MCP servers** - Run in the SDK process, connected via in-memory transport:
|
||||
* ```typescript
|
||||
* const myTool = tool({
|
||||
* name: 'my_tool',
|
||||
* description: 'My custom tool',
|
||||
* inputSchema: { type: 'object', properties: { input: { type: 'string' } } },
|
||||
* handler: async (input) => ({ result: input.input.toUpperCase() }),
|
||||
* });
|
||||
*
|
||||
* const server = createSdkMcpServer('my-server', '1.0.0', [myTool]);
|
||||
*
|
||||
* mcpServers: {
|
||||
* 'my-server': { type: 'sdk', name: 'my-server', instance: server }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
mcpServers?: Record<string, ExternalMcpServerConfig>;
|
||||
mcpServers?: Record<string, McpServerConfig>;
|
||||
|
||||
/**
|
||||
* AbortController to cancel the query session.
|
||||
@@ -294,4 +406,43 @@ export interface QueryOptions {
|
||||
* @default false
|
||||
*/
|
||||
includePartialMessages?: boolean;
|
||||
|
||||
/**
|
||||
* Timeout configuration for various SDK operations.
|
||||
* All values are in milliseconds.
|
||||
*/
|
||||
timeout?: {
|
||||
/**
|
||||
* Timeout for the `canUseTool` callback.
|
||||
* If the callback doesn't resolve within this time, the permission request
|
||||
* will be denied with a timeout error (fail-safe behavior).
|
||||
* @default 60000 (1 minute)
|
||||
*/
|
||||
canUseTool?: number;
|
||||
|
||||
/**
|
||||
* Timeout for SDK MCP tool calls.
|
||||
* This applies to tool calls made to SDK-embedded MCP servers.
|
||||
* @default 60000 (1 minute)
|
||||
*/
|
||||
mcpRequest?: number;
|
||||
|
||||
/**
|
||||
* Timeout for SDK→CLI control requests.
|
||||
* This applies to internal control operations like initialize, interrupt,
|
||||
* setPermissionMode, setModel, etc.
|
||||
* @default 60000 (1 minute)
|
||||
*/
|
||||
controlRequest?: number;
|
||||
|
||||
/**
|
||||
* Timeout for waiting before closing CLI's stdin after user messages are sent.
|
||||
* In multi-turn mode with SDK MCP servers, after all user messages are processed,
|
||||
* the SDK waits for the first result message to ensure all initialization
|
||||
* (control responses, MCP server setup, etc.) is complete before closing stdin.
|
||||
* This timeout is a fallback to avoid hanging indefinitely.
|
||||
* @default 60000 (1 minute)
|
||||
*/
|
||||
streamClose?: number;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Qwen Team
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
/**
|
||||
* Unit tests for createSdkMcpServer
|
||||
*
|
||||
@@ -5,93 +11,112 @@
|
||||
*/
|
||||
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import { z } from 'zod';
|
||||
import { createSdkMcpServer } from '../../src/mcp/createSdkMcpServer.js';
|
||||
import { tool } from '../../src/mcp/tool.js';
|
||||
import type { ToolDefinition } from '../../src/types/config.js';
|
||||
import type { SdkMcpToolDefinition } from '../../src/mcp/tool.js';
|
||||
|
||||
describe('createSdkMcpServer', () => {
|
||||
describe('Server Creation', () => {
|
||||
it('should create server with name and version', () => {
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', []);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
expect(server.type).toBe('sdk');
|
||||
expect(server.name).toBe('test-server');
|
||||
expect(server.instance).toBeDefined();
|
||||
});
|
||||
|
||||
it('should create server with default version', () => {
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
expect(server.name).toBe('test-server');
|
||||
});
|
||||
|
||||
it('should throw error with invalid name', () => {
|
||||
expect(() => createSdkMcpServer('', '1.0.0', [])).toThrow(
|
||||
'name must be a non-empty string',
|
||||
expect(() => createSdkMcpServer({ name: '', version: '1.0.0' })).toThrow(
|
||||
'MCP server name must be a non-empty string',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error with invalid version', () => {
|
||||
expect(() => createSdkMcpServer('test', '', [])).toThrow(
|
||||
'version must be a non-empty string',
|
||||
expect(() => createSdkMcpServer({ name: 'test', version: '' })).toThrow(
|
||||
'MCP server version must be a non-empty string',
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw error with non-array tools', () => {
|
||||
expect(() =>
|
||||
createSdkMcpServer('test', '1.0.0', {} as unknown as ToolDefinition[]),
|
||||
createSdkMcpServer({
|
||||
name: 'test',
|
||||
version: '1.0.0',
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
tools: {} as unknown as SdkMcpToolDefinition<any>[],
|
||||
}),
|
||||
).toThrow('Tools must be an array');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Registration', () => {
|
||||
it('should register single tool', () => {
|
||||
const testTool = tool({
|
||||
name: 'test_tool',
|
||||
description: 'A test tool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' },
|
||||
},
|
||||
},
|
||||
handler: async () => 'result',
|
||||
});
|
||||
const testTool = tool(
|
||||
'test_tool',
|
||||
'A test tool',
|
||||
{ input: z.string() },
|
||||
async () => ({
|
||||
content: [{ type: 'text', text: 'result' }],
|
||||
}),
|
||||
);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [testTool]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [testTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
|
||||
it('should register multiple tools', () => {
|
||||
const tool1 = tool({
|
||||
name: 'tool1',
|
||||
description: 'Tool 1',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result1',
|
||||
});
|
||||
const tool1 = tool('tool1', 'Tool 1', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result1' }],
|
||||
}));
|
||||
|
||||
const tool2 = tool({
|
||||
name: 'tool2',
|
||||
description: 'Tool 2',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result2',
|
||||
});
|
||||
const tool2 = tool('tool2', 'Tool 2', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result2' }],
|
||||
}));
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [tool1, tool2]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [tool1, tool2],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
|
||||
it('should throw error for duplicate tool names', () => {
|
||||
const tool1 = tool({
|
||||
name: 'duplicate',
|
||||
description: 'Tool 1',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result1',
|
||||
});
|
||||
const tool1 = tool('duplicate', 'Tool 1', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result1' }],
|
||||
}));
|
||||
|
||||
const tool2 = tool({
|
||||
name: 'duplicate',
|
||||
description: 'Tool 2',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result2',
|
||||
});
|
||||
const tool2 = tool('duplicate', 'Tool 2', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result2' }],
|
||||
}));
|
||||
|
||||
expect(() =>
|
||||
createSdkMcpServer('test-server', '1.0.0', [tool1, tool2]),
|
||||
createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [tool1, tool2],
|
||||
}),
|
||||
).toThrow("Duplicate tool name 'duplicate'");
|
||||
});
|
||||
|
||||
@@ -99,36 +124,41 @@ describe('createSdkMcpServer', () => {
|
||||
const invalidTool = {
|
||||
name: '123invalid', // Starts with number
|
||||
description: 'Invalid tool',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result',
|
||||
inputSchema: {},
|
||||
handler: async () => ({
|
||||
content: [{ type: 'text' as const, text: 'result' }],
|
||||
}),
|
||||
};
|
||||
|
||||
expect(() =>
|
||||
createSdkMcpServer('test-server', '1.0.0', [
|
||||
invalidTool as unknown as ToolDefinition,
|
||||
]),
|
||||
createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
|
||||
tools: [invalidTool as unknown as SdkMcpToolDefinition<any>],
|
||||
}),
|
||||
).toThrow('Tool name');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Tool Handler Invocation', () => {
|
||||
it('should invoke tool handler with correct input', async () => {
|
||||
const handler = vi.fn().mockResolvedValue({ result: 'success' });
|
||||
|
||||
const testTool = tool({
|
||||
name: 'test_tool',
|
||||
description: 'A test tool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
value: { type: 'string' },
|
||||
},
|
||||
required: ['value'],
|
||||
},
|
||||
handler,
|
||||
const handler = vi.fn().mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'success' }],
|
||||
});
|
||||
|
||||
createSdkMcpServer('test-server', '1.0.0', [testTool]);
|
||||
const testTool = tool(
|
||||
'test_tool',
|
||||
'A test tool',
|
||||
{ value: z.string() },
|
||||
handler,
|
||||
);
|
||||
|
||||
createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [testTool],
|
||||
});
|
||||
|
||||
// Note: Actual invocation testing requires MCP SDK integration
|
||||
// This test verifies the handler was properly registered
|
||||
@@ -140,17 +170,18 @@ describe('createSdkMcpServer', () => {
|
||||
.fn()
|
||||
.mockImplementation(async (input: { value: string }) => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
return { processed: input.value };
|
||||
return {
|
||||
content: [{ type: 'text', text: `processed: ${input.value}` }],
|
||||
};
|
||||
});
|
||||
|
||||
const testTool = tool({
|
||||
name: 'async_tool',
|
||||
description: 'An async tool',
|
||||
inputSchema: { type: 'object' },
|
||||
handler,
|
||||
});
|
||||
const testTool = tool('async_tool', 'An async tool', {}, handler);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [testTool]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [testTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
@@ -158,40 +189,29 @@ describe('createSdkMcpServer', () => {
|
||||
|
||||
describe('Type Safety', () => {
|
||||
it('should preserve input type in handler', async () => {
|
||||
type ToolInput = {
|
||||
name: string;
|
||||
age: number;
|
||||
};
|
||||
|
||||
type ToolOutput = {
|
||||
greeting: string;
|
||||
};
|
||||
|
||||
const handler = vi
|
||||
.fn()
|
||||
.mockImplementation(async (input: ToolInput): Promise<ToolOutput> => {
|
||||
const handler = vi.fn().mockImplementation(async (input) => {
|
||||
return {
|
||||
greeting: `Hello ${input.name}, age ${input.age}`,
|
||||
content: [
|
||||
{ type: 'text', text: `Hello ${input.name}, age ${input.age}` },
|
||||
],
|
||||
};
|
||||
});
|
||||
|
||||
const typedTool = tool<ToolInput, ToolOutput>({
|
||||
name: 'typed_tool',
|
||||
description: 'A typed tool',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
age: { type: 'number' },
|
||||
},
|
||||
required: ['name', 'age'],
|
||||
const typedTool = tool(
|
||||
'typed_tool',
|
||||
'A typed tool',
|
||||
{
|
||||
name: z.string(),
|
||||
age: z.number(),
|
||||
},
|
||||
handler,
|
||||
});
|
||||
);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [
|
||||
typedTool as ToolDefinition,
|
||||
]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [typedTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
@@ -201,14 +221,13 @@ describe('createSdkMcpServer', () => {
|
||||
it('should handle tool handler errors gracefully', async () => {
|
||||
const handler = vi.fn().mockRejectedValue(new Error('Tool failed'));
|
||||
|
||||
const errorTool = tool({
|
||||
name: 'error_tool',
|
||||
description: 'A tool that errors',
|
||||
inputSchema: { type: 'object' },
|
||||
handler,
|
||||
});
|
||||
const errorTool = tool('error_tool', 'A tool that errors', {}, handler);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [errorTool]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [errorTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
// Error handling occurs during tool invocation
|
||||
@@ -219,14 +238,18 @@ describe('createSdkMcpServer', () => {
|
||||
throw new Error('Sync error');
|
||||
});
|
||||
|
||||
const errorTool = tool({
|
||||
name: 'sync_error_tool',
|
||||
description: 'A tool that errors synchronously',
|
||||
inputSchema: { type: 'object' },
|
||||
const errorTool = tool(
|
||||
'sync_error_tool',
|
||||
'A tool that errors synchronously',
|
||||
{},
|
||||
handler,
|
||||
});
|
||||
);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [errorTool]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [errorTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
@@ -234,49 +257,51 @@ describe('createSdkMcpServer', () => {
|
||||
|
||||
describe('Complex Tool Scenarios', () => {
|
||||
it('should support tool with complex input schema', () => {
|
||||
const complexTool = tool({
|
||||
name: 'complex_tool',
|
||||
description: 'A tool with complex schema',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string' },
|
||||
filters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
category: { type: 'string' },
|
||||
minPrice: { type: 'number' },
|
||||
const complexTool = tool(
|
||||
'complex_tool',
|
||||
'A tool with complex schema',
|
||||
{
|
||||
query: z.string(),
|
||||
filters: z
|
||||
.object({
|
||||
category: z.string().optional(),
|
||||
minPrice: z.number().optional(),
|
||||
})
|
||||
.optional(),
|
||||
options: z.array(z.string()).optional(),
|
||||
},
|
||||
},
|
||||
options: {
|
||||
type: 'array',
|
||||
items: { type: 'string' },
|
||||
},
|
||||
},
|
||||
required: ['query'],
|
||||
},
|
||||
handler: async (input: { filters?: unknown[] }) => {
|
||||
async (input) => {
|
||||
return {
|
||||
results: [],
|
||||
filters: input.filters,
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({ results: [], filters: input.filters }),
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
});
|
||||
);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [
|
||||
complexTool as ToolDefinition,
|
||||
]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [complexTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
|
||||
it('should support tool returning complex output', () => {
|
||||
const complexOutputTool = tool({
|
||||
name: 'complex_output_tool',
|
||||
description: 'Returns complex data',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => {
|
||||
const complexOutputTool = tool(
|
||||
'complex_output_tool',
|
||||
'Returns complex data',
|
||||
{},
|
||||
async () => {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
data: [
|
||||
{ id: 1, name: 'Item 1' },
|
||||
{ id: 2, name: 'Item 2' },
|
||||
@@ -290,13 +315,18 @@ describe('createSdkMcpServer', () => {
|
||||
value: 'test',
|
||||
},
|
||||
},
|
||||
}),
|
||||
},
|
||||
],
|
||||
};
|
||||
},
|
||||
});
|
||||
);
|
||||
|
||||
const server = createSdkMcpServer('test-server', '1.0.0', [
|
||||
complexOutputTool,
|
||||
]);
|
||||
const server = createSdkMcpServer({
|
||||
name: 'test-server',
|
||||
version: '1.0.0',
|
||||
tools: [complexOutputTool],
|
||||
});
|
||||
|
||||
expect(server).toBeDefined();
|
||||
});
|
||||
@@ -304,44 +334,50 @@ describe('createSdkMcpServer', () => {
|
||||
|
||||
describe('Multiple Servers', () => {
|
||||
it('should create multiple independent servers', () => {
|
||||
const tool1 = tool({
|
||||
name: 'tool1',
|
||||
description: 'Tool in server 1',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result1',
|
||||
});
|
||||
const tool1 = tool('tool1', 'Tool in server 1', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result1' }],
|
||||
}));
|
||||
|
||||
const tool2 = tool({
|
||||
name: 'tool2',
|
||||
description: 'Tool in server 2',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result2',
|
||||
});
|
||||
const tool2 = tool('tool2', 'Tool in server 2', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result2' }],
|
||||
}));
|
||||
|
||||
const server1 = createSdkMcpServer('server1', '1.0.0', [tool1]);
|
||||
const server2 = createSdkMcpServer('server2', '1.0.0', [tool2]);
|
||||
const server1 = createSdkMcpServer({
|
||||
name: 'server1',
|
||||
version: '1.0.0',
|
||||
tools: [tool1],
|
||||
});
|
||||
const server2 = createSdkMcpServer({
|
||||
name: 'server2',
|
||||
version: '1.0.0',
|
||||
tools: [tool2],
|
||||
});
|
||||
|
||||
expect(server1).toBeDefined();
|
||||
expect(server2).toBeDefined();
|
||||
expect(server1.name).toBe('server1');
|
||||
expect(server2.name).toBe('server2');
|
||||
});
|
||||
|
||||
it('should allow same tool name in different servers', () => {
|
||||
const tool1 = tool({
|
||||
name: 'shared_name',
|
||||
description: 'Tool in server 1',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result1',
|
||||
});
|
||||
const tool1 = tool('shared_name', 'Tool in server 1', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result1' }],
|
||||
}));
|
||||
|
||||
const tool2 = tool({
|
||||
name: 'shared_name',
|
||||
description: 'Tool in server 2',
|
||||
inputSchema: { type: 'object' },
|
||||
handler: async () => 'result2',
|
||||
});
|
||||
const tool2 = tool('shared_name', 'Tool in server 2', {}, async () => ({
|
||||
content: [{ type: 'text', text: 'result2' }],
|
||||
}));
|
||||
|
||||
const server1 = createSdkMcpServer('server1', '1.0.0', [tool1]);
|
||||
const server2 = createSdkMcpServer('server2', '1.0.0', [tool2]);
|
||||
const server1 = createSdkMcpServer({
|
||||
name: 'server1',
|
||||
version: '1.0.0',
|
||||
tools: [tool1],
|
||||
});
|
||||
const server2 = createSdkMcpServer({
|
||||
name: 'server2',
|
||||
version: '1.0.0',
|
||||
tools: [tool2],
|
||||
});
|
||||
|
||||
expect(server1).toBeDefined();
|
||||
expect(server2).toBeDefined();
|
||||
|
||||
Reference in New Issue
Block a user