From bb8a23ae803d097814563a7fb70e29c9dcaf7175 Mon Sep 17 00:00:00 2001 From: Victor May Date: Fri, 22 Aug 2025 15:43:53 -0400 Subject: [PATCH] Retry Message Stream on Empty Chunks (#6777) --- .../cli/src/ui/hooks/useGeminiStream.test.tsx | 251 ++++++++++- packages/core/src/core/geminiChat.test.ts | 393 ++++++++++++++++- packages/core/src/core/geminiChat.ts | 402 +++++++++--------- 3 files changed, 842 insertions(+), 204 deletions(-) diff --git a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx index f08f6606..52188436 100644 --- a/packages/cli/src/ui/hooks/useGeminiStream.test.tsx +++ b/packages/cli/src/ui/hooks/useGeminiStream.test.tsx @@ -1793,7 +1793,6 @@ describe('useGeminiStream', () => { const userMessageTimestamp = Date.now(); vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp); - // Mock the behavior of handleAtCommand handleAtCommandSpy.mockResolvedValue({ processedQuery: processedQueryParts, shouldProceed: true, @@ -1818,20 +1817,16 @@ describe('useGeminiStream', () => { ), ); - // Act: Submit the query await act(async () => { await result.current.submitQuery(rawQuery); }); - // Assert - // 1. Verify handleAtCommand was called with the raw query. expect(handleAtCommandSpy).toHaveBeenCalledWith( expect.objectContaining({ query: rawQuery, }), ); - // 2. Verify the user's turn was added to history *after* processing. expect(mockAddItem).toHaveBeenCalledWith( { type: MessageType.USER, @@ -1840,11 +1835,249 @@ describe('useGeminiStream', () => { userMessageTimestamp, ); - // 3. Verify the *processed* query was sent to the model, not the raw one. + // FIX: The expectation now matches the actual call signature. expect(mockSendMessageStream).toHaveBeenCalledWith( - processedQueryParts, - expect.any(AbortSignal), - expect.any(String), + processedQueryParts, // Argument 1: The parts array directly + expect.any(AbortSignal), // Argument 2: An AbortSignal + expect.any(String), // Argument 3: The prompt_id string + ); + }); + describe('Thought Reset', () => { + it('should reset thought to null when starting a new prompt', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.Thought, + value: { + subject: 'Previous thought', + description: 'Old description', + }, + }; + yield { + type: ServerGeminiEventType.Content, + value: 'Some response content', + }; + yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + })(), + ); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('First query'); + }); + + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'gemini', + text: 'Some response content', + }), + expect.any(Number), + ); + }); + + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.Content, + value: 'New response content', + }; + yield { type: ServerGeminiEventType.Finished, value: 'STOP' }; + })(), + ); + + await act(async () => { + await result.current.submitQuery('Second query'); + }); + + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'gemini', + text: 'New response content', + }), + expect.any(Number), + ); + }); + }); + + it('should reset thought to null when user cancels', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.Thought, + value: { subject: 'Some thought', description: 'Description' }, + }; + yield { type: ServerGeminiEventType.UserCancelled }; + })(), + ); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('Test query'); + }); + + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'info', + text: 'User cancelled the request.', + }), + expect.any(Number), + ); + }); + + expect(result.current.streamingState).toBe(StreamingState.Idle); + }); + + it('should reset thought to null when there is an error', async () => { + mockSendMessageStream.mockReturnValue( + (async function* () { + yield { + type: ServerGeminiEventType.Thought, + value: { subject: 'Some thought', description: 'Description' }, + }; + yield { + type: ServerGeminiEventType.Error, + value: { error: { message: 'Test error' } }, + }; + })(), + ); + + const { result } = renderHook(() => + useGeminiStream( + new MockedGeminiClientClass(mockConfig), + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + () => 'vscode' as EditorType, + () => {}, + () => Promise.resolve(), + false, + () => {}, + () => {}, + () => {}, + ), + ); + + await act(async () => { + await result.current.submitQuery('Test query'); + }); + + await waitFor(() => { + expect(mockAddItem).toHaveBeenCalledWith( + expect.objectContaining({ + type: 'error', + }), + expect.any(Number), + ); + }); + + expect(mockParseAndFormatApiError).toHaveBeenCalledWith( + { message: 'Test error' }, + expect.any(String), + undefined, + 'gemini-2.5-pro', + 'gemini-2.5-flash', + ); + }); + }); + + it('should process @include commands, adding user turn after processing to prevent race conditions', async () => { + const rawQuery = '@include file.txt Summarize this.'; + const processedQueryParts = [ + { text: 'Summarize this with content from @file.txt' }, + { text: 'File content...' }, + ]; + const userMessageTimestamp = Date.now(); + vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp); + + handleAtCommandSpy.mockResolvedValue({ + processedQuery: processedQueryParts, + shouldProceed: true, + }); + + const { result } = renderHook(() => + useGeminiStream( + mockConfig.getGeminiClient() as GeminiClient, + [], + mockAddItem, + mockConfig, + mockOnDebugMessage, + mockHandleSlashCommand, + false, + vi.fn(), + vi.fn(), + vi.fn(), + false, + vi.fn(), + vi.fn(), + vi.fn(), + ), + ); + + await act(async () => { + await result.current.submitQuery(rawQuery); + }); + + expect(handleAtCommandSpy).toHaveBeenCalledWith( + expect.objectContaining({ + query: rawQuery, + }), + ); + + expect(mockAddItem).toHaveBeenCalledWith( + { + type: MessageType.USER, + text: rawQuery, + }, + userMessageTimestamp, + ); + + // FIX: This expectation now correctly matches the actual function call signature. + expect(mockSendMessageStream).toHaveBeenCalledWith( + processedQueryParts, // Argument 1: The parts array directly + expect.any(AbortSignal), // Argument 2: An AbortSignal + expect.any(String), // Argument 3: The prompt_id string ); }); }); diff --git a/packages/core/src/core/geminiChat.test.ts b/packages/core/src/core/geminiChat.test.ts index cd5e3841..c4fb7f0f 100644 --- a/packages/core/src/core/geminiChat.test.ts +++ b/packages/core/src/core/geminiChat.test.ts @@ -12,7 +12,7 @@ import { Part, GenerateContentResponse, } from '@google/genai'; -import { GeminiChat } from './geminiChat.js'; +import { GeminiChat, EmptyStreamError } from './geminiChat.js'; import { Config } from '../config/config.js'; import { setSimulate429 } from '../utils/testUtils.js'; @@ -112,7 +112,13 @@ describe('GeminiChat', () => { response, ); - await chat.sendMessageStream({ message: 'hello' }, 'prompt-id-1'); + const stream = await chat.sendMessageStream( + { message: 'hello' }, + 'prompt-id-1', + ); + for await (const _ of stream) { + // consume stream to trigger internal logic + } expect(mockModelsModule.generateContentStream).toHaveBeenCalledWith( { @@ -475,4 +481,387 @@ describe('GeminiChat', () => { expect(history[1]).toEqual(content2); }); }); + + describe('sendMessageStream with retries', () => { + it('should retry on invalid content and succeed on the second attempt', async () => { + // Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt. + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce(async () => + // First call returns an invalid stream + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid empty text part + } as unknown as GenerateContentResponse; + })(), + ) + .mockImplementationOnce(async () => + // Second call returns a valid stream + (async function* () { + yield { + candidates: [ + { content: { parts: [{ text: 'Successful response' }] } }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-retry-success', + ); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + + // Assertions + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect( + chunks.some( + (c) => + c.candidates?.[0]?.content?.parts?.[0]?.text === + 'Successful response', + ), + ).toBe(true); + + // Check that history was recorded correctly once, with no duplicates. + const history = chat.getHistory(); + expect(history.length).toBe(2); + expect(history[0]).toEqual({ + role: 'user', + parts: [{ text: 'test' }], + }); + expect(history[1]).toEqual({ + role: 'model', + parts: [{ text: 'Successful response' }], + }); + }); + + it('should fail after all retries on persistent invalid content', async () => { + vi.mocked(mockModelsModule.generateContentStream).mockImplementation( + async () => + (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: '' }], + role: 'model', + }, + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + // This helper function consumes the stream and allows us to test for rejection. + async function consumeStreamAndExpectError() { + const stream = await chat.sendMessageStream( + { message: 'test' }, + 'prompt-id-retry-fail', + ); + for await (const _ of stream) { + // Must loop to trigger the internal logic that throws. + } + } + + await expect(consumeStreamAndExpectError()).rejects.toThrow( + EmptyStreamError, + ); + + // Should be called 3 times (initial + 2 retries) + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(3); + + // History should be clean, as if the failed turn never happened. + const history = chat.getHistory(); + expect(history.length).toBe(0); + }); + }); + it('should correctly retry and append to an existing history mid-conversation', async () => { + // 1. Setup + const initialHistory: Content[] = [ + { role: 'user', parts: [{ text: 'First question' }] }, + { role: 'model', parts: [{ text: 'First answer' }] }, + ]; + chat.setHistory(initialHistory); + + // 2. Mock the API + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: '' }] } }], + } as unknown as GenerateContentResponse; + })(), + ) + .mockImplementationOnce(async () => + (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'Second answer' }] } }], + } as unknown as GenerateContentResponse; + })(), + ); + + // 3. Send a new message + const stream = await chat.sendMessageStream( + { message: 'Second question' }, + 'prompt-id-retry-existing', + ); + for await (const _ of stream) { + // consume stream + } + + // 4. Assert the final history + const history = chat.getHistory(); + expect(history.length).toBe(4); + + // Explicitly verify the structure of each part to satisfy TypeScript + const turn1 = history[0]; + if (!turn1?.parts?.[0] || !('text' in turn1.parts[0])) { + throw new Error('Test setup error: First turn is not a valid text part.'); + } + expect(turn1.parts[0].text).toBe('First question'); + + const turn2 = history[1]; + if (!turn2?.parts?.[0] || !('text' in turn2.parts[0])) { + throw new Error( + 'Test setup error: Second turn is not a valid text part.', + ); + } + expect(turn2.parts[0].text).toBe('First answer'); + + const turn3 = history[2]; + if (!turn3?.parts?.[0] || !('text' in turn3.parts[0])) { + throw new Error('Test setup error: Third turn is not a valid text part.'); + } + expect(turn3.parts[0].text).toBe('Second question'); + + const turn4 = history[3]; + if (!turn4?.parts?.[0] || !('text' in turn4.parts[0])) { + throw new Error( + 'Test setup error: Fourth turn is not a valid text part.', + ); + } + expect(turn4.parts[0].text).toBe('Second answer'); + }); + + describe('concurrency control', () => { + it('should queue a subsequent sendMessage call until the first one completes', async () => { + // 1. Create promises to manually control when the API calls resolve + let firstCallResolver: (value: GenerateContentResponse) => void; + const firstCallPromise = new Promise( + (resolve) => { + firstCallResolver = resolve; + }, + ); + + let secondCallResolver: (value: GenerateContentResponse) => void; + const secondCallPromise = new Promise( + (resolve) => { + secondCallResolver = resolve; + }, + ); + + // A standard response body for the mock + const mockResponse = { + candidates: [ + { + content: { parts: [{ text: 'response' }], role: 'model' }, + }, + ], + } as unknown as GenerateContentResponse; + + // 2. Mock the API to return our controllable promises in order + vi.mocked(mockModelsModule.generateContent) + .mockReturnValueOnce(firstCallPromise) + .mockReturnValueOnce(secondCallPromise); + + // 3. Start the first message call. Do not await it yet. + const firstMessagePromise = chat.sendMessage( + { message: 'first' }, + 'prompt-1', + ); + + // Give the event loop a chance to run the async call up to the `await` + await new Promise(process.nextTick); + + // 4. While the first call is "in-flight", start the second message call. + const secondMessagePromise = chat.sendMessage( + { message: 'second' }, + 'prompt-2', + ); + + // 5. CRUCIAL CHECK: At this point, only the first API call should have been made. + // The second call should be waiting on `sendPromise`. + expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(1); + expect(mockModelsModule.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + contents: expect.arrayContaining([ + expect.objectContaining({ parts: [{ text: 'first' }] }), + ]), + }), + 'prompt-1', + ); + + // 6. Unblock the first API call and wait for the first message to fully complete. + firstCallResolver!(mockResponse); + await firstMessagePromise; + + // Give the event loop a chance to unblock and run the second call. + await new Promise(process.nextTick); + + // 7. CRUCIAL CHECK: Now, the second API call should have been made. + expect(mockModelsModule.generateContent).toHaveBeenCalledTimes(2); + expect(mockModelsModule.generateContent).toHaveBeenCalledWith( + expect.objectContaining({ + contents: expect.arrayContaining([ + expect.objectContaining({ parts: [{ text: 'second' }] }), + ]), + }), + 'prompt-2', + ); + + // 8. Clean up by resolving the second call. + secondCallResolver!(mockResponse); + await secondMessagePromise; + }); + }); + it('should retry if the model returns a completely empty stream (no chunks)', async () => { + // 1. Mock the API to return an empty stream first, then a valid one. + vi.mocked(mockModelsModule.generateContentStream) + .mockImplementationOnce( + // First call resolves to an async generator that yields nothing. + async () => (async function* () {})(), + ) + .mockImplementationOnce( + // Second call returns a valid stream. + async () => + (async function* () { + yield { + candidates: [ + { + content: { + parts: [{ text: 'Successful response after empty' }], + }, + }, + ], + } as unknown as GenerateContentResponse; + })(), + ); + + // 2. Call the method and consume the stream. + const stream = await chat.sendMessageStream( + { message: 'test empty stream' }, + 'prompt-id-empty-stream', + ); + const chunks = []; + for await (const chunk of stream) { + chunks.push(chunk); + } + + // 3. Assert the results. + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + expect( + chunks.some( + (c) => + c.candidates?.[0]?.content?.parts?.[0]?.text === + 'Successful response after empty', + ), + ).toBe(true); + + const history = chat.getHistory(); + expect(history.length).toBe(2); + + // Explicitly verify the structure of each part to satisfy TypeScript + const turn1 = history[0]; + if (!turn1?.parts?.[0] || !('text' in turn1.parts[0])) { + throw new Error('Test setup error: First turn is not a valid text part.'); + } + expect(turn1.parts[0].text).toBe('test empty stream'); + + const turn2 = history[1]; + if (!turn2?.parts?.[0] || !('text' in turn2.parts[0])) { + throw new Error( + 'Test setup error: Second turn is not a valid text part.', + ); + } + expect(turn2.parts[0].text).toBe('Successful response after empty'); + }); + it('should queue a subsequent sendMessageStream call until the first stream is fully consumed', async () => { + // 1. Create a promise to manually control the stream's lifecycle + let continueFirstStream: () => void; + const firstStreamContinuePromise = new Promise((resolve) => { + continueFirstStream = resolve; + }); + + // 2. Mock the API to return controllable async generators + const firstStreamGenerator = (async function* () { + yield { + candidates: [ + { content: { parts: [{ text: 'first response part 1' }] } }, + ], + } as unknown as GenerateContentResponse; + await firstStreamContinuePromise; // Pause the stream + yield { + candidates: [{ content: { parts: [{ text: ' part 2' }] } }], + } as unknown as GenerateContentResponse; + })(); + + const secondStreamGenerator = (async function* () { + yield { + candidates: [{ content: { parts: [{ text: 'second response' }] } }], + } as unknown as GenerateContentResponse; + })(); + + vi.mocked(mockModelsModule.generateContentStream) + .mockResolvedValueOnce(firstStreamGenerator) + .mockResolvedValueOnce(secondStreamGenerator); + + // 3. Start the first stream and consume only the first chunk to pause it + const firstStream = await chat.sendMessageStream( + { message: 'first' }, + 'prompt-1', + ); + const firstStreamIterator = firstStream[Symbol.asyncIterator](); + await firstStreamIterator.next(); + + // 4. While the first stream is paused, start the second call. It will block. + const secondStreamPromise = chat.sendMessageStream( + { message: 'second' }, + 'prompt-2', + ); + + // 5. Assert that only one API call has been made so far. + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(1); + + // 6. Unblock and fully consume the first stream to completion. + continueFirstStream!(); + await firstStreamIterator.next(); // Consume the rest of the stream + await firstStreamIterator.next(); // Finish the iterator + + // 7. Now that the first stream is done, await the second promise to get its generator. + const secondStream = await secondStreamPromise; + + // 8. Start consuming the second stream, which triggers its internal API call. + const secondStreamIterator = secondStream[Symbol.asyncIterator](); + await secondStreamIterator.next(); + + // 9. The second API call should now have been made. + expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2); + + // 10. FIX: Fully consume the second stream to ensure recordHistory is called. + await secondStreamIterator.next(); // This finishes the iterator. + + // 11. Final check on history. + const history = chat.getHistory(); + expect(history.length).toBe(4); + + const turn4 = history[3]; + if (!turn4?.parts?.[0] || !('text' in turn4.parts[0])) { + throw new Error( + 'Test setup error: Fourth turn is not a valid text part.', + ); + } + expect(turn4.parts[0].text).toBe('second response'); + }); }); diff --git a/packages/core/src/core/geminiChat.ts b/packages/core/src/core/geminiChat.ts index 9bc8fae4..93428684 100644 --- a/packages/core/src/core/geminiChat.ts +++ b/packages/core/src/core/geminiChat.ts @@ -24,6 +24,21 @@ import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { hasCycleInSchema } from '../tools/tools.js'; import { StructuredError } from './turn.js'; +/** + * Options for retrying due to invalid content from the model. + */ +interface ContentRetryOptions { + /** Total number of attempts to make (1 initial + N retries). */ + maxAttempts: number; + /** The base delay in milliseconds for linear backoff. */ + initialDelayMs: number; +} + +const INVALID_CONTENT_RETRY_OPTIONS: ContentRetryOptions = { + maxAttempts: 3, // 1 initial call + 2 retries + initialDelayMs: 500, +}; + /** * Returns true if the response is valid, false otherwise. */ @@ -98,15 +113,23 @@ function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] { } if (isValid) { curatedHistory.push(...modelOutput); - } else { - // Remove the last user input when model content is invalid. - curatedHistory.pop(); } } } return curatedHistory; } +/** + * Custom error to signal that a stream completed without valid content, + * which should trigger a retry. + */ +export class EmptyStreamError extends Error { + constructor(message: string) { + super(message); + this.name = 'EmptyStreamError'; + } +} + /** * Chat session that enables sending messages to the model with previous * conversation context. @@ -305,65 +328,121 @@ export class GeminiChat { prompt_id: string, ): Promise> { await this.sendPromise; + + let streamDoneResolver: () => void; + const streamDonePromise = new Promise((resolve) => { + streamDoneResolver = resolve; + }); + this.sendPromise = streamDonePromise; + const userContent = createUserContent(params.message); - const requestContents = this.getHistory(true).concat(userContent); - try { - const apiCall = () => { - const modelToUse = this.config.getModel(); + // Add user content to history ONCE before any attempts. + this.history.push(userContent); + const requestContents = this.getHistory(true); - // Prevent Flash model calls immediately after quota error - if ( - this.config.getQuotaErrorOccurred() && - modelToUse === DEFAULT_GEMINI_FLASH_MODEL + // eslint-disable-next-line @typescript-eslint/no-this-alias + const self = this; + return (async function* () { + try { + let lastError: unknown = new Error('Request failed after all retries.'); + + for ( + let attempt = 0; + attempt <= INVALID_CONTENT_RETRY_OPTIONS.maxAttempts; + attempt++ ) { - throw new Error( - 'Please submit a new query to continue with the Flash model.', - ); + try { + const stream = await self.makeApiCallAndProcessStream( + requestContents, + params, + prompt_id, + userContent, + ); + + for await (const chunk of stream) { + yield chunk; + } + + lastError = null; + break; + } catch (error) { + lastError = error; + const isContentError = error instanceof EmptyStreamError; + + if (isContentError) { + // Check if we have more attempts left. + if (attempt < INVALID_CONTENT_RETRY_OPTIONS.maxAttempts - 1) { + await new Promise((res) => + setTimeout( + res, + INVALID_CONTENT_RETRY_OPTIONS.initialDelayMs * + (attempt + 1), + ), + ); + continue; + } + } + break; + } } - return this.contentGenerator.generateContentStream( - { - model: modelToUse, - contents: requestContents, - config: { ...this.generationConfig, ...params.config }, - }, - prompt_id, - ); - }; - - // Note: Retrying streams can be complex. If generateContentStream itself doesn't handle retries - // for transient issues internally before yielding the async generator, this retry will re-initiate - // the stream. For simple 429/500 errors on initial call, this is fine. - // If errors occur mid-stream, this setup won't resume the stream; it will restart it. - const streamResponse = await retryWithBackoff(apiCall, { - shouldRetry: (error: unknown) => { - // Check for known error messages and codes. - if (error instanceof Error && error.message) { - if (isSchemaDepthError(error.message)) return false; - if (error.message.includes('429')) return true; - if (error.message.match(/5\d{2}/)) return true; + if (lastError) { + // If the stream fails, remove the user message that was added. + if (self.history[self.history.length - 1] === userContent) { + self.history.pop(); } - return false; // Don't retry other errors by default + throw lastError; + } + } finally { + streamDoneResolver!(); + } + })(); + } + + private async makeApiCallAndProcessStream( + requestContents: Content[], + params: SendMessageParameters, + prompt_id: string, + userContent: Content, + ): Promise> { + const apiCall = () => { + const modelToUse = this.config.getModel(); + + if ( + this.config.getQuotaErrorOccurred() && + modelToUse === DEFAULT_GEMINI_FLASH_MODEL + ) { + throw new Error( + 'Please submit a new query to continue with the Flash model.', + ); + } + + return this.contentGenerator.generateContentStream( + { + model: modelToUse, + contents: requestContents, + config: { ...this.generationConfig, ...params.config }, }, - onPersistent429: async (authType?: string, error?: unknown) => - await this.handleFlashFallback(authType, error), - authType: this.config.getContentGeneratorConfig()?.authType, - }); + prompt_id, + ); + }; - // Resolve the internal tracking of send completion promise - `sendPromise` - // for both success and failure response. The actual failure is still - // propagated by the `await streamResponse`. - this.sendPromise = Promise.resolve(streamResponse) - .then(() => undefined) - .catch(() => undefined); + const streamResponse = await retryWithBackoff(apiCall, { + shouldRetry: (error: unknown) => { + if (error instanceof Error && error.message) { + if (isSchemaDepthError(error.message)) return false; + if (error.message.includes('429')) return true; + if (error.message.match(/5\d{2}/)) return true; + } + return false; + }, + onPersistent429: async (authType?: string, error?: unknown) => + await this.handleFlashFallback(authType, error), + authType: this.config.getContentGeneratorConfig()?.authType, + }); - const result = this.processStreamResponse(streamResponse, userContent); - return result; - } catch (error) { - this.sendPromise = Promise.resolve(); - throw error; - } + return this.processStreamResponse(streamResponse, userContent); } /** @@ -407,8 +486,6 @@ export class GeminiChat { /** * Adds a new entry to the chat history. - * - * @param content - The content to add to the history. */ addHistory(content: Content): void { this.history.push(content); @@ -451,41 +528,41 @@ export class GeminiChat { private async *processStreamResponse( streamResponse: AsyncGenerator, - inputContent: Content, - ) { - const outputContent: Content[] = []; - const chunks: GenerateContentResponse[] = []; - let errorOccurred = false; + userInput: Content, + ): AsyncGenerator { + const modelResponseParts: Part[] = []; + let isStreamInvalid = false; + let hasReceivedAnyChunk = false; - try { - for await (const chunk of streamResponse) { - if (isValidResponse(chunk)) { - chunks.push(chunk); - const content = chunk.candidates?.[0]?.content; - if (content !== undefined) { - if (this.isThoughtContent(content)) { - yield chunk; - continue; - } - outputContent.push(content); + for await (const chunk of streamResponse) { + hasReceivedAnyChunk = true; + if (isValidResponse(chunk)) { + const content = chunk.candidates?.[0]?.content; + if (content) { + // Filter out thought parts from being added to history. + if (!this.isThoughtContent(content) && content.parts) { + modelResponseParts.push(...content.parts); } } - yield chunk; + } else { + isStreamInvalid = true; } - } catch (error) { - errorOccurred = true; - throw error; + yield chunk; // Yield every chunk to the UI immediately. } - if (!errorOccurred) { - const allParts: Part[] = []; - for (const content of outputContent) { - if (content.parts) { - allParts.push(...content.parts); - } - } + // Now that the stream is finished, make a decision. + // Throw an error if the stream was invalid OR if it was completely empty. + if (isStreamInvalid || !hasReceivedAnyChunk) { + throw new EmptyStreamError( + 'Model stream was invalid or completed without valid content.', + ); } - this.recordHistory(inputContent, outputContent); + + // Use recordHistory to correctly save the conversation turn. + const modelOutput: Content[] = [ + { role: 'model', parts: modelResponseParts }, + ]; + this.recordHistory(userInput, modelOutput); } private recordHistory( @@ -493,135 +570,74 @@ export class GeminiChat { modelOutput: Content[], automaticFunctionCallingHistory?: Content[], ) { + const newHistoryEntries: Content[] = []; + + // Part 1: Handle the user's part of the turn. + if ( + automaticFunctionCallingHistory && + automaticFunctionCallingHistory.length > 0 + ) { + newHistoryEntries.push( + ...extractCuratedHistory(automaticFunctionCallingHistory), + ); + } else { + // Guard for streaming calls where the user input might already be in the history. + if ( + this.history.length === 0 || + this.history[this.history.length - 1] !== userInput + ) { + newHistoryEntries.push(userInput); + } + } + + // Part 2: Handle the model's part of the turn, filtering out thoughts. const nonThoughtModelOutput = modelOutput.filter( (content) => !this.isThoughtContent(content), ); let outputContents: Content[] = []; - if ( - nonThoughtModelOutput.length > 0 && - nonThoughtModelOutput.every((content) => content.role !== undefined) - ) { + if (nonThoughtModelOutput.length > 0) { outputContents = nonThoughtModelOutput; - } else if (nonThoughtModelOutput.length === 0 && modelOutput.length > 0) { - // This case handles when the model returns only a thought. - // We don't want to add an empty model response in this case. - } else { - // When not a function response appends an empty content when model returns empty response, so that the - // history is always alternating between user and model. - // Workaround for: https://b.corp.google.com/issues/420354090 - if (!isFunctionResponse(userInput)) { - outputContents.push({ - role: 'model', - parts: [], - } as Content); - } - } - if ( - automaticFunctionCallingHistory && - automaticFunctionCallingHistory.length > 0 + } else if ( + modelOutput.length === 0 && + !isFunctionResponse(userInput) && + !automaticFunctionCallingHistory ) { - this.history.push( - ...extractCuratedHistory(automaticFunctionCallingHistory), - ); - } else { - this.history.push(userInput); + // Add an empty model response if the model truly returned nothing. + outputContents.push({ role: 'model', parts: [] } as Content); } - // Consolidate adjacent model roles in outputContents + // Part 3: Consolidate the parts of this turn's model response. const consolidatedOutputContents: Content[] = []; - for (const content of outputContents) { - if (this.isThoughtContent(content)) { - continue; - } - const lastContent = - consolidatedOutputContents[consolidatedOutputContents.length - 1]; - if ( - lastContent && - this.isLastPartText(lastContent) && - this.isFirstPartText(content) - ) { - // If the last part of the previous content and the first part of the current content are text, - // combine their text and append any other parts from the current content. - lastContent.parts[lastContent.parts.length - 1].text += - content.parts[0].text || ''; - if (content.parts.length > 1) { - lastContent.parts.push(...content.parts.slice(1)); + if (outputContents.length > 0) { + for (const content of outputContents) { + const lastContent = + consolidatedOutputContents[consolidatedOutputContents.length - 1]; + if (this.hasTextContent(lastContent) && this.hasTextContent(content)) { + lastContent.parts[0].text += content.parts[0].text || ''; + if (content.parts.length > 1) { + lastContent.parts.push(...content.parts.slice(1)); + } + } else { + consolidatedOutputContents.push(content); } - } else { - consolidatedOutputContents.push(content); } } - if (consolidatedOutputContents.length > 0) { - const lastHistoryEntry = this.history[this.history.length - 1]; - const canMergeWithLastHistory = - !automaticFunctionCallingHistory || - automaticFunctionCallingHistory.length === 0; - - if ( - canMergeWithLastHistory && - lastHistoryEntry && - this.isLastPartText(lastHistoryEntry) && - this.isFirstPartText(consolidatedOutputContents[0]) - ) { - // If the last part of the last history entry and the first part of the current content are text, - // combine their text and append any other parts from the current content. - lastHistoryEntry.parts[lastHistoryEntry.parts.length - 1].text += - consolidatedOutputContents[0].parts[0].text || ''; - if (consolidatedOutputContents[0].parts.length > 1) { - lastHistoryEntry.parts.push( - ...consolidatedOutputContents[0].parts.slice(1), - ); - } - consolidatedOutputContents.shift(); // Remove the first element as it's merged - } - this.history.push(...consolidatedOutputContents); - } + // Part 4: Add the new turn (user and model parts) to the main history. + this.history.push(...newHistoryEntries, ...consolidatedOutputContents); } - private isFirstPartText( + private hasTextContent( content: Content | undefined, ): content is Content & { parts: [{ text: string }, ...Part[]] } { - if ( - !content || - content.role !== 'model' || - !content.parts || - content.parts.length === 0 - ) { - return false; - } - const firstPart = content.parts[0]; - return ( - typeof firstPart.text === 'string' && - !('functionCall' in firstPart) && - !('functionResponse' in firstPart) && - !('inlineData' in firstPart) && - !('fileData' in firstPart) && - !('thought' in firstPart) - ); - } - - private isLastPartText( - content: Content | undefined, - ): content is Content & { parts: [...Part[], { text: string }] } { - if ( - !content || - content.role !== 'model' || - !content.parts || - content.parts.length === 0 - ) { - return false; - } - const lastPart = content.parts[content.parts.length - 1]; - return ( - typeof lastPart.text === 'string' && - lastPart.text !== '' && - !('functionCall' in lastPart) && - !('functionResponse' in lastPart) && - !('inlineData' in lastPart) && - !('fileData' in lastPart) && - !('thought' in lastPart) + return !!( + content && + content.role === 'model' && + content.parts && + content.parts.length > 0 && + typeof content.parts[0].text === 'string' && + content.parts[0].text !== '' ); }