diff --git a/.github/workflows/release-sdk.yml b/.github/workflows/release-sdk.yml index d0b558f7..69192520 100644 --- a/.github/workflows/release-sdk.yml +++ b/.github/workflows/release-sdk.yml @@ -132,6 +132,24 @@ jobs: OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}' OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}' + - name: 'Build CLI for Integration Tests' + if: |- + ${{ github.event.inputs.force_skip_tests != 'true' }} + run: | + npm run build + npm run bundle + + - name: 'Run SDK Integration Tests' + if: |- + ${{ github.event.inputs.force_skip_tests != 'true' }} + run: | + npm run test:integration:sdk:sandbox:none + npm run test:integration:sdk:sandbox:docker + env: + OPENAI_API_KEY: '${{ secrets.OPENAI_API_KEY }}' + OPENAI_BASE_URL: '${{ secrets.OPENAI_BASE_URL }}' + OPENAI_MODEL: '${{ secrets.OPENAI_MODEL }}' + - name: 'Configure Git User' run: | git config user.name "github-actions[bot]" @@ -184,7 +202,7 @@ jobs: registry-url: 'https://registry.npmjs.org' scope: '@qwen-code' - - name: 'Publish @qwen-code/sdk-typescript' + - name: 'Publish @qwen-code/sdk' working-directory: 'packages/sdk-typescript' run: |- npm publish --access public --tag=${{ steps.version.outputs.NPM_TAG }} ${{ steps.vars.outputs.is_dry_run == 'true' && '--dry-run' || '' }} diff --git a/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts b/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts index b0b4c3fd..93005d4b 100644 --- a/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts +++ b/integration-tests/sdk-typescript/abort-and-lifecycle.test.ts @@ -13,7 +13,7 @@ import { isSDKAssistantMessage, type TextBlock, type ContentBlock, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, createSharedTestOptions } from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); diff --git a/integration-tests/sdk-typescript/configuration-options.test.ts b/integration-tests/sdk-typescript/configuration-options.test.ts index 3f825f3e..bc59cd79 100644 --- a/integration-tests/sdk-typescript/configuration-options.test.ts +++ b/integration-tests/sdk-typescript/configuration-options.test.ts @@ -17,7 +17,7 @@ import { isSDKAssistantMessage, isSDKSystemMessage, type SDKMessage, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, extractText, @@ -532,7 +532,6 @@ describe('Configuration Options (E2E)', () => { cwd: testDir, authType: 'openai', debug: true, - logLevel: 'debug', stderr: (msg: string) => { stderrMessages.push(msg); }, diff --git a/integration-tests/sdk-typescript/mcp-server.test.ts b/integration-tests/sdk-typescript/mcp-server.test.ts index 110c1924..9b3f2193 100644 --- a/integration-tests/sdk-typescript/mcp-server.test.ts +++ b/integration-tests/sdk-typescript/mcp-server.test.ts @@ -19,7 +19,7 @@ import { type SDKMessage, type ToolUseBlock, type SDKSystemMessage, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, createMCPServer, diff --git a/integration-tests/sdk-typescript/multi-turn.test.ts b/integration-tests/sdk-typescript/multi-turn.test.ts index 17b6f675..c1b96cc7 100644 --- a/integration-tests/sdk-typescript/multi-turn.test.ts +++ b/integration-tests/sdk-typescript/multi-turn.test.ts @@ -21,7 +21,7 @@ import { type SDKMessage, type ControlMessage, type ToolUseBlock, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, createSharedTestOptions } from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); diff --git a/integration-tests/sdk-typescript/permission-control.test.ts b/integration-tests/sdk-typescript/permission-control.test.ts index 31c7768a..e8d201e6 100644 --- a/integration-tests/sdk-typescript/permission-control.test.ts +++ b/integration-tests/sdk-typescript/permission-control.test.ts @@ -22,7 +22,7 @@ import { type SDKUserMessage, type ToolUseBlock, type ContentBlock, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, createSharedTestOptions, @@ -555,6 +555,15 @@ describe('Permission Control (E2E)', () => { ...SHARED_TEST_OPTIONS, cwd: testDir, permissionMode: 'default', + timeout: { + /** + * We use a short control request timeout and + * wait till the time exceeded to test if + * an immediate close() will raise an query close + * error and no other uncaught timeout error + */ + controlRequest: 5000, + }, }, }); @@ -563,7 +572,9 @@ describe('Permission Control (E2E)', () => { await expect(q.setPermissionMode('yolo')).rejects.toThrow( 'Query is closed', ); - }); + + await new Promise((resolve) => setTimeout(resolve, 8000)); + }, 10_000); }); describe('canUseTool and setPermissionMode integration', () => { @@ -1184,7 +1195,7 @@ describe('Permission Control (E2E)', () => { }); describe('mode comparison tests', () => { - it( + it.skip( 'should demonstrate different behaviors across all modes for write operations', async () => { const modes: Array<'default' | 'auto-edit' | 'yolo'> = [ diff --git a/integration-tests/sdk-typescript/sdk-mcp-server.test.ts b/integration-tests/sdk-typescript/sdk-mcp-server.test.ts new file mode 100644 index 00000000..1ce8658e --- /dev/null +++ b/integration-tests/sdk-typescript/sdk-mcp-server.test.ts @@ -0,0 +1,456 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * E2E tests for SDK-embedded MCP servers + * + * Tests that the SDK can create and manage MCP servers running in the SDK process + * using the tool() and createSdkMcpServer() APIs. + */ + +import { describe, it, expect, beforeEach, afterEach } from 'vitest'; +import { z } from 'zod'; +import { + query, + tool, + createSdkMcpServer, + isSDKAssistantMessage, + isSDKResultMessage, + isSDKSystemMessage, + type SDKMessage, + type SDKSystemMessage, +} from '@qwen-code/sdk'; +import { + SDKTestHelper, + extractText, + findToolUseBlocks, + createSharedTestOptions, +} from './test-helper.js'; + +const SHARED_TEST_OPTIONS = { + ...createSharedTestOptions(), + permissionMode: 'yolo' as const, +}; + +describe('SDK MCP Server Integration (E2E)', () => { + let helper: SDKTestHelper; + let testDir: string; + + beforeEach(async () => { + helper = new SDKTestHelper(); + testDir = await helper.setup('sdk-mcp-server-integration'); + }); + + afterEach(async () => { + await helper.cleanup(); + }); + + describe('Basic SDK MCP Tool Usage', () => { + it('should use SDK MCP tool to perform a simple calculation', async () => { + // Define a simple calculator tool using the tool() API with Zod schema + const calculatorTool = tool( + 'calculate_sum', + 'Calculate the sum of two numbers', + z.object({ + a: z.number().describe('First number'), + b: z.number().describe('Second number'), + }).shape, + async (args) => ({ + content: [{ type: 'text', text: String(args.a + args.b) }], + }), + ); + + // Create SDK MCP server with the tool + const serverConfig = createSdkMcpServer({ + name: 'sdk-calculator', + version: '1.0.0', + tools: [calculatorTool], + }); + + const q = query({ + prompt: + 'Use the calculate_sum tool to add 25 and 17. Output the result of tool only.', + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + mcpServers: { + 'sdk-calculator': serverConfig, + }, + }, + }); + + const messages: SDKMessage[] = []; + let assistantText = ''; + let foundToolUse = false; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks(message, 'calculate_sum'); + if (toolUseBlocks.length > 0) { + foundToolUse = true; + } + assistantText += extractText(message.message.content); + } + } + + // Validate tool was called + expect(foundToolUse).toBe(true); + + // Validate result contains expected answer: 25 + 17 = 42 + expect(assistantText).toMatch(/42/); + + // Validate successful completion + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + if (isSDKResultMessage(lastMessage)) { + expect(lastMessage.subtype).toBe('success'); + } + } finally { + await q.close(); + } + }); + + it('should use SDK MCP tool with string operations', async () => { + // Define a string manipulation tool with Zod schema + const stringTool = tool( + 'reverse_string', + 'Reverse a string', + { + text: z.string().describe('The text to reverse'), + }, + async (args) => ({ + content: [ + { type: 'text', text: args.text.split('').reverse().join('') }, + ], + }), + ); + + const serverConfig = createSdkMcpServer({ + name: 'sdk-string-utils', + version: '1.0.0', + tools: [stringTool], + }); + + const q = query({ + prompt: `Use the 'reverse_string' tool to process the word "hello world". Output the tool result only.`, + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + mcpServers: { + 'sdk-string-utils': serverConfig, + }, + }, + }); + + const messages: SDKMessage[] = []; + let assistantText = ''; + let foundToolUse = false; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks(message, 'reverse_string'); + if (toolUseBlocks.length > 0) { + foundToolUse = true; + } + assistantText += extractText(message.message.content); + } + } + + // Validate tool was called + expect(foundToolUse).toBe(true); + + // Validate result contains reversed string: "olleh" + expect(assistantText.toLowerCase()).toMatch(/olleh/); + + // Validate successful completion + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + } finally { + await q.close(); + } + }); + }); + + describe('Multiple SDK MCP Tools', () => { + it('should use multiple tools from the same SDK MCP server', async () => { + // Define the Zod schema shape for two numbers + const twoNumbersSchema = { + a: z.number().describe('First number'), + b: z.number().describe('Second number'), + }; + + // Define multiple tools + const addTool = tool( + 'sdk_add', + 'Add two numbers', + twoNumbersSchema, + async (args) => ({ + content: [{ type: 'text', text: String(args.a + args.b) }], + }), + ); + + const multiplyTool = tool( + 'sdk_multiply', + 'Multiply two numbers', + twoNumbersSchema, + async (args) => ({ + content: [{ type: 'text', text: String(args.a * args.b) }], + }), + ); + + const serverConfig = createSdkMcpServer({ + name: 'sdk-math', + version: '1.0.0', + tools: [addTool, multiplyTool], + }); + + const q = query({ + prompt: + 'First use sdk_add to calculate 10 + 5, then use sdk_multiply to multiply the result by 3. Give me the final answer.', + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + debug: false, + mcpServers: { + 'sdk-math': serverConfig, + }, + }, + }); + + const messages: SDKMessage[] = []; + let assistantText = ''; + const toolCalls: string[] = []; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks(message); + toolUseBlocks.forEach((block) => { + toolCalls.push(block.name); + }); + assistantText += extractText(message.message.content); + } + } + + // Validate both tools were called + expect(toolCalls).toContain('sdk_add'); + expect(toolCalls).toContain('sdk_multiply'); + + // Validate result: (10 + 5) * 3 = 45 + expect(assistantText).toMatch(/45/); + + // Validate successful completion + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + } finally { + await q.close(); + } + }); + }); + + describe('SDK MCP Server Discovery', () => { + it('should list SDK MCP servers in system init message', async () => { + // Define echo tool with Zod schema + const echoTool = tool( + 'echo', + 'Echo a message', + { + message: z.string().describe('Message to echo'), + }, + async (args) => ({ + content: [{ type: 'text', text: args.message }], + }), + ); + + const serverConfig = createSdkMcpServer({ + name: 'sdk-echo', + version: '1.0.0', + tools: [echoTool], + }); + + const q = query({ + prompt: 'Hello', + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + debug: false, + mcpServers: { + 'sdk-echo': serverConfig, + }, + }, + }); + + let systemMessage: SDKSystemMessage | null = null; + + try { + for await (const message of q) { + if (isSDKSystemMessage(message) && message.subtype === 'init') { + systemMessage = message; + break; + } + } + + // Validate MCP server is listed + expect(systemMessage).not.toBeNull(); + expect(systemMessage!.mcp_servers).toBeDefined(); + expect(Array.isArray(systemMessage!.mcp_servers)).toBe(true); + + // Find our SDK MCP server + const sdkServer = systemMessage!.mcp_servers?.find( + (server) => server.name === 'sdk-echo', + ); + expect(sdkServer).toBeDefined(); + } finally { + await q.close(); + } + }); + }); + + describe('SDK MCP Tool Error Handling', () => { + it('should handle tool errors gracefully', async () => { + // Define a tool that throws an error with Zod schema + const errorTool = tool( + 'maybe_fail', + 'A tool that may fail based on input', + { + shouldFail: z.boolean().describe('If true, the tool will fail'), + }, + async (args) => { + if (args.shouldFail) { + throw new Error('Tool intentionally failed'); + } + return { content: [{ type: 'text', text: 'Success!' }] }; + }, + ); + + const serverConfig = createSdkMcpServer({ + name: 'sdk-error-test', + version: '1.0.0', + tools: [errorTool], + }); + + const q = query({ + prompt: + 'Use the maybe_fail tool with shouldFail set to true. Tell me what happens.', + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + debug: false, + mcpServers: { + 'sdk-error-test': serverConfig, + }, + }, + }); + + const messages: SDKMessage[] = []; + let foundToolUse = false; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks(message, 'maybe_fail'); + if (toolUseBlocks.length > 0) { + foundToolUse = true; + } + } + } + + // Tool should be called + expect(foundToolUse).toBe(true); + + // Query should complete (even with tool error) + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + } finally { + await q.close(); + } + }); + }); + + describe('Async Tool Handlers', () => { + it('should handle async tool handlers with delays', async () => { + // Define a tool with async delay using Zod schema + const delayedTool = tool( + 'delayed_response', + 'Returns a value after a delay', + { + delay: z.number().describe('Delay in milliseconds (max 100)'), + value: z.string().describe('Value to return'), + }, + async (args) => { + // Cap delay at 100ms for test performance + const actualDelay = Math.min(args.delay, 100); + await new Promise((resolve) => setTimeout(resolve, actualDelay)); + return { + content: [{ type: 'text', text: `Delayed result: ${args.value}` }], + }; + }, + ); + + const serverConfig = createSdkMcpServer({ + name: 'sdk-async', + version: '1.0.0', + tools: [delayedTool], + }); + + const q = query({ + prompt: + 'Use the delayed_response tool with delay=50 and value="test_async". Tell me the result.', + options: { + ...SHARED_TEST_OPTIONS, + cwd: testDir, + debug: false, + mcpServers: { + 'sdk-async': serverConfig, + }, + }, + }); + + const messages: SDKMessage[] = []; + let assistantText = ''; + let foundToolUse = false; + + try { + for await (const message of q) { + messages.push(message); + + if (isSDKAssistantMessage(message)) { + const toolUseBlocks = findToolUseBlocks( + message, + 'delayed_response', + ); + if (toolUseBlocks.length > 0) { + foundToolUse = true; + } + assistantText += extractText(message.message.content); + } + } + + // Validate tool was called + expect(foundToolUse).toBe(true); + + // Validate result contains the delayed response + expect(assistantText.toLowerCase()).toMatch(/test_async/i); + + // Validate successful completion + const lastMessage = messages[messages.length - 1]; + expect(isSDKResultMessage(lastMessage)).toBe(true); + } finally { + await q.close(); + } + }); + }); +}); diff --git a/integration-tests/sdk-typescript/single-turn.test.ts b/integration-tests/sdk-typescript/single-turn.test.ts index aa2716f3..3608e619 100644 --- a/integration-tests/sdk-typescript/single-turn.test.ts +++ b/integration-tests/sdk-typescript/single-turn.test.ts @@ -13,7 +13,7 @@ import { type SDKMessage, type SDKSystemMessage, type SDKAssistantMessage, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, extractText, @@ -44,7 +44,6 @@ describe('Single-Turn Query (E2E)', () => { ...SHARED_TEST_OPTIONS, cwd: testDir, debug: true, - logLevel: 'debug', }, }); diff --git a/integration-tests/sdk-typescript/subagents.test.ts b/integration-tests/sdk-typescript/subagents.test.ts index 86516053..c327c96e 100644 --- a/integration-tests/sdk-typescript/subagents.test.ts +++ b/integration-tests/sdk-typescript/subagents.test.ts @@ -17,7 +17,7 @@ import { type SubagentConfig, type ContentBlock, type ToolUseBlock, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, extractText, diff --git a/integration-tests/sdk-typescript/system-control.test.ts b/integration-tests/sdk-typescript/system-control.test.ts index 069eccd9..0b0a74d3 100644 --- a/integration-tests/sdk-typescript/system-control.test.ts +++ b/integration-tests/sdk-typescript/system-control.test.ts @@ -9,7 +9,7 @@ import { isSDKAssistantMessage, isSDKSystemMessage, type SDKUserMessage, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { SDKTestHelper, createSharedTestOptions } from './test-helper.js'; const SHARED_TEST_OPTIONS = createSharedTestOptions(); diff --git a/integration-tests/sdk-typescript/test-helper.ts b/integration-tests/sdk-typescript/test-helper.ts index cd95051f..f3005655 100644 --- a/integration-tests/sdk-typescript/test-helper.ts +++ b/integration-tests/sdk-typescript/test-helper.ts @@ -21,12 +21,12 @@ import type { ContentBlock, TextBlock, ToolUseBlock, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; import { isSDKAssistantMessage, isSDKSystemMessage, isSDKResultMessage, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; // ============================================================================ // Core Test Helper Class diff --git a/integration-tests/sdk-typescript/tool-control.test.ts b/integration-tests/sdk-typescript/tool-control.test.ts index 036d779e..b2b955a6 100644 --- a/integration-tests/sdk-typescript/tool-control.test.ts +++ b/integration-tests/sdk-typescript/tool-control.test.ts @@ -12,11 +12,7 @@ */ import { describe, it, expect, beforeEach, afterEach } from 'vitest'; -import { - query, - isSDKAssistantMessage, - type SDKMessage, -} from '@qwen-code/sdk-typescript'; +import { query, isSDKAssistantMessage, type SDKMessage } from '@qwen-code/sdk'; import { SDKTestHelper, extractText, diff --git a/integration-tests/tsconfig.json b/integration-tests/tsconfig.json index 7f2a010d..0cd24f82 100644 --- a/integration-tests/tsconfig.json +++ b/integration-tests/tsconfig.json @@ -5,9 +5,7 @@ "allowJs": true, "baseUrl": ".", "paths": { - "@qwen-code/sdk-typescript": [ - "../packages/sdk-typescript/dist/index.d.ts" - ] + "@qwen-code/sdk": ["../packages/sdk-typescript/dist/index.d.ts"] } }, "include": ["**/*.ts"], diff --git a/integration-tests/vitest.config.ts b/integration-tests/vitest.config.ts index a452583c..9be72f50 100644 --- a/integration-tests/vitest.config.ts +++ b/integration-tests/vitest.config.ts @@ -31,7 +31,7 @@ export default defineConfig({ resolve: { alias: { // Use built SDK bundle for e2e tests - '@qwen-code/sdk-typescript': resolve( + '@qwen-code/sdk': resolve( __dirname, '../packages/sdk-typescript/dist/index.mjs', ), diff --git a/package-lock.json b/package-lock.json index 53fe9d46..f3bb0cad 100644 --- a/package-lock.json +++ b/package-lock.json @@ -2793,7 +2793,7 @@ "resolved": "packages/test-utils", "link": true }, - "node_modules/@qwen-code/sdk-typescript": { + "node_modules/@qwen-code/sdk": { "resolved": "packages/sdk-typescript", "link": true }, @@ -16676,7 +16676,7 @@ } }, "packages/sdk-typescript": { - "name": "@qwen-code/sdk-typescript", + "name": "@qwen-code/sdk", "version": "0.1.0", "license": "Apache-2.0", "dependencies": { diff --git a/package.json b/package.json index a8b6857f..563de8f7 100644 --- a/package.json +++ b/package.json @@ -37,6 +37,10 @@ "test:integration:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests", "test:integration:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests", "test:integration:sandbox:podman": "cross-env GEMINI_SANDBOX=podman vitest run --root ./integration-tests", + "test:integration:sdk:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests sdk-typescript", + "test:integration:sdk:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests sdk-typescript", + "test:integration:cli:sandbox:none": "cross-env GEMINI_SANDBOX=false vitest run --root ./integration-tests --exclude '**/sdk-typescript/**'", + "test:integration:cli:sandbox:docker": "cross-env GEMINI_SANDBOX=docker npm run build:sandbox && GEMINI_SANDBOX=docker vitest run --root ./integration-tests --exclude '**/sdk-typescript/**'", "test:terminal-bench": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests", "test:terminal-bench:oracle": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'oracle'", "test:terminal-bench:qwen": "cross-env VERBOSE=true KEEP_OUTPUT=true vitest run --config ./vitest.terminal-bench.config.ts --root ./integration-tests -t 'qwen'", diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 7171670c..18f191bc 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -276,8 +276,11 @@ export async function main() { process.exit(1); } } + // For stream-json mode, don't read stdin here - it should be forwarded to the sandbox + // and consumed by StreamJsonInputReader inside the container + const inputFormat = argv.inputFormat as string | undefined; let stdinData = ''; - if (!process.stdin.isTTY) { + if (!process.stdin.isTTY && inputFormat !== 'stream-json') { stdinData = await readStdin(); } diff --git a/packages/cli/src/nonInteractive/control/ControlDispatcher.ts b/packages/cli/src/nonInteractive/control/ControlDispatcher.ts index b2165ee9..d6dc79a4 100644 --- a/packages/cli/src/nonInteractive/control/ControlDispatcher.ts +++ b/packages/cli/src/nonInteractive/control/ControlDispatcher.ts @@ -16,9 +16,12 @@ * Controllers: * - SystemController: initialize, interrupt, set_model, supported_commands * - PermissionController: can_use_tool, set_permission_mode - * - MCPController: mcp_message, mcp_server_status + * - SdkMcpController: mcp_server_status (mcp_message handled via callback) * - HookController: hook_callback * + * Note: mcp_message requests are NOT routed through the dispatcher. CLI MCP + * clients send messages via SdkMcpController.createSendSdkMcpMessage() callback. + * * Note: Control request types are centrally defined in the ControlRequestType * enum in packages/sdk/typescript/src/types/controlRequests.ts */ @@ -27,7 +30,7 @@ import type { IControlContext } from './ControlContext.js'; import type { IPendingRequestRegistry } from './controllers/baseController.js'; import { SystemController } from './controllers/systemController.js'; import { PermissionController } from './controllers/permissionController.js'; -// import { MCPController } from './controllers/mcpController.js'; +import { SdkMcpController } from './controllers/sdkMcpController.js'; // import { HookController } from './controllers/hookController.js'; import type { CLIControlRequest, @@ -65,7 +68,7 @@ export class ControlDispatcher implements IPendingRequestRegistry { // Make controllers publicly accessible readonly systemController: SystemController; readonly permissionController: PermissionController; - // readonly mcpController: MCPController; + readonly sdkMcpController: SdkMcpController; // readonly hookController: HookController; // Central pending request registries @@ -88,7 +91,11 @@ export class ControlDispatcher implements IPendingRequestRegistry { this, 'PermissionController', ); - // this.mcpController = new MCPController(context, this, 'MCPController'); + this.sdkMcpController = new SdkMcpController( + context, + this, + 'SdkMcpController', + ); // this.hookController = new HookController(context, this, 'HookController'); // Listen for main abort signal @@ -228,10 +235,10 @@ export class ControlDispatcher implements IPendingRequestRegistry { } this.pendingOutgoingRequests.clear(); - // Cleanup controllers (MCP controller will close all clients) + // Cleanup controllers this.systemController.cleanup(); this.permissionController.cleanup(); - // this.mcpController.cleanup(); + this.sdkMcpController.cleanup(); // this.hookController.cleanup(); } @@ -291,6 +298,47 @@ export class ControlDispatcher implements IPendingRequestRegistry { } } + /** + * Get count of pending incoming requests (for debugging) + */ + getPendingIncomingRequestCount(): number { + return this.pendingIncomingRequests.size; + } + + /** + * Wait for all incoming request handlers to complete. + * + * Uses polling since we don't have direct Promise references to handlers. + * The pendingIncomingRequests map is managed by BaseController: + * - Registered when handler starts (in handleRequest) + * - Deregistered when handler completes (success or error) + * + * @param pollIntervalMs - How often to check (default 50ms) + * @param timeoutMs - Maximum wait time (default 30s) + */ + async waitForPendingIncomingRequests( + pollIntervalMs: number = 50, + timeoutMs: number = 30000, + ): Promise { + const startTime = Date.now(); + + while (this.pendingIncomingRequests.size > 0) { + if (Date.now() - startTime > timeoutMs) { + if (this.context.debugMode) { + console.error( + `[ControlDispatcher] Timeout waiting for ${this.pendingIncomingRequests.size} pending incoming requests`, + ); + } + break; + } + await new Promise((resolve) => setTimeout(resolve, pollIntervalMs)); + } + + if (this.context.debugMode && this.pendingIncomingRequests.size === 0) { + console.error('[ControlDispatcher] All incoming requests completed'); + } + } + /** * Returns the controller that handles the given request subtype */ @@ -306,9 +354,8 @@ export class ControlDispatcher implements IPendingRequestRegistry { case 'set_permission_mode': return this.permissionController; - // case 'mcp_message': - // case 'mcp_server_status': - // return this.mcpController; + case 'mcp_server_status': + return this.sdkMcpController; // case 'hook_callback': // return this.hookController; diff --git a/packages/cli/src/nonInteractive/control/controllers/baseController.ts b/packages/cli/src/nonInteractive/control/controllers/baseController.ts index 90b7f56a..dcb9e7c9 100644 --- a/packages/cli/src/nonInteractive/control/controllers/baseController.ts +++ b/packages/cli/src/nonInteractive/control/controllers/baseController.ts @@ -117,16 +117,41 @@ export abstract class BaseController { * Send an outgoing control request to SDK * * Manages lifecycle: register -> send -> wait for response -> deregister + * Respects the provided AbortSignal for cancellation. */ async sendControlRequest( payload: ControlRequestPayload, timeoutMs: number = DEFAULT_REQUEST_TIMEOUT_MS, + signal?: AbortSignal, ): Promise { + // Check if already aborted + if (signal?.aborted) { + throw new Error('Request aborted'); + } + const requestId = randomUUID(); return new Promise((resolve, reject) => { + // Setup abort handler + const abortHandler = () => { + this.registry.deregisterOutgoingRequest(requestId); + reject(new Error('Request aborted')); + if (this.context.debugMode) { + console.error( + `[${this.controllerName}] Outgoing request aborted: ${requestId}`, + ); + } + }; + + if (signal) { + signal.addEventListener('abort', abortHandler, { once: true }); + } + // Setup timeout const timeoutId = setTimeout(() => { + if (signal) { + signal.removeEventListener('abort', abortHandler); + } this.registry.deregisterOutgoingRequest(requestId); reject(new Error('Control request timeout')); if (this.context.debugMode) { @@ -136,12 +161,27 @@ export abstract class BaseController { } }, timeoutMs); + // Wrap resolve/reject to clean up abort listener + const wrappedResolve = (response: ControlResponse) => { + if (signal) { + signal.removeEventListener('abort', abortHandler); + } + resolve(response); + }; + + const wrappedReject = (error: Error) => { + if (signal) { + signal.removeEventListener('abort', abortHandler); + } + reject(error); + }; + // Register with central registry this.registry.registerOutgoingRequest( requestId, this.controllerName, - resolve, - reject, + wrappedResolve, + wrappedReject, timeoutId, ); @@ -155,6 +195,9 @@ export abstract class BaseController { try { this.context.streamJson.send(request); } catch (error) { + if (signal) { + signal.removeEventListener('abort', abortHandler); + } this.registry.deregisterOutgoingRequest(requestId); reject(error); } diff --git a/packages/cli/src/nonInteractive/control/controllers/mcpController.ts b/packages/cli/src/nonInteractive/control/controllers/mcpController.ts deleted file mode 100644 index fccafb67..00000000 --- a/packages/cli/src/nonInteractive/control/controllers/mcpController.ts +++ /dev/null @@ -1,287 +0,0 @@ -/** - * @license - * Copyright 2025 Qwen Team - * SPDX-License-Identifier: Apache-2.0 - */ - -/** - * MCP Controller - * - * Handles MCP-related control requests: - * - mcp_message: Route MCP messages - * - mcp_server_status: Return MCP server status - */ - -import { BaseController } from './baseController.js'; -import type { Client } from '@modelcontextprotocol/sdk/client/index.js'; -import { ResultSchema } from '@modelcontextprotocol/sdk/types.js'; -import type { - ControlRequestPayload, - CLIControlMcpMessageRequest, -} from '../../types.js'; -import type { - MCPServerConfig, - WorkspaceContext, -} from '@qwen-code/qwen-code-core'; -import { - connectToMcpServer, - MCP_DEFAULT_TIMEOUT_MSEC, -} from '@qwen-code/qwen-code-core'; - -export class MCPController extends BaseController { - /** - * Handle MCP control requests - */ - protected async handleRequestPayload( - payload: ControlRequestPayload, - _signal: AbortSignal, - ): Promise> { - switch (payload.subtype) { - case 'mcp_message': - return this.handleMcpMessage(payload as CLIControlMcpMessageRequest); - - case 'mcp_server_status': - return this.handleMcpStatus(); - - default: - throw new Error(`Unsupported request subtype in MCPController`); - } - } - - /** - * Handle mcp_message request - * - * Routes JSON-RPC messages to MCP servers - */ - private async handleMcpMessage( - payload: CLIControlMcpMessageRequest, - ): Promise> { - const serverNameRaw = payload.server_name; - if ( - typeof serverNameRaw !== 'string' || - serverNameRaw.trim().length === 0 - ) { - throw new Error('Missing server_name in mcp_message request'); - } - - const message = payload.message; - if (!message || typeof message !== 'object') { - throw new Error( - 'Missing or invalid message payload for mcp_message request', - ); - } - - // Get or create MCP client - let clientEntry: { client: Client; config: MCPServerConfig }; - try { - clientEntry = await this.getOrCreateMcpClient(serverNameRaw.trim()); - } catch (error) { - throw new Error( - error instanceof Error - ? error.message - : 'Failed to connect to MCP server', - ); - } - - const method = message.method; - if (typeof method !== 'string' || method.trim().length === 0) { - throw new Error('Invalid MCP message: missing method'); - } - - const jsonrpcVersion = - typeof message.jsonrpc === 'string' ? message.jsonrpc : '2.0'; - const messageId = message.id; - const params = message.params; - const timeout = - typeof clientEntry.config.timeout === 'number' - ? clientEntry.config.timeout - : MCP_DEFAULT_TIMEOUT_MSEC; - - try { - // Handle notification (no id) - if (messageId === undefined) { - await clientEntry.client.notification({ - method, - params, - }); - return { - subtype: 'mcp_message', - mcp_response: { - jsonrpc: jsonrpcVersion, - id: null, - result: { success: true, acknowledged: true }, - }, - }; - } - - // Handle request (with id) - const result = await clientEntry.client.request( - { - method, - params, - }, - ResultSchema, - { timeout }, - ); - - return { - subtype: 'mcp_message', - mcp_response: { - jsonrpc: jsonrpcVersion, - id: messageId, - result, - }, - }; - } catch (error) { - // If connection closed, remove from cache - if (error instanceof Error && /closed/i.test(error.message)) { - this.context.mcpClients.delete(serverNameRaw.trim()); - } - - const errorCode = - typeof (error as { code?: unknown })?.code === 'number' - ? ((error as { code: number }).code as number) - : -32603; - const errorMessage = - error instanceof Error - ? error.message - : 'Failed to execute MCP request'; - const errorData = (error as { data?: unknown })?.data; - - const errorBody: Record = { - code: errorCode, - message: errorMessage, - }; - if (errorData !== undefined) { - errorBody['data'] = errorData; - } - - return { - subtype: 'mcp_message', - mcp_response: { - jsonrpc: jsonrpcVersion, - id: messageId ?? null, - error: errorBody, - }, - }; - } - } - - /** - * Handle mcp_server_status request - * - * Returns status of registered MCP servers - */ - private async handleMcpStatus(): Promise> { - const status: Record = {}; - - // Include SDK MCP servers - for (const serverName of this.context.sdkMcpServers) { - status[serverName] = 'connected'; - } - - // Include CLI-managed MCP clients - for (const serverName of this.context.mcpClients.keys()) { - status[serverName] = 'connected'; - } - - if (this.context.debugMode) { - console.error( - `[MCPController] MCP status: ${Object.keys(status).length} servers`, - ); - } - - return status; - } - - /** - * Get or create MCP client for a server - * - * Implements lazy connection and caching - */ - private async getOrCreateMcpClient( - serverName: string, - ): Promise<{ client: Client; config: MCPServerConfig }> { - // Check cache first - const cached = this.context.mcpClients.get(serverName); - if (cached) { - return cached; - } - - // Get server configuration - const provider = this.context.config as unknown as { - getMcpServers?: () => Record | undefined; - getDebugMode?: () => boolean; - getWorkspaceContext?: () => unknown; - }; - - if (typeof provider.getMcpServers !== 'function') { - throw new Error(`MCP server "${serverName}" is not configured`); - } - - const servers = provider.getMcpServers() ?? {}; - const serverConfig = servers[serverName]; - if (!serverConfig) { - throw new Error(`MCP server "${serverName}" is not configured`); - } - - const debugMode = - typeof provider.getDebugMode === 'function' - ? provider.getDebugMode() - : false; - - const workspaceContext = - typeof provider.getWorkspaceContext === 'function' - ? provider.getWorkspaceContext() - : undefined; - - if (!workspaceContext) { - throw new Error('Workspace context is not available for MCP connection'); - } - - // Connect to MCP server - const client = await connectToMcpServer( - serverName, - serverConfig, - debugMode, - workspaceContext as WorkspaceContext, - ); - - // Cache the client - const entry = { client, config: serverConfig }; - this.context.mcpClients.set(serverName, entry); - - if (this.context.debugMode) { - console.error(`[MCPController] Connected to MCP server: ${serverName}`); - } - - return entry; - } - - /** - * Cleanup MCP clients - */ - override cleanup(): void { - if (this.context.debugMode) { - console.error( - `[MCPController] Cleaning up ${this.context.mcpClients.size} MCP clients`, - ); - } - - // Close all MCP clients - for (const [serverName, { client }] of this.context.mcpClients.entries()) { - try { - client.close(); - } catch (error) { - if (this.context.debugMode) { - console.error( - `[MCPController] Failed to close MCP client ${serverName}:`, - error, - ); - } - } - } - - this.context.mcpClients.clear(); - } -} diff --git a/packages/cli/src/nonInteractive/control/controllers/permissionController.ts b/packages/cli/src/nonInteractive/control/controllers/permissionController.ts index 37a9082f..4cec3b00 100644 --- a/packages/cli/src/nonInteractive/control/controllers/permissionController.ts +++ b/packages/cli/src/nonInteractive/control/controllers/permissionController.ts @@ -44,15 +44,23 @@ export class PermissionController extends BaseController { */ protected async handleRequestPayload( payload: ControlRequestPayload, - _signal: AbortSignal, + signal: AbortSignal, ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + switch (payload.subtype) { case 'can_use_tool': - return this.handleCanUseTool(payload as CLIControlPermissionRequest); + return this.handleCanUseTool( + payload as CLIControlPermissionRequest, + signal, + ); case 'set_permission_mode': return this.handleSetPermissionMode( payload as CLIControlSetPermissionModeRequest, + signal, ); default: @@ -70,7 +78,12 @@ export class PermissionController extends BaseController { */ private async handleCanUseTool( payload: CLIControlPermissionRequest, + signal: AbortSignal, ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + const toolName = payload.tool_name; if ( !toolName || @@ -192,7 +205,12 @@ export class PermissionController extends BaseController { */ private async handleSetPermissionMode( payload: CLIControlSetPermissionModeRequest, + signal: AbortSignal, ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + const mode = payload.mode; const validModes: PermissionMode[] = [ 'default', @@ -373,6 +391,14 @@ export class PermissionController extends BaseController { toolCall: WaitingToolCall, ): Promise { try { + // Check if already aborted + if (this.context.abortSignal?.aborted) { + await toolCall.confirmationDetails.onConfirm( + ToolConfirmationOutcome.Cancel, + ); + return; + } + const inputFormat = this.context.config.getInputFormat?.(); const isStreamJsonMode = inputFormat === InputFormat.STREAM_JSON; @@ -392,14 +418,18 @@ export class PermissionController extends BaseController { toolCall.confirmationDetails, ); - const response = await this.sendControlRequest({ - subtype: 'can_use_tool', - tool_name: toolCall.request.name, - tool_use_id: toolCall.request.callId, - input: toolCall.request.args, - permission_suggestions: permissionSuggestions, - blocked_path: null, - } as CLIControlPermissionRequest); + const response = await this.sendControlRequest( + { + subtype: 'can_use_tool', + tool_name: toolCall.request.name, + tool_use_id: toolCall.request.callId, + input: toolCall.request.args, + permission_suggestions: permissionSuggestions, + blocked_path: null, + } as CLIControlPermissionRequest, + undefined, // use default timeout + this.context.abortSignal, + ); if (response.subtype !== 'success') { await toolCall.confirmationDetails.onConfirm( diff --git a/packages/cli/src/nonInteractive/control/controllers/sdkMcpController.ts b/packages/cli/src/nonInteractive/control/controllers/sdkMcpController.ts new file mode 100644 index 00000000..5d0264fb --- /dev/null +++ b/packages/cli/src/nonInteractive/control/controllers/sdkMcpController.ts @@ -0,0 +1,138 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * SDK MCP Controller + * + * Handles MCP communication between CLI MCP clients and SDK MCP servers: + * - Provides sendSdkMcpMessage callback for CLI → SDK MCP message routing + * - mcp_server_status: Returns status of SDK MCP servers + * + * Message Flow (CLI MCP Client → SDK MCP Server): + * CLI MCP Client → SdkControlClientTransport.send() → + * sendSdkMcpMessage callback → control_request (mcp_message) → SDK → + * SDK MCP Server processes → control_response → CLI MCP Client + */ + +import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; +import { BaseController } from './baseController.js'; +import type { + ControlRequestPayload, + CLIControlMcpMessageRequest, +} from '../../types.js'; + +const MCP_REQUEST_TIMEOUT = 30_000; // 30 seconds + +export class SdkMcpController extends BaseController { + /** + * Handle SDK MCP control requests from ControlDispatcher + * + * Note: mcp_message requests are NOT handled here. CLI MCP clients + * send messages via the sendSdkMcpMessage callback directly, not + * through the control dispatcher. + */ + protected async handleRequestPayload( + payload: ControlRequestPayload, + signal: AbortSignal, + ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + + switch (payload.subtype) { + case 'mcp_server_status': + return this.handleMcpStatus(); + + default: + throw new Error(`Unsupported request subtype in SdkMcpController`); + } + } + + /** + * Handle mcp_server_status request + * + * Returns status of all registered SDK MCP servers. + * SDK servers are considered "connected" if they are registered. + */ + private async handleMcpStatus(): Promise> { + const status: Record = {}; + + for (const serverName of this.context.sdkMcpServers) { + // SDK MCP servers are "connected" once registered since they run in SDK process + status[serverName] = 'connected'; + } + + return { + subtype: 'mcp_server_status', + status, + }; + } + + /** + * Send MCP message to SDK server via control plane + * + * @param serverName - Name of the SDK MCP server + * @param message - MCP JSON-RPC message to send + * @returns MCP JSON-RPC response from SDK server + */ + private async sendMcpMessageToSdk( + serverName: string, + message: JSONRPCMessage, + ): Promise { + if (this.context.debugMode) { + console.error( + `[SdkMcpController] Sending MCP message to SDK server '${serverName}':`, + JSON.stringify(message), + ); + } + + // Send control request to SDK with the MCP message + const response = await this.sendControlRequest( + { + subtype: 'mcp_message', + server_name: serverName, + message: message as CLIControlMcpMessageRequest['message'], + }, + MCP_REQUEST_TIMEOUT, + this.context.abortSignal, + ); + + // Extract MCP response from control response + const responsePayload = response.response as Record; + const mcpResponse = responsePayload?.['mcp_response'] as JSONRPCMessage; + + if (!mcpResponse) { + throw new Error( + `Invalid MCP response from SDK for server '${serverName}'`, + ); + } + + if (this.context.debugMode) { + console.error( + `[SdkMcpController] Received MCP response from SDK server '${serverName}':`, + JSON.stringify(mcpResponse), + ); + } + + return mcpResponse; + } + + /** + * Create a callback function for sending MCP messages to SDK servers. + * + * This callback is used by McpClientManager/SdkControlClientTransport to send + * MCP messages from CLI MCP clients to SDK MCP servers via the control plane. + * + * @returns A function that sends MCP messages to SDK and returns the response + */ + createSendSdkMcpMessage(): ( + serverName: string, + message: JSONRPCMessage, + ) => Promise { + return (serverName: string, message: JSONRPCMessage) => + this.sendMcpMessageToSdk(serverName, message); + } +} diff --git a/packages/cli/src/nonInteractive/control/controllers/systemController.ts b/packages/cli/src/nonInteractive/control/controllers/systemController.ts index c94187e7..e214a881 100644 --- a/packages/cli/src/nonInteractive/control/controllers/systemController.ts +++ b/packages/cli/src/nonInteractive/control/controllers/systemController.ts @@ -18,9 +18,15 @@ import type { ControlRequestPayload, CLIControlInitializeRequest, CLIControlSetModelRequest, + CLIMcpServerConfig, } from '../../types.js'; import { CommandService } from '../../../services/CommandService.js'; import { BuiltinCommandLoader } from '../../../services/BuiltinCommandLoader.js'; +import { + MCPServerConfig, + AuthProviderType, + type MCPOAuthConfig, +} from '@qwen-code/qwen-code-core'; export class SystemController extends BaseController { /** @@ -28,20 +34,30 @@ export class SystemController extends BaseController { */ protected async handleRequestPayload( payload: ControlRequestPayload, - _signal: AbortSignal, + signal: AbortSignal, ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + switch (payload.subtype) { case 'initialize': - return this.handleInitialize(payload as CLIControlInitializeRequest); + return this.handleInitialize( + payload as CLIControlInitializeRequest, + signal, + ); case 'interrupt': return this.handleInterrupt(); case 'set_model': - return this.handleSetModel(payload as CLIControlSetModelRequest); + return this.handleSetModel( + payload as CLIControlSetModelRequest, + signal, + ); case 'supported_commands': - return this.handleSupportedCommands(); + return this.handleSupportedCommands(signal); default: throw new Error(`Unsupported request subtype in SystemController`); @@ -51,46 +67,110 @@ export class SystemController extends BaseController { /** * Handle initialize request * - * Registers SDK MCP servers and returns capabilities + * Processes SDK MCP servers config. + * SDK servers are registered in context.sdkMcpServers + * and added to config.mcpServers with the sdk type flag. + * External MCP servers are configured separately in settings. */ private async handleInitialize( payload: CLIControlInitializeRequest, + signal: AbortSignal, ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + this.context.config.setSdkMode(true); - if (payload.sdkMcpServers && typeof payload.sdkMcpServers === 'object') { - for (const serverName of Object.keys(payload.sdkMcpServers)) { - this.context.sdkMcpServers.add(serverName); + // Process SDK MCP servers + if ( + payload.sdkMcpServers && + typeof payload.sdkMcpServers === 'object' && + payload.sdkMcpServers !== null + ) { + const sdkServers: Record = {}; + for (const [key, wireConfig] of Object.entries(payload.sdkMcpServers)) { + const name = + typeof wireConfig?.name === 'string' && wireConfig.name.trim().length + ? wireConfig.name + : key; + + this.context.sdkMcpServers.add(name); + sdkServers[name] = new MCPServerConfig( + undefined, // command + undefined, // args + undefined, // env + undefined, // cwd + undefined, // url + undefined, // httpUrl + undefined, // headers + undefined, // tcp + undefined, // timeout + true, // trust - SDK servers are trusted + undefined, // description + undefined, // includeTools + undefined, // excludeTools + undefined, // extensionName + undefined, // oauth + undefined, // authProviderType + undefined, // targetAudience + undefined, // targetServiceAccount + 'sdk', // type + ); } - try { - this.context.config.addMcpServers(payload.sdkMcpServers); - if (this.context.debugMode) { - console.error( - `[SystemController] Added ${Object.keys(payload.sdkMcpServers).length} SDK MCP servers to config`, - ); - } - } catch (error) { - if (this.context.debugMode) { - console.error( - '[SystemController] Failed to add SDK MCP servers:', - error, - ); + const sdkServerCount = Object.keys(sdkServers).length; + if (sdkServerCount > 0) { + try { + this.context.config.addMcpServers(sdkServers); + if (this.context.debugMode) { + console.error( + `[SystemController] Added ${sdkServerCount} SDK MCP servers to config`, + ); + } + } catch (error) { + if (this.context.debugMode) { + console.error( + '[SystemController] Failed to add SDK MCP servers:', + error, + ); + } } } } - if (payload.mcpServers && typeof payload.mcpServers === 'object') { - try { - this.context.config.addMcpServers(payload.mcpServers); - if (this.context.debugMode) { - console.error( - `[SystemController] Added ${Object.keys(payload.mcpServers).length} MCP servers to config`, - ); + if ( + payload.mcpServers && + typeof payload.mcpServers === 'object' && + payload.mcpServers !== null + ) { + const externalServers: Record = {}; + for (const [name, serverConfig] of Object.entries(payload.mcpServers)) { + const normalized = this.normalizeMcpServerConfig( + name, + serverConfig as CLIMcpServerConfig | undefined, + ); + if (normalized) { + externalServers[name] = normalized; } - } catch (error) { - if (this.context.debugMode) { - console.error('[SystemController] Failed to add MCP servers:', error); + } + + const externalCount = Object.keys(externalServers).length; + if (externalCount > 0) { + try { + this.context.config.addMcpServers(externalServers); + if (this.context.debugMode) { + console.error( + `[SystemController] Added ${externalCount} external MCP servers to config`, + ); + } + } catch (error) { + if (this.context.debugMode) { + console.error( + '[SystemController] Failed to add external MCP servers:', + error, + ); + } } } } @@ -143,13 +223,96 @@ export class SystemController extends BaseController { can_set_permission_mode: typeof this.context.config.setApprovalMode === 'function', can_set_model: typeof this.context.config.setModel === 'function', - /* TODO: sdkMcpServers support */ - can_handle_mcp_message: false, + // SDK MCP servers are supported - messages routed through control plane + can_handle_mcp_message: true, }; return capabilities; } + private normalizeMcpServerConfig( + serverName: string, + config?: CLIMcpServerConfig, + ): MCPServerConfig | null { + if (!config || typeof config !== 'object') { + if (this.context.debugMode) { + console.error( + `[SystemController] Ignoring invalid MCP server config for '${serverName}'`, + ); + } + return null; + } + + const authProvider = this.normalizeAuthProviderType( + config.authProviderType, + ); + const oauthConfig = this.normalizeOAuthConfig(config.oauth); + + return new MCPServerConfig( + config.command, + config.args, + config.env, + config.cwd, + config.url, + config.httpUrl, + config.headers, + config.tcp, + config.timeout, + config.trust, + config.description, + config.includeTools, + config.excludeTools, + config.extensionName, + oauthConfig, + authProvider, + config.targetAudience, + config.targetServiceAccount, + ); + } + + private normalizeAuthProviderType( + value?: string, + ): AuthProviderType | undefined { + if (!value) { + return undefined; + } + + switch (value) { + case AuthProviderType.DYNAMIC_DISCOVERY: + case AuthProviderType.GOOGLE_CREDENTIALS: + case AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION: + return value; + default: + if (this.context.debugMode) { + console.error( + `[SystemController] Unsupported authProviderType '${value}', skipping`, + ); + } + return undefined; + } + } + + private normalizeOAuthConfig( + oauth?: CLIMcpServerConfig['oauth'], + ): MCPOAuthConfig | undefined { + if (!oauth) { + return undefined; + } + + return { + enabled: oauth.enabled, + clientId: oauth.clientId, + clientSecret: oauth.clientSecret, + authorizationUrl: oauth.authorizationUrl, + tokenUrl: oauth.tokenUrl, + scopes: oauth.scopes, + audiences: oauth.audiences, + redirectUri: oauth.redirectUri, + tokenParamName: oauth.tokenParamName, + registrationUrl: oauth.registrationUrl, + }; + } + /** * Handle interrupt request * @@ -183,7 +346,12 @@ export class SystemController extends BaseController { */ private async handleSetModel( payload: CLIControlSetModelRequest, + signal: AbortSignal, ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + const model = payload.model; // Validate model parameter @@ -223,8 +391,14 @@ export class SystemController extends BaseController { * * Returns list of supported slash commands loaded dynamically */ - private async handleSupportedCommands(): Promise> { - const slashCommands = await this.loadSlashCommandNames(); + private async handleSupportedCommands( + signal: AbortSignal, + ): Promise> { + if (signal.aborted) { + throw new Error('Request aborted'); + } + + const slashCommands = await this.loadSlashCommandNames(signal); return { subtype: 'supported_commands', @@ -235,15 +409,24 @@ export class SystemController extends BaseController { /** * Load slash command names using CommandService * + * @param signal - AbortSignal to respect for cancellation * @returns Promise resolving to array of slash command names */ - private async loadSlashCommandNames(): Promise { - const controller = new AbortController(); + private async loadSlashCommandNames(signal: AbortSignal): Promise { + if (signal.aborted) { + return []; + } + try { const service = await CommandService.create( [new BuiltinCommandLoader(this.context.config)], - controller.signal, + signal, ); + + if (signal.aborted) { + return []; + } + const names = new Set(); const commands = service.getCommands(); for (const command of commands) { @@ -251,6 +434,11 @@ export class SystemController extends BaseController { } return Array.from(names).sort(); } catch (error) { + // Check if the error is due to abort + if (signal.aborted) { + return []; + } + if (this.context.debugMode) { console.error( '[SystemController] Failed to load slash commands:', @@ -258,8 +446,6 @@ export class SystemController extends BaseController { ); } return []; - } finally { - controller.abort(); } } } diff --git a/packages/cli/src/nonInteractive/session.test.ts b/packages/cli/src/nonInteractive/session.test.ts index 6670d4c2..84d7dece 100644 --- a/packages/cli/src/nonInteractive/session.test.ts +++ b/packages/cli/src/nonInteractive/session.test.ts @@ -153,6 +153,11 @@ describe('runNonInteractiveStreamJson', () => { handleControlResponse: ReturnType; handleCancel: ReturnType; shutdown: ReturnType; + getPendingIncomingRequestCount: ReturnType; + waitForPendingIncomingRequests: ReturnType; + sdkMcpController: { + createSendSdkMcpMessage: ReturnType; + }; }; let mockConsolePatcher: { patch: ReturnType; @@ -187,6 +192,11 @@ describe('runNonInteractiveStreamJson', () => { handleControlResponse: vi.fn(), handleCancel: vi.fn(), shutdown: vi.fn(), + getPendingIncomingRequestCount: vi.fn().mockReturnValue(0), + waitForPendingIncomingRequests: vi.fn().mockResolvedValue(undefined), + sdkMcpController: { + createSendSdkMcpMessage: vi.fn().mockReturnValue(vi.fn()), + }, }; ( ControlDispatcher as unknown as ReturnType diff --git a/packages/cli/src/nonInteractive/session.ts b/packages/cli/src/nonInteractive/session.ts index 7cfa92c0..e8e6da12 100644 --- a/packages/cli/src/nonInteractive/session.ts +++ b/packages/cli/src/nonInteractive/session.ts @@ -4,7 +4,10 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { Config } from '@qwen-code/qwen-code-core'; +import type { + Config, + ConfigInitializeOptions, +} from '@qwen-code/qwen-code-core'; import { StreamJsonInputReader } from './io/StreamJsonInputReader.js'; import { StreamJsonOutputAdapter } from './io/StreamJsonOutputAdapter.js'; import { ControlContext } from './control/ControlContext.js'; @@ -50,6 +53,12 @@ class Session { private isShuttingDown: boolean = false; private configInitialized: boolean = false; + // Single initialization promise that resolves when session is ready for user messages. + // Created lazily once initialization actually starts. + private initializationPromise: Promise | null = null; + private initializationResolve: (() => void) | null = null; + private initializationReject: ((error: Error) => void) | null = null; + constructor(config: Config, initialPrompt?: CLIUserMessage) { this.config = config; this.sessionId = config.getSessionId(); @@ -66,12 +75,32 @@ class Session { this.setupSignalHandlers(); } + private ensureInitializationPromise(): void { + if (this.initializationPromise) { + return; + } + this.initializationPromise = new Promise((resolve, reject) => { + this.initializationResolve = () => { + resolve(); + this.initializationResolve = null; + this.initializationReject = null; + }; + this.initializationReject = (error: Error) => { + reject(error); + this.initializationResolve = null; + this.initializationReject = null; + }; + }); + } + private getNextPromptId(): string { this.promptIdCounter++; return `${this.sessionId}########${this.promptIdCounter}`; } - private async ensureConfigInitialized(): Promise { + private async ensureConfigInitialized( + options?: ConfigInitializeOptions, + ): Promise { if (this.configInitialized) { return; } @@ -81,7 +110,7 @@ class Session { } try { - await this.config.initialize(); + await this.config.initialize(options); this.configInitialized = true; } catch (error) { if (this.debugMode) { @@ -91,6 +120,44 @@ class Session { } } + /** + * Mark initialization as complete + */ + private completeInitialization(): void { + if (this.initializationResolve) { + if (this.debugMode) { + console.error('[Session] Initialization complete'); + } + this.initializationResolve(); + this.initializationResolve = null; + this.initializationReject = null; + } + } + + /** + * Mark initialization as failed + */ + private failInitialization(error: Error): void { + if (this.initializationReject) { + if (this.debugMode) { + console.error('[Session] Initialization failed:', error); + } + this.initializationReject(error); + this.initializationResolve = null; + this.initializationReject = null; + } + } + + /** + * Wait for session to be ready for user messages + */ + private async waitForInitialization(): Promise { + if (!this.initializationPromise) { + return; + } + await this.initializationPromise; + } + private ensureControlSystem(): void { if (this.controlContext && this.dispatcher && this.controlService) { return; @@ -120,49 +187,114 @@ class Session { return this.dispatcher; } - private async handleFirstMessage( + /** + * Handle the first message to determine session mode (SDK vs direct). + * This is synchronous from the message loop's perspective - it starts + * async work but does not return a promise that the loop awaits. + * + * The initialization completes asynchronously and resolves initializationPromise + * when ready for user messages. + */ + private handleFirstMessage( message: | CLIMessage | CLIControlRequest | CLIControlResponse | ControlCancelRequest, - ): Promise { + ): void { if (isControlRequest(message)) { const request = message as CLIControlRequest; this.controlSystemEnabled = true; this.ensureControlSystem(); - if (request.request.subtype === 'initialize') { - // Dispatch the initialize request first - await this.dispatcher?.dispatch(request); - // After handling initialize control request, initialize the config - // This is the SDK mode where config initialization is deferred - await this.ensureConfigInitialized(); - return true; + if (request.request.subtype === 'initialize') { + // Start SDK mode initialization (fire-and-forget from loop perspective) + void this.initializeSdkMode(request); + return; } + if (this.debugMode) { console.error( '[Session] Ignoring non-initialize control request during initialization', ); } - return true; + return; } if (isCLIUserMessage(message)) { this.controlSystemEnabled = false; - // For non-SDK mode (direct user message), initialize config if not already done - await this.ensureConfigInitialized(); - this.enqueueUserMessage(message as CLIUserMessage); - return true; + // Start direct mode initialization (fire-and-forget from loop perspective) + void this.initializeDirectMode(message as CLIUserMessage); + return; } this.controlSystemEnabled = false; - return false; } - private async handleControlRequest( - request: CLIControlRequest, + /** + * SDK mode initialization flow + * Dispatches initialize request and initializes config with MCP support + */ + private async initializeSdkMode(request: CLIControlRequest): Promise { + this.ensureInitializationPromise(); + try { + // Dispatch the initialize request first + // This registers SDK MCP servers in the control context + await this.dispatcher?.dispatch(request); + + // Get sendSdkMcpMessage callback from SdkMcpController + // This callback is used by McpClientManager to send MCP messages + // from CLI MCP clients to SDK MCP servers via the control plane + const sendSdkMcpMessage = + this.dispatcher?.sdkMcpController.createSendSdkMcpMessage(); + + // Initialize config with SDK MCP message support + await this.ensureConfigInitialized({ sendSdkMcpMessage }); + + // Initialization complete! + this.completeInitialization(); + } catch (error) { + if (this.debugMode) { + console.error('[Session] SDK mode initialization failed:', error); + } + this.failInitialization( + error instanceof Error ? error : new Error(String(error)), + ); + } + } + + /** + * Direct mode initialization flow + * Initializes config and enqueues the first user message + */ + private async initializeDirectMode( + userMessage: CLIUserMessage, ): Promise { + this.ensureInitializationPromise(); + try { + // Initialize config + await this.ensureConfigInitialized(); + + // Initialization complete! + this.completeInitialization(); + + // Enqueue the first user message for processing + this.enqueueUserMessage(userMessage); + } catch (error) { + if (this.debugMode) { + console.error('[Session] Direct mode initialization failed:', error); + } + this.failInitialization( + error instanceof Error ? error : new Error(String(error)), + ); + } + } + + /** + * Handle control request asynchronously (fire-and-forget from main loop). + * Errors are handled internally and responses sent by dispatcher. + */ + private handleControlRequestAsync(request: CLIControlRequest): void { const dispatcher = this.getDispatcher(); if (!dispatcher) { if (this.debugMode) { @@ -171,9 +303,20 @@ class Session { return; } - await dispatcher.dispatch(request); + // Fire-and-forget: dispatch runs concurrently + // The dispatcher's pendingIncomingRequests tracks completion + void dispatcher.dispatch(request).catch((error) => { + if (this.debugMode) { + console.error('[Session] Control request dispatch error:', error); + } + // Error response is already sent by dispatcher.dispatch() + }); } + /** + * Handle control response - MUST be synchronous + * This resolves pending outgoing requests, breaking the deadlock cycle. + */ private handleControlResponse(response: CLIControlResponse): void { const dispatcher = this.getDispatcher(); if (!dispatcher) { @@ -201,8 +344,8 @@ class Session { return; } - // Ensure config is initialized before processing user messages - await this.ensureConfigInitialized(); + // Wait for initialization to complete before processing user messages + await this.waitForInitialization(); const promptId = this.getNextPromptId(); @@ -307,6 +450,45 @@ class Session { process.on('SIGTERM', this.shutdownHandler); } + /** + * Wait for all pending work to complete before shutdown + */ + private async waitForAllPendingWork(): Promise { + // 1. Wait for initialization to complete (or fail) + try { + await this.waitForInitialization(); + } catch (error) { + if (this.debugMode) { + console.error('[Session] Initialization error during shutdown:', error); + } + } + + // 2. Wait for all control request handlers using dispatcher's tracking + if (this.dispatcher) { + const pendingCount = this.dispatcher.getPendingIncomingRequestCount(); + if (pendingCount > 0 && this.debugMode) { + console.error( + `[Session] Waiting for ${pendingCount} pending control request handlers`, + ); + } + await this.dispatcher.waitForPendingIncomingRequests(); + } + + // 3. Wait for user message processing queue + while (this.processingPromise) { + if (this.debugMode) { + console.error('[Session] Waiting for user message processing'); + } + try { + await this.processingPromise; + } catch (error) { + if (this.debugMode) { + console.error('[Session] Error in user message processing:', error); + } + } + } + } + private async shutdown(): Promise { if (this.debugMode) { console.error('[Session] Shutting down'); @@ -314,18 +496,8 @@ class Session { this.isShuttingDown = true; - if (this.processingPromise) { - try { - await this.processingPromise; - } catch (error) { - if (this.debugMode) { - console.error( - '[Session] Error waiting for processing to complete:', - error, - ); - } - } - } + // Wait for all pending work + await this.waitForAllPendingWork(); this.dispatcher?.shutdown(); this.cleanupSignalHandlers(); @@ -339,18 +511,30 @@ class Session { } } + /** + * Main message processing loop + * + * CRITICAL: This loop must NEVER await handlers that might need to + * send control requests and wait for responses. Such handlers must + * be started in fire-and-forget mode, allowing the loop to continue + * reading responses that resolve pending requests. + * + * Message handling order: + * 1. control_response - FIRST, synchronously resolves pending requests + * 2. First message - determines mode, starts async initialization + * 3. control_request - fire-and-forget, tracked by dispatcher + * 4. control_cancel - synchronous + * 5. user_message - enqueued for processing + */ async run(): Promise { try { if (this.debugMode) { console.error('[Session] Starting session', this.sessionId); } + // Handle initial prompt if provided (fire-and-forget) if (this.initialPrompt !== null) { - const handled = await this.handleFirstMessage(this.initialPrompt); - if (handled && this.isShuttingDown) { - await this.shutdown(); - return; - } + this.handleFirstMessage(this.initialPrompt); } try { @@ -359,23 +543,33 @@ class Session { break; } - if (this.controlSystemEnabled === null) { - const handled = await this.handleFirstMessage(message); - if (handled) { - if (this.isShuttingDown) { - break; - } - continue; - } + // ============================================================ + // CRITICAL: Handle control_response FIRST and SYNCHRONOUSLY + // This resolves pending outgoing requests, breaking deadlock. + // ============================================================ + if (isControlResponse(message)) { + this.handleControlResponse(message as CLIControlResponse); + continue; } + // Handle first message to determine session mode + if (this.controlSystemEnabled === null) { + this.handleFirstMessage(message); + continue; + } + + // ============================================================ + // CRITICAL: Handle control_request in FIRE-AND-FORGET mode + // DON'T await - let handler run concurrently while loop continues + // Dispatcher's pendingIncomingRequests tracks completion + // ============================================================ if (isControlRequest(message)) { - await this.handleControlRequest(message as CLIControlRequest); - } else if (isControlResponse(message)) { - this.handleControlResponse(message as CLIControlResponse); + this.handleControlRequestAsync(message as CLIControlRequest); } else if (isControlCancel(message)) { + // Cancel is synchronous - OK to handle inline this.handleControlCancel(message as ControlCancelRequest); } else if (isCLIUserMessage(message)) { + // User messages are enqueued, processing runs separately this.enqueueUserMessage(message as CLIUserMessage); } else if (this.debugMode) { if ( @@ -402,19 +596,8 @@ class Session { throw streamError; } - while (this.processingPromise) { - if (this.debugMode) { - console.error('[Session] Waiting for final processing to complete'); - } - try { - await this.processingPromise; - } catch (error) { - if (this.debugMode) { - console.error('[Session] Error in final processing:', error); - } - } - } - + // Stream ended - wait for all pending work before shutdown + await this.waitForAllPendingWork(); await this.shutdown(); } catch (error) { if (this.debugMode) { diff --git a/packages/cli/src/nonInteractive/types.ts b/packages/cli/src/nonInteractive/types.ts index 131c1be0..1d5e800d 100644 --- a/packages/cli/src/nonInteractive/types.ts +++ b/packages/cli/src/nonInteractive/types.ts @@ -1,8 +1,5 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import type { - MCPServerConfig, - SubagentConfig, -} from '@qwen-code/qwen-code-core'; +import type { SubagentConfig } from '@qwen-code/qwen-code-core'; /** * Annotation for attaching metadata to content blocks @@ -298,11 +295,68 @@ export interface CLIControlPermissionRequest { blocked_path: string | null; } +/** + * Wire format for SDK MCP server config in initialization request. + * The actual Server instance stays in the SDK process. + */ +export interface SDKMcpServerConfig { + type: 'sdk'; + name: string; +} + +/** + * Wire format for external MCP server config in initialization request. + * Represents stdio/SSE/HTTP/TCP transports that must run in the CLI process. + */ +export interface CLIMcpServerConfig { + command?: string; + args?: string[]; + env?: Record; + cwd?: string; + url?: string; + httpUrl?: string; + headers?: Record; + tcp?: string; + timeout?: number; + trust?: boolean; + description?: string; + includeTools?: string[]; + excludeTools?: string[]; + extensionName?: string; + oauth?: { + enabled?: boolean; + clientId?: string; + clientSecret?: string; + authorizationUrl?: string; + tokenUrl?: string; + scopes?: string[]; + audiences?: string[]; + redirectUri?: string; + tokenParamName?: string; + registrationUrl?: string; + }; + authProviderType?: + | 'dynamic_discovery' + | 'google_credentials' + | 'service_account_impersonation'; + targetAudience?: string; + targetServiceAccount?: string; +} + export interface CLIControlInitializeRequest { subtype: 'initialize'; hooks?: HookRegistration[] | null; - sdkMcpServers?: Record; - mcpServers?: Record; + /** + * SDK MCP servers config + * These are MCP servers running in the SDK process, connected via control plane. + * External MCP servers are configured separately in settings, not via initialization. + */ + sdkMcpServers?: Record>; + /** + * External MCP servers that the SDK wants the CLI to manage. + * These run outside the SDK process and require CLI-side transport setup. + */ + mcpServers?: Record; agents?: SubagentConfig[]; } diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index 1c83432d..6aa49306 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -63,6 +63,7 @@ vi.mock('../tools/tool-registry', () => { ToolRegistryMock.prototype.registerTool = vi.fn(); ToolRegistryMock.prototype.discoverAllTools = vi.fn(); ToolRegistryMock.prototype.getAllTools = vi.fn(() => []); // Mock methods if needed + ToolRegistryMock.prototype.getAllToolNames = vi.fn(() => []); ToolRegistryMock.prototype.getTool = vi.fn(); ToolRegistryMock.prototype.getFunctionDeclarations = vi.fn(() => []); return { ToolRegistry: ToolRegistryMock }; diff --git a/packages/core/src/config/config.ts b/packages/core/src/config/config.ts index d3e0fd5c..6383cb17 100644 --- a/packages/core/src/config/config.ts +++ b/packages/core/src/config/config.ts @@ -46,6 +46,7 @@ import { ExitPlanModeTool } from '../tools/exitPlanMode.js'; import { GlobTool } from '../tools/glob.js'; import { GrepTool } from '../tools/grep.js'; import { LSTool } from '../tools/ls.js'; +import type { SendSdkMcpMessage } from '../tools/mcp-client.js'; import { MemoryTool, setGeminiMdFilename } from '../tools/memoryTool.js'; import { ReadFileTool } from '../tools/read-file.js'; import { ReadManyFilesTool } from '../tools/read-many-files.js'; @@ -239,9 +240,18 @@ export class MCPServerConfig { readonly targetAudience?: string, /* targetServiceAccount format: @.iam.gserviceaccount.com */ readonly targetServiceAccount?: string, + // SDK MCP server type - 'sdk' indicates server runs in SDK process + readonly type?: 'sdk', ) {} } +/** + * Check if an MCP server config represents an SDK server + */ +export function isSdkMcpServerConfig(config: MCPServerConfig): boolean { + return config.type === 'sdk'; +} + export enum AuthProviderType { DYNAMIC_DISCOVERY = 'dynamic_discovery', GOOGLE_CREDENTIALS = 'google_credentials', @@ -360,6 +370,17 @@ function normalizeConfigOutputFormat( } } +/** + * Options for Config.initialize() + */ +export interface ConfigInitializeOptions { + /** + * Callback for sending MCP messages to SDK servers via control plane. + * Required for SDK MCP server support in SDK mode. + */ + sendSdkMcpMessage?: SendSdkMcpMessage; +} + export class Config { private sessionId: string; private sessionData?: ResumedSessionData; @@ -599,8 +620,9 @@ export class Config { /** * Must only be called once, throws if called again. + * @param options Optional initialization options including sendSdkMcpMessage callback */ - async initialize(): Promise { + async initialize(options?: ConfigInitializeOptions): Promise { if (this.initialized) { throw Error('Config was already initialized'); } @@ -619,7 +641,9 @@ export class Config { this.subagentManager.loadSessionSubagents(this.sessionSubagents); } - this.toolRegistry = await this.createToolRegistry(); + this.toolRegistry = await this.createToolRegistry( + options?.sendSdkMcpMessage, + ); await this.geminiClient.initialize(); @@ -1261,8 +1285,14 @@ export class Config { return this.subagentManager; } - async createToolRegistry(): Promise { - const registry = new ToolRegistry(this, this.eventEmitter); + async createToolRegistry( + sendSdkMcpMessage?: SendSdkMcpMessage, + ): Promise { + const registry = new ToolRegistry( + this, + this.eventEmitter, + sendSdkMcpMessage, + ); const coreToolsConfig = this.getCoreTools(); const excludeToolsConfig = this.getExcludeTools(); @@ -1347,6 +1377,7 @@ export class Config { } await registry.discoverAllTools(); + console.debug('ToolRegistry created', registry.getAllToolNames()); return registry; } } diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 38ac7ada..738aca57 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -102,7 +102,9 @@ export * from './tools/shell.js'; export * from './tools/web-search/index.js'; export * from './tools/read-many-files.js'; export * from './tools/mcp-client.js'; +export * from './tools/mcp-client-manager.js'; export * from './tools/mcp-tool.js'; +export * from './tools/sdk-control-client-transport.js'; export * from './tools/task.js'; export * from './tools/todoWrite.js'; export * from './tools/exitPlanMode.js'; diff --git a/packages/core/src/tools/mcp-client-manager.ts b/packages/core/src/tools/mcp-client-manager.ts index 93e25ea8..a8b48236 100644 --- a/packages/core/src/tools/mcp-client-manager.ts +++ b/packages/core/src/tools/mcp-client-manager.ts @@ -5,6 +5,7 @@ */ import type { Config, MCPServerConfig } from '../config/config.js'; +import { isSdkMcpServerConfig } from '../config/config.js'; import type { ToolRegistry } from './tool-registry.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import { @@ -12,6 +13,7 @@ import { MCPDiscoveryState, populateMcpServerCommand, } from './mcp-client.js'; +import type { SendSdkMcpMessage } from './mcp-client.js'; import { getErrorMessage } from '../utils/errors.js'; import type { EventEmitter } from 'node:events'; import type { WorkspaceContext } from '../utils/workspaceContext.js'; @@ -31,6 +33,7 @@ export class McpClientManager { private readonly workspaceContext: WorkspaceContext; private discoveryState: MCPDiscoveryState = MCPDiscoveryState.NOT_STARTED; private readonly eventEmitter?: EventEmitter; + private readonly sendSdkMcpMessage?: SendSdkMcpMessage; constructor( mcpServers: Record, @@ -40,6 +43,7 @@ export class McpClientManager { debugMode: boolean, workspaceContext: WorkspaceContext, eventEmitter?: EventEmitter, + sendSdkMcpMessage?: SendSdkMcpMessage, ) { this.mcpServers = mcpServers; this.mcpServerCommand = mcpServerCommand; @@ -48,6 +52,7 @@ export class McpClientManager { this.debugMode = debugMode; this.workspaceContext = workspaceContext; this.eventEmitter = eventEmitter; + this.sendSdkMcpMessage = sendSdkMcpMessage; } /** @@ -71,6 +76,11 @@ export class McpClientManager { this.eventEmitter?.emit('mcp-client-update', this.clients); const discoveryPromises = Object.entries(servers).map( async ([name, config]) => { + // For SDK MCP servers, pass the sendSdkMcpMessage callback + const sdkCallback = isSdkMcpServerConfig(config) + ? this.sendSdkMcpMessage + : undefined; + const client = new McpClient( name, config, @@ -78,6 +88,7 @@ export class McpClientManager { this.promptRegistry, this.workspaceContext, this.debugMode, + sdkCallback, ); this.clients.set(name, client); diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index a6903d13..efea02ad 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -13,6 +13,7 @@ import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/ import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; import type { GetPromptResult, + JSONRPCMessage, Prompt, } from '@modelcontextprotocol/sdk/types.js'; import { @@ -22,10 +23,11 @@ import { } from '@modelcontextprotocol/sdk/types.js'; import { parse } from 'shell-quote'; import type { Config, MCPServerConfig } from '../config/config.js'; -import { AuthProviderType } from '../config/config.js'; +import { AuthProviderType, isSdkMcpServerConfig } from '../config/config.js'; import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { ServiceAccountImpersonationProvider } from '../mcp/sa-impersonation-provider.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; +import { SdkControlClientTransport } from './sdk-control-client-transport.js'; import type { FunctionDeclaration } from '@google/genai'; import { mcpToTool } from '@google/genai'; @@ -42,6 +44,14 @@ import type { } from '../utils/workspaceContext.js'; import type { ToolRegistry } from './tool-registry.js'; +/** + * Callback type for sending MCP messages to SDK servers via control plane + */ +export type SendSdkMcpMessage = ( + serverName: string, + message: JSONRPCMessage, +) => Promise; + export const MCP_DEFAULT_TIMEOUT_MSEC = 10 * 60 * 1000; // default to 10 minutes export type DiscoveredMCPPrompt = Prompt & { @@ -92,6 +102,7 @@ export class McpClient { private readonly promptRegistry: PromptRegistry, private readonly workspaceContext: WorkspaceContext, private readonly debugMode: boolean, + private readonly sendSdkMcpMessage?: SendSdkMcpMessage, ) { this.client = new Client({ name: `qwen-cli-mcp-client-${this.serverName}`, @@ -189,7 +200,12 @@ export class McpClient { } private async createTransport(): Promise { - return createTransport(this.serverName, this.serverConfig, this.debugMode); + return createTransport( + this.serverName, + this.serverConfig, + this.debugMode, + this.sendSdkMcpMessage, + ); } private async discoverTools(cliConfig: Config): Promise { @@ -501,6 +517,7 @@ export function populateMcpServerCommand( * @param mcpServerName The name identifier for this MCP server * @param mcpServerConfig Configuration object containing connection details * @param toolRegistry The registry to register discovered tools with + * @param sendSdkMcpMessage Optional callback for SDK MCP servers to route messages via control plane. * @returns Promise that resolves when discovery is complete */ export async function connectAndDiscover( @@ -511,6 +528,7 @@ export async function connectAndDiscover( debugMode: boolean, workspaceContext: WorkspaceContext, cliConfig: Config, + sendSdkMcpMessage?: SendSdkMcpMessage, ): Promise { updateMCPServerStatus(mcpServerName, MCPServerStatus.CONNECTING); @@ -521,6 +539,7 @@ export async function connectAndDiscover( mcpServerConfig, debugMode, workspaceContext, + sendSdkMcpMessage, ); mcpClient.onerror = (error) => { @@ -744,6 +763,7 @@ export function hasNetworkTransport(config: MCPServerConfig): boolean { * * @param mcpServerName The name of the MCP server, used for logging and identification. * @param mcpServerConfig The configuration specifying how to connect to the server. + * @param sendSdkMcpMessage Optional callback for SDK MCP servers to route messages via control plane. * @returns A promise that resolves to a connected MCP `Client` instance. * @throws An error if the connection fails or the configuration is invalid. */ @@ -752,6 +772,7 @@ export async function connectToMcpServer( mcpServerConfig: MCPServerConfig, debugMode: boolean, workspaceContext: WorkspaceContext, + sendSdkMcpMessage?: SendSdkMcpMessage, ): Promise { const mcpClient = new Client({ name: 'qwen-code-mcp-client', @@ -808,6 +829,7 @@ export async function connectToMcpServer( mcpServerName, mcpServerConfig, debugMode, + sendSdkMcpMessage, ); try { await mcpClient.connect(transport, { @@ -1172,7 +1194,21 @@ export async function createTransport( mcpServerName: string, mcpServerConfig: MCPServerConfig, debugMode: boolean, + sendSdkMcpMessage?: SendSdkMcpMessage, ): Promise { + if (isSdkMcpServerConfig(mcpServerConfig)) { + if (!sendSdkMcpMessage) { + throw new Error( + `SDK MCP server '${mcpServerName}' requires sendSdkMcpMessage callback`, + ); + } + return new SdkControlClientTransport({ + serverName: mcpServerName, + sendMcpMessage: sendSdkMcpMessage, + debugMode, + }); + } + if ( mcpServerConfig.authProviderType === AuthProviderType.SERVICE_ACCOUNT_IMPERSONATION diff --git a/packages/core/src/tools/sdk-control-client-transport.ts b/packages/core/src/tools/sdk-control-client-transport.ts new file mode 100644 index 00000000..be2f3099 --- /dev/null +++ b/packages/core/src/tools/sdk-control-client-transport.ts @@ -0,0 +1,163 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + +/** + * SdkControlClientTransport - MCP Client transport for SDK MCP servers + * + * This transport enables CLI's MCP client to connect to SDK MCP servers + * through the control plane. Messages are routed: + * + * CLI MCP Client → SdkControlClientTransport → sendMcpMessage() → + * control_request (mcp_message) → SDK → control_response → onmessage → CLI + * + * Unlike StdioClientTransport which spawns a subprocess, this transport + * communicates with SDK MCP servers running in the SDK process. + */ + +import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; + +/** + * Callback to send MCP messages to SDK via control plane + * Returns the MCP response from the SDK + */ +export type SendMcpMessageCallback = ( + serverName: string, + message: JSONRPCMessage, +) => Promise; + +export interface SdkControlClientTransportOptions { + serverName: string; + sendMcpMessage: SendMcpMessageCallback; + debugMode?: boolean; +} + +/** + * MCP Client Transport for SDK MCP servers + * + * Implements the @modelcontextprotocol/sdk Transport interface to enable + * CLI's MCP client to connect to SDK MCP servers via the control plane. + */ +export class SdkControlClientTransport { + private serverName: string; + private sendMcpMessage: SendMcpMessageCallback; + private debugMode: boolean; + private started = false; + + // Transport interface callbacks + onmessage?: (message: JSONRPCMessage) => void; + onerror?: (error: Error) => void; + onclose?: () => void; + + constructor(options: SdkControlClientTransportOptions) { + this.serverName = options.serverName; + this.sendMcpMessage = options.sendMcpMessage; + this.debugMode = options.debugMode ?? false; + } + + /** + * Start the transport + * For SDK transport, this just marks it as ready - no subprocess to spawn + */ + async start(): Promise { + if (this.started) { + return; + } + + this.started = true; + + if (this.debugMode) { + console.error( + `[SdkControlClientTransport] Started for server '${this.serverName}'`, + ); + } + } + + /** + * Send a message to the SDK MCP server via control plane + * + * Routes the message through the control plane and delivers + * the response via onmessage callback. + */ + async send(message: JSONRPCMessage): Promise { + if (!this.started) { + throw new Error( + `SdkControlClientTransport (${this.serverName}) not started. Call start() first.`, + ); + } + + if (this.debugMode) { + console.error( + `[SdkControlClientTransport] Sending message to '${this.serverName}':`, + JSON.stringify(message), + ); + } + + try { + // Send message to SDK and wait for response + const response = await this.sendMcpMessage(this.serverName, message); + + if (this.debugMode) { + console.error( + `[SdkControlClientTransport] Received response from '${this.serverName}':`, + JSON.stringify(response), + ); + } + + // Deliver response via onmessage callback + if (this.onmessage) { + this.onmessage(response); + } + } catch (error) { + if (this.debugMode) { + console.error( + `[SdkControlClientTransport] Error sending to '${this.serverName}':`, + error, + ); + } + + if (this.onerror) { + this.onerror(error instanceof Error ? error : new Error(String(error))); + } + + throw error; + } + } + + /** + * Close the transport + */ + async close(): Promise { + if (!this.started) { + return; + } + + this.started = false; + + if (this.debugMode) { + console.error( + `[SdkControlClientTransport] Closed for server '${this.serverName}'`, + ); + } + + if (this.onclose) { + this.onclose(); + } + } + + /** + * Check if transport is started + */ + isStarted(): boolean { + return this.started; + } + + /** + * Get server name + */ + getServerName(): string { + return this.serverName; + } +} diff --git a/packages/core/src/tools/tool-registry.ts b/packages/core/src/tools/tool-registry.ts index a0123107..9b641647 100644 --- a/packages/core/src/tools/tool-registry.ts +++ b/packages/core/src/tools/tool-registry.ts @@ -16,6 +16,7 @@ import type { Config } from '../config/config.js'; import { spawn } from 'node:child_process'; import { StringDecoder } from 'node:string_decoder'; import { connectAndDiscover } from './mcp-client.js'; +import type { SendSdkMcpMessage } from './mcp-client.js'; import { McpClientManager } from './mcp-client-manager.js'; import { DiscoveredMCPTool } from './mcp-tool.js'; import { parse } from 'shell-quote'; @@ -173,7 +174,11 @@ export class ToolRegistry { private config: Config; private mcpClientManager: McpClientManager; - constructor(config: Config, eventEmitter?: EventEmitter) { + constructor( + config: Config, + eventEmitter?: EventEmitter, + sendSdkMcpMessage?: SendSdkMcpMessage, + ) { this.config = config; this.mcpClientManager = new McpClientManager( this.config.getMcpServers() ?? {}, @@ -183,6 +188,7 @@ export class ToolRegistry { this.config.getDebugMode(), this.config.getWorkspaceContext(), eventEmitter, + sendSdkMcpMessage, ); } diff --git a/packages/sdk-typescript/README.md b/packages/sdk-typescript/README.md index ed441bc7..bc3ef6aa 100644 --- a/packages/sdk-typescript/README.md +++ b/packages/sdk-typescript/README.md @@ -1,4 +1,4 @@ -# @qwen-code/sdk-typescript +# @qwen-code/sdk A minimum experimental TypeScript SDK for programmatic access to Qwen Code. @@ -7,20 +7,20 @@ Feel free to submit a feature request/issue/PR. ## Installation ```bash -npm install @qwen-code/sdk-typescript +npm install @qwen-code/sdk ``` ## Requirements - Node.js >= 20.0.0 -- [Qwen Code](https://github.com/QwenLM/qwen-code) installed and accessible in PATH +- [Qwen Code](https://github.com/QwenLM/qwen-code) >= 0.4.0 (stable) installed and accessible in PATH > **Note for nvm users**: If you use nvm to manage Node.js versions, the SDK may not be able to auto-detect the Qwen Code executable. You should explicitly set the `pathToQwenExecutable` option to the full path of the `qwen` binary. ## Quick Start ```typescript -import { query } from '@qwen-code/sdk-typescript'; +import { query } from '@qwen-code/sdk'; // Single-turn query const result = query({ @@ -59,9 +59,9 @@ Creates a new query session with the Qwen Code. | `model` | `string` | - | The AI model to use (e.g., `'qwen-max'`, `'qwen-plus'`, `'qwen-turbo'`). Takes precedence over `OPENAI_MODEL` and `QWEN_MODEL` environment variables. | | `pathToQwenExecutable` | `string` | Auto-detected | Path to the Qwen Code executable. Supports multiple formats: `'qwen'` (native binary from PATH), `'/path/to/qwen'` (explicit path), `'/path/to/cli.js'` (Node.js bundle), `'node:/path/to/cli.js'` (force Node.js runtime), `'bun:/path/to/cli.js'` (force Bun runtime). If not provided, auto-detects from: `QWEN_CODE_CLI_PATH` env var, `~/.volta/bin/qwen`, `~/.npm-global/bin/qwen`, `/usr/local/bin/qwen`, `~/.local/bin/qwen`, `~/node_modules/.bin/qwen`, `~/.yarn/bin/qwen`. | | `permissionMode` | `'default' \| 'plan' \| 'auto-edit' \| 'yolo'` | `'default'` | Permission mode controlling tool execution approval. See [Permission Modes](#permission-modes) for details. | -| `canUseTool` | `CanUseTool` | - | Custom permission handler for tool execution approval. Invoked when a tool requires confirmation. Must respond within 30 seconds or the request will be auto-denied. See [Custom Permission Handler](#custom-permission-handler). | +| `canUseTool` | `CanUseTool` | - | Custom permission handler for tool execution approval. Invoked when a tool requires confirmation. Must respond within 60 seconds or the request will be auto-denied. See [Custom Permission Handler](#custom-permission-handler). | | `env` | `Record` | - | Environment variables to pass to the Qwen Code process. Merged with the current process environment. | -| `mcpServers` | `Record` | - | External MCP (Model Context Protocol) servers to connect. Each server is identified by a unique name and configured with `command`, `args`, and `env`. | +| `mcpServers` | `Record` | - | MCP (Model Context Protocol) servers to connect. Supports external servers (stdio/SSE/HTTP) and SDK-embedded servers. External servers are configured with transport options like `command`, `args`, `url`, `httpUrl`, etc. SDK servers use `{ type: 'sdk', name: string, instance: Server }`. | | `abortController` | `AbortController` | - | Controller to cancel the query session. Call `abortController.abort()` to terminate the session and cleanup resources. | | `debug` | `boolean` | `false` | Enable debug mode for verbose logging from the CLI process. | | `maxSessionTurns` | `number` | `-1` (unlimited) | Maximum number of conversation turns before the session automatically terminates. A turn consists of a user message and an assistant response. | @@ -74,12 +74,27 @@ Creates a new query session with the Qwen Code. ### Timeouts -The SDK enforces the following timeouts: +The SDK enforces the following default timeouts: -| Timeout | Duration | Description | -| ------------------- | ---------- | ---------------------------------------------------------------------------------------------------------------------------- | -| Permission Callback | 30 seconds | Maximum time for `canUseTool` callback to respond. If exceeded, the tool request is auto-denied. | -| Control Request | 30 seconds | Maximum time for control operations like `initialize()`, `setModel()`, `setPermissionMode()`, and `interrupt()` to complete. | +| Timeout | Default | Description | +| ---------------- | -------- | ---------------------------------------------------------------------------------------------------------------------------- | +| `canUseTool` | 1 minute | Maximum time for `canUseTool` callback to respond. If exceeded, the tool request is auto-denied. | +| `mcpRequest` | 1 minute | Maximum time for SDK MCP tool calls to complete. | +| `controlRequest` | 1 minute | Maximum time for control operations like `initialize()`, `setModel()`, `setPermissionMode()`, and `interrupt()` to complete. | +| `streamClose` | 1 minute | Maximum time to wait for initialization to complete before closing CLI stdin in multi-turn mode with SDK MCP servers. | + +You can customize these timeouts via the `timeout` option: + +```typescript +const query = qwen.query('Your prompt', { + timeout: { + canUseTool: 60000, // 60 seconds for permission callback + mcpRequest: 600000, // 10 minutes for MCP tool calls + controlRequest: 60000, // 60 seconds for control requests + streamClose: 15000, // 15 seconds for stream close wait + }, +}); +``` ### Message Types @@ -92,7 +107,7 @@ import { isSDKSystemMessage, isSDKResultMessage, isSDKPartialAssistantMessage, -} from '@qwen-code/sdk-typescript'; +} from '@qwen-code/sdk'; for await (const message of result) { if (isSDKAssistantMessage(message)) { @@ -152,7 +167,7 @@ The SDK supports different permission modes for controlling tool execution: ### Multi-turn Conversation ```typescript -import { query, type SDKUserMessage } from '@qwen-code/sdk-typescript'; +import { query, type SDKUserMessage } from '@qwen-code/sdk'; async function* generateMessages(): AsyncIterable { yield { @@ -186,7 +201,7 @@ for await (const message of result) { ### Custom Permission Handler ```typescript -import { query, type CanUseTool } from '@qwen-code/sdk-typescript'; +import { query, type CanUseTool } from '@qwen-code/sdk'; const canUseTool: CanUseTool = async (toolName, input, { signal }) => { // Allow all read operations @@ -212,10 +227,10 @@ const result = query({ }); ``` -### With MCP Servers +### With External MCP Servers ```typescript -import { query } from '@qwen-code/sdk-typescript'; +import { query } from '@qwen-code/sdk'; const result = query({ prompt: 'Use the custom tool from my MCP server', @@ -231,10 +246,88 @@ const result = query({ }); ``` +### With SDK-Embedded MCP Servers + +The SDK provides `tool` and `createSdkMcpServer` to create MCP servers that run in the same process as your SDK application. This is useful when you want to expose custom tools to the AI without running a separate server process. + +#### `tool(name, description, inputSchema, handler)` + +Creates a tool definition with Zod schema type inference. + +| Parameter | Type | Description | +| ------------- | ---------------------------------- | ------------------------------------------------------------------------ | +| `name` | `string` | Tool name (1-64 chars, starts with letter, alphanumeric and underscores) | +| `description` | `string` | Human-readable description of what the tool does | +| `inputSchema` | `ZodRawShape` | Zod schema object defining the tool's input parameters | +| `handler` | `(args, extra) => Promise` | Async function that executes the tool and returns MCP content blocks | + +The handler must return a `CallToolResult` object with the following structure: + +```typescript +{ + content: Array< + | { type: 'text'; text: string } + | { type: 'image'; data: string; mimeType: string } + | { type: 'resource'; uri: string; mimeType?: string; text?: string } + >; + isError?: boolean; +} +``` + +#### `createSdkMcpServer(options)` + +Creates an SDK-embedded MCP server instance. + +| Option | Type | Default | Description | +| --------- | ------------------------ | --------- | ------------------------------------ | +| `name` | `string` | Required | Unique name for the MCP server | +| `version` | `string` | `'1.0.0'` | Server version | +| `tools` | `SdkMcpToolDefinition[]` | - | Array of tools created with `tool()` | + +Returns a `McpSdkServerConfigWithInstance` object that can be passed directly to the `mcpServers` option. + +#### Example + +```typescript +import { z } from 'zod'; +import { query, tool, createSdkMcpServer } from '@qwen-code/sdk'; + +// Define a tool with Zod schema +const calculatorTool = tool( + 'calculate_sum', + 'Add two numbers', + { a: z.number(), b: z.number() }, + async (args) => ({ + content: [{ type: 'text', text: String(args.a + args.b) }], + }), +); + +// Create the MCP server +const server = createSdkMcpServer({ + name: 'calculator', + tools: [calculatorTool], +}); + +// Use the server in a query +const result = query({ + prompt: 'What is 42 + 17?', + options: { + permissionMode: 'yolo', + mcpServers: { + calculator: server, + }, + }, +}); + +for await (const message of result) { + console.log(message); +} +``` + ### Abort a Query ```typescript -import { query, isAbortError } from '@qwen-code/sdk-typescript'; +import { query, isAbortError } from '@qwen-code/sdk'; const abortController = new AbortController(); @@ -266,7 +359,7 @@ try { The SDK provides an `AbortError` class for handling aborted queries: ```typescript -import { AbortError, isAbortError } from '@qwen-code/sdk-typescript'; +import { AbortError, isAbortError } from '@qwen-code/sdk'; try { // ... query operations diff --git a/packages/sdk-typescript/package.json b/packages/sdk-typescript/package.json index 0f234603..b0f35709 100644 --- a/packages/sdk-typescript/package.json +++ b/packages/sdk-typescript/package.json @@ -1,5 +1,5 @@ { - "name": "@qwen-code/sdk-typescript", + "name": "@qwen-code/sdk", "version": "0.1.0", "description": "TypeScript SDK for programmatic access to qwen-code CLI", "main": "./dist/index.cjs", diff --git a/packages/sdk-typescript/scripts/get-release-version.js b/packages/sdk-typescript/scripts/get-release-version.js index 349bfd07..c6b1f665 100644 --- a/packages/sdk-typescript/scripts/get-release-version.js +++ b/packages/sdk-typescript/scripts/get-release-version.js @@ -14,7 +14,7 @@ import { dirname, join } from 'node:path'; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); -const PACKAGE_NAME = '@qwen-code/sdk-typescript'; +const PACKAGE_NAME = '@qwen-code/sdk'; const TAG_PREFIX = 'sdk-typescript-v'; function readJson(filePath) { diff --git a/packages/sdk-typescript/src/index.ts b/packages/sdk-typescript/src/index.ts index da40baf2..4ae46597 100644 --- a/packages/sdk-typescript/src/index.ts +++ b/packages/sdk-typescript/src/index.ts @@ -3,6 +3,17 @@ export { AbortError, isAbortError } from './types/errors.js'; export { Query } from './query/Query.js'; export { SdkLogger } from './utils/logger.js'; +// SDK MCP Server exports +export { tool } from './mcp/tool.js'; +export { createSdkMcpServer } from './mcp/createSdkMcpServer.js'; + +export type { SdkMcpToolDefinition } from './mcp/tool.js'; + +export type { + CreateSdkMcpServerOptions, + McpSdkServerConfigWithInstance, +} from './mcp/createSdkMcpServer.js'; + export type { QueryOptions } from './query/createQuery.js'; export type { LogLevel, LoggerConfig, ScopedLogger } from './utils/logger.js'; @@ -18,6 +29,7 @@ export type { SDKResultMessage, SDKPartialAssistantMessage, SDKMessage, + SDKMcpServerConfig, ControlMessage, CLIControlRequest, CLIControlResponse, @@ -43,6 +55,10 @@ export type { PermissionMode, CanUseTool, PermissionResult, - ExternalMcpServerConfig, - SdkMcpServerConfig, + CLIMcpServerConfig, + McpServerConfig, + McpOAuthConfig, + McpAuthProviderType, } from './types/types.js'; + +export { isSdkMcpServerConfig } from './types/types.js'; diff --git a/packages/sdk-typescript/src/mcp/SdkControlServerTransport.ts b/packages/sdk-typescript/src/mcp/SdkControlServerTransport.ts index 06392a4f..28db7b2d 100644 --- a/packages/sdk-typescript/src/mcp/SdkControlServerTransport.ts +++ b/packages/sdk-typescript/src/mcp/SdkControlServerTransport.ts @@ -103,9 +103,3 @@ export class SdkControlServerTransport { return this.serverName; } } - -export function createSdkControlServerTransport( - options: SdkControlServerTransportOptions, -): SdkControlServerTransport { - return new SdkControlServerTransport(options); -} diff --git a/packages/sdk-typescript/src/mcp/createSdkMcpServer.ts b/packages/sdk-typescript/src/mcp/createSdkMcpServer.ts index 841440e1..cf2482d6 100644 --- a/packages/sdk-typescript/src/mcp/createSdkMcpServer.ts +++ b/packages/sdk-typescript/src/mcp/createSdkMcpServer.ts @@ -1,29 +1,63 @@ /** - * Factory function to create SDK-embedded MCP servers - * - * Creates MCP Server instances that run in the user's Node.js process - * and are proxied to the CLI via the control plane. + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 */ -import { Server } from '@modelcontextprotocol/sdk/server/index.js'; -import { - ListToolsRequestSchema, - CallToolRequestSchema, - type CallToolResultSchema, -} from '@modelcontextprotocol/sdk/types.js'; -import type { ToolDefinition } from '../types/types.js'; -import { formatToolResult, formatToolError } from './formatters.js'; +/** + * Factory function to create SDK-embedded MCP servers + */ + +import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import type { SdkMcpToolDefinition } from './tool.js'; import { validateToolName } from './tool.js'; -import type { z } from 'zod'; -type CallToolResult = z.infer; +/** + * Options for creating an SDK MCP server + */ +export type CreateSdkMcpServerOptions = { + name: string; + version?: string; + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ + tools?: Array>; +}; +/** + * SDK MCP Server configuration with instance + */ +export type McpSdkServerConfigWithInstance = { + type: 'sdk'; + name: string; + instance: McpServer; +}; + +/** + * Creates an MCP server instance that can be used with the SDK transport. + * + * @example + * ```typescript + * import { z } from 'zod'; + * import { tool, createSdkMcpServer } from '@qwen-code/sdk'; + * + * const calculatorTool = tool( + * 'calculate_sum', + * 'Add two numbers', + * { a: z.number(), b: z.number() }, + * async (args) => ({ content: [{ type: 'text', text: String(args.a + args.b) }] }) + * ); + * + * const server = createSdkMcpServer({ + * name: 'calculator', + * version: '1.0.0', + * tools: [calculatorTool], + * }); + * ``` + */ export function createSdkMcpServer( - name: string, - version: string, - tools: ToolDefinition[], -): Server { - // Validate server name + options: CreateSdkMcpServerOptions, +): McpSdkServerConfigWithInstance { + const { name, version = '1.0.0', tools } = options; + if (!name || typeof name !== 'string') { throw new Error('MCP server name must be a non-empty string'); } @@ -32,78 +66,42 @@ export function createSdkMcpServer( throw new Error('MCP server version must be a non-empty string'); } - if (!Array.isArray(tools)) { + if (tools !== undefined && !Array.isArray(tools)) { throw new Error('Tools must be an array'); } - // Validate tool names are unique const toolNames = new Set(); - for (const tool of tools) { - validateToolName(tool.name); - - if (toolNames.has(tool.name)) { - throw new Error( - `Duplicate tool name '${tool.name}' in MCP server '${name}'`, - ); + if (tools) { + for (const t of tools) { + validateToolName(t.name); + if (toolNames.has(t.name)) { + throw new Error( + `Duplicate tool name '${t.name}' in MCP server '${name}'`, + ); + } + toolNames.add(t.name); } - toolNames.add(tool.name); } - // Create MCP Server instance - const server = new Server( - { - name, - version, - }, + const server = new McpServer( + { name, version }, { capabilities: { - tools: {}, + tools: tools ? {} : undefined, }, }, ); - // Create tool map for fast lookup - const toolMap = new Map(); - for (const tool of tools) { - toolMap.set(tool.name, tool); + if (tools) { + tools.forEach((toolDef) => { + server.tool( + toolDef.name, + toolDef.description, + toolDef.inputSchema, + toolDef.handler, + ); + }); } - // Register list_tools handler - server.setRequestHandler(ListToolsRequestSchema, async () => ({ - tools: tools.map((tool) => ({ - name: tool.name, - description: tool.description, - inputSchema: tool.inputSchema, - })), - })); - - // Register call_tool handler - server.setRequestHandler(CallToolRequestSchema, async (request) => { - const { name: toolName, arguments: toolArgs } = request.params; - - // Find tool - const tool = toolMap.get(toolName); - if (!tool) { - return formatToolError( - new Error(`Tool '${toolName}' not found in server '${name}'`), - ) as CallToolResult; - } - - try { - // Invoke tool handler - const result = await tool.handler(toolArgs); - - // Format result - return formatToolResult(result) as CallToolResult; - } catch (error) { - // Handle tool execution error - return formatToolError( - error instanceof Error - ? error - : new Error(`Tool '${toolName}' failed: ${String(error)}`), - ) as CallToolResult; - } - }); - - return server; + return { type: 'sdk', name, instance: server }; } diff --git a/packages/sdk-typescript/src/mcp/tool.ts b/packages/sdk-typescript/src/mcp/tool.ts index 667bf5e5..53e00399 100644 --- a/packages/sdk-typescript/src/mcp/tool.ts +++ b/packages/sdk-typescript/src/mcp/tool.ts @@ -1,39 +1,76 @@ /** - * Tool definition helper for SDK-embedded MCP servers - * - * Provides type-safe tool definitions with generic input/output types. + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 */ -import type { ToolDefinition } from '../types/types.js'; +/** + * Tool definition helper for SDK-embedded MCP servers + */ -export function tool( - def: ToolDefinition, -): ToolDefinition { - // Validate tool definition - if (!def.name || typeof def.name !== 'string') { - throw new Error('Tool definition must have a name (string)'); +import type { CallToolResultSchema } from '@modelcontextprotocol/sdk/types.js'; +import type { z, ZodRawShape, ZodObject, ZodTypeAny } from 'zod'; + +type CallToolResult = z.infer; + +/** + * SDK MCP Tool Definition with Zod schema type inference + */ +export type SdkMcpToolDefinition = { + name: string; + description: string; + inputSchema: Schema; + handler: ( + args: z.infer>, + extra: unknown, + ) => Promise; +}; + +/** + * Create an SDK MCP tool definition with Zod schema inference + * + * @example + * ```typescript + * import { z } from 'zod'; + * import { tool } from '@qwen-code/sdk'; + * + * const calculatorTool = tool( + * 'calculate_sum', + * 'Calculate the sum of two numbers', + * { a: z.number(), b: z.number() }, + * async (args) => { + * // args is inferred as { a: number, b: number } + * return { content: [{ type: 'text', text: String(args.a + args.b) }] }; + * } + * ); + * ``` + */ +export function tool( + name: string, + description: string, + inputSchema: Schema, + handler: ( + args: z.infer>, + extra: unknown, + ) => Promise, +): SdkMcpToolDefinition { + if (!name || typeof name !== 'string') { + throw new Error('Tool name must be a non-empty string'); } - if (!def.description || typeof def.description !== 'string') { - throw new Error( - `Tool definition for '${def.name}' must have a description (string)`, - ); + if (!description || typeof description !== 'string') { + throw new Error(`Tool '${name}' must have a description (string)`); } - if (!def.inputSchema || typeof def.inputSchema !== 'object') { - throw new Error( - `Tool definition for '${def.name}' must have an inputSchema (object)`, - ); + if (!inputSchema || typeof inputSchema !== 'object') { + throw new Error(`Tool '${name}' must have an inputSchema (object)`); } - if (!def.handler || typeof def.handler !== 'function') { - throw new Error( - `Tool definition for '${def.name}' must have a handler (function)`, - ); + if (!handler || typeof handler !== 'function') { + throw new Error(`Tool '${name}' must have a handler (function)`); } - // Return definition (pass-through for type safety) - return def; + return { name, description, inputSchema, handler }; } export function validateToolName(name: string): void { @@ -53,39 +90,3 @@ export function validateToolName(name: string): void { ); } } - -export function validateInputSchema(schema: unknown): void { - if (!schema || typeof schema !== 'object') { - throw new Error('Input schema must be an object'); - } - - const schemaObj = schema as Record; - - if (!schemaObj.type) { - throw new Error('Input schema must have a type field'); - } - - // For object schemas, validate properties - if (schemaObj.type === 'object') { - if (schemaObj.properties && typeof schemaObj.properties !== 'object') { - throw new Error('Input schema properties must be an object'); - } - - if (schemaObj.required && !Array.isArray(schemaObj.required)) { - throw new Error('Input schema required must be an array'); - } - } -} - -export function createTool( - def: ToolDefinition, -): ToolDefinition { - // Validate via tool() function - const validated = tool(def); - - // Additional validation - validateToolName(validated.name); - validateInputSchema(validated.inputSchema); - - return validated; -} diff --git a/packages/sdk-typescript/src/query/Query.ts b/packages/sdk-typescript/src/query/Query.ts index 849b0d7b..78bb10b9 100644 --- a/packages/sdk-typescript/src/query/Query.ts +++ b/packages/sdk-typescript/src/query/Query.ts @@ -5,10 +5,10 @@ * Implements AsyncIterator protocol for message consumption. */ -const PERMISSION_CALLBACK_TIMEOUT = 30000; -const MCP_REQUEST_TIMEOUT = 30000; -const CONTROL_REQUEST_TIMEOUT = 30000; -const STREAM_CLOSE_TIMEOUT = 10000; +const DEFAULT_CAN_USE_TOOL_TIMEOUT = 60_000; +const DEFAULT_MCP_REQUEST_TIMEOUT = 60_000; +const DEFAULT_CONTROL_REQUEST_TIMEOUT = 60_000; +const DEFAULT_STREAM_CLOSE_TIMEOUT = 60_000; import { randomUUID } from 'node:crypto'; import { SdkLogger } from '../utils/logger.js'; @@ -19,6 +19,7 @@ import type { CLIControlResponse, ControlCancelRequest, PermissionSuggestion, + WireSDKMcpServerConfig, } from '../types/protocol.js'; import { isSDKUserMessage, @@ -31,12 +32,17 @@ import { isControlCancel, } from '../types/protocol.js'; import type { Transport } from '../transport/Transport.js'; -import type { QueryOptions } from '../types/types.js'; +import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; +import type { QueryOptions, CLIMcpServerConfig } from '../types/types.js'; +import { isSdkMcpServerConfig } from '../types/types.js'; import { Stream } from '../utils/Stream.js'; import { serializeJsonLine } from '../utils/jsonLines.js'; import { AbortError } from '../types/errors.js'; import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; -import type { SdkControlServerTransport } from '../mcp/SdkControlServerTransport.js'; +import { + SdkControlServerTransport, + type SdkControlServerTransportOptions, +} from '../mcp/SdkControlServerTransport.js'; import { ControlRequestType } from '../types/protocol.js'; interface PendingControlRequest { @@ -46,6 +52,11 @@ interface PendingControlRequest { abortController: AbortController; } +interface PendingMcpResponse { + resolve: (response: JSONRPCMessage) => void; + reject: (error: Error) => void; +} + interface TransportWithEndInput extends Transport { endInput(): void; } @@ -61,7 +72,9 @@ export class Query implements AsyncIterable { private abortController: AbortController; private pendingControlRequests: Map = new Map(); + private pendingMcpResponses: Map = new Map(); private sdkMcpTransports: Map = new Map(); + private sdkMcpServers: Map = new Map(); readonly initialized: Promise; private closed = false; private messageRouterStarted = false; @@ -92,6 +105,11 @@ export class Query implements AsyncIterable { */ this.sdkMessages = this.readSdkMessages(); + /** + * Promise that resolves when the first SDKResultMessage is received. + * Used to coordinate endInput() timing - ensures all initialization + * (SDK MCP servers, control responses) is complete before closing CLI stdin. + */ this.firstResultReceivedPromise = new Promise((resolve) => { this.firstResultReceivedResolve = resolve; }); @@ -121,17 +139,152 @@ export class Query implements AsyncIterable { this.startMessageRouter(); } + private async initializeSdkMcpServers(): Promise { + if (!this.options.mcpServers) { + return; + } + + const connectionPromises: Array> = []; + + // Extract SDK MCP servers from the unified mcpServers config + for (const [key, config] of Object.entries(this.options.mcpServers)) { + if (!isSdkMcpServerConfig(config)) { + continue; // Skip external MCP servers + } + + // Use the name from SDKMcpServerConfig, fallback to key for backwards compatibility + const serverName = config.name || key; + const server = config.instance; + + // Create transport options with callback to route MCP server responses + const transportOptions: SdkControlServerTransportOptions = { + sendToQuery: async (message: JSONRPCMessage) => { + this.handleMcpServerResponse(serverName, message); + }, + serverName, + }; + + const sdkTransport = new SdkControlServerTransport(transportOptions); + + // Connect server to transport and only register on success + const connectionPromise = server + .connect(sdkTransport) + .then(() => { + // Only add to maps after successful connection + this.sdkMcpServers.set(serverName, server); + this.sdkMcpTransports.set(serverName, sdkTransport); + logger.debug(`SDK MCP server '${serverName}' connected to transport`); + }) + .catch((error) => { + logger.error( + `Failed to connect SDK MCP server '${serverName}' to transport:`, + error, + ); + // Don't throw - one failed server shouldn't prevent others + }); + + connectionPromises.push(connectionPromise); + } + + // Wait for all connection attempts to complete + await Promise.all(connectionPromises); + + if (this.sdkMcpServers.size > 0) { + logger.info( + `Initialized ${this.sdkMcpServers.size} SDK MCP server(s): ${Array.from(this.sdkMcpServers.keys()).join(', ')}`, + ); + } + } + + /** + * Handle response messages from SDK MCP servers + * + * When an MCP server sends a response via transport.send(), this callback + * routes it back to the pending request that's waiting for it. + */ + private handleMcpServerResponse( + serverName: string, + message: JSONRPCMessage, + ): void { + // Check if this is a response with an id + if ('id' in message && message.id !== null && message.id !== undefined) { + const key = `${serverName}:${message.id}`; + const pending = this.pendingMcpResponses.get(key); + if (pending) { + logger.debug( + `Routing MCP response for server '${serverName}', id: ${message.id}`, + ); + pending.resolve(message); + this.pendingMcpResponses.delete(key); + return; + } + } + + // If no pending request found, log a warning (this shouldn't happen normally) + logger.warn( + `Received MCP server response with no pending request: server='${serverName}'`, + message, + ); + } + + /** + * Get SDK MCP servers config for CLI initialization + * + * Only SDK servers are sent in the initialize request. + */ + private getSdkMcpServersForCli(): Record { + const sdkServers: Record = {}; + + for (const [name] of this.sdkMcpServers.entries()) { + sdkServers[name] = { type: 'sdk', name }; + } + + return sdkServers; + } + + /** + * Get external MCP servers (non-SDK) that should be managed by the CLI + */ + private getMcpServersForCli(): Record { + if (!this.options.mcpServers) { + return {}; + } + + const externalServers: Record = {}; + + for (const [name, config] of Object.entries(this.options.mcpServers)) { + if (isSdkMcpServerConfig(config)) { + continue; + } + externalServers[name] = config as CLIMcpServerConfig; + } + + return externalServers; + } + private async initialize(): Promise { try { logger.debug('Initializing Query'); - const sdkMcpServerNames = Array.from(this.sdkMcpTransports.keys()); + // Initialize SDK MCP servers and wait for connections + await this.initializeSdkMcpServers(); + + // Get only successfully connected SDK servers for CLI + const sdkMcpServersForCli = this.getSdkMcpServersForCli(); + const mcpServersForCli = this.getMcpServersForCli(); + logger.debug('SDK MCP servers for CLI:', sdkMcpServersForCli); + logger.debug('External MCP servers for CLI:', mcpServersForCli); await this.sendControlRequest(ControlRequestType.INITIALIZE, { hooks: null, sdkMcpServers: - sdkMcpServerNames.length > 0 ? sdkMcpServerNames : undefined, - mcpServers: this.options.mcpServers, + Object.keys(sdkMcpServersForCli).length > 0 + ? sdkMcpServersForCli + : undefined, + mcpServers: + Object.keys(mcpServersForCli).length > 0 + ? mcpServersForCli + : undefined, agents: this.options.agents, }); logger.info('Query initialized successfully'); @@ -279,10 +432,13 @@ export class Query implements AsyncIterable { } try { + const canUseToolTimeout = + this.options.timeout?.canUseTool ?? DEFAULT_CAN_USE_TOOL_TIMEOUT; + let timeoutId: NodeJS.Timeout | undefined; const timeoutPromise = new Promise((_, reject) => { - setTimeout( + timeoutId = setTimeout( () => reject(new Error('Permission callback timeout')), - PERMISSION_CALLBACK_TIMEOUT, + canUseToolTimeout, ); }); @@ -296,6 +452,10 @@ export class Query implements AsyncIterable { timeoutPromise, ]); + if (timeoutId) { + clearTimeout(timeoutId); + } + if (result.behavior === 'allow') { return { behavior: 'allow', @@ -361,32 +521,45 @@ export class Query implements AsyncIterable { } private handleMcpRequest( - _serverName: string, + serverName: string, message: JSONRPCMessage, transport: SdkControlServerTransport, ): Promise { + const messageId = 'id' in message ? message.id : null; + const key = `${serverName}:${messageId}`; + return new Promise((resolve, reject) => { + const mcpRequestTimeout = + this.options.timeout?.mcpRequest ?? DEFAULT_MCP_REQUEST_TIMEOUT; const timeout = setTimeout(() => { + this.pendingMcpResponses.delete(key); reject(new Error('MCP request timeout')); - }, MCP_REQUEST_TIMEOUT); + }, mcpRequestTimeout); - const messageId = 'id' in message ? message.id : null; - - /** - * Hook into transport to capture response. - * Temporarily replace sendToQuery to intercept the response message - * matching this request's ID, then restore the original handler. - */ - const originalSend = transport.sendToQuery; - transport.sendToQuery = async (responseMessage: JSONRPCMessage) => { - if ('id' in responseMessage && responseMessage.id === messageId) { - clearTimeout(timeout); - transport.sendToQuery = originalSend; - resolve(responseMessage); - } - return originalSend(responseMessage); + const cleanup = () => { + clearTimeout(timeout); + this.pendingMcpResponses.delete(key); }; + const resolveAndCleanup = (response: JSONRPCMessage) => { + cleanup(); + resolve(response); + }; + + const rejectAndCleanup = (error: Error) => { + cleanup(); + reject(error); + }; + + // Register pending response handler + this.pendingMcpResponses.set(key, { + resolve: resolveAndCleanup, + reject: rejectAndCleanup, + }); + + // Deliver message to MCP server via transport.onmessage + // The server will process it and call transport.send() with the response, + // which triggers handleMcpServerResponse to resolve our pending promise transport.handleMessage(message); }); } @@ -452,6 +625,10 @@ export class Query implements AsyncIterable { subtype: string, data: Record = {}, ): Promise | null> { + if (this.closed) { + return Promise.reject(new Error('Query is closed')); + } + const requestId = randomUUID(); const request: CLIControlRequest = { @@ -466,10 +643,13 @@ export class Query implements AsyncIterable { const responsePromise = new Promise | null>( (resolve, reject) => { const abortController = new AbortController(); + const controlRequestTimeout = + this.options.timeout?.controlRequest ?? + DEFAULT_CONTROL_REQUEST_TIMEOUT; const timeout = setTimeout(() => { this.pendingControlRequests.delete(requestId); reject(new Error(`Control request timeout: ${subtype}`)); - }, CONTROL_REQUEST_TIMEOUT); + }, controlRequestTimeout); this.pendingControlRequests.set(requestId, { resolve, @@ -517,9 +697,16 @@ export class Query implements AsyncIterable { for (const pending of this.pendingControlRequests.values()) { pending.abortController.abort(); clearTimeout(pending.timeout); + pending.reject(new Error('Query is closed')); } this.pendingControlRequests.clear(); + // Clean up pending MCP responses + for (const pending of this.pendingMcpResponses.values()) { + pending.reject(new Error('Query is closed')); + } + this.pendingMcpResponses.clear(); + await this.transport.close(); /** @@ -542,7 +729,7 @@ export class Query implements AsyncIterable { } } this.sdkMcpTransports.clear(); - logger.info('Query closed'); + logger.info('Query is closed'); } private async *readSdkMessages(): AsyncGenerator { @@ -588,24 +775,39 @@ export class Query implements AsyncIterable { } /** - * In multi-turn mode with MCP servers, wait for first result - * to ensure MCP servers have time to process before next input. - * This prevents race conditions where the next input arrives before - * MCP servers have finished processing the current request. + * After all user messages are sent (for-await loop ended), determine when to + * close the CLI's stdin via endInput(). + * + * - If a result message was already received: All initialization (SDK MCP servers, + * control responses, etc.) is complete, safe to close stdin immediately. + * - If no result yet: Wait for either the result to arrive, or the timeout to expire. + * This gives pending control_responses from SDK MCP servers or other modules + * time to complete their initialization before we close the input stream. + * + * The timeout ensures we don't hang indefinitely - either the turn proceeds + * normally, or it fails with a timeout, but Promise.race will always resolve. */ if ( !this.isSingleTurn && this.sdkMcpTransports.size > 0 && this.firstResultReceivedPromise ) { - await Promise.race([ - this.firstResultReceivedPromise, - new Promise((resolve) => { - setTimeout(() => { - resolve(); - }, STREAM_CLOSE_TIMEOUT); - }), - ]); + const streamCloseTimeout = + this.options.timeout?.streamClose ?? DEFAULT_STREAM_CLOSE_TIMEOUT; + let timeoutId: NodeJS.Timeout | undefined; + + const timeoutPromise = new Promise((resolve) => { + timeoutId = setTimeout(() => { + logger.info('streamCloseTimeout resolved'); + resolve(); + }, streamCloseTimeout); + }); + + await Promise.race([this.firstResultReceivedPromise, timeoutPromise]); + + if (timeoutId) { + clearTimeout(timeoutId); + } } this.endInput(); @@ -635,28 +837,16 @@ export class Query implements AsyncIterable { } async interrupt(): Promise { - if (this.closed) { - throw new Error('Query is closed'); - } - await this.sendControlRequest(ControlRequestType.INTERRUPT); } async setPermissionMode(mode: string): Promise { - if (this.closed) { - throw new Error('Query is closed'); - } - await this.sendControlRequest(ControlRequestType.SET_PERMISSION_MODE, { mode, }); } async setModel(model: string): Promise { - if (this.closed) { - throw new Error('Query is closed'); - } - await this.sendControlRequest(ControlRequestType.SET_MODEL, { model }); } @@ -667,10 +857,6 @@ export class Query implements AsyncIterable { * @throws Error if query is closed */ async supportedCommands(): Promise | null> { - if (this.closed) { - throw new Error('Query is closed'); - } - return this.sendControlRequest(ControlRequestType.SUPPORTED_COMMANDS); } @@ -681,10 +867,6 @@ export class Query implements AsyncIterable { * @throws Error if query is closed */ async mcpServerStatus(): Promise | null> { - if (this.closed) { - throw new Error('Query is closed'); - } - return this.sendControlRequest(ControlRequestType.MCP_SERVER_STATUS); } diff --git a/packages/sdk-typescript/src/types/protocol.ts b/packages/sdk-typescript/src/types/protocol.ts index efb61cb4..e5eeb121 100644 --- a/packages/sdk-typescript/src/types/protocol.ts +++ b/packages/sdk-typescript/src/types/protocol.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import type { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; export interface Annotation { type: string; value: string; @@ -293,10 +294,44 @@ export interface MCPServerConfig { targetServiceAccount?: string; } +/** + * SDK MCP Server configuration + * + * SDK MCP servers run in the SDK process and are connected via in-memory transport. + * Tool calls are routed through the control plane between SDK and CLI. + */ +export interface SDKMcpServerConfig { + /** + * Type identifier for SDK MCP servers + */ + type: 'sdk'; + /** + * Server name for identification and routing + */ + name: string; + /** + * The MCP Server instance created by createSdkMcpServer() + */ + instance: McpServer; +} + +/** + * Wire format for SDK MCP servers sent to the CLI + */ +export type WireSDKMcpServerConfig = Omit; + export interface CLIControlInitializeRequest { subtype: 'initialize'; hooks?: HookRegistration[] | null; - sdkMcpServers?: Record; + /** + * SDK MCP servers config + * These are MCP servers running in the SDK process, connected via control plane. + * External MCP servers are configured separately in settings, not via initialization. + */ + sdkMcpServers?: Record; + /** + * External MCP servers that should be managed by the CLI. + */ mcpServers?: Record; agents?: SubagentConfig[]; } diff --git a/packages/sdk-typescript/src/types/queryOptionsSchema.ts b/packages/sdk-typescript/src/types/queryOptionsSchema.ts index 579445cf..a4794b3f 100644 --- a/packages/sdk-typescript/src/types/queryOptionsSchema.ts +++ b/packages/sdk-typescript/src/types/queryOptionsSchema.ts @@ -2,19 +2,98 @@ import { z } from 'zod'; import type { CanUseTool } from './types.js'; import type { SubagentConfig } from './protocol.js'; -export const ExternalMcpServerConfigSchema = z.object({ - command: z.string().min(1, 'Command must be a non-empty string'), +/** + * OAuth configuration for MCP servers + */ +export const McpOAuthConfigSchema = z + .object({ + enabled: z.boolean().optional(), + clientId: z + .string() + .min(1, 'clientId must be a non-empty string') + .optional(), + clientSecret: z.string().optional(), + scopes: z.array(z.string()).optional(), + redirectUri: z.string().optional(), + authorizationUrl: z.string().optional(), + tokenUrl: z.string().optional(), + audiences: z.array(z.string()).optional(), + tokenParamName: z.string().optional(), + registrationUrl: z.string().optional(), + }) + .strict(); + +/** + * CLI MCP Server configuration schema + * + * Supports multiple transport types: + * - stdio: command, args, env, cwd + * - SSE: url + * - Streamable HTTP: httpUrl, headers + * - WebSocket: tcp + */ +export const CLIMcpServerConfigSchema = z.object({ + // For stdio transport + command: z.string().optional(), args: z.array(z.string()).optional(), env: z.record(z.string(), z.string()).optional(), + cwd: z.string().optional(), + // For SSE transport + url: z.string().optional(), + // For streamable HTTP transport + httpUrl: z.string().optional(), + headers: z.record(z.string(), z.string()).optional(), + // For WebSocket transport + tcp: z.string().optional(), + // Common + timeout: z.number().optional(), + trust: z.boolean().optional(), + // Metadata + description: z.string().optional(), + includeTools: z.array(z.string()).optional(), + excludeTools: z.array(z.string()).optional(), + extensionName: z.string().optional(), + // OAuth configuration + oauth: McpOAuthConfigSchema.optional(), + authProviderType: z + .enum([ + 'dynamic_discovery', + 'google_credentials', + 'service_account_impersonation', + ]) + .optional(), + // Service Account Configuration + targetAudience: z.string().optional(), + targetServiceAccount: z.string().optional(), }); +/** + * SDK MCP Server configuration schema + */ export const SdkMcpServerConfigSchema = z.object({ - connect: z.custom<(transport: unknown) => Promise>( - (val) => typeof val === 'function', - { message: 'connect must be a function' }, + type: z.literal('sdk'), + name: z.string().min(1, 'name must be a non-empty string'), + instance: z.custom<{ + connect(transport: unknown): Promise; + close(): Promise; + }>( + (val) => + val && + typeof val === 'object' && + 'connect' in val && + typeof val.connect === 'function', + { message: 'instance must be an MCP Server with connect method' }, ), }); +/** + * Unified MCP Server configuration schema + */ +export const McpServerConfigSchema = z.union([ + CLIMcpServerConfigSchema, + SdkMcpServerConfigSchema, +]); + export const ModelConfigSchema = z.object({ model: z.string().optional(), temp: z.number().optional(), @@ -37,6 +116,13 @@ export const SubagentConfigSchema = z.object({ isBuiltin: z.boolean().optional(), }); +export const TimeoutConfigSchema = z.object({ + canUseTool: z.number().positive().optional(), + mcpRequest: z.number().positive().optional(), + controlRequest: z.number().positive().optional(), + streamClose: z.number().positive().optional(), +}); + export const QueryOptionsSchema = z .object({ cwd: z.string().optional(), @@ -49,7 +135,7 @@ export const QueryOptionsSchema = z message: 'canUseTool must be a function', }) .optional(), - mcpServers: z.record(z.string(), ExternalMcpServerConfigSchema).optional(), + mcpServers: z.record(z.string(), McpServerConfigSchema).optional(), abortController: z.instanceof(AbortController).optional(), debug: z.boolean().optional(), stderr: z @@ -78,5 +164,6 @@ export const QueryOptionsSchema = z ) .optional(), includePartialMessages: z.boolean().optional(), + timeout: TimeoutConfigSchema.optional(), }) .strict(); diff --git a/packages/sdk-typescript/src/types/types.ts b/packages/sdk-typescript/src/types/types.ts index a3f6cd03..24dc0575 100644 --- a/packages/sdk-typescript/src/types/types.ts +++ b/packages/sdk-typescript/src/types/types.ts @@ -2,25 +2,11 @@ import type { PermissionMode, PermissionSuggestion, SubagentConfig, + SDKMcpServerConfig, } from './protocol.js'; export type { PermissionMode }; -type JSONSchema = { - type: string; - properties?: Record; - required?: string[]; - description?: string; - [key: string]: unknown; -}; - -export type ToolDefinition = { - name: string; - description: string; - inputSchema: JSONSchema; - handler: (input: TInput) => Promise; -}; - export type TransportOptions = { pathToQwenExecutable: string; cwd?: string; @@ -61,14 +47,115 @@ export type PermissionResult = interrupt?: boolean; }; -export interface ExternalMcpServerConfig { - command: string; - args?: string[]; - env?: Record; +/** + * OAuth configuration for MCP servers + */ +export interface McpOAuthConfig { + enabled?: boolean; + clientId?: string; + clientSecret?: string; + scopes?: string[]; + redirectUri?: string; + authorizationUrl?: string; + tokenUrl?: string; + audiences?: string[]; + tokenParamName?: string; + registrationUrl?: string; } -export interface SdkMcpServerConfig { - connect: (transport: unknown) => Promise; +/** + * Auth provider type for MCP servers + */ +export type McpAuthProviderType = + | 'dynamic_discovery' + | 'google_credentials' + | 'service_account_impersonation'; + +/** + * CLI MCP Server configuration + * + * Supports multiple transport types: + * - stdio: command, args, env, cwd + * - SSE: url + * - Streamable HTTP: httpUrl, headers + * - WebSocket: tcp + * + * This interface aligns with MCPServerConfig in @qwen-code/qwen-code-core. + */ +export interface CLIMcpServerConfig { + // For stdio transport + command?: string; + args?: string[]; + env?: Record; + cwd?: string; + // For SSE transport + url?: string; + // For streamable HTTP transport + httpUrl?: string; + headers?: Record; + // For WebSocket transport + tcp?: string; + // Common + timeout?: number; + trust?: boolean; + // Metadata + description?: string; + includeTools?: string[]; + excludeTools?: string[]; + extensionName?: string; + // OAuth configuration + oauth?: McpOAuthConfig; + authProviderType?: McpAuthProviderType; + // Service Account Configuration + /** targetAudience format: CLIENT_ID.apps.googleusercontent.com */ + targetAudience?: string; + /** targetServiceAccount format: @.iam.gserviceaccount.com */ + targetServiceAccount?: string; +} + +/** + * Unified MCP Server configuration + * + * Supports both external MCP servers (stdio/SSE/HTTP/WebSocket) and SDK-embedded MCP servers. + * + * @example External MCP server (stdio) + * ```typescript + * mcpServers: { + * 'my-server': { command: 'node', args: ['server.js'] } + * } + * ``` + * + * @example External MCP server (SSE) + * ```typescript + * mcpServers: { + * 'remote-server': { url: 'http://localhost:3000/sse' } + * } + * ``` + * + * @example External MCP server (Streamable HTTP) + * ```typescript + * mcpServers: { + * 'http-server': { httpUrl: 'http://localhost:3000/mcp', headers: { 'Authorization': 'Bearer token' } } + * } + * ``` + * + * @example SDK MCP server + * ```typescript + * const server = createSdkMcpServer('weather', '1.0.0', [weatherTool]); + * mcpServers: { + * 'weather': { type: 'sdk', name: 'weather', instance: server } + * } + * ``` + */ +export type McpServerConfig = CLIMcpServerConfig | SDKMcpServerConfig; + +/** + * Type guard to check if a config is an SDK MCP server + */ +export function isSdkMcpServerConfig( + config: McpServerConfig, +): config is SDKMcpServerConfig { + return 'type' in config && config.type === 'sdk'; } /** @@ -174,11 +261,36 @@ export interface QueryOptions { canUseTool?: CanUseTool; /** - * External MCP (Model Context Protocol) servers to connect to. - * Each server is identified by a unique name and configured with command, args, and environment. - * @example { 'my-server': { command: 'node', args: ['server.js'], env: { PORT: '3000' } } } + * MCP (Model Context Protocol) servers to connect to. + * + * Supports both external MCP servers and SDK-embedded MCP servers: + * + * **External MCP servers** - Run in separate processes, connected via stdio/SSE/HTTP: + * ```typescript + * mcpServers: { + * 'stdio-server': { command: 'node', args: ['server.js'], env: { PORT: '3000' } }, + * 'sse-server': { url: 'http://localhost:3000/sse' }, + * 'http-server': { httpUrl: 'http://localhost:3000/mcp' } + * } + * ``` + * + * **SDK MCP servers** - Run in the SDK process, connected via in-memory transport: + * ```typescript + * const myTool = tool({ + * name: 'my_tool', + * description: 'My custom tool', + * inputSchema: { type: 'object', properties: { input: { type: 'string' } } }, + * handler: async (input) => ({ result: input.input.toUpperCase() }), + * }); + * + * const server = createSdkMcpServer('my-server', '1.0.0', [myTool]); + * + * mcpServers: { + * 'my-server': { type: 'sdk', name: 'my-server', instance: server } + * } + * ``` */ - mcpServers?: Record; + mcpServers?: Record; /** * AbortController to cancel the query session. @@ -204,7 +316,7 @@ export interface QueryOptions { /** * Logging level for the SDK. * Controls the verbosity of log messages output by the SDK. - * @default 'info' + * @default 'error' */ logLevel?: 'debug' | 'info' | 'warn' | 'error'; @@ -294,4 +406,43 @@ export interface QueryOptions { * @default false */ includePartialMessages?: boolean; + + /** + * Timeout configuration for various SDK operations. + * All values are in milliseconds. + */ + timeout?: { + /** + * Timeout for the `canUseTool` callback. + * If the callback doesn't resolve within this time, the permission request + * will be denied with a timeout error (fail-safe behavior). + * @default 60000 (1 minute) + */ + canUseTool?: number; + + /** + * Timeout for SDK MCP tool calls. + * This applies to tool calls made to SDK-embedded MCP servers. + * @default 60000 (1 minute) + */ + mcpRequest?: number; + + /** + * Timeout for SDK→CLI control requests. + * This applies to internal control operations like initialize, interrupt, + * setPermissionMode, setModel, etc. + * @default 60000 (1 minute) + */ + controlRequest?: number; + + /** + * Timeout for waiting before closing CLI's stdin after user messages are sent. + * In multi-turn mode with SDK MCP servers, after all user messages are processed, + * the SDK waits for the first result message to ensure all initialization + * (control responses, MCP server setup, etc.) is complete before closing stdin. + * This timeout is a fallback to avoid hanging indefinitely. + * @default 60000 (1 minute) + */ + streamClose?: number; + }; } diff --git a/packages/sdk-typescript/src/utils/logger.ts b/packages/sdk-typescript/src/utils/logger.ts index afb7a495..caf57ede 100644 --- a/packages/sdk-typescript/src/utils/logger.ts +++ b/packages/sdk-typescript/src/utils/logger.ts @@ -22,7 +22,7 @@ const LOG_LEVEL_PRIORITY: Record = { export class SdkLogger { private static config: LoggerConfig = {}; - private static effectiveLevel: LogLevel = 'info'; + private static effectiveLevel: LogLevel = 'error'; static configure(config: LoggerConfig): void { this.config = config; @@ -47,7 +47,7 @@ export class SdkLogger { return 'debug'; } - return 'info'; + return 'error'; } private static isValidLogLevel(level: string): boolean { diff --git a/packages/sdk-typescript/test/unit/Query.test.ts b/packages/sdk-typescript/test/unit/Query.test.ts index 2b89ca51..1dd0a992 100644 --- a/packages/sdk-typescript/test/unit/Query.test.ts +++ b/packages/sdk-typescript/test/unit/Query.test.ts @@ -542,13 +542,16 @@ describe('Query', () => { const canUseTool = vi.fn().mockImplementation( () => new Promise((resolve) => { - setTimeout(() => resolve({ behavior: 'allow' }), 35000); // Exceeds 30s timeout + setTimeout(() => resolve({ behavior: 'allow' }), 15000); }), ); const query = new Query(transport, { cwd: '/test', canUseTool, + timeout: { + canUseTool: 10000, + }, }); const controlReq = createControlRequest('can_use_tool', 'perm-req-4'); @@ -567,7 +570,7 @@ describe('Query', () => { }); } }, - { timeout: 35000 }, + { timeout: 15000 }, ); await query.close(); @@ -1204,7 +1207,12 @@ describe('Query', () => { }); it('should handle control request timeout', async () => { - const query = new Query(transport, { cwd: '/test' }); + const query = new Query(transport, { + cwd: '/test', + timeout: { + controlRequest: 10000, + }, + }); // Respond to initialize await vi.waitFor(() => { @@ -1224,7 +1232,7 @@ describe('Query', () => { await expect(interruptPromise).rejects.toThrow(/timeout/i); await query.close(); - }, 35000); + }, 15000); it('should handle malformed control responses', async () => { const query = new Query(transport, { cwd: '/test' }); diff --git a/packages/sdk-typescript/test/unit/createSdkMcpServer.test.ts b/packages/sdk-typescript/test/unit/createSdkMcpServer.test.ts index e608ba7b..8f39ad08 100644 --- a/packages/sdk-typescript/test/unit/createSdkMcpServer.test.ts +++ b/packages/sdk-typescript/test/unit/createSdkMcpServer.test.ts @@ -1,3 +1,9 @@ +/** + * @license + * Copyright 2025 Qwen Team + * SPDX-License-Identifier: Apache-2.0 + */ + /** * Unit tests for createSdkMcpServer * @@ -5,93 +11,112 @@ */ import { describe, expect, it, vi } from 'vitest'; +import { z } from 'zod'; import { createSdkMcpServer } from '../../src/mcp/createSdkMcpServer.js'; import { tool } from '../../src/mcp/tool.js'; -import type { ToolDefinition } from '../../src/types/config.js'; +import type { SdkMcpToolDefinition } from '../../src/mcp/tool.js'; describe('createSdkMcpServer', () => { describe('Server Creation', () => { it('should create server with name and version', () => { - const server = createSdkMcpServer('test-server', '1.0.0', []); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [], + }); expect(server).toBeDefined(); + expect(server.type).toBe('sdk'); + expect(server.name).toBe('test-server'); + expect(server.instance).toBeDefined(); + }); + + it('should create server with default version', () => { + const server = createSdkMcpServer({ + name: 'test-server', + }); + + expect(server).toBeDefined(); + expect(server.name).toBe('test-server'); }); it('should throw error with invalid name', () => { - expect(() => createSdkMcpServer('', '1.0.0', [])).toThrow( - 'name must be a non-empty string', + expect(() => createSdkMcpServer({ name: '', version: '1.0.0' })).toThrow( + 'MCP server name must be a non-empty string', ); }); it('should throw error with invalid version', () => { - expect(() => createSdkMcpServer('test', '', [])).toThrow( - 'version must be a non-empty string', + expect(() => createSdkMcpServer({ name: 'test', version: '' })).toThrow( + 'MCP server version must be a non-empty string', ); }); it('should throw error with non-array tools', () => { expect(() => - createSdkMcpServer('test', '1.0.0', {} as unknown as ToolDefinition[]), + createSdkMcpServer({ + name: 'test', + version: '1.0.0', + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ + tools: {} as unknown as SdkMcpToolDefinition[], + }), ).toThrow('Tools must be an array'); }); }); describe('Tool Registration', () => { it('should register single tool', () => { - const testTool = tool({ - name: 'test_tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: { - input: { type: 'string' }, - }, - }, - handler: async () => 'result', - }); + const testTool = tool( + 'test_tool', + 'A test tool', + { input: z.string() }, + async () => ({ + content: [{ type: 'text', text: 'result' }], + }), + ); - const server = createSdkMcpServer('test-server', '1.0.0', [testTool]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [testTool], + }); expect(server).toBeDefined(); }); it('should register multiple tools', () => { - const tool1 = tool({ - name: 'tool1', - description: 'Tool 1', - inputSchema: { type: 'object' }, - handler: async () => 'result1', - }); + const tool1 = tool('tool1', 'Tool 1', {}, async () => ({ + content: [{ type: 'text', text: 'result1' }], + })); - const tool2 = tool({ - name: 'tool2', - description: 'Tool 2', - inputSchema: { type: 'object' }, - handler: async () => 'result2', - }); + const tool2 = tool('tool2', 'Tool 2', {}, async () => ({ + content: [{ type: 'text', text: 'result2' }], + })); - const server = createSdkMcpServer('test-server', '1.0.0', [tool1, tool2]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [tool1, tool2], + }); expect(server).toBeDefined(); }); it('should throw error for duplicate tool names', () => { - const tool1 = tool({ - name: 'duplicate', - description: 'Tool 1', - inputSchema: { type: 'object' }, - handler: async () => 'result1', - }); + const tool1 = tool('duplicate', 'Tool 1', {}, async () => ({ + content: [{ type: 'text', text: 'result1' }], + })); - const tool2 = tool({ - name: 'duplicate', - description: 'Tool 2', - inputSchema: { type: 'object' }, - handler: async () => 'result2', - }); + const tool2 = tool('duplicate', 'Tool 2', {}, async () => ({ + content: [{ type: 'text', text: 'result2' }], + })); expect(() => - createSdkMcpServer('test-server', '1.0.0', [tool1, tool2]), + createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [tool1, tool2], + }), ).toThrow("Duplicate tool name 'duplicate'"); }); @@ -99,36 +124,41 @@ describe('createSdkMcpServer', () => { const invalidTool = { name: '123invalid', // Starts with number description: 'Invalid tool', - inputSchema: { type: 'object' }, - handler: async () => 'result', + inputSchema: {}, + handler: async () => ({ + content: [{ type: 'text' as const, text: 'result' }], + }), }; expect(() => - createSdkMcpServer('test-server', '1.0.0', [ - invalidTool as unknown as ToolDefinition, - ]), + createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ + tools: [invalidTool as unknown as SdkMcpToolDefinition], + }), ).toThrow('Tool name'); }); }); describe('Tool Handler Invocation', () => { it('should invoke tool handler with correct input', async () => { - const handler = vi.fn().mockResolvedValue({ result: 'success' }); - - const testTool = tool({ - name: 'test_tool', - description: 'A test tool', - inputSchema: { - type: 'object', - properties: { - value: { type: 'string' }, - }, - required: ['value'], - }, - handler, + const handler = vi.fn().mockResolvedValue({ + content: [{ type: 'text', text: 'success' }], }); - createSdkMcpServer('test-server', '1.0.0', [testTool]); + const testTool = tool( + 'test_tool', + 'A test tool', + { value: z.string() }, + handler, + ); + + createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [testTool], + }); // Note: Actual invocation testing requires MCP SDK integration // This test verifies the handler was properly registered @@ -140,17 +170,18 @@ describe('createSdkMcpServer', () => { .fn() .mockImplementation(async (input: { value: string }) => { await new Promise((resolve) => setTimeout(resolve, 10)); - return { processed: input.value }; + return { + content: [{ type: 'text', text: `processed: ${input.value}` }], + }; }); - const testTool = tool({ - name: 'async_tool', - description: 'An async tool', - inputSchema: { type: 'object' }, - handler, - }); + const testTool = tool('async_tool', 'An async tool', {}, handler); - const server = createSdkMcpServer('test-server', '1.0.0', [testTool]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [testTool], + }); expect(server).toBeDefined(); }); @@ -158,40 +189,29 @@ describe('createSdkMcpServer', () => { describe('Type Safety', () => { it('should preserve input type in handler', async () => { - type ToolInput = { - name: string; - age: number; - }; - - type ToolOutput = { - greeting: string; - }; - - const handler = vi - .fn() - .mockImplementation(async (input: ToolInput): Promise => { - return { - greeting: `Hello ${input.name}, age ${input.age}`, - }; - }); - - const typedTool = tool({ - name: 'typed_tool', - description: 'A typed tool', - inputSchema: { - type: 'object', - properties: { - name: { type: 'string' }, - age: { type: 'number' }, - }, - required: ['name', 'age'], - }, - handler, + const handler = vi.fn().mockImplementation(async (input) => { + return { + content: [ + { type: 'text', text: `Hello ${input.name}, age ${input.age}` }, + ], + }; }); - const server = createSdkMcpServer('test-server', '1.0.0', [ - typedTool as ToolDefinition, - ]); + const typedTool = tool( + 'typed_tool', + 'A typed tool', + { + name: z.string(), + age: z.number(), + }, + handler, + ); + + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [typedTool], + }); expect(server).toBeDefined(); }); @@ -201,14 +221,13 @@ describe('createSdkMcpServer', () => { it('should handle tool handler errors gracefully', async () => { const handler = vi.fn().mockRejectedValue(new Error('Tool failed')); - const errorTool = tool({ - name: 'error_tool', - description: 'A tool that errors', - inputSchema: { type: 'object' }, - handler, - }); + const errorTool = tool('error_tool', 'A tool that errors', {}, handler); - const server = createSdkMcpServer('test-server', '1.0.0', [errorTool]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [errorTool], + }); expect(server).toBeDefined(); // Error handling occurs during tool invocation @@ -219,14 +238,18 @@ describe('createSdkMcpServer', () => { throw new Error('Sync error'); }); - const errorTool = tool({ - name: 'sync_error_tool', - description: 'A tool that errors synchronously', - inputSchema: { type: 'object' }, + const errorTool = tool( + 'sync_error_tool', + 'A tool that errors synchronously', + {}, handler, - }); + ); - const server = createSdkMcpServer('test-server', '1.0.0', [errorTool]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [errorTool], + }); expect(server).toBeDefined(); }); @@ -234,69 +257,76 @@ describe('createSdkMcpServer', () => { describe('Complex Tool Scenarios', () => { it('should support tool with complex input schema', () => { - const complexTool = tool({ - name: 'complex_tool', - description: 'A tool with complex schema', - inputSchema: { - type: 'object', - properties: { - query: { type: 'string' }, - filters: { - type: 'object', - properties: { - category: { type: 'string' }, - minPrice: { type: 'number' }, - }, - }, - options: { - type: 'array', - items: { type: 'string' }, - }, - }, - required: ['query'], + const complexTool = tool( + 'complex_tool', + 'A tool with complex schema', + { + query: z.string(), + filters: z + .object({ + category: z.string().optional(), + minPrice: z.number().optional(), + }) + .optional(), + options: z.array(z.string()).optional(), }, - handler: async (input: { filters?: unknown[] }) => { + async (input) => { return { - results: [], - filters: input.filters, + content: [ + { + type: 'text', + text: JSON.stringify({ results: [], filters: input.filters }), + }, + ], }; }, - }); + ); - const server = createSdkMcpServer('test-server', '1.0.0', [ - complexTool as ToolDefinition, - ]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [complexTool], + }); expect(server).toBeDefined(); }); it('should support tool returning complex output', () => { - const complexOutputTool = tool({ - name: 'complex_output_tool', - description: 'Returns complex data', - inputSchema: { type: 'object' }, - handler: async () => { + const complexOutputTool = tool( + 'complex_output_tool', + 'Returns complex data', + {}, + async () => { return { - data: [ - { id: 1, name: 'Item 1' }, - { id: 2, name: 'Item 2' }, - ], - metadata: { - total: 2, - page: 1, - }, - nested: { - deep: { - value: 'test', + content: [ + { + type: 'text', + text: JSON.stringify({ + data: [ + { id: 1, name: 'Item 1' }, + { id: 2, name: 'Item 2' }, + ], + metadata: { + total: 2, + page: 1, + }, + nested: { + deep: { + value: 'test', + }, + }, + }), }, - }, + ], }; }, - }); + ); - const server = createSdkMcpServer('test-server', '1.0.0', [ - complexOutputTool, - ]); + const server = createSdkMcpServer({ + name: 'test-server', + version: '1.0.0', + tools: [complexOutputTool], + }); expect(server).toBeDefined(); }); @@ -304,44 +334,50 @@ describe('createSdkMcpServer', () => { describe('Multiple Servers', () => { it('should create multiple independent servers', () => { - const tool1 = tool({ - name: 'tool1', - description: 'Tool in server 1', - inputSchema: { type: 'object' }, - handler: async () => 'result1', - }); + const tool1 = tool('tool1', 'Tool in server 1', {}, async () => ({ + content: [{ type: 'text', text: 'result1' }], + })); - const tool2 = tool({ - name: 'tool2', - description: 'Tool in server 2', - inputSchema: { type: 'object' }, - handler: async () => 'result2', - }); + const tool2 = tool('tool2', 'Tool in server 2', {}, async () => ({ + content: [{ type: 'text', text: 'result2' }], + })); - const server1 = createSdkMcpServer('server1', '1.0.0', [tool1]); - const server2 = createSdkMcpServer('server2', '1.0.0', [tool2]); + const server1 = createSdkMcpServer({ + name: 'server1', + version: '1.0.0', + tools: [tool1], + }); + const server2 = createSdkMcpServer({ + name: 'server2', + version: '1.0.0', + tools: [tool2], + }); expect(server1).toBeDefined(); expect(server2).toBeDefined(); + expect(server1.name).toBe('server1'); + expect(server2.name).toBe('server2'); }); it('should allow same tool name in different servers', () => { - const tool1 = tool({ - name: 'shared_name', - description: 'Tool in server 1', - inputSchema: { type: 'object' }, - handler: async () => 'result1', - }); + const tool1 = tool('shared_name', 'Tool in server 1', {}, async () => ({ + content: [{ type: 'text', text: 'result1' }], + })); - const tool2 = tool({ - name: 'shared_name', - description: 'Tool in server 2', - inputSchema: { type: 'object' }, - handler: async () => 'result2', - }); + const tool2 = tool('shared_name', 'Tool in server 2', {}, async () => ({ + content: [{ type: 'text', text: 'result2' }], + })); - const server1 = createSdkMcpServer('server1', '1.0.0', [tool1]); - const server2 = createSdkMcpServer('server2', '1.0.0', [tool2]); + const server1 = createSdkMcpServer({ + name: 'server1', + version: '1.0.0', + tools: [tool1], + }); + const server2 = createSdkMcpServer({ + name: 'server2', + version: '1.0.0', + tools: [tool2], + }); expect(server1).toBeDefined(); expect(server2).toBeDefined();