Restore Checkpoint Feature (#934)

This commit is contained in:
Louis Jimenez
2025-06-11 15:33:09 -04:00
committed by GitHub
parent f75c48323c
commit e0f4f428fc
19 changed files with 837 additions and 63 deletions

View File

@@ -43,6 +43,7 @@ interface CliArgs {
show_memory_usage: boolean | undefined;
yolo: boolean | undefined;
telemetry: boolean | undefined;
checkpoint: boolean | undefined;
}
async function parseArguments(): Promise<CliArgs> {
@@ -91,6 +92,12 @@ async function parseArguments(): Promise<CliArgs> {
type: 'boolean',
description: 'Enable telemetry?',
})
.option('checkpoint', {
alias: 'c',
type: 'boolean',
description: 'Enables checkpointing of file edits',
default: false,
})
.version(process.env.CLI_VERSION || '0.0.0') // This will enable the --version flag based on package.json
.help()
.alias('h', 'help')
@@ -178,6 +185,7 @@ export async function loadCliConfig(
fileFilteringAllowBuildArtifacts:
settings.fileFiltering?.allowBuildArtifacts,
enableModifyWithExternalEditors: settings.enableModifyWithExternalEditors,
checkpoint: argv.checkpoint,
});
}

View File

@@ -17,6 +17,7 @@ import { getStartupWarnings } from './utils/startupWarnings.js';
import { runNonInteractive } from './nonInteractiveCli.js';
import { loadGeminiIgnorePatterns } from './utils/loadIgnorePatterns.js';
import { loadExtensions, ExtensionConfig } from './config/extension.js';
import { cleanupCheckpoints } from './utils/cleanup.js';
import {
ApprovalMode,
Config,
@@ -40,7 +41,7 @@ export async function main() {
setWindowTitle(basename(workspaceRoot), settings);
const geminiIgnorePatterns = loadGeminiIgnorePatterns(workspaceRoot);
await cleanupCheckpoints();
if (settings.errors.length > 0) {
for (const error of settings.errors) {
let errorMessage = `Error in ${error.path}: ${error.message}`;
@@ -63,6 +64,13 @@ export async function main() {
// Initialize centralized FileDiscoveryService
await config.getFileService();
if (config.getCheckpointEnabled()) {
try {
await config.getGitService();
} catch {
// For now swallow the error, later log it.
}
}
if (settings.merged.theme) {
if (!themeManager.setActiveTheme(settings.merged.theme)) {

View File

@@ -63,6 +63,7 @@ interface MockServerConfig {
getVertexAI: Mock<() => boolean | undefined>;
getShowMemoryUsage: Mock<() => boolean>;
getAccessibility: Mock<() => AccessibilitySettings>;
getProjectRoot: Mock<() => string | undefined>;
}
// Mock @gemini-cli/core and its Config class
@@ -120,7 +121,9 @@ vi.mock('@gemini-cli/core', async (importOriginal) => {
getVertexAI: vi.fn(() => opts.vertexai),
getShowMemoryUsage: vi.fn(() => opts.showMemoryUsage ?? false),
getAccessibility: vi.fn(() => opts.accessibility ?? {}),
getProjectRoot: vi.fn(() => opts.projectRoot),
getGeminiClient: vi.fn(() => ({})),
getCheckpointEnabled: vi.fn(() => opts.checkpoint ?? true),
};
});
return {

View File

@@ -66,7 +66,7 @@ export const AppWrapper = (props: AppProps) => (
);
const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
const { history, addItem, clearItems } = useHistory();
const { history, addItem, clearItems, loadHistory } = useHistory();
const {
consoleMessages,
handleNewMessage,
@@ -151,8 +151,10 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
const { handleSlashCommand, slashCommands } = useSlashCommandProcessor(
config,
history,
addItem,
clearItems,
loadHistory,
refreshStatic,
setShowHelp,
setDebugMessage,
@@ -217,6 +219,7 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
const { streamingState, submitQuery, initError, pendingHistoryItems } =
useGeminiStream(
config.getGeminiClient(),
history,
addItem,
setShowHelp,
config,
@@ -512,7 +515,6 @@ const App = ({ config, settings, startupWarnings = [] }: AppProps) => {
)}
</Box>
)}
<Footer
model={config.getModel()}
targetDir={config.getTargetDir()}

View File

@@ -65,6 +65,14 @@ import {
} from '@gemini-cli/core';
import { useSessionStats } from '../contexts/SessionContext.js';
vi.mock('@gemini-code/core', async (importOriginal) => {
const actual = await importOriginal<typeof import('@gemini-code/core')>();
return {
...actual,
GitService: vi.fn(),
};
});
import * as ShowMemoryCommandModule from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
@@ -84,6 +92,7 @@ vi.mock('open', () => ({
describe('useSlashCommandProcessor', () => {
let mockAddItem: ReturnType<typeof vi.fn>;
let mockClearItems: ReturnType<typeof vi.fn>;
let mockLoadHistory: ReturnType<typeof vi.fn>;
let mockRefreshStatic: ReturnType<typeof vi.fn>;
let mockSetShowHelp: ReturnType<typeof vi.fn>;
let mockOnDebugMessage: ReturnType<typeof vi.fn>;
@@ -96,6 +105,7 @@ describe('useSlashCommandProcessor', () => {
beforeEach(() => {
mockAddItem = vi.fn();
mockClearItems = vi.fn();
mockLoadHistory = vi.fn();
mockRefreshStatic = vi.fn();
mockSetShowHelp = vi.fn();
mockOnDebugMessage = vi.fn();
@@ -105,6 +115,8 @@ describe('useSlashCommandProcessor', () => {
getDebugMode: vi.fn(() => false),
getSandbox: vi.fn(() => 'test-sandbox'),
getModel: vi.fn(() => 'test-model'),
getProjectRoot: vi.fn(() => '/test/dir'),
getCheckpointEnabled: vi.fn(() => true),
} as unknown as Config;
mockCorgiMode = vi.fn();
mockUseSessionStats.mockReturnValue({
@@ -133,8 +145,10 @@ describe('useSlashCommandProcessor', () => {
const { result } = renderHook(() =>
useSlashCommandProcessor(
mockConfig,
[],
mockAddItem,
mockClearItems,
mockLoadHistory,
mockRefreshStatic,
mockSetShowHelp,
mockOnDebugMessage,
@@ -153,7 +167,7 @@ describe('useSlashCommandProcessor', () => {
const fact = 'Remember this fact';
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand(`/memory add ${fact}`);
commandResult = await handleSlashCommand(`/memory add ${fact}`);
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -187,7 +201,7 @@ describe('useSlashCommandProcessor', () => {
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/memory add ');
commandResult = await handleSlashCommand('/memory add ');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -211,7 +225,7 @@ describe('useSlashCommandProcessor', () => {
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/memory show');
commandResult = await handleSlashCommand('/memory show');
});
expect(
ShowMemoryCommandModule.createShowMemoryAction,
@@ -226,7 +240,7 @@ describe('useSlashCommandProcessor', () => {
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/memory refresh');
commandResult = await handleSlashCommand('/memory refresh');
});
expect(mockPerformMemoryRefresh).toHaveBeenCalled();
expect(commandResult).toBe(true);
@@ -238,7 +252,7 @@ describe('useSlashCommandProcessor', () => {
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/memory foobar');
commandResult = await handleSlashCommand('/memory foobar');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
2,
@@ -300,7 +314,7 @@ describe('useSlashCommandProcessor', () => {
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/help');
commandResult = await handleSlashCommand('/help');
});
expect(mockSetShowHelp).toHaveBeenCalledWith(true);
expect(commandResult).toBe(true);
@@ -373,7 +387,7 @@ Add any other context about the problem here.
);
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand(`/bug ${bugDescription}`);
commandResult = await handleSlashCommand(`/bug ${bugDescription}`);
});
expect(mockAddItem).toHaveBeenCalledTimes(2);
@@ -387,7 +401,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/unknowncommand');
commandResult = await handleSlashCommand('/unknowncommand');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
2,
@@ -410,7 +424,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/tools');
commandResult = await handleSlashCommand('/tools');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -434,7 +448,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/tools');
commandResult = await handleSlashCommand('/tools');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -467,7 +481,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/tools');
commandResult = await handleSlashCommand('/tools');
});
// Should only show tool1 and tool2, not the MCP tools
@@ -499,7 +513,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/tools');
commandResult = await handleSlashCommand('/tools');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -545,7 +559,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/mcp');
commandResult = await handleSlashCommand('/mcp');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -571,7 +585,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/mcp');
commandResult = await handleSlashCommand('/mcp');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -633,7 +647,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/mcp');
commandResult = await handleSlashCommand('/mcp');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -706,7 +720,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor(true);
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/mcp');
commandResult = await handleSlashCommand('/mcp');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -780,7 +794,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/mcp');
commandResult = await handleSlashCommand('/mcp');
});
expect(mockAddItem).toHaveBeenNthCalledWith(
@@ -846,7 +860,7 @@ Add any other context about the problem here.
const { handleSlashCommand } = getProcessor();
let commandResult: SlashCommandActionReturn | boolean = false;
await act(async () => {
commandResult = handleSlashCommand('/mcp');
commandResult = await handleSlashCommand('/mcp');
});
const message = mockAddItem.mock.calls[1][0].text;

View File

@@ -11,14 +11,22 @@ import process from 'node:process';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import {
Config,
GitService,
Logger,
MCPDiscoveryState,
MCPServerStatus,
getMCPDiscoveryState,
getMCPServerStatus,
} from '@gemini-cli/core';
import { Message, MessageType, HistoryItemWithoutId } from '../types.js';
import { useSessionStats } from '../contexts/SessionContext.js';
import {
Message,
MessageType,
HistoryItemWithoutId,
HistoryItem,
} from '../types.js';
import { promises as fs } from 'fs';
import path from 'path';
import { createShowMemoryAction } from './useShowMemoryCommand.js';
import { GIT_COMMIT_INFO } from '../../generated/git-commit.js';
import { formatDuration, formatMemoryUsage } from '../utils/formatters.js';
@@ -39,7 +47,10 @@ export interface SlashCommand {
mainCommand: string,
subCommand?: string,
args?: string,
) => void | SlashCommandActionReturn; // Action can now return this object
) =>
| void
| SlashCommandActionReturn
| Promise<void | SlashCommandActionReturn>; // Action can now return this object
}
/**
@@ -47,8 +58,10 @@ export interface SlashCommand {
*/
export const useSlashCommandProcessor = (
config: Config | null,
history: HistoryItem[],
addItem: UseHistoryManagerReturn['addItem'],
clearItems: UseHistoryManagerReturn['clearItems'],
loadHistory: UseHistoryManagerReturn['loadHistory'],
refreshStatic: () => void,
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
onDebugMessage: (message: string) => void,
@@ -58,6 +71,13 @@ export const useSlashCommandProcessor = (
showToolDescriptions: boolean = false,
) => {
const session = useSessionStats();
const gitService = useMemo(() => {
if (!config?.getProjectRoot()) {
return;
}
return new GitService(config.getProjectRoot());
}, [config]);
const addMessage = useCallback(
(message: Message) => {
// Convert Message to HistoryItemWithoutId
@@ -126,8 +146,8 @@ export const useSlashCommandProcessor = (
[addMessage],
);
const slashCommands: SlashCommand[] = useMemo(
() => [
const slashCommands: SlashCommand[] = useMemo(() => {
const commands: SlashCommand[] = [
{
name: 'help',
altName: '?',
@@ -408,7 +428,9 @@ export const useSlashCommandProcessor = (
if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') {
sandboxEnv = process.env.SANDBOX;
} else if (process.env.SANDBOX === 'sandbox-exec') {
sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`;
sandboxEnv = `sandbox-exec (${
process.env.SEATBELT_PROFILE || 'unknown'
})`;
}
const modelVersion = config?.getModel() || 'Unknown';
const cliVersion = getCliVersion();
@@ -437,7 +459,9 @@ export const useSlashCommandProcessor = (
if (process.env.SANDBOX && process.env.SANDBOX !== 'sandbox-exec') {
sandboxEnv = process.env.SANDBOX.replace(/^gemini-(?:code-)?/, '');
} else if (process.env.SANDBOX === 'sandbox-exec') {
sandboxEnv = `sandbox-exec (${process.env.SEATBELT_PROFILE || 'unknown'})`;
sandboxEnv = `sandbox-exec (${
process.env.SEATBELT_PROFILE || 'unknown'
})`;
}
const modelVersion = config?.getModel() || 'Unknown';
const memoryUsage = formatMemoryUsage(process.memoryUsage().rss);
@@ -569,31 +593,140 @@ Add any other context about the problem here.
name: 'quit',
altName: 'exit',
description: 'exit the cli',
action: (_mainCommand, _subCommand, _args) => {
action: async (_mainCommand, _subCommand, _args) => {
onDebugMessage('Quitting. Good-bye.');
process.exit(0);
},
},
],
[
onDebugMessage,
setShowHelp,
refreshStatic,
openThemeDialog,
clearItems,
performMemoryRefresh,
showMemoryAction,
addMemoryAction,
addMessage,
toggleCorgiMode,
config,
showToolDescriptions,
session,
],
);
];
if (config?.getCheckpointEnabled()) {
commands.push({
name: 'restore',
description:
'restore a tool call. This will reset the conversation and file history to the state it was in when the tool call was suggested',
action: async (_mainCommand, subCommand, _args) => {
const checkpointDir = config?.getGeminiDir()
? path.join(config.getGeminiDir(), 'checkpoints')
: undefined;
if (!checkpointDir) {
addMessage({
type: MessageType.ERROR,
content: 'Could not determine the .gemini directory path.',
timestamp: new Date(),
});
return;
}
try {
// Ensure the directory exists before trying to read it.
await fs.mkdir(checkpointDir, { recursive: true });
const files = await fs.readdir(checkpointDir);
const jsonFiles = files.filter((file) => file.endsWith('.json'));
if (!subCommand) {
if (jsonFiles.length === 0) {
addMessage({
type: MessageType.INFO,
content: 'No restorable tool calls found.',
timestamp: new Date(),
});
return;
}
const truncatedFiles = jsonFiles.map((file) => {
const components = file.split('.');
if (components.length <= 1) {
return file;
}
components.pop();
return components.join('.');
});
const fileList = truncatedFiles.join('\n');
addMessage({
type: MessageType.INFO,
content: `Available tool calls to restore:\n\n${fileList}`,
timestamp: new Date(),
});
return;
}
const selectedFile = subCommand.endsWith('.json')
? subCommand
: `${subCommand}.json`;
if (!jsonFiles.includes(selectedFile)) {
addMessage({
type: MessageType.ERROR,
content: `File not found: ${selectedFile}`,
timestamp: new Date(),
});
return;
}
const filePath = path.join(checkpointDir, selectedFile);
const data = await fs.readFile(filePath, 'utf-8');
const toolCallData = JSON.parse(data);
if (toolCallData.history) {
loadHistory(toolCallData.history);
}
if (toolCallData.clientHistory) {
await config
?.getGeminiClient()
?.setHistory(toolCallData.clientHistory);
}
if (toolCallData.commitHash) {
await gitService?.restoreProjectFromSnapshot(
toolCallData.commitHash,
);
addMessage({
type: MessageType.INFO,
content: `Restored project to the state before the tool call.`,
timestamp: new Date(),
});
}
return {
shouldScheduleTool: true,
toolName: toolCallData.toolCall.name,
toolArgs: toolCallData.toolCall.args,
};
} catch (error) {
addMessage({
type: MessageType.ERROR,
content: `Could not read restorable tool calls. This is the error: ${error}`,
timestamp: new Date(),
});
}
},
});
}
return commands;
}, [
onDebugMessage,
setShowHelp,
refreshStatic,
openThemeDialog,
clearItems,
performMemoryRefresh,
showMemoryAction,
addMemoryAction,
addMessage,
toggleCorgiMode,
config,
showToolDescriptions,
session,
gitService,
loadHistory,
]);
const handleSlashCommand = useCallback(
(rawQuery: PartListUnion): SlashCommandActionReturn | boolean => {
async (
rawQuery: PartListUnion,
): Promise<SlashCommandActionReturn | boolean> => {
if (typeof rawQuery !== 'string') {
return false;
}
@@ -625,7 +758,7 @@ Add any other context about the problem here.
for (const cmd of slashCommands) {
if (mainCommand === cmd.name || mainCommand === cmd.altName) {
const actionResult = cmd.action(mainCommand, subCommand, args);
const actionResult = await cmd.action(mainCommand, subCommand, args);
if (
typeof actionResult === 'object' &&
actionResult?.shouldScheduleTool

View File

@@ -18,6 +18,7 @@ import {
import { Config } from '@gemini-cli/core';
import { Part, PartListUnion } from '@google/genai';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { HistoryItem } from '../types.js';
import { Dispatch, SetStateAction } from 'react';
// --- MOCKS ---
@@ -38,9 +39,9 @@ const MockedGeminiClientClass = vi.hoisted(() =>
vi.mock('@gemini-cli/core', async (importOriginal) => {
const actualCoreModule = (await importOriginal()) as any;
return {
...(actualCoreModule || {}),
GeminiClient: MockedGeminiClientClass, // Export the class for type checking or other direct uses
Config: actualCoreModule.Config, // Ensure Config is passed through
...actualCoreModule,
GitService: vi.fn(),
GeminiClient: MockedGeminiClientClass,
};
});
@@ -277,11 +278,13 @@ describe('useGeminiStream', () => {
getToolRegistry: vi.fn(
() => ({ getToolSchemaList: vi.fn(() => []) }) as any,
),
getProjectRoot: vi.fn(() => '/test/dir'),
getCheckpointEnabled: vi.fn(() => false),
getGeminiClient: mockGetGeminiClient,
addHistory: vi.fn(),
} as unknown as Config;
mockOnDebugMessage = vi.fn();
mockHandleSlashCommand = vi.fn().mockReturnValue(false);
mockHandleSlashCommand = vi.fn().mockResolvedValue(false);
// Mock return value for useReactToolScheduler
mockScheduleToolCalls = vi.fn();
@@ -322,19 +325,22 @@ describe('useGeminiStream', () => {
const { result, rerender } = renderHook(
(props: {
client: any;
history: HistoryItem[];
addItem: UseHistoryManagerReturn['addItem'];
setShowHelp: Dispatch<SetStateAction<boolean>>;
config: Config;
onDebugMessage: (message: string) => void;
handleSlashCommand: (
command: PartListUnion,
) =>
cmd: PartListUnion,
) => Promise<
| import('./slashCommandProcessor.js').SlashCommandActionReturn
| boolean;
| boolean
>;
shellModeActive: boolean;
}) =>
useGeminiStream(
props.client,
props.history,
props.addItem,
props.setShowHelp,
props.config,
@@ -345,12 +351,17 @@ describe('useGeminiStream', () => {
{
initialProps: {
client,
history: [],
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
handleSlashCommand:
mockHandleSlashCommand as unknown as typeof mockHandleSlashCommand,
handleSlashCommand: mockHandleSlashCommand as unknown as (
cmd: PartListUnion,
) => Promise<
| import('./slashCommandProcessor.js').SlashCommandActionReturn
| boolean
>,
shellModeActive: false,
},
},
@@ -467,7 +478,8 @@ describe('useGeminiStream', () => {
act(() => {
rerender({
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
history: [],
addItem: mockAddItem,
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,
@@ -521,7 +533,8 @@ describe('useGeminiStream', () => {
act(() => {
rerender({
client,
addItem: mockAddItem as unknown as UseHistoryManagerReturn['addItem'],
history: [],
addItem: mockAddItem,
setShowHelp: mockSetShowHelp,
config: mockConfig,
onDebugMessage: mockOnDebugMessage,

View File

@@ -7,6 +7,7 @@
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useInput } from 'ink';
import {
Config,
GeminiClient,
GeminiEventType as ServerGeminiEventType,
ServerGeminiStreamEvent as GeminiEvent,
@@ -14,14 +15,15 @@ import {
ServerGeminiErrorEvent as ErrorEvent,
getErrorMessage,
isNodeError,
Config,
MessageSenderType,
ToolCallRequestInfo,
logUserPrompt,
GitService,
} from '@gemini-cli/core';
import { type Part, type PartListUnion } from '@google/genai';
import {
StreamingState,
HistoryItem,
HistoryItemWithoutId,
HistoryItemToolGroup,
MessageType,
@@ -35,6 +37,8 @@ import { findLastSafeSplitPoint } from '../utils/markdownUtilities.js';
import { useStateAndRef } from './useStateAndRef.js';
import { UseHistoryManagerReturn } from './useHistoryManager.js';
import { useLogger } from './useLogger.js';
import { promises as fs } from 'fs';
import path from 'path';
import {
useReactToolScheduler,
mapToDisplay as mapTrackedToolCallsToDisplay,
@@ -68,13 +72,16 @@ enum StreamProcessingStatus {
*/
export const useGeminiStream = (
geminiClient: GeminiClient | null,
history: HistoryItem[],
addItem: UseHistoryManagerReturn['addItem'],
setShowHelp: React.Dispatch<React.SetStateAction<boolean>>,
config: Config,
onDebugMessage: (message: string) => void,
handleSlashCommand: (
cmd: PartListUnion,
) => import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean,
) => Promise<
import('./slashCommandProcessor.js').SlashCommandActionReturn | boolean
>,
shellModeActive: boolean,
) => {
const [initError, setInitError] = useState<string | null>(null);
@@ -84,6 +91,12 @@ export const useGeminiStream = (
useStateAndRef<HistoryItemWithoutId | null>(null);
const logger = useLogger();
const { startNewTurn, addUsage } = useSessionStats();
const gitService = useMemo(() => {
if (!config.getProjectRoot()) {
return;
}
return new GitService(config.getProjectRoot());
}, [config]);
const [toolCalls, scheduleToolCalls, markToolsAsSubmitted] =
useReactToolScheduler(
@@ -178,7 +191,7 @@ export const useGeminiStream = (
await logger?.logMessage(MessageSenderType.USER, trimmedQuery);
// Handle UI-only commands first
const slashCommandResult = handleSlashCommand(trimmedQuery);
const slashCommandResult = await handleSlashCommand(trimmedQuery);
if (typeof slashCommandResult === 'boolean' && slashCommandResult) {
// Command was handled, and it doesn't require a tool call from here
return { queryToSend: null, shouldProceed: false };
@@ -605,6 +618,106 @@ export const useGeminiStream = (
pendingToolCallGroupDisplay,
].filter((i) => i !== undefined && i !== null);
useEffect(() => {
const saveRestorableToolCalls = async () => {
if (!config.getCheckpointEnabled()) {
return;
}
const restorableToolCalls = toolCalls.filter(
(toolCall) =>
(toolCall.request.name === 'replace' ||
toolCall.request.name === 'write_file') &&
toolCall.status === 'awaiting_approval',
);
if (restorableToolCalls.length > 0) {
const checkpointDir = config.getGeminiDir()
? path.join(config.getGeminiDir(), 'checkpoints')
: undefined;
if (!checkpointDir) {
return;
}
try {
await fs.mkdir(checkpointDir, { recursive: true });
} catch (error) {
if (!isNodeError(error) || error.code !== 'EEXIST') {
onDebugMessage(
`Failed to create checkpoint directory: ${getErrorMessage(error)}`,
);
return;
}
}
for (const toolCall of restorableToolCalls) {
const filePath = toolCall.request.args['file_path'] as string;
if (!filePath) {
onDebugMessage(
`Skipping restorable tool call due to missing file_path: ${toolCall.request.name}`,
);
continue;
}
try {
let commitHash = await gitService?.createFileSnapshot(
`Snapshot for ${toolCall.request.name}`,
);
if (!commitHash) {
commitHash = await gitService?.getCurrentCommitHash();
}
if (!commitHash) {
onDebugMessage(
`Failed to create snapshot for ${filePath}. Skipping restorable tool call.`,
);
continue;
}
const timestamp = new Date()
.toISOString()
.replace(/:/g, '-')
.replace(/\./g, '_');
const toolName = toolCall.request.name;
const fileName = path.basename(filePath);
const toolCallWithSnapshotFileName = `${timestamp}-${fileName}-${toolName}.json`;
const clientHistory = await geminiClient?.getHistory();
const toolCallWithSnapshotFilePath = path.join(
checkpointDir,
toolCallWithSnapshotFileName,
);
await fs.writeFile(
toolCallWithSnapshotFilePath,
JSON.stringify(
{
history,
clientHistory,
toolCall: {
name: toolCall.request.name,
args: toolCall.request.args,
},
commitHash,
filePath,
},
null,
2,
),
);
} catch (error) {
onDebugMessage(
`Failed to write restorable tool call file: ${getErrorMessage(
error,
)}`,
);
}
}
}
};
saveRestorableToolCalls();
}, [toolCalls, config, onDebugMessage, gitService, history, geminiClient]);
return {
streamingState,
submitQuery,

View File

@@ -20,6 +20,7 @@ export interface UseHistoryManagerReturn {
updates: Partial<Omit<HistoryItem, 'id'>> | HistoryItemUpdater,
) => void;
clearItems: () => void;
loadHistory: (newHistory: HistoryItem[]) => void;
}
/**
@@ -38,6 +39,10 @@ export function useHistory(): UseHistoryManagerReturn {
return baseTimestamp + messageIdCounterRef.current;
}, []);
const loadHistory = useCallback((newHistory: HistoryItem[]) => {
setHistory(newHistory);
}, []);
// Adds a new item to the history state with a unique ID.
const addItem = useCallback(
(itemData: Omit<HistoryItem, 'id'>, baseTimestamp: number): number => {
@@ -101,5 +106,6 @@ export function useHistory(): UseHistoryManagerReturn {
addItem,
updateItem,
clearItems,
loadHistory,
};
}

View File

@@ -5,8 +5,7 @@
*/
import { useState, useEffect } from 'react';
import { sessionId } from '@gemini-cli/core';
import { Logger } from '@gemini-cli/core';
import { sessionId, Logger } from '@gemini-cli/core';
/**
* Hook to manage the logger instance.

View File

@@ -0,0 +1,18 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import { promises as fs } from 'fs';
import { join } from 'path';
export async function cleanupCheckpoints() {
const geminiDir = join(process.cwd(), '.gemini');
const checkpointsDir = join(geminiDir, 'checkpoints');
try {
await fs.rm(checkpointsDir, { recursive: true, force: true });
} catch {
// Ignore errors if the directory doesn't exist or fails to delete.
}
}