feat: Display initial token usage metrics in /stats (#879)

This commit is contained in:
Abhi
2025-06-09 20:25:37 -04:00
committed by GitHub
parent 6484dc9008
commit 7f1252d364
11 changed files with 608 additions and 63 deletions

View File

@@ -13,14 +13,32 @@ import {
GoogleGenAI,
} from '@google/genai';
import { GeminiClient } from './client.js';
import { ContentGenerator } from './contentGenerator.js';
import { GeminiChat } from './geminiChat.js';
import { Config } from '../config/config.js';
import { Turn } from './turn.js';
// --- Mocks ---
const mockChatCreateFn = vi.fn();
const mockGenerateContentFn = vi.fn();
const mockEmbedContentFn = vi.fn();
const mockTurnRunFn = vi.fn();
vi.mock('@google/genai');
vi.mock('./turn', () => {
// Define a mock class that has the same shape as the real Turn
class MockTurn {
pendingToolCalls = [];
// The run method is a property that holds our mock function
run = mockTurnRunFn;
constructor() {
// The constructor can be empty or do some mock setup
}
}
// Export the mock class as 'Turn'
return { Turn: MockTurn };
});
vi.mock('../config/config.js');
vi.mock('./prompts');
@@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => {
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
});
});
describe('sendMessageStream', () => {
it('should return the turn instance after the stream is complete', async () => {
// Arrange
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = Promise.resolve(mockChat as GeminiChat);
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
);
// Consume the stream manually to get the final return value.
let finalResult: Turn | undefined;
while (true) {
const result = await stream.next();
if (result.done) {
finalResult = result.value;
break;
}
}
// Assert
expect(finalResult).toBeInstanceOf(Turn);
});
});
});

View File

@@ -174,9 +174,10 @@ export class GeminiClient {
request: PartListUnion,
signal: AbortSignal,
turns: number = this.MAX_TURNS,
): AsyncGenerator<ServerGeminiStreamEvent> {
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
if (!turns) {
return;
const chat = await this.chat;
return new Turn(chat);
}
const compressed = await this.tryCompressChat();
@@ -193,9 +194,12 @@ export class GeminiClient {
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
if (nextSpeakerCheck?.next_speaker === 'model') {
const nextRequest = [{ text: 'Please continue.' }];
// This recursive call's events will be yielded out, but the final
// turn object will be from the top-level call.
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
}
}
return turn;
}
private _logApiRequest(model: string, inputTokenCount: number): void {
@@ -423,6 +427,10 @@ export class GeminiClient {
});
const result = await retryWithBackoff(apiCall);
console.log(
'Raw API Response in client.ts:',
JSON.stringify(result, null, 2),
);
const durationMs = Date.now() - startTime;
this._logApiResponse(modelToUse, durationMs, attempt, result);
return result;

View File

@@ -10,8 +10,14 @@ import {
GeminiEventType,
ServerGeminiToolCallRequestEvent,
ServerGeminiErrorEvent,
ServerGeminiUsageMetadataEvent,
} from './turn.js';
import { GenerateContentResponse, Part, Content } from '@google/genai';
import {
GenerateContentResponse,
Part,
Content,
GenerateContentResponseUsageMetadata,
} from '@google/genai';
import { reportError } from '../utils/errorReporting.js';
import { GeminiChat } from './geminiChat.js';
@@ -49,6 +55,24 @@ describe('Turn', () => {
};
let mockChatInstance: MockedChatInstance;
const mockMetadata1: GenerateContentResponseUsageMetadata = {
promptTokenCount: 10,
candidatesTokenCount: 20,
totalTokenCount: 30,
cachedContentTokenCount: 5,
toolUsePromptTokenCount: 2,
thoughtsTokenCount: 3,
};
const mockMetadata2: GenerateContentResponseUsageMetadata = {
promptTokenCount: 100,
candidatesTokenCount: 200,
totalTokenCount: 300,
cachedContentTokenCount: 50,
toolUsePromptTokenCount: 20,
thoughtsTokenCount: 30,
};
beforeEach(() => {
vi.resetAllMocks();
mockChatInstance = {
@@ -96,6 +120,7 @@ describe('Turn', () => {
message: reqParts,
config: { abortSignal: expect.any(AbortSignal) },
});
expect(events).toEqual([
{ type: GeminiEventType.Content, value: 'Hello' },
{ type: GeminiEventType.Content, value: ' world' },
@@ -208,6 +233,41 @@ describe('Turn', () => {
);
});
it('should yield the last UsageMetadata event from the stream', async () => {
const mockResponseStream = (async function* () {
yield {
candidates: [{ content: { parts: [{ text: 'First response' }] } }],
usageMetadata: mockMetadata1,
} as unknown as GenerateContentResponse;
yield {
functionCalls: [{ name: 'aTool' }],
usageMetadata: mockMetadata2,
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Test metadata' }];
for await (const event of turn.run(
reqParts,
new AbortController().signal,
)) {
events.push(event);
}
// There should be a content event, a tool call, and our metadata event
expect(events.length).toBe(3);
const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent;
expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata);
// The value should be the *last* metadata object received.
expect(metadataEvent.value).toEqual(mockMetadata2);
// Also check the public getter
expect(turn.getUsageMetadata()).toEqual(mockMetadata2);
});
it('should handle function calls with undefined name or args', async () => {
const mockResponseStream = (async function* () {
yield {
@@ -219,7 +279,6 @@ describe('Turn', () => {
} as unknown as GenerateContentResponse;
})();
mockSendMessageStream.mockResolvedValue(mockResponseStream);
const events = [];
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
for await (const event of turn.run(

View File

@@ -9,6 +9,7 @@ import {
GenerateContentResponse,
FunctionCall,
FunctionDeclaration,
GenerateContentResponseUsageMetadata,
} from '@google/genai';
import {
ToolCallConfirmationDetails,
@@ -43,6 +44,7 @@ export enum GeminiEventType {
UserCancelled = 'user_cancelled',
Error = 'error',
ChatCompressed = 'chat_compressed',
UsageMetadata = 'usage_metadata',
}
export interface GeminiErrorEventValue {
@@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = {
type: GeminiEventType.ChatCompressed;
};
export type ServerGeminiUsageMetadataEvent = {
type: GeminiEventType.UsageMetadata;
value: GenerateContentResponseUsageMetadata;
};
// The original union type, now composed of the individual types
export type ServerGeminiStreamEvent =
| ServerGeminiContentEvent
@@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent =
| ServerGeminiToolCallConfirmationEvent
| ServerGeminiUserCancelledEvent
| ServerGeminiErrorEvent
| ServerGeminiChatCompressedEvent;
| ServerGeminiChatCompressedEvent
| ServerGeminiUsageMetadataEvent;
// A turn manages the agentic loop turn within the server context.
export class Turn {
@@ -118,6 +126,7 @@ export class Turn {
args: Record<string, unknown>;
}>;
private debugResponses: GenerateContentResponse[];
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
constructor(private readonly chat: GeminiChat) {
this.pendingToolCalls = [];
@@ -157,6 +166,18 @@ export class Turn {
yield event;
}
}
if (resp.usageMetadata) {
this.lastUsageMetadata =
resp.usageMetadata as GenerateContentResponseUsageMetadata;
}
}
if (this.lastUsageMetadata) {
yield {
type: GeminiEventType.UsageMetadata,
value: this.lastUsageMetadata,
};
}
} catch (error) {
if (signal.aborted) {
@@ -197,4 +218,8 @@ export class Turn {
getDebugResponses(): GenerateContentResponse[] {
return this.debugResponses;
}
getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
return this.lastUsageMetadata;
}
}