diff --git a/package-lock.json b/package-lock.json index 6b90c240..9cc924b3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11609,7 +11609,7 @@ }, "packages/a2a-server": { "name": "@google/gemini-cli-a2a-server", - "version": "0.3.0", + "version": "0.3.4", "dependencies": { "@a2a-js/sdk": "^0.3.2", "@google-cloud/storage": "^7.16.0", diff --git a/packages/a2a-server/package.json b/packages/a2a-server/package.json index 71709278..ca5ef370 100644 --- a/packages/a2a-server/package.json +++ b/packages/a2a-server/package.json @@ -1,6 +1,6 @@ { "name": "@google/gemini-cli-a2a-server", - "version": "0.3.0", + "version": "0.3.4", "private": true, "description": "Gemini CLI A2A Server", "repository": { diff --git a/packages/cli/src/config/config.test.ts b/packages/cli/src/config/config.test.ts index c8874579..a4296943 100644 --- a/packages/cli/src/config/config.test.ts +++ b/packages/cli/src/config/config.test.ts @@ -83,7 +83,7 @@ vi.mock('@qwen-code/qwen-code-core', async () => { return { ...actualServer, IdeClient: { - getInstance: vi.fn().mockReturnValue({ + getInstance: vi.fn().mockResolvedValue({ getConnectionStatus: vi.fn(), initialize: vi.fn(), shutdown: vi.fn(), diff --git a/packages/cli/src/ui/hooks/useGeminiStream.ts b/packages/cli/src/ui/hooks/useGeminiStream.ts index 2d9c036e..967a8682 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.ts +++ b/packages/cli/src/ui/hooks/useGeminiStream.ts @@ -638,6 +638,9 @@ export const useGeminiStream = ( // before we add loop detected message to history loopDetectedRef.current = true; break; + case ServerGeminiEventType.Retry: + // Will add the missing logic later + break; default: { // enforces exhaustive switch-case const unreachable: never = event; diff --git a/packages/cli/src/zed-integration/zedIntegration.ts b/packages/cli/src/zed-integration/zedIntegration.ts index e3c9714a..1668b53d 100644 --- a/packages/cli/src/zed-integration/zedIntegration.ts +++ b/packages/cli/src/zed-integration/zedIntegration.ts @@ -24,6 +24,7 @@ import { isWithinRoot, logToolCall, MCPServerConfig, + StreamEventType, ToolConfirmationOutcome, } from '@qwen-code/qwen-code-core'; import * as fs from 'node:fs/promises'; @@ -269,8 +270,12 @@ class Session { return { stopReason: 'cancelled' }; } - if (resp.candidates && resp.candidates.length > 0) { - const candidate = resp.candidates[0]; + if ( + resp.type === StreamEventType.CHUNK && + resp.value.candidates && + resp.value.candidates.length > 0 + ) { + const candidate = resp.value.candidates[0]; for (const part of candidate.content?.parts ?? []) { if (!part.text) { continue; @@ -290,8 +295,8 @@ class Session { } } - if (resp.functionCalls) { - functionCalls.push(...resp.functionCalls); + if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) { + functionCalls.push(...resp.value.functionCalls); } } } catch (error) { diff --git a/packages/core/src/config/config.test.ts b/packages/core/src/config/config.test.ts index d72e3a6b..8d18b89a 100644 --- a/packages/core/src/config/config.test.ts +++ b/packages/core/src/config/config.test.ts @@ -111,6 +111,16 @@ vi.mock('../services/gitService.js', () => { return { GitService: GitServiceMock }; }); +vi.mock('../ide/ide-client.js', () => ({ + IdeClient: { + getInstance: vi.fn().mockResolvedValue({ + getConnectionStatus: vi.fn(), + initialize: vi.fn(), + shutdown: vi.fn(), + }), + }, +})); + describe('Server Config (config.ts)', () => { const MODEL = 'gemini-pro'; const SANDBOX: SandboxConfig = { diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index deacea37..a69292fb 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -12,7 +12,12 @@ import type { Part, GenerateContentResponse, } from '@google/genai'; -import { GeminiChat, EmptyStreamError } from './geminiChat.js'; +import { + GeminiChat, + EmptyStreamError, + StreamEventType, + type StreamEvent, +} from './geminiChat.js'; import type { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; @@ -73,6 +78,92 @@ describe('GeminiChat', () => { }); describe('sendMessage', () => { + it('should retain the initial user message when an automatic function call occurs', async () => { + // 1. Define the user's initial text message. This is the turn that gets dropped by the buggy logic. + const userInitialMessage: Content = { + role: 'user', + parts: [{ text: 'How is the weather in Boston?' }], + }; + + // 2. Mock the full API response, including the automaticFunctionCallingHistory. + // This history represents the full turn: user asks, model calls tool, tool responds, model answers. + const mockAfcResponse = { + candidates: [ + { + content: { + role: 'model', + parts: [ + { text: 'The weather in Boston is 72 degrees and sunny.' }, + ], + }, + }, + ], + automaticFunctionCallingHistory: [ + userInitialMessage, // The user's turn + { + // The model's first response: a tool call + role: 'model', + parts: [ + { + functionCall: { + name: 'get_weather', + args: { location: 'Boston' }, + }, + }, + ], + }, + { + // The tool's response, which has a 'user' role + role: 'user', + parts: [ + { + functionResponse: { + name: 'get_weather', + response: { temperature: 72, condition: 'sunny' }, + }, + }, + ], + }, + ], + } as unknown as GenerateContentResponse; + + vi.mocked(mockModelsModule.generateContent).mockResolvedValue( + mockAfcResponse, + ); + + // 3. Action: Send the initial message. + await chat.sendMessage( + { message: 'How is the weather in Boston?' }, + 'prompt-id-afc-bug', + ); + + // 4. Assert: Check the final state of the history. + const history = chat.getHistory(); + + // With the bug, history.length will be 3, because the first user message is dropped. + // The correct behavior is for the history to contain all 4 turns. + expect(history.length).toBe(4); + + // Crucially, assert that the very first turn in the history matches the user's initial message. + // This is the assertion that will fail. + const firstTurn = history[0]!; + expect(firstTurn.role).toBe('user'); + expect(firstTurn?.parts![0]!.text).toBe('How is the weather in Boston?'); + + // Verify the rest of the history is also correct. + const secondTurn = history[1]!; + expect(secondTurn.role).toBe('model'); + expect(secondTurn?.parts![0]!.functionCall).toBeDefined(); + + const thirdTurn = history[2]!; + expect(thirdTurn.role).toBe('user'); + expect(thirdTurn?.parts![0]!.functionResponse).toBeDefined(); + + const fourthTurn = history[3]!; + expect(fourthTurn.role).toBe('model'); + expect(fourthTurn?.parts![0]!.text).toContain('72 degrees and sunny'); + }); + it('should throw an error when attempting to add a user turn after another user turn', async () => { // 1. Setup: Create a history that already ends with a user turn (a functionResponse). const initialHistory: Content[] = [ @@ -240,6 +331,153 @@ describe('GeminiChat', () => { }); describe('sendMessageStream', () => { + it('should succeed if a tool call is followed by an empty part', async () => { + // 1. Mock a stream that contains a tool call, then an invalid (empty) part. + const streamWithToolCall = (async function* () { + yield { + candidates: [ + { + content: { + role: 'model', + parts: [{ functionCall: { name: 'test_tool', args: {} } }], + }, + }, + ], + } as unknown as GenerateContentResponse; + // This second chunk is invalid according to isValidResponse + yield { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: '' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + })(); + + vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + streamWithToolCall, + ); + + // 2. Action & Assert: The stream processing should complete without throwing an error + // because the presence of a tool call makes the empty final chunk acceptable. + const stream = await chat.sendMessageStream( + { message: 'test message' }, + 'prompt-id-tool-call-empty-end', + ); + await expect( + (async () => { + for await (const _ of stream) { + /* consume stream */ + } + })(), + ).resolves.not.toThrow(); + + // 3. Verify history was recorded correctly + const history = chat.getHistory(); + expect(history.length).toBe(2); // user turn + model turn + const modelTurn = history[1]!; + expect(modelTurn?.parts?.length).toBe(1); // The empty part is discarded + expect(modelTurn?.parts![0]!.functionCall).toBeDefined(); + }); + + it('should fail if the stream ends with an empty part and has no finishReason', async () => { + // 1. Mock a stream that ends with an invalid part and has no finish reason. + const streamWithNoFinish = (async function* () { + yield { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'Initial content...' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + // This second chunk is invalid and has no finishReason, so it should fail. + yield { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: '' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + })(); + + vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + streamWithNoFinish, + ); + + // 2. Action & Assert: The stream should fail because there's no finish reason. + const stream = await chat.sendMessageStream( + { message: 'test message' }, + 'prompt-id-no-finish-empty-end', + ); + await expect( + (async () => { + for await (const _ of stream) { + /* consume stream */ + } + })(), + ).rejects.toThrow(EmptyStreamError); + }); + + it('should succeed if the stream ends with an invalid part but has a finishReason and contained a valid part', async () => { + // 1. Mock a stream that sends a valid chunk, then an invalid one, but has a finish reason. + const streamWithInvalidEnd = (async function* () { + yield { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: 'Initial valid content...' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + // This second chunk is invalid, but the response has a finishReason. + yield { + candidates: [ + { + content: { + role: 'model', + parts: [{ text: '' }], // Invalid part + }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(); + + vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue( + streamWithInvalidEnd, + ); + + // 2. Action & Assert: The stream should complete without throwing an error. + const stream = await chat.sendMessageStream( + { message: 'test message' }, + 'prompt-id-valid-then-invalid-end', + ); + await expect( + (async () => { + for await (const _ of stream) { + /* consume stream */ + } + })(), + ).resolves.not.toThrow(); + + // 3. Verify history was recorded correctly with only the valid part. + const history = chat.getHistory(); + expect(history.length).toBe(2); // user turn + model turn + const modelTurn = history[1]!; + expect(modelTurn?.parts?.length).toBe(1); + expect(modelTurn?.parts![0]!.text).toBe('Initial valid content...'); + }); it('should not consolidate text into a part that also contains a functionCall', async () => { // 1. Mock the API to stream a malformed part followed by a valid text part. const multiChunkStream = (async function* () { @@ -314,7 +552,10 @@ describe('GeminiChat', () => { // as the important part is consolidating what comes after. yield { candidates: [ - { content: { role: 'model', parts: [{ text: ' World!' }] } }, + { + content: { role: 'model', parts: [{ text: ' World!' }] }, + finishReason: 'STOP', + }, ], } as unknown as GenerateContentResponse; })(); @@ -417,6 +658,7 @@ describe('GeminiChat', () => { { text: 'This is the visible text that should not be lost.' }, ], }, + finishReason: 'STOP', }, ], } as unknown as GenerateContentResponse; @@ -477,7 +719,10 @@ describe('GeminiChat', () => { const emptyStreamResponse = (async function* () { yield { candidates: [ - { content: { role: 'model', parts: [{ thought: true }] } }, + { + content: { role: 'model', parts: [{ thought: true }] }, + finishReason: 'STOP', + }, ], } as unknown as GenerateContentResponse; })(); @@ -732,6 +977,47 @@ describe('GeminiChat', () => { }); describe('sendMessageStream with retries', () => { + it('should yield a RETRY event when an invalid stream is encountered', async () => { + // ARRANGE: Mock the stream to fail once, then succeed. + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce(async () => + // First attempt: An invalid stream with an empty text part. + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], + } as unknown as GenerateContentResponse; + })(), + ) + .mockImplementationOnce(async () => + // Second attempt (the retry): A minimal valid stream. + (async function* () { + yield { + candidates: [ + { + content: { parts: [{ text: 'Success' }] }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + // ACT: Send a message and collect all events from the stream. + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-yield-retry', + ); + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + // ASSERT: Check that a RETRY event was present in the stream's output. + const retryEvent = events.find((e) => e.type === StreamEventType.RETRY); + + expect(retryEvent).toBeDefined(); + expect(retryEvent?.type).toBe(StreamEventType.RETRY); + }); it('should retry on invalid content, succeed, and report metrics', async () => { // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. vi.mocked(mockModelsModule.generateContentStream) @@ -748,7 +1034,10 @@ describe('GeminiChat', () => { (async function* () { yield { candidates: [ - { content: { parts: [{ text: 'Successful response' }] } }, + { + content: { parts: [{ text: 'Successful response' }] }, + finishReason: 'STOP', + }, ], } as unknown as GenerateContentResponse; })(), @@ -758,7 +1047,7 @@ describe('GeminiChat', () => { { message: 'test' }, 'prompt-id-retry-success', ); - const chunks = []; + const chunks: StreamEvent[] = []; for await (const chunk of stream) { chunks.push(chunk); } @@ -768,11 +1057,17 @@ describe('GeminiChat', () => { expect(mockLogContentRetry).toHaveBeenCalledTimes(1); expect(mockLogContentRetryFailure).not.toHaveBeenCalled(); expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + + // Check for a retry event + expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true); + + // Check for the successful content chunk expect( chunks.some( (c) => - c.candidates?.[0]?.content?.parts?.[0]?.text === - 'Successful response', + c.type === StreamEventType.CHUNK && + c.value.candidates?.[0]?.content?.parts?.[0]?.text === + 'Successful response', ), ).toBe(true); @@ -853,7 +1148,12 @@ describe('GeminiChat', () => { // Second attempt succeeds (async function* () { yield { - candidates: [{ content: { parts: [{ text: 'Second answer' }] } }], + candidates: [ + { + content: { parts: [{ text: 'Second answer' }] }, + finishReason: 'STOP', + }, + ], } as unknown as GenerateContentResponse; })(), ); @@ -1002,6 +1302,7 @@ describe('GeminiChat', () => { content: { parts: [{ text: 'Successful response after empty' }], }, + finishReason: 'STOP', }, ], } as unknown as GenerateContentResponse; @@ -1013,7 +1314,7 @@ describe('GeminiChat', () => { { message: 'test empty stream' }, 'prompt-id-empty-stream', ); - const chunks = []; + const chunks: StreamEvent[] = []; for await (const chunk of stream) { chunks.push(chunk); } @@ -1023,8 +1324,9 @@ describe('GeminiChat', () => { expect( chunks.some( (c) => - c.candidates?.[0]?.content?.parts?.[0]?.text === - 'Successful response after empty', + c.type === StreamEventType.CHUNK && + c.value.candidates?.[0]?.content?.parts?.[0]?.text === + 'Successful response after empty', ), ).toBe(true); @@ -1062,13 +1364,23 @@ describe('GeminiChat', () => { } as unknown as GenerateContentResponse; await firstStreamContinuePromise; // Pause the stream yield { - candidates: [{ content: { parts: [{ text: ' part 2' }] } }], + candidates: [ + { + content: { parts: [{ text: ' part 2' }] }, + finishReason: 'STOP', + }, + ], } as unknown as GenerateContentResponse; })(); const secondStreamGenerator = (async function* () { yield { - candidates: [{ content: { parts: [{ text: 'second response' }] } }], + candidates: [ + { + content: { parts: [{ text: 'second response' }] }, + finishReason: 'STOP', + }, + ], } as unknown as GenerateContentResponse; })(); @@ -1123,4 +1435,68 @@ describe('GeminiChat', () => { } expect(turn4.parts[0].text).toBe('second response'); }); + + it('should discard valid partial content from a failed attempt upon retry', async () => { + // ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content. + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce(async () => + // First attempt: yields one valid chunk, then one invalid chunk + (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: 'This valid part should be discarded' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid chunk triggers retry + } as unknown as GenerateContentResponse; + })(), + ) + .mockImplementationOnce(async () => + // Second attempt (the retry): succeeds + (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: 'Successful final response' }], + }, + finishReason: 'STOP', + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + // ACT: Send a message and consume the stream + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-discard-test', + ); + const events: StreamEvent[] = []; + for await (const event of stream) { + events.push(event); + } + + // ASSERT + // Check that a retry happened + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true); + + // Check the final recorded history + const history = chat.getHistory(); + expect(history.length).toBe(2); // user turn + final model turn + + const modelTurn = history[1]!; + // The model turn should only contain the text from the successful attempt + expect(modelTurn!.parts![0]!.text).toBe('Successful final response'); + // It should NOT contain any text from the failed attempt + expect(modelTurn!.parts![0]!.text).not.toContain( + 'This valid part should be discarded', + ); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 76386aa1..e5df53b9 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -34,6 +34,18 @@ import { InvalidChunkEvent, } from '../telemetry/types.js'; +export enum StreamEventType { + /** A regular content chunk from the API. */ + CHUNK = 'chunk', + /** A signal that a retry is about to happen. The UI should discard any partial + * content from the attempt that just failed. */ + RETRY = 'retry', +} + +export type StreamEvent = + | { type: StreamEventType.CHUNK; value: GenerateContentResponse } + | { type: StreamEventType.RETRY }; + /** * Options for retrying due to invalid content from the model. */ @@ -352,7 +364,7 @@ export class GeminiChat { async sendMessageStream( params: SendMessageParameters, prompt_id: string, - ): Promise> { + ): Promise> { await this.sendPromise; let streamDoneResolver: () => void; @@ -379,6 +391,10 @@ export class GeminiChat { attempt++ ) { try { + if (attempt > 0) { + yield { type: StreamEventType.RETRY }; + } + const stream = await self.makeApiCallAndProcessStream( requestContents, params, @@ -387,7 +403,7 @@ export class GeminiChat { ); for await (const chunk of stream) { - yield chunk; + yield { type: StreamEventType.CHUNK, value: chunk }; } lastError = null; @@ -574,29 +590,57 @@ export class GeminiChat { userInput: Content, ): AsyncGenerator { const modelResponseParts: Part[] = []; - let isStreamInvalid = false; let hasReceivedAnyChunk = false; + let hasReceivedValidChunk = false; + let hasToolCall = false; + let lastChunk: GenerateContentResponse | null = null; + let lastChunkIsInvalid = false; for await (const chunk of streamResponse) { hasReceivedAnyChunk = true; + lastChunk = chunk; + if (isValidResponse(chunk)) { + hasReceivedValidChunk = true; + lastChunkIsInvalid = false; const content = chunk.candidates?.[0]?.content; if (content?.parts) { modelResponseParts.push(...content.parts); + if (content.parts.some((part) => part.functionCall)) { + hasToolCall = true; + } } } else { logInvalidChunk( this.config, new InvalidChunkEvent('Invalid chunk received from stream.'), ); - isStreamInvalid = true; + lastChunkIsInvalid = true; } yield chunk; } - if (isStreamInvalid || !hasReceivedAnyChunk) { + if (!hasReceivedAnyChunk) { + throw new EmptyStreamError('Model stream completed without any chunks.'); + } + + const hasFinishReason = lastChunk?.candidates?.some( + (candidate) => candidate.finishReason, + ); + + // Stream validation logic: A stream is considered successful if: + // 1. There's a tool call (tool calls can end without explicit finish reasons), OR + // 2. There's a finish reason AND the last chunk is valid (or we haven't received any valid chunks) + // + // We throw an error only when there's no tool call AND: + // - No finish reason, OR + // - Last chunk is invalid after receiving valid content + if ( + !hasToolCall && + (!hasFinishReason || (lastChunkIsInvalid && !hasReceivedValidChunk)) + ) { throw new EmptyStreamError( - 'Model stream was invalid or completed without valid content.', + 'Model stream ended with an invalid chunk or missing finish reason.', ); } diff --git a/packages/core/src/core/subagent.test.ts b/packages/core/src/core/subagent.test.ts index f2073306..cc54037b 100644 --- a/packages/core/src/core/subagent.test.ts +++ b/packages/core/src/core/subagent.test.ts @@ -21,7 +21,7 @@ import type { } from './subagent.js'; import { Config } from '../config/config.js'; import type { ConfigParameters } from '../config/config.js'; -import { GeminiChat } from './geminiChat.js'; +import { GeminiChat, StreamEventType } from './geminiChat.js'; import { createContentGenerator } from './contentGenerator.js'; import { getEnvironmentContext } from '../utils/environmentContext.js'; import { executeToolCall } from './nonInteractiveToolExecutor.js'; @@ -33,6 +33,7 @@ import type { FunctionCall, FunctionDeclaration, GenerateContentConfig, + GenerateContentResponse, } from '@google/genai'; import { ToolErrorType } from '../tools/tool-error.js'; @@ -73,18 +74,33 @@ const createMockStream = ( functionCallsList: Array, ) => { let index = 0; - return vi.fn().mockImplementation(() => { + // This mock now returns a Promise that resolves to the async generator, + // matching the new signature for sendMessageStream. + return vi.fn().mockImplementation(async () => { const response = functionCallsList[index] || 'stop'; index++; + return (async function* () { - if (response === 'stop') { - // When stopping, the model might return text, but the subagent logic primarily cares about the absence of functionCalls. - yield { text: 'Done.' }; - } else if (response.length > 0) { - yield { functionCalls: response }; + let mockResponseValue: Partial; + + if (response === 'stop' || response.length === 0) { + // Simulate a text response for stop/empty conditions. + mockResponseValue = { + candidates: [{ content: { parts: [{ text: 'Done.' }] } }], + }; } else { - yield { text: 'Done.' }; // Handle empty array also as stop + // Simulate a tool call response. + mockResponseValue = { + candidates: [], // Good practice to include for safety. + functionCalls: response, + }; } + + // The stream must now yield a StreamEvent object of type CHUNK. + yield { + type: StreamEventType.CHUNK, + value: mockResponseValue as GenerateContentResponse, + }; })(); }); }; diff --git a/packages/core/src/core/subagent.ts b/packages/core/src/core/subagent.ts index 75f77c57..41de5978 100644 --- a/packages/core/src/core/subagent.ts +++ b/packages/core/src/core/subagent.ts @@ -20,7 +20,7 @@ import type { FunctionDeclaration, } from '@google/genai'; import { Type } from '@google/genai'; -import { GeminiChat } from './geminiChat.js'; +import { GeminiChat, StreamEventType } from './geminiChat.js'; /** * @fileoverview Defines the configuration interfaces for a subagent. @@ -439,12 +439,11 @@ export class SubAgentScope { let textResponse = ''; for await (const resp of responseStream) { if (abortController.signal.aborted) return; - if (resp.functionCalls) { - functionCalls.push(...resp.functionCalls); + if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) { + functionCalls.push(...resp.value.functionCalls); } - const text = resp.text; - if (text) { - textResponse += text; + if (resp.type === StreamEventType.CHUNK && resp.value.text) { + textResponse += resp.value.text; } } diff --git a/packages/core/src/core/turn.test.ts b/packages/core/src/core/turn.test.ts index c749baba..87be432b 100644 --- a/packages/core/src/core/turn.test.ts +++ b/packages/core/src/core/turn.test.ts @@ -13,6 +13,7 @@ import { Turn, GeminiEventType } from './turn.js'; import type { GenerateContentResponse, Part, Content } from '@google/genai'; import { reportError } from '../utils/errorReporting.js'; import type { GeminiChat } from './geminiChat.js'; +import { StreamEventType } from './geminiChat.js'; const mockSendMessageStream = vi.fn(); const mockGetHistory = vi.fn(); @@ -35,6 +36,7 @@ vi.mock('../utils/errorReporting', () => ({ reportError: vi.fn(), })); +// Use the actual implementation from partUtils now that it's provided. vi.mock('../utils/generateContentResponseUtilities', () => ({ getResponseText: (resp: GenerateContentResponse) => resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') || @@ -78,11 +80,17 @@ describe('Turn', () => { it('should yield content events for text parts', async () => { const mockResponseStream = (async function* () { yield { - candidates: [{ content: { parts: [{ text: 'Hello' }] } }], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Hello' }] } }], + } as GenerateContentResponse, + }; yield { - candidates: [{ content: { parts: [{ text: ' world' }] } }], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: ' world' }] } }], + } as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -113,16 +121,23 @@ describe('Turn', () => { it('should yield tool_call_request events for function calls', async () => { const mockResponseStream = (async function* () { yield { - functionCalls: [ - { - id: 'fc1', - name: 'tool1', - args: { arg1: 'val1' }, - isClientInitiated: false, - }, - { name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + functionCalls: [ + { + id: 'fc1', + name: 'tool1', + args: { arg1: 'val1' }, + isClientInitiated: false, + }, + { + name: 'tool2', + args: { arg2: 'val2' }, + isClientInitiated: false, + }, // No ID + ], + } as unknown as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -168,18 +183,24 @@ describe('Turn', () => { const abortController = new AbortController(); const mockResponseStream = (async function* () { yield { - candidates: [{ content: { parts: [{ text: 'First part' }] } }], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'First part' }] } }], + } as GenerateContentResponse, + }; abortController.abort(); yield { - candidates: [ - { - content: { - parts: [{ text: 'Second part - should not be processed' }], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { + parts: [{ text: 'Second part - should not be processed' }], + }, }, - }, - ], - } as unknown as GenerateContentResponse; + ], + } as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -230,79 +251,79 @@ describe('Turn', () => { it('should handle function calls with undefined name or args', async () => { const mockResponseStream = (async function* () { yield { - functionCalls: [ - { id: 'fc1', name: undefined, args: { arg1: 'val1' } }, - { id: 'fc2', name: 'tool2', args: undefined }, - { id: 'fc3', name: undefined, args: undefined }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [], + functionCalls: [ + // Add `id` back to the mock to match what the code expects + { id: 'fc1', name: undefined, args: { arg1: 'val1' } }, + { id: 'fc2', name: 'tool2', args: undefined }, + { id: 'fc3', name: undefined, args: undefined }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); + const events = []; - const reqParts: Part[] = [{ text: 'Test undefined tool parts' }]; for await (const event of turn.run( - reqParts, + [{ text: 'Test undefined tool parts' }], new AbortController().signal, )) { events.push(event); } expect(events.length).toBe(3); + + // Assertions for each specific tool call event const event1 = events[0] as ServerGeminiToolCallRequestEvent; - expect(event1.type).toBe(GeminiEventType.ToolCallRequest); - expect(event1.value).toEqual( - expect.objectContaining({ - callId: 'fc1', - name: 'undefined_tool_name', - args: { arg1: 'val1' }, - isClientInitiated: false, - }), - ); - expect(turn.pendingToolCalls[0]).toEqual(event1.value); + expect(event1.value).toMatchObject({ + callId: 'fc1', + name: 'undefined_tool_name', + args: { arg1: 'val1' }, + }); const event2 = events[1] as ServerGeminiToolCallRequestEvent; - expect(event2.type).toBe(GeminiEventType.ToolCallRequest); - expect(event2.value).toEqual( - expect.objectContaining({ - callId: 'fc2', - name: 'tool2', - args: {}, - isClientInitiated: false, - }), - ); - expect(turn.pendingToolCalls[1]).toEqual(event2.value); + expect(event2.value).toMatchObject({ + callId: 'fc2', + name: 'tool2', + args: {}, + }); const event3 = events[2] as ServerGeminiToolCallRequestEvent; - expect(event3.type).toBe(GeminiEventType.ToolCallRequest); - expect(event3.value).toEqual( - expect.objectContaining({ - callId: 'fc3', - name: 'undefined_tool_name', - args: {}, - isClientInitiated: false, - }), - ); - expect(turn.pendingToolCalls[2]).toEqual(event3.value); - expect(turn.getDebugResponses().length).toBe(1); + expect(event3.value).toMatchObject({ + callId: 'fc3', + name: 'undefined_tool_name', + args: {}, + }); }); it('should yield finished event when response has finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Partial response' }] }, - finishReason: 'STOP', + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Partial response' }] }, + finishReason: 'STOP', + }, + ], + usageMetadata: { + promptTokenCount: 17, + candidatesTokenCount: 50, + cachedContentTokenCount: 10, + thoughtsTokenCount: 5, + toolUsePromptTokenCount: 2, }, - ], - } as unknown as GenerateContentResponse; + } as GenerateContentResponse, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const events = []; - const reqParts: Part[] = [{ text: 'Test finish reason' }]; for await (const event of turn.run( - reqParts, + [{ text: 'Test finish reason' }], new AbortController().signal, )) { events.push(event); @@ -317,17 +338,20 @@ describe('Turn', () => { it('should yield finished event for MAX_TOKENS finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { - parts: [ - { text: 'This is a long response that was cut off...' }, - ], + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { + parts: [ + { text: 'This is a long response that was cut off...' }, + ], + }, + finishReason: 'MAX_TOKENS', }, - finishReason: 'MAX_TOKENS', - }, - ], - } as unknown as GenerateContentResponse; + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -352,13 +376,16 @@ describe('Turn', () => { it('should yield finished event for SAFETY finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Content blocked' }] }, - finishReason: 'SAFETY', - }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'Content blocked' }] }, + finishReason: 'SAFETY', + }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -380,13 +407,18 @@ describe('Turn', () => { it('should not yield finished event when there is no finish reason', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'Response without finish reason' }] }, - // No finishReason property - }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { + parts: [{ text: 'Response without finish reason' }], + }, + // No finishReason property + }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -411,21 +443,27 @@ describe('Turn', () => { it('should handle multiple responses with different finish reasons', async () => { const mockResponseStream = (async function* () { yield { - candidates: [ - { - content: { parts: [{ text: 'First part' }] }, - // No finish reason on first response - }, - ], - } as unknown as GenerateContentResponse; + type: StreamEventType.CHUNK, + value: { + candidates: [ + { + content: { parts: [{ text: 'First part' }] }, + // No finish reason on first response + }, + ], + }, + }; yield { - candidates: [ - { - content: { parts: [{ text: 'Second part' }] }, - finishReason: 'OTHER', - }, - ], - } as unknown as GenerateContentResponse; + value: { + type: StreamEventType.CHUNK, + candidates: [ + { + content: { parts: [{ text: 'Second part' }] }, + finishReason: 'OTHER', + }, + ], + }, + }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); @@ -470,6 +508,29 @@ describe('Turn', () => { expect(reportError).not.toHaveBeenCalled(); }); + + it('should yield a Retry event when it receives one from the chat stream', async () => { + const mockResponseStream = (async function* () { + yield { type: StreamEventType.RETRY }; + yield { + type: StreamEventType.CHUNK, + value: { + candidates: [{ content: { parts: [{ text: 'Success' }] } }], + }, + }; + })(); + mockSendMessageStream.mockResolvedValue(mockResponseStream); + + const events = []; + for await (const event of turn.run([], new AbortController().signal)) { + events.push(event); + } + + expect(events).toEqual([ + { type: GeminiEventType.Retry }, + { type: GeminiEventType.Content, value: 'Success' }, + ]); + }); }); describe('getDebugResponses', () => { @@ -481,8 +542,8 @@ describe('Turn', () => { functionCalls: [{ name: 'debugTool' }], } as unknown as GenerateContentResponse; const mockResponseStream = (async function* () { - yield resp1; - yield resp2; + yield { type: StreamEventType.CHUNK, value: resp1 }; + yield { type: StreamEventType.CHUNK, value: resp2 }; })(); mockSendMessageStream.mockResolvedValue(mockResponseStream); const reqParts: Part[] = [{ text: 'Hi' }]; diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 06c4af9d..8dd8377c 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -55,8 +55,14 @@ export enum GeminiEventType { SessionTokenLimitExceeded = 'session_token_limit_exceeded', Finished = 'finished', LoopDetected = 'loop_detected', + Citation = 'citation', + Retry = 'retry', } +export type ServerGeminiRetryEvent = { + type: GeminiEventType.Retry; +}; + export interface StructuredError { message: string; status?: number; @@ -188,7 +194,8 @@ export type ServerGeminiStreamEvent = | ServerGeminiMaxSessionTurnsEvent | ServerGeminiSessionTokenLimitExceededEvent | ServerGeminiFinishedEvent - | ServerGeminiLoopDetectedEvent; + | ServerGeminiLoopDetectedEvent + | ServerGeminiRetryEvent; // A turn manages the agentic loop turn within the server context. export class Turn { @@ -210,6 +217,8 @@ export class Turn { signal: AbortSignal, ): AsyncGenerator { try { + // Note: This assumes `sendMessageStream` yields events like + // { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse } const responseStream = await this.chat.sendMessageStream( { message: req, @@ -220,12 +229,22 @@ export class Turn { this.prompt_id, ); - for await (const resp of responseStream) { + for await (const streamEvent of responseStream) { if (signal?.aborted) { yield { type: GeminiEventType.UserCancelled }; - // Do not add resp to debugResponses if aborted before processing return; } + + // Handle the new RETRY event + if (streamEvent.type === 'retry') { + yield { type: GeminiEventType.Retry }; + continue; // Skip to the next event in the stream + } + + // Assuming other events are chunks with a `value` property + const resp = streamEvent.value as GenerateContentResponse; + if (!resp) continue; // Skip if there's no response body + this.debugResponses.push(resp); const thoughtPart = resp.candidates?.[0]?.content?.parts?.[0]; @@ -267,6 +286,7 @@ export class Turn { // Check if response was truncated or stopped for various reasons const finishReason = resp.candidates?.[0]?.finishReason; + // This is the key change: Only yield 'Finished' if there is a finishReason. if (finishReason) { this.finishReason = finishReason; yield { diff --git a/packages/core/src/ide/process-utils.test.ts b/packages/core/src/ide/process-utils.test.ts index 9ac56424..e6c68f14 100644 --- a/packages/core/src/ide/process-utils.test.ts +++ b/packages/core/src/ide/process-utils.test.ts @@ -66,13 +66,34 @@ describe('getIdeProcessInfo', () => { it('should traverse up and find the great-grandchild of the root process', async () => { (os.platform as Mock).mockReturnValue('win32'); const processInfoMap = new Map([ - [1000, { stdout: 'ParentProcessId=900\r\nCommandLine=node.exe\r\n' }], + [ + 1000, + { + stdout: + '{"Name":"node.exe","ParentProcessId":900,"CommandLine":"node.exe"}', + }, + ], [ 900, - { stdout: 'ParentProcessId=800\r\nCommandLine=powershell.exe\r\n' }, + { + stdout: + '{"Name":"powershell.exe","ParentProcessId":800,"CommandLine":"powershell.exe"}', + }, + ], + [ + 800, + { + stdout: + '{"Name":"code.exe","ParentProcessId":700,"CommandLine":"code.exe"}', + }, + ], + [ + 700, + { + stdout: + '{"Name":"wininit.exe","ParentProcessId":0,"CommandLine":"wininit.exe"}', + }, ], - [800, { stdout: 'ParentProcessId=700\r\nCommandLine=code.exe\r\n' }], - [700, { stdout: 'ParentProcessId=0\r\nCommandLine=wininit.exe\r\n' }], ]); mockedExec.mockImplementation((command: string) => { const pidMatch = command.match(/ProcessId=(\d+)/); @@ -86,5 +107,90 @@ describe('getIdeProcessInfo', () => { const result = await getIdeProcessInfo(); expect(result).toEqual({ pid: 900, command: 'powershell.exe' }); }); + + it('should handle non-existent process gracefully', async () => { + (os.platform as Mock).mockReturnValue('win32'); + mockedExec + .mockResolvedValueOnce({ stdout: '' }) // Non-existent PID returns empty due to -ErrorAction SilentlyContinue + .mockResolvedValueOnce({ + stdout: + '{"Name":"fallback.exe","ParentProcessId":0,"CommandLine":"fallback.exe"}', + }); // Fallback call + + const result = await getIdeProcessInfo(); + expect(result).toEqual({ pid: 1000, command: 'fallback.exe' }); + }); + + it('should handle malformed JSON output gracefully', async () => { + (os.platform as Mock).mockReturnValue('win32'); + mockedExec + .mockResolvedValueOnce({ stdout: '{"invalid":json}' }) // Malformed JSON + .mockResolvedValueOnce({ + stdout: + '{"Name":"fallback.exe","ParentProcessId":0,"CommandLine":"fallback.exe"}', + }); // Fallback call + + const result = await getIdeProcessInfo(); + expect(result).toEqual({ pid: 1000, command: 'fallback.exe' }); + }); + + it('should handle PowerShell errors without crashing the process chain', async () => { + (os.platform as Mock).mockReturnValue('win32'); + const processInfoMap = new Map([ + [1000, { stdout: '' }], // First process doesn't exist (empty due to -ErrorAction) + [ + 1001, + { + stdout: + '{"Name":"parent.exe","ParentProcessId":800,"CommandLine":"parent.exe"}', + }, + ], + [ + 800, + { + stdout: + '{"Name":"ide.exe","ParentProcessId":0,"CommandLine":"ide.exe"}', + }, + ], + ]); + + // Mock the process.pid to test traversal with missing processes + Object.defineProperty(process, 'pid', { + value: 1001, + configurable: true, + }); + + mockedExec.mockImplementation((command: string) => { + const pidMatch = command.match(/ProcessId=(\d+)/); + if (pidMatch) { + const pid = parseInt(pidMatch[1], 10); + return Promise.resolve(processInfoMap.get(pid) || { stdout: '' }); + } + return Promise.reject(new Error('Invalid command for mock')); + }); + + const result = await getIdeProcessInfo(); + // Should return the current process command since traversal continues despite missing processes + expect(result).toEqual({ pid: 1001, command: 'parent.exe' }); + + // Reset process.pid + Object.defineProperty(process, 'pid', { + value: 1000, + configurable: true, + }); + }); + + it('should handle partial JSON data with defaults', async () => { + (os.platform as Mock).mockReturnValue('win32'); + mockedExec + .mockResolvedValueOnce({ stdout: '{"Name":"partial.exe"}' }) // Missing ParentProcessId, defaults to 0 + .mockResolvedValueOnce({ + stdout: + '{"Name":"root.exe","ParentProcessId":0,"CommandLine":"root.exe"}', + }); // Get grandparent info + + const result = await getIdeProcessInfo(); + expect(result).toEqual({ pid: 1000, command: 'root.exe' }); + }); }); }); diff --git a/packages/core/src/ide/process-utils.ts b/packages/core/src/ide/process-utils.ts index 17d1d520..ecb93781 100644 --- a/packages/core/src/ide/process-utils.ts +++ b/packages/core/src/ide/process-utils.ts @@ -24,30 +24,44 @@ async function getProcessInfo(pid: number): Promise<{ name: string; command: string; }> { - const platform = os.platform(); - if (platform === 'win32') { - const command = `wmic process where "ProcessId=${pid}" get Name,ParentProcessId,CommandLine /value`; - const { stdout } = await execAsync(command); - const nameMatch = stdout.match(/Name=([^\n]*)/); - const processName = nameMatch ? nameMatch[1].trim() : ''; - const ppidMatch = stdout.match(/ParentProcessId=(\d+)/); - const parentPid = ppidMatch ? parseInt(ppidMatch[1], 10) : 0; - const commandLineMatch = stdout.match(/CommandLine=([^\n]*)/); - const commandLine = commandLineMatch ? commandLineMatch[1].trim() : ''; - return { parentPid, name: processName, command: commandLine }; - } else { - const command = `ps -o ppid=,command= -p ${pid}`; - const { stdout } = await execAsync(command); - const trimmedStdout = stdout.trim(); - const ppidString = trimmedStdout.split(/\s+/)[0]; - const parentPid = parseInt(ppidString, 10); - const fullCommand = trimmedStdout.substring(ppidString.length).trim(); - const processName = path.basename(fullCommand.split(' ')[0]); - return { - parentPid: isNaN(parentPid) ? 1 : parentPid, - name: processName, - command: fullCommand, - }; + try { + const platform = os.platform(); + if (platform === 'win32') { + const powershellCommand = [ + '$p = Get-CimInstance Win32_Process', + `-Filter 'ProcessId=${pid}'`, + '-ErrorAction SilentlyContinue;', + 'if ($p) {', + '@{Name=$p.Name;ParentProcessId=$p.ParentProcessId;CommandLine=$p.CommandLine}', + '| ConvertTo-Json', + '}', + ].join(' '); + const { stdout } = await execAsync(`powershell "${powershellCommand}"`); + const output = stdout.trim(); + if (!output) return { parentPid: 0, name: '', command: '' }; + const { + Name = '', + ParentProcessId = 0, + CommandLine = '', + } = JSON.parse(output); + return { parentPid: ParentProcessId, name: Name, command: CommandLine }; + } else { + const command = `ps -o ppid=,command= -p ${pid}`; + const { stdout } = await execAsync(command); + const trimmedStdout = stdout.trim(); + const ppidString = trimmedStdout.split(/\s+/)[0]; + const parentPid = parseInt(ppidString, 10); + const fullCommand = trimmedStdout.substring(ppidString.length).trim(); + const processName = path.basename(fullCommand.split(' ')[0]); + return { + parentPid: isNaN(parentPid) ? 1 : parentPid, + name: processName, + command: fullCommand, + }; + } + } catch (_e) { + console.debug(`Failed to get process info for pid ${pid}:`, _e); + return { parentPid: 0, name: '', command: '' }; } } @@ -160,7 +174,6 @@ async function getIdeProcessInfoForWindows(): Promise<{ * top-level ancestor process ID and command as a fallback. * * @returns A promise that resolves to the PID and command of the IDE process. - * @throws Will throw an error if the underlying shell commands fail. */ export async function getIdeProcessInfo(): Promise<{ pid: number;