mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
pre-release commit
This commit is contained in:
888
packages/cli/src/ui/hooks/useGeminiStream.ts
Normal file
888
packages/cli/src/ui/hooks/useGeminiStream.ts
Normal file
@@ -0,0 +1,888 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useInput } from 'ink';
|
||||
import {
|
||||
Config,
|
||||
GeminiClient,
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
ServerGeminiStreamEvent as GeminiEvent,
|
||||
ServerGeminiContentEvent as ContentEvent,
|
||||
ServerGeminiErrorEvent as ErrorEvent,
|
||||
ServerGeminiChatCompressedEvent,
|
||||
getErrorMessage,
|
||||
isNodeError,
|
||||
MessageSenderType,
|
||||
ToolCallRequestInfo,
|
||||
logUserPrompt,
|
||||
GitService,
|
||||
EditorType,
|
||||
ThoughtSummary,
|
||||
UnauthorizedError,
|
||||
UserPromptEvent,
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
} from '@qwen/qwen-code-core';
|
||||
import { type Part, type PartListUnion } from '@google/genai';
|
||||
import {
|
||||
StreamingState,
|
||||
HistoryItem,
|
||||
HistoryItemWithoutId,
|
||||
HistoryItemToolGroup,
|
||||
MessageType,
|
||||
SlashCommandProcessorResult,
|
||||
ToolCallStatus,
|
||||
} from '../types.js';
|
||||
import { isAtCommand } from '../utils/commandUtils.js';
|
||||
import { parseAndFormatApiError } from '../utils/errorParsing.js';
|
||||
import { useShellCommandProcessor } from './shellCommandProcessor.js';
|
||||
import { handleAtCommand } from './atCommandProcessor.js';
|
||||
import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { useLogger } from './useLogger.js';
|
||||
import { promises as fs } from 'fs';
|
||||
import path from 'path';
|
||||
import {
|
||||
useReactToolScheduler,
|
||||
mapToDisplay as mapTrackedToolCallsToDisplay,
|
||||
TrackedToolCall,
|
||||
TrackedCompletedToolCall,
|
||||
TrackedCancelledToolCall,
|
||||
} from './useReactToolScheduler.js';
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
|
||||
export function mergePartListUnions(list: PartListUnion[]): PartListUnion {
|
||||
const resultParts: PartListUnion = [];
|
||||
for (const item of list) {
|
||||
if (Array.isArray(item)) {
|
||||
resultParts.push(...item);
|
||||
} else {
|
||||
resultParts.push(item);
|
||||
}
|
||||
}
|
||||
return resultParts;
|
||||
}
|
||||
|
||||
enum StreamProcessingStatus {
|
||||
Completed,
|
||||
UserCancelled,
|
||||
Error,
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages the Gemini stream, including user input, command processing,
|
||||
* API interaction, and tool call lifecycle.
|
||||
*/
|
||||
export const useGeminiStream = (
|
||||
geminiClient: GeminiClient,
|
||||
history: HistoryItem[],
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
config: Config,
|
||||
onDebugMessage: (message: string) => void,
|
||||
handleSlashCommand: (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<SlashCommandProcessorResult | false>,
|
||||
shellModeActive: boolean,
|
||||
getPreferredEditor: () => EditorType | undefined,
|
||||
onAuthError: () => void,
|
||||
performMemoryRefresh: () => Promise<void>,
|
||||
modelSwitchedFromQuotaError: boolean,
|
||||
setModelSwitchedFromQuotaError: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
const abortControllerRef = useRef<AbortController | null>(null);
|
||||
const turnCancelledRef = useRef(false);
|
||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||
const [thought, setThought] = useState<ThoughtSummary | null>(null);
|
||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const processedMemoryToolsRef = useRef<Set<string>>(new Set());
|
||||
const { startNewPrompt, getPromptCount } = useSessionStats();
|
||||
const logger = useLogger();
|
||||
const gitService = useMemo(() => {
|
||||
if (!config.getProjectRoot()) {
|
||||
return;
|
||||
}
|
||||
return new GitService(config.getProjectRoot());
|
||||
}, [config]);
|
||||
|
||||
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||
useReactToolScheduler(
|
||||
async (completedToolCallsFromScheduler) => {
|
||||
// This onComplete is called when ALL scheduled tools for a given batch are done.
|
||||
if (completedToolCallsFromScheduler.length > 0) {
|
||||
// Add the final state of these tools to the history for display.
|
||||
addItem(
|
||||
mapTrackedToolCallsToDisplay(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
),
|
||||
Date.now(),
|
||||
);
|
||||
|
||||
// Handle tool response submission immediately when tools complete
|
||||
await handleCompletedTools(
|
||||
completedToolCallsFromScheduler as TrackedToolCall[],
|
||||
);
|
||||
}
|
||||
},
|
||||
config,
|
||||
setPendingHistoryItem,
|
||||
getPreferredEditor,
|
||||
);
|
||||
|
||||
const pendingToolCallGroupDisplay = useMemo(
|
||||
() =>
|
||||
toolCalls.length ? mapTrackedToolCallsToDisplay(toolCalls) : undefined,
|
||||
[toolCalls],
|
||||
);
|
||||
|
||||
const loopDetectedRef = useRef(false);
|
||||
|
||||
const onExec = useCallback(async (done: Promise<void>) => {
|
||||
setIsResponding(true);
|
||||
await done;
|
||||
setIsResponding(false);
|
||||
}, []);
|
||||
const { handleShellCommand } = useShellCommandProcessor(
|
||||
addItem,
|
||||
setPendingHistoryItem,
|
||||
onExec,
|
||||
onDebugMessage,
|
||||
config,
|
||||
geminiClient,
|
||||
);
|
||||
|
||||
const streamingState = useMemo(() => {
|
||||
if (toolCalls.some((tc) => tc.status === 'awaiting_approval')) {
|
||||
return StreamingState.WaitingForConfirmation;
|
||||
}
|
||||
if (
|
||||
isResponding ||
|
||||
toolCalls.some(
|
||||
(tc) =>
|
||||
tc.status === 'executing' ||
|
||||
tc.status === 'scheduled' ||
|
||||
tc.status === 'validating' ||
|
||||
((tc.status === 'success' ||
|
||||
tc.status === 'error' ||
|
||||
tc.status === 'cancelled') &&
|
||||
!(tc as TrackedCompletedToolCall | TrackedCancelledToolCall)
|
||||
.responseSubmittedToGemini),
|
||||
)
|
||||
) {
|
||||
return StreamingState.Responding;
|
||||
}
|
||||
return StreamingState.Idle;
|
||||
}, [isResponding, toolCalls]);
|
||||
|
||||
useInput((_input, key) => {
|
||||
if (streamingState === StreamingState.Responding && key.escape) {
|
||||
if (turnCancelledRef.current) {
|
||||
return;
|
||||
}
|
||||
turnCancelledRef.current = true;
|
||||
abortControllerRef.current?.abort();
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, Date.now());
|
||||
}
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
text: 'Request cancelled.',
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
setPendingHistoryItem(null);
|
||||
setIsResponding(false);
|
||||
}
|
||||
});
|
||||
|
||||
const prepareQueryForGemini = useCallback(
|
||||
async (
|
||||
query: PartListUnion,
|
||||
userMessageTimestamp: number,
|
||||
abortSignal: AbortSignal,
|
||||
prompt_id: string,
|
||||
): Promise<{
|
||||
queryToSend: PartListUnion | null;
|
||||
shouldProceed: boolean;
|
||||
}> => {
|
||||
if (turnCancelledRef.current) {
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
if (typeof query === 'string' && query.trim().length === 0) {
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
|
||||
let localQueryToSendToGemini: PartListUnion | null = null;
|
||||
|
||||
if (typeof query === 'string') {
|
||||
const trimmedQuery = query.trim();
|
||||
logUserPrompt(
|
||||
config,
|
||||
new UserPromptEvent(
|
||||
trimmedQuery.length,
|
||||
prompt_id,
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
trimmedQuery,
|
||||
),
|
||||
);
|
||||
onDebugMessage(`User query: '${trimmedQuery}'`);
|
||||
await logger?.logMessage(MessageSenderType.USER, trimmedQuery);
|
||||
|
||||
// Handle UI-only commands first
|
||||
const slashCommandResult = await handleSlashCommand(trimmedQuery);
|
||||
|
||||
if (slashCommandResult) {
|
||||
if (slashCommandResult.type === 'schedule_tool') {
|
||||
const { toolName, toolArgs } = slashCommandResult;
|
||||
const toolCallRequest: ToolCallRequestInfo = {
|
||||
callId: `${toolName}-${Date.now()}-${Math.random().toString(16).slice(2)}`,
|
||||
name: toolName,
|
||||
args: toolArgs,
|
||||
isClientInitiated: true,
|
||||
prompt_id,
|
||||
};
|
||||
scheduleToolCalls([toolCallRequest], abortSignal);
|
||||
}
|
||||
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
|
||||
if (shellModeActive && handleShellCommand(trimmedQuery, abortSignal)) {
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
|
||||
// Handle @-commands (which might involve tool calls)
|
||||
if (isAtCommand(trimmedQuery)) {
|
||||
const atCommandResult = await handleAtCommand({
|
||||
query: trimmedQuery,
|
||||
config,
|
||||
addItem,
|
||||
onDebugMessage,
|
||||
messageId: userMessageTimestamp,
|
||||
signal: abortSignal,
|
||||
});
|
||||
if (!atCommandResult.shouldProceed) {
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
localQueryToSendToGemini = atCommandResult.processedQuery;
|
||||
} else {
|
||||
// Normal query for Gemini
|
||||
addItem(
|
||||
{ type: MessageType.USER, text: trimmedQuery },
|
||||
userMessageTimestamp,
|
||||
);
|
||||
localQueryToSendToGemini = trimmedQuery;
|
||||
}
|
||||
} else {
|
||||
// It's a function response (PartListUnion that isn't a string)
|
||||
localQueryToSendToGemini = query;
|
||||
}
|
||||
|
||||
if (localQueryToSendToGemini === null) {
|
||||
onDebugMessage(
|
||||
'Query processing resulted in null, not sending to Gemini.',
|
||||
);
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
return { queryToSend: localQueryToSendToGemini, shouldProceed: true };
|
||||
},
|
||||
[
|
||||
config,
|
||||
addItem,
|
||||
onDebugMessage,
|
||||
handleShellCommand,
|
||||
handleSlashCommand,
|
||||
logger,
|
||||
shellModeActive,
|
||||
scheduleToolCalls,
|
||||
],
|
||||
);
|
||||
|
||||
// --- Stream Event Handlers ---
|
||||
|
||||
const handleContentEvent = useCallback(
|
||||
(
|
||||
eventValue: ContentEvent['value'],
|
||||
currentGeminiMessageBuffer: string,
|
||||
userMessageTimestamp: number,
|
||||
): string => {
|
||||
if (turnCancelledRef.current) {
|
||||
// Prevents additional output after a user initiated cancel.
|
||||
return '';
|
||||
}
|
||||
let newGeminiMessageBuffer = currentGeminiMessageBuffer + eventValue;
|
||||
if (
|
||||
pendingHistoryItemRef.current?.type !== 'gemini' &&
|
||||
pendingHistoryItemRef.current?.type !== 'gemini_content'
|
||||
) {
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
}
|
||||
setPendingHistoryItem({ type: 'gemini', text: '' });
|
||||
newGeminiMessageBuffer = eventValue;
|
||||
}
|
||||
// Split large messages for better rendering performance. Ideally,
|
||||
// we should maximize the amount of output sent to <Static />.
|
||||
const splitPoint = findLastSafeSplitPoint(newGeminiMessageBuffer);
|
||||
if (splitPoint === newGeminiMessageBuffer.length) {
|
||||
// Update the existing message with accumulated content
|
||||
setPendingHistoryItem((item) => ({
|
||||
type: item?.type as 'gemini' | 'gemini_content',
|
||||
text: newGeminiMessageBuffer,
|
||||
}));
|
||||
} else {
|
||||
// This indicates that we need to split up this Gemini Message.
|
||||
// Splitting a message is primarily a performance consideration. There is a
|
||||
// <Static> component at the root of App.tsx which takes care of rendering
|
||||
// content statically or dynamically. Everything but the last message is
|
||||
// treated as static in order to prevent re-rendering an entire message history
|
||||
// multiple times per-second (as streaming occurs). Prior to this change you'd
|
||||
// see heavy flickering of the terminal. This ensures that larger messages get
|
||||
// broken up so that there are more "statically" rendered.
|
||||
const beforeText = newGeminiMessageBuffer.substring(0, splitPoint);
|
||||
const afterText = newGeminiMessageBuffer.substring(splitPoint);
|
||||
addItem(
|
||||
{
|
||||
type: pendingHistoryItemRef.current?.type as
|
||||
| 'gemini'
|
||||
| 'gemini_content',
|
||||
text: beforeText,
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
setPendingHistoryItem({ type: 'gemini_content', text: afterText });
|
||||
newGeminiMessageBuffer = afterText;
|
||||
}
|
||||
return newGeminiMessageBuffer;
|
||||
},
|
||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem],
|
||||
);
|
||||
|
||||
const handleUserCancelledEvent = useCallback(
|
||||
(userMessageTimestamp: number) => {
|
||||
if (turnCancelledRef.current) {
|
||||
return;
|
||||
}
|
||||
if (pendingHistoryItemRef.current) {
|
||||
if (pendingHistoryItemRef.current.type === 'tool_group') {
|
||||
const updatedTools = pendingHistoryItemRef.current.tools.map(
|
||||
(tool) =>
|
||||
tool.status === ToolCallStatus.Pending ||
|
||||
tool.status === ToolCallStatus.Confirming ||
|
||||
tool.status === ToolCallStatus.Executing
|
||||
? { ...tool, status: ToolCallStatus.Canceled }
|
||||
: tool,
|
||||
);
|
||||
const pendingItem: HistoryItemToolGroup = {
|
||||
...pendingHistoryItemRef.current,
|
||||
tools: updatedTools,
|
||||
};
|
||||
addItem(pendingItem, userMessageTimestamp);
|
||||
} else {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
}
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
addItem(
|
||||
{ type: MessageType.INFO, text: 'User cancelled the request.' },
|
||||
userMessageTimestamp,
|
||||
);
|
||||
setIsResponding(false);
|
||||
},
|
||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem],
|
||||
);
|
||||
|
||||
const handleErrorEvent = useCallback(
|
||||
(eventValue: ErrorEvent['value'], userMessageTimestamp: number) => {
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.ERROR,
|
||||
text: parseAndFormatApiError(
|
||||
eventValue.error,
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
},
|
||||
[addItem, pendingHistoryItemRef, setPendingHistoryItem, config],
|
||||
);
|
||||
|
||||
const handleChatCompressionEvent = useCallback(
|
||||
(eventValue: ServerGeminiChatCompressedEvent['value']) =>
|
||||
addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text:
|
||||
`IMPORTANT: This conversation approached the input token limit for ${config.getModel()}. ` +
|
||||
`A compressed context will be sent for future messages (compressed from: ` +
|
||||
`${eventValue?.originalTokenCount ?? 'unknown'} to ` +
|
||||
`${eventValue?.newTokenCount ?? 'unknown'} tokens).`,
|
||||
},
|
||||
Date.now(),
|
||||
),
|
||||
[addItem, config],
|
||||
);
|
||||
|
||||
const handleMaxSessionTurnsEvent = useCallback(
|
||||
() =>
|
||||
addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text:
|
||||
`The session has reached the maximum number of turns: ${config.getMaxSessionTurns()}. ` +
|
||||
`Please update this limit in your setting.json file.`,
|
||||
},
|
||||
Date.now(),
|
||||
),
|
||||
[addItem, config],
|
||||
);
|
||||
|
||||
const handleLoopDetectedEvent = useCallback(() => {
|
||||
addItem(
|
||||
{
|
||||
type: 'info',
|
||||
text: `A potential loop was detected. This can happen due to repetitive tool calls or other model behavior. The request has been halted.`,
|
||||
},
|
||||
Date.now(),
|
||||
);
|
||||
}, [addItem]);
|
||||
|
||||
const processGeminiStreamEvents = useCallback(
|
||||
async (
|
||||
stream: AsyncIterable<GeminiEvent>,
|
||||
userMessageTimestamp: number,
|
||||
signal: AbortSignal,
|
||||
): Promise<StreamProcessingStatus> => {
|
||||
let geminiMessageBuffer = '';
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
for await (const event of stream) {
|
||||
switch (event.type) {
|
||||
case ServerGeminiEventType.Thought:
|
||||
setThought(event.value);
|
||||
break;
|
||||
case ServerGeminiEventType.Content:
|
||||
geminiMessageBuffer = handleContentEvent(
|
||||
event.value,
|
||||
geminiMessageBuffer,
|
||||
userMessageTimestamp,
|
||||
);
|
||||
break;
|
||||
case ServerGeminiEventType.ToolCallRequest:
|
||||
toolCallRequests.push(event.value);
|
||||
break;
|
||||
case ServerGeminiEventType.UserCancelled:
|
||||
handleUserCancelledEvent(userMessageTimestamp);
|
||||
break;
|
||||
case ServerGeminiEventType.Error:
|
||||
handleErrorEvent(event.value, userMessageTimestamp);
|
||||
break;
|
||||
case ServerGeminiEventType.ChatCompressed:
|
||||
handleChatCompressionEvent(event.value);
|
||||
break;
|
||||
case ServerGeminiEventType.ToolCallConfirmation:
|
||||
case ServerGeminiEventType.ToolCallResponse:
|
||||
// do nothing
|
||||
break;
|
||||
case ServerGeminiEventType.MaxSessionTurns:
|
||||
handleMaxSessionTurnsEvent();
|
||||
break;
|
||||
case ServerGeminiEventType.LoopDetected:
|
||||
// handle later because we want to move pending history to history
|
||||
// before we add loop detected message to history
|
||||
loopDetectedRef.current = true;
|
||||
break;
|
||||
default: {
|
||||
// enforces exhaustive switch-case
|
||||
const unreachable: never = event;
|
||||
return unreachable;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (toolCallRequests.length > 0) {
|
||||
scheduleToolCalls(toolCallRequests, signal);
|
||||
}
|
||||
return StreamProcessingStatus.Completed;
|
||||
},
|
||||
[
|
||||
handleContentEvent,
|
||||
handleUserCancelledEvent,
|
||||
handleErrorEvent,
|
||||
scheduleToolCalls,
|
||||
handleChatCompressionEvent,
|
||||
handleMaxSessionTurnsEvent,
|
||||
],
|
||||
);
|
||||
|
||||
const submitQuery = useCallback(
|
||||
async (
|
||||
query: PartListUnion,
|
||||
options?: { isContinuation: boolean },
|
||||
prompt_id?: string,
|
||||
) => {
|
||||
if (
|
||||
(streamingState === StreamingState.Responding ||
|
||||
streamingState === StreamingState.WaitingForConfirmation) &&
|
||||
!options?.isContinuation
|
||||
)
|
||||
return;
|
||||
|
||||
const userMessageTimestamp = Date.now();
|
||||
setShowHelp(false);
|
||||
|
||||
// Reset quota error flag when starting a new query (not a continuation)
|
||||
if (!options?.isContinuation) {
|
||||
setModelSwitchedFromQuotaError(false);
|
||||
config.setQuotaErrorOccurred(false);
|
||||
}
|
||||
|
||||
abortControllerRef.current = new AbortController();
|
||||
const abortSignal = abortControllerRef.current.signal;
|
||||
turnCancelledRef.current = false;
|
||||
|
||||
if (!prompt_id) {
|
||||
prompt_id = config.getSessionId() + '########' + getPromptCount();
|
||||
}
|
||||
|
||||
const { queryToSend, shouldProceed } = await prepareQueryForGemini(
|
||||
query,
|
||||
userMessageTimestamp,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
|
||||
if (!shouldProceed || queryToSend === null) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!options?.isContinuation) {
|
||||
startNewPrompt();
|
||||
}
|
||||
|
||||
setIsResponding(true);
|
||||
setInitError(null);
|
||||
|
||||
try {
|
||||
const stream = geminiClient.sendMessageStream(
|
||||
queryToSend,
|
||||
abortSignal,
|
||||
prompt_id!,
|
||||
);
|
||||
const processingStatus = await processGeminiStreamEvents(
|
||||
stream,
|
||||
userMessageTimestamp,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
if (processingStatus === StreamProcessingStatus.UserCancelled) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (pendingHistoryItemRef.current) {
|
||||
addItem(pendingHistoryItemRef.current, userMessageTimestamp);
|
||||
setPendingHistoryItem(null);
|
||||
}
|
||||
if (loopDetectedRef.current) {
|
||||
loopDetectedRef.current = false;
|
||||
handleLoopDetectedEvent();
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
if (error instanceof UnauthorizedError) {
|
||||
onAuthError();
|
||||
} else if (!isNodeError(error) || error.name !== 'AbortError') {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.ERROR,
|
||||
text: parseAndFormatApiError(
|
||||
getErrorMessage(error) || 'Unknown error',
|
||||
config.getContentGeneratorConfig()?.authType,
|
||||
undefined,
|
||||
config.getModel(),
|
||||
DEFAULT_GEMINI_FLASH_MODEL,
|
||||
),
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
}
|
||||
} finally {
|
||||
setIsResponding(false);
|
||||
}
|
||||
},
|
||||
[
|
||||
streamingState,
|
||||
setShowHelp,
|
||||
setModelSwitchedFromQuotaError,
|
||||
prepareQueryForGemini,
|
||||
processGeminiStreamEvents,
|
||||
pendingHistoryItemRef,
|
||||
addItem,
|
||||
setPendingHistoryItem,
|
||||
setInitError,
|
||||
geminiClient,
|
||||
onAuthError,
|
||||
config,
|
||||
startNewPrompt,
|
||||
getPromptCount,
|
||||
handleLoopDetectedEvent,
|
||||
],
|
||||
);
|
||||
|
||||
const handleCompletedTools = useCallback(
|
||||
async (completedToolCallsFromScheduler: TrackedToolCall[]) => {
|
||||
if (isResponding) {
|
||||
return;
|
||||
}
|
||||
|
||||
const completedAndReadyToSubmitTools =
|
||||
completedToolCallsFromScheduler.filter(
|
||||
(
|
||||
tc: TrackedToolCall,
|
||||
): tc is TrackedCompletedToolCall | TrackedCancelledToolCall => {
|
||||
const isTerminalState =
|
||||
tc.status === 'success' ||
|
||||
tc.status === 'error' ||
|
||||
tc.status === 'cancelled';
|
||||
|
||||
if (isTerminalState) {
|
||||
const completedOrCancelledCall = tc as
|
||||
| TrackedCompletedToolCall
|
||||
| TrackedCancelledToolCall;
|
||||
return (
|
||||
completedOrCancelledCall.response?.responseParts !== undefined
|
||||
);
|
||||
}
|
||||
return false;
|
||||
},
|
||||
);
|
||||
|
||||
// Finalize any client-initiated tools as soon as they are done.
|
||||
const clientTools = completedAndReadyToSubmitTools.filter(
|
||||
(t) => t.request.isClientInitiated,
|
||||
);
|
||||
if (clientTools.length > 0) {
|
||||
markToolsAsSubmitted(clientTools.map((t) => t.request.callId));
|
||||
}
|
||||
|
||||
// Identify new, successful save_memory calls that we haven't processed yet.
|
||||
const newSuccessfulMemorySaves = completedAndReadyToSubmitTools.filter(
|
||||
(t) =>
|
||||
t.request.name === 'save_memory' &&
|
||||
t.status === 'success' &&
|
||||
!processedMemoryToolsRef.current.has(t.request.callId),
|
||||
);
|
||||
|
||||
if (newSuccessfulMemorySaves.length > 0) {
|
||||
// Perform the refresh only if there are new ones.
|
||||
void performMemoryRefresh();
|
||||
// Mark them as processed so we don't do this again on the next render.
|
||||
newSuccessfulMemorySaves.forEach((t) =>
|
||||
processedMemoryToolsRef.current.add(t.request.callId),
|
||||
);
|
||||
}
|
||||
|
||||
const geminiTools = completedAndReadyToSubmitTools.filter(
|
||||
(t) => !t.request.isClientInitiated,
|
||||
);
|
||||
|
||||
if (geminiTools.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// If all the tools were cancelled, don't submit a response to Gemini.
|
||||
const allToolsCancelled = geminiTools.every(
|
||||
(tc) => tc.status === 'cancelled',
|
||||
);
|
||||
|
||||
if (allToolsCancelled) {
|
||||
if (geminiClient) {
|
||||
// We need to manually add the function responses to the history
|
||||
// so the model knows the tools were cancelled.
|
||||
const responsesToAdd = geminiTools.flatMap(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
);
|
||||
const combinedParts: Part[] = [];
|
||||
for (const response of responsesToAdd) {
|
||||
if (Array.isArray(response)) {
|
||||
combinedParts.push(...response);
|
||||
} else if (typeof response === 'string') {
|
||||
combinedParts.push({ text: response });
|
||||
} else {
|
||||
combinedParts.push(response);
|
||||
}
|
||||
}
|
||||
geminiClient.addHistory({
|
||||
role: 'user',
|
||||
parts: combinedParts,
|
||||
});
|
||||
}
|
||||
|
||||
const callIdsToMarkAsSubmitted = geminiTools.map(
|
||||
(toolCall) => toolCall.request.callId,
|
||||
);
|
||||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||
return;
|
||||
}
|
||||
|
||||
const responsesToSend: PartListUnion[] = geminiTools.map(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
);
|
||||
const callIdsToMarkAsSubmitted = geminiTools.map(
|
||||
(toolCall) => toolCall.request.callId,
|
||||
);
|
||||
|
||||
const prompt_ids = geminiTools.map(
|
||||
(toolCall) => toolCall.request.prompt_id,
|
||||
);
|
||||
|
||||
markToolsAsSubmitted(callIdsToMarkAsSubmitted);
|
||||
|
||||
// Don't continue if model was switched due to quota error
|
||||
if (modelSwitchedFromQuotaError) {
|
||||
return;
|
||||
}
|
||||
|
||||
submitQuery(
|
||||
mergePartListUnions(responsesToSend),
|
||||
{
|
||||
isContinuation: true,
|
||||
},
|
||||
prompt_ids[0],
|
||||
);
|
||||
},
|
||||
[
|
||||
isResponding,
|
||||
submitQuery,
|
||||
markToolsAsSubmitted,
|
||||
geminiClient,
|
||||
performMemoryRefresh,
|
||||
modelSwitchedFromQuotaError,
|
||||
],
|
||||
);
|
||||
|
||||
const pendingHistoryItems = [
|
||||
pendingHistoryItemRef.current,
|
||||
pendingToolCallGroupDisplay,
|
||||
].filter((i) => i !== undefined && i !== null);
|
||||
|
||||
useEffect(() => {
|
||||
const saveRestorableToolCalls = async () => {
|
||||
if (!config.getCheckpointingEnabled()) {
|
||||
return;
|
||||
}
|
||||
const restorableToolCalls = toolCalls.filter(
|
||||
(toolCall) =>
|
||||
(toolCall.request.name === 'replace' ||
|
||||
toolCall.request.name === 'write_file') &&
|
||||
toolCall.status === 'awaiting_approval',
|
||||
);
|
||||
|
||||
if (restorableToolCalls.length > 0) {
|
||||
const checkpointDir = config.getProjectTempDir()
|
||||
? path.join(config.getProjectTempDir(), 'checkpoints')
|
||||
: undefined;
|
||||
|
||||
if (!checkpointDir) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
} catch (error) {
|
||||
if (!isNodeError(error) || error.code !== 'EEXIST') {
|
||||
onDebugMessage(
|
||||
`Failed to create checkpoint directory: ${getErrorMessage(error)}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (const toolCall of restorableToolCalls) {
|
||||
const filePath = toolCall.request.args['file_path'] as string;
|
||||
if (!filePath) {
|
||||
onDebugMessage(
|
||||
`Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
let commitHash = await gitService?.createFileSnapshot(
|
||||
`Snapshot for ${toolCall.request.name}`,
|
||||
);
|
||||
|
||||
if (!commitHash) {
|
||||
commitHash = await gitService?.getCurrentCommitHash();
|
||||
}
|
||||
|
||||
if (!commitHash) {
|
||||
onDebugMessage(
|
||||
`Failed to create snapshot for ${filePath}. Skipping restorable tool call.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/:/g, '-')
|
||||
.replace(/\./g, '_');
|
||||
const toolName = toolCall.request.name;
|
||||
const fileName = path.basename(filePath);
|
||||
const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`;
|
||||
const clientHistory = await geminiClient?.getHistory();
|
||||
const toolCallWithSnapshotFilePath = path.join(
|
||||
checkpointDir,
|
||||
toolCallWithSnapshotFileName,
|
||||
);
|
||||
|
||||
await fs.writeFile(
|
||||
toolCallWithSnapshotFilePath,
|
||||
JSON.stringify(
|
||||
{
|
||||
history,
|
||||
clientHistory,
|
||||
toolCall: {
|
||||
name: toolCall.request.name,
|
||||
args: toolCall.request.args,
|
||||
},
|
||||
commitHash,
|
||||
filePath,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
} catch (error) {
|
||||
onDebugMessage(
|
||||
`Failed to write restorable tool call file: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
saveRestorableToolCalls();
|
||||
}, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]);
|
||||
|
||||
return {
|
||||
streamingState,
|
||||
submitQuery,
|
||||
initError,
|
||||
pendingHistoryItems,
|
||||
thought,
|
||||
};
|
||||
};
|
||||
Reference in New Issue
Block a user