mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
# 🚀 Sync Gemini CLI v0.2.1 - Major Feature Update (#483)
This commit is contained in:
@@ -11,6 +11,7 @@ import {
|
||||
FileDiscoveryService,
|
||||
GlobTool,
|
||||
ReadManyFilesTool,
|
||||
StandardFileSystemService,
|
||||
ToolRegistry,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import * as os from 'os';
|
||||
@@ -56,11 +57,18 @@ describe('handleAtCommand', () => {
|
||||
respectGitIgnore: true,
|
||||
respectGeminiIgnore: true,
|
||||
}),
|
||||
getFileSystemService: () => new StandardFileSystemService(),
|
||||
getEnableRecursiveFileSearch: vi.fn(() => true),
|
||||
getWorkspaceContext: () => ({
|
||||
isPathWithinWorkspace: () => true,
|
||||
getDirectories: () => [testRootDir],
|
||||
}),
|
||||
getMcpServers: () => ({}),
|
||||
getMcpServerCommand: () => undefined,
|
||||
getPromptRegistry: () => ({
|
||||
getPromptsByServer: () => [],
|
||||
}),
|
||||
getDebugMode: () => false,
|
||||
} as unknown as Config;
|
||||
|
||||
const registry = new ToolRegistry(mockConfig);
|
||||
@@ -90,10 +98,6 @@ describe('handleAtCommand', () => {
|
||||
processedQuery: [{ text: query }],
|
||||
shouldProceed: true,
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{ type: 'user', text: query },
|
||||
123,
|
||||
);
|
||||
});
|
||||
|
||||
it('should pass through original query if only a lone @ symbol is present', async () => {
|
||||
@@ -112,10 +116,6 @@ describe('handleAtCommand', () => {
|
||||
processedQuery: [{ text: queryWithSpaces }],
|
||||
shouldProceed: true,
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{ type: 'user', text: queryWithSpaces },
|
||||
124,
|
||||
);
|
||||
expect(mockOnDebugMessage).toHaveBeenCalledWith(
|
||||
'Lone @ detected, will be treated as text in the modified query.',
|
||||
);
|
||||
@@ -148,10 +148,6 @@ describe('handleAtCommand', () => {
|
||||
],
|
||||
shouldProceed: true,
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{ type: 'user', text: query },
|
||||
125,
|
||||
);
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'tool_group',
|
||||
@@ -190,10 +186,6 @@ describe('handleAtCommand', () => {
|
||||
],
|
||||
shouldProceed: true,
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{ type: 'user', text: query },
|
||||
126,
|
||||
);
|
||||
expect(mockOnDebugMessage).toHaveBeenCalledWith(
|
||||
`Path ${dirPath} resolved to directory, using glob: ${resolvedGlob}`,
|
||||
);
|
||||
@@ -228,10 +220,6 @@ describe('handleAtCommand', () => {
|
||||
],
|
||||
shouldProceed: true,
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{ type: 'user', text: query },
|
||||
128,
|
||||
);
|
||||
});
|
||||
|
||||
it('should correctly unescape paths with escaped spaces', async () => {
|
||||
@@ -262,10 +250,6 @@ describe('handleAtCommand', () => {
|
||||
],
|
||||
shouldProceed: true,
|
||||
});
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{ type: 'user', text: query },
|
||||
125,
|
||||
);
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: 'tool_group',
|
||||
@@ -1082,4 +1066,37 @@ describe('handleAtCommand', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should not add the user's turn to history, as that is the caller's responsibility", async () => {
|
||||
// Arrange
|
||||
const fileContent = 'This is the file content.';
|
||||
const filePath = await createTestFile(
|
||||
path.join(testRootDir, 'path', 'to', 'another-file.txt'),
|
||||
fileContent,
|
||||
);
|
||||
const query = `A query with @${filePath}`;
|
||||
|
||||
// Act
|
||||
await handleAtCommand({
|
||||
query,
|
||||
config: mockConfig,
|
||||
addItem: mockAddItem,
|
||||
onDebugMessage: mockOnDebugMessage,
|
||||
messageId: 999,
|
||||
signal: abortController.signal,
|
||||
});
|
||||
|
||||
// Assert
|
||||
// It SHOULD be called for the tool_group
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ type: 'tool_group' }),
|
||||
999,
|
||||
);
|
||||
|
||||
// It should NOT have been called for the user turn
|
||||
const userTurnCalls = mockAddItem.mock.calls.filter(
|
||||
(call) => call[0].type === 'user',
|
||||
);
|
||||
expect(userTurnCalls).toHaveLength(0);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -137,12 +137,9 @@ export async function handleAtCommand({
|
||||
);
|
||||
|
||||
if (atPathCommandParts.length === 0) {
|
||||
addItem({ type: 'user', text: query }, userMessageTimestamp);
|
||||
return { processedQuery: [{ text: query }], shouldProceed: true };
|
||||
}
|
||||
|
||||
addItem({ type: 'user', text: query }, userMessageTimestamp);
|
||||
|
||||
// Get centralized file discovery service
|
||||
const fileDiscovery = config.getFileService();
|
||||
|
||||
@@ -157,7 +154,7 @@ export async function handleAtCommand({
|
||||
both: [],
|
||||
};
|
||||
|
||||
const toolRegistry = await config.getToolRegistry();
|
||||
const toolRegistry = config.getToolRegistry();
|
||||
const readManyFilesTool = toolRegistry.getTool('read_many_files');
|
||||
const globTool = toolRegistry.getTool('glob');
|
||||
|
||||
@@ -362,20 +359,20 @@ export async function handleAtCommand({
|
||||
|
||||
// Inform user about ignored paths
|
||||
const totalIgnored =
|
||||
ignoredByReason.git.length +
|
||||
ignoredByReason.gemini.length +
|
||||
ignoredByReason.both.length;
|
||||
ignoredByReason['git'].length +
|
||||
ignoredByReason['gemini'].length +
|
||||
ignoredByReason['both'].length;
|
||||
|
||||
if (totalIgnored > 0) {
|
||||
const messages = [];
|
||||
if (ignoredByReason.git.length) {
|
||||
messages.push(`Git-ignored: ${ignoredByReason.git.join(', ')}`);
|
||||
if (ignoredByReason['git'].length) {
|
||||
messages.push(`Git-ignored: ${ignoredByReason['git'].join(', ')}`);
|
||||
}
|
||||
if (ignoredByReason.gemini.length) {
|
||||
messages.push(`Gemini-ignored: ${ignoredByReason.gemini.join(', ')}`);
|
||||
if (ignoredByReason['gemini'].length) {
|
||||
messages.push(`Gemini-ignored: ${ignoredByReason['gemini'].join(', ')}`);
|
||||
}
|
||||
if (ignoredByReason.both.length) {
|
||||
messages.push(`Ignored by both: ${ignoredByReason.both.join(', ')}`);
|
||||
if (ignoredByReason['both'].length) {
|
||||
messages.push(`Ignored by both: ${ignoredByReason['both'].join(', ')}`);
|
||||
}
|
||||
|
||||
const message = `Ignored ${totalIgnored} files:\n${messages.join('\n')}`;
|
||||
|
||||
@@ -65,7 +65,10 @@ describe('useShellCommandProcessor', () => {
|
||||
setPendingHistoryItemMock = vi.fn();
|
||||
onExecMock = vi.fn();
|
||||
onDebugMessageMock = vi.fn();
|
||||
mockConfig = { getTargetDir: () => '/test/dir' } as Config;
|
||||
mockConfig = {
|
||||
getTargetDir: () => '/test/dir',
|
||||
getShouldUseNodePtyShell: () => false,
|
||||
} as Config;
|
||||
mockGeminiClient = { addHistory: vi.fn() } as unknown as GeminiClient;
|
||||
|
||||
vi.mocked(os.platform).mockReturnValue('linux');
|
||||
@@ -104,13 +107,12 @@ describe('useShellCommandProcessor', () => {
|
||||
): ShellExecutionResult => ({
|
||||
rawOutput: Buffer.from(overrides.output || ''),
|
||||
output: 'Success',
|
||||
stdout: 'Success',
|
||||
stderr: '',
|
||||
exitCode: 0,
|
||||
signal: null,
|
||||
error: null,
|
||||
aborted: false,
|
||||
pid: 12345,
|
||||
executionMethod: 'child_process',
|
||||
...overrides,
|
||||
});
|
||||
|
||||
@@ -141,6 +143,7 @@ describe('useShellCommandProcessor', () => {
|
||||
'/test/dir',
|
||||
expect.any(Function),
|
||||
expect.any(Object),
|
||||
false,
|
||||
);
|
||||
expect(onExecMock).toHaveBeenCalledWith(expect.any(Promise));
|
||||
});
|
||||
@@ -223,7 +226,6 @@ describe('useShellCommandProcessor', () => {
|
||||
act(() => {
|
||||
mockShellOutputCallback({
|
||||
type: 'data',
|
||||
stream: 'stdout',
|
||||
chunk: 'hello',
|
||||
});
|
||||
});
|
||||
@@ -238,7 +240,6 @@ describe('useShellCommandProcessor', () => {
|
||||
act(() => {
|
||||
mockShellOutputCallback({
|
||||
type: 'data',
|
||||
stream: 'stdout',
|
||||
chunk: ' world',
|
||||
});
|
||||
});
|
||||
@@ -319,6 +320,7 @@ describe('useShellCommandProcessor', () => {
|
||||
'/test/dir',
|
||||
expect.any(Function),
|
||||
expect.any(Object),
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
@@ -101,10 +101,11 @@ export const useShellCommandProcessor = (
|
||||
commandToExecute = `{ ${command} }; __code=$?; pwd > "${pwdFilePath}"; exit $__code`;
|
||||
}
|
||||
|
||||
const execPromise = new Promise<void>((resolve) => {
|
||||
const executeCommand = async (
|
||||
resolve: (value: void | PromiseLike<void>) => void,
|
||||
) => {
|
||||
let lastUpdateTime = Date.now();
|
||||
let cumulativeStdout = '';
|
||||
let cumulativeStderr = '';
|
||||
let isBinaryStream = false;
|
||||
let binaryBytesReceived = 0;
|
||||
|
||||
@@ -134,7 +135,7 @@ export const useShellCommandProcessor = (
|
||||
onDebugMessage(`Executing in ${targetDir}: ${commandToExecute}`);
|
||||
|
||||
try {
|
||||
const { pid, result } = ShellExecutionService.execute(
|
||||
const { pid, result } = await ShellExecutionService.execute(
|
||||
commandToExecute,
|
||||
targetDir,
|
||||
(event) => {
|
||||
@@ -142,11 +143,7 @@ export const useShellCommandProcessor = (
|
||||
case 'data':
|
||||
// Do not process text data if we've already switched to binary mode.
|
||||
if (isBinaryStream) break;
|
||||
if (event.stream === 'stdout') {
|
||||
cumulativeStdout += event.chunk;
|
||||
} else {
|
||||
cumulativeStderr += event.chunk;
|
||||
}
|
||||
cumulativeStdout += event.chunk;
|
||||
break;
|
||||
case 'binary_detected':
|
||||
isBinaryStream = true;
|
||||
@@ -172,9 +169,7 @@ export const useShellCommandProcessor = (
|
||||
'[Binary output detected. Halting stream...]';
|
||||
}
|
||||
} else {
|
||||
currentDisplayOutput =
|
||||
cumulativeStdout +
|
||||
(cumulativeStderr ? `\n${cumulativeStderr}` : '');
|
||||
currentDisplayOutput = cumulativeStdout;
|
||||
}
|
||||
|
||||
// Throttle pending UI updates to avoid excessive re-renders.
|
||||
@@ -192,6 +187,7 @@ export const useShellCommandProcessor = (
|
||||
}
|
||||
},
|
||||
abortSignal,
|
||||
config.getShouldUseNodePtyShell(),
|
||||
);
|
||||
|
||||
executionPid = pid;
|
||||
@@ -295,6 +291,10 @@ export const useShellCommandProcessor = (
|
||||
|
||||
resolve(); // Resolve the promise to unblock `onExec`
|
||||
}
|
||||
};
|
||||
|
||||
const execPromise = new Promise<void>((resolve) => {
|
||||
executeCommand(resolve);
|
||||
});
|
||||
|
||||
onExec(execPromise);
|
||||
|
||||
@@ -8,7 +8,6 @@ import { useCallback, useMemo, useEffect, useState } from 'react';
|
||||
import { type PartListUnion } from '@google/genai';
|
||||
import process from 'node:process';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
import { useStateAndRef } from './useStateAndRef.js';
|
||||
import {
|
||||
Config,
|
||||
GitService,
|
||||
@@ -93,16 +92,16 @@ export const useSlashCommandProcessor = (
|
||||
return l;
|
||||
}, [config]);
|
||||
|
||||
const [pendingCompressionItemRef, setPendingCompressionItem] =
|
||||
useStateAndRef<HistoryItemWithoutId | null>(null);
|
||||
const [pendingCompressionItem, setPendingCompressionItem] =
|
||||
useState<HistoryItemWithoutId | null>(null);
|
||||
|
||||
const pendingHistoryItems = useMemo(() => {
|
||||
const items: HistoryItemWithoutId[] = [];
|
||||
if (pendingCompressionItemRef.current != null) {
|
||||
items.push(pendingCompressionItemRef.current);
|
||||
if (pendingCompressionItem != null) {
|
||||
items.push(pendingCompressionItem);
|
||||
}
|
||||
return items;
|
||||
}, [pendingCompressionItemRef]);
|
||||
}, [pendingCompressionItem]);
|
||||
|
||||
const addMessage = useCallback(
|
||||
(message: Message) => {
|
||||
@@ -117,6 +116,7 @@ export const useSlashCommandProcessor = (
|
||||
modelVersion: message.modelVersion,
|
||||
selectedAuthType: message.selectedAuthType,
|
||||
gcpProject: message.gcpProject,
|
||||
ideClient: message.ideClient,
|
||||
};
|
||||
} else if (message.type === MessageType.HELP) {
|
||||
historyItemContent = {
|
||||
@@ -173,7 +173,7 @@ export const useSlashCommandProcessor = (
|
||||
},
|
||||
loadHistory,
|
||||
setDebugMessage: onDebugMessage,
|
||||
pendingItem: pendingCompressionItemRef.current,
|
||||
pendingItem: pendingCompressionItem,
|
||||
setPendingItem: setPendingCompressionItem,
|
||||
toggleCorgiMode,
|
||||
toggleVimEnabled,
|
||||
@@ -183,7 +183,6 @@ export const useSlashCommandProcessor = (
|
||||
session: {
|
||||
stats: session.stats,
|
||||
sessionShellAllowlist,
|
||||
resetSession: session.resetSession,
|
||||
},
|
||||
}),
|
||||
[
|
||||
@@ -196,9 +195,8 @@ export const useSlashCommandProcessor = (
|
||||
clearItems,
|
||||
refreshStatic,
|
||||
session.stats,
|
||||
session.resetSession,
|
||||
onDebugMessage,
|
||||
pendingCompressionItemRef,
|
||||
pendingCompressionItem,
|
||||
setPendingCompressionItem,
|
||||
toggleCorgiMode,
|
||||
toggleVimEnabled,
|
||||
@@ -208,7 +206,22 @@ export const useSlashCommandProcessor = (
|
||||
],
|
||||
);
|
||||
|
||||
const ideMode = config?.getIdeMode();
|
||||
useEffect(() => {
|
||||
if (!config) {
|
||||
return;
|
||||
}
|
||||
|
||||
const ideClient = config.getIdeClient();
|
||||
const listener = () => {
|
||||
reloadCommands();
|
||||
};
|
||||
|
||||
ideClient.addStatusChangeListener(listener);
|
||||
|
||||
return () => {
|
||||
ideClient.removeStatusChangeListener(listener);
|
||||
};
|
||||
}, [config, reloadCommands]);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
@@ -230,7 +243,7 @@ export const useSlashCommandProcessor = (
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [config, ideMode, reloadTrigger]);
|
||||
}, [config, reloadTrigger]);
|
||||
|
||||
const handleSlashCommand = useCallback(
|
||||
async (
|
||||
|
||||
@@ -9,7 +9,11 @@
|
||||
import { describe, it, expect, beforeEach, vi, afterEach } from 'vitest';
|
||||
import { renderHook, waitFor, act } from '@testing-library/react';
|
||||
import { useAtCompletion } from './useAtCompletion.js';
|
||||
import { Config, FileSearch } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
Config,
|
||||
FileSearch,
|
||||
FileSearchFactory,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
createTmpDir,
|
||||
cleanupTmpDir,
|
||||
@@ -190,14 +194,25 @@ describe('useAtCompletion', () => {
|
||||
const structure: FileSystemStructure = { 'a.txt': '', 'b.txt': '' };
|
||||
testRootDir = await createTmpDir(structure);
|
||||
|
||||
// Spy on the search method to introduce an artificial delay
|
||||
const originalSearch = FileSearch.prototype.search;
|
||||
vi.spyOn(FileSearch.prototype, 'search').mockImplementation(
|
||||
async function (...args) {
|
||||
const realFileSearch = FileSearchFactory.create({
|
||||
projectRoot: testRootDir,
|
||||
ignoreDirs: [],
|
||||
useGitignore: true,
|
||||
useGeminiignore: true,
|
||||
cache: false,
|
||||
cacheTtl: 0,
|
||||
enableRecursiveFileSearch: true,
|
||||
});
|
||||
await realFileSearch.initialize();
|
||||
|
||||
const mockFileSearch: FileSearch = {
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
search: vi.fn().mockImplementation(async (...args) => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
return originalSearch.apply(this, args);
|
||||
},
|
||||
);
|
||||
return realFileSearch.search(...args);
|
||||
}),
|
||||
};
|
||||
vi.spyOn(FileSearchFactory, 'create').mockReturnValue(mockFileSearch);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ pattern }) =>
|
||||
@@ -241,14 +256,15 @@ describe('useAtCompletion', () => {
|
||||
testRootDir = await createTmpDir(structure);
|
||||
|
||||
const abortSpy = vi.spyOn(AbortController.prototype, 'abort');
|
||||
const searchSpy = vi
|
||||
.spyOn(FileSearch.prototype, 'search')
|
||||
.mockImplementation(async (...args) => {
|
||||
const delay = args[0] === 'a' ? 500 : 50;
|
||||
const mockFileSearch: FileSearch = {
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
search: vi.fn().mockImplementation(async (pattern: string) => {
|
||||
const delay = pattern === 'a' ? 500 : 50;
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
return [args[0] as any];
|
||||
});
|
||||
return [pattern];
|
||||
}),
|
||||
};
|
||||
vi.spyOn(FileSearchFactory, 'create').mockReturnValue(mockFileSearch);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ pattern }) =>
|
||||
@@ -258,7 +274,10 @@ describe('useAtCompletion', () => {
|
||||
|
||||
// Wait for the hook to be ready (initialization is complete)
|
||||
await waitFor(() => {
|
||||
expect(searchSpy).toHaveBeenCalledWith('a', expect.any(Object));
|
||||
expect(mockFileSearch.search).toHaveBeenCalledWith(
|
||||
'a',
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
|
||||
// Now that the first search is in-flight, trigger the second one.
|
||||
@@ -278,9 +297,10 @@ describe('useAtCompletion', () => {
|
||||
);
|
||||
|
||||
// The search spy should have been called for both patterns.
|
||||
expect(searchSpy).toHaveBeenCalledWith('b', expect.any(Object));
|
||||
|
||||
vi.restoreAllMocks();
|
||||
expect(mockFileSearch.search).toHaveBeenCalledWith(
|
||||
'b',
|
||||
expect.any(Object),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -313,9 +333,13 @@ describe('useAtCompletion', () => {
|
||||
testRootDir = await createTmpDir({});
|
||||
|
||||
// Force an error during initialization
|
||||
vi.spyOn(FileSearch.prototype, 'initialize').mockRejectedValueOnce(
|
||||
new Error('Initialization failed'),
|
||||
);
|
||||
const mockFileSearch: FileSearch = {
|
||||
initialize: vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('Initialization failed')),
|
||||
search: vi.fn(),
|
||||
};
|
||||
vi.spyOn(FileSearchFactory, 'create').mockReturnValue(mockFileSearch);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
({ enabled }) =>
|
||||
|
||||
@@ -5,7 +5,12 @@
|
||||
*/
|
||||
|
||||
import { useEffect, useReducer, useRef } from 'react';
|
||||
import { Config, FileSearch, escapePath } from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
Config,
|
||||
FileSearch,
|
||||
FileSearchFactory,
|
||||
escapePath,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import {
|
||||
Suggestion,
|
||||
MAX_SUGGESTIONS_TO_SHOW,
|
||||
@@ -156,7 +161,7 @@ export function useAtCompletion(props: UseAtCompletionProps): void {
|
||||
useEffect(() => {
|
||||
const initialize = async () => {
|
||||
try {
|
||||
const searcher = new FileSearch({
|
||||
const searcher = FileSearchFactory.create({
|
||||
projectRoot: cwd,
|
||||
ignoreDirs: [],
|
||||
useGitignore:
|
||||
@@ -165,9 +170,8 @@ export function useAtCompletion(props: UseAtCompletionProps): void {
|
||||
config?.getFileFilteringOptions()?.respectGeminiIgnore ?? true,
|
||||
cache: true,
|
||||
cacheTtl: 30, // 30 seconds
|
||||
maxDepth: !(config?.getEnableRecursiveFileSearch() ?? true)
|
||||
? 0
|
||||
: undefined,
|
||||
enableRecursiveFileSearch:
|
||||
config?.getEnableRecursiveFileSearch() ?? true,
|
||||
});
|
||||
await searcher.initialize();
|
||||
fileSearch.current = searcher;
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
import { vi } from 'vitest';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import { useFolderTrust } from './useFolderTrust.js';
|
||||
import { type Config } from '@qwen-code/qwen-code-core';
|
||||
import { LoadedSettings } from '../../config/settings.js';
|
||||
import { FolderTrustChoice } from '../components/FolderTrustDialog.js';
|
||||
import {
|
||||
@@ -25,9 +24,10 @@ vi.mock('process', () => ({
|
||||
|
||||
describe('useFolderTrust', () => {
|
||||
let mockSettings: LoadedSettings;
|
||||
let mockConfig: Config;
|
||||
let mockTrustedFolders: LoadedTrustedFolders;
|
||||
let loadTrustedFoldersSpy: vi.SpyInstance;
|
||||
let isWorkspaceTrustedSpy: vi.SpyInstance;
|
||||
let onTrustChange: (isTrusted: boolean | undefined) => void;
|
||||
|
||||
beforeEach(() => {
|
||||
mockSettings = {
|
||||
@@ -38,10 +38,6 @@ describe('useFolderTrust', () => {
|
||||
setValue: vi.fn(),
|
||||
} as unknown as LoadedSettings;
|
||||
|
||||
mockConfig = {
|
||||
isTrustedFolder: vi.fn().mockReturnValue(undefined),
|
||||
} as unknown as Config;
|
||||
|
||||
mockTrustedFolders = {
|
||||
setValue: vi.fn(),
|
||||
} as unknown as LoadedTrustedFolders;
|
||||
@@ -49,7 +45,9 @@ describe('useFolderTrust', () => {
|
||||
loadTrustedFoldersSpy = vi
|
||||
.spyOn(trustedFolders, 'loadTrustedFolders')
|
||||
.mockReturnValue(mockTrustedFolders);
|
||||
isWorkspaceTrustedSpy = vi.spyOn(trustedFolders, 'isWorkspaceTrusted');
|
||||
(process.cwd as vi.Mock).mockReturnValue('/test/path');
|
||||
onTrustChange = vi.fn();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -57,34 +55,39 @@ describe('useFolderTrust', () => {
|
||||
});
|
||||
|
||||
it('should not open dialog when folder is already trusted', () => {
|
||||
(mockConfig.isTrustedFolder as vi.Mock).mockReturnValue(true);
|
||||
isWorkspaceTrustedSpy.mockReturnValue(true);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(false);
|
||||
expect(onTrustChange).toHaveBeenCalledWith(true);
|
||||
});
|
||||
|
||||
it('should not open dialog when folder is already untrusted', () => {
|
||||
(mockConfig.isTrustedFolder as vi.Mock).mockReturnValue(false);
|
||||
isWorkspaceTrustedSpy.mockReturnValue(false);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(false);
|
||||
expect(onTrustChange).toHaveBeenCalledWith(false);
|
||||
});
|
||||
|
||||
it('should open dialog when folder trust is undefined', () => {
|
||||
(mockConfig.isTrustedFolder as vi.Mock).mockReturnValue(undefined);
|
||||
isWorkspaceTrustedSpy.mockReturnValue(undefined);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(true);
|
||||
expect(onTrustChange).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
|
||||
it('should handle TRUST_FOLDER choice', () => {
|
||||
isWorkspaceTrustedSpy.mockReturnValue(undefined);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
|
||||
isWorkspaceTrustedSpy.mockReturnValue(true);
|
||||
act(() => {
|
||||
result.current.handleFolderTrustSelect(FolderTrustChoice.TRUST_FOLDER);
|
||||
});
|
||||
@@ -95,13 +98,16 @@ describe('useFolderTrust', () => {
|
||||
TrustLevel.TRUST_FOLDER,
|
||||
);
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(false);
|
||||
expect(onTrustChange).toHaveBeenLastCalledWith(true);
|
||||
});
|
||||
|
||||
it('should handle TRUST_PARENT choice', () => {
|
||||
isWorkspaceTrustedSpy.mockReturnValue(undefined);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
|
||||
isWorkspaceTrustedSpy.mockReturnValue(true);
|
||||
act(() => {
|
||||
result.current.handleFolderTrustSelect(FolderTrustChoice.TRUST_PARENT);
|
||||
});
|
||||
@@ -111,13 +117,16 @@ describe('useFolderTrust', () => {
|
||||
TrustLevel.TRUST_PARENT,
|
||||
);
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(false);
|
||||
expect(onTrustChange).toHaveBeenLastCalledWith(true);
|
||||
});
|
||||
|
||||
it('should handle DO_NOT_TRUST choice', () => {
|
||||
isWorkspaceTrustedSpy.mockReturnValue(undefined);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
|
||||
isWorkspaceTrustedSpy.mockReturnValue(false);
|
||||
act(() => {
|
||||
result.current.handleFolderTrustSelect(FolderTrustChoice.DO_NOT_TRUST);
|
||||
});
|
||||
@@ -127,11 +136,13 @@ describe('useFolderTrust', () => {
|
||||
TrustLevel.DO_NOT_TRUST,
|
||||
);
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(false);
|
||||
expect(onTrustChange).toHaveBeenLastCalledWith(false);
|
||||
});
|
||||
|
||||
it('should do nothing for default choice', () => {
|
||||
isWorkspaceTrustedSpy.mockReturnValue(undefined);
|
||||
const { result } = renderHook(() =>
|
||||
useFolderTrust(mockSettings, mockConfig),
|
||||
useFolderTrust(mockSettings, onTrustChange),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
@@ -143,5 +154,6 @@ describe('useFolderTrust', () => {
|
||||
expect(mockTrustedFolders.setValue).not.toHaveBeenCalled();
|
||||
expect(mockSettings.setValue).not.toHaveBeenCalled();
|
||||
expect(result.current.isFolderTrustDialogOpen).toBe(true);
|
||||
expect(onTrustChange).toHaveBeenCalledWith(undefined);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,42 +4,68 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { useState, useCallback } from 'react';
|
||||
import { type Config } from '@qwen-code/qwen-code-core';
|
||||
import { LoadedSettings } from '../../config/settings.js';
|
||||
import { useState, useCallback, useEffect } from 'react';
|
||||
import { Settings, LoadedSettings } from '../../config/settings.js';
|
||||
import { FolderTrustChoice } from '../components/FolderTrustDialog.js';
|
||||
import { loadTrustedFolders, TrustLevel } from '../../config/trustedFolders.js';
|
||||
import {
|
||||
loadTrustedFolders,
|
||||
TrustLevel,
|
||||
isWorkspaceTrusted,
|
||||
} from '../../config/trustedFolders.js';
|
||||
import * as process from 'process';
|
||||
|
||||
export const useFolderTrust = (settings: LoadedSettings, config: Config) => {
|
||||
const [isFolderTrustDialogOpen, setIsFolderTrustDialogOpen] = useState(
|
||||
config.isTrustedFolder() === undefined,
|
||||
export const useFolderTrust = (
|
||||
settings: LoadedSettings,
|
||||
onTrustChange: (isTrusted: boolean | undefined) => void,
|
||||
) => {
|
||||
const [isTrusted, setIsTrusted] = useState<boolean | undefined>(undefined);
|
||||
const [isFolderTrustDialogOpen, setIsFolderTrustDialogOpen] = useState(false);
|
||||
|
||||
const { folderTrust, folderTrustFeature } = settings.merged;
|
||||
useEffect(() => {
|
||||
const trusted = isWorkspaceTrusted({
|
||||
folderTrust,
|
||||
folderTrustFeature,
|
||||
} as Settings);
|
||||
setIsTrusted(trusted);
|
||||
setIsFolderTrustDialogOpen(trusted === undefined);
|
||||
onTrustChange(trusted);
|
||||
}, [onTrustChange, folderTrust, folderTrustFeature]);
|
||||
|
||||
const handleFolderTrustSelect = useCallback(
|
||||
(choice: FolderTrustChoice) => {
|
||||
const trustedFolders = loadTrustedFolders();
|
||||
const cwd = process.cwd();
|
||||
let trustLevel: TrustLevel;
|
||||
|
||||
switch (choice) {
|
||||
case FolderTrustChoice.TRUST_FOLDER:
|
||||
trustLevel = TrustLevel.TRUST_FOLDER;
|
||||
break;
|
||||
case FolderTrustChoice.TRUST_PARENT:
|
||||
trustLevel = TrustLevel.TRUST_PARENT;
|
||||
break;
|
||||
case FolderTrustChoice.DO_NOT_TRUST:
|
||||
trustLevel = TrustLevel.DO_NOT_TRUST;
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
|
||||
trustedFolders.setValue(cwd, trustLevel);
|
||||
const trusted = isWorkspaceTrusted({
|
||||
folderTrust,
|
||||
folderTrustFeature,
|
||||
} as Settings);
|
||||
setIsTrusted(trusted);
|
||||
setIsFolderTrustDialogOpen(false);
|
||||
onTrustChange(trusted);
|
||||
},
|
||||
[onTrustChange, folderTrust, folderTrustFeature],
|
||||
);
|
||||
|
||||
const handleFolderTrustSelect = useCallback((choice: FolderTrustChoice) => {
|
||||
const trustedFolders = loadTrustedFolders();
|
||||
const cwd = process.cwd();
|
||||
let trustLevel: TrustLevel;
|
||||
|
||||
switch (choice) {
|
||||
case FolderTrustChoice.TRUST_FOLDER:
|
||||
trustLevel = TrustLevel.TRUST_FOLDER;
|
||||
break;
|
||||
case FolderTrustChoice.TRUST_PARENT:
|
||||
trustLevel = TrustLevel.TRUST_PARENT;
|
||||
break;
|
||||
case FolderTrustChoice.DO_NOT_TRUST:
|
||||
trustLevel = TrustLevel.DO_NOT_TRUST;
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
|
||||
trustedFolders.setValue(cwd, trustLevel);
|
||||
setIsFolderTrustDialogOpen(false);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
isTrusted,
|
||||
isFolderTrustDialogOpen,
|
||||
handleFolderTrustSelect,
|
||||
};
|
||||
|
||||
@@ -5,10 +5,19 @@
|
||||
*/
|
||||
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { describe, it, expect, vi, beforeEach, Mock } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
Mock,
|
||||
MockInstance,
|
||||
} from 'vitest';
|
||||
import { renderHook, act, waitFor } from '@testing-library/react';
|
||||
import { useGeminiStream, mergePartListUnions } from './useGeminiStream.js';
|
||||
import { useKeypress } from './useKeypress.js';
|
||||
import * as atCommandProcessor from './atCommandProcessor.js';
|
||||
import {
|
||||
useReactToolScheduler,
|
||||
TrackedToolCall,
|
||||
@@ -20,8 +29,10 @@ import {
|
||||
Config,
|
||||
EditorType,
|
||||
AuthType,
|
||||
GeminiClient,
|
||||
GeminiEventType as ServerGeminiEventType,
|
||||
AnyToolInvocation,
|
||||
ToolErrorType,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { Part, PartListUnion } from '@google/genai';
|
||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||
@@ -83,11 +94,7 @@ vi.mock('./shellCommandProcessor.js', () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock('./atCommandProcessor.js', () => ({
|
||||
handleAtCommand: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ shouldProceed: true, processedQuery: 'mocked' }),
|
||||
}));
|
||||
vi.mock('./atCommandProcessor.js');
|
||||
|
||||
vi.mock('../utils/markdownUtilities.js', () => ({
|
||||
findLastSafeSplitPoint: vi.fn((s: string) => s.length),
|
||||
@@ -259,6 +266,7 @@ describe('useGeminiStream', () => {
|
||||
let mockScheduleToolCalls: Mock;
|
||||
let mockCancelAllToolCalls: Mock;
|
||||
let mockMarkToolsAsSubmitted: Mock;
|
||||
let handleAtCommandSpy: MockInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks(); // Clear mocks before each test
|
||||
@@ -342,6 +350,7 @@ describe('useGeminiStream', () => {
|
||||
mockSendMessageStream
|
||||
.mockClear()
|
||||
.mockReturnValue((async function* () {})());
|
||||
handleAtCommandSpy = vi.spyOn(atCommandProcessor, 'handleAtCommand');
|
||||
});
|
||||
|
||||
const mockLoadedSettings: LoadedSettings = {
|
||||
@@ -513,7 +522,11 @@ describe('useGeminiStream', () => {
|
||||
},
|
||||
status: 'success',
|
||||
responseSubmittedToGemini: false,
|
||||
response: { callId: 'call1', responseParts: toolCall1ResponseParts },
|
||||
response: {
|
||||
callId: 'call1',
|
||||
responseParts: toolCall1ResponseParts,
|
||||
errorType: undefined, // FIX: Added missing property
|
||||
},
|
||||
tool: {
|
||||
displayName: 'MockTool',
|
||||
},
|
||||
@@ -531,7 +544,11 @@ describe('useGeminiStream', () => {
|
||||
},
|
||||
status: 'error',
|
||||
responseSubmittedToGemini: false,
|
||||
response: { callId: 'call2', responseParts: toolCall2ResponseParts },
|
||||
response: {
|
||||
callId: 'call2',
|
||||
responseParts: toolCall2ResponseParts,
|
||||
errorType: ToolErrorType.UNHANDLED_EXCEPTION, // FIX: Added missing property
|
||||
},
|
||||
} as TrackedCompletedToolCall, // Treat error as a form of completion for submission
|
||||
];
|
||||
|
||||
@@ -598,7 +615,11 @@ describe('useGeminiStream', () => {
|
||||
prompt_id: 'prompt-id-3',
|
||||
},
|
||||
status: 'cancelled',
|
||||
response: { callId: '1', responseParts: [{ text: 'cancelled' }] },
|
||||
response: {
|
||||
callId: '1',
|
||||
responseParts: [{ text: 'cancelled' }],
|
||||
errorType: undefined, // FIX: Added missing property
|
||||
},
|
||||
responseSubmittedToGemini: false,
|
||||
tool: {
|
||||
displayName: 'mock tool',
|
||||
@@ -1902,173 +1923,76 @@ describe('useGeminiStream', () => {
|
||||
});
|
||||
|
||||
// Second call should work normally
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Valid response',
|
||||
};
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Valid query');
|
||||
});
|
||||
|
||||
// The second call should have been made
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
'Valid query',
|
||||
expect.any(AbortSignal),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should reset execution flag when user cancels', async () => {
|
||||
let resolveCancelledStream!: () => void;
|
||||
const cancelledStreamPromise = new Promise<void>((resolve) => {
|
||||
resolveCancelledStream = resolve;
|
||||
});
|
||||
|
||||
// Mock a stream that can be cancelled
|
||||
const cancelledStream = (async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Cancelled content',
|
||||
};
|
||||
await cancelledStreamPromise;
|
||||
yield { type: ServerGeminiEventType.UserCancelled };
|
||||
})();
|
||||
|
||||
mockSendMessageStream.mockReturnValueOnce(cancelledStream);
|
||||
|
||||
const { result } = renderTestHook();
|
||||
|
||||
// Start first call
|
||||
const firstCallResult = act(async () => {
|
||||
await result.current.submitQuery('First query');
|
||||
});
|
||||
|
||||
// Wait a bit then resolve to trigger cancellation
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
resolveCancelledStream();
|
||||
await firstCallResult;
|
||||
|
||||
// Now try a second call - should work
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Second response',
|
||||
};
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Second query');
|
||||
});
|
||||
|
||||
// Both calls should have been made
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should reset execution flag when an error occurs', async () => {
|
||||
// Mock a stream that throws an error
|
||||
mockSendMessageStream.mockReturnValueOnce(
|
||||
(async function* () {
|
||||
yield { type: ServerGeminiEventType.Content, value: 'Error content' };
|
||||
throw new Error('Stream error');
|
||||
})(),
|
||||
);
|
||||
|
||||
const { result } = renderTestHook();
|
||||
|
||||
// First call that will error
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Error query');
|
||||
});
|
||||
|
||||
// Second call should work normally
|
||||
mockSendMessageStream.mockReturnValue(
|
||||
(async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Success response',
|
||||
};
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})(),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery('Success query');
|
||||
});
|
||||
|
||||
// Both calls should have been attempted
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle rapid multiple concurrent calls correctly', async () => {
|
||||
let resolveStream!: () => void;
|
||||
const streamPromise = new Promise<void>((resolve) => {
|
||||
resolveStream = resolve;
|
||||
});
|
||||
|
||||
// Mock a long-running stream
|
||||
const longStream = (async function* () {
|
||||
yield {
|
||||
type: ServerGeminiEventType.Content,
|
||||
value: 'Long running content',
|
||||
};
|
||||
await streamPromise;
|
||||
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
|
||||
})();
|
||||
|
||||
mockSendMessageStream.mockReturnValue(longStream);
|
||||
|
||||
const { result } = renderTestHook();
|
||||
|
||||
// Start multiple concurrent calls
|
||||
const calls = [
|
||||
act(async () => {
|
||||
await result.current.submitQuery('Query 1');
|
||||
}),
|
||||
act(async () => {
|
||||
await result.current.submitQuery('Query 2');
|
||||
}),
|
||||
act(async () => {
|
||||
await result.current.submitQuery('Query 3');
|
||||
}),
|
||||
act(async () => {
|
||||
await result.current.submitQuery('Query 4');
|
||||
}),
|
||||
act(async () => {
|
||||
await result.current.submitQuery('Query 5');
|
||||
}),
|
||||
];
|
||||
|
||||
// Wait a bit then resolve the stream
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
resolveStream();
|
||||
|
||||
// Wait for all calls to complete
|
||||
await Promise.all(calls);
|
||||
|
||||
// Only the first call should have been made
|
||||
// Verify that only the second call was made (empty query is filtered out)
|
||||
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
'Query 1',
|
||||
'Second query',
|
||||
expect.any(AbortSignal),
|
||||
expect.any(String),
|
||||
);
|
||||
|
||||
// Only one user message should have been added
|
||||
const userMessages = mockAddItem.mock.calls.filter(
|
||||
(call) => call[0].type === MessageType.USER,
|
||||
);
|
||||
expect(userMessages).toHaveLength(1);
|
||||
expect(userMessages[0][0].text).toBe('Query 1');
|
||||
});
|
||||
});
|
||||
|
||||
it('should process @include commands, adding user turn after processing to prevent race conditions', async () => {
|
||||
const rawQuery = '@include file.txt Summarize this.';
|
||||
const processedQueryParts = [
|
||||
{ text: 'Summarize this with content from @file.txt' },
|
||||
{ text: 'File content...' },
|
||||
];
|
||||
const userMessageTimestamp = Date.now();
|
||||
vi.spyOn(Date, 'now').mockReturnValue(userMessageTimestamp);
|
||||
|
||||
handleAtCommandSpy.mockResolvedValue({
|
||||
processedQuery: processedQueryParts,
|
||||
shouldProceed: true,
|
||||
});
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useGeminiStream(
|
||||
mockConfig.getGeminiClient() as GeminiClient,
|
||||
[],
|
||||
mockAddItem,
|
||||
mockConfig,
|
||||
mockOnDebugMessage,
|
||||
mockHandleSlashCommand,
|
||||
false,
|
||||
vi.fn(),
|
||||
vi.fn(),
|
||||
vi.fn(),
|
||||
false,
|
||||
vi.fn(),
|
||||
vi.fn(),
|
||||
vi.fn(),
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await result.current.submitQuery(rawQuery);
|
||||
});
|
||||
|
||||
expect(handleAtCommandSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
query: rawQuery,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(mockAddItem).toHaveBeenCalledWith(
|
||||
{
|
||||
type: MessageType.USER,
|
||||
text: rawQuery,
|
||||
},
|
||||
userMessageTimestamp,
|
||||
);
|
||||
|
||||
// FIX: This expectation now correctly matches the actual function call signature.
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith(
|
||||
processedQueryParts, // Argument 1: The parts array directly
|
||||
expect.any(AbortSignal), // Argument 2: An AbortSignal
|
||||
expect.any(String), // Argument 3: The prompt_id string
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -308,6 +308,13 @@ export const useGeminiStream = (
|
||||
messageId: userMessageTimestamp,
|
||||
signal: abortSignal,
|
||||
});
|
||||
|
||||
// Add user's turn after @ command processing is done.
|
||||
addItem(
|
||||
{ type: MessageType.USER, text: trimmedQuery },
|
||||
userMessageTimestamp,
|
||||
);
|
||||
|
||||
if (!atCommandResult.shouldProceed) {
|
||||
return { queryToSend: null, shouldProceed: false };
|
||||
}
|
||||
|
||||
@@ -4,8 +4,10 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import { useKeypress, Key } from './useKeypress.js';
|
||||
import { KeypressProvider } from '../contexts/KeypressContext.js';
|
||||
import { useStdin } from 'ink';
|
||||
import { EventEmitter } from 'events';
|
||||
import { PassThrough } from 'stream';
|
||||
@@ -102,6 +104,9 @@ describe('useKeypress', () => {
|
||||
const onKeypress = vi.fn();
|
||||
let originalNodeVersion: string;
|
||||
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) =>
|
||||
React.createElement(KeypressProvider, null, children);
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
stdin = new MockStdin();
|
||||
@@ -111,7 +116,7 @@ describe('useKeypress', () => {
|
||||
});
|
||||
|
||||
originalNodeVersion = process.versions.node;
|
||||
delete process.env['PASTE_WORKAROUND'];
|
||||
vi.unstubAllEnvs();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -129,7 +134,9 @@ describe('useKeypress', () => {
|
||||
};
|
||||
|
||||
it('should not listen if isActive is false', () => {
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: false }));
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: false }), {
|
||||
wrapper,
|
||||
});
|
||||
act(() => stdin.pressKey({ name: 'a' }));
|
||||
expect(onKeypress).not.toHaveBeenCalled();
|
||||
});
|
||||
@@ -141,14 +148,15 @@ describe('useKeypress', () => {
|
||||
{ key: { name: 'up', sequence: '\x1b[A' } },
|
||||
{ key: { name: 'down', sequence: '\x1b[B' } },
|
||||
])('should listen for keypress when active for key $key.name', ({ key }) => {
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }));
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }), { wrapper });
|
||||
act(() => stdin.pressKey(key));
|
||||
expect(onKeypress).toHaveBeenCalledWith(expect.objectContaining(key));
|
||||
});
|
||||
|
||||
it('should set and release raw mode', () => {
|
||||
const { unmount } = renderHook(() =>
|
||||
useKeypress(onKeypress, { isActive: true }),
|
||||
const { unmount } = renderHook(
|
||||
() => useKeypress(onKeypress, { isActive: true }),
|
||||
{ wrapper },
|
||||
);
|
||||
expect(mockSetRawMode).toHaveBeenCalledWith(true);
|
||||
unmount();
|
||||
@@ -156,8 +164,9 @@ describe('useKeypress', () => {
|
||||
});
|
||||
|
||||
it('should stop listening after being unmounted', () => {
|
||||
const { unmount } = renderHook(() =>
|
||||
useKeypress(onKeypress, { isActive: true }),
|
||||
const { unmount } = renderHook(
|
||||
() => useKeypress(onKeypress, { isActive: true }),
|
||||
{ wrapper },
|
||||
);
|
||||
unmount();
|
||||
act(() => stdin.pressKey({ name: 'a' }));
|
||||
@@ -165,7 +174,7 @@ describe('useKeypress', () => {
|
||||
});
|
||||
|
||||
it('should correctly identify alt+enter (meta key)', () => {
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }));
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }), { wrapper });
|
||||
const key = { name: 'return', sequence: '\x1B\r' };
|
||||
act(() => stdin.pressKey(key));
|
||||
expect(onKeypress).toHaveBeenCalledWith(
|
||||
@@ -188,7 +197,7 @@ describe('useKeypress', () => {
|
||||
description: 'Workaround Env Var',
|
||||
setup: () => {
|
||||
setNodeVersion('20.0.0');
|
||||
process.env['PASTE_WORKAROUND'] = 'true';
|
||||
vi.stubEnv('PASTE_WORKAROUND', 'true');
|
||||
},
|
||||
isLegacy: true,
|
||||
},
|
||||
@@ -199,7 +208,9 @@ describe('useKeypress', () => {
|
||||
});
|
||||
|
||||
it('should process a paste as a single event', () => {
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }));
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }), {
|
||||
wrapper,
|
||||
});
|
||||
const pasteText = 'hello world';
|
||||
act(() => stdin.paste(pasteText));
|
||||
|
||||
@@ -215,7 +226,9 @@ describe('useKeypress', () => {
|
||||
});
|
||||
|
||||
it('should handle keypress interspersed with pastes', () => {
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }));
|
||||
renderHook(() => useKeypress(onKeypress, { isActive: true }), {
|
||||
wrapper,
|
||||
});
|
||||
|
||||
const keyA = { name: 'a', sequence: 'a' };
|
||||
act(() => stdin.pressKey(keyA));
|
||||
@@ -239,8 +252,9 @@ describe('useKeypress', () => {
|
||||
});
|
||||
|
||||
it('should emit partial paste content if unmounted mid-paste', () => {
|
||||
const { unmount } = renderHook(() =>
|
||||
useKeypress(onKeypress, { isActive: true }),
|
||||
const { unmount } = renderHook(
|
||||
() => useKeypress(onKeypress, { isActive: true }),
|
||||
{ wrapper },
|
||||
);
|
||||
const pasteText = 'incomplete paste';
|
||||
|
||||
|
||||
@@ -4,414 +4,36 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { useEffect, useRef } from 'react';
|
||||
import { useStdin } from 'ink';
|
||||
import readline from 'readline';
|
||||
import { PassThrough } from 'stream';
|
||||
import { useEffect } from 'react';
|
||||
import {
|
||||
KITTY_CTRL_C,
|
||||
BACKSLASH_ENTER_DETECTION_WINDOW_MS,
|
||||
MAX_KITTY_SEQUENCE_LENGTH,
|
||||
} from '../utils/platformConstants.js';
|
||||
import {
|
||||
KittySequenceOverflowEvent,
|
||||
logKittySequenceOverflow,
|
||||
Config,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { FOCUS_IN, FOCUS_OUT } from './useFocus.js';
|
||||
useKeypressContext,
|
||||
KeypressHandler,
|
||||
Key,
|
||||
} from '../contexts/KeypressContext.js';
|
||||
|
||||
const ESC = '\u001B';
|
||||
export const PASTE_MODE_PREFIX = `${ESC}[200~`;
|
||||
export const PASTE_MODE_SUFFIX = `${ESC}[201~`;
|
||||
|
||||
export interface Key {
|
||||
name: string;
|
||||
ctrl: boolean;
|
||||
meta: boolean;
|
||||
shift: boolean;
|
||||
paste: boolean;
|
||||
sequence: string;
|
||||
kittyProtocol?: boolean;
|
||||
}
|
||||
export { Key };
|
||||
|
||||
/**
|
||||
* A hook that listens for keypress events from stdin, providing a
|
||||
* key object that mirrors the one from Node's `readline` module,
|
||||
* adding a 'paste' flag for characters input as part of a bracketed
|
||||
* paste (when enabled).
|
||||
*
|
||||
* Pastes are currently sent as a single key event where the full paste
|
||||
* is in the sequence field.
|
||||
* A hook that listens for keypress events from stdin.
|
||||
*
|
||||
* @param onKeypress - The callback function to execute on each keypress.
|
||||
* @param options - Options to control the hook's behavior.
|
||||
* @param options.isActive - Whether the hook should be actively listening for input.
|
||||
* @param options.kittyProtocolEnabled - Whether Kitty keyboard protocol is enabled.
|
||||
* @param options.config - Optional config for telemetry logging.
|
||||
*/
|
||||
export function useKeypress(
|
||||
onKeypress: (key: Key) => void,
|
||||
{
|
||||
isActive,
|
||||
kittyProtocolEnabled = false,
|
||||
config,
|
||||
}: { isActive: boolean; kittyProtocolEnabled?: boolean; config?: Config },
|
||||
onKeypress: KeypressHandler,
|
||||
{ isActive }: { isActive: boolean },
|
||||
) {
|
||||
const { stdin, setRawMode } = useStdin();
|
||||
const onKeypressRef = useRef(onKeypress);
|
||||
const { subscribe, unsubscribe } = useKeypressContext();
|
||||
|
||||
useEffect(() => {
|
||||
onKeypressRef.current = onKeypress;
|
||||
}, [onKeypress]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isActive || !stdin.isTTY) {
|
||||
if (!isActive) {
|
||||
return;
|
||||
}
|
||||
|
||||
setRawMode(true);
|
||||
|
||||
const keypressStream = new PassThrough();
|
||||
let usePassthrough = false;
|
||||
const nodeMajorVersion = parseInt(process.versions.node.split('.')[0], 10);
|
||||
if (
|
||||
nodeMajorVersion < 20 ||
|
||||
process.env['PASTE_WORKAROUND'] === '1' ||
|
||||
process.env['PASTE_WORKAROUND'] === 'true'
|
||||
) {
|
||||
// Prior to node 20, node's built-in readline does not support bracketed
|
||||
// paste mode. We hack by detecting it with our own handler.
|
||||
usePassthrough = true;
|
||||
}
|
||||
|
||||
let isPaste = false;
|
||||
let pasteBuffer = Buffer.alloc(0);
|
||||
let kittySequenceBuffer = '';
|
||||
let backslashTimeout: NodeJS.Timeout | null = null;
|
||||
let waitingForEnterAfterBackslash = false;
|
||||
|
||||
// Parse Kitty protocol sequences
|
||||
const parseKittySequence = (sequence: string): Key | null => {
|
||||
// Match CSI <number> ; <modifiers> u or ~
|
||||
// Format: ESC [ <keycode> ; <modifiers> u/~
|
||||
const kittyPattern = new RegExp(`^${ESC}\\[(\\d+)(;(\\d+))?([u~])$`);
|
||||
const match = sequence.match(kittyPattern);
|
||||
if (!match) return null;
|
||||
|
||||
const keyCode = parseInt(match[1], 10);
|
||||
const modifiers = match[3] ? parseInt(match[3], 10) : 1;
|
||||
|
||||
// Decode modifiers (subtract 1 as per Kitty protocol spec)
|
||||
const modifierBits = modifiers - 1;
|
||||
const shift = (modifierBits & 1) === 1;
|
||||
const alt = (modifierBits & 2) === 2;
|
||||
const ctrl = (modifierBits & 4) === 4;
|
||||
|
||||
// Handle Escape key (code 27)
|
||||
if (keyCode === 27) {
|
||||
return {
|
||||
name: 'escape',
|
||||
ctrl,
|
||||
meta: alt,
|
||||
shift,
|
||||
paste: false,
|
||||
sequence,
|
||||
kittyProtocol: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Enter key (code 13)
|
||||
if (keyCode === 13) {
|
||||
return {
|
||||
name: 'return',
|
||||
ctrl,
|
||||
meta: alt,
|
||||
shift,
|
||||
paste: false,
|
||||
sequence,
|
||||
kittyProtocol: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle Ctrl+letter combinations (a-z)
|
||||
// ASCII codes: a=97, b=98, c=99, ..., z=122
|
||||
if (keyCode >= 97 && keyCode <= 122 && ctrl) {
|
||||
const letter = String.fromCharCode(keyCode);
|
||||
return {
|
||||
name: letter,
|
||||
ctrl: true,
|
||||
meta: alt,
|
||||
shift,
|
||||
paste: false,
|
||||
sequence,
|
||||
kittyProtocol: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Handle other keys as needed
|
||||
return null;
|
||||
};
|
||||
|
||||
const handleKeypress = (_: unknown, key: Key) => {
|
||||
// Handle VS Code's backslash+return pattern (Shift+Enter)
|
||||
if (key.name === 'return' && waitingForEnterAfterBackslash) {
|
||||
// Cancel the timeout since we got the Enter
|
||||
if (backslashTimeout) {
|
||||
clearTimeout(backslashTimeout);
|
||||
backslashTimeout = null;
|
||||
}
|
||||
waitingForEnterAfterBackslash = false;
|
||||
|
||||
// Convert to Shift+Enter
|
||||
onKeypressRef.current({
|
||||
...key,
|
||||
shift: true,
|
||||
sequence: '\\\r', // VS Code's Shift+Enter representation
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle backslash - hold it to see if Enter follows
|
||||
if (key.sequence === '\\' && !key.name) {
|
||||
// Don't pass through the backslash yet - wait to see if Enter follows
|
||||
waitingForEnterAfterBackslash = true;
|
||||
|
||||
// Set up a timeout to pass through the backslash if no Enter follows
|
||||
backslashTimeout = setTimeout(() => {
|
||||
waitingForEnterAfterBackslash = false;
|
||||
backslashTimeout = null;
|
||||
// Pass through the backslash since no Enter followed
|
||||
onKeypressRef.current(key);
|
||||
}, BACKSLASH_ENTER_DETECTION_WINDOW_MS);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// If we're waiting for Enter after backslash but got something else,
|
||||
// pass through the backslash first, then the new key
|
||||
if (waitingForEnterAfterBackslash && key.name !== 'return') {
|
||||
if (backslashTimeout) {
|
||||
clearTimeout(backslashTimeout);
|
||||
backslashTimeout = null;
|
||||
}
|
||||
waitingForEnterAfterBackslash = false;
|
||||
|
||||
// Pass through the backslash that was held
|
||||
onKeypressRef.current({
|
||||
name: '',
|
||||
sequence: '\\',
|
||||
ctrl: false,
|
||||
meta: false,
|
||||
shift: false,
|
||||
paste: false,
|
||||
});
|
||||
|
||||
// Then continue processing the current key normally
|
||||
}
|
||||
|
||||
// If readline has already identified an arrow key, pass it through
|
||||
// immediately, bypassing the Kitty protocol sequence buffering.
|
||||
if (['up', 'down', 'left', 'right'].includes(key.name)) {
|
||||
onKeypressRef.current(key);
|
||||
return;
|
||||
}
|
||||
|
||||
// Always pass through Ctrl+C immediately, regardless of protocol state
|
||||
// Check both standard format and Kitty protocol sequence
|
||||
if (
|
||||
(key.ctrl && key.name === 'c') ||
|
||||
key.sequence === `${ESC}${KITTY_CTRL_C}`
|
||||
) {
|
||||
kittySequenceBuffer = '';
|
||||
// If it's the Kitty sequence, create a proper key object
|
||||
if (key.sequence === `${ESC}${KITTY_CTRL_C}`) {
|
||||
onKeypressRef.current({
|
||||
name: 'c',
|
||||
ctrl: true,
|
||||
meta: false,
|
||||
shift: false,
|
||||
paste: false,
|
||||
sequence: key.sequence,
|
||||
kittyProtocol: true,
|
||||
});
|
||||
} else {
|
||||
onKeypressRef.current(key);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// If Kitty protocol is enabled, handle CSI sequences
|
||||
if (kittyProtocolEnabled) {
|
||||
// If we have a buffer or this starts a CSI sequence
|
||||
if (
|
||||
kittySequenceBuffer ||
|
||||
(key.sequence.startsWith(`${ESC}[`) &&
|
||||
!key.sequence.startsWith(PASTE_MODE_PREFIX) &&
|
||||
!key.sequence.startsWith(PASTE_MODE_SUFFIX) &&
|
||||
!key.sequence.startsWith(FOCUS_IN) &&
|
||||
!key.sequence.startsWith(FOCUS_OUT))
|
||||
) {
|
||||
kittySequenceBuffer += key.sequence;
|
||||
|
||||
// Try to parse the buffer as a Kitty sequence
|
||||
const kittyKey = parseKittySequence(kittySequenceBuffer);
|
||||
if (kittyKey) {
|
||||
kittySequenceBuffer = '';
|
||||
onKeypressRef.current(kittyKey);
|
||||
return;
|
||||
}
|
||||
|
||||
if (config?.getDebugMode()) {
|
||||
const codes = Array.from(kittySequenceBuffer).map((ch) =>
|
||||
ch.charCodeAt(0),
|
||||
);
|
||||
// Unless the user is sshing over a slow connection, this likely
|
||||
// indicates this is not a kitty sequence but we have incorrectly
|
||||
// interpreted it as such. See the examples above for sequences
|
||||
// such as FOCUS_IN that are not Kitty sequences.
|
||||
console.warn('Kitty sequence buffer has char codes:', codes);
|
||||
}
|
||||
|
||||
// If buffer doesn't match expected pattern and is getting long, flush it
|
||||
if (kittySequenceBuffer.length > MAX_KITTY_SEQUENCE_LENGTH) {
|
||||
// Log telemetry for buffer overflow
|
||||
if (config) {
|
||||
const event = new KittySequenceOverflowEvent(
|
||||
kittySequenceBuffer.length,
|
||||
kittySequenceBuffer,
|
||||
);
|
||||
logKittySequenceOverflow(config, event);
|
||||
}
|
||||
// Not a Kitty sequence, treat as regular key
|
||||
kittySequenceBuffer = '';
|
||||
} else {
|
||||
// Wait for more characters
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (key.name === 'paste-start') {
|
||||
isPaste = true;
|
||||
} else if (key.name === 'paste-end') {
|
||||
isPaste = false;
|
||||
onKeypressRef.current({
|
||||
name: '',
|
||||
ctrl: false,
|
||||
meta: false,
|
||||
shift: false,
|
||||
paste: true,
|
||||
sequence: pasteBuffer.toString(),
|
||||
});
|
||||
pasteBuffer = Buffer.alloc(0);
|
||||
} else {
|
||||
if (isPaste) {
|
||||
pasteBuffer = Buffer.concat([pasteBuffer, Buffer.from(key.sequence)]);
|
||||
} else {
|
||||
// Handle special keys
|
||||
if (key.name === 'return' && key.sequence === `${ESC}\r`) {
|
||||
key.meta = true;
|
||||
}
|
||||
onKeypressRef.current({ ...key, paste: isPaste });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleRawKeypress = (data: Buffer) => {
|
||||
const pasteModePrefixBuffer = Buffer.from(PASTE_MODE_PREFIX);
|
||||
const pasteModeSuffixBuffer = Buffer.from(PASTE_MODE_SUFFIX);
|
||||
|
||||
let pos = 0;
|
||||
while (pos < data.length) {
|
||||
const prefixPos = data.indexOf(pasteModePrefixBuffer, pos);
|
||||
const suffixPos = data.indexOf(pasteModeSuffixBuffer, pos);
|
||||
|
||||
// Determine which marker comes first, if any.
|
||||
const isPrefixNext =
|
||||
prefixPos !== -1 && (suffixPos === -1 || prefixPos < suffixPos);
|
||||
const isSuffixNext =
|
||||
suffixPos !== -1 && (prefixPos === -1 || suffixPos < prefixPos);
|
||||
|
||||
let nextMarkerPos = -1;
|
||||
let markerLength = 0;
|
||||
|
||||
if (isPrefixNext) {
|
||||
nextMarkerPos = prefixPos;
|
||||
} else if (isSuffixNext) {
|
||||
nextMarkerPos = suffixPos;
|
||||
}
|
||||
markerLength = pasteModeSuffixBuffer.length;
|
||||
|
||||
if (nextMarkerPos === -1) {
|
||||
keypressStream.write(data.slice(pos));
|
||||
return;
|
||||
}
|
||||
|
||||
const nextData = data.slice(pos, nextMarkerPos);
|
||||
if (nextData.length > 0) {
|
||||
keypressStream.write(nextData);
|
||||
}
|
||||
const createPasteKeyEvent = (
|
||||
name: 'paste-start' | 'paste-end',
|
||||
): Key => ({
|
||||
name,
|
||||
ctrl: false,
|
||||
meta: false,
|
||||
shift: false,
|
||||
paste: false,
|
||||
sequence: '',
|
||||
});
|
||||
if (isPrefixNext) {
|
||||
handleKeypress(undefined, createPasteKeyEvent('paste-start'));
|
||||
} else if (isSuffixNext) {
|
||||
handleKeypress(undefined, createPasteKeyEvent('paste-end'));
|
||||
}
|
||||
pos = nextMarkerPos + markerLength;
|
||||
}
|
||||
};
|
||||
|
||||
let rl: readline.Interface;
|
||||
if (usePassthrough) {
|
||||
rl = readline.createInterface({
|
||||
input: keypressStream,
|
||||
escapeCodeTimeout: 0,
|
||||
});
|
||||
readline.emitKeypressEvents(keypressStream, rl);
|
||||
keypressStream.on('keypress', handleKeypress);
|
||||
stdin.on('data', handleRawKeypress);
|
||||
} else {
|
||||
rl = readline.createInterface({ input: stdin, escapeCodeTimeout: 0 });
|
||||
readline.emitKeypressEvents(stdin, rl);
|
||||
stdin.on('keypress', handleKeypress);
|
||||
}
|
||||
|
||||
subscribe(onKeypress);
|
||||
return () => {
|
||||
if (usePassthrough) {
|
||||
keypressStream.removeListener('keypress', handleKeypress);
|
||||
stdin.removeListener('data', handleRawKeypress);
|
||||
} else {
|
||||
stdin.removeListener('keypress', handleKeypress);
|
||||
}
|
||||
rl.close();
|
||||
setRawMode(false);
|
||||
|
||||
// Clean up any pending backslash timeout
|
||||
if (backslashTimeout) {
|
||||
clearTimeout(backslashTimeout);
|
||||
backslashTimeout = null;
|
||||
}
|
||||
|
||||
// If we are in the middle of a paste, send what we have.
|
||||
if (isPaste) {
|
||||
onKeypressRef.current({
|
||||
name: '',
|
||||
ctrl: false,
|
||||
meta: false,
|
||||
shift: false,
|
||||
paste: true,
|
||||
sequence: pasteBuffer.toString(),
|
||||
});
|
||||
pasteBuffer = Buffer.alloc(0);
|
||||
}
|
||||
unsubscribe(onKeypress);
|
||||
};
|
||||
}, [isActive, stdin, setRawMode, kittyProtocolEnabled, config]);
|
||||
}, [isActive, onKeypress, subscribe, unsubscribe]);
|
||||
}
|
||||
|
||||
226
packages/cli/src/ui/hooks/useMessageQueue.test.ts
Normal file
226
packages/cli/src/ui/hooks/useMessageQueue.test.ts
Normal file
@@ -0,0 +1,226 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import { renderHook, act } from '@testing-library/react';
|
||||
import { useMessageQueue } from './useMessageQueue.js';
|
||||
import { StreamingState } from '../types.js';
|
||||
|
||||
describe('useMessageQueue', () => {
|
||||
let mockSubmitQuery: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
mockSubmitQuery = vi.fn();
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should initialize with empty queue', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useMessageQueue({
|
||||
streamingState: StreamingState.Idle,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
);
|
||||
|
||||
expect(result.current.messageQueue).toEqual([]);
|
||||
expect(result.current.getQueuedMessagesText()).toBe('');
|
||||
});
|
||||
|
||||
it('should add messages to queue', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useMessageQueue({
|
||||
streamingState: StreamingState.Responding,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.addMessage('Test message 1');
|
||||
result.current.addMessage('Test message 2');
|
||||
});
|
||||
|
||||
expect(result.current.messageQueue).toEqual([
|
||||
'Test message 1',
|
||||
'Test message 2',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should filter out empty messages', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useMessageQueue({
|
||||
streamingState: StreamingState.Responding,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.addMessage('Valid message');
|
||||
result.current.addMessage(' '); // Only whitespace
|
||||
result.current.addMessage(''); // Empty
|
||||
result.current.addMessage('Another valid message');
|
||||
});
|
||||
|
||||
expect(result.current.messageQueue).toEqual([
|
||||
'Valid message',
|
||||
'Another valid message',
|
||||
]);
|
||||
});
|
||||
|
||||
it('should clear queue', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useMessageQueue({
|
||||
streamingState: StreamingState.Responding,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.addMessage('Test message');
|
||||
});
|
||||
|
||||
expect(result.current.messageQueue).toEqual(['Test message']);
|
||||
|
||||
act(() => {
|
||||
result.current.clearQueue();
|
||||
});
|
||||
|
||||
expect(result.current.messageQueue).toEqual([]);
|
||||
});
|
||||
|
||||
it('should return queued messages as text with double newlines', () => {
|
||||
const { result } = renderHook(() =>
|
||||
useMessageQueue({
|
||||
streamingState: StreamingState.Responding,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result.current.addMessage('Message 1');
|
||||
result.current.addMessage('Message 2');
|
||||
result.current.addMessage('Message 3');
|
||||
});
|
||||
|
||||
expect(result.current.getQueuedMessagesText()).toBe(
|
||||
'Message 1\n\nMessage 2\n\nMessage 3',
|
||||
);
|
||||
});
|
||||
|
||||
it('should auto-submit queued messages when transitioning to Idle', () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ streamingState }) =>
|
||||
useMessageQueue({
|
||||
streamingState,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
{
|
||||
initialProps: { streamingState: StreamingState.Responding },
|
||||
},
|
||||
);
|
||||
|
||||
// Add some messages
|
||||
act(() => {
|
||||
result.current.addMessage('Message 1');
|
||||
result.current.addMessage('Message 2');
|
||||
});
|
||||
|
||||
expect(result.current.messageQueue).toEqual(['Message 1', 'Message 2']);
|
||||
|
||||
// Transition to Idle
|
||||
rerender({ streamingState: StreamingState.Idle });
|
||||
|
||||
expect(mockSubmitQuery).toHaveBeenCalledWith('Message 1\n\nMessage 2');
|
||||
expect(result.current.messageQueue).toEqual([]);
|
||||
});
|
||||
|
||||
it('should not auto-submit when queue is empty', () => {
|
||||
const { rerender } = renderHook(
|
||||
({ streamingState }) =>
|
||||
useMessageQueue({
|
||||
streamingState,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
{
|
||||
initialProps: { streamingState: StreamingState.Responding },
|
||||
},
|
||||
);
|
||||
|
||||
// Transition to Idle with empty queue
|
||||
rerender({ streamingState: StreamingState.Idle });
|
||||
|
||||
expect(mockSubmitQuery).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not auto-submit when not transitioning to Idle', () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ streamingState }) =>
|
||||
useMessageQueue({
|
||||
streamingState,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
{
|
||||
initialProps: { streamingState: StreamingState.Responding },
|
||||
},
|
||||
);
|
||||
|
||||
// Add messages
|
||||
act(() => {
|
||||
result.current.addMessage('Message 1');
|
||||
});
|
||||
|
||||
// Transition to WaitingForConfirmation (not Idle)
|
||||
rerender({ streamingState: StreamingState.WaitingForConfirmation });
|
||||
|
||||
expect(mockSubmitQuery).not.toHaveBeenCalled();
|
||||
expect(result.current.messageQueue).toEqual(['Message 1']);
|
||||
});
|
||||
|
||||
it('should handle multiple state transitions correctly', () => {
|
||||
const { result, rerender } = renderHook(
|
||||
({ streamingState }) =>
|
||||
useMessageQueue({
|
||||
streamingState,
|
||||
submitQuery: mockSubmitQuery,
|
||||
}),
|
||||
{
|
||||
initialProps: { streamingState: StreamingState.Idle },
|
||||
},
|
||||
);
|
||||
|
||||
// Start responding
|
||||
rerender({ streamingState: StreamingState.Responding });
|
||||
|
||||
// Add messages while responding
|
||||
act(() => {
|
||||
result.current.addMessage('First batch');
|
||||
});
|
||||
|
||||
// Go back to idle - should submit
|
||||
rerender({ streamingState: StreamingState.Idle });
|
||||
|
||||
expect(mockSubmitQuery).toHaveBeenCalledWith('First batch');
|
||||
expect(result.current.messageQueue).toEqual([]);
|
||||
|
||||
// Start responding again
|
||||
rerender({ streamingState: StreamingState.Responding });
|
||||
|
||||
// Add more messages
|
||||
act(() => {
|
||||
result.current.addMessage('Second batch');
|
||||
});
|
||||
|
||||
// Go back to idle - should submit again
|
||||
rerender({ streamingState: StreamingState.Idle });
|
||||
|
||||
expect(mockSubmitQuery).toHaveBeenCalledWith('Second batch');
|
||||
expect(mockSubmitQuery).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
69
packages/cli/src/ui/hooks/useMessageQueue.ts
Normal file
69
packages/cli/src/ui/hooks/useMessageQueue.ts
Normal file
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { StreamingState } from '../types.js';
|
||||
|
||||
export interface UseMessageQueueOptions {
|
||||
streamingState: StreamingState;
|
||||
submitQuery: (query: string) => void;
|
||||
}
|
||||
|
||||
export interface UseMessageQueueReturn {
|
||||
messageQueue: string[];
|
||||
addMessage: (message: string) => void;
|
||||
clearQueue: () => void;
|
||||
getQueuedMessagesText: () => string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for managing message queuing during streaming responses.
|
||||
* Allows users to queue messages while the AI is responding and automatically
|
||||
* sends them when streaming completes.
|
||||
*/
|
||||
export function useMessageQueue({
|
||||
streamingState,
|
||||
submitQuery,
|
||||
}: UseMessageQueueOptions): UseMessageQueueReturn {
|
||||
const [messageQueue, setMessageQueue] = useState<string[]>([]);
|
||||
|
||||
// Add a message to the queue
|
||||
const addMessage = useCallback((message: string) => {
|
||||
const trimmedMessage = message.trim();
|
||||
if (trimmedMessage.length > 0) {
|
||||
setMessageQueue((prev) => [...prev, trimmedMessage]);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Clear the entire queue
|
||||
const clearQueue = useCallback(() => {
|
||||
setMessageQueue([]);
|
||||
}, []);
|
||||
|
||||
// Get all queued messages as a single text string
|
||||
const getQueuedMessagesText = useCallback(() => {
|
||||
if (messageQueue.length === 0) return '';
|
||||
return messageQueue.join('\n\n');
|
||||
}, [messageQueue]);
|
||||
|
||||
// Process queued messages when streaming becomes idle
|
||||
useEffect(() => {
|
||||
if (streamingState === StreamingState.Idle && messageQueue.length > 0) {
|
||||
// Combine all messages with double newlines for clarity
|
||||
const combinedMessage = messageQueue.join('\n\n');
|
||||
// Clear the queue and submit
|
||||
setMessageQueue([]);
|
||||
submitQuery(combinedMessage);
|
||||
}
|
||||
}, [streamingState, messageQueue, submitQuery]);
|
||||
|
||||
return {
|
||||
messageQueue,
|
||||
addMessage,
|
||||
clearQueue,
|
||||
getQueuedMessagesText,
|
||||
};
|
||||
}
|
||||
242
packages/cli/src/ui/hooks/usePrivacySettings.test.ts
Normal file
242
packages/cli/src/ui/hooks/usePrivacySettings.test.ts
Normal file
@@ -0,0 +1,242 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { renderHook, waitFor } from '@testing-library/react';
|
||||
import {
|
||||
Config,
|
||||
CodeAssistServer,
|
||||
LoggingContentGenerator,
|
||||
UserTierId,
|
||||
GeminiClient,
|
||||
ContentGenerator,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
import { OAuth2Client } from 'google-auth-library';
|
||||
import { usePrivacySettings } from './usePrivacySettings.js';
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock('@qwen-code/qwen-code-core', () => {
|
||||
// Mock classes for instanceof checks
|
||||
class MockCodeAssistServer {
|
||||
projectId = 'test-project-id';
|
||||
loadCodeAssist = vi.fn();
|
||||
getCodeAssistGlobalUserSetting = vi.fn();
|
||||
setCodeAssistGlobalUserSetting = vi.fn();
|
||||
|
||||
constructor(
|
||||
_client?: GeminiClient,
|
||||
_projectId?: string,
|
||||
_httpOptions?: Record<string, unknown>,
|
||||
_sessionId?: string,
|
||||
_userTier?: UserTierId,
|
||||
) {}
|
||||
}
|
||||
|
||||
class MockLoggingContentGenerator {
|
||||
getWrapped = vi.fn();
|
||||
|
||||
constructor(
|
||||
_wrapped?: ContentGenerator,
|
||||
_config?: Record<string, unknown>,
|
||||
) {}
|
||||
}
|
||||
|
||||
return {
|
||||
Config: vi.fn(),
|
||||
CodeAssistServer: MockCodeAssistServer,
|
||||
LoggingContentGenerator: MockLoggingContentGenerator,
|
||||
GeminiClient: vi.fn(),
|
||||
UserTierId: {
|
||||
FREE: 'free-tier',
|
||||
LEGACY: 'legacy-tier',
|
||||
STANDARD: 'standard-tier',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
describe('usePrivacySettings', () => {
|
||||
let mockConfig: Config;
|
||||
let mockClient: GeminiClient;
|
||||
let mockCodeAssistServer: CodeAssistServer;
|
||||
let mockLoggingContentGenerator: LoggingContentGenerator;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Create mock CodeAssistServer instance
|
||||
mockCodeAssistServer = new CodeAssistServer(
|
||||
null as unknown as OAuth2Client,
|
||||
'test-project-id',
|
||||
) as unknown as CodeAssistServer;
|
||||
(
|
||||
mockCodeAssistServer.loadCodeAssist as ReturnType<typeof vi.fn>
|
||||
).mockResolvedValue({
|
||||
currentTier: { id: UserTierId.FREE },
|
||||
});
|
||||
(
|
||||
mockCodeAssistServer.getCodeAssistGlobalUserSetting as ReturnType<
|
||||
typeof vi.fn
|
||||
>
|
||||
).mockResolvedValue({
|
||||
freeTierDataCollectionOptin: true,
|
||||
});
|
||||
(
|
||||
mockCodeAssistServer.setCodeAssistGlobalUserSetting as ReturnType<
|
||||
typeof vi.fn
|
||||
>
|
||||
).mockResolvedValue({
|
||||
freeTierDataCollectionOptin: false,
|
||||
});
|
||||
|
||||
// Create mock LoggingContentGenerator that wraps the CodeAssistServer
|
||||
mockLoggingContentGenerator = new LoggingContentGenerator(
|
||||
mockCodeAssistServer,
|
||||
null as unknown as Config,
|
||||
) as unknown as LoggingContentGenerator;
|
||||
(
|
||||
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue(mockCodeAssistServer);
|
||||
|
||||
// Create mock GeminiClient
|
||||
mockClient = {
|
||||
getContentGenerator: vi.fn().mockReturnValue(mockLoggingContentGenerator),
|
||||
} as unknown as GeminiClient;
|
||||
|
||||
// Create mock Config
|
||||
mockConfig = {
|
||||
getGeminiClient: vi.fn().mockReturnValue(mockClient),
|
||||
} as unknown as Config;
|
||||
});
|
||||
|
||||
it('should handle LoggingContentGenerator wrapper correctly and not throw "Oauth not being used" error', async () => {
|
||||
const { result } = renderHook(() => usePrivacySettings(mockConfig));
|
||||
|
||||
// Initial state should be loading
|
||||
expect(result.current.privacyState.isLoading).toBe(true);
|
||||
expect(result.current.privacyState.error).toBeUndefined();
|
||||
|
||||
// Wait for the hook to complete
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
// Should not have the "Oauth not being used" error
|
||||
expect(result.current.privacyState.error).toBeUndefined();
|
||||
expect(result.current.privacyState.isFreeTier).toBe(true);
|
||||
expect(result.current.privacyState.dataCollectionOptIn).toBe(true);
|
||||
|
||||
// Verify that getWrapped was called to unwrap the LoggingContentGenerator
|
||||
expect(mockLoggingContentGenerator.getWrapped).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should work with direct CodeAssistServer (no wrapper)', async () => {
|
||||
// Test case where the content generator is directly a CodeAssistServer
|
||||
const directServer = new CodeAssistServer(
|
||||
null as unknown as OAuth2Client,
|
||||
'test-project-id',
|
||||
) as unknown as CodeAssistServer;
|
||||
(directServer.loadCodeAssist as ReturnType<typeof vi.fn>).mockResolvedValue(
|
||||
{
|
||||
currentTier: { id: UserTierId.FREE },
|
||||
},
|
||||
);
|
||||
(
|
||||
directServer.getCodeAssistGlobalUserSetting as ReturnType<typeof vi.fn>
|
||||
).mockResolvedValue({
|
||||
freeTierDataCollectionOptin: true,
|
||||
});
|
||||
|
||||
mockClient.getContentGenerator = vi.fn().mockReturnValue(directServer);
|
||||
|
||||
const { result } = renderHook(() => usePrivacySettings(mockConfig));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
expect(result.current.privacyState.error).toBeUndefined();
|
||||
expect(result.current.privacyState.isFreeTier).toBe(true);
|
||||
expect(result.current.privacyState.dataCollectionOptIn).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle paid tier users correctly', async () => {
|
||||
// Mock paid tier response
|
||||
(
|
||||
mockCodeAssistServer.loadCodeAssist as ReturnType<typeof vi.fn>
|
||||
).mockResolvedValue({
|
||||
currentTier: { id: UserTierId.STANDARD },
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => usePrivacySettings(mockConfig));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
expect(result.current.privacyState.error).toBeUndefined();
|
||||
expect(result.current.privacyState.isFreeTier).toBe(false);
|
||||
expect(result.current.privacyState.dataCollectionOptIn).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should throw error when content generator is not a CodeAssistServer', async () => {
|
||||
// Mock a non-CodeAssistServer content generator
|
||||
const mockOtherGenerator = { someOtherMethod: vi.fn() };
|
||||
(
|
||||
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue(mockOtherGenerator);
|
||||
|
||||
const { result } = renderHook(() => usePrivacySettings(mockConfig));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
expect(result.current.privacyState.error).toBe('Oauth not being used');
|
||||
});
|
||||
|
||||
it('should throw error when CodeAssistServer has no projectId', async () => {
|
||||
// Mock CodeAssistServer without projectId
|
||||
const mockServerNoProject = {
|
||||
...mockCodeAssistServer,
|
||||
projectId: undefined,
|
||||
};
|
||||
(
|
||||
mockLoggingContentGenerator.getWrapped as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue(mockServerNoProject);
|
||||
|
||||
const { result } = renderHook(() => usePrivacySettings(mockConfig));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
expect(result.current.privacyState.error).toBe('Oauth not being used');
|
||||
});
|
||||
|
||||
it('should update data collection opt-in setting', async () => {
|
||||
const { result } = renderHook(() => usePrivacySettings(mockConfig));
|
||||
|
||||
// Wait for initial load
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
// Update the setting
|
||||
await result.current.updateDataCollectionOptIn(false);
|
||||
|
||||
// Wait for update to complete
|
||||
await waitFor(() => {
|
||||
expect(result.current.privacyState.dataCollectionOptIn).toBe(false);
|
||||
});
|
||||
|
||||
expect(
|
||||
mockCodeAssistServer.setCodeAssistGlobalUserSetting,
|
||||
).toHaveBeenCalledWith({
|
||||
cloudaicompanionProject: 'test-project-id',
|
||||
freeTierDataCollectionOptin: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
Config,
|
||||
CodeAssistServer,
|
||||
UserTierId,
|
||||
LoggingContentGenerator,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
|
||||
export interface PrivacyState {
|
||||
@@ -84,7 +85,13 @@ export const usePrivacySettings = (config: Config) => {
|
||||
};
|
||||
|
||||
function getCodeAssistServer(config: Config): CodeAssistServer {
|
||||
const server = config.getGeminiClient().getContentGenerator();
|
||||
let server = config.getGeminiClient().getContentGenerator();
|
||||
|
||||
// Unwrap LoggingContentGenerator if present
|
||||
if (server instanceof LoggingContentGenerator) {
|
||||
server = server.getWrapped();
|
||||
}
|
||||
|
||||
// Neither of these cases should ever happen.
|
||||
if (!(server instanceof CodeAssistServer)) {
|
||||
throw new Error('Oauth not being used');
|
||||
|
||||
@@ -39,7 +39,7 @@ export const useThemeCommand = (
|
||||
}, [loadedSettings.merged.theme, setThemeError]);
|
||||
|
||||
const openThemeDialog = useCallback(() => {
|
||||
if (process.env.NO_COLOR) {
|
||||
if (process.env['NO_COLOR']) {
|
||||
addItem(
|
||||
{
|
||||
type: MessageType.INFO,
|
||||
|
||||
@@ -24,7 +24,9 @@ import {
|
||||
Status as ToolCallStatusType,
|
||||
ApprovalMode,
|
||||
Kind,
|
||||
BaseTool,
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
ToolInvocation,
|
||||
AnyDeclarativeTool,
|
||||
AnyToolInvocation,
|
||||
} from '@qwen-code/qwen-code-core';
|
||||
@@ -53,9 +55,48 @@ const mockConfig = {
|
||||
getApprovalMode: vi.fn(() => ApprovalMode.DEFAULT),
|
||||
getUsageStatisticsEnabled: () => true,
|
||||
getDebugMode: () => false,
|
||||
getSessionId: () => 'test-session-id',
|
||||
getContentGeneratorConfig: () => ({
|
||||
model: 'test-model',
|
||||
authType: 'oauth-personal',
|
||||
}),
|
||||
};
|
||||
|
||||
class MockTool extends BaseTool<object, ToolResult> {
|
||||
class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
|
||||
constructor(
|
||||
private readonly tool: MockTool,
|
||||
params: object,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return JSON.stringify(this.params);
|
||||
}
|
||||
|
||||
override shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return this.tool.shouldConfirmExecute(this.params, abortSignal);
|
||||
}
|
||||
|
||||
execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
terminalColumns?: number,
|
||||
terminalRows?: number,
|
||||
): Promise<ToolResult> {
|
||||
return this.tool.execute(
|
||||
this.params,
|
||||
signal,
|
||||
updateOutput,
|
||||
terminalColumns,
|
||||
terminalRows,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class MockTool extends BaseDeclarativeTool<object, ToolResult> {
|
||||
constructor(
|
||||
name: string,
|
||||
displayName: string,
|
||||
@@ -73,11 +114,12 @@ class MockTool extends BaseTool<object, ToolResult> {
|
||||
canUpdateOutput,
|
||||
);
|
||||
if (shouldConfirm) {
|
||||
this.shouldConfirmExecute = vi.fn(
|
||||
this.shouldConfirmExecute.mockImplementation(
|
||||
async (): Promise<ToolCallConfirmationDetails | false> => ({
|
||||
type: 'edit',
|
||||
title: 'Mock Tool Requires Confirmation',
|
||||
onConfirm: mockOnUserConfirmForToolConfirmation,
|
||||
filePath: 'mock',
|
||||
fileName: 'mockToolRequiresConfirmation.ts',
|
||||
fileDiff: 'Mock tool requires confirmation',
|
||||
originalContent: 'Original content',
|
||||
@@ -89,6 +131,12 @@ class MockTool extends BaseTool<object, ToolResult> {
|
||||
|
||||
execute = vi.fn();
|
||||
shouldConfirmExecute = vi.fn();
|
||||
|
||||
protected createInvocation(
|
||||
params: object,
|
||||
): ToolInvocation<object, ToolResult> {
|
||||
return new MockToolInvocation(this, params);
|
||||
}
|
||||
}
|
||||
|
||||
const mockTool = new MockTool('mockTool', 'Mock Tool');
|
||||
@@ -135,6 +183,8 @@ describe('useReactToolScheduler in YOLO Mode', () => {
|
||||
onComplete,
|
||||
mockConfig as unknown as Config,
|
||||
setPendingHistoryItem,
|
||||
() => undefined,
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
@@ -153,7 +203,7 @@ describe('useReactToolScheduler in YOLO Mode', () => {
|
||||
callId: 'yoloCall',
|
||||
name: 'mockToolRequiresConfirmation',
|
||||
args: { data: 'any data' },
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -179,6 +229,8 @@ describe('useReactToolScheduler in YOLO Mode', () => {
|
||||
request.args,
|
||||
expect.any(AbortSignal),
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
|
||||
// Check that onComplete was called with success
|
||||
@@ -261,13 +313,14 @@ describe('useReactToolScheduler', () => {
|
||||
(
|
||||
mockToolRequiresConfirmation.shouldConfirmExecute as Mock
|
||||
).mockImplementation(
|
||||
async (): Promise<ToolCallConfirmationDetails | null> => ({
|
||||
onConfirm: mockOnUserConfirmForToolConfirmation,
|
||||
fileName: 'mockToolRequiresConfirmation.ts',
|
||||
fileDiff: 'Mock tool requires confirmation',
|
||||
type: 'edit',
|
||||
title: 'Mock Tool Requires Confirmation',
|
||||
}),
|
||||
async (): Promise<ToolCallConfirmationDetails | null> =>
|
||||
({
|
||||
onConfirm: mockOnUserConfirmForToolConfirmation,
|
||||
fileName: 'mockToolRequiresConfirmation.ts',
|
||||
fileDiff: 'Mock tool requires confirmation',
|
||||
type: 'edit',
|
||||
title: 'Mock Tool Requires Confirmation',
|
||||
}) as any,
|
||||
);
|
||||
|
||||
vi.useFakeTimers();
|
||||
@@ -284,6 +337,8 @@ describe('useReactToolScheduler', () => {
|
||||
onComplete,
|
||||
mockConfig as unknown as Config,
|
||||
setPendingHistoryItem,
|
||||
() => undefined,
|
||||
() => {},
|
||||
),
|
||||
);
|
||||
|
||||
@@ -307,7 +362,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'call1',
|
||||
name: 'mockTool',
|
||||
args: { param: 'value' },
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -326,6 +381,8 @@ describe('useReactToolScheduler', () => {
|
||||
request.args,
|
||||
expect.any(AbortSignal),
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
expect(onComplete).toHaveBeenCalledWith([
|
||||
expect.objectContaining({
|
||||
@@ -354,7 +411,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'call1',
|
||||
name: 'nonexistentTool',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -391,7 +448,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'call1',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -427,7 +484,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'call1',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -469,7 +526,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'callConfirm',
|
||||
name: 'mockToolRequiresConfirmation',
|
||||
args: { data: 'sensitive' },
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -525,7 +582,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'callConfirmCancel',
|
||||
name: 'mockToolRequiresConfirmation',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -597,7 +654,7 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'liveCall',
|
||||
name: 'mockToolWithLiveOutput',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request, new AbortController().signal);
|
||||
@@ -682,8 +739,8 @@ describe('useReactToolScheduler', () => {
|
||||
const { result } = renderScheduler();
|
||||
const schedule = result.current[1];
|
||||
const requests: ToolCallRequestInfo[] = [
|
||||
{ callId: 'multi1', name: 'tool1', args: { p: 1 } },
|
||||
{ callId: 'multi2', name: 'tool2', args: { p: 2 } },
|
||||
{ callId: 'multi1', name: 'tool1', args: { p: 1 } } as any,
|
||||
{ callId: 'multi2', name: 'tool2', args: { p: 2 } } as any,
|
||||
];
|
||||
|
||||
act(() => {
|
||||
@@ -766,12 +823,12 @@ describe('useReactToolScheduler', () => {
|
||||
callId: 'run1',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
const request2: ToolCallRequestInfo = {
|
||||
callId: 'run2',
|
||||
name: 'mockTool',
|
||||
args: {},
|
||||
};
|
||||
} as any;
|
||||
|
||||
act(() => {
|
||||
schedule(request1, new AbortController().signal);
|
||||
@@ -807,7 +864,7 @@ describe('mapToDisplay', () => {
|
||||
callId: 'testCallId',
|
||||
name: 'testTool',
|
||||
args: { foo: 'bar' },
|
||||
};
|
||||
} as any;
|
||||
|
||||
const baseTool = new MockTool('testTool', 'Test Tool Display');
|
||||
|
||||
@@ -823,9 +880,8 @@ describe('mapToDisplay', () => {
|
||||
} as PartUnion,
|
||||
],
|
||||
resultDisplay: 'Test display output',
|
||||
summary: 'Test summary',
|
||||
error: undefined,
|
||||
};
|
||||
} as any;
|
||||
|
||||
// Define a more specific type for extraProps for these tests
|
||||
// This helps ensure that tool and confirmationDetails are only accessed when they are expected to exist.
|
||||
@@ -871,7 +927,7 @@ describe('mapToDisplay', () => {
|
||||
extraProps: { tool: baseTool, invocation: baseInvocation },
|
||||
expectedStatus: ToolCallStatus.Executing,
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'awaiting_approval',
|
||||
@@ -886,6 +942,7 @@ describe('mapToDisplay', () => {
|
||||
serverName: 'testTool',
|
||||
toolName: 'testTool',
|
||||
toolDisplayName: 'Test Tool Display',
|
||||
filePath: 'mock',
|
||||
fileName: 'test.ts',
|
||||
fileDiff: 'Test diff',
|
||||
originalContent: 'Original content',
|
||||
@@ -894,7 +951,7 @@ describe('mapToDisplay', () => {
|
||||
},
|
||||
expectedStatus: ToolCallStatus.Confirming,
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'scheduled',
|
||||
@@ -902,7 +959,7 @@ describe('mapToDisplay', () => {
|
||||
extraProps: { tool: baseTool, invocation: baseInvocation },
|
||||
expectedStatus: ToolCallStatus.Pending,
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'executing no live output',
|
||||
@@ -910,7 +967,7 @@ describe('mapToDisplay', () => {
|
||||
extraProps: { tool: baseTool, invocation: baseInvocation },
|
||||
expectedStatus: ToolCallStatus.Executing,
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'executing with live output',
|
||||
@@ -923,7 +980,7 @@ describe('mapToDisplay', () => {
|
||||
expectedStatus: ToolCallStatus.Executing,
|
||||
expectedResultDisplay: 'Live test output',
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'success',
|
||||
@@ -936,7 +993,7 @@ describe('mapToDisplay', () => {
|
||||
expectedStatus: ToolCallStatus.Success,
|
||||
expectedResultDisplay: baseResponse.resultDisplay as any,
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'error tool not found',
|
||||
@@ -967,7 +1024,7 @@ describe('mapToDisplay', () => {
|
||||
expectedStatus: ToolCallStatus.Error,
|
||||
expectedResultDisplay: 'Execution failed display',
|
||||
expectedName: baseTool.displayName, // Changed from baseTool.name
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
{
|
||||
name: 'cancelled',
|
||||
@@ -983,7 +1040,7 @@ describe('mapToDisplay', () => {
|
||||
expectedStatus: ToolCallStatus.Canceled,
|
||||
expectedResultDisplay: 'Cancelled display',
|
||||
expectedName: baseTool.displayName,
|
||||
expectedDescription: baseTool.getDescription(baseRequest.args),
|
||||
expectedDescription: baseInvocation.getDescription(),
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
Reference in New Issue
Block a user