feat: prevent concurrent query submissions in useGeminiStream hook

This commit is contained in:
tanzhenxin
2025-08-14 16:39:26 +08:00
parent 6516d0d136
commit c58106079e
2 changed files with 275 additions and 0 deletions

View File

@@ -448,6 +448,7 @@ describe('useGeminiStream', () => {
callId: 'call1',
responseParts: [{ text: 'tool 1 response' }],
error: undefined,
errorType: undefined,
resultDisplay: 'Tool 1 success display',
},
tool: {
@@ -655,6 +656,7 @@ describe('useGeminiStream', () => {
],
resultDisplay: undefined,
error: undefined,
errorType: undefined,
},
responseSubmittedToGemini: false,
};
@@ -679,6 +681,7 @@ describe('useGeminiStream', () => {
],
resultDisplay: undefined,
error: undefined,
errorType: undefined,
},
responseSubmittedToGemini: false,
};
@@ -775,6 +778,7 @@ describe('useGeminiStream', () => {
callId: 'call1',
responseParts: toolCallResponseParts,
error: undefined,
errorType: undefined,
resultDisplay: 'Tool 1 success display',
},
endTime: Date.now(),
@@ -1128,6 +1132,7 @@ describe('useGeminiStream', () => {
responseParts: [{ text: 'Memory saved' }],
resultDisplay: 'Success: Memory saved',
error: undefined,
errorType: undefined,
},
tool: {
name: 'save_memory',
@@ -1649,4 +1654,262 @@ describe('useGeminiStream', () => {
);
});
});
describe('Concurrent Execution Prevention', () => {
it('should prevent concurrent submitQuery calls', async () => {
let resolveFirstCall!: () => void;
let resolveSecondCall!: () => void;
const firstCallPromise = new Promise<void>((resolve) => {
resolveFirstCall = resolve;
});
const secondCallPromise = new Promise<void>((resolve) => {
resolveSecondCall = resolve;
});
// Mock a long-running stream for the first call
const firstStream = (async function* () {
yield { type: ServerGeminiEventType.Content, value: 'First call content' };
await firstCallPromise; // Wait until we manually resolve
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})();
// Mock a stream for the second call (should not be used)
const secondStream = (async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Second call content' };
await secondCallPromise;
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})();
let callCount = 0;
mockSendMessageStream.mockImplementation(() => {
callCount++;
if (callCount === 1) {
return firstStream;
} else {
return secondStream;
}
});
const { result } = renderTestHook();
// Start first call
const firstCallResult = act(async () => {
await result.current.submitQuery('First query');
});
// Wait a bit to ensure first call has started
await new Promise(resolve => setTimeout(resolve, 10));
// Try to start second call while first is still running
const secondCallResult = act(async () => {
await result.current.submitQuery('Second query');
});
// Resolve both calls
resolveFirstCall();
resolveSecondCall();
await Promise.all([firstCallResult, secondCallResult]);
// Verify only one call was made to sendMessageStream
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledWith(
'First query',
expect.any(AbortSignal),
expect.any(String),
);
// Verify only the first query was added to history
const userMessages = mockAddItem.mock.calls.filter(
call => call[0].type === MessageType.USER
);
expect(userMessages).toHaveLength(1);
expect(userMessages[0][0].text).toBe('First query');
});
it('should allow subsequent calls after first call completes', async () => {
// Mock streams that complete immediately
mockSendMessageStream
.mockReturnValueOnce((async function* () {
yield { type: ServerGeminiEventType.Content, value: 'First response' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})())
.mockReturnValueOnce((async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Second response' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})());
const { result } = renderTestHook();
// First call
await act(async () => {
await result.current.submitQuery('First query');
});
// Second call after first completes
await act(async () => {
await result.current.submitQuery('Second query');
});
// Both calls should have been made
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
expect(mockSendMessageStream).toHaveBeenNthCalledWith(1,
'First query',
expect.any(AbortSignal),
expect.any(String),
);
expect(mockSendMessageStream).toHaveBeenNthCalledWith(2,
'Second query',
expect.any(AbortSignal),
expect.any(String),
);
});
it('should reset execution flag even when query preparation fails', async () => {
const { result } = renderTestHook();
// First call with empty query (should fail in preparation)
await act(async () => {
await result.current.submitQuery(' '); // Empty trimmed query
});
// Second call should work normally
mockSendMessageStream.mockReturnValue((async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Valid response' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})());
await act(async () => {
await result.current.submitQuery('Valid query');
});
// The second call should have been made
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledWith(
'Valid query',
expect.any(AbortSignal),
expect.any(String),
);
});
it('should reset execution flag when user cancels', async () => {
let resolveCancelledStream!: () => void;
const cancelledStreamPromise = new Promise<void>((resolve) => {
resolveCancelledStream = resolve;
});
// Mock a stream that can be cancelled
const cancelledStream = (async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Cancelled content' };
await cancelledStreamPromise;
yield { type: ServerGeminiEventType.UserCancelled };
})();
mockSendMessageStream.mockReturnValueOnce(cancelledStream);
const { result } = renderTestHook();
// Start first call
const firstCallResult = act(async () => {
await result.current.submitQuery('First query');
});
// Wait a bit then resolve to trigger cancellation
await new Promise(resolve => setTimeout(resolve, 10));
resolveCancelledStream();
await firstCallResult;
// Now try a second call - should work
mockSendMessageStream.mockReturnValue((async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Second response' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})());
await act(async () => {
await result.current.submitQuery('Second query');
});
// Both calls should have been made
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
});
it('should reset execution flag when an error occurs', async () => {
// Mock a stream that throws an error
mockSendMessageStream.mockReturnValueOnce((async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Error content' };
throw new Error('Stream error');
})());
const { result } = renderTestHook();
// First call that will error
await act(async () => {
await result.current.submitQuery('Error query');
});
// Second call should work normally
mockSendMessageStream.mockReturnValue((async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Success response' };
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})());
await act(async () => {
await result.current.submitQuery('Success query');
});
// Both calls should have been attempted
expect(mockSendMessageStream).toHaveBeenCalledTimes(2);
});
it('should handle rapid multiple concurrent calls correctly', async () => {
let resolveStream!: () => void;
const streamPromise = new Promise<void>((resolve) => {
resolveStream = resolve;
});
// Mock a long-running stream
const longStream = (async function* () {
yield { type: ServerGeminiEventType.Content, value: 'Long running content' };
await streamPromise;
yield { type: ServerGeminiEventType.Finished, value: 'STOP' };
})();
mockSendMessageStream.mockReturnValue(longStream);
const { result } = renderTestHook();
// Start multiple concurrent calls
const calls = [
act(async () => { await result.current.submitQuery('Query 1'); }),
act(async () => { await result.current.submitQuery('Query 2'); }),
act(async () => { await result.current.submitQuery('Query 3'); }),
act(async () => { await result.current.submitQuery('Query 4'); }),
act(async () => { await result.current.submitQuery('Query 5'); }),
];
// Wait a bit then resolve the stream
await new Promise(resolve => setTimeout(resolve, 10));
resolveStream();
// Wait for all calls to complete
await Promise.all(calls);
// Only the first call should have been made
expect(mockSendMessageStream).toHaveBeenCalledTimes(1);
expect(mockSendMessageStream).toHaveBeenCalledWith(
'Query 1',
expect.any(AbortSignal),
expect.any(String),
);
// Only one user message should have been added
const userMessages = mockAddItem.mock.calls.filter(
call => call[0].type === MessageType.USER
);
expect(userMessages).toHaveLength(1);
expect(userMessages[0][0].text).toBe('Query 1');
});
});
});

View File

@@ -97,6 +97,7 @@ export const useGeminiStream = (
const [initError, setInitError] = useState<string | null>(null);
const abortControllerRef = useRef<AbortController | null>(null);
const turnCancelledRef = useRef(false);
const isSubmittingQueryRef = useRef(false);
const [isResponding, setIsResponding] = useState<boolean>(false);
const [thought, setThought] = useState<ThoughtSummary | null>(null);
const [pendingHistoryItemRef, setPendingHistoryItem] =
@@ -622,6 +623,11 @@ export const useGeminiStream = (
options?: { isContinuation: boolean },
prompt_id?: string,
) => {
// Prevent concurrent executions of submitQuery
if (isSubmittingQueryRef.current) {
return;
}
if (
(streamingState === StreamingState.Responding ||
streamingState === StreamingState.WaitingForConfirmation) &&
@@ -629,6 +635,9 @@ export const useGeminiStream = (
)
return;
// Set the flag to indicate we're now executing
isSubmittingQueryRef.current = true;
const userMessageTimestamp = Date.now();
// Reset quota error flag when starting a new query (not a continuation)
@@ -653,6 +662,7 @@ export const useGeminiStream = (
);
if (!shouldProceed || queryToSend === null) {
isSubmittingQueryRef.current = false;
return;
}
@@ -677,6 +687,7 @@ export const useGeminiStream = (
);
if (processingStatus === StreamProcessingStatus.UserCancelled) {
isSubmittingQueryRef.current = false;
return;
}
@@ -708,6 +719,7 @@ export const useGeminiStream = (
}
} finally {
setIsResponding(false);
isSubmittingQueryRef.current = false;
}
},
[