mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 16:57:46 +00:00
feat: Display initial token usage metrics in /stats (#879)
This commit is contained in:
@@ -13,14 +13,32 @@ import {
|
||||
GoogleGenAI,
|
||||
} from '@google/genai';
|
||||
import { GeminiClient } from './client.js';
|
||||
import { ContentGenerator } from './contentGenerator.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import { Turn } from './turn.js';
|
||||
|
||||
// --- Mocks ---
|
||||
const mockChatCreateFn = vi.fn();
|
||||
const mockGenerateContentFn = vi.fn();
|
||||
const mockEmbedContentFn = vi.fn();
|
||||
const mockTurnRunFn = vi.fn();
|
||||
|
||||
vi.mock('@google/genai');
|
||||
vi.mock('./turn', () => {
|
||||
// Define a mock class that has the same shape as the real Turn
|
||||
class MockTurn {
|
||||
pendingToolCalls = [];
|
||||
// The run method is a property that holds our mock function
|
||||
run = mockTurnRunFn;
|
||||
|
||||
constructor() {
|
||||
// The constructor can be empty or do some mock setup
|
||||
}
|
||||
}
|
||||
// Export the mock class as 'Turn'
|
||||
return { Turn: MockTurn };
|
||||
});
|
||||
|
||||
vi.mock('../config/config.js');
|
||||
vi.mock('./prompts');
|
||||
@@ -237,4 +255,44 @@ describe('Gemini Client (client.ts)', () => {
|
||||
expect(mockChat.addHistory).toHaveBeenCalledWith(newContent);
|
||||
});
|
||||
});
|
||||
|
||||
describe('sendMessageStream', () => {
|
||||
it('should return the turn instance after the stream is complete', async () => {
|
||||
// Arrange
|
||||
const mockStream = (async function* () {
|
||||
yield { type: 'content', value: 'Hello' };
|
||||
})();
|
||||
mockTurnRunFn.mockReturnValue(mockStream);
|
||||
|
||||
const mockChat: Partial<GeminiChat> = {
|
||||
addHistory: vi.fn(),
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
};
|
||||
client['chat'] = Promise.resolve(mockChat as GeminiChat);
|
||||
|
||||
const mockGenerator: Partial<ContentGenerator> = {
|
||||
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
|
||||
};
|
||||
client['contentGenerator'] = mockGenerator as ContentGenerator;
|
||||
|
||||
// Act
|
||||
const stream = client.sendMessageStream(
|
||||
[{ text: 'Hi' }],
|
||||
new AbortController().signal,
|
||||
);
|
||||
|
||||
// Consume the stream manually to get the final return value.
|
||||
let finalResult: Turn | undefined;
|
||||
while (true) {
|
||||
const result = await stream.next();
|
||||
if (result.done) {
|
||||
finalResult = result.value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Assert
|
||||
expect(finalResult).toBeInstanceOf(Turn);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -174,9 +174,10 @@ export class GeminiClient {
|
||||
request: PartListUnion,
|
||||
signal: AbortSignal,
|
||||
turns: number = this.MAX_TURNS,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
): AsyncGenerator<ServerGeminiStreamEvent, Turn> {
|
||||
if (!turns) {
|
||||
return;
|
||||
const chat = await this.chat;
|
||||
return new Turn(chat);
|
||||
}
|
||||
|
||||
const compressed = await this.tryCompressChat();
|
||||
@@ -193,9 +194,12 @@ export class GeminiClient {
|
||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
// This recursive call's events will be yielded out, but the final
|
||||
// turn object will be from the top-level call.
|
||||
yield* this.sendMessageStream(nextRequest, signal, turns - 1);
|
||||
}
|
||||
}
|
||||
return turn;
|
||||
}
|
||||
|
||||
private _logApiRequest(model: string, inputTokenCount: number): void {
|
||||
@@ -423,6 +427,10 @@ export class GeminiClient {
|
||||
});
|
||||
|
||||
const result = await retryWithBackoff(apiCall);
|
||||
console.log(
|
||||
'Raw API Response in client.ts:',
|
||||
JSON.stringify(result, null, 2),
|
||||
);
|
||||
const durationMs = Date.now() - startTime;
|
||||
this._logApiResponse(modelToUse, durationMs, attempt, result);
|
||||
return result;
|
||||
|
||||
@@ -10,8 +10,14 @@ import {
|
||||
GeminiEventType,
|
||||
ServerGeminiToolCallRequestEvent,
|
||||
ServerGeminiErrorEvent,
|
||||
ServerGeminiUsageMetadataEvent,
|
||||
} from './turn.js';
|
||||
import { GenerateContentResponse, Part, Content } from '@google/genai';
|
||||
import {
|
||||
GenerateContentResponse,
|
||||
Part,
|
||||
Content,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
} from '@google/genai';
|
||||
import { reportError } from '../utils/errorReporting.js';
|
||||
import { GeminiChat } from './geminiChat.js';
|
||||
|
||||
@@ -49,6 +55,24 @@ describe('Turn', () => {
|
||||
};
|
||||
let mockChatInstance: MockedChatInstance;
|
||||
|
||||
const mockMetadata1: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30,
|
||||
cachedContentTokenCount: 5,
|
||||
toolUsePromptTokenCount: 2,
|
||||
thoughtsTokenCount: 3,
|
||||
};
|
||||
|
||||
const mockMetadata2: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 100,
|
||||
candidatesTokenCount: 200,
|
||||
totalTokenCount: 300,
|
||||
cachedContentTokenCount: 50,
|
||||
toolUsePromptTokenCount: 20,
|
||||
thoughtsTokenCount: 30,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
mockChatInstance = {
|
||||
@@ -96,6 +120,7 @@ describe('Turn', () => {
|
||||
message: reqParts,
|
||||
config: { abortSignal: expect.any(AbortSignal) },
|
||||
});
|
||||
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Hello' },
|
||||
{ type: GeminiEventType.Content, value: ' world' },
|
||||
@@ -208,6 +233,41 @@ describe('Turn', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should yield the last UsageMetadata event from the stream', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
candidates: [{ content: { parts: [{ text: 'First response' }] } }],
|
||||
usageMetadata: mockMetadata1,
|
||||
} as unknown as GenerateContentResponse;
|
||||
yield {
|
||||
functionCalls: [{ name: 'aTool' }],
|
||||
usageMetadata: mockMetadata2,
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test metadata' }];
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
// There should be a content event, a tool call, and our metadata event
|
||||
expect(events.length).toBe(3);
|
||||
|
||||
const metadataEvent = events[2] as ServerGeminiUsageMetadataEvent;
|
||||
expect(metadataEvent.type).toBe(GeminiEventType.UsageMetadata);
|
||||
|
||||
// The value should be the *last* metadata object received.
|
||||
expect(metadataEvent.value).toEqual(mockMetadata2);
|
||||
|
||||
// Also check the public getter
|
||||
expect(turn.getUsageMetadata()).toEqual(mockMetadata2);
|
||||
});
|
||||
|
||||
it('should handle function calls with undefined name or args', async () => {
|
||||
const mockResponseStream = (async function* () {
|
||||
yield {
|
||||
@@ -219,7 +279,6 @@ describe('Turn', () => {
|
||||
} as unknown as GenerateContentResponse;
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
||||
for await (const event of turn.run(
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
GenerateContentResponse,
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
GenerateContentResponseUsageMetadata,
|
||||
} from '@google/genai';
|
||||
import {
|
||||
ToolCallConfirmationDetails,
|
||||
@@ -43,6 +44,7 @@ export enum GeminiEventType {
|
||||
UserCancelled = 'user_cancelled',
|
||||
Error = 'error',
|
||||
ChatCompressed = 'chat_compressed',
|
||||
UsageMetadata = 'usage_metadata',
|
||||
}
|
||||
|
||||
export interface GeminiErrorEventValue {
|
||||
@@ -100,6 +102,11 @@ export type ServerGeminiChatCompressedEvent = {
|
||||
type: GeminiEventType.ChatCompressed;
|
||||
};
|
||||
|
||||
export type ServerGeminiUsageMetadataEvent = {
|
||||
type: GeminiEventType.UsageMetadata;
|
||||
value: GenerateContentResponseUsageMetadata;
|
||||
};
|
||||
|
||||
// The original union type, now composed of the individual types
|
||||
export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiContentEvent
|
||||
@@ -108,7 +115,8 @@ export type ServerGeminiStreamEvent =
|
||||
| ServerGeminiToolCallConfirmationEvent
|
||||
| ServerGeminiUserCancelledEvent
|
||||
| ServerGeminiErrorEvent
|
||||
| ServerGeminiChatCompressedEvent;
|
||||
| ServerGeminiChatCompressedEvent
|
||||
| ServerGeminiUsageMetadataEvent;
|
||||
|
||||
// A turn manages the agentic loop turn within the server context.
|
||||
export class Turn {
|
||||
@@ -118,6 +126,7 @@ export class Turn {
|
||||
args: Record<string, unknown>;
|
||||
}>;
|
||||
private debugResponses: GenerateContentResponse[];
|
||||
private lastUsageMetadata: GenerateContentResponseUsageMetadata | null = null;
|
||||
|
||||
constructor(private readonly chat: GeminiChat) {
|
||||
this.pendingToolCalls = [];
|
||||
@@ -157,6 +166,18 @@ export class Turn {
|
||||
yield event;
|
||||
}
|
||||
}
|
||||
|
||||
if (resp.usageMetadata) {
|
||||
this.lastUsageMetadata =
|
||||
resp.usageMetadata as GenerateContentResponseUsageMetadata;
|
||||
}
|
||||
}
|
||||
|
||||
if (this.lastUsageMetadata) {
|
||||
yield {
|
||||
type: GeminiEventType.UsageMetadata,
|
||||
value: this.lastUsageMetadata,
|
||||
};
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
@@ -197,4 +218,8 @@ export class Turn {
|
||||
getDebugResponses(): GenerateContentResponse[] {
|
||||
return this.debugResponses;
|
||||
}
|
||||
|
||||
getUsageMetadata(): GenerateContentResponseUsageMetadata | null {
|
||||
return this.lastUsageMetadata;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user