mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-21 17:27:54 +00:00
feat: Allow cancellation of in-progress Gemini requests and pre-execution checks
- Implements cancellation for Gemini requests while they are actively being processed by the model. - Extends cancellation support to the logic within tools. This allows users to cancel operations during the phase where the system is determining if a tool execution requires user confirmation, which can include potentially long-running pre-flight checks or LLM-based corrections. - Underlying LLM calls for edit corrections (within and ) and next speaker checks can now also be cancelled. - Previously, cancellation of the main request was not possible until text started streaming, and pre-execution checks were not cancellable. - This change leverages the updated SDK's ability to accept an abort token and threads s throughout the request, tool execution, and pre-execution check lifecycle. Fixes https://github.com/google-gemini/gemini-cli/issues/531
This commit is contained in:
committed by
N. Taylor Mullen
parent
bfeaac8441
commit
f2f2ecf9d8
@@ -132,6 +132,7 @@ describe('editCorrector', () => {
|
||||
let mockGeminiClientInstance: Mocked<GeminiClient>;
|
||||
let mockToolRegistry: Mocked<ToolRegistry>;
|
||||
let mockConfigInstance: Config;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
mockToolRegistry = new ToolRegistry({} as Config) as Mocked<ToolRegistry>;
|
||||
@@ -187,12 +188,18 @@ describe('editCorrector', () => {
|
||||
|
||||
callCount = 0;
|
||||
mockResponses.length = 0;
|
||||
mockGenerateJson = vi.fn().mockImplementation(() => {
|
||||
const response = mockResponses[callCount];
|
||||
callCount++;
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockGenerateJson = vi
|
||||
.fn()
|
||||
.mockImplementation((_contents, _schema, signal) => {
|
||||
// Check if the signal is aborted. If so, throw an error or return a specific response.
|
||||
if (signal && signal.aborted) {
|
||||
return Promise.reject(new Error('Aborted')); // Or some other specific error/response
|
||||
}
|
||||
const response = mockResponses[callCount];
|
||||
callCount++;
|
||||
if (response === undefined) return Promise.resolve({});
|
||||
return Promise.resolve(response);
|
||||
});
|
||||
mockStartChat = vi.fn();
|
||||
mockSendMessageStream = vi.fn();
|
||||
|
||||
@@ -217,6 +224,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
@@ -234,6 +242,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
@@ -254,6 +263,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
@@ -271,6 +281,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
@@ -292,6 +303,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
@@ -309,6 +321,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params.new_string).toBe('replace with this');
|
||||
@@ -329,6 +342,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with foobar');
|
||||
@@ -351,6 +365,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe(llmNewString);
|
||||
@@ -372,6 +387,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
expect(result.params.new_string).toBe(llmNewString);
|
||||
@@ -391,6 +407,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe('replace with "this"');
|
||||
@@ -412,6 +429,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params.new_string).toBe(newStringForLLMAndReturnedByLLM);
|
||||
@@ -432,6 +450,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(1);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
@@ -449,6 +468,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(0);
|
||||
expect(result.params).toEqual(originalParams);
|
||||
@@ -471,6 +491,7 @@ describe('editCorrector', () => {
|
||||
currentContent,
|
||||
originalParams,
|
||||
mockGeminiClientInstance,
|
||||
abortSignal,
|
||||
);
|
||||
expect(mockGenerateJson).toHaveBeenCalledTimes(2);
|
||||
expect(result.params.old_string).toBe(currentContent);
|
||||
|
||||
@@ -63,6 +63,7 @@ export async function ensureCorrectEdit(
|
||||
currentContent: string,
|
||||
originalParams: EditToolParams, // This is the EditToolParams from edit.ts, without \'corrected\'
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<CorrectedEditResult> {
|
||||
const cacheKey = `${currentContent}---${originalParams.old_string}---${originalParams.new_string}`;
|
||||
const cachedResult = editCorrectionCache.get(cacheKey);
|
||||
@@ -84,6 +85,7 @@ export async function ensureCorrectEdit(
|
||||
client,
|
||||
finalOldString,
|
||||
originalParams.new_string,
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else if (occurrences > 1) {
|
||||
@@ -108,6 +110,7 @@ export async function ensureCorrectEdit(
|
||||
originalParams.old_string, // original old
|
||||
unescapedOldStringAttempt, // corrected old
|
||||
originalParams.new_string, // original new (which is potentially escaped)
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else if (occurrences === 0) {
|
||||
@@ -115,6 +118,7 @@ export async function ensureCorrectEdit(
|
||||
client,
|
||||
currentContent,
|
||||
unescapedOldStringAttempt,
|
||||
abortSignal,
|
||||
);
|
||||
const llmOldOccurrences = countOccurrences(
|
||||
currentContent,
|
||||
@@ -134,6 +138,7 @@ export async function ensureCorrectEdit(
|
||||
originalParams.old_string, // original old
|
||||
llmCorrectedOldString, // corrected old
|
||||
baseNewStringForLLMCorrection, // base new for correction
|
||||
abortSignal,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
@@ -180,6 +185,7 @@ export async function ensureCorrectEdit(
|
||||
export async function ensureCorrectFileContent(
|
||||
content: string,
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const cachedResult = fileContentCorrectionCache.get(content);
|
||||
if (cachedResult) {
|
||||
@@ -193,7 +199,11 @@ export async function ensureCorrectFileContent(
|
||||
return content;
|
||||
}
|
||||
|
||||
const correctedContent = await correctStringEscaping(content, client);
|
||||
const correctedContent = await correctStringEscaping(
|
||||
content,
|
||||
client,
|
||||
abortSignal,
|
||||
);
|
||||
fileContentCorrectionCache.set(content, correctedContent);
|
||||
return correctedContent;
|
||||
}
|
||||
@@ -215,6 +225,7 @@ export async function correctOldStringMismatch(
|
||||
geminiClient: GeminiClient,
|
||||
fileContent: string,
|
||||
problematicSnippet: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: A process needs to find an exact literal, unique match for a specific text snippet within a file's content. The provided snippet failed to match exactly. This is most likely because it has been overly escaped.
|
||||
@@ -243,6 +254,7 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
OLD_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
@@ -257,10 +269,15 @@ Return ONLY the corrected target snippet in the specified JSON format with the k
|
||||
return problematicSnippet;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for old string snippet correction:',
|
||||
error,
|
||||
);
|
||||
|
||||
return problematicSnippet;
|
||||
}
|
||||
}
|
||||
@@ -286,6 +303,7 @@ export async function correctNewString(
|
||||
originalOldString: string,
|
||||
correctedOldString: string,
|
||||
originalNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
if (originalOldString === correctedOldString) {
|
||||
return originalNewString;
|
||||
@@ -324,6 +342,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
NEW_STRING_CORRECTION_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
@@ -338,6 +357,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
return originalNewString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error('Error during LLM call for new_string correction:', error);
|
||||
return originalNewString;
|
||||
}
|
||||
@@ -359,6 +382,7 @@ export async function correctNewStringEscaping(
|
||||
geminiClient: GeminiClient,
|
||||
oldString: string,
|
||||
potentiallyProblematicNewString: string,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: A text replacement operation is planned. The text to be replaced (old_string) has been correctly identified in the file. However, the replacement text (new_string) might have been improperly escaped by a previous LLM generation (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||
@@ -387,6 +411,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const result = await geminiClient.generateJson(
|
||||
contents,
|
||||
CORRECT_NEW_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
@@ -401,6 +426,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
return potentiallyProblematicNewString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for new_string escaping correction:',
|
||||
error,
|
||||
@@ -424,6 +453,7 @@ const CORRECT_STRING_ESCAPING_SCHEMA: SchemaUnion = {
|
||||
export async function correctStringEscaping(
|
||||
potentiallyProblematicString: string,
|
||||
client: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<string> {
|
||||
const prompt = `
|
||||
Context: An LLM has just generated potentially_problematic_string and the text might have been improperly escaped (e.g. too many backslashes for newlines like \\n instead of \n, or unnecessarily quotes like \\"Hello\\" instead of "Hello").
|
||||
@@ -447,6 +477,7 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
const result = await client.generateJson(
|
||||
contents,
|
||||
CORRECT_STRING_ESCAPING_SCHEMA,
|
||||
abortSignal,
|
||||
EditModel,
|
||||
EditConfig,
|
||||
);
|
||||
@@ -461,6 +492,10 @@ Return ONLY the corrected string in the specified JSON format with the key 'corr
|
||||
return potentiallyProblematicString;
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
throw error;
|
||||
}
|
||||
|
||||
console.error(
|
||||
'Error during LLM call for string escaping correction:',
|
||||
error,
|
||||
|
||||
@@ -44,6 +44,7 @@ describe('checkNextSpeaker', () => {
|
||||
let chatInstance: GeminiChat;
|
||||
let mockGeminiClient: GeminiClient;
|
||||
let MockConfig: Mock;
|
||||
const abortSignal = new AbortController().signal;
|
||||
|
||||
beforeEach(() => {
|
||||
MockConfig = vi.mocked(Config);
|
||||
@@ -71,7 +72,7 @@ describe('checkNextSpeaker', () => {
|
||||
mockGoogleGenAIInstance, // This will be the instance returned by the mocked GoogleGenAI constructor
|
||||
mockModelsInstance, // This is the instance returned by mockGoogleGenAIInstance.getGenerativeModel
|
||||
'gemini-pro', // model name
|
||||
{}, // config
|
||||
{},
|
||||
[], // initial history
|
||||
);
|
||||
|
||||
@@ -85,7 +86,11 @@ describe('checkNextSpeaker', () => {
|
||||
|
||||
it('should return null if history is empty', async () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([]);
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
@@ -94,7 +99,11 @@ describe('checkNextSpeaker', () => {
|
||||
(chatInstance.getHistory as Mock).mockReturnValue([
|
||||
{ role: 'user', parts: [{ text: 'Hello' }] },
|
||||
] as Content[]);
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
expect(mockGeminiClient.generateJson).not.toHaveBeenCalled();
|
||||
});
|
||||
@@ -109,7 +118,11 @@ describe('checkNextSpeaker', () => {
|
||||
};
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toEqual(mockApiResponse);
|
||||
expect(mockGeminiClient.generateJson).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
@@ -124,7 +137,11 @@ describe('checkNextSpeaker', () => {
|
||||
};
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toEqual(mockApiResponse);
|
||||
});
|
||||
|
||||
@@ -138,7 +155,11 @@ describe('checkNextSpeaker', () => {
|
||||
};
|
||||
(mockGeminiClient.generateJson as Mock).mockResolvedValue(mockApiResponse);
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toEqual(mockApiResponse);
|
||||
});
|
||||
|
||||
@@ -153,7 +174,11 @@ describe('checkNextSpeaker', () => {
|
||||
new Error('API Error'),
|
||||
);
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
@@ -166,7 +191,11 @@ describe('checkNextSpeaker', () => {
|
||||
reasoning: 'This is incomplete.',
|
||||
} as unknown as NextSpeakerResponse); // Type assertion to simulate invalid response
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
@@ -179,7 +208,11 @@ describe('checkNextSpeaker', () => {
|
||||
next_speaker: 123, // Invalid type
|
||||
} as unknown as NextSpeakerResponse);
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
@@ -192,7 +225,11 @@ describe('checkNextSpeaker', () => {
|
||||
next_speaker: 'neither', // Invalid enum value
|
||||
} as unknown as NextSpeakerResponse);
|
||||
|
||||
const result = await checkNextSpeaker(chatInstance, mockGeminiClient);
|
||||
const result = await checkNextSpeaker(
|
||||
chatInstance,
|
||||
mockGeminiClient,
|
||||
abortSignal,
|
||||
);
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -61,6 +61,7 @@ export interface NextSpeakerResponse {
|
||||
export async function checkNextSpeaker(
|
||||
chat: GeminiChat,
|
||||
geminiClient: GeminiClient,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<NextSpeakerResponse | null> {
|
||||
// We need to capture the curated history because there are many moments when the model will return invalid turns
|
||||
// that when passed back up to the endpoint will break subsequent calls. An example of this is when the model decides
|
||||
@@ -129,6 +130,7 @@ export async function checkNextSpeaker(
|
||||
const parsedResponse = (await geminiClient.generateJson(
|
||||
contents,
|
||||
RESPONSE_SCHEMA,
|
||||
abortSignal,
|
||||
)) as unknown as NextSpeakerResponse;
|
||||
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user