Change the type of ToolResult.responseParts (#6875)

This commit is contained in:
Tommaso Sciortino
2025-08-22 14:12:05 -07:00
committed by GitHub
parent 9a0722625b
commit 75822d3506
13 changed files with 205 additions and 324 deletions

View File

@@ -111,16 +111,7 @@ export async function runNonInteractive(
}
if (toolResponse.responseParts) {
const parts = Array.isArray(toolResponse.responseParts)
? toolResponse.responseParts
: [toolResponse.responseParts];
for (const part of parts) {
if (typeof part === 'string') {
toolResponseParts.push({ text: part });
} else if (part) {
toolResponseParts.push(part);
}
}
toolResponseParts.push(...toolResponse.responseParts);
}
}
currentMessages = [{ role: 'user', parts: toolResponseParts }];

View File

@@ -15,7 +15,7 @@ import {
MockInstance,
} from 'vitest';
import { renderHook, act, waitFor } from '@testing-library/react';
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
import { useGeminiStream } from './useGeminiStream.js';
import { useKeypress } from './useKeypress.js';
import * as atCommandProcessor from './atCommandProcessor.js';
import {
@@ -138,125 +138,6 @@ vi.mock('./slashCommandProcessor.js', () => ({
// --- END MOCKS ---
describe('mergePartListUnions', () => {
it('should merge multiple PartListUnion arrays', () => {
const list1: PartListUnion = [{ text: 'Hello' }];
const list2: PartListUnion = [
{ inlineData: { mimeType: 'image/png', data: 'abc' } },
];
const list3: PartListUnion = [{ text: 'World' }, { text: '!' }];
const result = mergePartListUnions([list1, list2, list3]);
expect(result).toEqual([
{ text: 'Hello' },
{ inlineData: { mimeType: 'image/png', data: 'abc' } },
{ text: 'World' },
{ text: '!' },
]);
});
it('should handle empty arrays in the input list', () => {
const list1: PartListUnion = [{ text: 'First' }];
const list2: PartListUnion = [];
const list3: PartListUnion = [{ text: 'Last' }];
const result = mergePartListUnions([list1, list2, list3]);
expect(result).toEqual([{ text: 'First' }, { text: 'Last' }]);
});
it('should handle a single PartListUnion array', () => {
const list1: PartListUnion = [
{ text: 'One' },
{ inlineData: { mimeType: 'image/jpeg', data: 'xyz' } },
];
const result = mergePartListUnions([list1]);
expect(result).toEqual(list1);
});
it('should return an empty array if all input arrays are empty', () => {
const list1: PartListUnion = [];
const list2: PartListUnion = [];
const result = mergePartListUnions([list1, list2]);
expect(result).toEqual([]);
});
it('should handle input list being empty', () => {
const result = mergePartListUnions([]);
expect(result).toEqual([]);
});
it('should correctly merge when PartListUnion items are single Parts not in arrays', () => {
const part1: Part = { text: 'Single part 1' };
const part2: Part = { inlineData: { mimeType: 'image/gif', data: 'gif' } };
const listContainingSingleParts: PartListUnion[] = [
part1,
[part2],
{ text: 'Another single part' },
];
const result = mergePartListUnions(listContainingSingleParts);
expect(result).toEqual([
{ text: 'Single part 1' },
{ inlineData: { mimeType: 'image/gif', data: 'gif' } },
{ text: 'Another single part' },
]);
});
it('should handle a mix of arrays and single parts, including empty arrays and undefined/null parts if they were possible (though PartListUnion typing restricts this)', () => {
const list1: PartListUnion = [{ text: 'A' }];
const list2: PartListUnion = [];
const part3: Part = { text: 'B' };
const list4: PartListUnion = [
{ text: 'C' },
{ inlineData: { mimeType: 'text/plain', data: 'D' } },
];
const result = mergePartListUnions([list1, list2, part3, list4]);
expect(result).toEqual([
{ text: 'A' },
{ text: 'B' },
{ text: 'C' },
{ inlineData: { mimeType: 'text/plain', data: 'D' } },
]);
});
it('should preserve the order of parts from the input arrays', () => {
const listA: PartListUnion = [{ text: '1' }, { text: '2' }];
const listB: PartListUnion = [{ text: '3' }];
const listC: PartListUnion = [{ text: '4' }, { text: '5' }];
const result = mergePartListUnions([listA, listB, listC]);
expect(result).toEqual([
{ text: '1' },
{ text: '2' },
{ text: '3' },
{ text: '4' },
{ text: '5' },
]);
});
it('should handle cases where some PartListUnion items are single Parts and others are arrays of Parts', () => {
const singlePart1: Part = { text: 'First single' };
const arrayPart1: Part[] = [
{ text: 'Array item 1' },
{ text: 'Array item 2' },
];
const singlePart2: Part = {
inlineData: { mimeType: 'application/json', data: 'e30=' },
}; // {}
const arrayPart2: Part[] = [{ text: 'Last array item' }];
const result = mergePartListUnions([
singlePart1,
arrayPart1,
singlePart2,
arrayPart2,
]);
expect(result).toEqual([
{ text: 'First single' },
{ text: 'Array item 1' },
{ text: 'Array item 2' },
{ inlineData: { mimeType: 'application/json', data: 'e30=' } },
{ text: 'Last array item' },
]);
});
});
// --- Tests for useGeminiStream Hook ---
describe('useGeminiStream', () => {
let mockAddItem: Mock;
@@ -505,12 +386,8 @@ describe('useGeminiStream', () => {
});
it('should submit tool responses when all tool calls are completed and ready', async () => {
const toolCall1ResponseParts: PartListUnion = [
{ text: 'tool 1 final response' },
];
const toolCall2ResponseParts: PartListUnion = [
{ text: 'tool 2 final response' },
];
const toolCall1ResponseParts: Part[] = [{ text: 'tool 1 final response' }];
const toolCall2ResponseParts: Part[] = [{ text: 'tool 2 final response' }];
const completedToolCalls: TrackedToolCall[] = [
{
request: {
@@ -593,10 +470,10 @@ describe('useGeminiStream', () => {
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
});
const expectedMergedResponse = mergePartListUnions([
toolCall1ResponseParts,
toolCall2ResponseParts,
]);
const expectedMergedResponse = [
...toolCall1ResponseParts,
...toolCall2ResponseParts,
];
expect(mockSendMessageStream).toHaveBeenCalledWith(
expectedMergedResponse,
expect.any(AbortSignal),

View File

@@ -56,18 +56,6 @@ import {
import { useSessionStats } from '../contexts/SessionContext.js';
import { useKeypress } from './useKeypress.js';
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
const resultParts: PartListUnion = [];
for (const item of list) {
if (Array.isArray(item)) {
resultParts.push(...item);
} else {
resultParts.push(item);
}
}
return resultParts;
}
enum StreamProcessingStatus {
Completed,
UserCancelled,
@@ -805,19 +793,9 @@ export const useGeminiStream = (
if (geminiClient) {
// We need to manually add the function responses to the history
// so the model knows the tools were cancelled.
const responsesToAdd = geminiTools.flatMap(
const combinedParts = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
const combinedParts: Part[] = [];
for (const response of responsesToAdd) {
if (Array.isArray(response)) {
combinedParts.push(...response);
} else if (typeof response === 'string') {
combinedParts.push({ text: response });
} else {
combinedParts.push(response);
}
}
geminiClient.addHistory({
role: 'user',
parts: combinedParts,
@@ -831,7 +809,7 @@ export const useGeminiStream = (
return;
}
const responsesToSend: PartListUnion[] = geminiTools.map(
const responsesToSend: Part[] = geminiTools.flatMap(
(toolCall) => toolCall.response.responseParts,
);
const callIdsToMarkAsSubmitted = geminiTools.map(
@@ -850,7 +828,7 @@ export const useGeminiStream = (
}
submitQuery(
mergePartListUnions(responsesToSend),
responsesToSend,
{
isContinuation: true,
},

View File

@@ -239,13 +239,15 @@ describe('useReactToolScheduler in YOLO Mode', () => {
request,
response: expect.objectContaining({
resultDisplay: 'YOLO Formatted tool output',
responseParts: {
functionResponse: {
id: 'yoloCall',
name: 'mockToolRequiresConfirmation',
response: { output: expectedOutput },
responseParts: [
{
functionResponse: {
id: 'yoloCall',
name: 'mockToolRequiresConfirmation',
response: { output: expectedOutput },
},
},
},
],
}),
}),
]);
@@ -388,13 +390,15 @@ describe('useReactToolScheduler', () => {
request,
response: expect.objectContaining({
resultDisplay: 'Formatted tool output',
responseParts: {
functionResponse: {
id: 'call1',
name: 'mockTool',
response: { output: 'Tool output' },
responseParts: [
{
functionResponse: {
id: 'call1',
name: 'mockTool',
response: { output: 'Tool output' },
},
},
},
],
}),
}),
]);
@@ -769,13 +773,15 @@ describe('useReactToolScheduler', () => {
request: requests[0],
response: expect.objectContaining({
resultDisplay: 'Display 1',
responseParts: {
functionResponse: {
id: 'multi1',
name: 'tool1',
response: { output: 'Output 1' },
responseParts: [
{
functionResponse: {
id: 'multi1',
name: 'tool1',
response: { output: 'Output 1' },
},
},
},
],
}),
});
expect(call2Result).toMatchObject({
@@ -783,13 +789,15 @@ describe('useReactToolScheduler', () => {
request: requests[1],
response: expect.objectContaining({
resultDisplay: 'Display 2',
responseParts: {
functionResponse: {
id: 'multi2',
name: 'tool2',
response: { output: 'Output 2' },
responseParts: [
{
functionResponse: {
id: 'multi2',
name: 'tool2',
response: { output: 'Output 2' },
},
},
},
],
}),
});
expect(result.current[0]).toEqual([]);

View File

@@ -26,7 +26,7 @@ import {
import * as acp from './acp.js';
import { AcpFileSystemService } from './fileSystemService.js';
import { Readable, Writable } from 'node:stream';
import { Content, Part, FunctionCall, PartListUnion } from '@google/genai';
import { Content, Part, FunctionCall } from '@google/genai';
import { LoadedSettings, SettingScope } from '../config/settings.js';
import * as fs from 'fs/promises';
import * as path from 'path';
@@ -300,16 +300,7 @@ class Session {
for (const fc of functionCalls) {
const response = await this.runTool(pendingSend.signal, promptId, fc);
const parts = Array.isArray(response) ? response : [response];
for (const part of parts) {
if (typeof part === 'string') {
toolResponseParts.push({ text: part });
} else if (part) {
toolResponseParts.push(part);
}
}
toolResponseParts.push(...response);
}
nextMessage = { role: 'user', parts: toolResponseParts };
@@ -332,7 +323,7 @@ class Session {
abortSignal: AbortSignal,
promptId: string,
fc: FunctionCall,
): Promise<PartListUnion> {
): Promise<Part[]> {
const callId = fc.id ?? `${fc.name}-${Date.now()}`;
const args = (fc.args ?? {}) as Record<string, unknown>;