diff --git a/packages/cli/src/gemini.test.tsx b/packages/cli/src/gemini.test.tsx index 12b22de2..94414caa 100644 --- a/packages/cli/src/gemini.test.tsx +++ b/packages/cli/src/gemini.test.tsx @@ -230,6 +230,146 @@ describe('gemini.tsx main function', () => { // Avoid the process.exit error from being thrown. processExitSpy.mockRestore(); }); + + it('invokes runStreamJsonSession and performs cleanup in stream-json mode', async () => { + const originalIsTTY = Object.getOwnPropertyDescriptor( + process.stdin, + 'isTTY', + ); + const originalIsRaw = Object.getOwnPropertyDescriptor( + process.stdin, + 'isRaw', + ); + Object.defineProperty(process.stdin, 'isTTY', { + value: true, + configurable: true, + }); + Object.defineProperty(process.stdin, 'isRaw', { + value: false, + configurable: true, + }); + + const processExitSpy = vi + .spyOn(process, 'exit') + .mockImplementation((code) => { + throw new MockProcessExitError(code); + }); + + const { loadCliConfig, parseArguments } = await import( + './config/config.js' + ); + const { loadSettings } = await import('./config/settings.js'); + const cleanupModule = await import('./utils/cleanup.js'); + const extensionModule = await import('./config/extension.js'); + const validatorModule = await import('./validateNonInterActiveAuth.js'); + const sessionModule = await import('./streamJson/session.js'); + const initializerModule = await import('./core/initializer.js'); + const startupWarningsModule = await import('./utils/startupWarnings.js'); + const userStartupWarningsModule = await import( + './utils/userStartupWarnings.js' + ); + + vi.mocked(cleanupModule.cleanupCheckpoints).mockResolvedValue(undefined); + vi.mocked(cleanupModule.registerCleanup).mockImplementation(() => {}); + const runExitCleanupMock = vi.mocked(cleanupModule.runExitCleanup); + runExitCleanupMock.mockResolvedValue(undefined); + vi.spyOn(extensionModule, 'loadExtensions').mockReturnValue([]); + vi.spyOn( + extensionModule.ExtensionStorage, + 'getUserExtensionsDir', + ).mockReturnValue('/tmp/extensions'); + vi.spyOn(initializerModule, 'initializeApp').mockResolvedValue({ + authError: null, + themeError: null, + shouldOpenAuthDialog: false, + geminiMdFileCount: 0, + }); + vi.spyOn(startupWarningsModule, 'getStartupWarnings').mockResolvedValue([]); + vi.spyOn( + userStartupWarningsModule, + 'getUserStartupWarnings', + ).mockResolvedValue([]); + + const validatedConfig = { validated: true } as unknown as Config; + const validateAuthSpy = vi + .spyOn(validatorModule, 'validateNonInteractiveAuth') + .mockResolvedValue(validatedConfig); + const runSessionSpy = vi + .spyOn(sessionModule, 'runStreamJsonSession') + .mockResolvedValue(undefined); + + vi.mocked(loadSettings).mockReturnValue({ + errors: [], + merged: { + advanced: {}, + security: { auth: {} }, + ui: {}, + }, + setValue: vi.fn(), + forScope: () => ({ settings: {}, originalSettings: {}, path: '' }), + } as never); + + vi.mocked(parseArguments).mockResolvedValue({ + extensions: [], + } as never); + + const configStub = { + isInteractive: () => false, + getQuestion: () => ' hello stream ', + getSandbox: () => false, + getDebugMode: () => false, + getListExtensions: () => false, + getMcpServers: () => ({}), + initialize: vi.fn().mockResolvedValue(undefined), + getIdeMode: () => false, + getExperimentalZedIntegration: () => false, + getScreenReader: () => false, + getGeminiMdFileCount: () => 0, + getProjectRoot: () => '/', + getInputFormat: () => 'stream-json', + getContentGeneratorConfig: () => ({ authType: 'test-auth' }), + } as unknown as Config; + + vi.mocked(loadCliConfig).mockResolvedValue(configStub); + + process.env['SANDBOX'] = '1'; + try { + await main(); + } catch (error) { + if (!(error instanceof MockProcessExitError)) { + throw error; + } + } finally { + processExitSpy.mockRestore(); + if (originalIsTTY) { + Object.defineProperty(process.stdin, 'isTTY', originalIsTTY); + } else { + delete (process.stdin as { isTTY?: unknown }).isTTY; + } + if (originalIsRaw) { + Object.defineProperty(process.stdin, 'isRaw', originalIsRaw); + } else { + delete (process.stdin as { isRaw?: unknown }).isRaw; + } + delete process.env['SANDBOX']; + } + + expect(runSessionSpy).toHaveBeenCalledTimes(1); + const [configArg, settingsArg, promptArg] = runSessionSpy.mock.calls[0]; + expect(configArg).toBe(validatedConfig); + expect(settingsArg).toMatchObject({ + merged: expect.objectContaining({ security: expect.any(Object) }), + }); + expect(promptArg).toBe('hello stream'); + + expect(validateAuthSpy).toHaveBeenCalledWith( + undefined, + undefined, + configStub, + expect.any(Object), + ); + expect(runExitCleanupMock).toHaveBeenCalledTimes(1); + }); }); describe('gemini.tsx main function kitty protocol', () => { @@ -410,6 +550,7 @@ describe('startInteractiveUI', () => { vi.mock('./utils/cleanup.js', () => ({ cleanupCheckpoints: vi.fn(() => Promise.resolve()), registerCleanup: vi.fn(), + runExitCleanup: vi.fn(() => Promise.resolve()), })); vi.mock('ink', () => ({ diff --git a/packages/cli/src/streamJson/session.test.ts b/packages/cli/src/streamJson/session.test.ts index a9fa3dd9..a4a18c4d 100644 --- a/packages/cli/src/streamJson/session.test.ts +++ b/packages/cli/src/streamJson/session.test.ts @@ -4,39 +4,238 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { PassThrough } from 'node:stream'; +import { PassThrough, Readable } from 'node:stream'; import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import type { Config } from '@qwen-code/qwen-code-core'; +import type { LoadedSettings } from '../config/settings.js'; import { runStreamJsonSession } from './session.js'; import { StreamJsonController } from './controller.js'; import { StreamJsonWriter } from './writer.js'; -import type { LoadedSettings } from '../config/settings.js'; + +const runNonInteractiveMock = vi.fn(); +const logUserPromptMock = vi.fn(); vi.mock('../nonInteractiveCli.js', () => ({ - runNonInteractive: vi.fn().mockResolvedValue(undefined), + runNonInteractive: (...args: unknown[]) => runNonInteractiveMock(...args), })); -function createConfig(): Config { +vi.mock('@qwen-code/qwen-code-core', async (importOriginal) => { + const actual = + await importOriginal(); return { + ...actual, + logUserPrompt: (...args: unknown[]) => logUserPromptMock(...args), + }; +}); + +interface ConfigOverrides { + getIncludePartialMessages?: () => boolean; + getSessionId?: () => string; + getModel?: () => string; + getContentGeneratorConfig?: () => { authType?: string }; + [key: string]: unknown; +} + +function createConfig(overrides: ConfigOverrides = {}): Config { + const base = { getIncludePartialMessages: () => false, getSessionId: () => 'session-test', getModel: () => 'model-test', - } as unknown as Config; + getContentGeneratorConfig: () => ({ authType: 'test-auth' }), + getOutputFormat: () => 'stream-json', + }; + return { ...base, ...overrides } as unknown as Config; +} + +function createSettings(): LoadedSettings { + return { + merged: { + security: { auth: {} }, + }, + } as unknown as LoadedSettings; +} + +function createWriter() { + return { + emitResult: vi.fn(), + writeEnvelope: vi.fn(), + emitSystemMessage: vi.fn(), + } as unknown as StreamJsonWriter; } describe('runStreamJsonSession', () => { let settings: LoadedSettings; beforeEach(() => { - vi.spyOn(process.stdout, 'write').mockImplementation(() => true); - settings = {} as LoadedSettings; + settings = createSettings(); + runNonInteractiveMock.mockReset(); + logUserPromptMock.mockReset(); }); afterEach(() => { vi.restoreAllMocks(); }); - it('delegates incoming control requests to the controller', async () => { + it('runs initial prompt before reading stream and logs it', async () => { + const config = createConfig(); + const writer = createWriter(); + const stream = Readable.from([]); + runNonInteractiveMock.mockResolvedValueOnce(undefined); + + await runStreamJsonSession(config, settings, 'Hello world', { + input: stream, + writer, + }); + + expect(runNonInteractiveMock).toHaveBeenCalledTimes(1); + const call = runNonInteractiveMock.mock.calls[0]; + expect(call[0]).toBe(config); + expect(call[1]).toBe(settings); + expect(call[2]).toBe('Hello world'); + expect(typeof call[3]).toBe('string'); + expect(call[4]).toEqual( + expect.objectContaining({ + streamJson: expect.objectContaining({ writer }), + abortController: expect.any(AbortController), + }), + ); + expect(logUserPromptMock).toHaveBeenCalledTimes(1); + const loggedPrompt = logUserPromptMock.mock.calls[0][1] as + | Record + | undefined; + expect(loggedPrompt).toMatchObject({ + prompt: 'Hello world', + prompt_length: 11, + }); + expect(loggedPrompt?.['prompt_id']).toBe(call[3]); + }); + + it('handles user envelope when no initial prompt is provided', async () => { + const config = createConfig(); + const writer = createWriter(); + const envelope = { + type: 'user' as const, + message: { + content: ' Stream mode ready ', + }, + }; + const stream = Readable.from([`${JSON.stringify(envelope)}\n`]); + runNonInteractiveMock.mockResolvedValueOnce(undefined); + + await runStreamJsonSession(config, settings, undefined, { + input: stream, + writer, + }); + + expect(runNonInteractiveMock).toHaveBeenCalledTimes(1); + const call = runNonInteractiveMock.mock.calls[0]; + expect(call[2]).toBe('Stream mode ready'); + expect(call[4]).toEqual( + expect.objectContaining({ + userEnvelope: envelope, + streamJson: expect.objectContaining({ writer }), + abortController: expect.any(AbortController), + }), + ); + }); + + it('processes multiple user messages sequentially', async () => { + const config = createConfig(); + const writer = createWriter(); + const lines = [ + JSON.stringify({ + type: 'user', + message: { content: 'first request' }, + }), + JSON.stringify({ + type: 'user', + message: { content: 'second request' }, + }), + ].map((line) => `${line}\n`); + const stream = Readable.from(lines); + runNonInteractiveMock.mockResolvedValue(undefined); + + await runStreamJsonSession(config, settings, undefined, { + input: stream, + writer, + }); + + expect(runNonInteractiveMock).toHaveBeenCalledTimes(2); + expect(runNonInteractiveMock.mock.calls[0][2]).toBe('first request'); + expect(runNonInteractiveMock.mock.calls[1][2]).toBe('second request'); + }); + + it('emits stream_event when partial messages are enabled', async () => { + const config = createConfig({ + getIncludePartialMessages: () => true, + getSessionId: () => 'partial-session', + getModel: () => 'partial-model', + }); + const stream = Readable.from([ + `${JSON.stringify({ + type: 'user', + message: { content: 'show partial' }, + })}\n`, + ]); + const writeSpy = vi + .spyOn(process.stdout, 'write') + .mockImplementation(() => true); + + runNonInteractiveMock.mockImplementationOnce( + async ( + _config, + _settings, + _prompt, + _promptId, + options?: { + streamJson?: { writer?: StreamJsonWriter }; + }, + ) => { + const builder = options?.streamJson?.writer?.createAssistantBuilder(); + builder?.appendText('partial'); + builder?.finalize(); + }, + ); + + await runStreamJsonSession(config, settings, undefined, { + input: stream, + }); + + const outputs = writeSpy.mock.calls + .map(([chunk]) => chunk as string) + .join('') + .split('\n') + .map((line) => line.trim()) + .filter((line) => line.length > 0) + .map((line) => JSON.parse(line)); + + expect(outputs.some((envelope) => envelope.type === 'stream_event')).toBe( + true, + ); + writeSpy.mockRestore(); + }); + + it('emits error result when JSON parsing fails', async () => { + const config = createConfig(); + const writer = createWriter(); + const stream = Readable.from(['{invalid json\n']); + + await runStreamJsonSession(config, settings, undefined, { + input: stream, + writer, + }); + + expect(writer.emitResult).toHaveBeenCalledWith( + expect.objectContaining({ + isError: true, + }), + ); + expect(runNonInteractiveMock).not.toHaveBeenCalled(); + }); + + it('delegates control requests to the controller', async () => { + const config = createConfig(); + const writer = new StreamJsonWriter(config, false); const controllerPrototype = StreamJsonController.prototype as unknown as { handleIncomingControlRequest: (...args: unknown[]) => unknown; }; @@ -46,8 +245,6 @@ describe('runStreamJsonSession', () => { ); const inputStream = new PassThrough(); - const config = createConfig(); - const controlRequest = { type: 'control_request', request_id: 'req-1', @@ -58,7 +255,7 @@ describe('runStreamJsonSession', () => { await runStreamJsonSession(config, settings, undefined, { input: inputStream, - writer: new StreamJsonWriter(config, false), + writer, }); expect(handleSpy).toHaveBeenCalledTimes(1); diff --git a/packages/cli/src/streamJson/session.ts b/packages/cli/src/streamJson/session.ts index 187b4bde..a6f7e35a 100644 --- a/packages/cli/src/streamJson/session.ts +++ b/packages/cli/src/streamJson/session.ts @@ -6,6 +6,7 @@ import readline from 'node:readline'; import type { Config } from '@qwen-code/qwen-code-core'; +import { logUserPrompt } from '@qwen-code/qwen-code-core'; import { parseStreamJsonEnvelope, type StreamJsonEnvelope, @@ -140,6 +141,13 @@ export async function runStreamJsonSession( } } } finally { + while (activeRun) { + try { + await activeRun; + } catch { + // 忽略已记录的运行错误。 + } + } rl.close(); controller.cancelPendingRequests('Session terminated'); } @@ -164,6 +172,32 @@ async function handleUserPrompt( : undefined; const promptId = envelopePromptId ?? `stream-json-${Date.now()}`; + if (prompt.length > 0) { + const authType = + typeof ( + config as { + getContentGeneratorConfig?: () => { authType?: string }; + } + ).getContentGeneratorConfig === 'function' + ? ( + ( + config as { + getContentGeneratorConfig: () => { authType?: string }; + } + ).getContentGeneratorConfig() ?? {} + ).authType + : undefined; + + logUserPrompt(config, { + 'event.name': 'user_prompt', + 'event.timestamp': new Date().toISOString(), + prompt, + prompt_id: promptId, + auth_type: authType, + prompt_length: prompt.length, + }); + } + await runNonInteractive(config, settings, prompt, promptId, { abortController, streamJson: {