mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 08:47:44 +00:00
feat: useToolScheduler hook to manage parallel tool calls (#448)
This commit is contained in:
@@ -4,34 +4,28 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { useState, useRef, useCallback, useEffect } from 'react';
|
||||
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useInput } from 'ink';
|
||||
import {
|
||||
GeminiClient,
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
ServerGeminiStreamEvent as GeminiEvent,
|
||||
ServerGeminiContentEvent as ContentEvent,
|
||||
ServerGeminiToolCallRequestEvent as ToolCallRequestEvent,
|
||||
ServerGeminiToolCallResponseEvent as ToolCallResponseEvent,
|
||||
ServerGeminiToolCallConfirmationEvent as ToolCallConfirmationEvent,
|
||||
ServerGeminiErrorEvent as ErrorEvent,
|
||||
getErrorMessage,
|
||||
isNodeError,
|
||||
Config,
|
||||
MessageSenderType,
|
||||
ServerToolCallConfirmationDetails,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolCallResponseInfo,
|
||||
ToolConfirmationOutcome,
|
||||
ToolEditConfirmationDetails,
|
||||
ToolExecuteConfirmationDetails,
|
||||
ToolResultDisplay,
|
||||
partListUnionToString,
|
||||
ToolCallRequestInfo,
|
||||
} from '@gemini-code/server';
|
||||
import { type Chat, type PartListUnion, type Part } from '@google/genai';
|
||||
import {
|
||||
StreamingState,
|
||||
IndividualToolCallDisplay,
|
||||
ToolCallStatus,
|
||||
HistoryItemWithoutId,
|
||||
HistoryItemToolGroup,
|
||||
@@ -44,6 +38,7 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { useLogger } from './useLogger.js';
|
||||
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
|
||||
|
||||
enum StreamProcessingStatus {
|
||||
Completed,
|
||||
@@ -65,7 +60,6 @@ export const useGeminiStream = (
|
||||
handleSlashCommand: (cmd: PartListUnion) => boolean,
|
||||
shellModeActive: boolean,
|
||||
) => {
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const chatSessionRef = useRef<Chat | null>(null);
|
||||
@@ -74,6 +68,25 @@ export const useGeminiStream = (
|
||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const logger = useLogger();
|
||||
const [toolCalls, schedule, cancel] = useToolScheduler((tools) => {
|
||||
if (tools.length) {
|
||||
addItem(mapToDisplay(tools), Date.now());
|
||||
submitQuery(
|
||||
tools
|
||||
.filter(
|
||||
(t) =>
|
||||
t.status === 'error' ||
|
||||
t.status === 'cancelled' ||
|
||||
t.status === 'success',
|
||||
)
|
||||
.map((t) => t.response.responsePart),
|
||||
);
|
||||
}
|
||||
}, config);
|
||||
const pendingToolCalls = useMemo(
|
||||
() => (toolCalls.length ? mapToDisplay(toolCalls) : undefined),
|
||||
[toolCalls],
|
||||
);
|
||||
|
||||
const onExec = useCallback(async (done: Promise<void>) => {
|
||||
setIsResponding(true);
|
||||
@@ -104,6 +117,7 @@ export const useGeminiStream = (
|
||||
useInput((_input, key) => {
|
||||
if (streamingState !== StreamingState.Idle && key.escape) {
|
||||
abortControllerRef.current?.abort();
|
||||
cancel();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -215,157 +229,48 @@ export const useGeminiStream = (
|
||||
);
|
||||
};
|
||||
|
||||
const updateConfirmingFunctionStatusUI = (
|
||||
callId: string,
|
||||
confirmationDetails: ToolCallConfirmationDetails | undefined,
|
||||
) => {
|
||||
setPendingHistoryItem((item) =>
|
||||
item?.type === 'tool_group'
|
||||
? {
|
||||
...item,
|
||||
tools: item.tools.map((tool) =>
|
||||
tool.callId === callId
|
||||
? {
|
||||
...tool,
|
||||
status: ToolCallStatus.Confirming,
|
||||
confirmationDetails,
|
||||
}
|
||||
: tool,
|
||||
),
|
||||
}
|
||||
: item,
|
||||
);
|
||||
};
|
||||
|
||||
const wireConfirmationSubmission = (
|
||||
confirmationDetails: ServerToolCallConfirmationDetails,
|
||||
): ToolCallConfirmationDetails => {
|
||||
const originalConfirmationDetails = confirmationDetails.details;
|
||||
const request = confirmationDetails.request;
|
||||
const resubmittingConfirm = async (outcome: ToolConfirmationOutcome) => {
|
||||
originalConfirmationDetails.onConfirm(outcome);
|
||||
if (pendingHistoryItemRef?.current?.type === 'tool_group') {
|
||||
setPendingHistoryItem((item) =>
|
||||
item?.type === 'tool_group'
|
||||
? {
|
||||
...item,
|
||||
tools: item.tools.map((tool) =>
|
||||
tool.callId === request.callId
|
||||
? {
|
||||
...tool,
|
||||
confirmationDetails: undefined,
|
||||
status: ToolCallStatus.Executing,
|
||||
}
|
||||
: tool,
|
||||
),
|
||||
}
|
||||
: item,
|
||||
);
|
||||
refreshStatic();
|
||||
}
|
||||
|
||||
if (outcome === ToolConfirmationOutcome.Cancel) {
|
||||
declineToolExecution(
|
||||
'User rejected function call.',
|
||||
ToolCallStatus.Error,
|
||||
request,
|
||||
originalConfirmationDetails,
|
||||
);
|
||||
} else {
|
||||
const tool = toolRegistry.getTool(request.name);
|
||||
if (!tool) {
|
||||
throw new Error(
|
||||
`Tool "${request.name}" not found or is not registered.`,
|
||||
);
|
||||
}
|
||||
try {
|
||||
abortControllerRef.current = new AbortController();
|
||||
const result = await tool.execute(
|
||||
request.args,
|
||||
abortControllerRef.current.signal,
|
||||
);
|
||||
if (abortControllerRef.current.signal.aborted) {
|
||||
declineToolExecution(
|
||||
partListUnionToString(result.llmContent),
|
||||
ToolCallStatus.Canceled,
|
||||
request,
|
||||
originalConfirmationDetails,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
name: request.name,
|
||||
id: request.callId,
|
||||
response: { output: result.llmContent },
|
||||
},
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay: result.returnDisplay,
|
||||
error: undefined,
|
||||
};
|
||||
updateFunctionResponseUI(responseInfo, ToolCallStatus.Success);
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
setIsResponding(false);
|
||||
await submitQuery(functionResponse); // Recursive call
|
||||
} finally {
|
||||
if (streamingState !== StreamingState.WaitingForConfirmation) {
|
||||
abortControllerRef.current = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
||||
// or could be a standalone helper if more params are passed.
|
||||
function declineToolExecution(
|
||||
declineMessage: string,
|
||||
status: ToolCallStatus,
|
||||
request: ServerToolCallConfirmationDetails['request'],
|
||||
originalDetails: ServerToolCallConfirmationDetails['details'],
|
||||
) {
|
||||
let resultDisplay: ToolResultDisplay | undefined;
|
||||
if ('fileDiff' in originalDetails) {
|
||||
resultDisplay = {
|
||||
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
||||
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
||||
};
|
||||
} else {
|
||||
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
||||
}
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
response: { error: declineMessage },
|
||||
},
|
||||
// Extracted declineToolExecution to be part of wireConfirmationSubmission's closure
|
||||
// or could be a standalone helper if more params are passed.
|
||||
// TODO: handle file diff result display stuff
|
||||
function _declineToolExecution(
|
||||
declineMessage: string,
|
||||
status: ToolCallStatus,
|
||||
request: ServerToolCallConfirmationDetails['request'],
|
||||
originalDetails: ServerToolCallConfirmationDetails['details'],
|
||||
) {
|
||||
let resultDisplay: ToolResultDisplay | undefined;
|
||||
if ('fileDiff' in originalDetails) {
|
||||
resultDisplay = {
|
||||
fileDiff: (originalDetails as ToolEditConfirmationDetails).fileDiff,
|
||||
fileName: (originalDetails as ToolEditConfirmationDetails).fileName,
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay,
|
||||
error: new Error(declineMessage),
|
||||
};
|
||||
const history = chatSessionRef.current?.getHistory();
|
||||
if (history) {
|
||||
history.push({ role: 'model', parts: [functionResponse] });
|
||||
}
|
||||
updateFunctionResponseUI(responseInfo, status);
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
setIsResponding(false);
|
||||
} else {
|
||||
resultDisplay = `~~${(originalDetails as ToolExecuteConfirmationDetails).command}~~`;
|
||||
}
|
||||
|
||||
return { ...originalConfirmationDetails, onConfirm: resubmittingConfirm };
|
||||
};
|
||||
const functionResponse: Part = {
|
||||
functionResponse: {
|
||||
id: request.callId,
|
||||
name: request.name,
|
||||
response: { error: declineMessage },
|
||||
},
|
||||
};
|
||||
const responseInfo: ToolCallResponseInfo = {
|
||||
callId: request.callId,
|
||||
responsePart: functionResponse,
|
||||
resultDisplay,
|
||||
error: new Error(declineMessage),
|
||||
};
|
||||
const history = chatSessionRef.current?.getHistory();
|
||||
if (history) {
|
||||
history.push({ role: 'model', parts: [functionResponse] });
|
||||
}
|
||||
updateFunctionResponseUI(responseInfo, status);
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
setIsResponding(false);
|
||||
}
|
||||
|
||||
// --- Stream Event Handlers ---
|
||||
const handleContentEvent = (
|
||||
@@ -419,62 +324,6 @@ export const useGeminiStream = (
|
||||
return newGeminiMessageBuffer;
|
||||
};
|
||||
|
||||
const handleToolCallRequestEvent = (
|
||||
eventValue: ToolCallRequestEvent['value'],
|
||||
userMessageTimestamp: number,
|
||||
) => {
|
||||
const { callId, name, args } = eventValue;
|
||||
const cliTool = toolRegistry.getTool(name);
|
||||
if (!cliTool) {
|
||||
console.error(`CLI Tool "${name}" not found!`);
|
||||
return; // Skip this event if tool is not found
|
||||
}
|
||||
if (pendingHistoryItemRef.current?.type !== 'tool_group') {
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
}
|
||||
setPendingHistoryItem({ type: 'tool_group', tools: [] });
|
||||
}
|
||||
let description: string;
|
||||
try {
|
||||
description = cliTool.getDescription(args);
|
||||
} catch (e) {
|
||||
description = `Error: Unable to get description: ${getErrorMessage(e)}`;
|
||||
}
|
||||
const toolCallDisplay: IndividualToolCallDisplay = {
|
||||
callId,
|
||||
name: cliTool.displayName,
|
||||
description,
|
||||
status: ToolCallStatus.Pending,
|
||||
resultDisplay: undefined,
|
||||
confirmationDetails: undefined,
|
||||
};
|
||||
setPendingHistoryItem((pending) =>
|
||||
pending?.type === 'tool_group'
|
||||
? { ...pending, tools: [...pending.tools, toolCallDisplay] }
|
||||
: null,
|
||||
);
|
||||
};
|
||||
|
||||
const handleToolCallResponseEvent = (
|
||||
eventValue: ToolCallResponseEvent['value'],
|
||||
) => {
|
||||
const status = eventValue.error
|
||||
? ToolCallStatus.Error
|
||||
: ToolCallStatus.Success;
|
||||
updateFunctionResponseUI(eventValue, status);
|
||||
};
|
||||
|
||||
const handleToolCallConfirmationEvent = (
|
||||
eventValue: ToolCallConfirmationEvent['value'],
|
||||
) => {
|
||||
const confirmationDetails = wireConfirmationSubmission(eventValue);
|
||||
updateConfirmingFunctionStatusUI(
|
||||
eventValue.request.callId,
|
||||
confirmationDetails,
|
||||
);
|
||||
};
|
||||
|
||||
const handleUserCancelledEvent = (userMessageTimestamp: number) => {
|
||||
if (pendingHistoryItemRef.current) {
|
||||
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
||||
@@ -500,6 +349,7 @@ export const useGeminiStream = (
|
||||
userMessageTimestamp,
|
||||
);
|
||||
setIsResponding(false);
|
||||
cancel();
|
||||
};
|
||||
|
||||
const handleErrorEvent = (
|
||||
@@ -521,7 +371,7 @@ export const useGeminiStream = (
|
||||
userMessageTimestamp: number,
|
||||
): Promise<StreamProcessingStatus> => {
|
||||
let geminiMessageBuffer = '';
|
||||
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
for await (const event of stream) {
|
||||
if (event.type === ServerGeminiEventType.Content) {
|
||||
geminiMessageBuffer = handleContentEvent(
|
||||
@@ -530,12 +380,7 @@ export const useGeminiStream = (
|
||||
userMessageTimestamp,
|
||||
);
|
||||
} else if (event.type === ServerGeminiEventType.ToolCallRequest) {
|
||||
handleToolCallRequestEvent(event.value, userMessageTimestamp);
|
||||
} else if (event.type === ServerGeminiEventType.ToolCallResponse) {
|
||||
handleToolCallResponseEvent(event.value);
|
||||
} else if (event.type === ServerGeminiEventType.ToolCallConfirmation) {
|
||||
handleToolCallConfirmationEvent(event.value);
|
||||
return StreamProcessingStatus.PausedForConfirmation;
|
||||
toolCallRequests.push(event.value);
|
||||
} else if (event.type === ServerGeminiEventType.UserCancelled) {
|
||||
handleUserCancelledEvent(userMessageTimestamp);
|
||||
return StreamProcessingStatus.UserCancelled;
|
||||
@@ -544,9 +389,18 @@ export const useGeminiStream = (
|
||||
return StreamProcessingStatus.Error;
|
||||
}
|
||||
}
|
||||
schedule(toolCallRequests);
|
||||
return StreamProcessingStatus.Completed;
|
||||
};
|
||||
|
||||
const streamingState: StreamingState = isResponding
|
||||
? StreamingState.Responding
|
||||
: pendingToolCalls?.tools.some(
|
||||
(t) => t.status === ToolCallStatus.Confirming,
|
||||
)
|
||||
? StreamingState.WaitingForConfirmation
|
||||
: StreamingState.Idle;
|
||||
|
||||
const submitQuery = useCallback(
|
||||
async (query: PartListUnion) => {
|
||||
if (isResponding) return;
|
||||
@@ -625,20 +479,15 @@ export const useGeminiStream = (
|
||||
],
|
||||
);
|
||||
|
||||
const streamingState: StreamingState = isResponding
|
||||
? StreamingState.Responding
|
||||
: pendingConfirmations(pendingHistoryItemRef.current)
|
||||
? StreamingState.WaitingForConfirmation
|
||||
: StreamingState.Idle;
|
||||
const pendingHistoryItems = [
|
||||
pendingHistoryItemRef.current,
|
||||
pendingToolCalls,
|
||||
].filter((i) => i !== undefined && i !== null);
|
||||
|
||||
return {
|
||||
streamingState,
|
||||
submitQuery,
|
||||
initError,
|
||||
pendingHistoryItem: pendingHistoryItemRef.current,
|
||||
pendingHistoryItems,
|
||||
};
|
||||
};
|
||||
|
||||
const pendingConfirmations = (item: HistoryItemWithoutId | null): boolean =>
|
||||
item?.type === 'tool_group' &&
|
||||
item.tools.some((t) => t.status === ToolCallStatus.Confirming);
|
||||
|
||||
Reference in New Issue
Block a user