From 415a36a195f34b3fcde7e08f0d17f3dea1ffb327 Mon Sep 17 00:00:00 2001 From: Silvio Junior Date: Tue, 26 Aug 2025 00:10:53 -0400 Subject: [PATCH] Do not call nextSpeakerCheck if there was an error processing the stream. (#7048) Co-authored-by: christine betts Co-authored-by: Antonio Scandurra Co-authored-by: Arya Gummadi Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Shreya Keshive Co-authored-by: Abhi <43648792+abhipatel12@users.noreply.github.com> Co-authored-by: shishu314 Co-authored-by: Shi Shu Co-authored-by: Steven Co-authored-by: Pascal Birchler Co-authored-by: N. Taylor Mullen --- packages/core/src/core/client.test.ts | 85 +++++++++++++++++++++++++++ packages/core/src/core/client.ts | 3 + 2 files changed, 88 insertions(+) diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 9fdd7e06..c858876e 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -51,6 +51,8 @@ vi.mock('./turn', () => { GeminiEventType: { MaxSessionTurns: 'MaxSessionTurns', ChatCompressed: 'ChatCompressed', + Error: 'error', + Content: 'content', }, }; }); @@ -1887,6 +1889,89 @@ ${JSON.stringify( expect(JSON.stringify(finalCall)).toContain('fileC.ts'); }); }); + + it('should not call checkNextSpeaker when turn.run() yields an error', async () => { + // Arrange + const { checkNextSpeaker } = await import( + '../utils/nextSpeakerChecker.js' + ); + const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker); + + const mockStream = (async function* () { + yield { + type: GeminiEventType.Error, + value: { error: { message: 'test error' } }, + }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-error', + ); + for await (const _ of stream) { + // consume stream + } + + // Assert + expect(mockCheckNextSpeaker).not.toHaveBeenCalled(); + }); + + it('should not call checkNextSpeaker when turn.run() yields a value then an error', async () => { + // Arrange + const { checkNextSpeaker } = await import( + '../utils/nextSpeakerChecker.js' + ); + const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker); + + const mockStream = (async function* () { + yield { type: GeminiEventType.Content, value: 'some content' }; + yield { + type: GeminiEventType.Error, + value: { error: { message: 'test error' } }, + }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-error', + ); + for await (const _ of stream) { + // consume stream + } + + // Assert + expect(mockCheckNextSpeaker).not.toHaveBeenCalled(); + }); }); describe('generateContent', () => { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index d6a30c94..bc1054a2 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -513,6 +513,9 @@ export class GeminiClient { return turn; } yield event; + if (event.type === GeminiEventType.Error) { + return turn; + } } if (!turn.pendingToolCalls.length && signal && !signal.aborted) { // Check if model was switched during the call (likely due to quota error)