feat: useToolScheduler hook to manage parallel tool calls (#448)

This commit is contained in:
Brandon Keiji
2025-05-22 05:57:53 +00:00
committed by GitHub
parent efee7c6cce
commit 02eec5c8ca
6 changed files with 109 additions and 369 deletions

View File

@@ -134,7 +134,7 @@ export const App = ({
cliVersion,
);
const { streamingState, submitQuery, initError, pendingHistoryItem } =
const { streamingState, submitQuery, initError, pendingHistoryItems } =
useGeminiStream(
addItem,
refreshStatic,
@@ -209,7 +209,7 @@ export const App = ({
}, [terminalHeight, footerHeight]);
useEffect(() => {
if (!pendingHistoryItem) {
if (!pendingHistoryItems.length) {
return;
}
@@ -223,7 +223,7 @@ export const App = ({
if (pendingItemDimensions.height > availableTerminalHeight) {
setStaticNeedsRefresh(true);
}
}, [pendingHistoryItem, availableTerminalHeight, streamingState]);
}, [pendingHistoryItems.length, availableTerminalHeight, streamingState]);
useEffect(() => {
if (streamingState === StreamingState.Idle && staticNeedsRefresh) {
@@ -264,17 +264,18 @@ export const App = ({
>
{(item) => item}
</Static>
{pendingHistoryItem && (
<Box ref={pendingHistoryItemRef}>
<Box ref={pendingHistoryItemRef}>
{pendingHistoryItems.map((item, i) => (
<HistoryItemDisplay
key={i}
availableTerminalHeight={availableTerminalHeight}
// TODO(taehykim): It seems like references to ids aren't necessary in
// HistoryItemDisplay. Refactor later. Use a fake id for now.
item={{ ...pendingHistoryItem, id: 0 }}
item={{ ...item, id: 0 }}
isPending={true}
/>
</Box>
)}
))}
</Box>
{showHelp && <Help commands={slashCommands} />}
<Box flexDirection="column" ref={mainControlsRef}>

View File

@@ -4,7 +4,7 @@
* SPDX-License-Identifier: Apache-2.0
*/
import React from 'react';
import React, { useMemo } from 'react';
import { Box } from 'ink';
import { IndividualToolCallDisplay, ToolCallStatus } from '../../types.js';
import { ToolMessage } from './ToolMessage.js';
@@ -19,7 +19,6 @@ interface ToolGroupMessageProps {
// Main component renders the border and maps the tools using ToolMessage
export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
groupId,
toolCalls,
availableTerminalHeight,
}) => {
@@ -30,9 +29,13 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
const staticHeight = /* border */ 2 + /* marginBottom */ 1;
const toolAwaitingApproval = useMemo(
() => toolCalls.find((tc) => tc.status === ToolCallStatus.Confirming),
[toolCalls],
);
return (
<Box
key={groupId}
flexDirection="column"
borderStyle="round"
/*
@@ -48,7 +51,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
marginBottom={1}
>
{toolCalls.map((tool) => (
<Box key={groupId + '-' + tool.callId} flexDirection="column">
<Box key={tool.callId} flexDirection="column">
<ToolMessage
key={tool.callId}
callId={tool.callId}
@@ -60,6 +63,7 @@ export const ToolGroupMessage: React.FC<ToolGroupMessageProps> = ({
availableTerminalHeight={availableTerminalHeight - staticHeight}
/>
{tool.status === ToolCallStatus.Confirming &&
tool.callId === toolAwaitingApproval?.callId &&
tool.confirmationDetails && (
<ToolConfirmationMessage
confirmationDetails={tool.confirmationDetails}

View File

@@ -4,34 +4,28 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { useState, useRef, useCallback, useEffect } from 'react';
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useInput } from 'ink';
import {
GeminiClient,
GeminiEventType as ServerGeminiEventType,
ServerGeminiStreamEvent as GeminiEvent,
ServerGeminiContentEvent as ContentEvent,
ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
ServerGeminiErrorEvent as ErrorEvent,
getErrorMessage,
isNodeError,
Config,
MessageSenderType,
ServerToolCallConfirmationDetails,
ToolCallConfirmationDetails,
ToolCallResponseInfo,
ToolConfirmationOutcome,
ToolEditConfirmationDetails,
ToolExecuteConfirmationDetails,
ToolResultDisplay,
partListUnionToString,
ToolCallRequestInfo,
} from '@gemini-code/server';
import { type Chat, type PartListUnion, type Part } from '@google/genai';
import {
StreamingState,
IndividualToolCallDisplay,
ToolCallStatus,
HistoryItemWithoutId,
HistoryItemToolGroup,
@@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
enum StreamProcessingStatus {
Completed,
@@ -65,7 +60,6 @@ export const useGeminiStream = (
handleSlashCommand: (cmd: PartListUnion) => boolean,
shellModeActive: boolean,
) => {
const toolRegistry = config.getToolRegistry();
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const chatSessionRef = useRef<Chat | null>(null);
@@ -74,6 +68,25 @@ export const useGeminiStream = (
const [pendingHistoryItemRef, setPendingHistoryItem] =
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
if (tools.length) {
addItem(mapToDisplay(tools), Date.now());
submitQuery(
tools
.filter(
(t) =>
t.status === 'error' ||
t.status === 'cancelled' ||
t.status === 'success',
)
.map((t) => t.response.responsePart),
);
}
}, config);
const pendingToolCalls = useMemo(
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
[toolCalls],
);
const onExec = useCallback(async (done: Promise<void>) => {
setIsResponding(true);
@@ -104,6 +117,7 @@ export const useGeminiStream = (
useInput((_input, key) => {
if (streamingState !== StreamingState.Idle && key.escape) {
abortControllerRef.current?.abort();
cancel();
}
});
@@ -215,157 +229,48 @@ export const useGeminiStream = (
);
};
const updateConfirmingFunctionStatusUI = (
callId: string,
confirmationDetails: ToolCallConfirmationDetails | undefined,
) => {
setPendingHistoryItem((item) =>
item?.type === 'tool_group'
? {
...item,
tools: item.tools.map((tool) =>
tool.callId === callId
? {
...tool,
status: ToolCallStatus.Confirming,
confirmationDetails,
}
: tool,
),
}
: item,
);
};
const wireConfirmationSubmission = (
confirmationDetails: ServerToolCallConfirmationDetails,
): ToolCallConfirmationDetails => {
const originalConfirmationDetails = confirmationDetails.details;
const request = confirmationDetails.request;
const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
originalConfirmationDetails.onConfirm(outcome);
if (pendingHistoryItemRef?.current?.type === 'tool_group') {
setPendingHistoryItem((item) =>
item?.type === 'tool_group'
? {
...item,
tools: item.tools.map((tool) =>
tool.callId === request.callId
? {
...tool,
confirmationDetails: undefined,
status: ToolCallStatus.Executing,
}
: tool,
),
}
: item,
);
refreshStatic();
}
if (outcome === ToolConfirmationOutcome.Cancel) {
declineToolExecution(
'User rejected function call.',
ToolCallStatus.Error,
request,
originalConfirmationDetails,
);
} else {
const tool = toolRegistry.getTool(request.name);
if (!tool) {
throw new Error(
`Tool "${request.name}" not found or is not registered.`,
);
}
try {
abortControllerRef.current = new AbortController();
const result = await tool.execute(
request.args,
abortControllerRef.current.signal,
);
if (abortControllerRef.current.signal.aborted) {
declineToolExecution(
partListUnionToString(result.llmContent),
ToolCallStatus.Canceled,
request,
originalConfirmationDetails,
);
return;
}
const functionResponse: Part = {
functionResponse: {
name: request.name,
id: request.callId,
response: { output: result.llmContent },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay: result.returnDisplay,
error: undefined,
};
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setIsResponding(false);
await submitQuery(functionResponse); // Recursive call
} finally {
if (streamingState !== StreamingState.WaitingForConfirmation) {
abortControllerRef.current = null;
}
}
}
};
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
// or could be a standalone helper if more params are passed.
function declineToolExecution(
declineMessage: string,
status: ToolCallStatus,
request: ServerToolCallConfirmationDetails['request'],
originalDetails: ServerToolCallConfirmationDetails['details'],
) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalDetails) {
resultDisplay = {
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
};
} else {
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
}
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: declineMessage },
},
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
// or could be a standalone helper if more params are passed.
// TODO: handle file diff result display stuff
function _declineToolExecution(
declineMessage: string,
status: ToolCallStatus,
request: ServerToolCallConfirmationDetails['request'],
originalDetails: ServerToolCallConfirmationDetails['details'],
) {
let resultDisplay: ToolResultDisplay | undefined;
if ('fileDiff' in originalDetails) {
resultDisplay = {
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: new Error(declineMessage),
};
const history = chatSessionRef.current?.getHistory();
if (history) {
history.push({ role: 'model', parts: [functionResponse] });
}
updateFunctionResponseUI(responseInfo, status);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setIsResponding(false);
} else {
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
}
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
};
const functionResponse: Part = {
functionResponse: {
id: request.callId,
name: request.name,
response: { error: declineMessage },
},
};
const responseInfo: ToolCallResponseInfo = {
callId: request.callId,
responsePart: functionResponse,
resultDisplay,
error: new Error(declineMessage),
};
const history = chatSessionRef.current?.getHistory();
if (history) {
history.push({ role: 'model', parts: [functionResponse] });
}
updateFunctionResponseUI(responseInfo, status);
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, Date.now());
setPendingHistoryItem(null);
}
setIsResponding(false);
}
// --- Stream Event Handlers ---
const handleContentEvent = (
@@ -419,62 +324,6 @@ export const useGeminiStream = (
return newGeminiMessageBuffer;
};
const handleToolCallRequestEvent = (
eventValue: ToolCallRequestEvent['value'],
userMessageTimestamp: number,
) => {
const { callId, name, args } = eventValue;
const cliTool = toolRegistry.getTool(name);
if (!cliTool) {
console.error(`CLI Tool "${name}" not found!`);
return; // Skip this event if tool is not found
}
if (pendingHistoryItemRef.current?.type !== 'tool_group') {
if (pendingHistoryItemRef.current) {
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
}
setPendingHistoryItem({ type: 'tool_group', tools: [] });
}
let description: string;
try {
description = cliTool.getDescription(args);
} catch (e) {
description = `Error: Unable to get description: ${getErrorMessage(e)}`;
}
const toolCallDisplay: IndividualToolCallDisplay = {
callId,
name: cliTool.displayName,
description,
status: ToolCallStatus.Pending,
resultDisplay: undefined,
confirmationDetails: undefined,
};
setPendingHistoryItem((pending) =>
pending?.type === 'tool_group'
? { ...pending, tools: [...pending.tools, toolCallDisplay] }
: null,
);
};
const handleToolCallResponseEvent = (
eventValue: ToolCallResponseEvent['value'],
) => {
const status = eventValue.error
? ToolCallStatus.Error
: ToolCallStatus.Success;
updateFunctionResponseUI(eventValue, status);
};
const handleToolCallConfirmationEvent = (
eventValue: ToolCallConfirmationEvent['value'],
) => {
const confirmationDetails = wireConfirmationSubmission(eventValue);
updateConfirmingFunctionStatusUI(
eventValue.request.callId,
confirmationDetails,
);
};
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
if (pendingHistoryItemRef.current) {
if (pendingHistoryItemRef.current.type === 'tool_group') {
@@ -500,6 +349,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
setIsResponding(false);
cancel();
};
const handleErrorEvent = (
@@ -521,7 +371,7 @@ export const useGeminiStream = (
userMessageTimestamp: number,
): Promise<StreamProcessingStatus> => {
let geminiMessageBuffer = '';
const toolCallRequests: ToolCallRequestInfo[] = [];
for await (const event of stream) {
if (event.type === ServerGeminiEventType.Content) {
geminiMessageBuffer = handleContentEvent(
@@ -530,12 +380,7 @@ export const useGeminiStream = (
userMessageTimestamp,
);
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
handleToolCallRequestEvent(event.value, userMessageTimestamp);
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
handleToolCallResponseEvent(event.value);
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
handleToolCallConfirmationEvent(event.value);
return StreamProcessingStatus.PausedForConfirmation;
toolCallRequests.push(event.value);
} else if (event.type === ServerGeminiEventType.UserCancelled) {
handleUserCancelledEvent(userMessageTimestamp);
return StreamProcessingStatus.UserCancelled;
@@ -544,9 +389,18 @@ export const useGeminiStream = (
return StreamProcessingStatus.Error;
}
}
schedule(toolCallRequests);
return StreamProcessingStatus.Completed;
};
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingToolCalls?.tools.some(
(t) => t.status === ToolCallStatus.Confirming,
)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle;
const submitQuery = useCallback(
async (query: PartListUnion) => {
if (isResponding) return;
@@ -625,20 +479,15 @@ export const useGeminiStream = (
],
);
const streamingState: StreamingState = isResponding
? StreamingState.Responding
: pendingConfirmations(pendingHistoryItemRef.current)
? StreamingState.WaitingForConfirmation
: StreamingState.Idle;
const pendingHistoryItems = [
pendingHistoryItemRef.current,
pendingToolCalls,
].filter((i) => i !== undefined && i !== null);
return {
streamingState,
submitQuery,
initError,
pendingHistoryItem: pendingHistoryItemRef.current,
pendingHistoryItems,
};
};
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
item?.type === 'tool_group' &&
item.tools.some((t) => t.status === ToolCallStatus.Confirming);