Refine stream validation to prevent unnecessary retries (#7278)

This commit is contained in:
Victor May
2025-08-28 13:21:27 -04:00
committed by Sandy Tao
parent 4c9746561c
commit dfa9dda1dd
2 changed files with 232 additions and 5 deletions

View File

@@ -73,6 +73,92 @@ describe('GeminiChat', () => {
});
describe('sendMessage', () => {
it('should retain the initial user message when an automatic function call occurs', async () => {
// 1. Define the user's initial text message. This is the turn that gets dropped by the buggy logic.
const userInitialMessage: Content = {
role: 'user',
parts: [{ text: 'How is the weather in Boston?' }],
};
// 2. Mock the full API response, including the automaticFunctionCallingHistory.
// This history represents the full turn: user asks, model calls tool, tool responds, model answers.
const mockAfcResponse = {
candidates: [
{
content: {
role: 'model',
parts: [
{ text: 'The weather in Boston is 72 degrees and sunny.' },
],
},
},
],
automaticFunctionCallingHistory: [
userInitialMessage, // The user's turn
{
// The model's first response: a tool call
role: 'model',
parts: [
{
functionCall: {
name: 'get_weather',
args: { location: 'Boston' },
},
},
],
},
{
// The tool's response, which has a 'user' role
role: 'user',
parts: [
{
functionResponse: {
name: 'get_weather',
response: { temperature: 72, condition: 'sunny' },
},
},
],
},
],
} as unknown as GenerateContentResponse;
vi.mocked(mockModelsModule.generateContent).mockResolvedValue(
mockAfcResponse,
);
// 3. Action: Send the initial message.
await chat.sendMessage(
{ message: 'How is the weather in Boston?' },
'prompt-id-afc-bug',
);
// 4. Assert: Check the final state of the history.
const history = chat.getHistory();
// With the bug, history.length will be 3, because the first user message is dropped.
// The correct behavior is for the history to contain all 4 turns.
expect(history.length).toBe(4);
// Crucially, assert that the very first turn in the history matches the user's initial message.
// This is the assertion that will fail.
const firstTurn = history[0]!;
expect(firstTurn.role).toBe('user');
expect(firstTurn?.parts![0]!.text).toBe('How is the weather in Boston?');
// Verify the rest of the history is also correct.
const secondTurn = history[1]!;
expect(secondTurn.role).toBe('model');
expect(secondTurn?.parts![0]!.functionCall).toBeDefined();
const thirdTurn = history[2]!;
expect(thirdTurn.role).toBe('user');
expect(thirdTurn?.parts![0]!.functionResponse).toBeDefined();
const fourthTurn = history[3]!;
expect(fourthTurn.role).toBe('model');
expect(fourthTurn?.parts![0]!.text).toContain('72 degrees and sunny');
});
it('should throw an error when attempting to add a user turn after another user turn', async () => {
// 1. Setup: Create a history that already ends with a user turn (a functionResponse).
const initialHistory: Content[] = [
@@ -240,6 +326,110 @@ describe('GeminiChat', () => {
});
describe('sendMessageStream', () => {
it('should succeed if a tool call is followed by an empty part', async () => {
// 1. Mock a stream that contains a tool call, then an invalid (empty) part.
const streamWithToolCall = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ functionCall: { name: 'test_tool', args: {} } }],
},
},
],
} as unknown as GenerateContentResponse;
// This second chunk is invalid according to isValidResponse
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ text: '' }],
},
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
streamWithToolCall,
);
// 2. Action & Assert: The stream processing should complete without throwing an error
// because the presence of a tool call makes the empty final chunk acceptable.
const stream = await chat.sendMessageStream(
{ message: 'test message' },
'prompt-id-tool-call-empty-end',
);
await expect(
(async () => {
for await (const _ of stream) {
/* consume stream */
}
})(),
).resolves.not.toThrow();
// 3. Verify history was recorded correctly
const history = chat.getHistory();
expect(history.length).toBe(2); // user turn + model turn
const modelTurn = history[1]!;
expect(modelTurn?.parts?.length).toBe(1); // The empty part is discarded
expect(modelTurn?.parts![0]!.functionCall).toBeDefined();
});
it('should succeed if the stream ends with an empty part but has a valid finishReason', async () => {
// 1. Mock a stream that ends with an invalid part but has a 'STOP' finish reason.
const streamWithValidFinish = (async function* () {
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ text: 'Initial content...' }],
},
},
],
} as unknown as GenerateContentResponse;
// This second chunk is invalid, but the finishReason should save it from retrying.
yield {
candidates: [
{
content: {
role: 'model',
parts: [{ text: '' }],
},
finishReason: 'STOP',
},
],
} as unknown as GenerateContentResponse;
})();
vi.mocked(mockModelsModule.generateContentStream).mockResolvedValue(
streamWithValidFinish,
);
// 2. Action & Assert: The stream should complete successfully because the valid
// finishReason overrides the invalid final chunk.
const stream = await chat.sendMessageStream(
{ message: 'test message' },
'prompt-id-valid-finish-empty-end',
);
await expect(
(async () => {
for await (const _ of stream) {
/* consume stream */
}
})(),
).resolves.not.toThrow();
// 3. Verify history was recorded correctly
const history = chat.getHistory();
expect(history.length).toBe(2);
const modelTurn = history[1]!;
expect(modelTurn?.parts?.length).toBe(1); // The empty part is discarded
expect(modelTurn?.parts![0]!.text).toBe('Initial content...');
});
it('should not consolidate text into a part that also contains a functionCall', async () => {
// 1. Mock the API to stream a malformed part followed by a valid text part.
const multiChunkStream = (async function* () {

View File

@@ -562,15 +562,30 @@ export class GeminiChat {
userInput: Content,
): AsyncGenerator<GenerateContentResponse> {
const modelResponseParts: Part[] = [];
let isStreamInvalid = false;
let hasReceivedAnyChunk = false;
let hasToolCall = false;
let lastChunk: GenerateContentResponse | null = null;
let isStreamInvalid = false;
let firstInvalidChunkEncountered = false;
let validChunkAfterInvalidEncountered = false;
for await (const chunk of streamResponse) {
hasReceivedAnyChunk = true;
lastChunk = chunk;
if (isValidResponse(chunk)) {
if (firstInvalidChunkEncountered) {
// A valid chunk appeared *after* an invalid one.
validChunkAfterInvalidEncountered = true;
}
const content = chunk.candidates?.[0]?.content;
if (content?.parts) {
modelResponseParts.push(...content.parts);
if (content.parts.some((part) => part.functionCall)) {
hasToolCall = true;
}
}
} else {
logInvalidChunk(
@@ -578,14 +593,36 @@ export class GeminiChat {
new InvalidChunkEvent('Invalid chunk received from stream.'),
);
isStreamInvalid = true;
firstInvalidChunkEncountered = true;
}
yield chunk;
}
if (isStreamInvalid || !hasReceivedAnyChunk) {
throw new EmptyStreamError(
'Model stream was invalid or completed without valid content.',
);
if (!hasReceivedAnyChunk) {
throw new EmptyStreamError('Model stream completed without any chunks.');
}
// --- FIX: The entire validation block was restructured for clarity and correctness ---
// Only apply complex validation if an invalid chunk was actually found.
if (isStreamInvalid) {
// Fail immediately if an invalid chunk was not the absolute last chunk.
if (validChunkAfterInvalidEncountered) {
throw new EmptyStreamError(
'Model stream had invalid intermediate chunks without a tool call.',
);
}
if (!hasToolCall) {
// If the *only* invalid part was the last chunk, we still check its finish reason.
const finishReason = lastChunk?.candidates?.[0]?.finishReason;
const isSuccessfulFinish =
finishReason === 'STOP' || finishReason === 'MAX_TOKENS';
if (!isSuccessfulFinish) {
throw new EmptyStreamError(
'Model stream ended with an invalid chunk and a failed finish reason.',
);
}
}
}
// Bundle all streamed parts into a single Content object