mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 08:47:44 +00:00
Restore Checkpoint Feature (#934)
This commit is contained in:
@@ -65,6 +65,14 @@ import {
|
||||
} from '@gemini-cli/core';
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
|
||||
vi.mock('@gemini-code/core', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@gemini-code/core')>();
|
||||
return {
|
||||
...actual,
|
||||
GitService: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
import * as ShowMemoryCommandModule from './useShowMemoryCommand.js';
|
||||
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
|
||||
|
||||
@@ -84,6 +92,7 @@ vi.mock('open', () => ({
|
||||
describe('useSlashCommandProcessor', () => {
|
||||
let mockAddItem: ReturnType<typeof vi.fn>;
|
||||
let mockClearItems: ReturnType<typeof vi.fn>;
|
||||
let mockLoadHistory: ReturnType<typeof vi.fn>;
|
||||
let mockRefreshStatic: ReturnType<typeof vi.fn>;
|
||||
let mockSetShowHelp: ReturnType<typeof vi.fn>;
|
||||
let mockOnDebugMessage: ReturnType<typeof vi.fn>;
|
||||
@@ -96,6 +105,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
beforeEach(() => {
|
||||
mockAddItem = vi.fn();
|
||||
mockClearItems = vi.fn();
|
||||
mockLoadHistory = vi.fn();
|
||||
mockRefreshStatic = vi.fn();
|
||||
mockSetShowHelp = vi.fn();
|
||||
mockOnDebugMessage = vi.fn();
|
||||
@@ -105,6 +115,8 @@ describe('useSlashCommandProcessor', () => {
|
||||
getDebugMode: vi.fn(() => false),
|
||||
getSandbox: vi.fn(() => 'test-sandbox'),
|
||||
getModel: vi.fn(() => 'test-model'),
|
||||
getProjectRoot: vi.fn(() => '/test/dir'),
|
||||
getCheckpointEnabled: vi.fn(() => true),
|
||||
} as unknown as Config;
|
||||
mockCorgiMode = vi.fn();
|
||||
mockUseSessionStats.mockReturnValue({
|
||||
@@ -133,8 +145,10 @@ describe('useSlashCommandProcessor', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useSlashCommandProcessor(
|
||||
mockConfig,
|
||||
[],
|
||||
mockAddItem,
|
||||
mockClearItems,
|
||||
mockLoadHistory,
|
||||
mockRefreshStatic,
|
||||
mockSetShowHelp,
|
||||
mockOnDebugMessage,
|
||||
@@ -153,7 +167,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const fact = 'Remember this fact';
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand(`/memory add ${fact}`);
|
||||
commandResult = await handleSlashCommand(`/memory add ${fact}`);
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -187,7 +201,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/memory add ');
|
||||
commandResult = await handleSlashCommand('/memory add ');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -211,7 +225,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/memory show');
|
||||
commandResult = await handleSlashCommand('/memory show');
|
||||
});
|
||||
expect(
|
||||
ShowMemoryCommandModule.createShowMemoryAction,
|
||||
@@ -226,7 +240,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/memory refresh');
|
||||
commandResult = await handleSlashCommand('/memory refresh');
|
||||
});
|
||||
expect(mockPerformMemoryRefresh).toHaveBeenCalled();
|
||||
expect(commandResult).toBe(true);
|
||||
@@ -238,7 +252,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/memory foobar');
|
||||
commandResult = await handleSlashCommand('/memory foobar');
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
@@ -300,7 +314,7 @@ describe('useSlashCommandProcessor', () => {
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/help');
|
||||
commandResult = await handleSlashCommand('/help');
|
||||
});
|
||||
expect(mockSetShowHelp).toHaveBeenCalledWith(true);
|
||||
expect(commandResult).toBe(true);
|
||||
@@ -373,7 +387,7 @@ Add any other context about the problem here.
|
||||
);
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand(`/bug ${bugDescription}`);
|
||||
commandResult = await handleSlashCommand(`/bug ${bugDescription}`);
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenCalledTimes(2);
|
||||
@@ -387,7 +401,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/unknowncommand');
|
||||
commandResult = await handleSlashCommand('/unknowncommand');
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
@@ -410,7 +424,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/tools');
|
||||
commandResult = await handleSlashCommand('/tools');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -434,7 +448,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/tools');
|
||||
commandResult = await handleSlashCommand('/tools');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -467,7 +481,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/tools');
|
||||
commandResult = await handleSlashCommand('/tools');
|
||||
});
|
||||
|
||||
// Should only show tool1 and tool2, not the MCP tools
|
||||
@@ -499,7 +513,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/tools');
|
||||
commandResult = await handleSlashCommand('/tools');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -545,7 +559,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/mcp');
|
||||
commandResult = await handleSlashCommand('/mcp');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -571,7 +585,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/mcp');
|
||||
commandResult = await handleSlashCommand('/mcp');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -633,7 +647,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/mcp');
|
||||
commandResult = await handleSlashCommand('/mcp');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -706,7 +720,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor(true);
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/mcp');
|
||||
commandResult = await handleSlashCommand('/mcp');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -780,7 +794,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/mcp');
|
||||
commandResult = await handleSlashCommand('/mcp');
|
||||
});
|
||||
|
||||
expect(mockAddItem).toHaveBeenNthCalledWith(
|
||||
@@ -846,7 +860,7 @@ Add any other context about the problem here.
|
||||
const { handleSlashCommand } = getProcessor();
|
||||
let commandResult: SlashCommandActionReturn | boolean = false;
|
||||
await act(async () => {
|
||||
commandResult = handleSlashCommand('/mcp');
|
||||
commandResult = await handleSlashCommand('/mcp');
|
||||
});
|
||||
|
||||
const message = mockAddItem.mock.calls[1][0].text;
|
||||
|
||||
@@ -11,14 +11,22 @@ import process from 'node:process';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import {
|
||||
Config,
|
||||
GitService,
|
||||
Logger,
|
||||
MCPDiscoveryState,
|
||||
MCPServerStatus,
|
||||
getMCPDiscoveryState,
|
||||
getMCPServerStatus,
|
||||
} from '@gemini-cli/core';
|
||||
import { Message, MessageType, HistoryItemWithoutId } from '../types.js';
|
||||
import { useSessionStats } from '../contexts/SessionContext.js';
|
||||
import {
|
||||
Message,
|
||||
MessageType,
|
||||
HistoryItemWithoutId,
|
||||
HistoryItem,
|
||||
} from '../types.js';
|
||||
import { promises as fs } from 'fs';
|
||||
import path from 'path';
|
||||
import { createShowMemoryAction } from './useShowMemoryCommand.js';
|
||||
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
|
||||
import { formatDuration, formatMemoryUsage } from '../utils/formatters.js';
|
||||
@@ -39,7 +47,10 @@ export interface SlashCommand {
|
||||
mainCommand: string,
|
||||
subCommand?: string,
|
||||
args?: string,
|
||||
) => void | SlashCommandActionReturn; // Action can now return this object
|
||||
) =>
|
||||
| void
|
||||
| SlashCommandActionReturn
|
||||
| Promise<void | SlashCommandActionReturn>; // Action can now return this object
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -47,8 +58,10 @@ export interface SlashCommand {
|
||||
*/
|
||||
export const useSlashCommandProcessor = (
|
||||
config: Config | null,
|
||||
history: HistoryItem[],
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
clearItems: UseHistoryManagerReturn['clearItems'],
|
||||
loadHistory: UseHistoryManagerReturn['loadHistory'],
|
||||
refreshStatic: () => void,
|
||||
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
onDebugMessage: (message: string) => void,
|
||||
@@ -58,6 +71,13 @@ export const useSlashCommandProcessor = (
|
||||
showToolDescriptions: boolean = false,
|
||||
) => {
|
||||
const session = useSessionStats();
|
||||
const gitService = useMemo(() => {
|
||||
if (!config?.getProjectRoot()) {
|
||||
return;
|
||||
}
|
||||
return new GitService(config.getProjectRoot());
|
||||
}, [config]);
|
||||
|
||||
const addMessage = useCallback(
|
||||
(message: Message) => {
|
||||
// Convert Message to HistoryItemWithoutId
|
||||
@@ -126,8 +146,8 @@ export const useSlashCommandProcessor = (
|
||||
[addMessage],
|
||||
);
|
||||
|
||||
const slashCommands: SlashCommand[] = useMemo(
|
||||
() => [
|
||||
const slashCommands: SlashCommand[] = useMemo(() => {
|
||||
const commands: SlashCommand[] = [
|
||||
{
|
||||
name: 'help',
|
||||
altName: '?',
|
||||
@@ -408,7 +428,9 @@ export const useSlashCommandProcessor = (
|
||||
if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') {
|
||||
sandboxEnv = process.env.SANDBOX;
|
||||
} else if (process.env.SANDBOX === 'sandbox-exec') {
|
||||
sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`;
|
||||
sandboxEnv = `sandbox-exec (${
|
||||
process.env.SEATBELT_PROFILE || 'unknown'
|
||||
})`;
|
||||
}
|
||||
const modelVersion = config?.getModel() || 'Unknown';
|
||||
const cliVersion = getCliVersion();
|
||||
@@ -437,7 +459,9 @@ export const useSlashCommandProcessor = (
|
||||
if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') {
|
||||
sandboxEnv = process.env.SANDBOX.replace(/^gemini-(?:code-)?/, '');
|
||||
} else if (process.env.SANDBOX === 'sandbox-exec') {
|
||||
sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`;
|
||||
sandboxEnv = `sandbox-exec (${
|
||||
process.env.SEATBELT_PROFILE || 'unknown'
|
||||
})`;
|
||||
}
|
||||
const modelVersion = config?.getModel() || 'Unknown';
|
||||
const memoryUsage = formatMemoryUsage(process.memoryUsage().rss);
|
||||
@@ -569,31 +593,140 @@ Add any other context about the problem here.
|
||||
name: 'quit',
|
||||
altName: 'exit',
|
||||
description: 'exit the cli',
|
||||
action: (_mainCommand, _subCommand, _args) => {
|
||||
action: async (_mainCommand, _subCommand, _args) => {
|
||||
onDebugMessage('Quitting. Good-bye.');
|
||||
process.exit(0);
|
||||
},
|
||||
},
|
||||
],
|
||||
[
|
||||
onDebugMessage,
|
||||
setShowHelp,
|
||||
refreshStatic,
|
||||
openThemeDialog,
|
||||
clearItems,
|
||||
performMemoryRefresh,
|
||||
showMemoryAction,
|
||||
addMemoryAction,
|
||||
addMessage,
|
||||
toggleCorgiMode,
|
||||
config,
|
||||
showToolDescriptions,
|
||||
session,
|
||||
],
|
||||
);
|
||||
];
|
||||
|
||||
if (config?.getCheckpointEnabled()) {
|
||||
commands.push({
|
||||
name: 'restore',
|
||||
description:
|
||||
'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested',
|
||||
action: async (_mainCommand, subCommand, _args) => {
|
||||
const checkpointDir = config?.getGeminiDir()
|
||||
? path.join(config.getGeminiDir(), 'checkpoints')
|
||||
: undefined;
|
||||
|
||||
if (!checkpointDir) {
|
||||
addMessage({
|
||||
type: MessageType.ERROR,
|
||||
content: 'Could not determine the .gemini directory path.',
|
||||
timestamp: new Date(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Ensure the directory exists before trying to read it.
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
const files = await fs.readdir(checkpointDir);
|
||||
const jsonFiles = files.filter((file) => file.endsWith('.json'));
|
||||
|
||||
if (!subCommand) {
|
||||
if (jsonFiles.length === 0) {
|
||||
addMessage({
|
||||
type: MessageType.INFO,
|
||||
content: 'No restorable tool calls found.',
|
||||
timestamp: new Date(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
const truncatedFiles = jsonFiles.map((file) => {
|
||||
const components = file.split('.');
|
||||
if (components.length <= 1) {
|
||||
return file;
|
||||
}
|
||||
components.pop();
|
||||
return components.join('.');
|
||||
});
|
||||
const fileList = truncatedFiles.join('\n');
|
||||
addMessage({
|
||||
type: MessageType.INFO,
|
||||
content: `Available tool calls to restore:\n\n${fileList}`,
|
||||
timestamp: new Date(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedFile = subCommand.endsWith('.json')
|
||||
? subCommand
|
||||
: `${subCommand}.json`;
|
||||
|
||||
if (!jsonFiles.includes(selectedFile)) {
|
||||
addMessage({
|
||||
type: MessageType.ERROR,
|
||||
content: `File not found: ${selectedFile}`,
|
||||
timestamp: new Date(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const filePath = path.join(checkpointDir, selectedFile);
|
||||
const data = await fs.readFile(filePath, 'utf-8');
|
||||
const toolCallData = JSON.parse(data);
|
||||
|
||||
if (toolCallData.history) {
|
||||
loadHistory(toolCallData.history);
|
||||
}
|
||||
|
||||
if (toolCallData.clientHistory) {
|
||||
await config
|
||||
?.getGeminiClient()
|
||||
?.setHistory(toolCallData.clientHistory);
|
||||
}
|
||||
|
||||
if (toolCallData.commitHash) {
|
||||
await gitService?.restoreProjectFromSnapshot(
|
||||
toolCallData.commitHash,
|
||||
);
|
||||
addMessage({
|
||||
type: MessageType.INFO,
|
||||
content: `Restored project to the state before the tool call.`,
|
||||
timestamp: new Date(),
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
shouldScheduleTool: true,
|
||||
toolName: toolCallData.toolCall.name,
|
||||
toolArgs: toolCallData.toolCall.args,
|
||||
};
|
||||
} catch (error) {
|
||||
addMessage({
|
||||
type: MessageType.ERROR,
|
||||
content: `Could not read restorable tool calls. This is the error: ${error}`,
|
||||
timestamp: new Date(),
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
return commands;
|
||||
}, [
|
||||
onDebugMessage,
|
||||
setShowHelp,
|
||||
refreshStatic,
|
||||
openThemeDialog,
|
||||
clearItems,
|
||||
performMemoryRefresh,
|
||||
showMemoryAction,
|
||||
addMemoryAction,
|
||||
addMessage,
|
||||
toggleCorgiMode,
|
||||
config,
|
||||
showToolDescriptions,
|
||||
session,
|
||||
gitService,
|
||||
loadHistory,
|
||||
]);
|
||||
|
||||
const handleSlashCommand = useCallback(
|
||||
(rawQuery: PartListUnion): SlashCommandActionReturn | boolean => {
|
||||
async (
|
||||
rawQuery: PartListUnion,
|
||||
): Promise<SlashCommandActionReturn | boolean> => {
|
||||
if (typeof rawQuery !== 'string') {
|
||||
return false;
|
||||
}
|
||||
@@ -625,7 +758,7 @@ Add any other context about the problem here.
|
||||
|
||||
for (const cmd of slashCommands) {
|
||||
if (mainCommand === cmd.name || mainCommand === cmd.altName) {
|
||||
const actionResult = cmd.action(mainCommand, subCommand, args);
|
||||
const actionResult = await cmd.action(mainCommand, subCommand, args);
|
||||
if (
|
||||
typeof actionResult === 'object' &&
|
||||
actionResult?.shouldScheduleTool
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
import { Config } from '@gemini-cli/core';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { HistoryItem } from '../types.js';
|
||||
import { Dispatch, SetStateAction } from 'react';
|
||||
|
||||
// --- MOCKS ---
|
||||
@@ -38,9 +39,9 @@ const MockedGeminiClientClass = vi.hoisted(() =>
|
||||
vi.mock('@gemini-cli/core', async (importOriginal) => {
|
||||
const actualCoreModule = (await importOriginal()) as any;
|
||||
return {
|
||||
...(actualCoreModule || {}),
|
||||
GeminiClient: MockedGeminiClientClass, // Export the class for type checking or other direct uses
|
||||
Config: actualCoreModule.Config, // Ensure Config is passed through
|
||||
...actualCoreModule,
|
||||
GitService: vi.fn(),
|
||||
GeminiClient: MockedGeminiClientClass,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -277,11 +278,13 @@ describe('useGeminiStream', () => {
|
||||
getToolRegistry: vi.fn(
|
||||
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
|
||||
),
|
||||
getProjectRoot: vi.fn(() => '/test/dir'),
|
||||
getCheckpointEnabled: vi.fn(() => false),
|
||||
getGeminiClient: mockGetGeminiClient,
|
||||
addHistory: vi.fn(),
|
||||
} as unknown as Config;
|
||||
mockOnDebugMessage = vi.fn();
|
||||
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
|
||||
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
|
||||
|
||||
// Mock return value for useReactToolScheduler
|
||||
mockScheduleToolCalls = vi.fn();
|
||||
@@ -322,19 +325,22 @@ describe('useGeminiStream', () => {
|
||||
const { result, rerender } = renderHook(
|
||||
(props: {
|
||||
client: any;
|
||||
history: HistoryItem[];
|
||||
addItem: UseHistoryManagerReturn['addItem'];
|
||||
setShowHelp: Dispatch<SetStateAction<boolean>>;
|
||||
config: Config;
|
||||
onDebugMessage: (message: string) => void;
|
||||
handleSlashCommand: (
|
||||
command: PartListUnion,
|
||||
) =>
|
||||
cmd: PartListUnion,
|
||||
) => Promise<
|
||||
| import('./slashCommandProcessor.js').SlashCommandActionReturn
|
||||
| boolean;
|
||||
| boolean
|
||||
>;
|
||||
shellModeActive: boolean;
|
||||
}) =>
|
||||
useGeminiStream(
|
||||
props.client,
|
||||
props.history,
|
||||
props.addItem,
|
||||
props.setShowHelp,
|
||||
props.config,
|
||||
@@ -345,12 +351,17 @@ describe('useGeminiStream', () => {
|
||||
{
|
||||
initialProps: {
|
||||
client,
|
||||
history: [],
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
setShowHelp: mockSetShowHelp,
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
handleSlashCommand:
|
||||
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
|
||||
handleSlashCommand: mockHandleSlashCommand as unknown as (
|
||||
cmd: PartListUnion,
|
||||
) => Promise<
|
||||
| import('./slashCommandProcessor.js').SlashCommandActionReturn
|
||||
| boolean
|
||||
>,
|
||||
shellModeActive: false,
|
||||
},
|
||||
},
|
||||
@@ -467,7 +478,8 @@ describe('useGeminiStream', () => {
|
||||
act(() => {
|
||||
rerender({
|
||||
client,
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
history: [],
|
||||
addItem: mockAddItem,
|
||||
setShowHelp: mockSetShowHelp,
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
@@ -521,7 +533,8 @@ describe('useGeminiStream', () => {
|
||||
act(() => {
|
||||
rerender({
|
||||
client,
|
||||
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
|
||||
history: [],
|
||||
addItem: mockAddItem,
|
||||
setShowHelp: mockSetShowHelp,
|
||||
config: mockConfig,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
|
||||
import { useInput } from 'ink';
|
||||
import {
|
||||
Config,
|
||||
GeminiClient,
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
ServerGeminiStreamEvent as GeminiEvent,
|
||||
@@ -14,14 +15,15 @@ import {
|
||||
ServerGeminiErrorEvent as ErrorEvent,
|
||||
getErrorMessage,
|
||||
isNodeError,
|
||||
Config,
|
||||
MessageSenderType,
|
||||
ToolCallRequestInfo,
|
||||
logUserPrompt,
|
||||
GitService,
|
||||
} from '@gemini-cli/core';
|
||||
import { type Part, type PartListUnion } from '@google/genai';
|
||||
import {
|
||||
StreamingState,
|
||||
HistoryItem,
|
||||
HistoryItemWithoutId,
|
||||
HistoryItemToolGroup,
|
||||
MessageType,
|
||||
@@ -35,6 +37,8 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { useLogger } from './useLogger.js';
|
||||
import { promises as fs } from 'fs';
|
||||
import path from 'path';
|
||||
import {
|
||||
useReactToolScheduler,
|
||||
mapToDisplay as mapTrackedToolCallsToDisplay,
|
||||
@@ -68,13 +72,16 @@ enum StreamProcessingStatus {
|
||||
*/
|
||||
export const useGeminiStream = (
|
||||
geminiClient: GeminiClient | null,
|
||||
history: HistoryItem[],
|
||||
addItem: UseHistoryManagerReturn['addItem'],
|
||||
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
|
||||
config: Config,
|
||||
onDebugMessage: (message: string) => void,
|
||||
handleSlashCommand: (
|
||||
cmd: PartListUnion,
|
||||
) => import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean,
|
||||
) => Promise<
|
||||
import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean
|
||||
>,
|
||||
shellModeActive: boolean,
|
||||
) => {
|
||||
const [initError, setInitError] = useState<string | null>(null);
|
||||
@@ -84,6 +91,12 @@ export const useGeminiStream = (
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const logger = useLogger();
|
||||
const { startNewTurn, addUsage } = useSessionStats();
|
||||
const gitService = useMemo(() => {
|
||||
if (!config.getProjectRoot()) {
|
||||
return;
|
||||
}
|
||||
return new GitService(config.getProjectRoot());
|
||||
}, [config]);
|
||||
|
||||
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
|
||||
useReactToolScheduler(
|
||||
@@ -178,7 +191,7 @@ export const useGeminiStream = (
|
||||
await logger?.logMessage(MessageSenderType.USER, trimmedQuery);
|
||||
|
||||
// Handle UI-only commands first
|
||||
const slashCommandResult = handleSlashCommand(trimmedQuery);
|
||||
const slashCommandResult = await handleSlashCommand(trimmedQuery);
|
||||
if (typeof slashCommandResult === 'boolean' && slashCommandResult) {
|
||||
// Command was handled, and it doesn't require a tool call from here
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
@@ -605,6 +618,106 @@ export const useGeminiStream = (
|
||||
pendingToolCallGroupDisplay,
|
||||
].filter((i) => i !== undefined && i !== null);
|
||||
|
||||
useEffect(() => {
|
||||
const saveRestorableToolCalls = async () => {
|
||||
if (!config.getCheckpointEnabled()) {
|
||||
return;
|
||||
}
|
||||
const restorableToolCalls = toolCalls.filter(
|
||||
(toolCall) =>
|
||||
(toolCall.request.name === 'replace' ||
|
||||
toolCall.request.name === 'write_file') &&
|
||||
toolCall.status === 'awaiting_approval',
|
||||
);
|
||||
|
||||
if (restorableToolCalls.length > 0) {
|
||||
const checkpointDir = config.getGeminiDir()
|
||||
? path.join(config.getGeminiDir(), 'checkpoints')
|
||||
: undefined;
|
||||
|
||||
if (!checkpointDir) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await fs.mkdir(checkpointDir, { recursive: true });
|
||||
} catch (error) {
|
||||
if (!isNodeError(error) || error.code !== 'EEXIST') {
|
||||
onDebugMessage(
|
||||
`Failed to create checkpoint directory: ${getErrorMessage(error)}`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (const toolCall of restorableToolCalls) {
|
||||
const filePath = toolCall.request.args['file_path'] as string;
|
||||
if (!filePath) {
|
||||
onDebugMessage(
|
||||
`Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
let commitHash = await gitService?.createFileSnapshot(
|
||||
`Snapshot for ${toolCall.request.name}`,
|
||||
);
|
||||
|
||||
if (!commitHash) {
|
||||
commitHash = await gitService?.getCurrentCommitHash();
|
||||
}
|
||||
|
||||
if (!commitHash) {
|
||||
onDebugMessage(
|
||||
`Failed to create snapshot for ${filePath}. Skipping restorable tool call.`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
const timestamp = new Date()
|
||||
.toISOString()
|
||||
.replace(/:/g, '-')
|
||||
.replace(/\./g, '_');
|
||||
const toolName = toolCall.request.name;
|
||||
const fileName = path.basename(filePath);
|
||||
const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`;
|
||||
const clientHistory = await geminiClient?.getHistory();
|
||||
const toolCallWithSnapshotFilePath = path.join(
|
||||
checkpointDir,
|
||||
toolCallWithSnapshotFileName,
|
||||
);
|
||||
|
||||
await fs.writeFile(
|
||||
toolCallWithSnapshotFilePath,
|
||||
JSON.stringify(
|
||||
{
|
||||
history,
|
||||
clientHistory,
|
||||
toolCall: {
|
||||
name: toolCall.request.name,
|
||||
args: toolCall.request.args,
|
||||
},
|
||||
commitHash,
|
||||
filePath,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
);
|
||||
} catch (error) {
|
||||
onDebugMessage(
|
||||
`Failed to write restorable tool call file: ${getErrorMessage(
|
||||
error,
|
||||
)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
saveRestorableToolCalls();
|
||||
}, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]);
|
||||
|
||||
return {
|
||||
streamingState,
|
||||
submitQuery,
|
||||
|
||||
@@ -20,6 +20,7 @@ export interface UseHistoryManagerReturn {
|
||||
updates: Partial<Omit<HistoryItem, 'id'>> | HistoryItemUpdater,
|
||||
) => void;
|
||||
clearItems: () => void;
|
||||
loadHistory: (newHistory: HistoryItem[]) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -38,6 +39,10 @@ export function useHistory(): UseHistoryManagerReturn {
|
||||
return baseTimestamp + messageIdCounterRef.current;
|
||||
}, []);
|
||||
|
||||
const loadHistory = useCallback((newHistory: HistoryItem[]) => {
|
||||
setHistory(newHistory);
|
||||
}, []);
|
||||
|
||||
// Adds a new item to the history state with a unique ID.
|
||||
const addItem = useCallback(
|
||||
(itemData: Omit<HistoryItem, 'id'>, baseTimestamp: number): number => {
|
||||
@@ -101,5 +106,6 @@ export function useHistory(): UseHistoryManagerReturn {
|
||||
addItem,
|
||||
updateItem,
|
||||
clearItems,
|
||||
loadHistory,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
*/
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
import { sessionId } from '@gemini-cli/core';
|
||||
import { Logger } from '@gemini-cli/core';
|
||||
import { sessionId, Logger } from '@gemini-cli/core';
|
||||
|
||||
/**
|
||||
* Hook to manage the logger instance.
|
||||
|
||||
Reference in New Issue
Block a user