mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-20 08:47:44 +00:00
Refactor(chat): Introduce custom Chat class for future modifications
- Copied the `Chat` class from `@google/genai` into `packages/server/src/core/geminiChat.ts`. - This change is in preparation for future modifications to the chat handling logic. - Updated relevant files to use the new `GeminiChat` class. Part of https://github.com/google-gemini/gemini-cli/issues/551
This commit is contained in:
committed by
N. Taylor Mullen
parent
02503a3248
commit
480549e02e
@@ -23,7 +23,7 @@ import {
|
|||||||
ToolResultDisplay,
|
ToolResultDisplay,
|
||||||
ToolCallRequestInfo,
|
ToolCallRequestInfo,
|
||||||
} from '@gemini-code/server';
|
} from '@gemini-code/server';
|
||||||
import { type Chat, type PartListUnion, type Part } from '@google/genai';
|
import { type PartListUnion, type Part } from '@google/genai';
|
||||||
import {
|
import {
|
||||||
StreamingState,
|
StreamingState,
|
||||||
ToolCallStatus,
|
ToolCallStatus,
|
||||||
@@ -39,6 +39,7 @@ import { useStateAndRef } from './useStateAndRef.js';
|
|||||||
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
import { UseHistoryManagerReturn } from './useHistoryManager.js';
|
||||||
import { useLogger } from './useLogger.js';
|
import { useLogger } from './useLogger.js';
|
||||||
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
|
import { useToolScheduler, mapToDisplay } from './useToolScheduler.js';
|
||||||
|
import { GeminiChat } from '@gemini-code/server/src/core/geminiChat.js';
|
||||||
|
|
||||||
enum StreamProcessingStatus {
|
enum StreamProcessingStatus {
|
||||||
Completed,
|
Completed,
|
||||||
@@ -63,7 +64,7 @@ export const useGeminiStream = (
|
|||||||
) => {
|
) => {
|
||||||
const [initError, setInitError] = useState<string | null>(null);
|
const [initError, setInitError] = useState<string | null>(null);
|
||||||
const abortControllerRef = useRef<AbortController | null>(null);
|
const abortControllerRef = useRef<AbortController | null>(null);
|
||||||
const chatSessionRef = useRef<Chat | null>(null);
|
const chatSessionRef = useRef<GeminiChat | null>(null);
|
||||||
const geminiClientRef = useRef<GeminiClient | null>(null);
|
const geminiClientRef = useRef<GeminiClient | null>(null);
|
||||||
const [isResponding, setIsResponding] = useState<boolean>(false);
|
const [isResponding, setIsResponding] = useState<boolean>(false);
|
||||||
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
const [pendingHistoryItemRef, setPendingHistoryItem] =
|
||||||
@@ -235,7 +236,7 @@ export const useGeminiStream = (
|
|||||||
|
|
||||||
const ensureChatSession = useCallback(async (): Promise<{
|
const ensureChatSession = useCallback(async (): Promise<{
|
||||||
client: GeminiClient | null;
|
client: GeminiClient | null;
|
||||||
chat: Chat | null;
|
chat: GeminiChat | null;
|
||||||
}> => {
|
}> => {
|
||||||
const currentClient = geminiClientRef.current;
|
const currentClient = geminiClientRef.current;
|
||||||
if (!currentClient) {
|
if (!currentClient) {
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import {
|
|||||||
GenerateContentConfig,
|
GenerateContentConfig,
|
||||||
GoogleGenAI,
|
GoogleGenAI,
|
||||||
Part,
|
Part,
|
||||||
Chat,
|
|
||||||
SchemaUnion,
|
SchemaUnion,
|
||||||
PartListUnion,
|
PartListUnion,
|
||||||
Content,
|
Content,
|
||||||
@@ -23,6 +22,7 @@ import { ReadManyFilesTool } from '../tools/read-many-files.js';
|
|||||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||||
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
import { checkNextSpeaker } from '../utils/nextSpeakerChecker.js';
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
|
import { GeminiChat } from './geminiChat.js';
|
||||||
|
|
||||||
export class GeminiClient {
|
export class GeminiClient {
|
||||||
private client: GoogleGenAI;
|
private client: GoogleGenAI;
|
||||||
@@ -108,7 +108,7 @@ export class GeminiClient {
|
|||||||
return initialParts;
|
return initialParts;
|
||||||
}
|
}
|
||||||
|
|
||||||
async startChat(): Promise<Chat> {
|
async startChat(): Promise<GeminiChat> {
|
||||||
const envParts = await this.getEnvironment();
|
const envParts = await this.getEnvironment();
|
||||||
const toolDeclarations = this.config
|
const toolDeclarations = this.config
|
||||||
.getToolRegistry()
|
.getToolRegistry()
|
||||||
@@ -128,15 +128,17 @@ export class GeminiClient {
|
|||||||
const userMemory = this.config.getUserMemory();
|
const userMemory = this.config.getUserMemory();
|
||||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||||
|
|
||||||
return this.client.chats.create({
|
return new GeminiChat(
|
||||||
model: this.model,
|
this.client,
|
||||||
config: {
|
this.client.models,
|
||||||
|
this.model,
|
||||||
|
{
|
||||||
systemInstruction,
|
systemInstruction,
|
||||||
...this.generateContentConfig,
|
...this.generateContentConfig,
|
||||||
tools,
|
tools,
|
||||||
},
|
},
|
||||||
history,
|
history,
|
||||||
});
|
);
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
await reportError(
|
await reportError(
|
||||||
error,
|
error,
|
||||||
@@ -150,7 +152,7 @@ export class GeminiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async *sendMessageStream(
|
async *sendMessageStream(
|
||||||
chat: Chat,
|
chat: GeminiChat,
|
||||||
request: PartListUnion,
|
request: PartListUnion,
|
||||||
signal?: AbortSignal,
|
signal?: AbortSignal,
|
||||||
turns: number = this.MAX_TURNS,
|
turns: number = this.MAX_TURNS,
|
||||||
|
|||||||
314
packages/server/src/core/geminiChat.ts
Normal file
314
packages/server/src/core/geminiChat.ts
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
/**
|
||||||
|
* @license
|
||||||
|
* Copyright 2025 Google LLC
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
// DISCLAIMER: This is a copied version of https://github.com/googleapis/js-genai/blob/main/src/chats.ts with the intention of working around a key bug
|
||||||
|
// where function responses are not treated as "valid" responses: https://b.corp.google.com/issues/420354090
|
||||||
|
|
||||||
|
import {
|
||||||
|
GenerateContentResponse,
|
||||||
|
Content,
|
||||||
|
Models,
|
||||||
|
GenerateContentConfig,
|
||||||
|
SendMessageParameters,
|
||||||
|
GoogleGenAI,
|
||||||
|
createUserContent,
|
||||||
|
} from '@google/genai';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns true if the response is valid, false otherwise.
|
||||||
|
*/
|
||||||
|
function isValidResponse(response: GenerateContentResponse): boolean {
|
||||||
|
if (response.candidates === undefined || response.candidates.length === 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
const content = response.candidates[0]?.content;
|
||||||
|
if (content === undefined) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return isValidContent(content);
|
||||||
|
}
|
||||||
|
|
||||||
|
function isValidContent(content: Content): boolean {
|
||||||
|
if (content.parts === undefined || content.parts.length === 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (const part of content.parts) {
|
||||||
|
if (part === undefined || Object.keys(part).length === 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!part.thought && part.text !== undefined && part.text === '') {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates the history contains the correct roles.
|
||||||
|
*
|
||||||
|
* @throws Error if the history does not start with a user turn.
|
||||||
|
* @throws Error if the history contains an invalid role.
|
||||||
|
*/
|
||||||
|
function validateHistory(history: Content[]) {
|
||||||
|
// Empty history is valid.
|
||||||
|
if (history.length === 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (const content of history) {
|
||||||
|
if (content.role !== 'user' && content.role !== 'model') {
|
||||||
|
throw new Error(`Role must be user or model, but got ${content.role}.`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the curated (valid) history from a comprehensive history.
|
||||||
|
*
|
||||||
|
* @remarks
|
||||||
|
* The model may sometimes generate invalid or empty contents(e.g., due to safty
|
||||||
|
* filters or recitation). Extracting valid turns from the history
|
||||||
|
* ensures that subsequent requests could be accpeted by the model.
|
||||||
|
*/
|
||||||
|
function extractCuratedHistory(comprehensiveHistory: Content[]): Content[] {
|
||||||
|
if (comprehensiveHistory === undefined || comprehensiveHistory.length === 0) {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
const curatedHistory: Content[] = [];
|
||||||
|
const length = comprehensiveHistory.length;
|
||||||
|
let i = 0;
|
||||||
|
while (i < length) {
|
||||||
|
if (comprehensiveHistory[i].role === 'user') {
|
||||||
|
curatedHistory.push(comprehensiveHistory[i]);
|
||||||
|
i++;
|
||||||
|
} else {
|
||||||
|
const modelOutput: Content[] = [];
|
||||||
|
let isValid = true;
|
||||||
|
while (i < length && comprehensiveHistory[i].role === 'model') {
|
||||||
|
modelOutput.push(comprehensiveHistory[i]);
|
||||||
|
if (isValid && !isValidContent(comprehensiveHistory[i])) {
|
||||||
|
isValid = false;
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
if (isValid) {
|
||||||
|
curatedHistory.push(...modelOutput);
|
||||||
|
} else {
|
||||||
|
// Remove the last user input when model content is invalid.
|
||||||
|
curatedHistory.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return curatedHistory;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Chat session that enables sending messages to the model with previous
|
||||||
|
* conversation context.
|
||||||
|
*
|
||||||
|
* @remarks
|
||||||
|
* The session maintains all the turns between user and model.
|
||||||
|
*/
|
||||||
|
export class GeminiChat {
|
||||||
|
// A promise to represent the current state of the message being sent to the
|
||||||
|
// model.
|
||||||
|
private sendPromise: Promise<void> = Promise.resolve();
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly apiClient: GoogleGenAI,
|
||||||
|
private readonly modelsModule: Models,
|
||||||
|
private readonly model: string,
|
||||||
|
private readonly config: GenerateContentConfig = {},
|
||||||
|
private history: Content[] = [],
|
||||||
|
) {
|
||||||
|
validateHistory(history);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a message to the model and returns the response.
|
||||||
|
*
|
||||||
|
* @remarks
|
||||||
|
* This method will wait for the previous message to be processed before
|
||||||
|
* sending the next message.
|
||||||
|
*
|
||||||
|
* @see {@link Chat#sendMessageStream} for streaming method.
|
||||||
|
* @param params - parameters for sending messages within a chat session.
|
||||||
|
* @returns The model's response.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* const chat = ai.chats.create({model: 'gemini-2.0-flash'});
|
||||||
|
* const response = await chat.sendMessage({
|
||||||
|
* message: 'Why is the sky blue?'
|
||||||
|
* });
|
||||||
|
* console.log(response.text);
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
async sendMessage(
|
||||||
|
params: SendMessageParameters,
|
||||||
|
): Promise<GenerateContentResponse> {
|
||||||
|
await this.sendPromise;
|
||||||
|
const userContent = createUserContent(params.message);
|
||||||
|
const responsePromise = this.modelsModule.generateContent({
|
||||||
|
model: this.model,
|
||||||
|
contents: this.getHistory(true).concat(userContent),
|
||||||
|
config: params.config ?? this.config,
|
||||||
|
});
|
||||||
|
this.sendPromise = (async () => {
|
||||||
|
const response = await responsePromise;
|
||||||
|
const outputContent = response.candidates?.[0]?.content;
|
||||||
|
|
||||||
|
// Because the AFC input contains the entire curated chat history in
|
||||||
|
// addition to the new user input, we need to truncate the AFC history
|
||||||
|
// to deduplicate the existing chat history.
|
||||||
|
const fullAutomaticFunctionCallingHistory =
|
||||||
|
response.automaticFunctionCallingHistory;
|
||||||
|
const index = this.getHistory(true).length;
|
||||||
|
|
||||||
|
let automaticFunctionCallingHistory: Content[] = [];
|
||||||
|
if (fullAutomaticFunctionCallingHistory != null) {
|
||||||
|
automaticFunctionCallingHistory =
|
||||||
|
fullAutomaticFunctionCallingHistory.slice(index) ?? [];
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelOutput = outputContent ? [outputContent] : [];
|
||||||
|
this.recordHistory(
|
||||||
|
userContent,
|
||||||
|
modelOutput,
|
||||||
|
automaticFunctionCallingHistory,
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
})();
|
||||||
|
await this.sendPromise.catch(() => {
|
||||||
|
// Resets sendPromise to avoid subsequent calls failing
|
||||||
|
this.sendPromise = Promise.resolve();
|
||||||
|
});
|
||||||
|
return responsePromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sends a message to the model and returns the response in chunks.
|
||||||
|
*
|
||||||
|
* @remarks
|
||||||
|
* This method will wait for the previous message to be processed before
|
||||||
|
* sending the next message.
|
||||||
|
*
|
||||||
|
* @see {@link Chat#sendMessage} for non-streaming method.
|
||||||
|
* @param params - parameters for sending the message.
|
||||||
|
* @return The model's response.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```ts
|
||||||
|
* const chat = ai.chats.create({model: 'gemini-2.0-flash'});
|
||||||
|
* const response = await chat.sendMessageStream({
|
||||||
|
* message: 'Why is the sky blue?'
|
||||||
|
* });
|
||||||
|
* for await (const chunk of response) {
|
||||||
|
* console.log(chunk.text);
|
||||||
|
* }
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
async sendMessageStream(
|
||||||
|
params: SendMessageParameters,
|
||||||
|
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||||
|
await this.sendPromise;
|
||||||
|
const userContent = createUserContent(params.message);
|
||||||
|
const streamResponse = this.modelsModule.generateContentStream({
|
||||||
|
model: this.model,
|
||||||
|
contents: this.getHistory(true).concat(userContent),
|
||||||
|
config: params.config ?? this.config,
|
||||||
|
});
|
||||||
|
// Resolve the internal tracking of send completion promise - `sendPromise`
|
||||||
|
// for both success and failure response. The actual failure is still
|
||||||
|
// propagated by the `await streamResponse`.
|
||||||
|
this.sendPromise = streamResponse
|
||||||
|
.then(() => undefined)
|
||||||
|
.catch(() => undefined);
|
||||||
|
const response = await streamResponse;
|
||||||
|
const result = this.processStreamResponse(response, userContent);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the chat history.
|
||||||
|
*
|
||||||
|
* @remarks
|
||||||
|
* The history is a list of contents alternating between user and model.
|
||||||
|
*
|
||||||
|
* There are two types of history:
|
||||||
|
* - The `curated history` contains only the valid turns between user and
|
||||||
|
* model, which will be included in the subsequent requests sent to the model.
|
||||||
|
* - The `comprehensive history` contains all turns, including invalid or
|
||||||
|
* empty model outputs, providing a complete record of the history.
|
||||||
|
*
|
||||||
|
* The history is updated after receiving the response from the model,
|
||||||
|
* for streaming response, it means receiving the last chunk of the response.
|
||||||
|
*
|
||||||
|
* The `comprehensive history` is returned by default. To get the `curated
|
||||||
|
* history`, set the `curated` parameter to `true`.
|
||||||
|
*
|
||||||
|
* @param curated - whether to return the curated history or the comprehensive
|
||||||
|
* history.
|
||||||
|
* @return History contents alternating between user and model for the entire
|
||||||
|
* chat session.
|
||||||
|
*/
|
||||||
|
getHistory(curated: boolean = false): Content[] {
|
||||||
|
const history = curated
|
||||||
|
? extractCuratedHistory(this.history)
|
||||||
|
: this.history;
|
||||||
|
// Deep copy the history to avoid mutating the history outside of the
|
||||||
|
// chat session.
|
||||||
|
return structuredClone(history);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async *processStreamResponse(
|
||||||
|
streamResponse: AsyncGenerator<GenerateContentResponse>,
|
||||||
|
inputContent: Content,
|
||||||
|
) {
|
||||||
|
const outputContent: Content[] = [];
|
||||||
|
for await (const chunk of streamResponse) {
|
||||||
|
if (isValidResponse(chunk)) {
|
||||||
|
const content = chunk.candidates?.[0]?.content;
|
||||||
|
if (content !== undefined) {
|
||||||
|
outputContent.push(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield chunk;
|
||||||
|
}
|
||||||
|
this.recordHistory(inputContent, outputContent);
|
||||||
|
}
|
||||||
|
|
||||||
|
private recordHistory(
|
||||||
|
userInput: Content,
|
||||||
|
modelOutput: Content[],
|
||||||
|
automaticFunctionCallingHistory?: Content[],
|
||||||
|
) {
|
||||||
|
let outputContents: Content[] = [];
|
||||||
|
if (
|
||||||
|
modelOutput.length > 0 &&
|
||||||
|
modelOutput.every((content) => content.role !== undefined)
|
||||||
|
) {
|
||||||
|
outputContents = modelOutput;
|
||||||
|
} else {
|
||||||
|
// Appends an empty content when model returns empty response, so that the
|
||||||
|
// history is always alternating between user and model.
|
||||||
|
outputContents.push({
|
||||||
|
role: 'model',
|
||||||
|
parts: [],
|
||||||
|
} as Content);
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
automaticFunctionCallingHistory &&
|
||||||
|
automaticFunctionCallingHistory.length > 0
|
||||||
|
) {
|
||||||
|
this.history.push(
|
||||||
|
...extractCuratedHistory(automaticFunctionCallingHistory!),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
this.history.push(userInput);
|
||||||
|
}
|
||||||
|
this.history.push(...outputContents);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,8 +11,9 @@ import {
|
|||||||
ServerGeminiToolCallRequestEvent,
|
ServerGeminiToolCallRequestEvent,
|
||||||
ServerGeminiErrorEvent,
|
ServerGeminiErrorEvent,
|
||||||
} from './turn.js';
|
} from './turn.js';
|
||||||
import { Chat, GenerateContentResponse, Part, Content } from '@google/genai';
|
import { GenerateContentResponse, Part, Content } from '@google/genai';
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
|
import { GeminiChat } from './geminiChat.js';
|
||||||
|
|
||||||
const mockSendMessageStream = vi.fn();
|
const mockSendMessageStream = vi.fn();
|
||||||
const mockGetHistory = vi.fn();
|
const mockGetHistory = vi.fn();
|
||||||
@@ -54,7 +55,7 @@ describe('Turn', () => {
|
|||||||
sendMessageStream: mockSendMessageStream,
|
sendMessageStream: mockSendMessageStream,
|
||||||
getHistory: mockGetHistory,
|
getHistory: mockGetHistory,
|
||||||
};
|
};
|
||||||
turn = new Turn(mockChatInstance as unknown as Chat);
|
turn = new Turn(mockChatInstance as unknown as GeminiChat);
|
||||||
mockGetHistory.mockReturnValue([]);
|
mockGetHistory.mockReturnValue([]);
|
||||||
mockSendMessageStream.mockResolvedValue((async function* () {})());
|
mockSendMessageStream.mockResolvedValue((async function* () {})());
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
Part,
|
Part,
|
||||||
Chat,
|
|
||||||
PartListUnion,
|
PartListUnion,
|
||||||
GenerateContentResponse,
|
GenerateContentResponse,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
@@ -20,6 +19,7 @@ import {
|
|||||||
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
import { getResponseText } from '../utils/generateContentResponseUtilities.js';
|
||||||
import { reportError } from '../utils/errorReporting.js';
|
import { reportError } from '../utils/errorReporting.js';
|
||||||
import { getErrorMessage } from '../utils/errors.js';
|
import { getErrorMessage } from '../utils/errors.js';
|
||||||
|
import { GeminiChat } from './geminiChat.js';
|
||||||
|
|
||||||
// Define a structure for tools passed to the server
|
// Define a structure for tools passed to the server
|
||||||
export interface ServerTool {
|
export interface ServerTool {
|
||||||
@@ -113,7 +113,7 @@ export class Turn {
|
|||||||
}>;
|
}>;
|
||||||
private debugResponses: GenerateContentResponse[];
|
private debugResponses: GenerateContentResponse[];
|
||||||
|
|
||||||
constructor(private readonly chat: Chat) {
|
constructor(private readonly chat: GeminiChat) {
|
||||||
this.pendingToolCalls = [];
|
this.pendingToolCalls = [];
|
||||||
this.debugResponses = [];
|
this.debugResponses = [];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,10 +5,11 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
|
import { describe, it, expect, vi, beforeEach, Mock, afterEach } from 'vitest';
|
||||||
import { Chat, Content } from '@google/genai';
|
import { Content } from '@google/genai';
|
||||||
import { GeminiClient } from '../core/client.js';
|
import { GeminiClient } from '../core/client.js';
|
||||||
import { Config } from '../config/config.js'; // Added Config import
|
import { Config } from '../config/config.js'; // Added Config import
|
||||||
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
|
import { checkNextSpeaker, NextSpeakerResponse } from './nextSpeakerChecker.js';
|
||||||
|
import { GeminiChat } from '../core/geminiChat.js';
|
||||||
|
|
||||||
// Mock GeminiClient and Config constructor
|
// Mock GeminiClient and Config constructor
|
||||||
vi.mock('../core/client.js');
|
vi.mock('../core/client.js');
|
||||||
@@ -39,7 +40,7 @@ vi.mock('@google/genai', async () => {
|
|||||||
});
|
});
|
||||||
|
|
||||||
describe('checkNextSpeaker', () => {
|
describe('checkNextSpeaker', () => {
|
||||||
let mockChat: Chat;
|
let mockChat: GeminiChat;
|
||||||
let mockGeminiClient: GeminiClient;
|
let mockGeminiClient: GeminiClient;
|
||||||
let MockConfig: Mock;
|
let MockConfig: Mock;
|
||||||
|
|
||||||
@@ -64,7 +65,7 @@ describe('checkNextSpeaker', () => {
|
|||||||
|
|
||||||
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
mockGeminiClient = new GeminiClient(mockConfigInstance);
|
||||||
// Simulate chat creation as done in GeminiClient
|
// Simulate chat creation as done in GeminiClient
|
||||||
mockChat = { getHistory: mockGetHistory } as unknown as Chat;
|
mockChat = { getHistory: mockGetHistory } as unknown as GeminiChat;
|
||||||
});
|
});
|
||||||
|
|
||||||
afterEach(() => {
|
afterEach(() => {
|
||||||
|
|||||||
@@ -4,8 +4,9 @@
|
|||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { Chat, Content, SchemaUnion, Type } from '@google/genai';
|
import { Content, SchemaUnion, Type } from '@google/genai';
|
||||||
import { GeminiClient } from '../core/client.js';
|
import { GeminiClient } from '../core/client.js';
|
||||||
|
import { GeminiChat } from '../core/geminiChat.js';
|
||||||
|
|
||||||
const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
|
const CHECK_PROMPT = `Analyze *only* the content and structure of your immediately preceding response (your last turn in the conversation history). Based *strictly* on that response, determine who should logically speak next: the 'user' or the 'model' (you).
|
||||||
**Decision Rules (apply in order):**
|
**Decision Rules (apply in order):**
|
||||||
@@ -57,7 +58,7 @@ export interface NextSpeakerResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export async function checkNextSpeaker(
|
export async function checkNextSpeaker(
|
||||||
chat: Chat,
|
chat: GeminiChat,
|
||||||
geminiClient: GeminiClient,
|
geminiClient: GeminiClient,
|
||||||
): Promise<NextSpeakerResponse | null> {
|
): Promise<NextSpeakerResponse | null> {
|
||||||
// We need to capture the curated history because there are many moments when the model will return invalid turns
|
// We need to capture the curated history because there are many moments when the model will return invalid turns
|
||||||
|
|||||||
Reference in New Issue
Block a user