diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index 9fdd7e06..c858876e 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -51,6 +51,8 @@ vi.mock('./turn', () => { GeminiEventType: { MaxSessionTurns: 'MaxSessionTurns', ChatCompressed: 'ChatCompressed', + Error: 'error', + Content: 'content', }, }; }); @@ -1887,6 +1889,89 @@ ${JSON.stringify( expect(JSON.stringify(finalCall)).toContain('fileC.ts'); }); }); + + it('should not call checkNextSpeaker when turn.run() yields an error', async () => { + // Arrange + const { checkNextSpeaker } = await import( + '../utils/nextSpeakerChecker.js' + ); + const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker); + + const mockStream = (async function* () { + yield { + type: GeminiEventType.Error, + value: { error: { message: 'test error' } }, + }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-error', + ); + for await (const _ of stream) { + // consume stream + } + + // Assert + expect(mockCheckNextSpeaker).not.toHaveBeenCalled(); + }); + + it('should not call checkNextSpeaker when turn.run() yields a value then an error', async () => { + // Arrange + const { checkNextSpeaker } = await import( + '../utils/nextSpeakerChecker.js' + ); + const mockCheckNextSpeaker = vi.mocked(checkNextSpeaker); + + const mockStream = (async function* () { + yield { type: GeminiEventType.Content, value: 'some content' }; + yield { + type: GeminiEventType.Error, + value: { error: { message: 'test error' } }, + }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-error', + ); + for await (const _ of stream) { + // consume stream + } + + // Assert + expect(mockCheckNextSpeaker).not.toHaveBeenCalled(); + }); }); describe('generateContent', () => { diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index d6a30c94..bc1054a2 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -513,6 +513,9 @@ export class GeminiClient { return turn; } yield event; + if (event.type === GeminiEventType.Error) { + return turn; + } } if (!turn.pendingToolCalls.length && signal && !signal.aborted) { // Check if model was switched during the call (likely due to quota error)