Propagate abort signal to ccpa generateContent. (#1106)

This commit is contained in:
Tommaso Sciortino
2025-06-16 13:24:42 -07:00
committed by GitHub
parent 42329e0258
commit 11f524c125
2 changed files with 12 additions and 1 deletions

View File

@@ -45,6 +45,7 @@ describe('CodeAssistServer', () => {
expect(server.callEndpoint).toHaveBeenCalledWith( expect(server.callEndpoint).toHaveBeenCalledWith(
'generateContent', 'generateContent',
expect.any(Object), expect.any(Object),
undefined,
); );
expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe( expect(response.candidates?.[0]?.content?.parts?.[0]?.text).toBe(
'response', 'response',
@@ -82,6 +83,7 @@ describe('CodeAssistServer', () => {
expect(server.streamEndpoint).toHaveBeenCalledWith( expect(server.streamEndpoint).toHaveBeenCalledWith(
'streamGenerateContent', 'streamGenerateContent',
expect.any(Object), expect.any(Object),
undefined,
); );
expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response'); expect(res.candidates?.[0]?.content?.parts?.[0]?.text).toBe('response');
} }

View File

@@ -53,6 +53,7 @@ export class CodeAssistServer implements ContentGenerator {
const resps = await this.streamEndpoint<CodeAssistResponse>( const resps = await this.streamEndpoint<CodeAssistResponse>(
'streamGenerateContent', 'streamGenerateContent',
toCodeAssistRequest(req, this.projectId), toCodeAssistRequest(req, this.projectId),
req.config?.abortSignal,
); );
return (async function* (): AsyncGenerator<GenerateContentResponse> { return (async function* (): AsyncGenerator<GenerateContentResponse> {
for await (const resp of resps) { for await (const resp of resps) {
@@ -67,6 +68,7 @@ export class CodeAssistServer implements ContentGenerator {
const resp = await this.callEndpoint<CodeAssistResponse>( const resp = await this.callEndpoint<CodeAssistResponse>(
'generateContent', 'generateContent',
toCodeAssistRequest(req, this.projectId), toCodeAssistRequest(req, this.projectId),
req.config?.abortSignal,
); );
return fromCodeAsistResponse(resp); return fromCodeAsistResponse(resp);
} }
@@ -99,7 +101,11 @@ export class CodeAssistServer implements ContentGenerator {
throw Error(); throw Error();
} }
async callEndpoint<T>(method: string, req: object): Promise<T> { async callEndpoint<T>(
method: string,
req: object,
signal?: AbortSignal,
): Promise<T> {
const res = await this.auth.request({ const res = await this.auth.request({
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
method: 'POST', method: 'POST',
@@ -109,6 +115,7 @@ export class CodeAssistServer implements ContentGenerator {
}, },
responseType: 'json', responseType: 'json',
body: JSON.stringify(req), body: JSON.stringify(req),
signal,
}); });
return res.data as T; return res.data as T;
} }
@@ -116,6 +123,7 @@ export class CodeAssistServer implements ContentGenerator {
async streamEndpoint<T>( async streamEndpoint<T>(
method: string, method: string,
req: object, req: object,
signal?: AbortSignal,
): Promise<AsyncGenerator<T>> { ): Promise<AsyncGenerator<T>> {
const res = await this.auth.request({ const res = await this.auth.request({
url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`, url: `${CODE_ASSIST_ENDPOINT}/${CODE_ASSIST_API_VERSION}:${method}`,
@@ -129,6 +137,7 @@ export class CodeAssistServer implements ContentGenerator {
}, },
responseType: 'stream', responseType: 'stream',
body: JSON.stringify(req), body: JSON.stringify(req),
signal,
}); });
return (async function* (): AsyncGenerator<T> { return (async function* (): AsyncGenerator<T> {