feat: Display initial token usage metrics in /stats (#879)

This commit is contained in:
Abhi
2025-06-09 20:25:37 -04:00
committed by GitHub
parent 6484dc9008
commit 7f1252d364
11 changed files with 608 additions and 63 deletions

View File

@@ -4,26 +4,181 @@
* SPDX-License-Identifier: Apache-2.0
*/
import { type MutableRefObject } from 'react';
import { render } from 'ink-testing-library';
import { Text } from 'ink';
import { SessionProvider, useSession } from './SessionContext.js';
import { describe, it, expect } from 'vitest';
import { act } from 'react-dom/test-utils';
import { SessionStatsProvider, useSessionStats } from './SessionContext.js';
import { describe, it, expect, vi } from 'vitest';
import { GenerateContentResponseUsageMetadata } from '@google/genai';
const TestComponent = () => {
const { startTime } = useSession();
return <Text>{startTime.toISOString()}</Text>;
// Mock data that simulates what the Gemini API would return.
const mockMetadata1: GenerateContentResponseUsageMetadata = {
promptTokenCount: 100,
candidatesTokenCount: 200,
totalTokenCount: 300,
cachedContentTokenCount: 50,
toolUsePromptTokenCount: 10,
thoughtsTokenCount: 20,
};
describe('SessionContext', () => {
it('should provide a start time', () => {
const { lastFrame } = render(
<SessionProvider>
<TestComponent />
</SessionProvider>,
const mockMetadata2: GenerateContentResponseUsageMetadata = {
promptTokenCount: 10,
candidatesTokenCount: 20,
totalTokenCount: 30,
cachedContentTokenCount: 5,
toolUsePromptTokenCount: 1,
thoughtsTokenCount: 2,
};
/**
* A test harness component that uses the hook and exposes the context value
* via a mutable ref. This allows us to interact with the context's functions
* and assert against its state directly in our tests.
*/
const TestHarness = ({
contextRef,
}: {
contextRef: MutableRefObject<ReturnType<typeof useSessionStats> | undefined>;
}) => {
contextRef.current = useSessionStats();
return null;
};
describe('SessionStatsContext', () => {
it('should provide the correct initial state', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
const frameText = lastFrame();
// Check if the output is a valid ISO string, which confirms it's a Date object.
expect(new Date(frameText!).toString()).not.toBe('Invalid Date');
const stats = contextRef.current?.stats;
expect(stats?.sessionStartTime).toBeInstanceOf(Date);
expect(stats?.lastTurn).toBeNull();
expect(stats?.cumulative.turnCount).toBe(0);
expect(stats?.cumulative.totalTokenCount).toBe(0);
expect(stats?.cumulative.promptTokenCount).toBe(0);
});
it('should increment turnCount when startNewTurn is called', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
act(() => {
contextRef.current?.startNewTurn();
});
const stats = contextRef.current?.stats;
expect(stats?.cumulative.turnCount).toBe(1);
// Ensure token counts are unaffected
expect(stats?.cumulative.totalTokenCount).toBe(0);
});
it('should aggregate token usage correctly when addUsage is called', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
act(() => {
contextRef.current?.addUsage(mockMetadata1);
});
const stats = contextRef.current?.stats;
// Check that token counts are updated
expect(stats?.cumulative.totalTokenCount).toBe(
mockMetadata1.totalTokenCount ?? 0,
);
expect(stats?.cumulative.promptTokenCount).toBe(
mockMetadata1.promptTokenCount ?? 0,
);
// Check that turn count is NOT incremented
expect(stats?.cumulative.turnCount).toBe(0);
// Check that lastTurn is updated
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata1);
});
it('should correctly track a full logical turn with multiple API calls', () => {
const contextRef: MutableRefObject<
ReturnType<typeof useSessionStats> | undefined
> = { current: undefined };
render(
<SessionStatsProvider>
<TestHarness contextRef={contextRef} />
</SessionStatsProvider>,
);
// 1. User starts a new turn
act(() => {
contextRef.current?.startNewTurn();
});
// 2. First API call (e.g., prompt with a tool request)
act(() => {
contextRef.current?.addUsage(mockMetadata1);
});
// 3. Second API call (e.g., sending tool response back)
act(() => {
contextRef.current?.addUsage(mockMetadata2);
});
const stats = contextRef.current?.stats;
// Turn count should only be 1
expect(stats?.cumulative.turnCount).toBe(1);
// These fields should be the SUM of both calls
expect(stats?.cumulative.totalTokenCount).toBe(330); // 300 + 30
expect(stats?.cumulative.candidatesTokenCount).toBe(220); // 200 + 20
expect(stats?.cumulative.thoughtsTokenCount).toBe(22); // 20 + 2
// These fields should ONLY be from the FIRST call, because isNewTurnForAggregation was true
expect(stats?.cumulative.promptTokenCount).toBe(100);
expect(stats?.cumulative.cachedContentTokenCount).toBe(50);
expect(stats?.cumulative.toolUsePromptTokenCount).toBe(10);
// Last turn should hold the metadata from the most recent call
expect(stats?.lastTurn?.metadata).toEqual(mockMetadata2);
});
it('should throw an error when useSessionStats is used outside of a provider', () => {
// Suppress the expected console error during this test.
const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
const contextRef = { current: undefined };
// We expect rendering to fail, which React will catch and log as an error.
render(<TestHarness contextRef={contextRef} />);
// Assert that the first argument of the first call to console.error
// contains the expected message. This is more robust than checking
// the exact arguments, which can be affected by React/JSDOM internals.
expect(errorSpy.mock.calls[0][0]).toContain(
'useSessionStats must be used within a SessionStatsProvider',
);
errorSpy.mockRestore();
});
});

View File

@@ -4,35 +4,140 @@
* SPDX-License-Identifier: Apache-2.0
*/
import React, { createContext, useContext, useState, useMemo } from 'react';
import React, {
createContext,
useContext,
useState,
useMemo,
useCallback,
} from 'react';
interface SessionContextType {
startTime: Date;
import { type GenerateContentResponseUsageMetadata } from '@google/genai';
// --- Interface Definitions ---
interface CumulativeStats {
turnCount: number;
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
cachedContentTokenCount: number;
toolUsePromptTokenCount: number;
thoughtsTokenCount: number;
}
const SessionContext = createContext<SessionContextType | null>(null);
interface LastTurnStats {
metadata: GenerateContentResponseUsageMetadata;
// TODO(abhipatel12): Add apiTime, etc. here in a future step.
}
export const SessionProvider: React.FC<{ children: React.ReactNode }> = ({
interface SessionStatsState {
sessionStartTime: Date;
cumulative: CumulativeStats;
lastTurn: LastTurnStats | null;
isNewTurnForAggregation: boolean;
}
// Defines the final "value" of our context, including the state
// and the functions to update it.
interface SessionStatsContextValue {
stats: SessionStatsState;
startNewTurn: () => void;
addUsage: (metadata: GenerateContentResponseUsageMetadata) => void;
}
// --- Context Definition ---
const SessionStatsContext = createContext<SessionStatsContextValue | undefined>(
undefined,
);
// --- Provider Component ---
export const SessionStatsProvider: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
const [startTime] = useState(new Date());
const [stats, setStats] = useState<SessionStatsState>({
sessionStartTime: new Date(),
cumulative: {
turnCount: 0,
promptTokenCount: 0,
candidatesTokenCount: 0,
totalTokenCount: 0,
cachedContentTokenCount: 0,
toolUsePromptTokenCount: 0,
thoughtsTokenCount: 0,
},
lastTurn: null,
isNewTurnForAggregation: true,
});
// A single, internal worker function to handle all metadata aggregation.
const aggregateTokens = useCallback(
(metadata: GenerateContentResponseUsageMetadata) => {
setStats((prevState) => {
const { isNewTurnForAggregation } = prevState;
const newCumulative = { ...prevState.cumulative };
newCumulative.candidatesTokenCount +=
metadata.candidatesTokenCount ?? 0;
newCumulative.thoughtsTokenCount += metadata.thoughtsTokenCount ?? 0;
newCumulative.totalTokenCount += metadata.totalTokenCount ?? 0;
if (isNewTurnForAggregation) {
newCumulative.promptTokenCount += metadata.promptTokenCount ?? 0;
newCumulative.cachedContentTokenCount +=
metadata.cachedContentTokenCount ?? 0;
newCumulative.toolUsePromptTokenCount +=
metadata.toolUsePromptTokenCount ?? 0;
}
return {
...prevState,
cumulative: newCumulative,
lastTurn: { metadata },
isNewTurnForAggregation: false,
};
});
},
[],
);
const startNewTurn = useCallback(() => {
setStats((prevState) => ({
...prevState,
cumulative: {
...prevState.cumulative,
turnCount: prevState.cumulative.turnCount + 1,
},
isNewTurnForAggregation: true,
}));
}, []);
const value = useMemo(
() => ({
startTime,
stats,
startNewTurn,
addUsage: aggregateTokens,
}),
[startTime],
[stats, startNewTurn, aggregateTokens],
);
return (
<SessionContext.Provider value={value}>{children}</SessionContext.Provider>
<SessionStatsContext.Provider value={value}>
{children}
</SessionStatsContext.Provider>
);
};
export const useSession = () => {
const context = useContext(SessionContext);
if (!context) {
throw new Error('useSession must be used within a SessionProvider');
// --- Consumer Hook ---
export const useSessionStats = () => {
const context = useContext(SessionStatsContext);
if (context === undefined) {
throw new Error(
'useSessionStats must be used within a SessionStatsProvider',
);
}
return context;
};