mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
fix(compression): Discard compression result if it results in more token usage (#7047)
This commit is contained in:
@@ -4,7 +4,11 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { GeminiClient } from '@google/gemini-cli-core';
|
import {
|
||||||
|
CompressionStatus,
|
||||||
|
type ChatCompressionInfo,
|
||||||
|
type GeminiClient,
|
||||||
|
} from '@google/gemini-cli-core';
|
||||||
import { vi, describe, it, expect, beforeEach } from 'vitest';
|
import { vi, describe, it, expect, beforeEach } from 'vitest';
|
||||||
import { compressCommand } from './compressCommand.js';
|
import { compressCommand } from './compressCommand.js';
|
||||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||||
@@ -35,6 +39,7 @@ describe('compressCommand', () => {
|
|||||||
isPending: true,
|
isPending: true,
|
||||||
originalTokenCount: null,
|
originalTokenCount: null,
|
||||||
newTokenCount: null,
|
newTokenCount: null,
|
||||||
|
compressionStatus: null,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
await compressCommand.action!(context, '');
|
await compressCommand.action!(context, '');
|
||||||
@@ -50,25 +55,24 @@ describe('compressCommand', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should set pending item, call tryCompressChat, and add result on success', async () => {
|
it('should set pending item, call tryCompressChat, and add result on success', async () => {
|
||||||
const compressedResult = {
|
const compressedResult: ChatCompressionInfo = {
|
||||||
originalTokenCount: 200,
|
originalTokenCount: 200,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
newTokenCount: 100,
|
newTokenCount: 100,
|
||||||
};
|
};
|
||||||
mockTryCompressChat.mockResolvedValue(compressedResult);
|
mockTryCompressChat.mockResolvedValue(compressedResult);
|
||||||
|
|
||||||
await compressCommand.action!(context, '');
|
await compressCommand.action!(context, '');
|
||||||
|
|
||||||
expect(context.ui.setPendingItem).toHaveBeenNthCalledWith(
|
expect(context.ui.setPendingItem).toHaveBeenNthCalledWith(1, {
|
||||||
1,
|
|
||||||
expect.objectContaining({
|
|
||||||
type: MessageType.COMPRESSION,
|
type: MessageType.COMPRESSION,
|
||||||
compression: {
|
compression: {
|
||||||
isPending: true,
|
isPending: true,
|
||||||
|
compressionStatus: null,
|
||||||
originalTokenCount: null,
|
originalTokenCount: null,
|
||||||
newTokenCount: null,
|
newTokenCount: null,
|
||||||
},
|
},
|
||||||
}),
|
});
|
||||||
);
|
|
||||||
|
|
||||||
expect(mockTryCompressChat).toHaveBeenCalledWith(
|
expect(mockTryCompressChat).toHaveBeenCalledWith(
|
||||||
expect.stringMatching(/^compress-\d+$/),
|
expect.stringMatching(/^compress-\d+$/),
|
||||||
@@ -76,14 +80,15 @@ describe('compressCommand', () => {
|
|||||||
);
|
);
|
||||||
|
|
||||||
expect(context.ui.addItem).toHaveBeenCalledWith(
|
expect(context.ui.addItem).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
{
|
||||||
type: MessageType.COMPRESSION,
|
type: MessageType.COMPRESSION,
|
||||||
compression: {
|
compression: {
|
||||||
isPending: false,
|
isPending: false,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
originalTokenCount: 200,
|
originalTokenCount: 200,
|
||||||
newTokenCount: 100,
|
newTokenCount: 100,
|
||||||
},
|
},
|
||||||
}),
|
},
|
||||||
expect.any(Number),
|
expect.any(Number),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ export const compressCommand: SlashCommand = {
|
|||||||
isPending: true,
|
isPending: true,
|
||||||
originalTokenCount: null,
|
originalTokenCount: null,
|
||||||
newTokenCount: null,
|
newTokenCount: null,
|
||||||
|
compressionStatus: null,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ export const compressCommand: SlashCommand = {
|
|||||||
isPending: false,
|
isPending: false,
|
||||||
originalTokenCount: compressed.originalTokenCount,
|
originalTokenCount: compressed.originalTokenCount,
|
||||||
newTokenCount: compressed.newTokenCount,
|
newTokenCount: compressed.newTokenCount,
|
||||||
|
compressionStatus: compressed.compressionStatus,
|
||||||
},
|
},
|
||||||
} as HistoryItemCompression,
|
} as HistoryItemCompression,
|
||||||
Date.now(),
|
Date.now(),
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
|
CompressionStatus,
|
||||||
ToolCallConfirmationDetails,
|
ToolCallConfirmationDetails,
|
||||||
ToolResultDisplay,
|
ToolResultDisplay,
|
||||||
} from '@google/gemini-cli-core';
|
} from '@google/gemini-cli-core';
|
||||||
@@ -56,6 +57,7 @@ export interface CompressionProps {
|
|||||||
isPending: boolean;
|
isPending: boolean;
|
||||||
originalTokenCount: number | null;
|
originalTokenCount: number | null;
|
||||||
newTokenCount: number | null;
|
newTokenCount: number | null;
|
||||||
|
compressionStatus: CompressionStatus | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface HistoryItemBase {
|
export interface HistoryItemBase {
|
||||||
|
|||||||
@@ -4,21 +4,38 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
import {
|
||||||
|
describe,
|
||||||
|
it,
|
||||||
|
expect,
|
||||||
|
vi,
|
||||||
|
beforeEach,
|
||||||
|
afterEach,
|
||||||
|
type Mocked,
|
||||||
|
} from 'vitest';
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
Chat,
|
Chat,
|
||||||
Content,
|
Content,
|
||||||
EmbedContentResponse,
|
EmbedContentResponse,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
|
Part,
|
||||||
} from '@google/genai';
|
} from '@google/genai';
|
||||||
import { GoogleGenAI } from '@google/genai';
|
import { GoogleGenAI } from '@google/genai';
|
||||||
import { findIndexAfterFraction, GeminiClient } from './client.js';
|
import { findIndexAfterFraction, GeminiClient } from './client.js';
|
||||||
import type { ContentGenerator } from './contentGenerator.js';
|
import {
|
||||||
import { AuthType } from './contentGenerator.js';
|
AuthType,
|
||||||
import type { GeminiChat } from './geminiChat.js';
|
type ContentGenerator,
|
||||||
|
type ContentGeneratorConfig,
|
||||||
|
} from './contentGenerator.js';
|
||||||
|
import { type GeminiChat } from './geminiChat.js';
|
||||||
import { Config } from '../config/config.js';
|
import { Config } from '../config/config.js';
|
||||||
import { GeminiEventType, Turn } from './turn.js';
|
import {
|
||||||
|
CompressionStatus,
|
||||||
|
GeminiEventType,
|
||||||
|
Turn,
|
||||||
|
type ChatCompressionInfo,
|
||||||
|
} from './turn.js';
|
||||||
import { getCoreSystemPrompt } from './prompts.js';
|
import { getCoreSystemPrompt } from './prompts.js';
|
||||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||||
@@ -34,7 +51,8 @@ const mockEmbedContentFn = vi.fn();
|
|||||||
const mockTurnRunFn = vi.fn();
|
const mockTurnRunFn = vi.fn();
|
||||||
|
|
||||||
vi.mock('@google/genai');
|
vi.mock('@google/genai');
|
||||||
vi.mock('./turn', () => {
|
vi.mock('./turn', async (importOriginal) => {
|
||||||
|
const actual = await importOriginal<typeof import('./turn.js')>();
|
||||||
// Define a mock class that has the same shape as the real Turn
|
// Define a mock class that has the same shape as the real Turn
|
||||||
class MockTurn {
|
class MockTurn {
|
||||||
pendingToolCalls = [];
|
pendingToolCalls = [];
|
||||||
@@ -47,13 +65,8 @@ vi.mock('./turn', () => {
|
|||||||
}
|
}
|
||||||
// Export the mock class as 'Turn'
|
// Export the mock class as 'Turn'
|
||||||
return {
|
return {
|
||||||
|
...actual,
|
||||||
Turn: MockTurn,
|
Turn: MockTurn,
|
||||||
GeminiEventType: {
|
|
||||||
MaxSessionTurns: 'MaxSessionTurns',
|
|
||||||
ChatCompressed: 'ChatCompressed',
|
|
||||||
Error: 'error',
|
|
||||||
Content: 'content',
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -78,6 +91,19 @@ vi.mock('../telemetry/index.js', () => ({
|
|||||||
}));
|
}));
|
||||||
vi.mock('../ide/ideContext.js');
|
vi.mock('../ide/ideContext.js');
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||||
|
*
|
||||||
|
* Buffers an async generator into an array and returns the result.
|
||||||
|
*/
|
||||||
|
async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
|
||||||
|
const results: T[] = [];
|
||||||
|
for await (const result of promise) {
|
||||||
|
results.push(result);
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
describe('findIndexAfterFraction', () => {
|
describe('findIndexAfterFraction', () => {
|
||||||
const history: Content[] = [
|
const history: Content[] = [
|
||||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
|
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
|
||||||
@@ -176,7 +202,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
getTool: vi.fn().mockReturnValue(null),
|
getTool: vi.fn().mockReturnValue(null),
|
||||||
};
|
};
|
||||||
const fileService = new FileDiscoveryService('/test/dir');
|
const fileService = new FileDiscoveryService('/test/dir');
|
||||||
const contentGeneratorConfig = {
|
const contentGeneratorConfig: ContentGeneratorConfig = {
|
||||||
model: 'test-model',
|
model: 'test-model',
|
||||||
apiKey: 'test-key',
|
apiKey: 'test-key',
|
||||||
vertexai: false,
|
vertexai: false,
|
||||||
@@ -380,7 +406,9 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
it('should allow overriding model and config', async () => {
|
it('should allow overriding model and config', async () => {
|
||||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
const contents: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||||
|
];
|
||||||
const schema = { type: 'string' };
|
const schema = { type: 'string' };
|
||||||
const abortSignal = new AbortController().signal;
|
const abortSignal = new AbortController().signal;
|
||||||
const customModel = 'custom-json-model';
|
const customModel = 'custom-json-model';
|
||||||
@@ -421,10 +449,10 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
|
|
||||||
describe('addHistory', () => {
|
describe('addHistory', () => {
|
||||||
it('should call chat.addHistory with the provided content', async () => {
|
it('should call chat.addHistory with the provided content', async () => {
|
||||||
const mockChat = {
|
const mockChat: Partial<GeminiChat> = {
|
||||||
addHistory: vi.fn(),
|
addHistory: vi.fn(),
|
||||||
};
|
};
|
||||||
client['chat'] = mockChat as unknown as GeminiChat;
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
const newContent = {
|
const newContent = {
|
||||||
role: 'user',
|
role: 'user',
|
||||||
@@ -486,6 +514,139 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
} as unknown as GeminiChat;
|
} as unknown as GeminiChat;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
function setup({
|
||||||
|
chatHistory = [
|
||||||
|
{ role: 'user', parts: [{ text: 'Long conversation' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'Long response' }] },
|
||||||
|
] as Content[],
|
||||||
|
} = {}) {
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
getHistory: vi.fn().mockReturnValue(chatHistory),
|
||||||
|
setHistory: vi.fn(),
|
||||||
|
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
|
||||||
|
};
|
||||||
|
const mockCountTokens = vi
|
||||||
|
.fn()
|
||||||
|
.mockResolvedValueOnce({ totalTokens: 1000 })
|
||||||
|
.mockResolvedValueOnce({ totalTokens: 5000 });
|
||||||
|
|
||||||
|
const mockGenerator: Partial<Mocked<ContentGenerator>> = {
|
||||||
|
countTokens: mockCountTokens,
|
||||||
|
};
|
||||||
|
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat });
|
||||||
|
|
||||||
|
return { client, mockChat, mockGenerator };
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('when compression inflates the token count', () => {
|
||||||
|
it('uses the truncated history for compression');
|
||||||
|
it('allows compression to be forced/manual after a failure', async () => {
|
||||||
|
const { client, mockGenerator } = setup();
|
||||||
|
mockGenerator.countTokens?.mockResolvedValue({
|
||||||
|
totalTokens: 1000,
|
||||||
|
});
|
||||||
|
await client.tryCompressChat('prompt-id-4'); // Fails
|
||||||
|
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
|
newTokenCount: 1000,
|
||||||
|
originalTokenCount: 1000,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('yields the result even if the compression inflated the tokens', async () => {
|
||||||
|
const { client } = setup();
|
||||||
|
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
newTokenCount: 5000,
|
||||||
|
originalTokenCount: 1000,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not manipulate the source chat', async () => {
|
||||||
|
const { client, mockChat } = setup();
|
||||||
|
await client.tryCompressChat('prompt-id-4', true);
|
||||||
|
|
||||||
|
expect(client['chat']).toBe(mockChat); // a new chat session was not created
|
||||||
|
});
|
||||||
|
|
||||||
|
it('restores the history back to the original', async () => {
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
mockCountTokens.mockResolvedValue({
|
||||||
|
totalTokens: 999,
|
||||||
|
});
|
||||||
|
|
||||||
|
const originalHistory: Content[] = [
|
||||||
|
{ role: 'user', parts: [{ text: 'what is your wisdom?' }] },
|
||||||
|
{ role: 'model', parts: [{ text: 'some wisdom' }] },
|
||||||
|
{ role: 'user', parts: [{ text: 'ahh that is a good a wisdom' }] },
|
||||||
|
];
|
||||||
|
|
||||||
|
const { client } = setup({
|
||||||
|
chatHistory: originalHistory,
|
||||||
|
});
|
||||||
|
const { compressionStatus } =
|
||||||
|
await client.tryCompressChat('prompt-id-4');
|
||||||
|
|
||||||
|
expect(compressionStatus).toBe(
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
);
|
||||||
|
expect(client['chat']?.setHistory).toHaveBeenCalledWith(
|
||||||
|
originalHistory,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('will not attempt to compress context after a failure', async () => {
|
||||||
|
const { client, mockGenerator } = setup();
|
||||||
|
await client.tryCompressChat('prompt-id-4');
|
||||||
|
|
||||||
|
const result = await client.tryCompressChat('prompt-id-5');
|
||||||
|
|
||||||
|
// it counts tokens for {original, compressed} and then never again
|
||||||
|
expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||||
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
newTokenCount: 0,
|
||||||
|
originalTokenCount: 0,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it('attempts to compress with a maxOutputTokens set to the original token count', async () => {
|
||||||
|
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||||
|
mockCountTokens.mockResolvedValue({
|
||||||
|
totalTokens: 999,
|
||||||
|
});
|
||||||
|
|
||||||
|
mockGetHistory.mockReturnValue([
|
||||||
|
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||||
|
]);
|
||||||
|
|
||||||
|
// Mock the summary response from the chat
|
||||||
|
mockSendMessage.mockResolvedValue({
|
||||||
|
role: 'model',
|
||||||
|
parts: [{ text: 'This is a summary.' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
await client.tryCompressChat('prompt-id-2', true);
|
||||||
|
|
||||||
|
expect(mockSendMessage).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
config: expect.objectContaining({
|
||||||
|
maxOutputTokens: 999,
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'prompt-id-2',
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
it('should not trigger summarization if token count is below threshold', async () => {
|
it('should not trigger summarization if token count is below threshold', async () => {
|
||||||
const MOCKED_TOKEN_LIMIT = 1000;
|
const MOCKED_TOKEN_LIMIT = 1000;
|
||||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||||
@@ -502,7 +663,11 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
const newChat = client.getChat();
|
const newChat = client.getChat();
|
||||||
|
|
||||||
expect(tokenLimit).toHaveBeenCalled();
|
expect(tokenLimit).toHaveBeenCalled();
|
||||||
expect(result).toBeNull();
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
newTokenCount: 699,
|
||||||
|
originalTokenCount: 699,
|
||||||
|
});
|
||||||
expect(newChat).toBe(initialChat);
|
expect(newChat).toBe(initialChat);
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -579,6 +744,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
|
|
||||||
// Assert that summarization happened and returned the correct stats
|
// Assert that summarization happened and returned the correct stats
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
newTokenCount,
|
newTokenCount,
|
||||||
});
|
});
|
||||||
@@ -631,6 +797,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
|
|
||||||
// Assert that summarization happened and returned the correct stats
|
// Assert that summarization happened and returned the correct stats
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
newTokenCount,
|
newTokenCount,
|
||||||
});
|
});
|
||||||
@@ -670,6 +837,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
expect(mockSendMessage).toHaveBeenCalled();
|
expect(mockSendMessage).toHaveBeenCalled();
|
||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
newTokenCount,
|
newTokenCount,
|
||||||
});
|
});
|
||||||
@@ -727,6 +895,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
expect(result).toEqual({
|
expect(result).toEqual({
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
originalTokenCount: 100000,
|
originalTokenCount: 100000,
|
||||||
newTokenCount: 5000,
|
newTokenCount: 5000,
|
||||||
});
|
});
|
||||||
@@ -734,6 +903,109 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('sendMessageStream', () => {
|
describe('sendMessageStream', () => {
|
||||||
|
it('emits a compression event when the context was automatically compressed', async () => {
|
||||||
|
// Arrange
|
||||||
|
const mockStream = (async function* () {
|
||||||
|
yield { type: 'content', value: 'Hello' };
|
||||||
|
})();
|
||||||
|
mockTurnRunFn.mockReturnValue(mockStream);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const mockGenerator: Partial<ContentGenerator> = {
|
||||||
|
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||||
|
generateContent: mockGenerateContentFn,
|
||||||
|
};
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
|
||||||
|
const compressionInfo: ChatCompressionInfo = {
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
|
originalTokenCount: 1000,
|
||||||
|
newTokenCount: 500,
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce(
|
||||||
|
compressionInfo,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const stream = client.sendMessageStream(
|
||||||
|
[{ text: 'Hi' }],
|
||||||
|
new AbortController().signal,
|
||||||
|
'prompt-id-1',
|
||||||
|
);
|
||||||
|
|
||||||
|
const events = await fromAsync(stream);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(events).toContainEqual({
|
||||||
|
type: GeminiEventType.ChatCompressed,
|
||||||
|
value: compressionInfo,
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
it.each([
|
||||||
|
{
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
},
|
||||||
|
{ compressionStatus: CompressionStatus.NOOP },
|
||||||
|
{
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||||
|
},
|
||||||
|
])(
|
||||||
|
'does not emit a compression event when the status is $compressionStatus',
|
||||||
|
async ({ compressionStatus }) => {
|
||||||
|
// Arrange
|
||||||
|
const mockStream = (async function* () {
|
||||||
|
yield { type: 'content', value: 'Hello' };
|
||||||
|
})();
|
||||||
|
mockTurnRunFn.mockReturnValue(mockStream);
|
||||||
|
|
||||||
|
const mockChat: Partial<GeminiChat> = {
|
||||||
|
addHistory: vi.fn(),
|
||||||
|
getHistory: vi.fn().mockReturnValue([]),
|
||||||
|
};
|
||||||
|
client['chat'] = mockChat as GeminiChat;
|
||||||
|
|
||||||
|
const mockGenerator: Partial<ContentGenerator> = {
|
||||||
|
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||||
|
generateContent: mockGenerateContentFn,
|
||||||
|
};
|
||||||
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
|
||||||
|
const compressionInfo: ChatCompressionInfo = {
|
||||||
|
compressionStatus,
|
||||||
|
originalTokenCount: 1000,
|
||||||
|
newTokenCount: 500,
|
||||||
|
};
|
||||||
|
|
||||||
|
vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce(
|
||||||
|
compressionInfo,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Act
|
||||||
|
const stream = client.sendMessageStream(
|
||||||
|
[{ text: 'Hi' }],
|
||||||
|
new AbortController().signal,
|
||||||
|
'prompt-id-1',
|
||||||
|
);
|
||||||
|
|
||||||
|
const events = await fromAsync(stream);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
expect(events).not.toContainEqual({
|
||||||
|
type: GeminiEventType.ChatCompressed,
|
||||||
|
value: expect.anything(),
|
||||||
|
});
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
it('should include editor context when ideMode is enabled', async () => {
|
it('should include editor context when ideMode is enabled', async () => {
|
||||||
// Arrange
|
// Arrange
|
||||||
vi.mocked(ideContext.getIdeContext).mockReturnValue({
|
vi.mocked(ideContext.getIdeContext).mockReturnValue({
|
||||||
@@ -777,7 +1049,7 @@ describe('Gemini Client (client.ts)', () => {
|
|||||||
};
|
};
|
||||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||||
|
|
||||||
const initialRequest = [{ text: 'Hi' }];
|
const initialRequest: Part[] = [{ text: 'Hi' }];
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
const stream = client.sendMessageStream(
|
const stream = client.sendMessageStream(
|
||||||
@@ -1292,7 +1564,11 @@ ${JSON.stringify(
|
|||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
client['forceFullIdeContext'] = false; // Reset before each delta test
|
client['forceFullIdeContext'] = false; // Reset before each delta test
|
||||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
|
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||||
|
originalTokenCount: 0,
|
||||||
|
newTokenCount: 0,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
|
});
|
||||||
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
||||||
mockTurnRunFn.mockReturnValue(mockStream);
|
mockTurnRunFn.mockReturnValue(mockStream);
|
||||||
|
|
||||||
@@ -1551,7 +1827,11 @@ ${JSON.stringify(
|
|||||||
let mockChat: Partial<GeminiChat>;
|
let mockChat: Partial<GeminiChat>;
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
|
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||||
|
originalTokenCount: 0,
|
||||||
|
newTokenCount: 0,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
|
});
|
||||||
|
|
||||||
const mockStream = (async function* () {
|
const mockStream = (async function* () {
|
||||||
yield { type: 'content', value: 'response' };
|
yield { type: 'content', value: 'response' };
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import {
|
|||||||
getEnvironmentContext,
|
getEnvironmentContext,
|
||||||
} from '../utils/environmentContext.js';
|
} from '../utils/environmentContext.js';
|
||||||
import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js';
|
import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js';
|
||||||
|
import { CompressionStatus } from './turn.js';
|
||||||
import { Turn, GeminiEventType } from './turn.js';
|
import { Turn, GeminiEventType } from './turn.js';
|
||||||
import type { Config } from '../config/config.js';
|
import type { Config } from '../config/config.js';
|
||||||
import type { UserTierId } from '../code_assist/types.js';
|
import type { UserTierId } from '../code_assist/types.js';
|
||||||
@@ -105,8 +106,8 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
|||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private chat?: GeminiChat;
|
private chat?: GeminiChat;
|
||||||
private contentGenerator?: ContentGenerator;
|
private contentGenerator?: ContentGenerator;
|
||||||
private embeddingModel: string;
|
private readonly embeddingModel: string;
|
||||||
private generateContentConfig: GenerateContentConfig = {
|
private readonly generateContentConfig: GenerateContentConfig = {
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
topP: 1,
|
topP: 1,
|
||||||
};
|
};
|
||||||
@@ -117,7 +118,13 @@ export class GeminiClient {
|
|||||||
private lastSentIdeContext: IdeContext | undefined;
|
private lastSentIdeContext: IdeContext | undefined;
|
||||||
private forceFullIdeContext = true;
|
private forceFullIdeContext = true;
|
||||||
|
|
||||||
constructor(private config: Config) {
|
/**
|
||||||
|
* At any point in this conversation, was compression triggered without
|
||||||
|
* being forced and did it fail?
|
||||||
|
*/
|
||||||
|
private hasFailedCompressionAttempt = false;
|
||||||
|
|
||||||
|
constructor(private readonly config: Config) {
|
||||||
if (config.getProxy()) {
|
if (config.getProxy()) {
|
||||||
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
||||||
}
|
}
|
||||||
@@ -219,6 +226,7 @@ export class GeminiClient {
|
|||||||
|
|
||||||
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
||||||
this.forceFullIdeContext = true;
|
this.forceFullIdeContext = true;
|
||||||
|
this.hasFailedCompressionAttempt = false;
|
||||||
const envParts = await getEnvironmentContext(this.config);
|
const envParts = await getEnvironmentContext(this.config);
|
||||||
const toolRegistry = this.config.getToolRegistry();
|
const toolRegistry = this.config.getToolRegistry();
|
||||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||||
@@ -467,7 +475,7 @@ export class GeminiClient {
|
|||||||
|
|
||||||
const compressed = await this.tryCompressChat(prompt_id);
|
const compressed = await this.tryCompressChat(prompt_id);
|
||||||
|
|
||||||
if (compressed) {
|
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||||
yield { type: GeminiEventType.ChatCompressed, value: compressed };
|
yield { type: GeminiEventType.ChatCompressed, value: compressed };
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -764,12 +772,19 @@ export class GeminiClient {
|
|||||||
async tryCompressChat(
|
async tryCompressChat(
|
||||||
prompt_id: string,
|
prompt_id: string,
|
||||||
force: boolean = false,
|
force: boolean = false,
|
||||||
): Promise<ChatCompressionInfo | null> {
|
): Promise<ChatCompressionInfo> {
|
||||||
const curatedHistory = this.getChat().getHistory(true);
|
const curatedHistory = this.getChat().getHistory(true);
|
||||||
|
|
||||||
// Regardless of `force`, don't do anything if the history is empty.
|
// Regardless of `force`, don't do anything if the history is empty.
|
||||||
if (curatedHistory.length === 0) {
|
if (
|
||||||
return null;
|
curatedHistory.length === 0 ||
|
||||||
|
(this.hasFailedCompressionAttempt && !force)
|
||||||
|
) {
|
||||||
|
return {
|
||||||
|
originalTokenCount: 0,
|
||||||
|
newTokenCount: 0,
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const model = this.config.getModel();
|
const model = this.config.getModel();
|
||||||
@@ -781,7 +796,13 @@ export class GeminiClient {
|
|||||||
});
|
});
|
||||||
if (originalTokenCount === undefined) {
|
if (originalTokenCount === undefined) {
|
||||||
console.warn(`Could not determine token count for model ${model}.`);
|
console.warn(`Could not determine token count for model ${model}.`);
|
||||||
return null;
|
this.hasFailedCompressionAttempt = !force && true;
|
||||||
|
return {
|
||||||
|
originalTokenCount: 0,
|
||||||
|
newTokenCount: 0,
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const contextPercentageThreshold =
|
const contextPercentageThreshold =
|
||||||
@@ -792,7 +813,11 @@ export class GeminiClient {
|
|||||||
const threshold =
|
const threshold =
|
||||||
contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD;
|
contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD;
|
||||||
if (originalTokenCount < threshold * tokenLimit(model)) {
|
if (originalTokenCount < threshold * tokenLimit(model)) {
|
||||||
return null;
|
return {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount: originalTokenCount,
|
||||||
|
compressionStatus: CompressionStatus.NOOP,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -821,11 +846,12 @@ export class GeminiClient {
|
|||||||
},
|
},
|
||||||
config: {
|
config: {
|
||||||
systemInstruction: { text: getCompressionPrompt() },
|
systemInstruction: { text: getCompressionPrompt() },
|
||||||
|
maxOutputTokens: originalTokenCount,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
prompt_id,
|
prompt_id,
|
||||||
);
|
);
|
||||||
this.chat = await this.startChat([
|
const chat = await this.startChat([
|
||||||
{
|
{
|
||||||
role: 'user',
|
role: 'user',
|
||||||
parts: [{ text: summary }],
|
parts: [{ text: summary }],
|
||||||
@@ -842,11 +868,17 @@ export class GeminiClient {
|
|||||||
await this.getContentGenerator().countTokens({
|
await this.getContentGenerator().countTokens({
|
||||||
// model might change after calling `sendMessage`, so we get the newest value from config
|
// model might change after calling `sendMessage`, so we get the newest value from config
|
||||||
model: this.config.getModel(),
|
model: this.config.getModel(),
|
||||||
contents: this.getChat().getHistory(),
|
contents: chat.getHistory(),
|
||||||
});
|
});
|
||||||
if (newTokenCount === undefined) {
|
if (newTokenCount === undefined) {
|
||||||
console.warn('Could not determine compressed history token count.');
|
console.warn('Could not determine compressed history token count.');
|
||||||
return null;
|
this.hasFailedCompressionAttempt = !force && true;
|
||||||
|
return {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount: originalTokenCount,
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
logChatCompression(
|
logChatCompression(
|
||||||
@@ -857,9 +889,23 @@ export class GeminiClient {
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if (newTokenCount > originalTokenCount) {
|
||||||
|
this.getChat().setHistory(curatedHistory);
|
||||||
|
this.hasFailedCompressionAttempt = !force && true;
|
||||||
return {
|
return {
|
||||||
originalTokenCount,
|
originalTokenCount,
|
||||||
newTokenCount,
|
newTokenCount,
|
||||||
|
compressionStatus:
|
||||||
|
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
};
|
||||||
|
} else {
|
||||||
|
this.chat = chat; // Chat compression successful, set new state.
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
originalTokenCount,
|
||||||
|
newTokenCount,
|
||||||
|
compressionStatus: CompressionStatus.COMPRESSED,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -125,9 +125,24 @@ export type ServerGeminiErrorEvent = {
|
|||||||
value: GeminiErrorEventValue;
|
value: GeminiErrorEventValue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export enum CompressionStatus {
|
||||||
|
/** The compression was successful */
|
||||||
|
COMPRESSED = 1,
|
||||||
|
|
||||||
|
/** The compression failed due to the compression inflating the token count */
|
||||||
|
COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||||
|
|
||||||
|
/** The compression failed due to an error counting tokens */
|
||||||
|
COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||||
|
|
||||||
|
/** The compression was not necessary and no action was taken */
|
||||||
|
NOOP,
|
||||||
|
}
|
||||||
|
|
||||||
export interface ChatCompressionInfo {
|
export interface ChatCompressionInfo {
|
||||||
originalTokenCount: number;
|
originalTokenCount: number;
|
||||||
newTokenCount: number;
|
newTokenCount: number;
|
||||||
|
compressionStatus: CompressionStatus;
|
||||||
}
|
}
|
||||||
|
|
||||||
export type ServerGeminiChatCompressedEvent = {
|
export type ServerGeminiChatCompressedEvent = {
|
||||||
|
|||||||
Reference in New Issue
Block a user