mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
fix: compression tool (#935)
This commit is contained in:
@@ -16,11 +16,11 @@ import {
|
|||||||
|
|
||||||
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
import type { Content, GenerateContentResponse, Part } from '@google/genai';
|
||||||
import {
|
import {
|
||||||
findCompressSplitPoint,
|
|
||||||
isThinkingDefault,
|
isThinkingDefault,
|
||||||
isThinkingSupported,
|
isThinkingSupported,
|
||||||
GeminiClient,
|
GeminiClient,
|
||||||
} from './client.js';
|
} from './client.js';
|
||||||
|
import { findCompressSplitPoint } from '../services/chatCompressionService.js';
|
||||||
import {
|
import {
|
||||||
AuthType,
|
AuthType,
|
||||||
type ContentGenerator,
|
type ContentGenerator,
|
||||||
@@ -42,7 +42,6 @@ import { setSimulate429 } from '../utils/testUtils.js';
|
|||||||
import { tokenLimit } from './tokenLimits.js';
|
import { tokenLimit } from './tokenLimits.js';
|
||||||
import { ideContextStore } from '../ide/ideContext.js';
|
import { ideContextStore } from '../ide/ideContext.js';
|
||||||
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||||
import { QwenLogger } from '../telemetry/index.js';
|
|
||||||
|
|
||||||
// Mock fs module to prevent actual file system operations during tests
|
// Mock fs module to prevent actual file system operations during tests
|
||||||
const mockFileSystem = new Map<string, string>();
|
const mockFileSystem = new Map<string, string>();
|
||||||
@@ -101,6 +100,22 @@ vi.mock('../utils/errorReporting', () => ({ reportError: vi.fn() }));
|
|||||||
vi.mock('../utils/nextSpeakerChecker', () => ({
|
vi.mock('../utils/nextSpeakerChecker', () => ({
|
||||||
checkNextSpeaker: vi.fn().mockResolvedValue(null),
|
checkNextSpeaker: vi.fn().mockResolvedValue(null),
|
||||||
}));
|
}));
|
||||||
|
vi.mock('../utils/environmentContext', () => ({
|
||||||
|
getEnvironmentContext: vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValue([{ text: 'Mocked env context' }]),
|
||||||
|
getInitialChatHistory: vi.fn(async (_config, extraHistory) => [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'Mocked env context' }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the context!' }],
|
||||||
|
},
|
||||||
|
...(extraHistory ?? []),
|
||||||
|
]),
|
||||||
|
}));
|
||||||
vi.mock('../utils/generateContentResponseUtilities', () => ({
|
vi.mock('../utils/generateContentResponseUtilities', () => ({
|
||||||
getResponseText: (result: GenerateContentResponse) =>
|
getResponseText: (result: GenerateContentResponse) =>
|
||||||
result.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
|
result.candidates?.[0]?.content?.parts?.map((part) => part.text).join('') ||
|
||||||
@@ -136,6 +151,10 @@ vi.mock('../ide/ideContext.js');
|
|||||||
vi.mock('../telemetry/uiTelemetry.js', () => ({
|
vi.mock('../telemetry/uiTelemetry.js', () => ({
|
||||||
uiTelemetryService: mockUiTelemetryService,
|
uiTelemetryService: mockUiTelemetryService,
|
||||||
}));
|
}));
|
||||||
|
vi.mock('../telemetry/loggers.js', () => ({
|
||||||
|
logChatCompression: vi.fn(),
|
||||||
|
logNextSpeakerCheck: vi.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||||
@@ -619,7 +638,8 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('logs a telemetry event when compressing', async () => {
|
it('logs a telemetry event when compressing', async () => {
|
||||||
vi.spyOn(QwenLogger.prototype, 'logChatCompressionEvent');
|
const { logChatCompression } = await import('../telemetry/loggers.js');
|
||||||
|
vi.mocked(logChatCompression).mockClear();
|
||||||
|
|
||||||
const MOCKED_TOKEN_LIMIT = 1000;
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5;
|
const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5;
|
||||||
@@ -627,19 +647,37 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
||||||
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
||||||
});
|
});
|
||||||
const history = [{ role: 'user', parts: [{ text: '...history...' }] }];
|
// Need multiple history items so there's something to compress
|
||||||
|
const history = [
|
||||||
|
{ role: 'user', parts: [{ text: '...history 1...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 2...' }] },
|
||||||
|
{ role: 'user', parts: [{ text: '...history 3...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 4...' }] },
|
||||||
|
];
|
||||||
mockGetHistory.mockReturnValue(history);
|
mockGetHistory.mockReturnValue(history);
|
||||||
|
|
||||||
|
// Token count needs to be ABOVE the threshold to trigger compression
|
||||||
const originalTokenCount =
|
const originalTokenCount =
|
||||||
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD + 1;
|
||||||
|
|
||||||
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
);
|
);
|
||||||
|
|
||||||
// We need to control the estimated new token count.
|
// Mock the summary response from the chat
|
||||||
// We mock startChat to return a chat with a known history.
|
|
||||||
const summaryText = 'This is a summary.';
|
const summaryText = 'This is a summary.';
|
||||||
|
mockGenerateContentFn.mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: summaryText }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
|
||||||
|
// Mock startChat to complete the compression flow
|
||||||
const splitPoint = findCompressSplitPoint(history, 0.7);
|
const splitPoint = findCompressSplitPoint(history, 0.7);
|
||||||
const historyToKeep = history.slice(splitPoint);
|
const historyToKeep = history.slice(splitPoint);
|
||||||
const newCompressedHistory: Content[] = [
|
const newCompressedHistory: Content[] = [
|
||||||
@@ -659,52 +697,36 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
.fn()
|
.fn()
|
||||||
.mockResolvedValue(mockNewChat as GeminiChat);
|
.mockResolvedValue(mockNewChat as GeminiChat);
|
||||||
|
|
||||||
const totalChars = newCompressedHistory.reduce(
|
|
||||||
(total, content) => total + JSON.stringify(content).length,
|
|
||||||
0,
|
|
||||||
);
|
|
||||||
const newTokenCount = Math.floor(totalChars / 4);
|
|
||||||
|
|
||||||
// Mock the summary response from the chat
|
|
||||||
mockGenerateContentFn.mockResolvedValue({
|
|
||||||
candidates: [
|
|
||||||
{
|
|
||||||
content: {
|
|
||||||
role: 'model',
|
|
||||||
parts: [{ text: summaryText }],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
],
|
|
||||||
} as unknown as GenerateContentResponse);
|
|
||||||
|
|
||||||
await client.tryCompressChat('prompt-id-3', false);
|
await client.tryCompressChat('prompt-id-3', false);
|
||||||
|
|
||||||
expect(QwenLogger.prototype.logChatCompressionEvent).toHaveBeenCalledWith(
|
expect(logChatCompression).toHaveBeenCalledWith(
|
||||||
|
expect.anything(),
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
tokens_before: originalTokenCount,
|
tokens_before: originalTokenCount,
|
||||||
tokens_after: newTokenCount,
|
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledWith(
|
expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalled();
|
||||||
newTokenCount,
|
|
||||||
);
|
|
||||||
expect(uiTelemetryService.setLastPromptTokenCount).toHaveBeenCalledTimes(
|
|
||||||
1,
|
|
||||||
);
|
|
||||||
});
|
});
|
||||||
|
|
||||||
it('should trigger summarization if token count is at threshold with contextPercentageThreshold setting', async () => {
|
it('should trigger summarization if token count is above threshold with contextPercentageThreshold setting', async () => {
|
||||||
const MOCKED_TOKEN_LIMIT = 1000;
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5;
|
const MOCKED_CONTEXT_PERCENTAGE_THRESHOLD = 0.5;
|
||||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||||
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
vi.spyOn(client['config'], 'getChatCompression').mockReturnValue({
|
||||||
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
contextPercentageThreshold: MOCKED_CONTEXT_PERCENTAGE_THRESHOLD,
|
||||||
});
|
});
|
||||||
const history = [{ role: 'user', parts: [{ text: '...history...' }] }];
|
// Need multiple history items so there's something to compress
|
||||||
|
const history = [
|
||||||
|
{ role: 'user', parts: [{ text: '...history 1...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 2...' }] },
|
||||||
|
{ role: 'user', parts: [{ text: '...history 3...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 4...' }] },
|
||||||
|
];
|
||||||
mockGetHistory.mockReturnValue(history);
|
mockGetHistory.mockReturnValue(history);
|
||||||
|
|
||||||
|
// Token count needs to be ABOVE the threshold to trigger compression
|
||||||
const originalTokenCount =
|
const originalTokenCount =
|
||||||
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD;
|
MOCKED_TOKEN_LIMIT * MOCKED_CONTEXT_PERCENTAGE_THRESHOLD + 1;
|
||||||
|
|
||||||
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
@@ -864,7 +886,13 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
it('should always trigger summarization when force is true, regardless of token count', async () => {
|
||||||
const history = [{ role: 'user', parts: [{ text: '...history...' }] }];
|
// Need multiple history items so there's something to compress
|
||||||
|
const history = [
|
||||||
|
{ role: 'user', parts: [{ text: '...history 1...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 2...' }] },
|
||||||
|
{ role: 'user', parts: [{ text: '...history 3...' }] },
|
||||||
|
{ role: 'model', parts: [{ text: '...history 4...' }] },
|
||||||
|
];
|
||||||
mockGetHistory.mockReturnValue(history);
|
mockGetHistory.mockReturnValue(history);
|
||||||
|
|
||||||
const originalTokenCount = 100; // Well below threshold, but > estimated new count
|
const originalTokenCount = 100; // Well below threshold, but > estimated new count
|
||||||
|
|||||||
@@ -25,13 +25,11 @@ import {
|
|||||||
import type { ContentGenerator } from './contentGenerator.js';
|
import type { ContentGenerator } from './contentGenerator.js';
|
||||||
import { GeminiChat } from './geminiChat.js';
|
import { GeminiChat } from './geminiChat.js';
|
||||||
import {
|
import {
|
||||||
getCompressionPrompt,
|
|
||||||
getCoreSystemPrompt,
|
getCoreSystemPrompt,
|
||||||
getCustomSystemPrompt,
|
getCustomSystemPrompt,
|
||||||
getPlanModeSystemReminder,
|
getPlanModeSystemReminder,
|
||||||
getSubagentSystemReminder,
|
getSubagentSystemReminder,
|
||||||
} from './prompts.js';
|
} from './prompts.js';
|
||||||
import { tokenLimit } from './tokenLimits.js';
|
|
||||||
import {
|
import {
|
||||||
CompressionStatus,
|
CompressionStatus,
|
||||||
GeminiEventType,
|
GeminiEventType,
|
||||||
@@ -42,6 +40,11 @@ import {
|
|||||||
|
|
||||||
// Services
|
// Services
|
||||||
import { type ChatRecordingService } from '../services/chatRecordingService.js';
|
import { type ChatRecordingService } from '../services/chatRecordingService.js';
|
||||||
|
import {
|
||||||
|
ChatCompressionService,
|
||||||
|
COMPRESSION_PRESERVE_THRESHOLD,
|
||||||
|
COMPRESSION_TOKEN_THRESHOLD,
|
||||||
|
} from '../services/chatCompressionService.js';
|
||||||
import { LoopDetectionService } from '../services/loopDetectionService.js';
|
import { LoopDetectionService } from '../services/loopDetectionService.js';
|
||||||
|
|
||||||
// Tools
|
// Tools
|
||||||
@@ -50,21 +53,18 @@ import { TaskTool } from '../tools/task.js';
|
|||||||
// Telemetry
|
// Telemetry
|
||||||
import {
|
import {
|
||||||
NextSpeakerCheckEvent,
|
NextSpeakerCheckEvent,
|
||||||
logChatCompression,
|
|
||||||
logNextSpeakerCheck,
|
logNextSpeakerCheck,
|
||||||
makeChatCompressionEvent,
|
|
||||||
uiTelemetryService,
|
|
||||||
} from '../telemetry/index.js';
|
} from '../telemetry/index.js';
|
||||||
|
|
||||||
// Utilities
|
// Utilities
|
||||||
import {
|
import {
|
||||||
getDirectoryContextString,
|
getDirectoryContextString,
|
||||||
getEnvironmentContext,
|
getInitialChatHistory,
|
||||||
} from '../utils/environmentContext.js';
|
} from '../utils/environmentContext.js';
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
||||||
import { flatMapTextParts, getResponseText } from '../utils/partUtils.js';
|
import { flatMapTextParts } from '../utils/partUtils.js';
|
||||||
import { retryWithBackoff } from '../utils/retry.js';
|
import { retryWithBackoff } from '../utils/retry.js';
|
||||||
|
|
||||||
// IDE integration
|
// IDE integration
|
||||||
@@ -85,68 +85,8 @@ export function isThinkingDefault(model: string) {
|
|||||||
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
return model.startsWith('gemini-2.5') || model === DEFAULT_GEMINI_MODEL_AUTO;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the index of the oldest item to keep when compressing. May return
|
|
||||||
* contents.length which indicates that everything should be compressed.
|
|
||||||
*
|
|
||||||
* Exported for testing purposes.
|
|
||||||
*/
|
|
||||||
export function findCompressSplitPoint(
|
|
||||||
contents: Content[],
|
|
||||||
fraction: number,
|
|
||||||
): number {
|
|
||||||
if (fraction <= 0 || fraction >= 1) {
|
|
||||||
throw new Error('Fraction must be between 0 and 1');
|
|
||||||
}
|
|
||||||
|
|
||||||
const charCounts = contents.map((content) => JSON.stringify(content).length);
|
|
||||||
const totalCharCount = charCounts.reduce((a, b) => a + b, 0);
|
|
||||||
const targetCharCount = totalCharCount * fraction;
|
|
||||||
|
|
||||||
let lastSplitPoint = 0; // 0 is always valid (compress nothing)
|
|
||||||
let cumulativeCharCount = 0;
|
|
||||||
for (let i = 0; i < contents.length; i++) {
|
|
||||||
const content = contents[i];
|
|
||||||
if (
|
|
||||||
content.role === 'user' &&
|
|
||||||
!content.parts?.some((part) => !!part.functionResponse)
|
|
||||||
) {
|
|
||||||
if (cumulativeCharCount >= targetCharCount) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
lastSplitPoint = i;
|
|
||||||
}
|
|
||||||
cumulativeCharCount += charCounts[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
// We found no split points after targetCharCount.
|
|
||||||
// Check if it's safe to compress everything.
|
|
||||||
const lastContent = contents[contents.length - 1];
|
|
||||||
if (
|
|
||||||
lastContent?.role === 'model' &&
|
|
||||||
!lastContent?.parts?.some((part) => part.functionCall)
|
|
||||||
) {
|
|
||||||
return contents.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Can't compress everything so just compress at last splitpoint.
|
|
||||||
return lastSplitPoint;
|
|
||||||
}
|
|
||||||
|
|
||||||
const MAX_TURNS = 100;
|
const MAX_TURNS = 100;
|
||||||
|
|
||||||
/**
|
|
||||||
* Threshold for compression token count as a fraction of the model's token limit.
|
|
||||||
* If the chat history exceeds this threshold, it will be compressed.
|
|
||||||
*/
|
|
||||||
const COMPRESSION_TOKEN_THRESHOLD = 0.7;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The fraction of the latest chat history to keep. A value of 0.3
|
|
||||||
* means that only the last 30% of the chat history will be kept after compression.
|
|
||||||
*/
|
|
||||||
const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private chat?: GeminiChat;
|
private chat?: GeminiChat;
|
||||||
private readonly generateContentConfig: GenerateContentConfig = {
|
private readonly generateContentConfig: GenerateContentConfig = {
|
||||||
@@ -243,23 +183,13 @@ export class GeminiClient {
|
|||||||
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
||||||
this.forceFullIdeContext = true;
|
this.forceFullIdeContext = true;
|
||||||
this.hasFailedCompressionAttempt = false;
|
this.hasFailedCompressionAttempt = false;
|
||||||
const envParts = await getEnvironmentContext(this.config);
|
|
||||||
|
|
||||||
const toolRegistry = this.config.getToolRegistry();
|
const toolRegistry = this.config.getToolRegistry();
|
||||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||||
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
const tools: Tool[] = [{ functionDeclarations: toolDeclarations }];
|
||||||
|
|
||||||
const history: Content[] = [
|
const history = await getInitialChatHistory(this.config, extraHistory);
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
parts: envParts,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: 'model',
|
|
||||||
parts: [{ text: 'Got it. Thanks for the context!' }],
|
|
||||||
},
|
|
||||||
...(extraHistory ?? []),
|
|
||||||
];
|
|
||||||
try {
|
try {
|
||||||
const userMemory = this.config.getUserMemory();
|
const userMemory = this.config.getUserMemory();
|
||||||
const model = this.config.getModel();
|
const model = this.config.getModel();
|
||||||
@@ -503,14 +433,15 @@ export class GeminiClient {
|
|||||||
userMemory,
|
userMemory,
|
||||||
this.config.getModel(),
|
this.config.getModel(),
|
||||||
);
|
);
|
||||||
const environment = await getEnvironmentContext(this.config);
|
const initialHistory = await getInitialChatHistory(this.config);
|
||||||
|
|
||||||
// Create a mock request content to count total tokens
|
// Create a mock request content to count total tokens
|
||||||
const mockRequestContent = [
|
const mockRequestContent = [
|
||||||
{
|
{
|
||||||
role: 'system' as const,
|
role: 'system' as const,
|
||||||
parts: [{ text: systemPrompt }, ...environment],
|
parts: [{ text: systemPrompt }],
|
||||||
},
|
},
|
||||||
|
...initialHistory,
|
||||||
...currentHistory,
|
...currentHistory,
|
||||||
];
|
];
|
||||||
|
|
||||||
@@ -732,127 +663,37 @@ export class GeminiClient {
|
|||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
force: boolean = false,
|
force: boolean = false,
|
||||||
): Promise<ChatCompressionInfo> {
|
): Promise<ChatCompressionInfo> {
|
||||||
const model = this.config.getModel();
|
const compressionService = new ChatCompressionService();
|
||||||
|
|
||||||
const curatedHistory = this.getChat().getHistory(true);
|
const { newHistory, info } = await compressionService.compress(
|
||||||
|
this.getChat(),
|
||||||
|
prompt_id,
|
||||||
|
force,
|
||||||
|
this.config.getModel(),
|
||||||
|
this.config,
|
||||||
|
this.hasFailedCompressionAttempt,
|
||||||
|
);
|
||||||
|
|
||||||
// Regardless of `force`, don't do anything if the history is empty.
|
// Handle compression result
|
||||||
if (
|
if (info.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||||
curatedHistory.length === 0 ||
|
// Success: update chat with new compressed history
|
||||||
(this.hasFailedCompressionAttempt && !force)
|
if (newHistory) {
|
||||||
|
this.chat = await this.startChat(newHistory);
|
||||||
|
this.forceFullIdeContext = true;
|
||||||
|
}
|
||||||
|
} else if (
|
||||||
|
info.compressionStatus ===
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT ||
|
||||||
|
info.compressionStatus ===
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY
|
||||||
) {
|
) {
|
||||||
return {
|
// Track failed attempts (only mark as failed if not forced)
|
||||||
originalTokenCount: 0,
|
if (!force) {
|
||||||
newTokenCount: 0,
|
this.hasFailedCompressionAttempt = true;
|
||||||
compressionStatus: CompressionStatus.NOOP,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
const originalTokenCount = uiTelemetryService.getLastPromptTokenCount();
|
|
||||||
|
|
||||||
const contextPercentageThreshold =
|
|
||||||
this.config.getChatCompression()?.contextPercentageThreshold;
|
|
||||||
|
|
||||||
// Don't compress if not forced and we are under the limit.
|
|
||||||
if (!force) {
|
|
||||||
const threshold =
|
|
||||||
contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD;
|
|
||||||
if (originalTokenCount < threshold * tokenLimit(model)) {
|
|
||||||
return {
|
|
||||||
originalTokenCount,
|
|
||||||
newTokenCount: originalTokenCount,
|
|
||||||
compressionStatus: CompressionStatus.NOOP,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const splitPoint = findCompressSplitPoint(
|
return info;
|
||||||
curatedHistory,
|
|
||||||
1 - COMPRESSION_PRESERVE_THRESHOLD,
|
|
||||||
);
|
|
||||||
|
|
||||||
const historyToCompress = curatedHistory.slice(0, splitPoint);
|
|
||||||
const historyToKeep = curatedHistory.slice(splitPoint);
|
|
||||||
|
|
||||||
const summaryResponse = await this.config
|
|
||||||
.getContentGenerator()
|
|
||||||
.generateContent(
|
|
||||||
{
|
|
||||||
model,
|
|
||||||
contents: [
|
|
||||||
...historyToCompress,
|
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
parts: [
|
|
||||||
{
|
|
||||||
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
],
|
|
||||||
config: {
|
|
||||||
systemInstruction: { text: getCompressionPrompt() },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
prompt_id,
|
|
||||||
);
|
|
||||||
const summary = getResponseText(summaryResponse) ?? '';
|
|
||||||
|
|
||||||
const chat = await this.startChat([
|
|
||||||
{
|
|
||||||
role: 'user',
|
|
||||||
parts: [{ text: summary }],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
role: 'model',
|
|
||||||
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
|
||||||
},
|
|
||||||
...historyToKeep,
|
|
||||||
]);
|
|
||||||
this.forceFullIdeContext = true;
|
|
||||||
|
|
||||||
// Estimate token count 1 token ≈ 4 characters
|
|
||||||
const newTokenCount = Math.floor(
|
|
||||||
chat
|
|
||||||
.getHistory()
|
|
||||||
.reduce((total, content) => total + JSON.stringify(content).length, 0) /
|
|
||||||
4,
|
|
||||||
);
|
|
||||||
|
|
||||||
logChatCompression(
|
|
||||||
this.config,
|
|
||||||
makeChatCompressionEvent({
|
|
||||||
tokens_before: originalTokenCount,
|
|
||||||
tokens_after: newTokenCount,
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
if (newTokenCount > originalTokenCount) {
|
|
||||||
this.hasFailedCompressionAttempt = !force && true;
|
|
||||||
return {
|
|
||||||
originalTokenCount,
|
|
||||||
newTokenCount,
|
|
||||||
compressionStatus:
|
|
||||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
this.chat = chat; // Chat compression successful, set new state.
|
|
||||||
uiTelemetryService.setLastPromptTokenCount(newTokenCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
logChatCompression(
|
|
||||||
this.config,
|
|
||||||
makeChatCompressionEvent({
|
|
||||||
tokens_before: originalTokenCount,
|
|
||||||
tokens_after: newTokenCount,
|
|
||||||
}),
|
|
||||||
);
|
|
||||||
|
|
||||||
return {
|
|
||||||
originalTokenCount,
|
|
||||||
newTokenCount,
|
|
||||||
compressionStatus: CompressionStatus.COMPRESSED,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -153,6 +153,9 @@ export enum CompressionStatus {
|
|||||||
/** The compression failed due to an error counting tokens */
|
/** The compression failed due to an error counting tokens */
|
||||||
COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||||
|
|
||||||
|
/** The compression failed due to receiving an empty or null summary */
|
||||||
|
COMPRESSION_FAILED_EMPTY_SUMMARY,
|
||||||
|
|
||||||
/** The compression was not necessary and no action was taken */
|
/** The compression was not necessary and no action was taken */
|
||||||
NOOP,
|
NOOP,
|
||||||
}
|
}
|
||||||
|
|||||||
372
packages/core/src/services/chatCompressionService.test.ts
Normal file
372
packages/core/src/services/chatCompressionService.test.ts
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||||
|
import {
|
||||||
|
ChatCompressionService,
|
||||||
|
findCompressSplitPoint,
|
||||||
|
} from './chatCompressionService.js';
|
||||||
|
import type { Content, GenerateContentResponse } from '@google/genai';
|
||||||
|
import { CompressionStatus } from '../core/turn.js';
|
||||||
|
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||||
|
import { tokenLimit } from '../core/tokenLimits.js';
|
||||||
|
import type { GeminiChat } from '../core/geminiChat.js';
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
|
import { getInitialChatHistory } from '../utils/environmentContext.js';
|
||||||
|
import type { ContentGenerator } from '../core/contentGenerator.js';
|
||||||
|
|
||||||
|
vi.mock('../telemetry/uiTelemetry.js');
|
||||||
|
vi.mock('../core/tokenLimits.js');
|
||||||
|
vi.mock('../telemetry/loggers.js');
|
||||||
|
vi.mock('../utils/environmentContext.js');
|
||||||
|
|
||||||
|
describe('findCompressSplitPoint', () => {
|
||||||
|
it('should throw an error for non-positive numbers', () => {
|
||||||
|
expect(() => findCompressSplitPoint([], 0)).toThrow(
|
||||||
|
'Fraction must be between 0 and 1',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should throw an error for a fraction greater than or equal to 1', () => {
|
||||||
|
expect(() => findCompressSplitPoint([], 1)).toThrow(
|
||||||
|
'Fraction must be between 0 and 1',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle an empty history', () => {
|
||||||
|
expect(findCompressSplitPoint([], 0.5)).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a fraction in the middle', () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
|
||||||
|
];
|
||||||
|
expect(findCompressSplitPoint(history, 0.5)).toBe(4);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a fraction of last index', () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (19%)
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (40%)
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (60%)
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (80%)
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the fifth message.' }] }, // JSON length: 65 (100%)
|
||||||
|
];
|
||||||
|
expect(findCompressSplitPoint(history, 0.9)).toBe(4);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a fraction of after last index', () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 (24%)
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the second message.' }] }, // JSON length: 68 (50%)
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the third message.' }] }, // JSON length: 66 (74%)
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the fourth message.' }] }, // JSON length: 68 (100%)
|
||||||
|
];
|
||||||
|
expect(findCompressSplitPoint(history, 0.8)).toBe(4);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return earlier splitpoint if no valid ones are after threshhold', () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the first message.' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'This is the second message.' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'This is the third message.' }] },
|
||||||
|
{ role: 'model', parts: [{ functionCall: { name: 'foo', args: {} } }] },
|
||||||
|
];
|
||||||
|
// Can't return 4 because the previous item has a function call.
|
||||||
|
expect(findCompressSplitPoint(history, 0.99)).toBe(2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle a history with only one item', () => {
|
||||||
|
const historyWithEmptyParts: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Message 1' }] },
|
||||||
|
];
|
||||||
|
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should handle history with weird parts', () => {
|
||||||
|
const historyWithEmptyParts: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Message 1' }] },
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ fileData: { fileUri: 'derp', mimeType: 'text/plain' } }],
|
||||||
|
},
|
||||||
|
{ role: 'user', parts: [{ text: 'Message 2' }] },
|
||||||
|
];
|
||||||
|
expect(findCompressSplitPoint(historyWithEmptyParts, 0.5)).toBe(2);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('ChatCompressionService', () => {
|
||||||
|
let service: ChatCompressionService;
|
||||||
|
let mockChat: GeminiChat;
|
||||||
|
let mockConfig: Config;
|
||||||
|
const mockModel = 'gemini-pro';
|
||||||
|
const mockPromptId = 'test-prompt-id';
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
service = new ChatCompressionService();
|
||||||
|
mockChat = {
|
||||||
|
getHistory: vi.fn(),
|
||||||
|
} as unknown as GeminiChat;
|
||||||
|
mockConfig = {
|
||||||
|
getChatCompression: vi.fn(),
|
||||||
|
getContentGenerator: vi.fn(),
|
||||||
|
} as unknown as Config;
|
||||||
|
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(500);
|
||||||
|
vi.mocked(getInitialChatHistory).mockImplementation(
|
||||||
|
async (_config, extraHistory) => extraHistory || [],
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return NOOP if history is empty', async () => {
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue([]);
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
false,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP);
|
||||||
|
expect(result.newHistory).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return NOOP if previously failed and not forced', async () => {
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: 'hi' }] },
|
||||||
|
]);
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
false,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP);
|
||||||
|
expect(result.newHistory).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return NOOP if under token threshold and not forced', async () => {
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: 'hi' }] },
|
||||||
|
]);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(600);
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
// Threshold is 0.7 * 1000 = 700. 600 < 700, so NOOP.
|
||||||
|
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
false,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
expect(result.info.compressionStatus).toBe(CompressionStatus.NOOP);
|
||||||
|
expect(result.newHistory).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should compress if over token threshold', async () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'msg1' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg2' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'msg3' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg4' }] },
|
||||||
|
];
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue(history);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(800);
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: 'Summary' }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
} as unknown as ContentGenerator);
|
||||||
|
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
false,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
|
||||||
|
expect(result.newHistory).not.toBeNull();
|
||||||
|
expect(result.newHistory![0].parts![0].text).toBe('Summary');
|
||||||
|
expect(mockGenerateContent).toHaveBeenCalled();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should force compress even if under threshold', async () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'msg1' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg2' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'msg3' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg4' }] },
|
||||||
|
];
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue(history);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(100);
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
|
||||||
|
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: 'Summary' }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
} as unknown as ContentGenerator);
|
||||||
|
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
true, // forced
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.info.compressionStatus).toBe(CompressionStatus.COMPRESSED);
|
||||||
|
expect(result.newHistory).not.toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return FAILED if new token count is inflated', async () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'msg1' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg2' }] },
|
||||||
|
];
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue(history);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(10);
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
|
||||||
|
const longSummary = 'a'.repeat(1000); // Long summary to inflate token count
|
||||||
|
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: longSummary }],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
} as unknown as ContentGenerator);
|
||||||
|
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
true,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.info.compressionStatus).toBe(
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
);
|
||||||
|
expect(result.newHistory).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return FAILED if summary is empty string', async () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'msg1' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg2' }] },
|
||||||
|
];
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue(history);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(100);
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
|
||||||
|
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: '' }], // Empty summary
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
} as unknown as ContentGenerator);
|
||||||
|
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
true,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.info.compressionStatus).toBe(
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY,
|
||||||
|
);
|
||||||
|
expect(result.newHistory).toBeNull();
|
||||||
|
expect(result.info.originalTokenCount).toBe(100);
|
||||||
|
expect(result.info.newTokenCount).toBe(100);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('should return FAILED if summary is only whitespace', async () => {
|
||||||
|
const history: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'msg1' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'msg2' }] },
|
||||||
|
];
|
||||||
|
vi.mocked(mockChat.getHistory).mockReturnValue(history);
|
||||||
|
vi.mocked(uiTelemetryService.getLastPromptTokenCount).mockReturnValue(100);
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
|
||||||
|
const mockGenerateContent = vi.fn().mockResolvedValue({
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ text: ' \n\t ' }], // Only whitespace
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
} as unknown as GenerateContentResponse);
|
||||||
|
vi.mocked(mockConfig.getContentGenerator).mockReturnValue({
|
||||||
|
generateContent: mockGenerateContent,
|
||||||
|
} as unknown as ContentGenerator);
|
||||||
|
|
||||||
|
const result = await service.compress(
|
||||||
|
mockChat,
|
||||||
|
mockPromptId,
|
||||||
|
true,
|
||||||
|
mockModel,
|
||||||
|
mockConfig,
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(result.info.compressionStatus).toBe(
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY,
|
||||||
|
);
|
||||||
|
expect(result.newHistory).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
235
packages/core/src/services/chatCompressionService.ts
Normal file
235
packages/core/src/services/chatCompressionService.ts
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Content } from '@google/genai';
|
||||||
|
import type { Config } from '../config/config.js';
|
||||||
|
import type { GeminiChat } from '../core/geminiChat.js';
|
||||||
|
import { type ChatCompressionInfo, CompressionStatus } from '../core/turn.js';
|
||||||
|
import { uiTelemetryService } from '../telemetry/uiTelemetry.js';
|
||||||
|
import { tokenLimit } from '../core/tokenLimits.js';
|
||||||
|
import { getCompressionPrompt } from '../core/prompts.js';
|
||||||
|
import { getResponseText } from '../utils/partUtils.js';
|
||||||
|
import { logChatCompression } from '../telemetry/loggers.js';
|
||||||
|
import { makeChatCompressionEvent } from '../telemetry/types.js';
|
||||||
|
import { getInitialChatHistory } from '../utils/environmentContext.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Threshold for compression token count as a fraction of the model's token limit.
|
||||||
|
* If the chat history exceeds this threshold, it will be compressed.
|
||||||
|
*/
|
||||||
|
export const COMPRESSION_TOKEN_THRESHOLD = 0.7;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The fraction of the latest chat history to keep. A value of 0.3
|
||||||
|
* means that only the last 30% of the chat history will be kept after compression.
|
||||||
|
*/
|
||||||
|
export const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the index of the oldest item to keep when compressing. May return
|
||||||
|
* contents.length which indicates that everything should be compressed.
|
||||||
|
*
|
||||||
|
* Exported for testing purposes.
|
||||||
|
*/
|
||||||
|
export function findCompressSplitPoint(
|
||||||
|
contents: Content[],
|
||||||
|
fraction: number,
|
||||||
|
): number {
|
||||||
|
if (fraction <= 0 || fraction >= 1) {
|
||||||
|
throw new Error('Fraction must be between 0 and 1');
|
||||||
|
}
|
||||||
|
|
||||||
|
const charCounts = contents.map((content) => JSON.stringify(content).length);
|
||||||
|
const totalCharCount = charCounts.reduce((a, b) => a + b, 0);
|
||||||
|
const targetCharCount = totalCharCount * fraction;
|
||||||
|
|
||||||
|
let lastSplitPoint = 0; // 0 is always valid (compress nothing)
|
||||||
|
let cumulativeCharCount = 0;
|
||||||
|
for (let i = 0; i < contents.length; i++) {
|
||||||
|
const content = contents[i];
|
||||||
|
if (
|
||||||
|
content.role === 'user' &&
|
||||||
|
!content.parts?.some((part) => !!part.functionResponse)
|
||||||
|
) {
|
||||||
|
if (cumulativeCharCount >= targetCharCount) {
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
lastSplitPoint = i;
|
||||||
|
}
|
||||||
|
cumulativeCharCount += charCounts[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// We found no split points after targetCharCount.
|
||||||
|
// Check if it's safe to compress everything.
|
||||||
|
const lastContent = contents[contents.length - 1];
|
||||||
|
if (
|
||||||
|
lastContent?.role === 'model' &&
|
||||||
|
!lastContent?.parts?.some((part) => part.functionCall)
|
||||||
|
) {
|
||||||
|
return contents.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can't compress everything so just compress at last splitpoint.
|
||||||
|
return lastSplitPoint;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ChatCompressionService {
|
||||||
|
async compress(
|
||||||
|
chat: GeminiChat,
|
||||||
|
promptId: string,
|
||||||
|
force: boolean,
|
||||||
|
model: string,
|
||||||
|
config: Config,
|
||||||
|
hasFailedCompressionAttempt: boolean,
|
||||||
|
): Promise<{ newHistory: Content[] | null; info: ChatCompressionInfo }> {
|
||||||
|
const curatedHistory = chat.getHistory(true);
|
||||||
|
|
||||||
|
// Regardless of `force`, don't do anything if the history is empty.
|
||||||
|
if (
|
||||||
|
curatedHistory.length === 0 ||
|
||||||
|
(hasFailedCompressionAttempt && !force)
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
originalTokenCount: 0,
|
||||||
|
newTokenCount: 0,
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const originalTokenCount = uiTelemetryService.getLastPromptTokenCount();
|
||||||
|
|
||||||
|
const contextPercentageThreshold =
|
||||||
|
config.getChatCompression()?.contextPercentageThreshold;
|
||||||
|
|
||||||
|
// Don't compress if not forced and we are under the limit.
|
||||||
|
if (!force) {
|
||||||
|
const threshold =
|
||||||
|
contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD;
|
||||||
|
if (originalTokenCount < threshold * tokenLimit(model)) {
|
||||||
|
return {
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount: originalTokenCount,
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const splitPoint = findCompressSplitPoint(
|
||||||
|
curatedHistory,
|
||||||
|
1 - COMPRESSION_PRESERVE_THRESHOLD,
|
||||||
|
);
|
||||||
|
|
||||||
|
const historyToCompress = curatedHistory.slice(0, splitPoint);
|
||||||
|
const historyToKeep = curatedHistory.slice(splitPoint);
|
||||||
|
|
||||||
|
if (historyToCompress.length === 0) {
|
||||||
|
return {
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount: originalTokenCount,
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const summaryResponse = await config.getContentGenerator().generateContent(
|
||||||
|
{
|
||||||
|
model,
|
||||||
|
contents: [
|
||||||
|
...historyToCompress,
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
text: 'First, reason in your scratchpad. Then, generate the <state_snapshot>.',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
config: {
|
||||||
|
systemInstruction: getCompressionPrompt(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
promptId,
|
||||||
|
);
|
||||||
|
const summary = getResponseText(summaryResponse) ?? '';
|
||||||
|
const isSummaryEmpty = !summary || summary.trim().length === 0;
|
||||||
|
|
||||||
|
let newTokenCount = originalTokenCount;
|
||||||
|
let extraHistory: Content[] = [];
|
||||||
|
|
||||||
|
if (!isSummaryEmpty) {
|
||||||
|
extraHistory = [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: summary }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the additional context!' }],
|
||||||
|
},
|
||||||
|
...historyToKeep,
|
||||||
|
];
|
||||||
|
|
||||||
|
// Use a shared utility to construct the initial history for an accurate token count.
|
||||||
|
const fullNewHistory = await getInitialChatHistory(config, extraHistory);
|
||||||
|
|
||||||
|
// Estimate token count 1 token ≈ 4 characters
|
||||||
|
newTokenCount = Math.floor(
|
||||||
|
fullNewHistory.reduce(
|
||||||
|
(total, content) => total + JSON.stringify(content).length,
|
||||||
|
0,
|
||||||
|
) / 4,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
logChatCompression(
|
||||||
|
config,
|
||||||
|
makeChatCompressionEvent({
|
||||||
|
tokens_before: originalTokenCount,
|
||||||
|
tokens_after: newTokenCount,
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
if (isSummaryEmpty) {
|
||||||
|
return {
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount: originalTokenCount,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSION_FAILED_EMPTY_SUMMARY,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else if (newTokenCount > originalTokenCount) {
|
||||||
|
return {
|
||||||
|
newHistory: null,
|
||||||
|
info: {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount,
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
uiTelemetryService.setLastPromptTokenCount(newTokenCount);
|
||||||
|
return {
|
||||||
|
newHistory: extraHistory,
|
||||||
|
info: {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -32,7 +32,6 @@ import { GeminiChat } from '../core/geminiChat.js';
|
|||||||
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
import { executeToolCall } from '../core/nonInteractiveToolExecutor.js';
|
||||||
import type { ToolRegistry } from '../tools/tool-registry.js';
|
import type { ToolRegistry } from '../tools/tool-registry.js';
|
||||||
import { type AnyDeclarativeTool } from '../tools/tools.js';
|
import { type AnyDeclarativeTool } from '../tools/tools.js';
|
||||||
import { getEnvironmentContext } from '../utils/environmentContext.js';
|
|
||||||
import { ContextState, SubAgentScope } from './subagent.js';
|
import { ContextState, SubAgentScope } from './subagent.js';
|
||||||
import type {
|
import type {
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
@@ -44,7 +43,20 @@ import { SubagentTerminateMode } from './types.js';
|
|||||||
|
|
||||||
vi.mock('../core/geminiChat.js');
|
vi.mock('../core/geminiChat.js');
|
||||||
vi.mock('../core/contentGenerator.js');
|
vi.mock('../core/contentGenerator.js');
|
||||||
vi.mock('../utils/environmentContext.js');
|
vi.mock('../utils/environmentContext.js', () => ({
|
||||||
|
getEnvironmentContext: vi.fn().mockResolvedValue([{ text: 'Env Context' }]),
|
||||||
|
getInitialChatHistory: vi.fn(async (_config, extraHistory) => [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: 'Env Context' }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the context!' }],
|
||||||
|
},
|
||||||
|
...(extraHistory ?? []),
|
||||||
|
]),
|
||||||
|
}));
|
||||||
vi.mock('../core/nonInteractiveToolExecutor.js');
|
vi.mock('../core/nonInteractiveToolExecutor.js');
|
||||||
vi.mock('../ide/ide-client.js');
|
vi.mock('../ide/ide-client.js');
|
||||||
vi.mock('../core/client.js');
|
vi.mock('../core/client.js');
|
||||||
@@ -174,9 +186,6 @@ describe('subagent.ts', () => {
|
|||||||
beforeEach(async () => {
|
beforeEach(async () => {
|
||||||
vi.clearAllMocks();
|
vi.clearAllMocks();
|
||||||
|
|
||||||
vi.mocked(getEnvironmentContext).mockResolvedValue([
|
|
||||||
{ text: 'Env Context' },
|
|
||||||
]);
|
|
||||||
vi.mocked(createContentGenerator).mockResolvedValue({
|
vi.mocked(createContentGenerator).mockResolvedValue({
|
||||||
getGenerativeModel: vi.fn(),
|
getGenerativeModel: vi.fn(),
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import type {
|
|||||||
ToolConfirmationOutcome,
|
ToolConfirmationOutcome,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
} from '../tools/tools.js';
|
} from '../tools/tools.js';
|
||||||
import { getEnvironmentContext } from '../utils/environmentContext.js';
|
import { getInitialChatHistory } from '../utils/environmentContext.js';
|
||||||
import type {
|
import type {
|
||||||
Content,
|
Content,
|
||||||
Part,
|
Part,
|
||||||
@@ -807,11 +807,7 @@ export class SubAgentScope {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
const envParts = await getEnvironmentContext(this.runtimeContext);
|
const envHistory = await getInitialChatHistory(this.runtimeContext);
|
||||||
const envHistory: Content[] = [
|
|
||||||
{ role: 'user', parts: envParts },
|
|
||||||
{ role: 'model', parts: [{ text: 'Got it. Thanks for the context!' }] },
|
|
||||||
];
|
|
||||||
|
|
||||||
const start_history = [
|
const start_history = [
|
||||||
...envHistory,
|
...envHistory,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { Part } from '@google/genai';
|
import type { Content, Part } from '@google/genai';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import { getFolderStructure } from './getFolderStructure.js';
|
import { getFolderStructure } from './getFolderStructure.js';
|
||||||
|
|
||||||
@@ -107,3 +107,23 @@ ${directoryContext}
|
|||||||
|
|
||||||
return initialParts;
|
return initialParts;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function getInitialChatHistory(
|
||||||
|
config: Config,
|
||||||
|
extraHistory?: Content[],
|
||||||
|
): Promise<Content[]> {
|
||||||
|
const envParts = await getEnvironmentContext(config);
|
||||||
|
const envContextString = envParts.map((part) => part.text || '').join('\n\n');
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: envContextString }],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'Got it. Thanks for the context!' }],
|
||||||
|
},
|
||||||
|
...(extraHistory ?? []),
|
||||||
|
];
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user