Sync upstream Gemini-CLI v0.8.2 (#838)

This commit is contained in:
tanzhenxin
2025-10-23 09:27:04 +08:00
committed by GitHub
parent 096fabb5d6
commit eb95c131be
644 changed files with 70389 additions and 23709 deletions

View File

@@ -26,7 +26,14 @@ import {
vi.mock('node:fs');
vi.mock('node:path');
vi.mock('node:crypto');
vi.mock('node:crypto', () => ({
randomUUID: vi.fn(),
createHash: vi.fn(() => ({
update: vi.fn(() => ({
digest: vi.fn(() => 'mocked-hash'),
})),
})),
}));
vi.mock('../utils/paths.js');
describe('ChatRecordingService', () => {
@@ -47,6 +54,13 @@ describe('ChatRecordingService', () => {
},
getModel: vi.fn().mockReturnValue('gemini-pro'),
getDebugMode: vi.fn().mockReturnValue(false),
getToolRegistry: vi.fn().mockReturnValue({
getTool: vi.fn().mockReturnValue({
displayName: 'Test Tool',
description: 'A test tool',
isOutputMarkdown: false,
}),
}),
} as unknown as Config;
vi.mocked(getProjectHash).mockReturnValue('test-project-hash');
@@ -120,7 +134,11 @@ describe('ChatRecordingService', () => {
const writeFileSyncSpy = vi
.spyOn(fs, 'writeFileSync')
.mockImplementation(() => undefined);
chatRecordingService.recordMessage({ type: 'user', content: 'Hello' });
chatRecordingService.recordMessage({
type: 'user',
content: 'Hello',
model: 'gemini-pro',
});
expect(mkdirSyncSpy).toHaveBeenCalled();
expect(writeFileSyncSpy).toHaveBeenCalled();
const conversation = JSON.parse(
@@ -131,7 +149,7 @@ describe('ChatRecordingService', () => {
expect(conversation.messages[0].type).toBe('user');
});
it('should append to the last message if append is true and types match', () => {
it('should create separate messages when recording multiple messages', () => {
const writeFileSyncSpy = vi
.spyOn(fs, 'writeFileSync')
.mockImplementation(() => undefined);
@@ -153,8 +171,8 @@ describe('ChatRecordingService', () => {
chatRecordingService.recordMessage({
type: 'user',
content: ' World',
append: true,
content: 'World',
model: 'gemini-pro',
});
expect(mkdirSyncSpy).toHaveBeenCalled();
@@ -162,8 +180,9 @@ describe('ChatRecordingService', () => {
const conversation = JSON.parse(
writeFileSyncSpy.mock.calls[0][1] as string,
) as ConversationRecord;
expect(conversation.messages).toHaveLength(1);
expect(conversation.messages[0].content).toBe('Hello World');
expect(conversation.messages).toHaveLength(2);
expect(conversation.messages[0].content).toBe('Hello');
expect(conversation.messages[1].content).toBe('World');
});
});
@@ -200,7 +219,7 @@ describe('ChatRecordingService', () => {
messages: [
{
id: '1',
type: 'gemini',
type: 'qwen',
content: 'Response',
timestamp: new Date().toISOString(),
},
@@ -211,10 +230,10 @@ describe('ChatRecordingService', () => {
);
chatRecordingService.recordMessageTokens({
input: 1,
output: 2,
total: 3,
cached: 0,
promptTokenCount: 1,
candidatesTokenCount: 2,
totalTokenCount: 3,
cachedContentTokenCount: 0,
});
expect(mkdirSyncSpy).toHaveBeenCalled();
@@ -224,7 +243,14 @@ describe('ChatRecordingService', () => {
) as ConversationRecord;
expect(conversation.messages[0]).toEqual({
...initialConversation.messages[0],
tokens: { input: 1, output: 2, total: 3, cached: 0 },
tokens: {
input: 1,
output: 2,
total: 3,
cached: 0,
thoughts: 0,
tool: 0,
},
});
});
@@ -235,7 +261,7 @@ describe('ChatRecordingService', () => {
messages: [
{
id: '1',
type: 'gemini',
type: 'qwen',
content: 'Response',
timestamp: new Date().toISOString(),
tokens: { input: 1, output: 1, total: 2, cached: 0 },
@@ -247,10 +273,10 @@ describe('ChatRecordingService', () => {
);
chatRecordingService.recordMessageTokens({
input: 2,
output: 2,
total: 4,
cached: 0,
promptTokenCount: 2,
candidatesTokenCount: 2,
totalTokenCount: 4,
cachedContentTokenCount: 0,
});
// @ts-expect-error private property
@@ -259,6 +285,8 @@ describe('ChatRecordingService', () => {
output: 2,
total: 4,
cached: 0,
thoughts: 0,
tool: 0,
});
});
});
@@ -278,7 +306,7 @@ describe('ChatRecordingService', () => {
messages: [
{
id: '1',
type: 'gemini',
type: 'qwen',
content: '',
timestamp: new Date().toISOString(),
},
@@ -295,7 +323,7 @@ describe('ChatRecordingService', () => {
status: 'awaiting_approval',
timestamp: new Date().toISOString(),
};
chatRecordingService.recordToolCalls([toolCall]);
chatRecordingService.recordToolCalls('gemini-pro', [toolCall]);
expect(mkdirSyncSpy).toHaveBeenCalled();
expect(writeFileSyncSpy).toHaveBeenCalled();
@@ -304,7 +332,14 @@ describe('ChatRecordingService', () => {
) as ConversationRecord;
expect(conversation.messages[0]).toEqual({
...initialConversation.messages[0],
toolCalls: [toolCall],
toolCalls: [
{
...toolCall,
displayName: 'Test Tool',
description: 'A test tool',
renderOutputAsMarkdown: false,
},
],
});
});
@@ -335,7 +370,7 @@ describe('ChatRecordingService', () => {
status: 'awaiting_approval',
timestamp: new Date().toISOString(),
};
chatRecordingService.recordToolCalls([toolCall]);
chatRecordingService.recordToolCalls('gemini-pro', [toolCall]);
expect(mkdirSyncSpy).toHaveBeenCalled();
expect(writeFileSyncSpy).toHaveBeenCalled();
@@ -347,10 +382,17 @@ describe('ChatRecordingService', () => {
...conversation.messages[1],
id: 'this-is-a-test-uuid',
model: 'gemini-pro',
type: 'gemini',
type: 'qwen',
thoughts: [],
content: '',
toolCalls: [toolCall],
toolCalls: [
{
...toolCall,
displayName: 'Test Tool',
description: 'A test tool',
renderOutputAsMarkdown: false,
},
],
});
});
});

View File

@@ -6,12 +6,15 @@
import { type Config } from '../config/config.js';
import { type Status } from '../core/coreToolScheduler.js';
import { type ThoughtSummary } from '../core/turn.js';
import { type ThoughtSummary } from '../utils/thoughtUtils.js';
import { getProjectHash } from '../utils/paths.js';
import path from 'node:path';
import fs from 'node:fs';
import { randomUUID } from 'node:crypto';
import type { PartListUnion } from '@google/genai';
import type {
PartListUnion,
GenerateContentResponseUsageMetadata,
} from '@google/genai';
/**
* Token usage summary for a message or conversation.
@@ -31,7 +34,7 @@ export interface TokensSummary {
export interface BaseMessageRecord {
id: string;
timestamp: string;
content: string;
content: PartListUnion;
}
/**
@@ -59,7 +62,7 @@ export type ConversationRecordExtra =
type: 'user';
}
| {
type: 'gemini';
type: 'qwen';
toolCalls?: ToolCallRecord[];
thoughts?: Array<ThoughtSummary & { timestamp: string }>;
tokens?: TokensSummary | null;
@@ -99,7 +102,7 @@ export interface ResumedSessionData {
* - Token usage statistics
* - Assistant thoughts and reasoning
*
* Sessions are stored as JSON files in ~/.gemini/tmp/<project_hash>/chats/
* Sessions are stored as JSON files in ~/.qwen/tmp/<project_hash>/chats/
*/
export class ChatRecordingService {
private conversationFile: string | null = null;
@@ -178,7 +181,7 @@ export class ChatRecordingService {
private newMessage(
type: ConversationRecordExtra['type'],
content: string,
content: PartListUnion,
): MessageRecord {
return {
id: randomUUID(),
@@ -192,31 +195,22 @@ export class ChatRecordingService {
* Records a message in the conversation.
*/
recordMessage(message: {
model: string;
type: ConversationRecordExtra['type'];
content: string;
append?: boolean;
content: PartListUnion;
}): void {
if (!this.conversationFile) return;
try {
this.updateConversation((conversation) => {
if (message.append) {
const lastMsg = this.getLastMessage(conversation);
if (lastMsg && lastMsg.type === message.type) {
lastMsg.content += message.content;
return;
}
}
// We're not appending, or we are appending but the last message's type is not the same as
// the specified type, so just create a new message.
const msg = this.newMessage(message.type, message.content);
if (msg.type === 'gemini') {
if (msg.type === 'qwen') {
// If it's a new Gemini message then incorporate any queued thoughts.
conversation.messages.push({
...msg,
thoughts: this.queuedThoughts,
tokens: this.queuedTokens,
model: this.config.getModel(),
model: message.model,
});
this.queuedThoughts = [];
this.queuedTokens = null;
@@ -243,32 +237,33 @@ export class ChatRecordingService {
timestamp: new Date().toISOString(),
});
} catch (error) {
if (this.config.getDebugMode()) {
console.error('Error saving thought:', error);
throw error;
}
console.error('Error saving thought:', error);
throw error;
}
}
/**
* Updates the tokens for the last message in the conversation (which should be by Gemini).
*/
recordMessageTokens(tokens: {
input: number;
output: number;
cached: number;
thoughts?: number;
tool?: number;
total: number;
}): void {
recordMessageTokens(
respUsageMetadata: GenerateContentResponseUsageMetadata,
): void {
if (!this.conversationFile) return;
try {
const tokens = {
input: respUsageMetadata.promptTokenCount ?? 0,
output: respUsageMetadata.candidatesTokenCount ?? 0,
cached: respUsageMetadata.cachedContentTokenCount ?? 0,
thoughts: respUsageMetadata.thoughtsTokenCount ?? 0,
tool: respUsageMetadata.toolUsePromptTokenCount ?? 0,
total: respUsageMetadata.totalTokenCount ?? 0,
};
this.updateConversation((conversation) => {
const lastMsg = this.getLastMessage(conversation);
// If the last message already has token info, it's because this new token info is for a
// new message that hasn't been recorded yet.
if (lastMsg && lastMsg.type === 'gemini' && !lastMsg.tokens) {
if (lastMsg && lastMsg.type === 'qwen' && !lastMsg.tokens) {
lastMsg.tokens = tokens;
this.queuedTokens = null;
} else {
@@ -283,10 +278,23 @@ export class ChatRecordingService {
/**
* Adds tool calls to the last message in the conversation (which should be by Gemini).
* This method enriches tool calls with metadata from the ToolRegistry.
*/
recordToolCalls(toolCalls: ToolCallRecord[]): void {
recordToolCalls(model: string, toolCalls: ToolCallRecord[]): void {
if (!this.conversationFile) return;
// Enrich tool calls with metadata from the ToolRegistry
const toolRegistry = this.config.getToolRegistry();
const enrichedToolCalls = toolCalls.map((toolCall) => {
const toolInstance = toolRegistry.getTool(toolCall.name);
return {
...toolCall,
displayName: toolInstance?.displayName || toolCall.name,
description: toolInstance?.description || '',
renderOutputAsMarkdown: toolInstance?.isOutputMarkdown || false,
};
});
try {
this.updateConversation((conversation) => {
const lastMsg = this.getLastMessage(conversation);
@@ -299,19 +307,19 @@ export class ChatRecordingService {
// message from tool calls, when we dequeued the thoughts.
if (
!lastMsg ||
lastMsg.type !== 'gemini' ||
lastMsg.type !== 'qwen' ||
this.queuedThoughts.length > 0
) {
const newMsg: MessageRecord = {
...this.newMessage('gemini' as const, ''),
...this.newMessage('qwen' as const, ''),
// This isn't strictly necessary, but TypeScript apparently can't
// tell that the first parameter to newMessage() becomes the
// resulting message's type, and so it thinks that toolCalls may
// not be present. Confirming the type here satisfies it.
type: 'gemini' as const,
toolCalls,
type: 'qwen' as const,
toolCalls: enrichedToolCalls,
thoughts: this.queuedThoughts,
model: this.config.getModel(),
model,
};
// If there are any queued thoughts join them to this message.
if (this.queuedThoughts.length > 0) {
@@ -346,7 +354,7 @@ export class ChatRecordingService {
});
// Add any new tools calls that aren't in the message yet.
for (const toolCall of toolCalls) {
for (const toolCall of enrichedToolCalls) {
const existingToolCall = lastMsg.toolCalls.find(
(tc) => tc.id === toolCall.id,
);

View File

@@ -57,8 +57,8 @@ describe('FileDiscoveryService', () => {
await createTestFile('.qwenignore', 'secrets.txt');
const service = new FileDiscoveryService(projectRoot);
expect(service.shouldGeminiIgnoreFile('secrets.txt')).toBe(true);
expect(service.shouldGeminiIgnoreFile('src/index.js')).toBe(false);
expect(service.shouldQwenIgnoreFile('secrets.txt')).toBe(true);
expect(service.shouldQwenIgnoreFile('src/index.js')).toBe(false);
});
});
@@ -69,7 +69,7 @@ describe('FileDiscoveryService', () => {
await createTestFile('.qwenignore', 'logs/');
});
it('should filter out git-ignored and gemini-ignored files by default', () => {
it('should filter out git-ignored and qwen-ignored files by default', () => {
const files = [
'src/index.ts',
'node_modules/package/index.js',
@@ -98,7 +98,7 @@ describe('FileDiscoveryService', () => {
const filtered = service.filterFiles(files, {
respectGitIgnore: false,
respectGeminiIgnore: true, // still respect this one
respectQwenIgnore: true, // still respect this one
});
expect(filtered).toEqual(
@@ -108,7 +108,7 @@ describe('FileDiscoveryService', () => {
);
});
it('should not filter files when respectGeminiIgnore is false', () => {
it('should not filter files when respectQwenIgnore is false', () => {
const files = [
'src/index.ts',
'node_modules/package/index.js',
@@ -119,7 +119,7 @@ describe('FileDiscoveryService', () => {
const filtered = service.filterFiles(files, {
respectGitIgnore: true,
respectGeminiIgnore: false,
respectQwenIgnore: false,
});
expect(filtered).toEqual(
@@ -136,7 +136,7 @@ describe('FileDiscoveryService', () => {
});
});
describe('shouldGitIgnoreFile & shouldGeminiIgnoreFile', () => {
describe('shouldGitIgnoreFile & shouldQwenIgnoreFile', () => {
beforeEach(async () => {
await fs.mkdir(path.join(projectRoot, '.git'));
await createTestFile('.gitignore', 'node_modules/');
@@ -161,19 +161,19 @@ describe('FileDiscoveryService', () => {
).toBe(false);
});
it('should return true for gemini-ignored files', () => {
it('should return true for qwen-ignored files', () => {
const service = new FileDiscoveryService(projectRoot);
expect(
service.shouldGeminiIgnoreFile(path.join(projectRoot, 'debug.log')),
service.shouldQwenIgnoreFile(path.join(projectRoot, 'debug.log')),
).toBe(true);
});
it('should return false for non-gemini-ignored files', () => {
it('should return false for non-qwen-ignored files', () => {
const service = new FileDiscoveryService(projectRoot);
expect(
service.shouldGeminiIgnoreFile(path.join(projectRoot, 'src/index.ts')),
service.shouldQwenIgnoreFile(path.join(projectRoot, 'src/index.ts')),
).toBe(false);
});
});

View File

@@ -5,40 +5,34 @@
*/
import type { GitIgnoreFilter } from '../utils/gitIgnoreParser.js';
import type { QwenIgnoreFilter } from '../utils/qwenIgnoreParser.js';
import { GitIgnoreParser } from '../utils/gitIgnoreParser.js';
import { QwenIgnoreParser } from '../utils/qwenIgnoreParser.js';
import { isGitRepository } from '../utils/gitUtils.js';
import * as path from 'node:path';
const GEMINI_IGNORE_FILE_NAME = '.qwenignore';
export interface FilterFilesOptions {
respectGitIgnore?: boolean;
respectGeminiIgnore?: boolean;
respectQwenIgnore?: boolean;
}
export interface FilterReport {
filteredPaths: string[];
gitIgnoredCount: number;
qwenIgnoredCount: number;
}
export class FileDiscoveryService {
private gitIgnoreFilter: GitIgnoreFilter | null = null;
private geminiIgnoreFilter: GitIgnoreFilter | null = null;
private qwenIgnoreFilter: QwenIgnoreFilter | null = null;
private projectRoot: string;
constructor(projectRoot: string) {
this.projectRoot = path.resolve(projectRoot);
if (isGitRepository(this.projectRoot)) {
const parser = new GitIgnoreParser(this.projectRoot);
try {
parser.loadGitRepoPatterns();
} catch (_error) {
// ignore file not found
}
this.gitIgnoreFilter = parser;
this.gitIgnoreFilter = new GitIgnoreParser(this.projectRoot);
}
const gParser = new GitIgnoreParser(this.projectRoot);
try {
gParser.loadPatterns(GEMINI_IGNORE_FILE_NAME);
} catch (_error) {
// ignore file not found
}
this.geminiIgnoreFilter = gParser;
this.qwenIgnoreFilter = new QwenIgnoreParser(this.projectRoot);
}
/**
@@ -48,23 +42,56 @@ export class FileDiscoveryService {
filePaths: string[],
options: FilterFilesOptions = {
respectGitIgnore: true,
respectGeminiIgnore: true,
respectQwenIgnore: true,
},
): string[] {
return filePaths.filter((filePath) => {
if (options.respectGitIgnore && this.shouldGitIgnoreFile(filePath)) {
return false;
}
if (
options.respectGeminiIgnore &&
this.shouldGeminiIgnoreFile(filePath)
) {
if (options.respectQwenIgnore && this.shouldQwenIgnoreFile(filePath)) {
return false;
}
return true;
});
}
/**
* Filters a list of file paths based on git ignore rules and returns a report
* with counts of ignored files.
*/
filterFilesWithReport(
filePaths: string[],
opts: FilterFilesOptions = {
respectGitIgnore: true,
respectQwenIgnore: true,
},
): FilterReport {
const filteredPaths: string[] = [];
let gitIgnoredCount = 0;
let qwenIgnoredCount = 0;
for (const filePath of filePaths) {
if (opts.respectGitIgnore && this.shouldGitIgnoreFile(filePath)) {
gitIgnoredCount++;
continue;
}
if (opts.respectQwenIgnore && this.shouldQwenIgnoreFile(filePath)) {
qwenIgnoredCount++;
continue;
}
filteredPaths.push(filePath);
}
return {
filteredPaths,
gitIgnoredCount,
qwenIgnoredCount,
};
}
/**
* Checks if a single file should be git-ignored
*/
@@ -76,11 +103,11 @@ export class FileDiscoveryService {
}
/**
* Checks if a single file should be gemini-ignored
* Checks if a single file should be qwen-ignored
*/
shouldGeminiIgnoreFile(filePath: string): boolean {
if (this.geminiIgnoreFilter) {
return this.geminiIgnoreFilter.isIgnored(filePath);
shouldQwenIgnoreFile(filePath: string): boolean {
if (this.qwenIgnoreFilter) {
return this.qwenIgnoreFilter.isIgnored(filePath);
}
return false;
}
@@ -92,12 +119,15 @@ export class FileDiscoveryService {
filePath: string,
options: FilterFilesOptions = {},
): boolean {
const { respectGitIgnore = true, respectGeminiIgnore = true } = options;
const {
respectGitIgnore = true,
respectQwenIgnore: respectQwenIgnore = true,
} = options;
if (respectGitIgnore && this.shouldGitIgnoreFile(filePath)) {
return true;
}
if (respectGeminiIgnore && this.shouldGeminiIgnoreFile(filePath)) {
if (respectQwenIgnore && this.shouldQwenIgnoreFile(filePath)) {
return true;
}
return false;
@@ -106,7 +136,7 @@ export class FileDiscoveryService {
/**
* Returns loaded patterns from .qwenignore
*/
getGeminiIgnorePatterns(): string[] {
return this.geminiIgnoreFilter?.getPatterns() ?? [];
getQwenIgnorePatterns(): string[] {
return this.qwenIgnoreFilter?.getPatterns() ?? [];
}
}

View File

@@ -5,6 +5,8 @@
*/
import fs from 'node:fs/promises';
import * as path from 'node:path';
import { globSync } from 'glob';
/**
* Interface for file system operations that may be delegated to different implementations
@@ -25,6 +27,15 @@ export interface FileSystemService {
* @param content - The content to write
*/
writeTextFile(filePath: string, content: string): Promise<void>;
/**
* Finds files with a given name within specified search paths.
*
* @param fileName - The name of the file to find.
* @param searchPaths - An array of directory paths to search within.
* @returns An array of absolute paths to the found files.
*/
findFiles(fileName: string, searchPaths: readonly string[]): string[];
}
/**
@@ -38,4 +49,14 @@ export class StandardFileSystemService implements FileSystemService {
async writeTextFile(filePath: string, content: string): Promise<void> {
await fs.writeFile(filePath, content, 'utf-8');
}
findFiles(fileName: string, searchPaths: readonly string[]): string[] {
return searchPaths.flatMap((searchPath) => {
const pattern = path.posix.join(searchPath, '**', fileName);
return globSync(pattern, {
nodir: true,
absolute: true,
});
});
}
}

View File

@@ -4,18 +4,25 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mock,
} from 'vitest';
import { GitService } from './gitService.js';
import { Storage } from '../config/storage.js';
import * as path from 'node:path';
import * as fs from 'node:fs/promises';
import * as os from 'node:os';
import type { ChildProcess } from 'node:child_process';
import { getProjectHash, QWEN_DIR } from '../utils/paths.js';
import { spawnAsync } from '../utils/shell-utils.js';
const hoistedMockExec = vi.hoisted(() => vi.fn());
vi.mock('node:child_process', () => ({
exec: hoistedMockExec,
vi.mock('../utils/shell-utils.js', () => ({
spawnAsync: vi.fn(),
}));
const hoistedMockEnv = vi.hoisted(() => vi.fn());
@@ -69,13 +76,9 @@ describe('GitService', () => {
vi.clearAllMocks();
hoistedIsGitRepositoryMock.mockReturnValue(true);
hoistedMockExec.mockImplementation((command, callback) => {
if (command === 'git --version') {
callback(null, 'git version 2.0.0');
} else {
callback(new Error('Command not mocked'));
}
return {};
(spawnAsync as Mock).mockResolvedValue({
stdout: 'git version 2.0.0',
stderr: '',
});
hoistedMockHomedir.mockReturnValue(homedir);
@@ -120,13 +123,11 @@ describe('GitService', () => {
it('should resolve true if git --version command succeeds', async () => {
const service = new GitService(projectRoot, storage);
await expect(service.verifyGitAvailability()).resolves.toBe(true);
expect(spawnAsync).toHaveBeenCalledWith('git', ['--version']);
});
it('should resolve false if git --version command fails', async () => {
hoistedMockExec.mockImplementation((command, callback) => {
callback(new Error('git not found'));
return {} as ChildProcess;
});
(spawnAsync as Mock).mockRejectedValue(new Error('git not found'));
const service = new GitService(projectRoot, storage);
await expect(service.verifyGitAvailability()).resolves.toBe(false);
});
@@ -134,10 +135,7 @@ describe('GitService', () => {
describe('initialize', () => {
it('should throw an error if Git is not available', async () => {
hoistedMockExec.mockImplementation((command, callback) => {
callback(new Error('git not found'));
return {} as ChildProcess;
});
(spawnAsync as Mock).mockRejectedValue(new Error('git not found'));
const service = new GitService(projectRoot, storage);
await expect(service.initialize()).rejects.toThrow(
'Checkpointing is enabled, but Git is not installed. Please install Git or disable checkpointing to continue.',

View File

@@ -4,10 +4,11 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { exec } from 'node:child_process';
import * as fs from 'node:fs/promises';
import * as path from 'node:path';
import { CheckRepoActions, simpleGit, type SimpleGit } from 'simple-git';
import { spawnAsync } from '../utils/shell-utils.js';
import type { SimpleGit } from 'simple-git';
import { simpleGit, CheckRepoActions } from 'simple-git';
import type { Storage } from '../config/storage.js';
import { isNodeError } from '../utils/errors.js';
@@ -40,16 +41,13 @@ export class GitService {
}
}
verifyGitAvailability(): Promise<boolean> {
return new Promise((resolve) => {
exec('git --version', (error) => {
if (error) {
resolve(false);
} else {
resolve(true);
}
});
});
async verifyGitAvailability(): Promise<boolean> {
try {
await spawnAsync('git', ['--version']);
return true;
} catch (_error) {
return false;
}
}
/**

View File

@@ -7,6 +7,7 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import type { Config } from '../config/config.js';
import type { GeminiClient } from '../core/client.js';
import type { BaseLlmClient } from '../core/baseLlmClient.js';
import type {
ServerGeminiContentEvent,
ServerGeminiStreamEvent,
@@ -19,8 +20,7 @@ import { LoopDetectionService } from './loopDetectionService.js';
vi.mock('../telemetry/loggers.js', () => ({
logLoopDetected: vi.fn(),
logApiError: vi.fn(),
logApiResponse: vi.fn(),
logLoopDetectionDisabled: vi.fn(),
}));
const TOOL_CALL_LOOP_THRESHOLD = 5;
@@ -132,6 +132,16 @@ describe('LoopDetectionService', () => {
expect(service.addAndCheck(toolCallEvent)).toBe(true);
expect(loggers.logLoopDetected).toHaveBeenCalledTimes(1);
});
it('should not detect a loop when disabled for session', () => {
service.disableForSession();
expect(loggers.logLoopDetectionDisabled).toHaveBeenCalledTimes(1);
const event = createToolCallRequestEvent('testTool', { param: 'value' });
for (let i = 0; i < TOOL_CALL_LOOP_THRESHOLD; i++) {
expect(service.addAndCheck(event)).toBe(false);
}
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
});
describe('Content Loop Detection', () => {
@@ -618,18 +628,24 @@ describe('LoopDetectionService LLM Checks', () => {
let service: LoopDetectionService;
let mockConfig: Config;
let mockGeminiClient: GeminiClient;
let mockBaseLlmClient: BaseLlmClient;
let abortController: AbortController;
beforeEach(() => {
mockGeminiClient = {
getHistory: vi.fn().mockReturnValue([]),
generateJson: vi.fn(),
} as unknown as GeminiClient;
mockBaseLlmClient = {
generateJson: vi.fn(),
} as unknown as BaseLlmClient;
mockConfig = {
getGeminiClient: () => mockGeminiClient,
getBaseLlmClient: () => mockBaseLlmClient,
getDebugMode: () => false,
getTelemetryEnabled: () => true,
getModel: () => 'test-model',
} as unknown as Config;
service = new LoopDetectionService(mockConfig);
@@ -649,30 +665,39 @@ describe('LoopDetectionService LLM Checks', () => {
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
await advanceTurns(29);
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
it('should trigger LLM check on the 30th turn', async () => {
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.1 });
await advanceTurns(30);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledWith(
expect.objectContaining({
systemInstruction: expect.any(String),
contents: expect.any(Array),
model: expect.any(String),
schema: expect.any(Object),
promptId: expect.any(String),
}),
);
});
it('should detect a cognitive loop when confidence is high', async () => {
// First check at turn 30
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.85, reasoning: 'Repetitive actions' });
await advanceTurns(30);
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
// The confidence of 0.85 will result in a low interval.
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
await advanceTurns(6); // advance to turn 36
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.95, reasoning: 'Repetitive actions' });
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37
@@ -688,7 +713,7 @@ describe('LoopDetectionService LLM Checks', () => {
});
it('should not detect a loop when confidence is low', async () => {
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.5, reasoning: 'Looks okay' });
await advanceTurns(30);
@@ -699,21 +724,21 @@ describe('LoopDetectionService LLM Checks', () => {
it('should adjust the check interval based on confidence', async () => {
// Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15)
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockResolvedValue({ confidence: 0.0 });
await advanceTurns(30); // First check at turn 30
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
await advanceTurns(14); // Advance to turn 44
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(1);
await service.turnStarted(abortController.signal); // Turn 45
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(2);
expect(mockBaseLlmClient.generateJson).toHaveBeenCalledTimes(2);
});
it('should handle errors from generateJson gracefully', async () => {
mockGeminiClient.generateJson = vi
mockBaseLlmClient.generateJson = vi
.fn()
.mockRejectedValue(new Error('API error'));
await advanceTurns(30);
@@ -721,4 +746,13 @@ describe('LoopDetectionService LLM Checks', () => {
expect(result).toBe(false);
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
});
it('should not trigger LLM check when disabled for session', async () => {
service.disableForSession();
expect(loggers.logLoopDetectionDisabled).toHaveBeenCalledTimes(1);
await advanceTurns(30);
const result = await service.turnStarted(abortController.signal);
expect(result).toBe(false);
expect(mockBaseLlmClient.generateJson).not.toHaveBeenCalled();
});
});

View File

@@ -8,14 +8,21 @@ import type { Content } from '@google/genai';
import { createHash } from 'node:crypto';
import type { ServerGeminiStreamEvent } from '../core/turn.js';
import { GeminiEventType } from '../core/turn.js';
import { logLoopDetected } from '../telemetry/loggers.js';
import { LoopDetectedEvent, LoopType } from '../telemetry/types.js';
import {
logLoopDetected,
logLoopDetectionDisabled,
} from '../telemetry/loggers.js';
import {
LoopDetectedEvent,
LoopDetectionDisabledEvent,
LoopType,
} from '../telemetry/types.js';
import type { Config } from '../config/config.js';
import {
isFunctionCall,
isFunctionResponse,
} from '../utils/messageInspectors.js';
import { DEFAULT_QWEN_FLASH_MODEL } from '../config/models.js';
import { DEFAULT_QWEN_MODEL } from '../config/models.js';
const TOOL_CALL_LOOP_THRESHOLD = 5;
const CONTENT_LOOP_THRESHOLD = 10;
@@ -50,6 +57,17 @@ const MIN_LLM_CHECK_INTERVAL = 5;
*/
const MAX_LLM_CHECK_INTERVAL = 15;
const LOOP_DETECTION_SYSTEM_PROMPT = `You are a sophisticated AI diagnostic agent specializing in identifying when a conversational AI is stuck in an unproductive state. Your task is to analyze the provided conversation history and determine if the assistant has ceased to make meaningful progress.
An unproductive state is characterized by one or more of the following patterns over the last 5 or more assistant turns:
Repetitive Actions: The assistant repeats the same tool calls or conversational responses a decent number of times. This includes simple loops (e.g., tool_A, tool_A, tool_A) and alternating patterns (e.g., tool_A, tool_B, tool_A, tool_B, ...).
Cognitive Loop: The assistant seems unable to determine the next logical step. It might express confusion, repeatedly ask the same questions, or generate responses that don't logically follow from the previous turns, indicating it's stuck and not advancing the task.
Crucially, differentiate between a true unproductive state and legitimate, incremental progress.
For example, a series of 'tool_A' or 'tool_B' tool calls that make small, distinct changes to the same file (like adding docstrings to functions one by one) is considered forward progress and is NOT a loop. A loop would be repeatedly replacing the same text with the same content, or cycling between a small set of files with no net change.`;
/**
* Service for detecting and preventing infinite loops in AI responses.
* Monitors tool call repetitions and content sentence repetitions.
@@ -74,10 +92,24 @@ export class LoopDetectionService {
private llmCheckInterval = DEFAULT_LLM_CHECK_INTERVAL;
private lastCheckTurn = 0;
// Session-level disable flag
private disabledForSession = false;
constructor(config: Config) {
this.config = config;
}
/**
* Disables loop detection for the current session.
*/
disableForSession(): void {
this.disabledForSession = true;
logLoopDetectionDisabled(
this.config,
new LoopDetectionDisabledEvent(this.promptId),
);
}
private getToolCallKey(toolCall: { name: string; args: object }): string {
const argsString = JSON.stringify(toolCall.args);
const keyString = `${toolCall.name}:${argsString}`;
@@ -90,8 +122,8 @@ export class LoopDetectionService {
* @returns true if a loop is detected, false otherwise
*/
addAndCheck(event: ServerGeminiStreamEvent): boolean {
if (this.loopDetected) {
return true;
if (this.loopDetected || this.disabledForSession) {
return this.loopDetected;
}
switch (event.type) {
@@ -121,6 +153,9 @@ export class LoopDetectionService {
* @returns A promise that resolves to `true` if a loop is detected, and `false` otherwise.
*/
async turnStarted(signal: AbortSignal) {
if (this.disabledForSession) {
return false;
}
this.turnsInCurrentPrompt++;
if (
@@ -362,21 +397,11 @@ export class LoopDetectionService {
const trimmedHistory = this.trimRecentHistory(recentHistory);
const prompt = `You are a sophisticated AI diagnostic agent specializing in identifying when a conversational AI is stuck in an unproductive state. Your task is to analyze the provided conversation history and determine if the assistant has ceased to make meaningful progress.
const taskPrompt = `Please analyze the conversation history to determine the possibility that the conversation is stuck in a repetitive, non-productive state. Provide your response in the requested JSON format.`;
An unproductive state is characterized by one or more of the following patterns over the last 5 or more assistant turns:
Repetitive Actions: The assistant repeats the same tool calls or conversational responses a decent number of times. This includes simple loops (e.g., tool_A, tool_A, tool_A) and alternating patterns (e.g., tool_A, tool_B, tool_A, tool_B, ...).
Cognitive Loop: The assistant seems unable to determine the next logical step. It might express confusion, repeatedly ask the same questions, or generate responses that don't logically follow from the previous turns, indicating it's stuck and not advancing the task.
Crucially, differentiate between a true unproductive state and legitimate, incremental progress.
For example, a series of 'tool_A' or 'tool_B' tool calls that make small, distinct changes to the same file (like adding docstrings to functions one by one) is considered forward progress and is NOT a loop. A loop would be repeatedly replacing the same text with the same content, or cycling between a small set of files with no net change.
Please analyze the conversation history to determine the possibility that the conversation is stuck in a repetitive, non-productive state.`;
const contents = [
...trimmedHistory,
{ role: 'user', parts: [{ text: prompt }] },
{ role: 'user', parts: [{ text: taskPrompt }] },
];
const schema: Record<string, unknown> = {
type: 'object',
@@ -396,9 +421,14 @@ Please analyze the conversation history to determine the possibility that the co
};
let result;
try {
result = await this.config
.getGeminiClient()
.generateJson(contents, schema, signal, DEFAULT_QWEN_FLASH_MODEL);
result = await this.config.getBaseLlmClient().generateJson({
contents,
schema,
model: this.config.getModel() || DEFAULT_QWEN_MODEL,
systemInstruction: LOOP_DETECTION_SYSTEM_PROMPT,
abortSignal: signal,
promptId: this.promptId,
});
} catch (e) {
// Do nothing, treat it as a non-loop.
this.config.getDebugMode() ? console.error(e) : console.debug(e);

View File

@@ -4,14 +4,13 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { type ChildProcess } from 'child_process';
import EventEmitter from 'events';
import type { Readable } from 'stream';
import { beforeEach, describe, expect, it, vi, type Mock } from 'vitest';
import {
ShellExecutionService,
type ShellOutputEvent,
} from './shellExecutionService.js';
import { vi, describe, it, expect, beforeEach, type Mock } from 'vitest';
import EventEmitter from 'node:events';
import type { Readable } from 'node:stream';
import { type ChildProcess } from 'node:child_process';
import type { ShellOutputEvent } from './shellExecutionService.js';
import { ShellExecutionService } from './shellExecutionService.js';
import type { AnsiOutput } from '../utils/terminalSerializer.js';
// Hoisted Mocks
const mockPtySpawn = vi.hoisted(() => vi.fn());
@@ -19,6 +18,7 @@ const mockCpSpawn = vi.hoisted(() => vi.fn());
const mockIsBinary = vi.hoisted(() => vi.fn());
const mockPlatform = vi.hoisted(() => vi.fn());
const mockGetPty = vi.hoisted(() => vi.fn());
const mockSerializeTerminalToObject = vi.hoisted(() => vi.fn());
// Top-level Mocks
vi.mock('@lydell/node-pty', () => ({
@@ -51,17 +51,59 @@ vi.mock('os', () => ({
vi.mock('../utils/getPty.js', () => ({
getPty: mockGetPty,
}));
vi.mock('../utils/terminalSerializer.js', () => ({
serializeTerminalToObject: mockSerializeTerminalToObject,
}));
const mockProcessKill = vi
.spyOn(process, 'kill')
.mockImplementation(() => true);
const shellExecutionConfig = {
terminalWidth: 80,
terminalHeight: 24,
pager: 'cat',
showColor: false,
disableDynamicLineTrimming: true,
};
const createExpectedAnsiOutput = (text: string | string[]): AnsiOutput => {
const lines = Array.isArray(text) ? text : text.split('\n');
const expected: AnsiOutput = Array.from(
{ length: shellExecutionConfig.terminalHeight },
(_, i) => [
{
text: expect.stringMatching((lines[i] || '').trim()),
bold: false,
italic: false,
underline: false,
dim: false,
inverse: false,
fg: '',
bg: '',
},
],
);
return expected;
};
describe('ShellExecutionService', () => {
let mockPtyProcess: EventEmitter & {
pid: number;
kill: Mock;
onData: Mock;
onExit: Mock;
write: Mock;
resize: Mock;
};
let mockHeadlessTerminal: {
resize: Mock;
scrollLines: Mock;
buffer: {
active: {
viewportY: number;
};
};
};
let onOutputEventMock: Mock<(event: ShellOutputEvent) => void>;
@@ -82,11 +124,25 @@ describe('ShellExecutionService', () => {
kill: Mock;
onData: Mock;
onExit: Mock;
write: Mock;
resize: Mock;
};
mockPtyProcess.pid = 12345;
mockPtyProcess.kill = vi.fn();
mockPtyProcess.onData = vi.fn();
mockPtyProcess.onExit = vi.fn();
mockPtyProcess.write = vi.fn();
mockPtyProcess.resize = vi.fn();
mockHeadlessTerminal = {
resize: vi.fn(),
scrollLines: vi.fn(),
buffer: {
active: {
viewportY: 0,
},
},
};
mockPtySpawn.mockReturnValue(mockPtyProcess);
});
@@ -98,6 +154,7 @@ describe('ShellExecutionService', () => {
ptyProcess: typeof mockPtyProcess,
ac: AbortController,
) => void,
config = shellExecutionConfig,
) => {
const abortController = new AbortController();
const handle = await ShellExecutionService.execute(
@@ -106,9 +163,10 @@ describe('ShellExecutionService', () => {
onOutputEventMock,
abortController.signal,
true,
config,
);
await new Promise((resolve) => setImmediate(resolve));
await new Promise((resolve) => process.nextTick(resolve));
simulation(mockPtyProcess, abortController);
const result = await handle.result;
return { result, handle, abortController };
@@ -130,12 +188,12 @@ describe('ShellExecutionService', () => {
expect(result.signal).toBeNull();
expect(result.error).toBeNull();
expect(result.aborted).toBe(false);
expect(result.output).toBe('file1.txt');
expect(result.output.trim()).toBe('file1.txt');
expect(handle.pid).toBe(12345);
expect(onOutputEventMock).toHaveBeenCalledWith({
type: 'data',
chunk: 'file1.txt',
chunk: createExpectedAnsiOutput('file1.txt'),
});
});
@@ -145,11 +203,13 @@ describe('ShellExecutionService', () => {
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
});
expect(result.output).toBe('aredword');
expect(onOutputEventMock).toHaveBeenCalledWith({
type: 'data',
chunk: 'aredword',
});
expect(result.output.trim()).toBe('aredword');
expect(onOutputEventMock).toHaveBeenCalledWith(
expect.objectContaining({
type: 'data',
chunk: createExpectedAnsiOutput('aredword'),
}),
);
});
it('should correctly decode multi-byte characters split across chunks', async () => {
@@ -159,16 +219,81 @@ describe('ShellExecutionService', () => {
pty.onData.mock.calls[0][0](multiByteChar.slice(1));
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
});
expect(result.output).toBe('你好');
expect(result.output.trim()).toBe('你好');
});
it('should handle commands with no output', async () => {
const { result } = await simulateExecution('touch file', (pty) => {
await simulateExecution('touch file', (pty) => {
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
});
expect(result.output).toBe('');
expect(onOutputEventMock).not.toHaveBeenCalled();
expect(onOutputEventMock).toHaveBeenCalledWith(
expect.objectContaining({
chunk: createExpectedAnsiOutput(''),
}),
);
});
it('should call onPid with the process id', async () => {
const abortController = new AbortController();
const handle = await ShellExecutionService.execute(
'ls -l',
'/test/dir',
onOutputEventMock,
abortController.signal,
true,
shellExecutionConfig,
);
mockPtyProcess.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
await handle.result;
expect(handle.pid).toBe(12345);
});
});
describe('pty interaction', () => {
beforeEach(() => {
vi.spyOn(ShellExecutionService['activePtys'], 'get').mockReturnValue({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
ptyProcess: mockPtyProcess as any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
headlessTerminal: mockHeadlessTerminal as any,
});
});
it('should write to the pty and trigger a render', async () => {
vi.useFakeTimers();
await simulateExecution('interactive-app', (pty) => {
ShellExecutionService.writeToPty(pty.pid!, 'input');
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
});
expect(mockPtyProcess.write).toHaveBeenCalledWith('input');
// Use fake timers to check for the delayed render
await vi.advanceTimersByTimeAsync(17);
// The render will cause an output event
expect(onOutputEventMock).toHaveBeenCalled();
vi.useRealTimers();
});
it('should resize the pty and the headless terminal', async () => {
await simulateExecution('ls -l', (pty) => {
pty.onData.mock.calls[0][0]('file1.txt\n');
ShellExecutionService.resizePty(pty.pid!, 100, 40);
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
});
expect(mockPtyProcess.resize).toHaveBeenCalledWith(100, 40);
expect(mockHeadlessTerminal.resize).toHaveBeenCalledWith(100, 40);
});
it('should scroll the headless terminal', async () => {
await simulateExecution('ls -l', (pty) => {
pty.onData.mock.calls[0][0]('file1.txt\n');
ShellExecutionService.scrollPty(pty.pid!, 10);
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
});
expect(mockHeadlessTerminal.scrollLines).toHaveBeenCalledWith(10);
});
});
@@ -180,7 +305,7 @@ describe('ShellExecutionService', () => {
});
expect(result.exitCode).toBe(127);
expect(result.output).toBe('command not found');
expect(result.output.trim()).toBe('command not found');
expect(result.error).toBeNull();
});
@@ -206,6 +331,7 @@ describe('ShellExecutionService', () => {
onOutputEventMock,
new AbortController().signal,
true,
{},
);
const result = await handle.result;
@@ -228,7 +354,7 @@ describe('ShellExecutionService', () => {
);
expect(result.aborted).toBe(true);
expect(mockPtyProcess.kill).toHaveBeenCalled();
// The process kill is mocked, so we just check that the flag is set.
});
});
@@ -265,7 +391,6 @@ describe('ShellExecutionService', () => {
mockIsBinary.mockImplementation((buffer) => buffer.includes(0x00));
await simulateExecution('cat mixed_file', (pty) => {
pty.onData.mock.calls[0][0](Buffer.from('some text'));
pty.onData.mock.calls[0][0](Buffer.from([0x00, 0x01, 0x02]));
pty.onData.mock.calls[0][0](Buffer.from('more text'));
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
@@ -275,7 +400,6 @@ describe('ShellExecutionService', () => {
(call: [ShellOutputEvent]) => call[0].type,
);
expect(eventTypes).toEqual([
'data',
'binary_detected',
'binary_progress',
'binary_progress',
@@ -310,6 +434,92 @@ describe('ShellExecutionService', () => {
);
});
});
describe('AnsiOutput rendering', () => {
it('should call onOutputEvent with AnsiOutput when showColor is true', async () => {
const coloredShellExecutionConfig = {
...shellExecutionConfig,
showColor: true,
defaultFg: '#ffffff',
defaultBg: '#000000',
disableDynamicLineTrimming: true,
};
const mockAnsiOutput = [
[{ text: 'hello', fg: '#ffffff', bg: '#000000' }],
];
mockSerializeTerminalToObject.mockReturnValue(mockAnsiOutput);
await simulateExecution(
'ls --color=auto',
(pty) => {
pty.onData.mock.calls[0][0]('a\u001b[31mred\u001b[0mword');
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
},
coloredShellExecutionConfig,
);
expect(mockSerializeTerminalToObject).toHaveBeenCalledWith(
expect.anything(), // The terminal object
);
expect(onOutputEventMock).toHaveBeenCalledWith(
expect.objectContaining({
type: 'data',
chunk: mockAnsiOutput,
}),
);
});
it('should call onOutputEvent with AnsiOutput when showColor is false', async () => {
await simulateExecution(
'ls --color=auto',
(pty) => {
pty.onData.mock.calls[0][0]('a\u001b[31mred\u001b[0mword');
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
},
{
...shellExecutionConfig,
showColor: false,
disableDynamicLineTrimming: true,
},
);
const expected = createExpectedAnsiOutput('aredword');
expect(onOutputEventMock).toHaveBeenCalledWith(
expect.objectContaining({
type: 'data',
chunk: expected,
}),
);
});
it('should handle multi-line output correctly when showColor is false', async () => {
await simulateExecution(
'ls --color=auto',
(pty) => {
pty.onData.mock.calls[0][0](
'line 1\n\u001b[32mline 2\u001b[0m\nline 3',
);
pty.onExit.mock.calls[0][0]({ exitCode: 0, signal: null });
},
{
...shellExecutionConfig,
showColor: false,
disableDynamicLineTrimming: true,
},
);
const expected = createExpectedAnsiOutput(['line 1', 'line 2', 'line 3']);
expect(onOutputEventMock).toHaveBeenCalledWith(
expect.objectContaining({
type: 'data',
chunk: expected,
}),
);
});
});
});
describe('ShellExecutionService child_process fallback', () => {
@@ -351,9 +561,10 @@ describe('ShellExecutionService child_process fallback', () => {
onOutputEventMock,
abortController.signal,
true,
shellExecutionConfig,
);
await new Promise((resolve) => setImmediate(resolve));
await new Promise((resolve) => process.nextTick(resolve));
simulation(mockChildProcess, abortController);
const result = await handle.result;
return { result, handle, abortController };
@@ -365,6 +576,7 @@ describe('ShellExecutionService child_process fallback', () => {
cp.stdout?.emit('data', Buffer.from('file1.txt\n'));
cp.stderr?.emit('data', Buffer.from('a warning'));
cp.emit('exit', 0, null);
cp.emit('close', 0, null);
});
expect(mockCpSpawn).toHaveBeenCalledWith(
@@ -377,15 +589,11 @@ describe('ShellExecutionService child_process fallback', () => {
expect(result.error).toBeNull();
expect(result.aborted).toBe(false);
expect(result.output).toBe('file1.txt\na warning');
expect(handle.pid).toBe(12345);
expect(handle.pid).toBe(undefined);
expect(onOutputEventMock).toHaveBeenCalledWith({
type: 'data',
chunk: 'file1.txt\n',
});
expect(onOutputEventMock).toHaveBeenCalledWith({
type: 'data',
chunk: 'a warning',
chunk: 'file1.txt\na warning',
});
});
@@ -393,13 +601,16 @@ describe('ShellExecutionService child_process fallback', () => {
const { result } = await simulateExecution('ls --color=auto', (cp) => {
cp.stdout?.emit('data', Buffer.from('a\u001b[31mred\u001b[0mword'));
cp.emit('exit', 0, null);
cp.emit('close', 0, null);
});
expect(result.output).toBe('aredword');
expect(onOutputEventMock).toHaveBeenCalledWith({
type: 'data',
chunk: 'aredword',
});
expect(result.output.trim()).toBe('aredword');
expect(onOutputEventMock).toHaveBeenCalledWith(
expect.objectContaining({
type: 'data',
chunk: 'aredword',
}),
);
});
it('should correctly decode multi-byte characters split across chunks', async () => {
@@ -408,16 +619,18 @@ describe('ShellExecutionService child_process fallback', () => {
cp.stdout?.emit('data', multiByteChar.slice(0, 2));
cp.stdout?.emit('data', multiByteChar.slice(2));
cp.emit('exit', 0, null);
cp.emit('close', 0, null);
});
expect(result.output).toBe('你好');
expect(result.output.trim()).toBe('你好');
});
it('should handle commands with no output', async () => {
const { result } = await simulateExecution('touch file', (cp) => {
cp.emit('exit', 0, null);
cp.emit('close', 0, null);
});
expect(result.output).toBe('');
expect(result.output.trim()).toBe('');
expect(onOutputEventMock).not.toHaveBeenCalled();
});
});
@@ -427,16 +640,18 @@ describe('ShellExecutionService child_process fallback', () => {
const { result } = await simulateExecution('a-bad-command', (cp) => {
cp.stderr?.emit('data', Buffer.from('command not found'));
cp.emit('exit', 127, null);
cp.emit('close', 127, null);
});
expect(result.exitCode).toBe(127);
expect(result.output).toBe('command not found');
expect(result.output.trim()).toBe('command not found');
expect(result.error).toBeNull();
});
it('should capture a termination signal', async () => {
const { result } = await simulateExecution('long-process', (cp) => {
cp.emit('exit', null, 'SIGTERM');
cp.emit('close', null, 'SIGTERM');
});
expect(result.exitCode).toBeNull();
@@ -448,6 +663,7 @@ describe('ShellExecutionService child_process fallback', () => {
const { result } = await simulateExecution('protected-cmd', (cp) => {
cp.emit('error', spawnError);
cp.emit('exit', 1, null);
cp.emit('close', 1, null);
});
expect(result.error).toBe(spawnError);
@@ -458,6 +674,7 @@ describe('ShellExecutionService child_process fallback', () => {
const error = new Error('spawn abc ENOENT');
const { result } = await simulateExecution('touch cat.jpg', (cp) => {
cp.emit('error', error); // No exit event is fired.
cp.emit('close', 1, null);
});
expect(result.error).toBe(error);
@@ -487,10 +704,14 @@ describe('ShellExecutionService child_process fallback', () => {
'sleep 10',
(cp, abortController) => {
abortController.abort();
if (expectedExit.signal)
if (expectedExit.signal) {
cp.emit('exit', null, expectedExit.signal);
if (typeof expectedExit.code === 'number')
cp.emit('close', null, expectedExit.signal);
}
if (typeof expectedExit.code === 'number') {
cp.emit('exit', expectedExit.code, null);
cp.emit('close', expectedExit.code, null);
}
},
);
@@ -526,6 +747,7 @@ describe('ShellExecutionService child_process fallback', () => {
onOutputEventMock,
abortController.signal,
true,
{},
);
abortController.abort();
@@ -547,14 +769,13 @@ describe('ShellExecutionService child_process fallback', () => {
// Finally, simulate the process exiting and await the result
mockChildProcess.emit('exit', null, 'SIGKILL');
mockChildProcess.emit('close', null, 'SIGKILL');
const result = await handle.result;
vi.useRealTimers();
expect(result.aborted).toBe(true);
expect(result.signal).toBe(9);
// The individual kill calls were already asserted above.
expect(mockProcessKill).toHaveBeenCalledTimes(2);
});
});
@@ -573,18 +794,10 @@ describe('ShellExecutionService child_process fallback', () => {
expect(result.rawOutput).toEqual(
Buffer.concat([binaryChunk1, binaryChunk2]),
);
expect(onOutputEventMock).toHaveBeenCalledTimes(3);
expect(onOutputEventMock).toHaveBeenCalledTimes(1);
expect(onOutputEventMock.mock.calls[0][0]).toEqual({
type: 'binary_detected',
});
expect(onOutputEventMock.mock.calls[1][0]).toEqual({
type: 'binary_progress',
bytesReceived: 4,
});
expect(onOutputEventMock.mock.calls[2][0]).toEqual({
type: 'binary_progress',
bytesReceived: 8,
});
});
it('should not emit data events after binary is detected', async () => {
@@ -600,12 +813,7 @@ describe('ShellExecutionService child_process fallback', () => {
const eventTypes = onOutputEventMock.mock.calls.map(
(call: [ShellOutputEvent]) => call[0].type,
);
expect(eventTypes).toEqual([
'data',
'binary_detected',
'binary_progress',
'binary_progress',
]);
expect(eventTypes).toEqual(['binary_detected']);
});
});
@@ -649,6 +857,8 @@ describe('ShellExecutionService execution method selection', () => {
kill: Mock;
onData: Mock;
onExit: Mock;
write: Mock;
resize: Mock;
};
let mockChildProcess: EventEmitter & Partial<ChildProcess>;
@@ -662,11 +872,16 @@ describe('ShellExecutionService execution method selection', () => {
kill: Mock;
onData: Mock;
onExit: Mock;
write: Mock;
resize: Mock;
};
mockPtyProcess.pid = 12345;
mockPtyProcess.kill = vi.fn();
mockPtyProcess.onData = vi.fn();
mockPtyProcess.onExit = vi.fn();
mockPtyProcess.write = vi.fn();
mockPtyProcess.resize = vi.fn();
mockPtySpawn.mockReturnValue(mockPtyProcess);
mockGetPty.mockResolvedValue({
module: { spawn: mockPtySpawn },
@@ -694,6 +909,7 @@ describe('ShellExecutionService execution method selection', () => {
onOutputEventMock,
abortController.signal,
true, // shouldUseNodePty
shellExecutionConfig,
);
// Simulate exit to allow promise to resolve
@@ -714,6 +930,7 @@ describe('ShellExecutionService execution method selection', () => {
onOutputEventMock,
abortController.signal,
false, // shouldUseNodePty
{},
);
// Simulate exit to allow promise to resolve
@@ -736,6 +953,7 @@ describe('ShellExecutionService execution method selection', () => {
onOutputEventMock,
abortController.signal,
true, // shouldUseNodePty
shellExecutionConfig,
);
// Simulate exit to allow promise to resolve

View File

@@ -4,30 +4,24 @@
* SPDX-License-Identifier: Apache-2.0
*/
import pkg from '@xterm/headless';
import { spawn as cpSpawn } from 'child_process';
import os from 'os';
import stripAnsi from 'strip-ansi';
import { TextDecoder } from 'util';
import type { PtyImplementation } from '../utils/getPty.js';
import { getPty } from '../utils/getPty.js';
import { spawn as cpSpawn } from 'node:child_process';
import { TextDecoder } from 'node:util';
import os from 'node:os';
import type { IPty } from '@lydell/node-pty';
import { getCachedEncodingForBuffer } from '../utils/systemEncoding.js';
import { isBinary } from '../utils/textUtils.js';
import pkg from '@xterm/headless';
import {
serializeTerminalToObject,
type AnsiOutput,
} from '../utils/terminalSerializer.js';
const { Terminal } = pkg;
const SIGKILL_TIMEOUT_MS = 200;
// @ts-expect-error getFullText is not a public API.
const getFullText = (terminal: Terminal) => {
const buffer = terminal.buffer.active;
const lines: string[] = [];
for (let i = 0; i < buffer.length; i++) {
const line = buffer.getLine(i);
lines.push(line ? line.translateToString(true) : '');
}
return lines.join('\n').trim();
};
/** A structured result from a shell command execution. */
export interface ShellExecutionResult {
/** The raw, unprocessed output buffer. */
@@ -56,6 +50,17 @@ export interface ShellExecutionHandle {
result: Promise<ShellExecutionResult>;
}
export interface ShellExecutionConfig {
terminalWidth?: number;
terminalHeight?: number;
pager?: string;
showColor?: boolean;
defaultFg?: string;
defaultBg?: string;
// Used for testing
disableDynamicLineTrimming?: boolean;
}
/**
* Describes a structured event emitted during shell command execution.
*/
@@ -64,7 +69,7 @@ export type ShellOutputEvent =
/** The event contains a chunk of output data. */
type: 'data';
/** The decoded string chunk. */
chunk: string;
chunk: string | AnsiOutput;
}
| {
/** Signals that the output stream has been identified as binary. */
@@ -77,12 +82,30 @@ export type ShellOutputEvent =
bytesReceived: number;
};
interface ActivePty {
ptyProcess: IPty;
headlessTerminal: pkg.Terminal;
}
const getFullBufferText = (terminal: pkg.Terminal): string => {
const buffer = terminal.buffer.active;
const lines: string[] = [];
for (let i = 0; i < buffer.length; i++) {
const line = buffer.getLine(i);
const lineContent = line ? line.translateToString() : '';
lines.push(lineContent);
}
return lines.join('\n').trimEnd();
};
/**
* A centralized service for executing shell commands with robust process
* management, cross-platform compatibility, and streaming output capabilities.
*
*/
export class ShellExecutionService {
private static activePtys = new Map<number, ActivePty>();
/**
* Executes a shell command using `node-pty`, capturing all output and lifecycle events.
*
@@ -99,8 +122,7 @@ export class ShellExecutionService {
onOutputEvent: (event: ShellOutputEvent) => void,
abortSignal: AbortSignal,
shouldUseNodePty: boolean,
terminalColumns?: number,
terminalRows?: number,
shellExecutionConfig: ShellExecutionConfig,
): Promise<ShellExecutionHandle> {
if (shouldUseNodePty) {
const ptyInfo = await getPty();
@@ -111,8 +133,7 @@ export class ShellExecutionService {
cwd,
onOutputEvent,
abortSignal,
terminalColumns,
terminalRows,
shellExecutionConfig,
ptyInfo,
);
} catch (_e) {
@@ -186,31 +207,18 @@ export class ShellExecutionService {
if (isBinary(sniffBuffer)) {
isStreamingRawContent = false;
onOutputEvent({ type: 'binary_detected' });
}
}
const decoder = stream === 'stdout' ? stdoutDecoder : stderrDecoder;
const decodedChunk = decoder.decode(data, { stream: true });
const strippedChunk = stripAnsi(decodedChunk);
if (stream === 'stdout') {
stdout += strippedChunk;
} else {
stderr += strippedChunk;
}
if (isStreamingRawContent) {
onOutputEvent({ type: 'data', chunk: strippedChunk });
} else {
const totalBytes = outputChunks.reduce(
(sum, chunk) => sum + chunk.length,
0,
);
onOutputEvent({
type: 'binary_progress',
bytesReceived: totalBytes,
});
const decoder = stream === 'stdout' ? stdoutDecoder : stderrDecoder;
const decodedChunk = decoder.decode(data, { stream: true });
if (stream === 'stdout') {
stdout += decodedChunk;
} else {
stderr += decodedChunk;
}
}
};
@@ -224,14 +232,24 @@ export class ShellExecutionService {
const combinedOutput =
stdout + (stderr ? (stdout ? separator : '') + stderr : '');
const finalStrippedOutput = stripAnsi(combinedOutput).trim();
if (isStreamingRawContent) {
if (finalStrippedOutput) {
onOutputEvent({ type: 'data', chunk: finalStrippedOutput });
}
} else {
onOutputEvent({ type: 'binary_detected' });
}
resolve({
rawOutput: finalBuffer,
output: combinedOutput.trim(),
output: finalStrippedOutput,
exitCode: code,
signal: signal ? os.constants.signals[signal] : null,
error,
aborted: abortSignal.aborted,
pid: child.pid,
pid: undefined,
executionMethod: 'child_process',
});
};
@@ -264,6 +282,9 @@ export class ShellExecutionService {
abortSignal.addEventListener('abort', abortHandler, { once: true });
child.on('exit', (code, signal) => {
if (child.pid) {
this.activePtys.delete(child.pid);
}
handleExit(code, signal);
});
@@ -273,13 +294,13 @@ export class ShellExecutionService {
if (stdoutDecoder) {
const remaining = stdoutDecoder.decode();
if (remaining) {
stdout += stripAnsi(remaining);
stdout += remaining;
}
}
if (stderrDecoder) {
const remaining = stderrDecoder.decode();
if (remaining) {
stderr += stripAnsi(remaining);
stderr += remaining;
}
}
@@ -289,7 +310,7 @@ export class ShellExecutionService {
}
});
return { pid: child.pid, result };
return { pid: undefined, result };
} catch (e) {
const error = e as Error;
return {
@@ -313,29 +334,32 @@ export class ShellExecutionService {
cwd: string,
onOutputEvent: (event: ShellOutputEvent) => void,
abortSignal: AbortSignal,
terminalColumns: number | undefined,
terminalRows: number | undefined,
ptyInfo: PtyImplementation | undefined,
shellExecutionConfig: ShellExecutionConfig,
ptyInfo: PtyImplementation,
): ShellExecutionHandle {
if (!ptyInfo) {
// This should not happen, but as a safeguard...
throw new Error('PTY implementation not found');
}
try {
const cols = terminalColumns ?? 80;
const rows = terminalRows ?? 30;
const cols = shellExecutionConfig.terminalWidth ?? 80;
const rows = shellExecutionConfig.terminalHeight ?? 30;
const isWindows = os.platform() === 'win32';
const shell = isWindows ? 'cmd.exe' : 'bash';
const args = isWindows
? `/c ${commandToExecute}`
: ['-c', commandToExecute];
const ptyProcess = ptyInfo?.module.spawn(shell, args, {
const ptyProcess = ptyInfo.module.spawn(shell, args, {
cwd,
name: 'xterm-color',
name: 'xterm',
cols,
rows,
env: {
...process.env,
QWEN_CODE: '1',
TERM: 'xterm-256color',
PAGER: 'cat',
PAGER: shellExecutionConfig.pager ?? 'cat',
},
handleFlowControl: true,
});
@@ -346,9 +370,13 @@ export class ShellExecutionService {
cols,
rows,
});
headlessTerminal.scrollToTop();
this.activePtys.set(ptyProcess.pid, { ptyProcess, headlessTerminal });
let processingChain = Promise.resolve();
let decoder: TextDecoder | null = null;
let output = '';
let output: string | AnsiOutput | null = null;
const outputChunks: Buffer[] = [];
const error: Error | null = null;
let exited = false;
@@ -356,6 +384,97 @@ export class ShellExecutionService {
let isStreamingRawContent = true;
const MAX_SNIFF_SIZE = 4096;
let sniffedBytes = 0;
let isWriting = false;
let hasStartedOutput = false;
let renderTimeout: NodeJS.Timeout | null = null;
const render = (finalRender = false) => {
if (renderTimeout) {
clearTimeout(renderTimeout);
}
const renderFn = () => {
if (!isStreamingRawContent) {
return;
}
if (!shellExecutionConfig.disableDynamicLineTrimming) {
if (!hasStartedOutput) {
const bufferText = getFullBufferText(headlessTerminal);
if (bufferText.trim().length === 0) {
return;
}
hasStartedOutput = true;
}
}
let newOutput: AnsiOutput;
if (shellExecutionConfig.showColor) {
newOutput = serializeTerminalToObject(headlessTerminal);
} else {
const buffer = headlessTerminal.buffer.active;
const lines: AnsiOutput = [];
for (let y = 0; y < headlessTerminal.rows; y++) {
const line = buffer.getLine(buffer.viewportY + y);
const lineContent = line ? line.translateToString(true) : '';
lines.push([
{
text: lineContent,
bold: false,
italic: false,
underline: false,
dim: false,
inverse: false,
fg: '',
bg: '',
},
]);
}
newOutput = lines;
}
let lastNonEmptyLine = -1;
for (let i = newOutput.length - 1; i >= 0; i--) {
const line = newOutput[i];
if (
line
.map((segment) => segment.text)
.join('')
.trim().length > 0
) {
lastNonEmptyLine = i;
break;
}
}
const trimmedOutput = newOutput.slice(0, lastNonEmptyLine + 1);
const finalOutput = shellExecutionConfig.disableDynamicLineTrimming
? newOutput
: trimmedOutput;
// Using stringify for a quick deep comparison.
if (JSON.stringify(output) !== JSON.stringify(finalOutput)) {
output = finalOutput;
onOutputEvent({
type: 'data',
chunk: finalOutput,
});
}
};
if (finalRender) {
renderFn();
} else {
renderTimeout = setTimeout(renderFn, 17);
}
};
headlessTerminal.onScroll(() => {
if (!isWriting) {
render();
}
});
const handleOutput = (data: Buffer) => {
processingChain = processingChain.then(
@@ -384,10 +503,10 @@ export class ShellExecutionService {
if (isStreamingRawContent) {
const decodedChunk = decoder.decode(data, { stream: true });
isWriting = true;
headlessTerminal.write(decodedChunk, () => {
const newStrippedOutput = getFullText(headlessTerminal);
output = newStrippedOutput;
onOutputEvent({ type: 'data', chunk: newStrippedOutput });
render();
isWriting = false;
resolve();
});
} else {
@@ -414,19 +533,23 @@ export class ShellExecutionService {
({ exitCode, signal }: { exitCode: number; signal?: number }) => {
exited = true;
abortSignal.removeEventListener('abort', abortHandler);
this.activePtys.delete(ptyProcess.pid);
processingChain.then(() => {
render(true);
const finalBuffer = Buffer.concat(outputChunks);
resolve({
rawOutput: finalBuffer,
output,
output: getFullBufferText(headlessTerminal),
exitCode,
signal: signal ?? null,
error,
aborted: abortSignal.aborted,
pid: ptyProcess.pid,
executionMethod: ptyInfo?.name ?? 'node-pty',
executionMethod:
(ptyInfo?.name as 'node-pty' | 'lydell-node-pty') ??
'node-pty',
});
});
},
@@ -434,7 +557,17 @@ export class ShellExecutionService {
const abortHandler = async () => {
if (ptyProcess.pid && !exited) {
ptyProcess.kill('SIGHUP');
if (os.platform() === 'win32') {
ptyProcess.kill();
} else {
try {
// Kill the entire process group
process.kill(-ptyProcess.pid, 'SIGINT');
} catch (_e) {
// Fallback to killing just the process if the group kill fails
ptyProcess.kill('SIGINT');
}
}
}
};
@@ -459,4 +592,90 @@ export class ShellExecutionService {
};
}
}
/**
* Writes a string to the pseudo-terminal (PTY) of a running process.
*
* @param pid The process ID of the target PTY.
* @param input The string to write to the terminal.
*/
static writeToPty(pid: number, input: string): void {
if (!this.isPtyActive(pid)) {
return;
}
const activePty = this.activePtys.get(pid);
if (activePty) {
activePty.ptyProcess.write(input);
}
}
static isPtyActive(pid: number): boolean {
try {
// process.kill with signal 0 is a way to check for the existence of a process.
// It doesn't actually send a signal.
return process.kill(pid, 0);
} catch (_) {
return false;
}
}
/**
* Resizes the pseudo-terminal (PTY) of a running process.
*
* @param pid The process ID of the target PTY.
* @param cols The new number of columns.
* @param rows The new number of rows.
*/
static resizePty(pid: number, cols: number, rows: number): void {
if (!this.isPtyActive(pid)) {
return;
}
const activePty = this.activePtys.get(pid);
if (activePty) {
try {
activePty.ptyProcess.resize(cols, rows);
activePty.headlessTerminal.resize(cols, rows);
} catch (e) {
// Ignore errors if the pty has already exited, which can happen
// due to a race condition between the exit event and this call.
if (e instanceof Error && 'code' in e && e.code === 'ESRCH') {
// ignore
} else {
throw e;
}
}
}
}
/**
* Scrolls the pseudo-terminal (PTY) of a running process.
*
* @param pid The process ID of the target PTY.
* @param lines The number of lines to scroll.
*/
static scrollPty(pid: number, lines: number): void {
if (!this.isPtyActive(pid)) {
return;
}
const activePty = this.activePtys.get(pid);
if (activePty) {
try {
activePty.headlessTerminal.scrollLines(lines);
if (activePty.headlessTerminal.buffer.active.viewportY < 0) {
activePty.headlessTerminal.scrollToTop();
}
} catch (e) {
// Ignore errors if the pty has already exited, which can happen
// due to a race condition between the exit event and this call.
if (e instanceof Error && 'code' in e && e.code === 'ESRCH') {
// ignore
} else {
throw e;
}
}
}
}
}