mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
fix(compression): Discard compression result if it results in more token usage (#7047)
This commit is contained in:
@@ -4,7 +4,11 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { GeminiClient } from '@google/gemini-cli-core';
|
||||
import {
|
||||
CompressionStatus,
|
||||
type ChatCompressionInfo,
|
||||
type GeminiClient,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { vi, describe, it, expect, beforeEach } from 'vitest';
|
||||
import { compressCommand } from './compressCommand.js';
|
||||
import { createMockCommandContext } from '../../test-utils/mockCommandContext.js';
|
||||
@@ -35,6 +39,7 @@ describe('compressCommand', () => {
|
||||
isPending: true,
|
||||
originalTokenCount: null,
|
||||
newTokenCount: null,
|
||||
compressionStatus: null,
|
||||
},
|
||||
};
|
||||
await compressCommand.action!(context, '');
|
||||
@@ -50,25 +55,24 @@ describe('compressCommand', () => {
|
||||
});
|
||||
|
||||
it('should set pending item, call tryCompressChat, and add result on success', async () => {
|
||||
const compressedResult = {
|
||||
const compressedResult: ChatCompressionInfo = {
|
||||
originalTokenCount: 200,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
newTokenCount: 100,
|
||||
};
|
||||
mockTryCompressChat.mockResolvedValue(compressedResult);
|
||||
|
||||
await compressCommand.action!(context, '');
|
||||
|
||||
expect(context.ui.setPendingItem).toHaveBeenNthCalledWith(
|
||||
1,
|
||||
expect.objectContaining({
|
||||
type: MessageType.COMPRESSION,
|
||||
compression: {
|
||||
isPending: true,
|
||||
originalTokenCount: null,
|
||||
newTokenCount: null,
|
||||
},
|
||||
}),
|
||||
);
|
||||
expect(context.ui.setPendingItem).toHaveBeenNthCalledWith(1, {
|
||||
type: MessageType.COMPRESSION,
|
||||
compression: {
|
||||
isPending: true,
|
||||
compressionStatus: null,
|
||||
originalTokenCount: null,
|
||||
newTokenCount: null,
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockTryCompressChat).toHaveBeenCalledWith(
|
||||
expect.stringMatching(/^compress-\d+$/),
|
||||
@@ -76,14 +80,15 @@ describe('compressCommand', () => {
|
||||
);
|
||||
|
||||
expect(context.ui.addItem).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
{
|
||||
type: MessageType.COMPRESSION,
|
||||
compression: {
|
||||
isPending: false,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
originalTokenCount: 200,
|
||||
newTokenCount: 100,
|
||||
},
|
||||
}),
|
||||
},
|
||||
expect.any(Number),
|
||||
);
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ export const compressCommand: SlashCommand = {
|
||||
isPending: true,
|
||||
originalTokenCount: null,
|
||||
newTokenCount: null,
|
||||
compressionStatus: null,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -50,6 +51,7 @@ export const compressCommand: SlashCommand = {
|
||||
isPending: false,
|
||||
originalTokenCount: compressed.originalTokenCount,
|
||||
newTokenCount: compressed.newTokenCount,
|
||||
compressionStatus: compressed.compressionStatus,
|
||||
},
|
||||
} as HistoryItemCompression,
|
||||
Date.now(),
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
*/
|
||||
|
||||
import type {
|
||||
CompressionStatus,
|
||||
ToolCallConfirmationDetails,
|
||||
ToolResultDisplay,
|
||||
} from '@google/gemini-cli-core';
|
||||
@@ -56,6 +57,7 @@ export interface CompressionProps {
|
||||
isPending: boolean;
|
||||
originalTokenCount: number | null;
|
||||
newTokenCount: number | null;
|
||||
compressionStatus: CompressionStatus | null;
|
||||
}
|
||||
|
||||
export interface HistoryItemBase {
|
||||
|
||||
@@ -4,21 +4,38 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
|
||||
import {
|
||||
describe,
|
||||
it,
|
||||
expect,
|
||||
vi,
|
||||
beforeEach,
|
||||
afterEach,
|
||||
type Mocked,
|
||||
} from 'vitest';
|
||||
|
||||
import type {
|
||||
Chat,
|
||||
Content,
|
||||
EmbedContentResponse,
|
||||
GenerateContentResponse,
|
||||
Part,
|
||||
} from '@google/genai';
|
||||
import { GoogleGenAI } from '@google/genai';
|
||||
import { findIndexAfterFraction, GeminiClient } from './client.js';
|
||||
import type { ContentGenerator } from './contentGenerator.js';
|
||||
import { AuthType } from './contentGenerator.js';
|
||||
import type { GeminiChat } from './geminiChat.js';
|
||||
import {
|
||||
AuthType,
|
||||
type ContentGenerator,
|
||||
type ContentGeneratorConfig,
|
||||
} from './contentGenerator.js';
|
||||
import { type GeminiChat } from './geminiChat.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { GeminiEventType, Turn } from './turn.js';
|
||||
import {
|
||||
CompressionStatus,
|
||||
GeminiEventType,
|
||||
Turn,
|
||||
type ChatCompressionInfo,
|
||||
} from './turn.js';
|
||||
import { getCoreSystemPrompt } from './prompts.js';
|
||||
import { DEFAULT_GEMINI_FLASH_MODEL } from '../config/models.js';
|
||||
import { FileDiscoveryService } from '../services/fileDiscoveryService.js';
|
||||
@@ -34,7 +51,8 @@ const mockEmbedContentFn = vi.fn();
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('./turn', () => {
|
||||
vi.mock('./turn', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('./turn.js')>();
|
||||
// Define a mock class that has the same shape as the real Turn
|
||||
class MockTurn {
|
||||
pendingToolCalls = [];
|
||||
@@ -47,13 +65,8 @@ vi.mock('./turn', () => {
|
||||
}
|
||||
// Export the mock class as 'Turn'
|
||||
return {
|
||||
...actual,
|
||||
Turn: MockTurn,
|
||||
GeminiEventType: {
|
||||
MaxSessionTurns: 'MaxSessionTurns',
|
||||
ChatCompressed: 'ChatCompressed',
|
||||
Error: 'error',
|
||||
Content: 'content',
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
@@ -78,6 +91,19 @@ vi.mock('../telemetry/index.js', () => ({
|
||||
}));
|
||||
vi.mock('../ide/ideContext.js');
|
||||
|
||||
/**
|
||||
* Array.fromAsync ponyfill, which will be available in es 2024.
|
||||
*
|
||||
* Buffers an async generator into an array and returns the result.
|
||||
*/
|
||||
async function fromAsync<T>(promise: AsyncGenerator<T>): Promise<readonly T[]> {
|
||||
const results: T[] = [];
|
||||
for await (const result of promise) {
|
||||
results.push(result);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
describe('findIndexAfterFraction', () => {
|
||||
const history: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'This is the first message.' }] }, // JSON length: 66
|
||||
@@ -176,7 +202,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
getTool: vi.fn().mockReturnValue(null),
|
||||
};
|
||||
const fileService = new FileDiscoveryService('/test/dir');
|
||||
const contentGeneratorConfig = {
|
||||
const contentGeneratorConfig: ContentGeneratorConfig = {
|
||||
model: 'test-model',
|
||||
apiKey: 'test-key',
|
||||
vertexai: false,
|
||||
@@ -380,7 +406,9 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
it('should allow overriding model and config', async () => {
|
||||
const contents = [{ role: 'user', parts: [{ text: 'hello' }] }];
|
||||
const contents: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'hello' }] },
|
||||
];
|
||||
const schema = { type: 'string' };
|
||||
const abortSignal = new AbortController().signal;
|
||||
const customModel = 'custom-json-model';
|
||||
@@ -421,10 +449,10 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
describe('addHistory', () => {
|
||||
it('should call chat.addHistory with the provided content', async () => {
|
||||
const mockChat = {
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
};
|
||||
client['chat'] = mockChat as unknown as GeminiChat;
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const newContent = {
|
||||
role: 'user',
|
||||
@@ -486,6 +514,139 @@ describe('Gemini Client (client.ts)', () => {
|
||||
} as unknown as GeminiChat;
|
||||
});
|
||||
|
||||
function setup({
|
||||
chatHistory = [
|
||||
{ role: 'user', parts: [{ text: 'Long conversation' }] },
|
||||
{ role: 'model', parts: [{ text: 'Long response' }] },
|
||||
] as Content[],
|
||||
} = {}) {
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
getHistory: vi.fn().mockReturnValue(chatHistory),
|
||||
setHistory: vi.fn(),
|
||||
sendMessage: vi.fn().mockResolvedValue({ text: 'Summary' }),
|
||||
};
|
||||
const mockCountTokens = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ totalTokens: 1000 })
|
||||
.mockResolvedValueOnce({ totalTokens: 5000 });
|
||||
|
||||
const mockGenerator: Partial<Mocked<ContentGenerator>> = {
|
||||
countTokens: mockCountTokens,
|
||||
};
|
||||
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
client['startChat'] = vi.fn().mockResolvedValue({ ...mockChat });
|
||||
|
||||
return { client, mockChat, mockGenerator };
|
||||
}
|
||||
|
||||
describe('when compression inflates the token count', () => {
|
||||
it('uses the truncated history for compression');
|
||||
it('allows compression to be forced/manual after a failure', async () => {
|
||||
const { client, mockGenerator } = setup();
|
||||
mockGenerator.countTokens?.mockResolvedValue({
|
||||
totalTokens: 1000,
|
||||
});
|
||||
await client.tryCompressChat('prompt-id-4'); // Fails
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
newTokenCount: 1000,
|
||||
originalTokenCount: 1000,
|
||||
});
|
||||
});
|
||||
|
||||
it('yields the result even if the compression inflated the tokens', async () => {
|
||||
const { client } = setup();
|
||||
const result = await client.tryCompressChat('prompt-id-4', true);
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
newTokenCount: 5000,
|
||||
originalTokenCount: 1000,
|
||||
});
|
||||
});
|
||||
|
||||
it('does not manipulate the source chat', async () => {
|
||||
const { client, mockChat } = setup();
|
||||
await client.tryCompressChat('prompt-id-4', true);
|
||||
|
||||
expect(client['chat']).toBe(mockChat); // a new chat session was not created
|
||||
});
|
||||
|
||||
it('restores the history back to the original', async () => {
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
mockCountTokens.mockResolvedValue({
|
||||
totalTokens: 999,
|
||||
});
|
||||
|
||||
const originalHistory: Content[] = [
|
||||
{ role: 'user', parts: [{ text: 'what is your wisdom?' }] },
|
||||
{ role: 'model', parts: [{ text: 'some wisdom' }] },
|
||||
{ role: 'user', parts: [{ text: 'ahh that is a good a wisdom' }] },
|
||||
];
|
||||
|
||||
const { client } = setup({
|
||||
chatHistory: originalHistory,
|
||||
});
|
||||
const { compressionStatus } =
|
||||
await client.tryCompressChat('prompt-id-4');
|
||||
|
||||
expect(compressionStatus).toBe(
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
);
|
||||
expect(client['chat']?.setHistory).toHaveBeenCalledWith(
|
||||
originalHistory,
|
||||
);
|
||||
});
|
||||
|
||||
it('will not attempt to compress context after a failure', async () => {
|
||||
const { client, mockGenerator } = setup();
|
||||
await client.tryCompressChat('prompt-id-4');
|
||||
|
||||
const result = await client.tryCompressChat('prompt-id-5');
|
||||
|
||||
// it counts tokens for {original, compressed} and then never again
|
||||
expect(mockGenerator.countTokens).toHaveBeenCalledTimes(2);
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
newTokenCount: 0,
|
||||
originalTokenCount: 0,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it('attempts to compress with a maxOutputTokens set to the original token count', async () => {
|
||||
vi.mocked(tokenLimit).mockReturnValue(1000);
|
||||
mockCountTokens.mockResolvedValue({
|
||||
totalTokens: 999,
|
||||
});
|
||||
|
||||
mockGetHistory.mockReturnValue([
|
||||
{ role: 'user', parts: [{ text: '...history...' }] },
|
||||
]);
|
||||
|
||||
// Mock the summary response from the chat
|
||||
mockSendMessage.mockResolvedValue({
|
||||
role: 'model',
|
||||
parts: [{ text: 'This is a summary.' }],
|
||||
});
|
||||
|
||||
await client.tryCompressChat('prompt-id-2', true);
|
||||
|
||||
expect(mockSendMessage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
config: expect.objectContaining({
|
||||
maxOutputTokens: 999,
|
||||
}),
|
||||
}),
|
||||
'prompt-id-2',
|
||||
);
|
||||
});
|
||||
|
||||
it('should not trigger summarization if token count is below threshold', async () => {
|
||||
const MOCKED_TOKEN_LIMIT = 1000;
|
||||
vi.mocked(tokenLimit).mockReturnValue(MOCKED_TOKEN_LIMIT);
|
||||
@@ -502,7 +663,11 @@ describe('Gemini Client (client.ts)', () => {
|
||||
const newChat = client.getChat();
|
||||
|
||||
expect(tokenLimit).toHaveBeenCalled();
|
||||
expect(result).toBeNull();
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
newTokenCount: 699,
|
||||
originalTokenCount: 699,
|
||||
});
|
||||
expect(newChat).toBe(initialChat);
|
||||
});
|
||||
|
||||
@@ -579,6 +744,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
// Assert that summarization happened and returned the correct stats
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
originalTokenCount,
|
||||
newTokenCount,
|
||||
});
|
||||
@@ -631,6 +797,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
|
||||
// Assert that summarization happened and returned the correct stats
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
originalTokenCount,
|
||||
newTokenCount,
|
||||
});
|
||||
@@ -670,6 +837,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
expect(mockSendMessage).toHaveBeenCalled();
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
originalTokenCount,
|
||||
newTokenCount,
|
||||
});
|
||||
@@ -727,6 +895,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
originalTokenCount: 100000,
|
||||
newTokenCount: 5000,
|
||||
});
|
||||
@@ -734,6 +903,109 @@ describe('Gemini Client (client.ts)', () => {
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
it('emits a compression event when the context was automatically compressed', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const compressionInfo: ChatCompressionInfo = {
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
originalTokenCount: 1000,
|
||||
newTokenCount: 500,
|
||||
};
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce(
|
||||
compressionInfo,
|
||||
);
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-id-1',
|
||||
);
|
||||
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
expect(events).toContainEqual({
|
||||
type: GeminiEventType.ChatCompressed,
|
||||
value: compressionInfo,
|
||||
});
|
||||
});
|
||||
|
||||
it.each([
|
||||
{
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
},
|
||||
{ compressionStatus: CompressionStatus.NOOP },
|
||||
{
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||
},
|
||||
])(
|
||||
'does not emit a compression event when the status is $compressionStatus',
|
||||
async ({ compressionStatus }) => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = mockChat as GeminiChat;
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
generateContent: mockGenerateContentFn,
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const compressionInfo: ChatCompressionInfo = {
|
||||
compressionStatus,
|
||||
originalTokenCount: 1000,
|
||||
newTokenCount: 500,
|
||||
};
|
||||
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValueOnce(
|
||||
compressionInfo,
|
||||
);
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
'prompt-id-1',
|
||||
);
|
||||
|
||||
const events = await fromAsync(stream);
|
||||
|
||||
// Assert
|
||||
expect(events).not.toContainEqual({
|
||||
type: GeminiEventType.ChatCompressed,
|
||||
value: expect.anything(),
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
it('should include editor context when ideMode is enabled', async () => {
|
||||
// Arrange
|
||||
vi.mocked(ideContext.getIdeContext).mockReturnValue({
|
||||
@@ -777,7 +1049,7 @@ describe('Gemini Client (client.ts)', () => {
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
const initialRequest = [{ text: 'Hi' }];
|
||||
const initialRequest: Part[] = [{ text: 'Hi' }];
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
@@ -1292,7 +1564,11 @@ ${JSON.stringify(
|
||||
|
||||
beforeEach(() => {
|
||||
client['forceFullIdeContext'] = false; // Reset before each delta test
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
});
|
||||
vi.spyOn(client['config'], 'getIdeMode').mockReturnValue(true);
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
@@ -1551,7 +1827,11 @@ ${JSON.stringify(
|
||||
let mockChat: Partial<GeminiChat>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue(null);
|
||||
vi.spyOn(client, 'tryCompressChat').mockResolvedValue({
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
});
|
||||
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'response' };
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
getEnvironmentContext,
|
||||
} from '../utils/environmentContext.js';
|
||||
import type { ServerGeminiStreamEvent, ChatCompressionInfo } from './turn.js';
|
||||
import { CompressionStatus } from './turn.js';
|
||||
import { Turn, GeminiEventType } from './turn.js';
|
||||
import type { Config } from '../config/config.js';
|
||||
import type { UserTierId } from '../code_assist/types.js';
|
||||
@@ -105,8 +106,8 @@ const COMPRESSION_PRESERVE_THRESHOLD = 0.3;
|
||||
export class GeminiClient {
|
||||
private chat?: GeminiChat;
|
||||
private contentGenerator?: ContentGenerator;
|
||||
private embeddingModel: string;
|
||||
private generateContentConfig: GenerateContentConfig = {
|
||||
private readonly embeddingModel: string;
|
||||
private readonly generateContentConfig: GenerateContentConfig = {
|
||||
temperature: 0,
|
||||
topP: 1,
|
||||
};
|
||||
@@ -117,7 +118,13 @@ export class GeminiClient {
|
||||
private lastSentIdeContext: IdeContext | undefined;
|
||||
private forceFullIdeContext = true;
|
||||
|
||||
constructor(private config: Config) {
|
||||
/**
|
||||
* At any point in this conversation, was compression triggered without
|
||||
* being forced and did it fail?
|
||||
*/
|
||||
private hasFailedCompressionAttempt = false;
|
||||
|
||||
constructor(private readonly config: Config) {
|
||||
if (config.getProxy()) {
|
||||
setGlobalDispatcher(new ProxyAgent(config.getProxy() as string));
|
||||
}
|
||||
@@ -219,6 +226,7 @@ export class GeminiClient {
|
||||
|
||||
async startChat(extraHistory?: Content[]): Promise<GeminiChat> {
|
||||
this.forceFullIdeContext = true;
|
||||
this.hasFailedCompressionAttempt = false;
|
||||
const envParts = await getEnvironmentContext(this.config);
|
||||
const toolRegistry = this.config.getToolRegistry();
|
||||
const toolDeclarations = toolRegistry.getFunctionDeclarations();
|
||||
@@ -467,7 +475,7 @@ export class GeminiClient {
|
||||
|
||||
const compressed = await this.tryCompressChat(prompt_id);
|
||||
|
||||
if (compressed) {
|
||||
if (compressed.compressionStatus === CompressionStatus.COMPRESSED) {
|
||||
yield { type: GeminiEventType.ChatCompressed, value: compressed };
|
||||
}
|
||||
|
||||
@@ -764,12 +772,19 @@ export class GeminiClient {
|
||||
async tryCompressChat(
|
||||
prompt_id: string,
|
||||
force: boolean = false,
|
||||
): Promise<ChatCompressionInfo | null> {
|
||||
): Promise<ChatCompressionInfo> {
|
||||
const curatedHistory = this.getChat().getHistory(true);
|
||||
|
||||
// Regardless of `force`, don't do anything if the history is empty.
|
||||
if (curatedHistory.length === 0) {
|
||||
return null;
|
||||
if (
|
||||
curatedHistory.length === 0 ||
|
||||
(this.hasFailedCompressionAttempt && !force)
|
||||
) {
|
||||
return {
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
};
|
||||
}
|
||||
|
||||
const model = this.config.getModel();
|
||||
@@ -781,7 +796,13 @@ export class GeminiClient {
|
||||
});
|
||||
if (originalTokenCount === undefined) {
|
||||
console.warn(`Could not determine token count for model ${model}.`);
|
||||
return null;
|
||||
this.hasFailedCompressionAttempt = !force && true;
|
||||
return {
|
||||
originalTokenCount: 0,
|
||||
newTokenCount: 0,
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||
};
|
||||
}
|
||||
|
||||
const contextPercentageThreshold =
|
||||
@@ -792,7 +813,11 @@ export class GeminiClient {
|
||||
const threshold =
|
||||
contextPercentageThreshold ?? COMPRESSION_TOKEN_THRESHOLD;
|
||||
if (originalTokenCount < threshold * tokenLimit(model)) {
|
||||
return null;
|
||||
return {
|
||||
originalTokenCount,
|
||||
newTokenCount: originalTokenCount,
|
||||
compressionStatus: CompressionStatus.NOOP,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -821,11 +846,12 @@ export class GeminiClient {
|
||||
},
|
||||
config: {
|
||||
systemInstruction: { text: getCompressionPrompt() },
|
||||
maxOutputTokens: originalTokenCount,
|
||||
},
|
||||
},
|
||||
prompt_id,
|
||||
);
|
||||
this.chat = await this.startChat([
|
||||
const chat = await this.startChat([
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: summary }],
|
||||
@@ -842,11 +868,17 @@ export class GeminiClient {
|
||||
await this.getContentGenerator().countTokens({
|
||||
// model might change after calling `sendMessage`, so we get the newest value from config
|
||||
model: this.config.getModel(),
|
||||
contents: this.getChat().getHistory(),
|
||||
contents: chat.getHistory(),
|
||||
});
|
||||
if (newTokenCount === undefined) {
|
||||
console.warn('Could not determine compressed history token count.');
|
||||
return null;
|
||||
this.hasFailedCompressionAttempt = !force && true;
|
||||
return {
|
||||
originalTokenCount,
|
||||
newTokenCount: originalTokenCount,
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||
};
|
||||
}
|
||||
|
||||
logChatCompression(
|
||||
@@ -857,9 +889,23 @@ export class GeminiClient {
|
||||
}),
|
||||
);
|
||||
|
||||
if (newTokenCount > originalTokenCount) {
|
||||
this.getChat().setHistory(curatedHistory);
|
||||
this.hasFailedCompressionAttempt = !force && true;
|
||||
return {
|
||||
originalTokenCount,
|
||||
newTokenCount,
|
||||
compressionStatus:
|
||||
CompressionStatus.COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
};
|
||||
} else {
|
||||
this.chat = chat; // Chat compression successful, set new state.
|
||||
}
|
||||
|
||||
return {
|
||||
originalTokenCount,
|
||||
newTokenCount,
|
||||
compressionStatus: CompressionStatus.COMPRESSED,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -125,9 +125,24 @@ export type ServerGeminiErrorEvent = {
|
||||
value: GeminiErrorEventValue;
|
||||
};
|
||||
|
||||
export enum CompressionStatus {
|
||||
/** The compression was successful */
|
||||
COMPRESSED = 1,
|
||||
|
||||
/** The compression failed due to the compression inflating the token count */
|
||||
COMPRESSION_FAILED_INFLATED_TOKEN_COUNT,
|
||||
|
||||
/** The compression failed due to an error counting tokens */
|
||||
COMPRESSION_FAILED_TOKEN_COUNT_ERROR,
|
||||
|
||||
/** The compression was not necessary and no action was taken */
|
||||
NOOP,
|
||||
}
|
||||
|
||||
export interface ChatCompressionInfo {
|
||||
originalTokenCount: number;
|
||||
newTokenCount: number;
|
||||
compressionStatus: CompressionStatus;
|
||||
}
|
||||
|
||||
export type ServerGeminiChatCompressedEvent = {
|
||||
|
||||
Reference in New Issue
Block a user