From e25d68afe9ffaa354c665d6dd072f987e73f7305 Mon Sep 17 00:00:00 2001 From: x22x22 Date: Thu, 30 Oct 2025 12:04:58 +0800 Subject: [PATCH] openspec/lightweight-tasks/task1-2-4.md feat: implement stream-json session handling and control requests --- packages/cli/src/gemini.tsx | 24 ++ packages/cli/src/nonInteractiveCli.test.ts | 133 ++++++++- packages/cli/src/nonInteractiveCli.ts | 296 ++++++++++++++++++--- packages/cli/src/streamJson/controller.ts | 165 ++++++++++++ packages/cli/src/streamJson/input.ts | 132 +++++++++ packages/cli/src/streamJson/session.ts | 214 +++++++++++++++ packages/cli/src/streamJson/writer.test.ts | 15 +- 7 files changed, 928 insertions(+), 51 deletions(-) create mode 100644 packages/cli/src/streamJson/controller.ts create mode 100644 packages/cli/src/streamJson/input.ts create mode 100644 packages/cli/src/streamJson/session.ts diff --git a/packages/cli/src/gemini.tsx b/packages/cli/src/gemini.tsx index 46c1052f..99a9f732 100644 --- a/packages/cli/src/gemini.tsx +++ b/packages/cli/src/gemini.tsx @@ -23,6 +23,7 @@ import { getStartupWarnings } from './utils/startupWarnings.js'; import { getUserStartupWarnings } from './utils/userStartupWarnings.js'; import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; import { runNonInteractive } from './nonInteractiveCli.js'; +import { runStreamJsonSession } from './streamJson/session.js'; import { ExtensionStorage, loadExtensions } from './config/extension.js'; import { cleanupCheckpoints, @@ -414,6 +415,29 @@ export async function main() { input = `${stdinData}\n\n${input}`; } } + const inputFormat = + typeof config.getInputFormat === 'function' + ? config.getInputFormat() + : 'text'; + + if (inputFormat === 'stream-json') { + const trimmedInput = (input ?? '').trim(); + const nonInteractiveConfig = await validateNonInteractiveAuth( + settings.merged.security?.auth?.selectedType, + settings.merged.security?.auth?.useExternal, + config, + settings, + ); + + await runStreamJsonSession( + nonInteractiveConfig, + settings, + trimmedInput.length > 0 ? trimmedInput : undefined, + ); + await runExitCleanup(); + process.exit(0); + } + if (!input) { console.error( `No input provided via stdin. Input can be provided by piping data into gemini or using the --prompt option.`, diff --git a/packages/cli/src/nonInteractiveCli.test.ts b/packages/cli/src/nonInteractiveCli.test.ts index 1aad835e..d92ef0f8 100644 --- a/packages/cli/src/nonInteractiveCli.test.ts +++ b/packages/cli/src/nonInteractiveCli.test.ts @@ -68,12 +68,13 @@ describe('runNonInteractive', () => { let mockGeminiClient: { sendMessageStream: vi.Mock; getChatRecordingService: vi.Mock; + getChat: vi.Mock; }; + let mockGetDebugResponses: vi.Mock; beforeEach(async () => { mockCoreExecuteToolCall = vi.mocked(executeToolCall); mockShutdownTelemetry = vi.mocked(shutdownTelemetry); - mockCommandServiceCreate.mockResolvedValue({ getCommands: mockGetCommands, }); @@ -91,6 +92,8 @@ describe('runNonInteractive', () => { getFunctionDeclarations: vi.fn().mockReturnValue([]), } as unknown as ToolRegistry; + mockGetDebugResponses = vi.fn(() => []); + mockGeminiClient = { sendMessageStream: vi.fn(), getChatRecordingService: vi.fn(() => ({ @@ -99,14 +102,18 @@ describe('runNonInteractive', () => { recordMessageTokens: vi.fn(), recordToolCalls: vi.fn(), })), + getChat: vi.fn(() => ({ + getDebugResponses: mockGetDebugResponses, + })), }; + let currentModel = 'test-model'; + mockConfig = { initialize: vi.fn().mockResolvedValue(undefined), getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient), getToolRegistry: vi.fn().mockReturnValue(mockToolRegistry), getMaxSessionTurns: vi.fn().mockReturnValue(10), - getSessionId: vi.fn().mockReturnValue('test-session-id'), getProjectRoot: vi.fn().mockReturnValue('/test/project'), storage: { getProjectTempDir: vi.fn().mockReturnValue('/test/project/.gemini/tmp'), @@ -118,6 +125,12 @@ describe('runNonInteractive', () => { getOutputFormat: vi.fn().mockReturnValue('text'), getFolderTrustFeature: vi.fn().mockReturnValue(false), getFolderTrust: vi.fn().mockReturnValue(false), + getIncludePartialMessages: vi.fn().mockReturnValue(false), + getSessionId: vi.fn().mockReturnValue('test-session-id'), + getModel: vi.fn(() => currentModel), + setModel: vi.fn(async (model: string) => { + currentModel = model; + }), } as unknown as Config; mockSettings = { @@ -873,4 +886,120 @@ describe('runNonInteractive', () => { expect(processStdoutSpy).toHaveBeenCalledWith('Acknowledged'); }); + + it('should emit stream-json envelopes when output format is stream-json', async () => { + (mockConfig.getOutputFormat as vi.Mock).mockReturnValue('stream-json'); + (mockConfig.getIncludePartialMessages as vi.Mock).mockReturnValue(false); + + const writes: string[] = []; + processStdoutSpy.mockImplementation((chunk: string | Uint8Array) => { + if (typeof chunk === 'string') { + writes.push(chunk); + } else { + writes.push(Buffer.from(chunk).toString('utf8')); + } + return true; + }); + + const events: ServerGeminiStreamEvent[] = [ + { type: GeminiEventType.Content, value: 'Hello stream' }, + { + type: GeminiEventType.Finished, + value: { reason: undefined, usageMetadata: { totalTokenCount: 4 } }, + }, + ]; + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents(events), + ); + + await runNonInteractive( + mockConfig, + mockSettings, + 'Stream input', + 'prompt-stream', + ); + + const envelopes = writes + .join('') + .split('\n') + .filter((line) => line.trim().length > 0) + .map((line) => JSON.parse(line)); + + expect(envelopes[0]).toMatchObject({ + type: 'user', + message: { content: 'Stream input' }, + }); + const assistantEnvelope = envelopes.find((env) => env.type === 'assistant'); + expect(assistantEnvelope).toBeTruthy(); + expect(assistantEnvelope?.message?.content?.[0]).toMatchObject({ + type: 'text', + text: 'Hello stream', + }); + const resultEnvelope = envelopes.at(-1); + expect(resultEnvelope).toMatchObject({ + type: 'result', + is_error: false, + num_turns: 1, + }); + }); + + it('should include usage metadata and API duration in stream-json result', async () => { + (mockConfig.getOutputFormat as vi.Mock).mockReturnValue('stream-json'); + (mockConfig.getIncludePartialMessages as vi.Mock).mockReturnValue(false); + + const writes: string[] = []; + processStdoutSpy.mockImplementation((chunk: string | Uint8Array) => { + if (typeof chunk === 'string') { + writes.push(chunk); + } else { + writes.push(Buffer.from(chunk).toString('utf8')); + } + return true; + }); + + const usageMetadata = { + promptTokenCount: 11, + candidatesTokenCount: 5, + totalTokenCount: 16, + cachedContentTokenCount: 3, + }; + mockGetDebugResponses.mockReturnValue([{ usageMetadata }]); + + const nowSpy = vi.spyOn(Date, 'now'); + let current = 0; + nowSpy.mockImplementation(() => { + current += 500; + return current; + }); + + mockGeminiClient.sendMessageStream.mockReturnValue( + createStreamFromEvents([ + { type: GeminiEventType.Content, value: 'All done' }, + ]), + ); + + await runNonInteractive( + mockConfig, + mockSettings, + 'usage test', + 'prompt-usage', + ); + + const envelopes = writes + .join('') + .split('\n') + .filter((line) => line.trim().length > 0) + .map((line) => JSON.parse(line)); + const resultEnvelope = envelopes.at(-1); + expect(resultEnvelope?.type).toBe('result'); + expect(resultEnvelope?.duration_api_ms).toBeGreaterThan(0); + expect(resultEnvelope?.usage).toEqual({ + input_tokens: 11, + output_tokens: 5, + total_tokens: 16, + cache_read_input_tokens: 3, + }); + + nowSpy.mockRestore(); + }); }); diff --git a/packages/cli/src/nonInteractiveCli.ts b/packages/cli/src/nonInteractiveCli.ts index 37f02fab..ffd7b9bd 100644 --- a/packages/cli/src/nonInteractiveCli.ts +++ b/packages/cli/src/nonInteractiveCli.ts @@ -18,8 +18,13 @@ import { JsonFormatter, uiTelemetryService, } from '@qwen-code/qwen-code-core'; - -import type { Content, Part } from '@google/genai'; +import type { Content, Part, PartListUnion } from '@google/genai'; +import { StreamJsonWriter } from './streamJson/writer.js'; +import type { + StreamJsonUsage, + StreamJsonUserEnvelope, +} from './streamJson/types.js'; +import type { StreamJsonController } from './streamJson/controller.js'; import { handleSlashCommand } from './nonInteractiveCliCommands.js'; import { ConsolePatcher } from './ui/utils/ConsolePatcher.js'; @@ -31,11 +36,134 @@ import { handleMaxTurnsExceededError, } from './utils/errors.js'; +export interface RunNonInteractiveOptions { + abortController?: AbortController; + streamJson?: { + writer?: StreamJsonWriter; + controller?: StreamJsonController; + }; + userEnvelope?: StreamJsonUserEnvelope; +} + +function normalizePartList(parts: PartListUnion | null): Part[] { + if (!parts) { + return []; + } + + if (typeof parts === 'string') { + return [{ text: parts }]; + } + + if (Array.isArray(parts)) { + return parts.map((part) => + typeof part === 'string' ? { text: part } : (part as Part), + ); + } + + return [parts as Part]; +} + +function extractPartsFromEnvelope( + envelope: StreamJsonUserEnvelope | undefined, +): PartListUnion | null { + if (!envelope) { + return null; + } + + const content = envelope.message?.content; + if (typeof content === 'string') { + return content; + } + + if (Array.isArray(content)) { + const parts: Part[] = []; + for (const block of content) { + if (!block || typeof block !== 'object' || !('type' in block)) { + continue; + } + if (block.type === 'text' && block.text) { + parts.push({ text: block.text }); + } else { + parts.push({ text: JSON.stringify(block) }); + } + } + return parts.length > 0 ? parts : null; + } + + return null; +} + +function extractUsageFromGeminiClient( + geminiClient: unknown, +): StreamJsonUsage | undefined { + if ( + !geminiClient || + typeof geminiClient !== 'object' || + typeof (geminiClient as { getChat?: unknown }).getChat !== 'function' + ) { + return undefined; + } + + try { + const chat = (geminiClient as { getChat: () => unknown }).getChat(); + if ( + !chat || + typeof chat !== 'object' || + typeof (chat as { getDebugResponses?: unknown }).getDebugResponses !== + 'function' + ) { + return undefined; + } + + const responses = ( + chat as { + getDebugResponses: () => Array>; + } + ).getDebugResponses(); + for (let i = responses.length - 1; i >= 0; i--) { + const metadata = responses[i]?.['usageMetadata'] as + | Record + | undefined; + if (metadata) { + const promptTokens = metadata['promptTokenCount']; + const completionTokens = metadata['candidatesTokenCount']; + const totalTokens = metadata['totalTokenCount']; + const cachedTokens = metadata['cachedContentTokenCount']; + + return { + input_tokens: + typeof promptTokens === 'number' ? promptTokens : undefined, + output_tokens: + typeof completionTokens === 'number' ? completionTokens : undefined, + total_tokens: + typeof totalTokens === 'number' ? totalTokens : undefined, + cache_read_input_tokens: + typeof cachedTokens === 'number' ? cachedTokens : undefined, + }; + } + } + } catch (error) { + console.debug('Failed to extract usage metadata:', error); + } + + return undefined; +} + +function calculateApproximateCost( + usage: StreamJsonUsage | undefined, +): number | undefined { + if (!usage) { + return undefined; + } + return 0; +} + export async function runNonInteractive( config: Config, settings: LoadedSettings, input: string, prompt_id: string, + options: RunNonInteractiveOptions = {}, ): Promise { return promptIdContext.run(prompt_id, async () => { const consolePatcher = new ConsolePatcher({ @@ -43,6 +171,17 @@ export async function runNonInteractive( debugMode: config.getDebugMode(), }); + const isStreamJsonOutput = config.getOutputFormat() === 'stream-json'; + const streamJsonContext = options.streamJson; + const streamJsonWriter = isStreamJsonOutput + ? (streamJsonContext?.writer ?? + new StreamJsonWriter(config, config.getIncludePartialMessages())) + : undefined; + + let turnCount = 0; + let totalApiDurationMs = 0; + const startTime = Date.now(); + try { consolePatcher.patch(); // Handle EPIPE errors when the output is piped to a command that closes early. @@ -54,49 +193,63 @@ export async function runNonInteractive( }); const geminiClient = config.getGeminiClient(); + const abortController = options.abortController ?? new AbortController(); + streamJsonContext?.controller?.setActiveRunAbortController?.( + abortController, + ); - const abortController = new AbortController(); + let initialPartList: PartListUnion | null = extractPartsFromEnvelope( + options.userEnvelope, + ); - let query: Part[] | undefined; - - if (isSlashCommand(input)) { - const slashCommandResult = await handleSlashCommand( - input, - abortController, - config, - settings, - ); - // If a slash command is found and returns a prompt, use it. - // Otherwise, slashCommandResult fall through to the default prompt - // handling. - if (slashCommandResult) { - query = slashCommandResult as Part[]; - } - } - - if (!query) { - const { processedQuery, shouldProceed } = await handleAtCommand({ - query: input, - config, - addItem: (_item, _timestamp) => 0, - onDebugMessage: () => {}, - messageId: Date.now(), - signal: abortController.signal, - }); - - if (!shouldProceed || !processedQuery) { - // An error occurred during @include processing (e.g., file not found). - // The error message is already logged by handleAtCommand. - throw new FatalInputError( - 'Exiting due to an error processing the @ command.', + if (!initialPartList) { + let slashHandled = false; + if (isSlashCommand(input)) { + const slashCommandResult = await handleSlashCommand( + input, + abortController, + config, + settings, ); + if (slashCommandResult) { + // A slash command can replace the prompt entirely; fall back to @-command processing otherwise. + initialPartList = slashCommandResult as PartListUnion; + slashHandled = true; + } + } + + if (!slashHandled) { + const { processedQuery, shouldProceed } = await handleAtCommand({ + query: input, + config, + addItem: (_item, _timestamp) => 0, + onDebugMessage: () => {}, + messageId: Date.now(), + signal: abortController.signal, + }); + + if (!shouldProceed || !processedQuery) { + // An error occurred during @include processing (e.g., file not found). + // The error message is already logged by handleAtCommand. + throw new FatalInputError( + 'Exiting due to an error processing the @ command.', + ); + } + initialPartList = processedQuery as PartListUnion; } - query = processedQuery as Part[]; } - let currentMessages: Content[] = [{ role: 'user', parts: query }]; + if (!initialPartList) { + initialPartList = [{ text: input }]; + } + + const initialParts = normalizePartList(initialPartList); + let currentMessages: Content[] = [{ role: 'user', parts: initialParts }]; + + if (streamJsonWriter) { + streamJsonWriter.emitUserMessageFromParts(initialParts); + } - let turnCount = 0; while (true) { turnCount++; if ( @@ -105,31 +258,53 @@ export async function runNonInteractive( ) { handleMaxTurnsExceededError(config); } - const toolCallRequests: ToolCallRequestInfo[] = []; + const toolCallRequests: ToolCallRequestInfo[] = []; + const apiStartTime = Date.now(); const responseStream = geminiClient.sendMessageStream( currentMessages[0]?.parts || [], abortController.signal, prompt_id, ); + const assistantBuilder = streamJsonWriter?.createAssistantBuilder(); let responseText = ''; + for await (const event of responseStream) { if (abortController.signal.aborted) { handleCancellationError(config); } if (event.type === GeminiEventType.Content) { - if (config.getOutputFormat() === OutputFormat.JSON) { + if (streamJsonWriter) { + assistantBuilder?.appendText(event.value); + } else if (config.getOutputFormat() === OutputFormat.JSON) { responseText += event.value; } else { process.stdout.write(event.value); } + } else if (event.type === GeminiEventType.Thought) { + if (streamJsonWriter) { + const subject = event.value.subject?.trim(); + const description = event.value.description?.trim(); + const combined = [subject, description] + .filter((part) => part && part.length > 0) + .join(': '); + if (combined.length > 0) { + assistantBuilder?.appendThinking(combined); + } + } } else if (event.type === GeminiEventType.ToolCallRequest) { toolCallRequests.push(event.value); + if (streamJsonWriter) { + assistantBuilder?.appendToolUse(event.value); + } } } + assistantBuilder?.finalize(); + totalApiDurationMs += Date.now() - apiStartTime; + if (toolCallRequests.length > 0) { const toolResponseParts: Part[] = []; for (const requestInfo of toolCallRequests) { @@ -149,6 +324,18 @@ export async function runNonInteractive( ? toolResponse.resultDisplay : undefined, ); + if (streamJsonWriter) { + const message = + toolResponse.resultDisplay || toolResponse.error.message; + streamJsonWriter.emitSystemMessage('tool_error', { + tool: requestInfo.name, + message, + }); + } + } + + if (streamJsonWriter) { + streamJsonWriter.emitToolResult(requestInfo, toolResponse); } if (toolResponse.responseParts) { @@ -157,19 +344,44 @@ export async function runNonInteractive( } currentMessages = [{ role: 'user', parts: toolResponseParts }]; } else { - if (config.getOutputFormat() === OutputFormat.JSON) { + if (streamJsonWriter) { + const usage = extractUsageFromGeminiClient(geminiClient); + streamJsonWriter.emitResult({ + isError: false, + durationMs: Date.now() - startTime, + apiDurationMs: totalApiDurationMs, + numTurns: turnCount, + usage, + totalCostUsd: calculateApproximateCost(usage), + }); + } else if (config.getOutputFormat() === OutputFormat.JSON) { const formatter = new JsonFormatter(); const stats = uiTelemetryService.getMetrics(); process.stdout.write(formatter.format(responseText, stats)); } else { - process.stdout.write('\n'); // Ensure a final newline + // Preserve the historical newline after a successful non-interactive run. + process.stdout.write('\n'); } return; } } } catch (error) { + if (streamJsonWriter) { + const usage = extractUsageFromGeminiClient(config.getGeminiClient()); + const message = error instanceof Error ? error.message : String(error); + streamJsonWriter.emitResult({ + isError: true, + durationMs: Date.now() - startTime, + apiDurationMs: totalApiDurationMs, + numTurns: turnCount, + errorMessage: message, + usage, + totalCostUsd: calculateApproximateCost(usage), + }); + } handleError(error, config); } finally { + streamJsonContext?.controller?.setActiveRunAbortController?.(null); consolePatcher.cleanup(); if (isTelemetrySdkInitialized()) { await shutdownTelemetry(config); diff --git a/packages/cli/src/streamJson/controller.ts b/packages/cli/src/streamJson/controller.ts new file mode 100644 index 00000000..4214b89b --- /dev/null +++ b/packages/cli/src/streamJson/controller.ts @@ -0,0 +1,165 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { randomUUID } from 'node:crypto'; +import type { StreamJsonWriter } from './writer.js'; +import type { + StreamJsonControlCancelRequestEnvelope, + StreamJsonControlResponseEnvelope, + StreamJsonOutputEnvelope, +} from './types.js'; + +interface PendingControlRequest { + resolve: (envelope: StreamJsonControlResponseEnvelope) => void; + reject: (error: Error) => void; + timeout?: NodeJS.Timeout; +} + +export interface ControlRequestOptions { + timeoutMs?: number; +} + +export class StreamJsonController { + private readonly pendingRequests = new Map(); + private activeAbortController: AbortController | null = null; + + constructor(private readonly writer: StreamJsonWriter) {} + + sendControlRequest( + subtype: string, + payload: Record, + options: ControlRequestOptions = {}, + ): Promise { + const requestId = randomUUID(); + const envelope: StreamJsonOutputEnvelope = { + type: 'control_request', + request_id: requestId, + request: { + subtype, + ...payload, + }, + }; + + const promise = new Promise( + (resolve, reject) => { + const pending: PendingControlRequest = { resolve, reject }; + + if (options.timeoutMs && options.timeoutMs > 0) { + pending.timeout = setTimeout(() => { + this.pendingRequests.delete(requestId); + reject( + new Error(`Timed out waiting for control_response to ${subtype}`), + ); + }, options.timeoutMs); + } + + this.pendingRequests.set(requestId, pending); + }, + ); + + this.writer.writeEnvelope(envelope); + return promise; + } + + handleControlResponse(envelope: StreamJsonControlResponseEnvelope): void { + const pending = this.pendingRequests.get(envelope.request_id); + if (!pending) { + return; + } + + if (pending.timeout) { + clearTimeout(pending.timeout); + } + + this.pendingRequests.delete(envelope.request_id); + pending.resolve(envelope); + } + + handleControlCancel(envelope: StreamJsonControlCancelRequestEnvelope): void { + if (envelope.request_id) { + this.rejectPending( + envelope.request_id, + new Error( + envelope.reason + ? `Control request cancelled: ${envelope.reason}` + : 'Control request cancelled', + ), + ); + return; + } + + for (const requestId of [...this.pendingRequests.keys()]) { + this.rejectPending( + requestId, + new Error( + envelope.reason + ? `Control request cancelled: ${envelope.reason}` + : 'Control request cancelled', + ), + ); + } + } + + setActiveRunAbortController(controller: AbortController | null): void { + this.activeAbortController = controller; + } + + interruptActiveRun(): void { + this.activeAbortController?.abort(); + } + + cancelPendingRequests(reason?: string, requestId?: string): void { + if (requestId) { + if (!this.pendingRequests.has(requestId)) { + return; + } + this.writer.writeEnvelope({ + type: 'control_cancel_request', + request_id: requestId, + reason, + }); + this.rejectPending( + requestId, + new Error( + reason + ? `Control request cancelled: ${reason}` + : 'Control request cancelled', + ), + ); + return; + } + + for (const pendingId of [...this.pendingRequests.keys()]) { + this.writer.writeEnvelope({ + type: 'control_cancel_request', + request_id: pendingId, + reason, + }); + this.rejectPending( + pendingId, + new Error( + reason + ? `Control request cancelled: ${reason}` + : 'Control request cancelled', + ), + ); + } + } + + private rejectPending(requestId: string, error: Error): void { + const pending = this.pendingRequests.get(requestId); + if (!pending) { + return; + } + + if (pending.timeout) { + clearTimeout(pending.timeout); + } + + this.pendingRequests.delete(requestId); + pending.reject(error); + } +} diff --git a/packages/cli/src/streamJson/input.ts b/packages/cli/src/streamJson/input.ts new file mode 100644 index 00000000..0da040f7 --- /dev/null +++ b/packages/cli/src/streamJson/input.ts @@ -0,0 +1,132 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { createInterface } from 'node:readline/promises'; +import process from 'node:process'; +import { + parseStreamJsonEnvelope, + serializeStreamJsonEnvelope, + type StreamJsonControlRequestEnvelope, + type StreamJsonOutputEnvelope, + type StreamJsonUserEnvelope, +} from './types.js'; +import { FatalInputError } from '@qwen-code/qwen-code-core'; + +export interface ParsedStreamJsonInput { + prompt: string; +} + +export async function readStreamJsonInput(): Promise { + const rl = createInterface({ + input: process.stdin, + crlfDelay: Number.POSITIVE_INFINITY, + terminal: false, + }); + + try { + return await parseStreamJsonInputFromIterable(rl); + } finally { + rl.close(); + } +} + +export async function parseStreamJsonInputFromIterable( + lines: AsyncIterable, + emitEnvelope: (envelope: StreamJsonOutputEnvelope) => void = writeEnvelope, +): Promise { + const promptParts: string[] = []; + let receivedUserMessage = false; + + for await (const rawLine of lines) { + const line = rawLine.trim(); + if (!line) { + continue; + } + + const envelope = parseStreamJsonEnvelope(line); + + switch (envelope.type) { + case 'user': + promptParts.push(extractUserMessageText(envelope)); + receivedUserMessage = true; + break; + case 'control_request': + handleControlRequest(envelope, emitEnvelope); + break; + case 'control_response': + case 'control_cancel_request': + // Currently ignored on CLI side. + break; + default: + throw new FatalInputError( + `Unsupported stream-json input type: ${envelope.type}`, + ); + } + } + + if (!receivedUserMessage) { + throw new FatalInputError( + 'No user message provided via stream-json input.', + ); + } + + return { + prompt: promptParts.join('\n').trim(), + }; +} + +function handleControlRequest( + envelope: StreamJsonControlRequestEnvelope, + emitEnvelope: (envelope: StreamJsonOutputEnvelope) => void, +) { + const subtype = envelope.request?.subtype; + if (subtype === 'initialize') { + emitEnvelope({ + type: 'control_response', + request_id: envelope.request_id, + success: true, + response: { + subtype, + capabilities: {}, + }, + }); + return; + } + + emitEnvelope({ + type: 'control_response', + request_id: envelope.request_id, + success: false, + error: `Unsupported control_request subtype: ${subtype ?? 'unknown'}`, + }); +} + +export function extractUserMessageText( + envelope: StreamJsonUserEnvelope, +): string { + const content = envelope.message?.content; + if (typeof content === 'string') { + return content; + } + if (Array.isArray(content)) { + return content + .map((block) => { + if (block && typeof block === 'object' && 'type' in block) { + if (block.type === 'text' && 'text' in block) { + return block.text ?? ''; + } + return JSON.stringify(block); + } + return ''; + }) + .join('\n'); + } + return ''; +} + +function writeEnvelope(envelope: StreamJsonOutputEnvelope): void { + process.stdout.write(`${serializeStreamJsonEnvelope(envelope)}\n`); +} diff --git a/packages/cli/src/streamJson/session.ts b/packages/cli/src/streamJson/session.ts new file mode 100644 index 00000000..de9ac9c6 --- /dev/null +++ b/packages/cli/src/streamJson/session.ts @@ -0,0 +1,214 @@ +/** + * @license + * Copyright 2025 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import readline from 'node:readline'; +import type { Config } from '@qwen-code/qwen-code-core'; +import { + parseStreamJsonEnvelope, + type StreamJsonEnvelope, + type StreamJsonControlRequestEnvelope, + type StreamJsonUserEnvelope, +} from './types.js'; +import { extractUserMessageText } from './input.js'; +import { StreamJsonWriter } from './writer.js'; +import { StreamJsonController } from './controller.js'; +import { runNonInteractive } from '../nonInteractiveCli.js'; +import type { LoadedSettings } from '../config/settings.js'; + +export interface StreamJsonSessionOptions { + input?: NodeJS.ReadableStream; + writer?: StreamJsonWriter; +} + +interface PromptJob { + prompt: string; + envelope?: StreamJsonUserEnvelope; +} + +export async function runStreamJsonSession( + config: Config, + settings: LoadedSettings, + initialPrompt: string | undefined, + options: StreamJsonSessionOptions = {}, +): Promise { + const inputStream = options.input ?? process.stdin; + const writer = + options.writer ?? + new StreamJsonWriter(config, config.getIncludePartialMessages()); + + const controller = new StreamJsonController(writer); + const promptQueue: PromptJob[] = []; + let activeRun: Promise | null = null; + + const processQueue = async (): Promise => { + if (activeRun || promptQueue.length === 0) { + return; + } + + const job = promptQueue.shift(); + if (!job) { + void processQueue(); + return; + } + + const abortController = new AbortController(); + controller.setActiveRunAbortController(abortController); + + const runPromise = handleUserPrompt( + config, + settings, + writer, + controller, + job, + abortController, + ) + .catch((error) => { + console.error('Failed to handle stream-json prompt:', error); + }) + .finally(() => { + controller.setActiveRunAbortController(null); + }); + + activeRun = runPromise; + try { + await runPromise; + } finally { + activeRun = null; + void processQueue(); + } + }; + + const enqueuePrompt = (job: PromptJob): void => { + promptQueue.push(job); + void processQueue(); + }; + + if (initialPrompt && initialPrompt.trim().length > 0) { + enqueuePrompt({ prompt: initialPrompt.trim() }); + } + + const rl = readline.createInterface({ + input: inputStream, + crlfDelay: Number.POSITIVE_INFINITY, + terminal: false, + }); + + try { + for await (const rawLine of rl) { + const line = rawLine.trim(); + if (!line) { + continue; + } + + let envelope: StreamJsonEnvelope; + try { + envelope = parseStreamJsonEnvelope(line); + } catch (error) { + writer.emitResult({ + isError: true, + numTurns: 0, + errorMessage: + error instanceof Error ? error.message : 'Failed to parse JSON', + }); + continue; + } + + switch (envelope.type) { + case 'user': + enqueuePrompt({ + prompt: extractUserMessageText(envelope).trim(), + envelope, + }); + break; + case 'control_request': + await handleControlRequest(config, controller, envelope, writer); + break; + case 'control_response': + controller.handleControlResponse(envelope); + break; + case 'control_cancel_request': + controller.handleControlCancel(envelope); + break; + default: + writer.emitResult({ + isError: true, + numTurns: 0, + errorMessage: `Unsupported stream-json input type: ${envelope.type}`, + }); + } + } + } finally { + rl.close(); + controller.cancelPendingRequests('Session terminated'); + } +} + +async function handleUserPrompt( + config: Config, + settings: LoadedSettings, + writer: StreamJsonWriter, + controller: StreamJsonController, + job: PromptJob, + abortController: AbortController, +): Promise { + const prompt = job.prompt ?? ''; + const messageRecord = + job.envelope && typeof job.envelope.message === 'object' + ? (job.envelope.message as Record) + : undefined; + const envelopePromptId = + messageRecord && typeof messageRecord['prompt_id'] === 'string' + ? String(messageRecord['prompt_id']).trim() + : undefined; + const promptId = envelopePromptId ?? `stream-json-${Date.now()}`; + + await runNonInteractive(config, settings, prompt, promptId, { + abortController, + streamJson: { + writer, + controller, + }, + userEnvelope: job.envelope, + }); +} + +async function handleControlRequest( + config: Config, + controller: StreamJsonController, + envelope: StreamJsonControlRequestEnvelope, + writer: StreamJsonWriter, +): Promise { + const subtype = envelope.request?.subtype; + switch (subtype) { + case 'initialize': + writer.emitSystemMessage('session_initialized', { + session_id: config.getSessionId(), + }); + controller.handleControlResponse({ + type: 'control_response', + request_id: envelope.request_id, + success: true, + response: { subtype: 'initialize' }, + }); + break; + case 'interrupt': + controller.interruptActiveRun(); + controller.handleControlResponse({ + type: 'control_response', + request_id: envelope.request_id, + success: true, + response: { subtype: 'interrupt' }, + }); + break; + default: + controller.handleControlResponse({ + type: 'control_response', + request_id: envelope.request_id, + success: false, + error: `Unsupported control_request subtype: ${subtype ?? 'unknown'}`, + }); + } +} diff --git a/packages/cli/src/streamJson/writer.test.ts b/packages/cli/src/streamJson/writer.test.ts index bc598496..6c9ece11 100644 --- a/packages/cli/src/streamJson/writer.test.ts +++ b/packages/cli/src/streamJson/writer.test.ts @@ -7,6 +7,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import type { Config, ToolCallRequestInfo } from '@qwen-code/qwen-code-core'; import { StreamJsonWriter } from './writer.js'; +import type { StreamJsonOutputEnvelope } from './types.js'; function createConfig(): Config { return { @@ -15,12 +16,12 @@ function createConfig(): Config { } as unknown as Config; } -function parseEnvelopes(writes: string[]): unknown[] { +function parseEnvelopes(writes: string[]): StreamJsonOutputEnvelope[] { return writes .join('') .split('\n') .filter((line) => line.trim().length > 0) - .map((line) => JSON.parse(line)); + .map((line) => JSON.parse(line) as StreamJsonOutputEnvelope); } describe('StreamJsonWriter', () => { @@ -62,7 +63,7 @@ describe('StreamJsonWriter', () => { subtype: 'session_summary', }); - const [envelope] = parseEnvelopes(writes) as Array>; + const [envelope] = parseEnvelopes(writes); expect(envelope).toMatchObject({ type: 'result', duration_ms: 1200, @@ -87,7 +88,7 @@ describe('StreamJsonWriter', () => { builder.appendThinking(' more'); builder.finalize(); - const envelopes = parseEnvelopes(writes) as Array>; + const envelopes = parseEnvelopes(writes); expect( envelopes.some( @@ -99,7 +100,7 @@ describe('StreamJsonWriter', () => { ).toBe(true); const assistantEnvelope = envelopes.find((env) => env.type === 'assistant'); - expect(assistantEnvelope?.message?.content?.[0]).toEqual({ + expect(assistantEnvelope?.message.content?.[0]).toEqual({ type: 'thinking', thinking: 'Reflecting more', }); @@ -119,7 +120,7 @@ describe('StreamJsonWriter', () => { builder.appendToolUse(request); builder.finalize(); - const envelopes = parseEnvelopes(writes) as Array>; + const envelopes = parseEnvelopes(writes); expect( envelopes.some( @@ -135,7 +136,7 @@ describe('StreamJsonWriter', () => { const writer = new StreamJsonWriter(createConfig(), false); writer.emitSystemMessage('init', { foo: 'bar' }); - const [envelope] = parseEnvelopes(writes) as Array>; + const [envelope] = parseEnvelopes(writes); expect(envelope).toMatchObject({ type: 'system', subtype: 'init',