mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 01:07:46 +00:00
feat: Display initial token usage metrics in /stats (#879)
This commit is contained in:
@@ -4,26 +4,181 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { type MutableRefObject } from 'react';
|
||||
import { render } from 'ink-testing-library';
|
||||
import { Text } from 'ink';
|
||||
import { SessionProvider, useSession } from './SessionContext.js';
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { act } from 'react-dom/test-utils';
|
||||
import { SessionStatsProvider, useSessionStats } from './SessionContext.js';
|
||||
import { describe, it, expect, vi } from 'vitest';
|
||||
import { GenerateContentResponseUsageMetadata } from '@google/genai';
|
||||
|
||||
const TestComponent = () => {
|
||||
const { startTime } = useSession();
|
||||
return <Text>{startTime.toISOString()}</Text>;
|
||||
// Mock data that simulates what the Gemini API would return.
|
||||
const mockMetadata1: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 100,
|
||||
candidatesTokenCount: 200,
|
||||
totalTokenCount: 300,
|
||||
cachedContentTokenCount: 50,
|
||||
toolUsePromptTokenCount: 10,
|
||||
thoughtsTokenCount: 20,
|
||||
};
|
||||
|
||||
describe('SessionContext', () => {
|
||||
it('should provide a start time', () => {
|
||||
const { lastFrame } = render(
|
||||
<SessionProvider>
|
||||
<TestComponent />
|
||||
</SessionProvider>,
|
||||
const mockMetadata2: GenerateContentResponseUsageMetadata = {
|
||||
promptTokenCount: 10,
|
||||
candidatesTokenCount: 20,
|
||||
totalTokenCount: 30,
|
||||
cachedContentTokenCount: 5,
|
||||
toolUsePromptTokenCount: 1,
|
||||
thoughtsTokenCount: 2,
|
||||
};
|
||||
|
||||
/**
|
||||
* A test harness component that uses the hook and exposes the context value
|
||||
* via a mutable ref. This allows us to interact with the context's functions
|
||||
* and assert against its state directly in our tests.
|
||||
*/
|
||||
const TestHarness = ({
|
||||
contextRef,
|
||||
}: {
|
||||
contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>;
|
||||
}) => {
|
||||
contextRef.current = useSessionStats();
|
||||
return null;
|
||||
};
|
||||
|
||||
describe('SessionStatsContext', () => {
|
||||
it('should provide the correct initial state', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
const frameText = lastFrame();
|
||||
// Check if the output is a valid ISO string, which confirms it's a Date object.
|
||||
expect(new Date(frameText!).toString()).not.toBe('Invalid Date');
|
||||
const stats = contextRef.current?.stats;
|
||||
|
||||
expect(stats?.sessionStartTime).toBeInstanceOf(Date);
|
||||
expect(stats?.lastTurn).toBeNull();
|
||||
expect(stats?.cumulative.turnCount).toBe(0);
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(0);
|
||||
expect(stats?.cumulative.promptTokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should increment turnCount when startNewTurn is called', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
act(() => {
|
||||
contextRef.current?.startNewTurn();
|
||||
});
|
||||
|
||||
const stats = contextRef.current?.stats;
|
||||
expect(stats?.cumulative.turnCount).toBe(1);
|
||||
// Ensure token counts are unaffected
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(0);
|
||||
});
|
||||
|
||||
it('should aggregate token usage correctly when addUsage is called', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
act(() => {
|
||||
contextRef.current?.addUsage(mockMetadata1);
|
||||
});
|
||||
|
||||
const stats = contextRef.current?.stats;
|
||||
|
||||
// Check that token counts are updated
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(
|
||||
mockMetadata1.totalTokenCount ?? 0,
|
||||
);
|
||||
expect(stats?.cumulative.promptTokenCount).toBe(
|
||||
mockMetadata1.promptTokenCount ?? 0,
|
||||
);
|
||||
|
||||
// Check that turn count is NOT incremented
|
||||
expect(stats?.cumulative.turnCount).toBe(0);
|
||||
|
||||
// Check that lastTurn is updated
|
||||
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1);
|
||||
});
|
||||
|
||||
it('should correctly track a full logical turn with multiple API calls', () => {
|
||||
const contextRef: MutableRefObject<
|
||||
ReturnType<typeof useSessionStats> | undefined
|
||||
> = { current: undefined };
|
||||
|
||||
render(
|
||||
<SessionStatsProvider>
|
||||
<TestHarness contextRef={contextRef} />
|
||||
</SessionStatsProvider>,
|
||||
);
|
||||
|
||||
// 1. User starts a new turn
|
||||
act(() => {
|
||||
contextRef.current?.startNewTurn();
|
||||
});
|
||||
|
||||
// 2. First API call (e.g., prompt with a tool request)
|
||||
act(() => {
|
||||
contextRef.current?.addUsage(mockMetadata1);
|
||||
});
|
||||
|
||||
// 3. Second API call (e.g., sending tool response back)
|
||||
act(() => {
|
||||
contextRef.current?.addUsage(mockMetadata2);
|
||||
});
|
||||
|
||||
const stats = contextRef.current?.stats;
|
||||
|
||||
// Turn count should only be 1
|
||||
expect(stats?.cumulative.turnCount).toBe(1);
|
||||
|
||||
// These fields should be the SUM of both calls
|
||||
expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30
|
||||
expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20
|
||||
expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2
|
||||
|
||||
// These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true
|
||||
expect(stats?.cumulative.promptTokenCount).toBe(100);
|
||||
expect(stats?.cumulative.cachedContentTokenCount).toBe(50);
|
||||
expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10);
|
||||
|
||||
// Last turn should hold the metadata from the most recent call
|
||||
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2);
|
||||
});
|
||||
|
||||
it('should throw an error when useSessionStats is used outside of a provider', () => {
|
||||
// Suppress the expected console error during this test.
|
||||
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
|
||||
|
||||
const contextRef = { current: undefined };
|
||||
|
||||
// We expect rendering to fail, which React will catch and log as an error.
|
||||
render(<TestHarness contextRef={contextRef} />);
|
||||
|
||||
// Assert that the first argument of the first call to console.error
|
||||
// contains the expected message. This is more robust than checking
|
||||
// the exact arguments, which can be affected by React/JSDOM internals.
|
||||
expect(errorSpy.mock.calls[0][0]).toContain(
|
||||
'useSessionStats must be used within a SessionStatsProvider',
|
||||
);
|
||||
|
||||
errorSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -4,35 +4,140 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import React, { createContext, useContext, useState, useMemo } from 'react';
|
||||
import React, {
|
||||
createContext,
|
||||
useContext,
|
||||
useState,
|
||||
useMemo,
|
||||
useCallback,
|
||||
} from 'react';
|
||||
|
||||
interface SessionContextType {
|
||||
startTime: Date;
|
||||
import { type GenerateContentResponseUsageMetadata } from '@google/genai';
|
||||
|
||||
// --- Interface Definitions ---
|
||||
|
||||
interface CumulativeStats {
|
||||
turnCount: number;
|
||||
promptTokenCount: number;
|
||||
candidatesTokenCount: number;
|
||||
totalTokenCount: number;
|
||||
cachedContentTokenCount: number;
|
||||
toolUsePromptTokenCount: number;
|
||||
thoughtsTokenCount: number;
|
||||
}
|
||||
|
||||
const SessionContext = createContext<SessionContextType | null>(null);
|
||||
interface LastTurnStats {
|
||||
metadata: GenerateContentResponseUsageMetadata;
|
||||
// TODO(abhipatel12): Add apiTime, etc. here in a future step.
|
||||
}
|
||||
|
||||
export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({
|
||||
interface SessionStatsState {
|
||||
sessionStartTime: Date;
|
||||
cumulative: CumulativeStats;
|
||||
lastTurn: LastTurnStats | null;
|
||||
isNewTurnForAggregation: boolean;
|
||||
}
|
||||
|
||||
// Defines the final "value" of our context, including the state
|
||||
// and the functions to update it.
|
||||
interface SessionStatsContextValue {
|
||||
stats: SessionStatsState;
|
||||
startNewTurn: () => void;
|
||||
addUsage: (metadata: GenerateContentResponseUsageMetadata) => void;
|
||||
}
|
||||
|
||||
// --- Context Definition ---
|
||||
|
||||
const SessionStatsContext = createContext<SessionStatsContextValue | undefined>(
|
||||
undefined,
|
||||
);
|
||||
|
||||
// --- Provider Component ---
|
||||
|
||||
export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({
|
||||
children,
|
||||
}) => {
|
||||
const [startTime] = useState(new Date());
|
||||
const [stats, setStats] = useState<SessionStatsState>({
|
||||
sessionStartTime: new Date(),
|
||||
cumulative: {
|
||||
turnCount: 0,
|
||||
promptTokenCount: 0,
|
||||
candidatesTokenCount: 0,
|
||||
totalTokenCount: 0,
|
||||
cachedContentTokenCount: 0,
|
||||
toolUsePromptTokenCount: 0,
|
||||
thoughtsTokenCount: 0,
|
||||
},
|
||||
lastTurn: null,
|
||||
isNewTurnForAggregation: true,
|
||||
});
|
||||
|
||||
// A single, internal worker function to handle all metadata aggregation.
|
||||
const aggregateTokens = useCallback(
|
||||
(metadata: GenerateContentResponseUsageMetadata) => {
|
||||
setStats((prevState) => {
|
||||
const { isNewTurnForAggregation } = prevState;
|
||||
const newCumulative = { ...prevState.cumulative };
|
||||
|
||||
newCumulative.candidatesTokenCount +=
|
||||
metadata.candidatesTokenCount ?? 0;
|
||||
newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0;
|
||||
newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0;
|
||||
|
||||
if (isNewTurnForAggregation) {
|
||||
newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0;
|
||||
newCumulative.cachedContentTokenCount +=
|
||||
metadata.cachedContentTokenCount ?? 0;
|
||||
newCumulative.toolUsePromptTokenCount +=
|
||||
metadata.toolUsePromptTokenCount ?? 0;
|
||||
}
|
||||
|
||||
return {
|
||||
...prevState,
|
||||
cumulative: newCumulative,
|
||||
lastTurn: { metadata },
|
||||
isNewTurnForAggregation: false,
|
||||
};
|
||||
});
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
const startNewTurn = useCallback(() => {
|
||||
setStats((prevState) => ({
|
||||
...prevState,
|
||||
cumulative: {
|
||||
...prevState.cumulative,
|
||||
turnCount: prevState.cumulative.turnCount + 1,
|
||||
},
|
||||
isNewTurnForAggregation: true,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
const value = useMemo(
|
||||
() => ({
|
||||
startTime,
|
||||
stats,
|
||||
startNewTurn,
|
||||
addUsage: aggregateTokens,
|
||||
}),
|
||||
[startTime],
|
||||
[stats, startNewTurn, aggregateTokens],
|
||||
);
|
||||
|
||||
return (
|
||||
<SessionContext.Provider value={value}>{children}</SessionContext.Provider>
|
||||
<SessionStatsContext.Provider value={value}>
|
||||
{children}
|
||||
</SessionStatsContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
export const useSession = () => {
|
||||
const context = useContext(SessionContext);
|
||||
if (!context) {
|
||||
throw new Error('useSession must be used within a SessionProvider');
|
||||
// --- Consumer Hook ---
|
||||
|
||||
export const useSessionStats = () => {
|
||||
const context = useContext(SessionStatsContext);
|
||||
if (context === undefined) {
|
||||
throw new Error(
|
||||
'useSessionStats must be used within a SessionStatsProvider',
|
||||
);
|
||||
}
|
||||
return context;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user