Merge pull request #610 from QwenLM/feat/skip-loop-detection

Add `skipLoopDetection` Configuration Option
This commit is contained in:
pomelo
2025-09-16 15:31:51 +08:00
committed by GitHub
6 changed files with 141 additions and 7 deletions

View File

@@ -233,6 +233,7 @@ export interface ConfigParameters {
trustedFolder?: boolean;
shouldUseNodePtyShell?: boolean;
skipNextSpeakerCheck?: boolean;
skipLoopDetection?: boolean;
}
export class Config {
@@ -318,6 +319,7 @@ export class Config {
private readonly trustedFolder: boolean | undefined;
private readonly shouldUseNodePtyShell: boolean;
private readonly skipNextSpeakerCheck: boolean;
private readonly skipLoopDetection: boolean;
private initialized: boolean = false;
constructor(params: ConfigParameters) {
@@ -399,6 +401,7 @@ export class Config {
this.trustedFolder = params.trustedFolder;
this.shouldUseNodePtyShell = params.shouldUseNodePtyShell ?? false;
this.skipNextSpeakerCheck = params.skipNextSpeakerCheck ?? false;
this.skipLoopDetection = params.skipLoopDetection ?? false;
// Web search
this.tavilyApiKey = params.tavilyApiKey;
@@ -861,6 +864,10 @@ export class Config {
return this.skipNextSpeakerCheck;
}
getSkipLoopDetection(): boolean {
return this.skipLoopDetection;
}
async getGitService(): Promise<GitService> {
if (!this.gitService) {
this.gitService = new GitService(this.targetDir);

View File

@@ -233,6 +233,7 @@ describe('Gemini Client (client.ts)', () => {
getCliVersion: vi.fn().mockReturnValue('1.0.0'),
getChatCompression: vi.fn().mockReturnValue(undefined),
getSkipNextSpeakerCheck: vi.fn().mockReturnValue(false),
getSkipLoopDetection: vi.fn().mockReturnValue(false),
};
const MockedConfig = vi.mocked(Config, true);
MockedConfig.mockImplementation(
@@ -1987,6 +1988,100 @@ ${JSON.stringify(
// Assert
expect(mockCheckNextSpeaker).not.toHaveBeenCalled();
});
it('does not run loop checks when skipLoopDetection is true', async () => {
// Arrange
// Ensure config returns true for skipLoopDetection
vi.spyOn(client['config'], 'getSkipLoopDetection').mockReturnValue(true);
// Replace loop detector with spies
const ldMock = {
turnStarted: vi.fn().mockResolvedValue(false),
addAndCheck: vi.fn().mockReturnValue(false),
reset: vi.fn(),
};
// @ts-expect-error override private for testing
client['loopDetector'] = ldMock;
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
yield { type: 'content', value: 'World' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-loop-skip',
);
for await (const _ of stream) {
// consume
}
// Assert: methods not called due to skip
const ld = client['loopDetector'] as unknown as {
turnStarted: ReturnType<typeof vi.fn>;
addAndCheck: ReturnType<typeof vi.fn>;
};
expect(ld.turnStarted).not.toHaveBeenCalled();
expect(ld.addAndCheck).not.toHaveBeenCalled();
});
it('runs loop checks when skipLoopDetection is false', async () => {
// Arrange
vi.spyOn(client['config'], 'getSkipLoopDetection').mockReturnValue(false);
const turnStarted = vi.fn().mockResolvedValue(false);
const addAndCheck = vi.fn().mockReturnValue(false);
const reset = vi.fn();
// @ts-expect-error override private for testing
client['loopDetector'] = { turnStarted, addAndCheck, reset };
const mockStream = (async function* () {
yield { type: 'content', value: 'Hello' };
yield { type: 'content', value: 'World' };
})();
mockTurnRunFn.mockReturnValue(mockStream);
const mockChat: Partial<GeminiChat> = {
addHistory: vi.fn(),
getHistory: vi.fn().mockReturnValue([]),
};
client['chat'] = mockChat as GeminiChat;
const mockGenerator: Partial<ContentGenerator> = {
countTokens: vi.fn().mockResolvedValue({ totalTokens: 0 }),
generateContent: mockGenerateContentFn,
};
client['contentGenerator'] = mockGenerator as ContentGenerator;
// Act
const stream = client.sendMessageStream(
[{ text: 'Hi' }],
new AbortController().signal,
'prompt-id-loop-run',
);
for await (const _ of stream) {
// consume
}
// Assert
expect(turnStarted).toHaveBeenCalledTimes(1);
expect(addAndCheck).toHaveBeenCalled();
});
});
describe('generateContent', () => {

View File

@@ -551,17 +551,21 @@ export class GeminiClient {
const turn = new Turn(this.getChat(), prompt_id);
const loopDetected = await this.loopDetector.turnStarted(signal);
if (loopDetected) {
yield { type: GeminiEventType.LoopDetected };
return turn;
if (!this.config.getSkipLoopDetection()) {
const loopDetected = await this.loopDetector.turnStarted(signal);
if (loopDetected) {
yield { type: GeminiEventType.LoopDetected };
return turn;
}
}
const resultStream = turn.run(request, signal);
for await (const event of resultStream) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
return turn;
if (!this.config.getSkipLoopDetection()) {
if (this.loopDetector.addAndCheck(event)) {
yield { type: GeminiEventType.LoopDetected };
return turn;
}
}
yield event;
if (event.type === GeminiEventType.Error) {