mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-22 01:37:50 +00:00
feat: Allow cancellation of in-progress Gemini requests and pre-execution checks
- Implements cancellation for Gemini requests while they are actively being processed by the model. - Extends cancellation support to the logic within tools. This allows users to cancel operations during the phase where the system is determining if a tool execution requires user confirmation, which can include potentially long-running pre-flight checks or LLM-based corrections. - Underlying LLM calls for edit corrections (within and ) and next speaker checks can now also be cancelled. - Previously, cancellation of the main request was not possible until text started streaming, and pre-execution checks were not cancellable. - This change leverages the updated SDK's ability to accept an abort token and threads s throughout the request, tool execution, and pre-execution check lifecycle. Fixes https://github.com/google-gemini/gemini-cli/issues/531
This commit is contained in:
committed by
N. Taylor Mullen
parent
bfeaac8441
commit
f2f2ecf9d8
@@ -157,7 +157,7 @@ export class GeminiClient {
|
||||
async *sendMessageStream(
|
||||
chat: GeminiChat,
|
||||
request: PartListUnion,
|
||||
signal?: AbortSignal,
|
||||
signal: AbortSignal,
|
||||
turns: number = this.MAX_TURNS,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
if (!turns) {
|
||||
@@ -169,8 +169,8 @@ export class GeminiClient {
|
||||
for await (const event of resultStream) {
|
||||
yield event;
|
||||
}
|
||||
if (!turn.pendingToolCalls.length) {
|
||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this);
|
||||
if (!turn.pendingToolCalls.length && signal && !signal.aborted) {
|
||||
const nextSpeakerCheck = await checkNextSpeaker(chat, this, signal);
|
||||
if (nextSpeakerCheck?.next_speaker === 'model') {
|
||||
const nextRequest = [{ text: 'Please continue.' }];
|
||||
yield* this.sendMessageStream(chat, nextRequest, signal, turns - 1);
|
||||
@@ -181,6 +181,7 @@ export class GeminiClient {
|
||||
async generateJson(
|
||||
contents: Content[],
|
||||
schema: SchemaUnion,
|
||||
abortSignal: AbortSignal,
|
||||
model: string = 'gemini-2.0-flash',
|
||||
config: GenerateContentConfig = {},
|
||||
): Promise<Record<string, unknown>> {
|
||||
@@ -188,6 +189,7 @@ export class GeminiClient {
|
||||
const userMemory = this.config.getUserMemory();
|
||||
const systemInstruction = getCoreSystemPrompt(userMemory);
|
||||
const requestConfig = {
|
||||
abortSignal,
|
||||
...this.generateContentConfig,
|
||||
...config,
|
||||
};
|
||||
@@ -232,6 +234,11 @@ export class GeminiClient {
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
// Regular cancellation error, fail normally
|
||||
throw error;
|
||||
}
|
||||
|
||||
// Avoid double reporting for the empty response case handled above
|
||||
if (
|
||||
error instanceof Error &&
|
||||
|
||||
@@ -155,7 +155,7 @@ export class GeminiChat {
|
||||
const responsePromise = this.modelsModule.generateContent({
|
||||
model: this.model,
|
||||
contents: this.getHistory(true).concat(userContent),
|
||||
config: params.config ?? this.config,
|
||||
config: { ...this.config, ...params.config },
|
||||
});
|
||||
this.sendPromise = (async () => {
|
||||
const response = await responsePromise;
|
||||
@@ -219,7 +219,7 @@ export class GeminiChat {
|
||||
const streamResponse = this.modelsModule.generateContentStream({
|
||||
model: this.model,
|
||||
contents: this.getHistory(true).concat(userContent),
|
||||
config: params.config ?? this.config,
|
||||
config: { ...this.config, ...params.config },
|
||||
});
|
||||
// Resolve the internal tracking of send completion promise - `sendPromise`
|
||||
// for both success and failure response. The actual failure is still
|
||||
|
||||
@@ -85,11 +85,17 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith({ message: reqParts });
|
||||
expect(mockSendMessageStream).toHaveBeenCalledWith({
|
||||
message: reqParts,
|
||||
config: { abortSignal: expect.any(AbortSignal) },
|
||||
});
|
||||
expect(events).toEqual([
|
||||
{ type: GeminiEventType.Content, value: 'Hello' },
|
||||
{ type: GeminiEventType.Content, value: ' world' },
|
||||
@@ -110,7 +116,10 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Use tools' }];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
@@ -179,7 +188,10 @@ describe('Turn', () => {
|
||||
mockGetHistory.mockReturnValue(historyContent);
|
||||
|
||||
const events = [];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
@@ -210,7 +222,10 @@ describe('Turn', () => {
|
||||
|
||||
const events = [];
|
||||
const reqParts: Part[] = [{ text: 'Test undefined tool parts' }];
|
||||
for await (const event of turn.run(reqParts)) {
|
||||
for await (const event of turn.run(
|
||||
reqParts,
|
||||
new AbortController().signal,
|
||||
)) {
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
@@ -261,7 +276,7 @@ describe('Turn', () => {
|
||||
})();
|
||||
mockSendMessageStream.mockResolvedValue(mockResponseStream);
|
||||
const reqParts: Part[] = [{ text: 'Hi' }];
|
||||
for await (const _ of turn.run(reqParts)) {
|
||||
for await (const _ of turn.run(reqParts, new AbortController().signal)) {
|
||||
// consume stream
|
||||
}
|
||||
expect(turn.getDebugResponses()).toEqual([resp1, resp2]);
|
||||
|
||||
@@ -32,6 +32,7 @@ export interface ServerTool {
|
||||
): Promise<ToolResult>;
|
||||
shouldConfirmExecute(
|
||||
params: Record<string, unknown>,
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false>;
|
||||
}
|
||||
|
||||
@@ -120,11 +121,14 @@ export class Turn {
|
||||
// The run method yields simpler events suitable for server logic
|
||||
async *run(
|
||||
req: PartListUnion,
|
||||
signal?: AbortSignal,
|
||||
signal: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
try {
|
||||
const responseStream = await this.chat.sendMessageStream({
|
||||
message: req,
|
||||
config: {
|
||||
abortSignal: signal,
|
||||
},
|
||||
});
|
||||
|
||||
for await (const resp of responseStream) {
|
||||
@@ -150,6 +154,12 @@ export class Turn {
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal.aborted) {
|
||||
yield { type: GeminiEventType.UserCancelled };
|
||||
// Regular cancellation error, fail gracefully.
|
||||
return;
|
||||
}
|
||||
|
||||
const contextForReport = [...this.chat.getHistory(/*curated*/ true), req];
|
||||
await reportError(
|
||||
error,
|
||||
|
||||
Reference in New Issue
Block a user