mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
Handle cleaning up the response text in the UI when a response stream retry occurs (#7416)
This commit is contained in:
@@ -616,6 +616,9 @@ export const useGeminiStream = (
|
||||
// before we add loop detected message to history
|
||||
loopDetectedRef.current = true;
|
||||
break;
|
||||
case ServerGeminiEventType.Retry:
|
||||
// Will add the missing logic later
|
||||
break;
|
||||
default: {
|
||||
// enforces exhaustive switch-case
|
||||
const unreachable: never = event;
|
||||
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
getErrorStatus,
|
||||
MCPServerConfig,
|
||||
DiscoveredMCPTool,
|
||||
StreamEventType,
|
||||
} from '@google/gemini-cli-core';
|
||||
import * as acp from './acp.js';
|
||||
import { AcpFileSystemService } from './fileSystemService.js';
|
||||
@@ -269,8 +270,12 @@ class Session {
|
||||
return { stopReason: 'cancelled' };
|
||||
}
|
||||
|
||||
if (resp.candidates && resp.candidates.length > 0) {
|
||||
const candidate = resp.candidates[0];
|
||||
if (
|
||||
resp.type === StreamEventType.CHUNK &&
|
||||
resp.value.candidates &&
|
||||
resp.value.candidates.length > 0
|
||||
) {
|
||||
const candidate = resp.value.candidates[0];
|
||||
for (const part of candidate.content?.parts ?? []) {
|
||||
if (!part.text) {
|
||||
continue;
|
||||
@@ -290,8 +295,8 @@ class Session {
|
||||
}
|
||||
}
|
||||
|
||||
if (resp.functionCalls) {
|
||||
functionCalls.push(...resp.functionCalls);
|
||||
if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) {
|
||||
functionCalls.push(...resp.value.functionCalls);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
|
||||
@@ -12,7 +12,12 @@ import type {
|
||||
Part,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { GeminiChat, EmptyStreamError } from './geminiChat.js';
|
||||
import {
|
||||
GeminiChat,
|
||||
EmptyStreamError,
|
||||
StreamEventType,
|
||||
type StreamEvent,
|
||||
} from './geminiChat.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import { setSimulate429 } from '../utils/testUtils.js';
|
||||
|
||||
@@ -922,6 +927,42 @@ describe('GeminiChat', () => {
|
||||
});
|
||||
|
||||
describe('sendMessageStream with retries', () => {
|
||||
it('should yield a RETRY event when an invalid stream is encountered', async () => {
|
||||
// ARRANGE: Mock the stream to fail once, then succeed.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First attempt: An invalid stream with an empty text part.
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: '' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
)
|
||||
.mockImplementationOnce(async () =>
|
||||
// Second attempt (the retry): A minimal valid stream.
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'Success' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
// ACT: Send a message and collect all events from the stream.
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-yield-retry',
|
||||
);
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// ASSERT: Check that a RETRY event was present in the stream's output.
|
||||
const retryEvent = events.find((e) => e.type === StreamEventType.RETRY);
|
||||
|
||||
expect(retryEvent).toBeDefined();
|
||||
expect(retryEvent?.type).toBe(StreamEventType.RETRY);
|
||||
});
|
||||
it('should retry on invalid content, succeed, and report metrics', async () => {
|
||||
// Use mockImplementationOnce to provide a fresh, promise-wrapped generator for each attempt.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
@@ -948,7 +989,7 @@ describe('GeminiChat', () => {
|
||||
{ message: 'test' },
|
||||
'prompt-id-retry-success',
|
||||
);
|
||||
const chunks = [];
|
||||
const chunks: StreamEvent[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
@@ -958,11 +999,17 @@ describe('GeminiChat', () => {
|
||||
expect(mockLogContentRetry).toHaveBeenCalledTimes(1);
|
||||
expect(mockLogContentRetryFailure).not.toHaveBeenCalled();
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
|
||||
// Check for a retry event
|
||||
expect(chunks.some((c) => c.type === StreamEventType.RETRY)).toBe(true);
|
||||
|
||||
// Check for the successful content chunk
|
||||
expect(
|
||||
chunks.some(
|
||||
(c) =>
|
||||
c.candidates?.[0]?.content?.parts?.[0]?.text ===
|
||||
'Successful response',
|
||||
c.type === StreamEventType.CHUNK &&
|
||||
c.value.candidates?.[0]?.content?.parts?.[0]?.text ===
|
||||
'Successful response',
|
||||
),
|
||||
).toBe(true);
|
||||
|
||||
@@ -1203,7 +1250,7 @@ describe('GeminiChat', () => {
|
||||
{ message: 'test empty stream' },
|
||||
'prompt-id-empty-stream',
|
||||
);
|
||||
const chunks = [];
|
||||
const chunks: StreamEvent[] = [];
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk);
|
||||
}
|
||||
@@ -1213,8 +1260,9 @@ describe('GeminiChat', () => {
|
||||
expect(
|
||||
chunks.some(
|
||||
(c) =>
|
||||
c.candidates?.[0]?.content?.parts?.[0]?.text ===
|
||||
'Successful response after empty',
|
||||
c.type === StreamEventType.CHUNK &&
|
||||
c.value.candidates?.[0]?.content?.parts?.[0]?.text ===
|
||||
'Successful response after empty',
|
||||
),
|
||||
).toBe(true);
|
||||
|
||||
@@ -1313,4 +1361,67 @@ describe('GeminiChat', () => {
|
||||
}
|
||||
expect(turn4.parts[0].text).toBe('second response');
|
||||
});
|
||||
|
||||
it('should discard valid partial content from a failed attempt upon retry', async () => {
|
||||
// ARRANGE: Mock the stream to fail on the first attempt after yielding some valid content.
|
||||
vi.mocked(mockModelsModule.generateContentStream)
|
||||
.mockImplementationOnce(async () =>
|
||||
// First attempt: yields one valid chunk, then one invalid chunk
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'This valid part should be discarded' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: '' }] } }], // Invalid chunk triggers retry
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
)
|
||||
.mockImplementationOnce(async () =>
|
||||
// Second attempt (the retry): succeeds
|
||||
(async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Successful final response' }],
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
})(),
|
||||
);
|
||||
|
||||
// ACT: Send a message and consume the stream
|
||||
const stream = await chat.sendMessageStream(
|
||||
{ message: 'test' },
|
||||
'prompt-id-discard-test',
|
||||
);
|
||||
const events: StreamEvent[] = [];
|
||||
for await (const event of stream) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// ASSERT
|
||||
// Check that a retry happened
|
||||
expect(mockModelsModule.generateContentStream).toHaveBeenCalledTimes(2);
|
||||
expect(events.some((e) => e.type === StreamEventType.RETRY)).toBe(true);
|
||||
|
||||
// Check the final recorded history
|
||||
const history = chat.getHistory();
|
||||
expect(history.length).toBe(2); // user turn + final model turn
|
||||
|
||||
const modelTurn = history[1]!;
|
||||
// The model turn should only contain the text from the successful attempt
|
||||
expect(modelTurn!.parts![0]!.text).toBe('Successful final response');
|
||||
// It should NOT contain any text from the failed attempt
|
||||
expect(modelTurn!.parts![0]!.text).not.toContain(
|
||||
'This valid part should be discarded',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -34,6 +34,18 @@ import {
|
||||
InvalidChunkEvent,
|
||||
} from '../telemetry/types.js';
|
||||
|
||||
export enum StreamEventType {
|
||||
/** A regular content chunk from the API. */
|
||||
CHUNK = 'chunk',
|
||||
/** A signal that a retry is about to happen. The UI should discard any partial
|
||||
* content from the attempt that just failed. */
|
||||
RETRY = 'retry',
|
||||
}
|
||||
|
||||
export type StreamEvent =
|
||||
| { type: StreamEventType.CHUNK; value: GenerateContentResponse }
|
||||
| { type: StreamEventType.RETRY };
|
||||
|
||||
/**
|
||||
* Options for retrying due to invalid content from the model.
|
||||
*/
|
||||
@@ -340,7 +352,7 @@ export class GeminiChat {
|
||||
async sendMessageStream(
|
||||
params: SendMessageParameters,
|
||||
prompt_id: string,
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
): Promise<AsyncGenerator<StreamEvent>> {
|
||||
await this.sendPromise;
|
||||
|
||||
let streamDoneResolver: () => void;
|
||||
@@ -367,6 +379,10 @@ export class GeminiChat {
|
||||
attempt++
|
||||
) {
|
||||
try {
|
||||
if (attempt > 0) {
|
||||
yield { type: StreamEventType.RETRY };
|
||||
}
|
||||
|
||||
const stream = await self.makeApiCallAndProcessStream(
|
||||
requestContents,
|
||||
params,
|
||||
@@ -375,7 +391,7 @@ export class GeminiChat {
|
||||
);
|
||||
|
||||
for await (const chunk of stream) {
|
||||
yield chunk;
|
||||
yield { type: StreamEventType.CHUNK, value: chunk };
|
||||
}
|
||||
|
||||
lastError = null;
|
||||
|
||||
@@ -21,7 +21,7 @@ import type {
|
||||
} from './subagent.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import type { ConfigParameters } from '../config/config.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { GeminiChat, StreamEventType } from './geminiChat.js';
|
||||
import { createContentGenerator } from './contentGenerator.js';
|
||||
import { getEnvironmentContext } from '../utils/environmentContext.js';
|
||||
import { executeToolCall } from './nonInteractiveToolExecutor.js';
|
||||
@@ -33,6 +33,7 @@ import type {
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
} from '@google/genai';
|
||||
import { ToolErrorType } from '../tools/tool-error.js';
|
||||
|
||||
@@ -73,18 +74,33 @@ const createMockStream = (
|
||||
functionCallsList: Array<FunctionCall[] | 'stop'>,
|
||||
) => {
|
||||
let index = 0;
|
||||
return vi.fn().mockImplementation(() => {
|
||||
// This mock now returns a Promise that resolves to the async generator,
|
||||
// matching the new signature for sendMessageStream.
|
||||
return vi.fn().mockImplementation(async () => {
|
||||
const response = functionCallsList[index] || 'stop';
|
||||
index++;
|
||||
|
||||
return (async function* () {
|
||||
if (response === 'stop') {
|
||||
// When stopping, the model might return text, but the subagent logic primarily cares about the absence of functionCalls.
|
||||
yield { text: 'Done.' };
|
||||
} else if (response.length > 0) {
|
||||
yield { functionCalls: response };
|
||||
let mockResponseValue: Partial<GenerateContentResponse>;
|
||||
|
||||
if (response === 'stop' || response.length === 0) {
|
||||
// Simulate a text response for stop/empty conditions.
|
||||
mockResponseValue = {
|
||||
candidates: [{ content: { parts: [{ text: 'Done.' }] } }],
|
||||
};
|
||||
} else {
|
||||
yield { text: 'Done.' }; // Handle empty array also as stop
|
||||
// Simulate a tool call response.
|
||||
mockResponseValue = {
|
||||
candidates: [], // Good practice to include for safety.
|
||||
functionCalls: response,
|
||||
};
|
||||
}
|
||||
|
||||
// The stream must now yield a StreamEvent object of type CHUNK.
|
||||
yield {
|
||||
type: StreamEventType.CHUNK,
|
||||
value: mockResponseValue as GenerateContentResponse,
|
||||
};
|
||||
})();
|
||||
});
|
||||
};
|
||||
|
||||
@@ -20,7 +20,7 @@ import type {
|
||||
FunctionDeclaration,
|
||||
} from '@google/genai';
|
||||
import { Type } from '@google/genai';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { GeminiChat, StreamEventType } from './geminiChat.js';
|
||||
|
||||
/**
|
||||
* @fileoverview Defines the configuration interfaces for a subagent.
|
||||
@@ -439,12 +439,11 @@ export class SubAgentScope {
|
||||
let textResponse = '';
|
||||
for await (const resp of responseStream) {
|
||||
if (abortController.signal.aborted) return;
|
||||
if (resp.functionCalls) {
|
||||
functionCalls.push(...resp.functionCalls);
|
||||
if (resp.type === StreamEventType.CHUNK && resp.value.functionCalls) {
|
||||
functionCalls.push(...resp.value.functionCalls);
|
||||
}
|
||||
const text = resp.text;
|
||||
if (text) {
|
||||
textResponse += text;
|
||||
if (resp.type === StreamEventType.CHUNK && resp.value.text) {
|
||||
textResponse += resp.value.text;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ import { Turn, GeminiEventType } from './turn.js';
|
||||
import type { GenerateContentResponse, Part, Content } from '@google/genai';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import type { GeminiChat } from './geminiChat.js';
|
||||
import { StreamEventType } from './geminiChat.js';
|
||||
|
||||
const mockSendMessageStream = vi.fn();
|
||||
const mockGetHistory = vi.fn();
|
||||
@@ -35,6 +36,7 @@ vi.mock('../utils/errorReporting', () => ({
|
||||
reportError: vi.fn(),
|
||||
}));
|
||||
|
||||
// Use the actual implementation from partUtils now that it's provided.
|
||||
vi.mock('../utils/generateContentResponseUtilities', () => ({
|
||||
getResponseText: (resp: GenerateContentResponse) =>
|
||||
resp.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
|
||||
@@ -78,11 +80,17 @@ describe('Turn', () => {
|
||||
it('should yield content events for text parts', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'Hello' }] } }],
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: ' world' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: ' world' }] } }],
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -113,16 +121,23 @@ describe('Turn', () => {
|
||||
it('should yield tool_call_request events for function calls', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
functionCalls: [
|
||||
{
|
||||
id: 'fc1',
|
||||
name: 'tool1',
|
||||
args: { arg1: 'val1' },
|
||||
isClientInitiated: false,
|
||||
},
|
||||
{ name: 'tool2', args: { arg2: 'val2' }, isClientInitiated: false }, // No ID
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
functionCalls: [
|
||||
{
|
||||
id: 'fc1',
|
||||
name: 'tool1',
|
||||
args: { arg1: 'val1' },
|
||||
isClientInitiated: false,
|
||||
},
|
||||
{
|
||||
name: 'tool2',
|
||||
args: { arg2: 'val2' },
|
||||
isClientInitiated: false,
|
||||
}, // No ID
|
||||
],
|
||||
} as unknown as GenerateContentResponse,
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -168,18 +183,24 @@ describe('Turn', () => {
|
||||
const abortController = new AbortController();
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'First part' }] } }],
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
abortController.abort();
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Second part - should not be processed' }],
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Second part - should not be processed' }],
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
],
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -230,79 +251,79 @@ describe('Turn', () => {
|
||||
it('should handle function calls with undefined name or args', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
functionCalls: [
|
||||
{ id: 'fc1', name: undefined, args: { arg1: 'val1' } },
|
||||
{ id: 'fc2', name: 'tool2', args: undefined },
|
||||
{ id: 'fc3', name: undefined, args: undefined },
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [],
|
||||
functionCalls: [
|
||||
// Add `id` back to the mock to match what the code expects
|
||||
{ id: 'fc1', name: undefined, args: { arg1: 'val1' } },
|
||||
{ id: 'fc2', name: 'tool2', args: undefined },
|
||||
{ id: 'fc3', name: undefined, args: undefined },
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
[{ text: 'Test undefined tool parts' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events.length).toBe(3);
|
||||
|
||||
// Assertions for each specific tool call event
|
||||
const event1 = events[0] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event1.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event1.value).toEqual(
|
||||
expect.objectContaining({
|
||||
callId: 'fc1',
|
||||
name: 'undefined_tool_name',
|
||||
args: { arg1: 'val1' },
|
||||
isClientInitiated: false,
|
||||
}),
|
||||
);
|
||||
expect(turn.pendingToolCalls[0]).toEqual(event1.value);
|
||||
expect(event1.value).toMatchObject({
|
||||
callId: 'fc1',
|
||||
name: 'undefined_tool_name',
|
||||
args: { arg1: 'val1' },
|
||||
});
|
||||
|
||||
const event2 = events[1] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event2.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event2.value).toEqual(
|
||||
expect.objectContaining({
|
||||
callId: 'fc2',
|
||||
name: 'tool2',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
}),
|
||||
);
|
||||
expect(turn.pendingToolCalls[1]).toEqual(event2.value);
|
||||
expect(event2.value).toMatchObject({
|
||||
callId: 'fc2',
|
||||
name: 'tool2',
|
||||
args: {},
|
||||
});
|
||||
|
||||
const event3 = events[2] as ServerGeminiToolCallRequestEvent;
|
||||
expect(event3.type).toBe(GeminiEventType.ToolCallRequest);
|
||||
expect(event3.value).toEqual(
|
||||
expect.objectContaining({
|
||||
callId: 'fc3',
|
||||
name: 'undefined_tool_name',
|
||||
args: {},
|
||||
isClientInitiated: false,
|
||||
}),
|
||||
);
|
||||
expect(turn.pendingToolCalls[2]).toEqual(event3.value);
|
||||
expect(turn.getDebugResponses().length).toBe(1);
|
||||
expect(event3.value).toMatchObject({
|
||||
callId: 'fc3',
|
||||
name: 'undefined_tool_name',
|
||||
args: {},
|
||||
});
|
||||
});
|
||||
|
||||
it('should yield finished event when response has finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Partial response' }] },
|
||||
finishReason: 'STOP',
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Partial response' }] },
|
||||
finishReason: 'STOP',
|
||||
},
|
||||
],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 17,
|
||||
candidatesTokenCount: 50,
|
||||
cachedContentTokenCount: 10,
|
||||
thoughtsTokenCount: 5,
|
||||
toolUsePromptTokenCount: 2,
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
} as GenerateContentResponse,
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test finish reason' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
[{ text: 'Test finish reason' }],
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
@@ -317,17 +338,20 @@ describe('Turn', () => {
|
||||
it('should yield finished event for MAX_TOKENS finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{ text: 'This is a long response that was cut off...' },
|
||||
],
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [
|
||||
{ text: 'This is a long response that was cut off...' },
|
||||
],
|
||||
},
|
||||
finishReason: 'MAX_TOKENS',
|
||||
},
|
||||
finishReason: 'MAX_TOKENS',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -352,13 +376,16 @@ describe('Turn', () => {
|
||||
it('should yield finished event for SAFETY finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Content blocked' }] },
|
||||
finishReason: 'SAFETY',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Content blocked' }] },
|
||||
finishReason: 'SAFETY',
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -380,13 +407,18 @@ describe('Turn', () => {
|
||||
it('should not yield finished event when there is no finish reason', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Response without finish reason' }] },
|
||||
// No finishReason property
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: 'Response without finish reason' }],
|
||||
},
|
||||
// No finishReason property
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -411,21 +443,27 @@ describe('Turn', () => {
|
||||
it('should handle multiple responses with different finish reasons', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'First part' }] },
|
||||
// No finish reason on first response
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'First part' }] },
|
||||
// No finish reason on first response
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
yield {
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Second part' }] },
|
||||
finishReason: 'OTHER',
|
||||
},
|
||||
],
|
||||
} as unknown as GenerateContentResponse;
|
||||
value: {
|
||||
type: StreamEventType.CHUNK,
|
||||
candidates: [
|
||||
{
|
||||
content: { parts: [{ text: 'Second part' }] },
|
||||
finishReason: 'OTHER',
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
@@ -470,6 +508,29 @@ describe('Turn', () => {
|
||||
|
||||
expect(reportError).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should yield a Retry event when it receives one from the chat stream', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield { type: StreamEventType.RETRY };
|
||||
yield {
|
||||
type: StreamEventType.CHUNK,
|
||||
value: {
|
||||
candidates: [{ content: { parts: [{ text: 'Success' }] } }],
|
||||
},
|
||||
};
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run([], new AbortController().signal)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Retry },
|
||||
{ type: GeminiEventType.Content, value: 'Success' },
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getDebugResponses', () => {
|
||||
@@ -481,8 +542,8 @@ describe('Turn', () => {
|
||||
functionCalls: [{ name: 'debugTool' }],
|
||||
} as unknown as GenerateContentResponse;
|
||||
const mockResponseStream = (async function* () {
|
||||
yield resp1;
|
||||
yield resp2;
|
||||
yield { type: StreamEventType.CHUNK, value: resp1 };
|
||||
yield { type: StreamEventType.CHUNK, value: resp2 };
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
|
||||
@@ -54,8 +54,14 @@ export enum GeminiEventType {
|
||||
MaxSessionTurns = 'max_session_turns',
|
||||
Finished = 'finished',
|
||||
LoopDetected = 'loop_detected',
|
||||
Citation = 'citation',
|
||||
Retry = 'retry',
|
||||
}
|
||||
|
||||
export type ServerGeminiRetryEvent = {
|
||||
type: GeminiEventType.Retry;
|
||||
};
|
||||
|
||||
export interface StructuredError {
|
||||
message: string;
|
||||
status?: number;
|
||||
@@ -175,7 +181,8 @@ export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiThoughtEvent
|
||||
| ServerGeminiMaxSessionTurnsEvent
|
||||
| ServerGeminiFinishedEvent
|
||||
| ServerGeminiLoopDetectedEvent;
|
||||
| ServerGeminiLoopDetectedEvent
|
||||
| ServerGeminiRetryEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
@@ -197,6 +204,8 @@ export class Turn {
|
||||
signal: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
try {
|
||||
// Note: This assumes `sendMessageStream` yields events like
|
||||
// { type: StreamEventType.RETRY } or { type: StreamEventType.CHUNK, value: GenerateContentResponse }
|
||||
const responseStream = await this.chat.sendMessageStream(
|
||||
{
|
||||
message: req,
|
||||
@@ -207,12 +216,22 @@ export class Turn {
|
||||
this.prompt_id,
|
||||
);
|
||||
|
||||
for await (const resp of responseStream) {
|
||||
for await (const streamEvent of responseStream) {
|
||||
if (signal?.aborted) {
|
||||
yield { type: GeminiEventType.UserCancelled };
|
||||
// Do not add resp to debugResponses if aborted before processing
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle the new RETRY event
|
||||
if (streamEvent.type === 'retry') {
|
||||
yield { type: GeminiEventType.Retry };
|
||||
continue; // Skip to the next event in the stream
|
||||
}
|
||||
|
||||
// Assuming other events are chunks with a `value` property
|
||||
const resp = streamEvent.value as GenerateContentResponse;
|
||||
if (!resp) continue; // Skip if there's no response body
|
||||
|
||||
this.debugResponses.push(resp);
|
||||
|
||||
const thoughtPart = resp.candidates?.[0]?.content?.parts?.[0];
|
||||
@@ -254,6 +273,7 @@ export class Turn {
|
||||
// Check if response was truncated or stopped for various reasons
|
||||
const finishReason = resp.candidates?.[0]?.finishReason;
|
||||
|
||||
// This is the key change: Only yield 'Finished' if there is a finishReason.
|
||||
if (finishReason) {
|
||||
this.finishReason = finishReason;
|
||||
yield {
|
||||
|
||||
Reference in New Issue
Block a user