fix(compression): Discard compression result if it results in more token usage (#7047)

This commit is contained in:
Richie Foreman
2025-08-27 17:00:45 -04:00
committed by GitHub
parent da7901acaf
commit cd2e237c73
6 changed files with 397 additions and 47 deletions

View File

@@ -4,21 +4,38 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
import {
describe,
it,
expect,
vi,
beforeEach,
afterEach,
type Mocked,
} from 'vitest';
import type {
Chat,
Content,
EmbedContentResponse,
GenerateContentResponse,
Part,
} from '@google/genai';
import { GoogleGenAI } from '@google/genai';
import { findIndexAfterFraction, GeminiClient } from './client.js';
import type { ContentGenerator } from './contentGenerator.js';
import { AuthType } from './contentGenerator.js';
import type { GeminiChat } from './geminiChat.js';
import {
AuthType,
type ContentGenerator,
type ContentGeneratorConfig,
} from './contentGenerator.js';
import { type GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
import { GeminiEventType, Turn } from './turn.js';
import {
CompressionStatus,
GeminiEventType,
Turn,
type ChatCompressionInfo,
} from './turn.js';
import { getCoreSystemPrompt } from './prompts.js';
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
@@ -34,7 +51,8 @@ const mockEmbedContentFn = vi.fn();
const mockTurnRunFn = vi.fn();
vi.mock('@google/genai');
vi.mock('./turn', () => {
vi.mock('./turn', async (importOriginal) => {
const actual = await importOriginal<typeof import('./turn.js')>();
// Define a mock class that has the same shape as the real Turn
class MockTurn {
pendingToolCalls = [];
@@ -47,13 +65,8 @@ vi.mock('./turn', () => {
}
// Export the mock class as 'Turn'
return {
...actual,
Turn: MockTurn,
GeminiEventType: {
MaxSessionTurns: 'MaxSessionTurns',
ChatCompressed: 'ChatCompressed',
Error: 'error',
Content: 'content',
},
};
});
@@ -78,6 +91,19 @@ vi.mock('../telemetry/index.js', () => ({
}));
vi.mock('../ide/ideContext.js');
/**
* Array.fromAsync ponyfill, which will be available in es 2024.
*
* Buffers an async generator into an array and returns the result.
*/
async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
const results: T[] = [];
for await (const result of promise) {
results.push(result);
}
return results;
}
describe('findIndexAfterFraction', () => {
const history: Content[] = [
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
@@ -176,7 +202,7 @@ describe('Gemini Client (client.ts)', () => {
getTool: vi.fn().mockReturnValue(null),
};
const fileService = new FileDiscoveryService('/test/dir');
const contentGeneratorConfig = {
const contentGeneratorConfig: ContentGeneratorConfig = {
model: 'test-model',
apiKey: 'test-key',
vertexai: false,
@@ -380,7 +406,9 @@ describe('Gemini Client (client.ts)', () => {
});
it('should allow overriding model and config', async () => {
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
const contents: Content[] = [
{ role: 'user', parts: [{ text: 'hello' }] },
];
const schema = { type: 'string' };
const abortSignal = new AbortController().signal;
const customModel = 'custom-json-model';
@@ -421,10 +449,10 @@ describe('Gemini Client (client.ts)', () => {
describe('addHistory', () => {
it('should call chat.addHistory with the provided content', async () => {
const mockChat = {
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
};
client['chat'] = mockChat as unknown as GeminiChat;
client['chat'] = mockChat as GeminiChat;
const newContent = {
role: 'user',
@@ -486,6 +514,139 @@ describe('Gemini Client (client.ts)', () => {
} as unknown as GeminiChat;
});
function setup({
chatHistory = [
{ role: 'user', parts: [{ text: 'Long conversation' }] },
{ role: 'model', parts: [{ text: 'Long response' }] },
] as Content[],
} = {}) {
const mockChat: Partial<GeminiChat> = {
getHistory: vi.fn().mockReturnValue(chatHistory),
setHistory: vi.fn(),
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
};
const mockCountTokens = vi
.fn()
.mockResolvedValueOnce({ totalTokens: 1000 })
.mockResolvedValueOnce({ totalTokens: 5000 });
const mockGenerator: Partial<Mocked<ContentGenerator>> = {
countTokens: mockCountTokens,
};
client['chat'] = mockChat as GeminiChat;
client['contentGenerator'] = mockGenerator as ContentGenerator;
client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat });
return { client, mockChat, mockGenerator };
}
describe('when compression inflates the token count', () => {
it('uses the truncated history for compression');
it('allows compression to be forced/manual after a failure', async () => {
const { client, mockGenerator } = setup();
mockGenerator.countTokens?.mockResolvedValue({
totalTokens: 1000,
});
await client.tryCompressChat('prompt-id-4'); // Fails
const result = await client.tryCompressChat('prompt-id-4', true);
expect(result).toEqual({
compressionStatus: CompressionStatus.COMPRESSED,
newTokenCount: 1000,
originalTokenCount: 1000,
});
});
it('yields the result even if the compression inflated the tokens', async () => {
const { client } = setup();
const result = await client.tryCompressChat('prompt-id-4', true);
expect(result).toEqual({
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
newTokenCount: 5000,
originalTokenCount: 1000,
});
});
it('does not manipulate the source chat', async () => {
const { client, mockChat } = setup();
await client.tryCompressChat('prompt-id-4', true);
expect(client['chat']).toBe(mockChat); // a new chat session was not created
});
it('restores the history back to the original', async () => {
vi.mocked(tokenLimit).mockReturnValue(1000);
mockCountTokens.mockResolvedValue({
totalTokens: 999,
});
const originalHistory: Content[] = [
{ role: 'user', parts: [{ text: 'what is your wisdom?' }] },
{ role: 'model', parts: [{ text: 'some wisdom' }] },
{ role: 'user', parts: [{ text: 'ahh that is a good a wisdom' }] },
];
const { client } = setup({
chatHistory: originalHistory,
});
const { compressionStatus } =
await client.tryCompressChat('prompt-id-4');
expect(compressionStatus).toBe(
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
);
expect(client['chat']?.setHistory).toHaveBeenCalledWith(
originalHistory,
);
});
it('will not attempt to compress context after a failure', async () => {
const { client, mockGenerator } = setup();
await client.tryCompressChat('prompt-id-4');
const result = await client.tryCompressChat('prompt-id-5');
// it counts tokens for {original, compressed} and then never again
expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2);
expect(result).toEqual({
compressionStatus: CompressionStatus.NOOP,
newTokenCount: 0,
originalTokenCount: 0,
});
});
});
it('attempts to compress with a maxOutputTokens set to the original token count', async () => {
vi.mocked(tokenLimit).mockReturnValue(1000);
mockCountTokens.mockResolvedValue({
totalTokens: 999,
});
mockGetHistory.mockReturnValue([
{ role: 'user', parts: [{ text: '...history...' }] },
]);
// Mock the summary response from the chat
mockSendMessage.mockResolvedValue({
role: 'model',
parts: [{ text: 'This is a summary.' }],
});
await client.tryCompressChat('prompt-id-2', true);
expect(mockSendMessage).toHaveBeenCalledWith(
expect.objectContaining({
config: expect.objectContaining({
maxOutputTokens: 999,
}),
}),
'prompt-id-2',
);
});
it('should not trigger summarization if token count is below threshold', async () => {
const MOCKED_TOKEN_LIMIT = 1000;
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
@@ -502,7 +663,11 @@ describe('Gemini Client (client.ts)', () => {
const newChat = client.getChat();
expect(tokenLimit).toHaveBeenCalled();
expect(result).toBeNull();
expect(result).toEqual({
compressionStatus: CompressionStatus.NOOP,
newTokenCount: 699,
originalTokenCount: 699,
});
expect(newChat).toBe(initialChat);
});
@@ -579,6 +744,7 @@ describe('Gemini Client (client.ts)', () => {
// Assert that summarization happened and returned the correct stats
expect(result).toEqual({
compressionStatus: CompressionStatus.COMPRESSED,
originalTokenCount,
newTokenCount,
});
@@ -631,6 +797,7 @@ describe('Gemini Client (client.ts)', () => {
// Assert that summarization happened and returned the correct stats
expect(result).toEqual({
compressionStatus: CompressionStatus.COMPRESSED,
originalTokenCount,
newTokenCount,
});
@@ -670,6 +837,7 @@ describe('Gemini Client (client.ts)', () => {
expect(mockSendMessage).toHaveBeenCalled();
expect(result).toEqual({
compressionStatus: CompressionStatus.COMPRESSED,
originalTokenCount,
newTokenCount,
});
@@ -727,6 +895,7 @@ describe('Gemini Client (client.ts)', () => {
});
expect(result).toEqual({
compressionStatus: CompressionStatus.COMPRESSED,
originalTokenCount: 100000,
newTokenCount: 5000,
});
@@ -734,6 +903,109 @@ describe('Gemini Client (client.ts)', () => {
});
describe('sendMessageStream', () => {
it('emits a compression event when the context was automatically compressed', async () => {
// Arrange
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const compressionInfo: ChatCompressionInfo = {
compressionStatus: CompressionStatus.COMPRESSED,
originalTokenCount: 1000,
newTokenCount: 500,
};
vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce(
compressionInfo,
);
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-1',
);
const events = await fromAsync(stream);
// Assert
expect(events).toContainEqual({
type: GeminiEventType.ChatCompressed,
value: compressionInfo,
});
});
it.each([
{
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
},
{ compressionStatus: CompressionStatus.NOOP },
{
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
},
])(
'does not emit a compression event when the status is $compressionStatus',
async ({ compressionStatus }) => {
// Arrange
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const compressionInfo: ChatCompressionInfo = {
compressionStatus,
originalTokenCount: 1000,
newTokenCount: 500,
};
vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce(
compressionInfo,
);
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-1',
);
const events = await fromAsync(stream);
// Assert
expect(events).not.toContainEqual({
type: GeminiEventType.ChatCompressed,
value: expect.anything(),
});
},
);
it('should include editor context when ideMode is enabled', async () => {
// Arrange
vi.mocked(ideContext.getIdeContext).mockReturnValue({
@@ -777,7 +1049,7 @@ describe('Gemini Client (client.ts)', () => {
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
const initialRequest = [{ text: 'Hi' }];
const initialRequest: Part[] = [{ text: 'Hi' }];
// Act
const stream = client.sendMessageStream(
@@ -1292,7 +1564,11 @@ ${JSON.stringify(
beforeEach(() => {
client['forceFullIdeContext'] = false; // Reset before each delta test
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
originalTokenCount: 0,
newTokenCount: 0,
compressionStatus: CompressionStatus.COMPRESSED,
});
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
mockTurnRunFn.mockReturnValue(mockStream);
@@ -1551,7 +1827,11 @@ ${JSON.stringify(
let mockChat: Partial<GeminiChat>;
beforeEach(() => {
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
originalTokenCount: 0,
newTokenCount: 0,
compressionStatus: CompressionStatus.COMPRESSED,
});
const mockStream = (async function* () {
yield { type: 'content', value: 'response' };

View File

@@ -17,6 +17,7 @@ import {
getEnvironmentContext,
} from '../utils/environmentContext.js';
import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js';
import { CompressionStatus } from './turn.js';
import { Turn, GeminiEventType } from './turn.js';
import type { Config } from '../config/config.js';
import type { UserTierId } from '../code_assist/types.js';
@@ -105,8 +106,8 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
export class GeminiClient {
private chat?: GeminiChat;
private contentGenerator?: ContentGenerator;
private embeddingModel: string;
private generateContentConfig: GenerateContentConfig = {
private readonly embeddingModel: string;
private readonly generateContentConfig: GenerateContentConfig = {
temperature: 0,
topP: 1,
};
@@ -117,7 +118,13 @@ export class GeminiClient {
private lastSentIdeContext: IdeContext | undefined;
private forceFullIdeContext = true;
constructor(private config: Config) {
/**
* At any point in this conversation, was compression triggered without
* being forced and did it fail?
*/
private hasFailedCompressionAttempt = false;
constructor(private readonly config: Config) {
if (config.getProxy()) {
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
}
@@ -219,6 +226,7 @@ export class GeminiClient {
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
this.forceFullIdeContext = true;
this.hasFailedCompressionAttempt = false;
const envParts = await getEnvironmentContext(this.config);
const toolRegistry = this.config.getToolRegistry();
const toolDeclarations = toolRegistry.getFunctionDeclarations();
@@ -467,7 +475,7 @@ export class GeminiClient {
const compressed = await this.tryCompressChat(prompt_id);
if (compressed) {
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
yield { type: GeminiEventType.ChatCompressed, value: compressed };
}
@@ -764,12 +772,19 @@ export class GeminiClient {
async tryCompressChat(
prompt_id: string,
force: boolean = false,
): Promise<ChatCompressionInfo | null> {
): Promise<ChatCompressionInfo> {
const curatedHistory = this.getChat().getHistory(true);
// Regardless of `force`, don't do anything if the history is empty.
if (curatedHistory.length === 0) {
return null;
if (
curatedHistory.length === 0 ||
(this.hasFailedCompressionAttempt && !force)
) {
return {
originalTokenCount: 0,
newTokenCount: 0,
compressionStatus: CompressionStatus.NOOP,
};
}
const model = this.config.getModel();
@@ -781,7 +796,13 @@ export class GeminiClient {
});
if (originalTokenCount === undefined) {
console.warn(`Could not determine token count for model ${model}.`);
return null;
this.hasFailedCompressionAttempt = !force && true;
return {
originalTokenCount: 0,
newTokenCount: 0,
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
};
}
const contextPercentageThreshold =
@@ -792,7 +813,11 @@ export class GeminiClient {
const threshold =
contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD;
if (originalTokenCount < threshold * tokenLimit(model)) {
return null;
return {
originalTokenCount,
newTokenCount: originalTokenCount,
compressionStatus: CompressionStatus.NOOP,
};
}
}
@@ -821,11 +846,12 @@ export class GeminiClient {
},
config: {
systemInstruction: { text: getCompressionPrompt() },
maxOutputTokens: originalTokenCount,
},
},
prompt_id,
);
this.chat = await this.startChat([
const chat = await this.startChat([
{
role: 'user',
parts: [{ text: summary }],
@@ -842,11 +868,17 @@ export class GeminiClient {
await this.getContentGenerator().countTokens({
// model might change after calling `sendMessage`, so we get the newest value from config
model: this.config.getModel(),
contents: this.getChat().getHistory(),
contents: chat.getHistory(),
});
if (newTokenCount === undefined) {
console.warn('Could not determine compressed history token count.');
return null;
this.hasFailedCompressionAttempt = !force && true;
return {
originalTokenCount,
newTokenCount: originalTokenCount,
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
};
}
logChatCompression(
@@ -857,9 +889,23 @@ export class GeminiClient {
}),
);
if (newTokenCount > originalTokenCount) {
this.getChat().setHistory(curatedHistory);
this.hasFailedCompressionAttempt = !force && true;
return {
originalTokenCount,
newTokenCount,
compressionStatus:
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
};
} else {
this.chat = chat; // Chat compression successful, set new state.
}
return {
originalTokenCount,
newTokenCount,
compressionStatus: CompressionStatus.COMPRESSED,
};
}

View File

@@ -125,9 +125,24 @@ export type ServerGeminiErrorEvent = {
value: GeminiErrorEventValue;
};
export enum CompressionStatus {
/** The compression was successful */
COMPRESSED = 1,
/** The compression failed due to the compression inflating the token count */
COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
/** The compression failed due to an error counting tokens */
COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
/** The compression was not necessary and no action was taken */
NOOP,
}
export interface ChatCompressionInfo {
originalTokenCount: number;
newTokenCount: number;
compressionStatus: CompressionStatus;
}
export type ServerGeminiChatCompressedEvent = {