From cd2e237c73c5e60933c4bf3dee5a3759ac7c3364 Mon Sep 17 00:00:00 2001 From: Richie Foreman Date: Wed, 27 Aug 2025 17:00:45 -0400 Subject: [PATCH] fix(compression): Discard compression result if it results in more token usage (#7047) --- .../src/ui/commands/compressCommand.test.ts | 35 +- .../cli/src/ui/commands/compressCommand.ts | 2 + packages/cli/src/ui/types.ts | 2 + packages/core/src/core/client.test.ts | 320 ++++++++++++++++-- packages/core/src/core/client.ts | 70 +++- packages/core/src/core/turn.ts | 15 + 6 files changed, 397 insertions(+), 47 deletions(-) diff --git a/packages/cli/src/ui/commands/compressCommand.test.ts b/packages/cli/src/ui/commands/compressCommand.test.ts index 7508bc9f..ed1e1345 100644 --- a/packages/cli/src/ui/commands/compressCommand.test.ts +++ b/packages/cli/src/ui/commands/compressCommand.test.ts @@ -4,7 +4,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -import type { GeminiClient } from '@google/gemini-cli-core'; +import { + CompressionStatus, + type ChatCompressionInfo, + type GeminiClient, +} from '@google/gemini-cli-core'; import { vi, describe, it, expect, beforeEach } from 'vitest'; import { compressCommand } from './compressCommand.js'; import { createMockCommandContext } from '../../test-utils/mockCommandContext.js'; @@ -35,6 +39,7 @@ describe('compressCommand', () => { isPending: true, originalTokenCount: null, newTokenCount: null, + compressionStatus: null, }, }; await compressCommand.action!(context, ''); @@ -50,25 +55,24 @@ describe('compressCommand', () => { }); it('should set pending item, call tryCompressChat, and add result on success', async () => { - const compressedResult = { + const compressedResult: ChatCompressionInfo = { originalTokenCount: 200, + compressionStatus: CompressionStatus.COMPRESSED, newTokenCount: 100, }; mockTryCompressChat.mockResolvedValue(compressedResult); await compressCommand.action!(context, ''); - expect(context.ui.setPendingItem).toHaveBeenNthCalledWith( - 1, - expect.objectContaining({ - type: MessageType.COMPRESSION, - compression: { - isPending: true, - originalTokenCount: null, - newTokenCount: null, - }, - }), - ); + expect(context.ui.setPendingItem).toHaveBeenNthCalledWith(1, { + type: MessageType.COMPRESSION, + compression: { + isPending: true, + compressionStatus: null, + originalTokenCount: null, + newTokenCount: null, + }, + }); expect(mockTryCompressChat).toHaveBeenCalledWith( expect.stringMatching(/^compress-\d+$/), @@ -76,14 +80,15 @@ describe('compressCommand', () => { ); expect(context.ui.addItem).toHaveBeenCalledWith( - expect.objectContaining({ + { type: MessageType.COMPRESSION, compression: { isPending: false, + compressionStatus: CompressionStatus.COMPRESSED, originalTokenCount: 200, newTokenCount: 100, }, - }), + }, expect.any(Number), ); diff --git a/packages/cli/src/ui/commands/compressCommand.ts b/packages/cli/src/ui/commands/compressCommand.ts index 78e8fa63..45dc6a46 100644 --- a/packages/cli/src/ui/commands/compressCommand.ts +++ b/packages/cli/src/ui/commands/compressCommand.ts @@ -33,6 +33,7 @@ export const compressCommand: SlashCommand = { isPending: true, originalTokenCount: null, newTokenCount: null, + compressionStatus: null, }, }; @@ -50,6 +51,7 @@ export const compressCommand: SlashCommand = { isPending: false, originalTokenCount: compressed.originalTokenCount, newTokenCount: compressed.newTokenCount, + compressionStatus: compressed.compressionStatus, }, } as HistoryItemCompression, Date.now(), diff --git a/packages/cli/src/ui/types.ts b/packages/cli/src/ui/types.ts index 5073aa0c..d453bec9 100644 --- a/packages/cli/src/ui/types.ts +++ b/packages/cli/src/ui/types.ts @@ -5,6 +5,7 @@ */ import type { + CompressionStatus, ToolCallConfirmationDetails, ToolResultDisplay, } from '@google/gemini-cli-core'; @@ -56,6 +57,7 @@ export interface CompressionProps { isPending: boolean; originalTokenCount: number | null; newTokenCount: number | null; + compressionStatus: CompressionStatus | null; } export interface HistoryItemBase { diff --git a/packages/core/src/core/client.test.ts b/packages/core/src/core/client.test.ts index c858876e..8f37c2dc 100644 --- a/packages/core/src/core/client.test.ts +++ b/packages/core/src/core/client.test.ts @@ -4,21 +4,38 @@ * SPDX-License-Identifier: Apache-2.0 */ -import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'; +import { + describe, + it, + expect, + vi, + beforeEach, + afterEach, + type Mocked, +} from 'vitest'; import type { Chat, Content, EmbedContentResponse, GenerateContentResponse, + Part, } from '@google/genai'; import { GoogleGenAI } from '@google/genai'; import { findIndexAfterFraction, GeminiClient } from './client.js'; -import type { ContentGenerator } from './contentGenerator.js'; -import { AuthType } from './contentGenerator.js'; -import type { GeminiChat } from './geminiChat.js'; +import { + AuthType, + type ContentGenerator, + type ContentGeneratorConfig, +} from './contentGenerator.js'; +import { type GeminiChat } from './geminiChat.js'; import { Config } from '../config/config.js'; -import { GeminiEventType, Turn } from './turn.js'; +import { + CompressionStatus, + GeminiEventType, + Turn, + type ChatCompressionInfo, +} from './turn.js'; import { getCoreSystemPrompt } from './prompts.js'; import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js'; import { FileDiscoveryService } from '../services/fileDiscoveryService.js'; @@ -34,7 +51,8 @@ const mockEmbedContentFn = vi.fn(); const mockTurnRunFn = vi.fn(); vi.mock('@google/genai'); -vi.mock('./turn', () => { +vi.mock('./turn', async (importOriginal) => { + const actual = await importOriginal(); // Define a mock class that has the same shape as the real Turn class MockTurn { pendingToolCalls = []; @@ -47,13 +65,8 @@ vi.mock('./turn', () => { } // Export the mock class as 'Turn' return { + ...actual, Turn: MockTurn, - GeminiEventType: { - MaxSessionTurns: 'MaxSessionTurns', - ChatCompressed: 'ChatCompressed', - Error: 'error', - Content: 'content', - }, }; }); @@ -78,6 +91,19 @@ vi.mock('../telemetry/index.js', () => ({ })); vi.mock('../ide/ideContext.js'); +/** + * Array.fromAsync ponyfill, which will be available in es 2024. + * + * Buffers an async generator into an array and returns the result. + */ +async function fromAsync(promise: AsyncGenerator): Promise { + const results: T[] = []; + for await (const result of promise) { + results.push(result); + } + return results; +} + describe('findIndexAfterFraction', () => { const history: Content[] = [ { role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66 @@ -176,7 +202,7 @@ describe('Gemini Client (client.ts)', () => { getTool: vi.fn().mockReturnValue(null), }; const fileService = new FileDiscoveryService('/test/dir'); - const contentGeneratorConfig = { + const contentGeneratorConfig: ContentGeneratorConfig = { model: 'test-model', apiKey: 'test-key', vertexai: false, @@ -380,7 +406,9 @@ describe('Gemini Client (client.ts)', () => { }); it('should allow overriding model and config', async () => { - const contents = [{ role: 'user', parts: [{ text: 'hello' }] }]; + const contents: Content[] = [ + { role: 'user', parts: [{ text: 'hello' }] }, + ]; const schema = { type: 'string' }; const abortSignal = new AbortController().signal; const customModel = 'custom-json-model'; @@ -421,10 +449,10 @@ describe('Gemini Client (client.ts)', () => { describe('addHistory', () => { it('should call chat.addHistory with the provided content', async () => { - const mockChat = { + const mockChat: Partial = { addHistory: vi.fn(), }; - client['chat'] = mockChat as unknown as GeminiChat; + client['chat'] = mockChat as GeminiChat; const newContent = { role: 'user', @@ -486,6 +514,139 @@ describe('Gemini Client (client.ts)', () => { } as unknown as GeminiChat; }); + function setup({ + chatHistory = [ + { role: 'user', parts: [{ text: 'Long conversation' }] }, + { role: 'model', parts: [{ text: 'Long response' }] }, + ] as Content[], + } = {}) { + const mockChat: Partial = { + getHistory: vi.fn().mockReturnValue(chatHistory), + setHistory: vi.fn(), + sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }), + }; + const mockCountTokens = vi + .fn() + .mockResolvedValueOnce({ totalTokens: 1000 }) + .mockResolvedValueOnce({ totalTokens: 5000 }); + + const mockGenerator: Partial> = { + countTokens: mockCountTokens, + }; + + client['chat'] = mockChat as GeminiChat; + client['contentGenerator'] = mockGenerator as ContentGenerator; + client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat }); + + return { client, mockChat, mockGenerator }; + } + + describe('when compression inflates the token count', () => { + it('uses the truncated history for compression'); + it('allows compression to be forced/manual after a failure', async () => { + const { client, mockGenerator } = setup(); + mockGenerator.countTokens?.mockResolvedValue({ + totalTokens: 1000, + }); + await client.tryCompressChat('prompt-id-4'); // Fails + const result = await client.tryCompressChat('prompt-id-4', true); + + expect(result).toEqual({ + compressionStatus: CompressionStatus.COMPRESSED, + newTokenCount: 1000, + originalTokenCount: 1000, + }); + }); + + it('yields the result even if the compression inflated the tokens', async () => { + const { client } = setup(); + const result = await client.tryCompressChat('prompt-id-4', true); + + expect(result).toEqual({ + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + newTokenCount: 5000, + originalTokenCount: 1000, + }); + }); + + it('does not manipulate the source chat', async () => { + const { client, mockChat } = setup(); + await client.tryCompressChat('prompt-id-4', true); + + expect(client['chat']).toBe(mockChat); // a new chat session was not created + }); + + it('restores the history back to the original', async () => { + vi.mocked(tokenLimit).mockReturnValue(1000); + mockCountTokens.mockResolvedValue({ + totalTokens: 999, + }); + + const originalHistory: Content[] = [ + { role: 'user', parts: [{ text: 'what is your wisdom?' }] }, + { role: 'model', parts: [{ text: 'some wisdom' }] }, + { role: 'user', parts: [{ text: 'ahh that is a good a wisdom' }] }, + ]; + + const { client } = setup({ + chatHistory: originalHistory, + }); + const { compressionStatus } = + await client.tryCompressChat('prompt-id-4'); + + expect(compressionStatus).toBe( + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + ); + expect(client['chat']?.setHistory).toHaveBeenCalledWith( + originalHistory, + ); + }); + + it('will not attempt to compress context after a failure', async () => { + const { client, mockGenerator } = setup(); + await client.tryCompressChat('prompt-id-4'); + + const result = await client.tryCompressChat('prompt-id-5'); + + // it counts tokens for {original, compressed} and then never again + expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2); + expect(result).toEqual({ + compressionStatus: CompressionStatus.NOOP, + newTokenCount: 0, + originalTokenCount: 0, + }); + }); + }); + + it('attempts to compress with a maxOutputTokens set to the original token count', async () => { + vi.mocked(tokenLimit).mockReturnValue(1000); + mockCountTokens.mockResolvedValue({ + totalTokens: 999, + }); + + mockGetHistory.mockReturnValue([ + { role: 'user', parts: [{ text: '...history...' }] }, + ]); + + // Mock the summary response from the chat + mockSendMessage.mockResolvedValue({ + role: 'model', + parts: [{ text: 'This is a summary.' }], + }); + + await client.tryCompressChat('prompt-id-2', true); + + expect(mockSendMessage).toHaveBeenCalledWith( + expect.objectContaining({ + config: expect.objectContaining({ + maxOutputTokens: 999, + }), + }), + 'prompt-id-2', + ); + }); + it('should not trigger summarization if token count is below threshold', async () => { const MOCKED_TOKEN_LIMIT = 1000; vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT); @@ -502,7 +663,11 @@ describe('Gemini Client (client.ts)', () => { const newChat = client.getChat(); expect(tokenLimit).toHaveBeenCalled(); - expect(result).toBeNull(); + expect(result).toEqual({ + compressionStatus: CompressionStatus.NOOP, + newTokenCount: 699, + originalTokenCount: 699, + }); expect(newChat).toBe(initialChat); }); @@ -579,6 +744,7 @@ describe('Gemini Client (client.ts)', () => { // Assert that summarization happened and returned the correct stats expect(result).toEqual({ + compressionStatus: CompressionStatus.COMPRESSED, originalTokenCount, newTokenCount, }); @@ -631,6 +797,7 @@ describe('Gemini Client (client.ts)', () => { // Assert that summarization happened and returned the correct stats expect(result).toEqual({ + compressionStatus: CompressionStatus.COMPRESSED, originalTokenCount, newTokenCount, }); @@ -670,6 +837,7 @@ describe('Gemini Client (client.ts)', () => { expect(mockSendMessage).toHaveBeenCalled(); expect(result).toEqual({ + compressionStatus: CompressionStatus.COMPRESSED, originalTokenCount, newTokenCount, }); @@ -727,6 +895,7 @@ describe('Gemini Client (client.ts)', () => { }); expect(result).toEqual({ + compressionStatus: CompressionStatus.COMPRESSED, originalTokenCount: 100000, newTokenCount: 5000, }); @@ -734,6 +903,109 @@ describe('Gemini Client (client.ts)', () => { }); describe('sendMessageStream', () => { + it('emits a compression event when the context was automatically compressed', async () => { + // Arrange + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + const compressionInfo: ChatCompressionInfo = { + compressionStatus: CompressionStatus.COMPRESSED, + originalTokenCount: 1000, + newTokenCount: 500, + }; + + vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce( + compressionInfo, + ); + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-1', + ); + + const events = await fromAsync(stream); + + // Assert + expect(events).toContainEqual({ + type: GeminiEventType.ChatCompressed, + value: compressionInfo, + }); + }); + + it.each([ + { + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }, + { compressionStatus: CompressionStatus.NOOP }, + { + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, + }, + ])( + 'does not emit a compression event when the status is $compressionStatus', + async ({ compressionStatus }) => { + // Arrange + const mockStream = (async function* () { + yield { type: 'content', value: 'Hello' }; + })(); + mockTurnRunFn.mockReturnValue(mockStream); + + const mockChat: Partial = { + addHistory: vi.fn(), + getHistory: vi.fn().mockReturnValue([]), + }; + client['chat'] = mockChat as GeminiChat; + + const mockGenerator: Partial = { + countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }), + generateContent: mockGenerateContentFn, + }; + client['contentGenerator'] = mockGenerator as ContentGenerator; + + const compressionInfo: ChatCompressionInfo = { + compressionStatus, + originalTokenCount: 1000, + newTokenCount: 500, + }; + + vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce( + compressionInfo, + ); + + // Act + const stream = client.sendMessageStream( + [{ text: 'Hi' }], + new AbortController().signal, + 'prompt-id-1', + ); + + const events = await fromAsync(stream); + + // Assert + expect(events).not.toContainEqual({ + type: GeminiEventType.ChatCompressed, + value: expect.anything(), + }); + }, + ); + it('should include editor context when ideMode is enabled', async () => { // Arrange vi.mocked(ideContext.getIdeContext).mockReturnValue({ @@ -777,7 +1049,7 @@ describe('Gemini Client (client.ts)', () => { }; client['contentGenerator'] = mockGenerator as ContentGenerator; - const initialRequest = [{ text: 'Hi' }]; + const initialRequest: Part[] = [{ text: 'Hi' }]; // Act const stream = client.sendMessageStream( @@ -1292,7 +1564,11 @@ ${JSON.stringify( beforeEach(() => { client['forceFullIdeContext'] = false; // Reset before each delta test - vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null); + vi.spyOn(client, 'tryCompressChat').mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.COMPRESSED, + }); vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true); mockTurnRunFn.mockReturnValue(mockStream); @@ -1551,7 +1827,11 @@ ${JSON.stringify( let mockChat: Partial; beforeEach(() => { - vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null); + vi.spyOn(client, 'tryCompressChat').mockResolvedValue({ + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.COMPRESSED, + }); const mockStream = (async function* () { yield { type: 'content', value: 'response' }; diff --git a/packages/core/src/core/client.ts b/packages/core/src/core/client.ts index bc1054a2..7778da6b 100644 --- a/packages/core/src/core/client.ts +++ b/packages/core/src/core/client.ts @@ -17,6 +17,7 @@ import { getEnvironmentContext, } from '../utils/environmentContext.js'; import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js'; +import { CompressionStatus } from './turn.js'; import { Turn, GeminiEventType } from './turn.js'; import type { Config } from '../config/config.js'; import type { UserTierId } from '../code_assist/types.js'; @@ -105,8 +106,8 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3; export class GeminiClient { private chat?: GeminiChat; private contentGenerator?: ContentGenerator; - private embeddingModel: string; - private generateContentConfig: GenerateContentConfig = { + private readonly embeddingModel: string; + private readonly generateContentConfig: GenerateContentConfig = { temperature: 0, topP: 1, }; @@ -117,7 +118,13 @@ export class GeminiClient { private lastSentIdeContext: IdeContext | undefined; private forceFullIdeContext = true; - constructor(private config: Config) { + /** + * At any point in this conversation, was compression triggered without + * being forced and did it fail? + */ + private hasFailedCompressionAttempt = false; + + constructor(private readonly config: Config) { if (config.getProxy()) { setGlobalDispatcher(new ProxyAgent(config.getProxy() as string)); } @@ -219,6 +226,7 @@ export class GeminiClient { async startChat(extraHistory?: Content[]): Promise { this.forceFullIdeContext = true; + this.hasFailedCompressionAttempt = false; const envParts = await getEnvironmentContext(this.config); const toolRegistry = this.config.getToolRegistry(); const toolDeclarations = toolRegistry.getFunctionDeclarations(); @@ -467,7 +475,7 @@ export class GeminiClient { const compressed = await this.tryCompressChat(prompt_id); - if (compressed) { + if (compressed.compressionStatus === CompressionStatus.COMPRESSED) { yield { type: GeminiEventType.ChatCompressed, value: compressed }; } @@ -764,12 +772,19 @@ export class GeminiClient { async tryCompressChat( prompt_id: string, force: boolean = false, - ): Promise { + ): Promise { const curatedHistory = this.getChat().getHistory(true); // Regardless of `force`, don't do anything if the history is empty. - if (curatedHistory.length === 0) { - return null; + if ( + curatedHistory.length === 0 || + (this.hasFailedCompressionAttempt && !force) + ) { + return { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: CompressionStatus.NOOP, + }; } const model = this.config.getModel(); @@ -781,7 +796,13 @@ export class GeminiClient { }); if (originalTokenCount === undefined) { console.warn(`Could not determine token count for model ${model}.`); - return null; + this.hasFailedCompressionAttempt = !force && true; + return { + originalTokenCount: 0, + newTokenCount: 0, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, + }; } const contextPercentageThreshold = @@ -792,7 +813,11 @@ export class GeminiClient { const threshold = contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD; if (originalTokenCount < threshold * tokenLimit(model)) { - return null; + return { + originalTokenCount, + newTokenCount: originalTokenCount, + compressionStatus: CompressionStatus.NOOP, + }; } } @@ -821,11 +846,12 @@ export class GeminiClient { }, config: { systemInstruction: { text: getCompressionPrompt() }, + maxOutputTokens: originalTokenCount, }, }, prompt_id, ); - this.chat = await this.startChat([ + const chat = await this.startChat([ { role: 'user', parts: [{ text: summary }], @@ -842,11 +868,17 @@ export class GeminiClient { await this.getContentGenerator().countTokens({ // model might change after calling `sendMessage`, so we get the newest value from config model: this.config.getModel(), - contents: this.getChat().getHistory(), + contents: chat.getHistory(), }); if (newTokenCount === undefined) { console.warn('Could not determine compressed history token count.'); - return null; + this.hasFailedCompressionAttempt = !force && true; + return { + originalTokenCount, + newTokenCount: originalTokenCount, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR, + }; } logChatCompression( @@ -857,9 +889,23 @@ export class GeminiClient { }), ); + if (newTokenCount > originalTokenCount) { + this.getChat().setHistory(curatedHistory); + this.hasFailedCompressionAttempt = !force && true; + return { + originalTokenCount, + newTokenCount, + compressionStatus: + CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + }; + } else { + this.chat = chat; // Chat compression successful, set new state. + } + return { originalTokenCount, newTokenCount, + compressionStatus: CompressionStatus.COMPRESSED, }; } diff --git a/packages/core/src/core/turn.ts b/packages/core/src/core/turn.ts index 6d1aa294..24db4940 100644 --- a/packages/core/src/core/turn.ts +++ b/packages/core/src/core/turn.ts @@ -125,9 +125,24 @@ export type ServerGeminiErrorEvent = { value: GeminiErrorEventValue; }; +export enum CompressionStatus { + /** The compression was successful */ + COMPRESSED = 1, + + /** The compression failed due to the compression inflating the token count */ + COMPRESSION_FAILED_INFLATED_TOKEN_COUNT, + + /** The compression failed due to an error counting tokens */ + COMPRESSION_FAILED_TOKEN_COUNT_ERROR, + + /** The compression was not necessary and no action was taken */ + NOOP, +} + export interface ChatCompressionInfo { originalTokenCount: number; newTokenCount: number; + compressionStatus: CompressionStatus; } export type ServerGeminiChatCompressedEvent = {