mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-22 09:47:47 +00:00
Implement loop check with LLM (#4337)
Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
@@ -4,16 +4,18 @@
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
import { LoopDetectionService } from './loopDetectionService.js';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
import { Config } from '../config/config.js';
|
||||
import { GeminiClient } from '../core/client.js';
|
||||
import {
|
||||
GeminiEventType,
|
||||
ServerGeminiContentEvent,
|
||||
ServerGeminiStreamEvent,
|
||||
ServerGeminiToolCallRequestEvent,
|
||||
} from '../core/turn.js';
|
||||
import { ServerGeminiStreamEvent } from '../core/turn.js';
|
||||
import { Config } from '../config/config.js';
|
||||
import * as loggers from '../telemetry/loggers.js';
|
||||
import { LoopType } from '../telemetry/types.js';
|
||||
import { LoopDetectionService } from './loopDetectionService.js';
|
||||
|
||||
vi.mock('../telemetry/loggers.js', () => ({
|
||||
logLoopDetected: vi.fn(),
|
||||
@@ -330,3 +332,112 @@ describe('LoopDetectionService', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('LoopDetectionService LLM Checks', () => {
|
||||
let service: LoopDetectionService;
|
||||
let mockConfig: Config;
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let abortController: AbortController;
|
||||
|
||||
beforeEach(() => {
|
||||
mockGeminiClient = {
|
||||
getHistory: vi.fn().mockReturnValue([]),
|
||||
generateJson: vi.fn(),
|
||||
} as unknown as GeminiClient;
|
||||
|
||||
mockConfig = {
|
||||
getGeminiClient: () => mockGeminiClient,
|
||||
getDebugMode: () => false,
|
||||
getTelemetryEnabled: () => true,
|
||||
} as unknown as Config;
|
||||
|
||||
service = new LoopDetectionService(mockConfig);
|
||||
abortController = new AbortController();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
const advanceTurns = async (count: number) => {
|
||||
for (let i = 0; i < count; i++) {
|
||||
await service.turnStarted(abortController.signal);
|
||||
}
|
||||
};
|
||||
|
||||
it('should not trigger LLM check before LLM_CHECK_AFTER_TURNS', async () => {
|
||||
await advanceTurns(29);
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should trigger LLM check on the 30th turn', async () => {
|
||||
mockGeminiClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.1 });
|
||||
await advanceTurns(30);
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should detect a cognitive loop when confidence is high', async () => {
|
||||
// First check at turn 30
|
||||
mockGeminiClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.85, reasoning: 'Repetitive actions' });
|
||||
await advanceTurns(30);
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
// The confidence of 0.85 will result in a low interval.
|
||||
// The interval will be: 5 + (15 - 5) * (1 - 0.85) = 5 + 10 * 0.15 = 6.5 -> rounded to 7
|
||||
await advanceTurns(6); // advance to turn 36
|
||||
|
||||
mockGeminiClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.95, reasoning: 'Repetitive actions' });
|
||||
const finalResult = await service.turnStarted(abortController.signal); // This is turn 37
|
||||
|
||||
expect(finalResult).toBe(true);
|
||||
expect(loggers.logLoopDetected).toHaveBeenCalledWith(
|
||||
mockConfig,
|
||||
expect.objectContaining({
|
||||
'event.name': 'loop_detected',
|
||||
loop_type: LoopType.LLM_DETECTED_LOOP,
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should not detect a loop when confidence is low', async () => {
|
||||
mockGeminiClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.5, reasoning: 'Looks okay' });
|
||||
await advanceTurns(30);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should adjust the check interval based on confidence', async () => {
|
||||
// Confidence is 0.0, so interval should be MAX_LLM_CHECK_INTERVAL (15)
|
||||
mockGeminiClient.generateJson = vi
|
||||
.fn()
|
||||
.mockResolvedValue({ confidence: 0.0 });
|
||||
await advanceTurns(30); // First check at turn 30
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await advanceTurns(14); // Advance to turn 44
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
|
||||
await service.turnStarted(abortController.signal); // Turn 45
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it('should handle errors from generateJson gracefully', async () => {
|
||||
mockGeminiClient.generateJson = vi
|
||||
.fn()
|
||||
.mockRejectedValue(new Error('API error'));
|
||||
await advanceTurns(30);
|
||||
const result = await service.turnStarted(abortController.signal);
|
||||
expect(result).toBe(false);
|
||||
expect(loggers.logLoopDetected).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user