mirror of
https://github.com/QwenLM/qwen-code.git
synced 2025-12-19 09:33:53 +00:00
Add a2a-server package to gemini-cli (#6597)
This commit is contained in:
5
packages/a2a-server/README.md
Normal file
5
packages/a2a-server/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# Gemini CLI A2A Server
|
||||
|
||||
## All code in this package is experimental and under active development
|
||||
|
||||
This package contains the A2A server implementation for the Gemini CLI.
|
||||
7
packages/a2a-server/index.ts
Normal file
7
packages/a2a-server/index.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export * from './src/index.js';
|
||||
48
packages/a2a-server/package.json
Normal file
48
packages/a2a-server/package.json
Normal file
@@ -0,0 +1,48 @@
|
||||
{
|
||||
"name": "@google/gemini-cli-a2a-server",
|
||||
"version": "0.1.0",
|
||||
"private": true,
|
||||
"description": "Gemini CLI A2A Server",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/google-gemini/gemini-cli.git",
|
||||
"directory": "packages/a2a-server"
|
||||
},
|
||||
"type": "module",
|
||||
"main": "dist/server.js",
|
||||
"scripts": {
|
||||
"start": "node dist/src/server.js",
|
||||
"build": "node ../../scripts/build_package.js",
|
||||
"lint": "eslint . --ext .ts,.tsx",
|
||||
"format": "prettier --write .",
|
||||
"test": "vitest run",
|
||||
"test:ci": "vitest run --coverage",
|
||||
"typecheck": "tsc --noEmit"
|
||||
},
|
||||
"files": [
|
||||
"dist"
|
||||
],
|
||||
"dependencies": {
|
||||
"@a2a-js/sdk": "^0.3.2",
|
||||
"@google-cloud/storage": "^7.16.0",
|
||||
"@google/gemini-cli-core": "file:../core",
|
||||
"express": "^5.1.0",
|
||||
"fs-extra": "^11.3.0",
|
||||
"tar": "^7.4.3",
|
||||
"uuid": "^11.1.0",
|
||||
"winston": "^3.17.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/express": "^5.0.3",
|
||||
"@types/fs-extra": "^11.0.4",
|
||||
"@types/supertest": "^6.0.3",
|
||||
"@types/tar": "^6.1.13",
|
||||
"dotenv": "^16.4.5",
|
||||
"supertest": "^7.1.4",
|
||||
"typescript": "^5.3.3",
|
||||
"vitest": "^3.1.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20"
|
||||
}
|
||||
}
|
||||
648
packages/a2a-server/src/agent.test.ts
Normal file
648
packages/a2a-server/src/agent.test.ts
Normal file
@@ -0,0 +1,648 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { Config } from '@google/gemini-cli-core';
|
||||
import {
|
||||
GeminiEventType,
|
||||
ApprovalMode,
|
||||
type ToolCallConfirmationDetails,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type {
|
||||
TaskStatusUpdateEvent,
|
||||
SendStreamingMessageSuccessResponse,
|
||||
} from '@a2a-js/sdk';
|
||||
import type express from 'express';
|
||||
import type { Server } from 'node:http';
|
||||
import request from 'supertest';
|
||||
import {
|
||||
afterAll,
|
||||
afterEach,
|
||||
beforeEach,
|
||||
beforeAll,
|
||||
describe,
|
||||
expect,
|
||||
it,
|
||||
vi,
|
||||
} from 'vitest';
|
||||
import { createApp } from './agent.js';
|
||||
import {
|
||||
assertUniqueFinalEventIsLast,
|
||||
assertTaskCreationAndWorkingStatus,
|
||||
createStreamMessageRequest,
|
||||
MockTool,
|
||||
} from './testing_utils.js';
|
||||
|
||||
const mockToolConfirmationFn = async () =>
|
||||
({}) as unknown as ToolCallConfirmationDetails;
|
||||
|
||||
const streamToSSEEvents = (
|
||||
stream: string,
|
||||
): SendStreamingMessageSuccessResponse[] =>
|
||||
stream
|
||||
.split('\n\n')
|
||||
.filter(Boolean) // Remove empty strings from trailing newlines
|
||||
.map((chunk) => {
|
||||
const dataLine = chunk
|
||||
.split('\n')
|
||||
.find((line) => line.startsWith('data: '));
|
||||
if (!dataLine) {
|
||||
throw new Error(`Invalid SSE chunk found: "${chunk}"`);
|
||||
}
|
||||
return JSON.parse(dataLine.substring(6));
|
||||
});
|
||||
|
||||
// Mock the logger to avoid polluting test output
|
||||
// Comment out to debug tests
|
||||
vi.mock('./logger.js', () => ({
|
||||
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn() },
|
||||
}));
|
||||
|
||||
let config: Config;
|
||||
const getToolRegistrySpy = vi.fn().mockReturnValue(ApprovalMode.DEFAULT);
|
||||
const getApprovalModeSpy = vi.fn();
|
||||
vi.mock('./config.js', async () => {
|
||||
const actual = await vi.importActual('./config.js');
|
||||
return {
|
||||
...actual,
|
||||
loadConfig: vi.fn().mockImplementation(async () => {
|
||||
config = {
|
||||
getToolRegistry: getToolRegistrySpy,
|
||||
getApprovalMode: getApprovalModeSpy,
|
||||
getIdeMode: vi.fn().mockReturnValue(false),
|
||||
getAllowedTools: vi.fn().mockReturnValue([]),
|
||||
getIdeClient: vi.fn(),
|
||||
getWorkspaceContext: vi.fn().mockReturnValue({
|
||||
isPathWithinWorkspace: () => true,
|
||||
}),
|
||||
getTargetDir: () => '/test',
|
||||
getGeminiClient: vi.fn(),
|
||||
getDebugMode: vi.fn().mockReturnValue(false),
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue({ model: 'gemini-pro' }),
|
||||
getModel: vi.fn().mockReturnValue('gemini-pro'),
|
||||
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
|
||||
setFlashFallbackHandler: vi.fn(),
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
} as unknown as Config;
|
||||
return config;
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock the GeminiClient to avoid actual API calls
|
||||
const sendMessageStreamSpy = vi.fn();
|
||||
vi.mock('@google/gemini-cli-core', async () => {
|
||||
const actual = await vi.importActual('@google/gemini-cli-core');
|
||||
return {
|
||||
...actual,
|
||||
GeminiClient: vi.fn().mockImplementation(() => ({
|
||||
sendMessageStream: sendMessageStreamSpy,
|
||||
getUserTier: vi.fn().mockReturnValue('free'),
|
||||
initialize: vi.fn(),
|
||||
})),
|
||||
};
|
||||
});
|
||||
|
||||
describe('E2E Tests', () => {
|
||||
let app: express.Express;
|
||||
let server: Server;
|
||||
|
||||
beforeAll(async () => {
|
||||
app = await createApp();
|
||||
server = app.listen(0); // Listen on a random available port
|
||||
});
|
||||
|
||||
beforeEach(() => {
|
||||
getApprovalModeSpy.mockReturnValue(ApprovalMode.DEFAULT);
|
||||
});
|
||||
|
||||
afterAll(
|
||||
() =>
|
||||
new Promise<void>((resolve) => {
|
||||
server.close(() => {
|
||||
resolve();
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('should create a new task and stream status updates (text-content) via POST /', async () => {
|
||||
sendMessageStreamSpy.mockImplementation(async function* () {
|
||||
yield* [{ type: 'content', value: 'Hello how are you?' }];
|
||||
});
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/')
|
||||
.send(createStreamMessageRequest('hello', 'a2a-test-message'))
|
||||
.set('Content-Type', 'application/json')
|
||||
.expect(200);
|
||||
|
||||
const events = streamToSSEEvents(res.text);
|
||||
|
||||
assertTaskCreationAndWorkingStatus(events);
|
||||
|
||||
// Status update: text-content
|
||||
const textContentEvent = events[2].result as TaskStatusUpdateEvent;
|
||||
expect(textContentEvent.kind).toBe('status-update');
|
||||
expect(textContentEvent.status.state).toBe('working');
|
||||
expect(textContentEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'text-content',
|
||||
});
|
||||
expect(textContentEvent.status.message?.parts).toMatchObject([
|
||||
{ kind: 'text', text: 'Hello how are you?' },
|
||||
]);
|
||||
|
||||
// Status update: input-required (final)
|
||||
const finalEvent = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(finalEvent.kind).toBe('status-update');
|
||||
expect(finalEvent.status?.state).toBe('input-required');
|
||||
expect(finalEvent.final).toBe(true);
|
||||
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(4);
|
||||
});
|
||||
|
||||
it('should create a new task, schedule a tool call, and wait for approval', async () => {
|
||||
// First call yields the tool request
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id',
|
||||
name: 'test-tool',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
});
|
||||
// Subsequent calls yield nothing
|
||||
sendMessageStreamSpy.mockImplementation(async function* () {
|
||||
yield* [];
|
||||
});
|
||||
|
||||
const mockTool = new MockTool(
|
||||
'test-tool',
|
||||
'Test Tool',
|
||||
true,
|
||||
false,
|
||||
mockToolConfirmationFn,
|
||||
);
|
||||
|
||||
getToolRegistrySpy.mockReturnValue({
|
||||
getAllTools: vi.fn().mockReturnValue([mockTool]),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getTool: vi.fn().mockReturnValue(mockTool),
|
||||
});
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/')
|
||||
.send(createStreamMessageRequest('run a tool', 'a2a-tool-test-message'))
|
||||
.set('Content-Type', 'application/json')
|
||||
.expect(200);
|
||||
|
||||
const events = streamToSSEEvents(res.text);
|
||||
assertTaskCreationAndWorkingStatus(events);
|
||||
|
||||
// Status update: working
|
||||
const workingEvent2 = events[2].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent2.kind).toBe('status-update');
|
||||
expect(workingEvent2.status.state).toBe('working');
|
||||
expect(workingEvent2.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'state-change',
|
||||
});
|
||||
|
||||
// Status update: tool-call-update
|
||||
const toolCallUpdateEvent = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallUpdateEvent.kind).toBe('status-update');
|
||||
expect(toolCallUpdateEvent.status.state).toBe('working');
|
||||
expect(toolCallUpdateEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(toolCallUpdateEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'validating',
|
||||
request: { callId: 'test-call-id' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// State update: awaiting_approval update
|
||||
const toolCallConfirmationEvent = events[4].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallConfirmationEvent.kind).toBe('status-update');
|
||||
expect(toolCallConfirmationEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-confirmation',
|
||||
});
|
||||
expect(toolCallConfirmationEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'awaiting_approval',
|
||||
request: { callId: 'test-call-id' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
expect(toolCallConfirmationEvent.status?.state).toBe('working');
|
||||
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(6);
|
||||
});
|
||||
|
||||
it('should handle multiple tool calls in a single turn', async () => {
|
||||
// First call yields the tool request
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id-1',
|
||||
name: 'test-tool-1',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id-2',
|
||||
name: 'test-tool-2',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
});
|
||||
// Subsequent calls yield nothing
|
||||
sendMessageStreamSpy.mockImplementation(async function* () {
|
||||
yield* [];
|
||||
});
|
||||
|
||||
const mockTool1 = new MockTool(
|
||||
'test-tool-1',
|
||||
'Test Tool 1',
|
||||
false,
|
||||
false,
|
||||
mockToolConfirmationFn,
|
||||
);
|
||||
const mockTool2 = new MockTool(
|
||||
'test-tool-2',
|
||||
'Test Tool 2',
|
||||
false,
|
||||
false,
|
||||
mockToolConfirmationFn,
|
||||
);
|
||||
|
||||
getToolRegistrySpy.mockReturnValue({
|
||||
getAllTools: vi.fn().mockReturnValue([mockTool1, mockTool2]),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getTool: vi.fn().mockImplementation((name: string) => {
|
||||
if (name === 'test-tool-1') return mockTool1;
|
||||
if (name === 'test-tool-2') return mockTool2;
|
||||
return undefined;
|
||||
}),
|
||||
});
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/')
|
||||
.send(
|
||||
createStreamMessageRequest(
|
||||
'run two tools',
|
||||
'a2a-multi-tool-test-message',
|
||||
),
|
||||
)
|
||||
.set('Content-Type', 'application/json')
|
||||
.expect(200);
|
||||
|
||||
const events = streamToSSEEvents(res.text);
|
||||
assertTaskCreationAndWorkingStatus(events);
|
||||
|
||||
// Second working update
|
||||
const workingEvent = events[2].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent.kind).toBe('status-update');
|
||||
expect(workingEvent.status.state).toBe('working');
|
||||
|
||||
// State Update: Validate each tool call
|
||||
const toolCallValidateEvent1 = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallValidateEvent1.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(toolCallValidateEvent1.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'validating',
|
||||
request: { callId: 'test-call-id-1' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
const toolCallValidateEvent2 = events[4].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallValidateEvent2.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(toolCallValidateEvent2.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'validating',
|
||||
request: { callId: 'test-call-id-2' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// State Update: Set each tool call to awaiting
|
||||
const toolCallAwaitEvent1 = events[5].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallAwaitEvent1.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-confirmation',
|
||||
});
|
||||
expect(toolCallAwaitEvent1.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'awaiting_approval',
|
||||
request: { callId: 'test-call-id-1' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
const toolCallAwaitEvent2 = events[6].result as TaskStatusUpdateEvent;
|
||||
expect(toolCallAwaitEvent2.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-confirmation',
|
||||
});
|
||||
expect(toolCallAwaitEvent2.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'awaiting_approval',
|
||||
request: { callId: 'test-call-id-2' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(8);
|
||||
});
|
||||
|
||||
it('should handle tool calls that do not require approval', async () => {
|
||||
// First call yields the tool request
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id-no-approval',
|
||||
name: 'test-tool-no-approval',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
});
|
||||
// Second call, after the tool runs, yields the final text
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [{ type: 'content', value: 'Tool executed successfully.' }];
|
||||
});
|
||||
|
||||
const mockTool = new MockTool(
|
||||
'test-tool-no-approval',
|
||||
'Test Tool No Approval',
|
||||
);
|
||||
mockTool.execute.mockResolvedValue({
|
||||
llmContent: 'Tool executed successfully.',
|
||||
returnDisplay: 'Tool executed successfully.',
|
||||
});
|
||||
|
||||
getToolRegistrySpy.mockReturnValue({
|
||||
getAllTools: vi.fn().mockReturnValue([mockTool]),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getTool: vi.fn().mockReturnValue(mockTool),
|
||||
});
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/')
|
||||
.send(
|
||||
createStreamMessageRequest(
|
||||
'run a tool without approval',
|
||||
'a2a-no-approval-test-message',
|
||||
),
|
||||
)
|
||||
.set('Content-Type', 'application/json')
|
||||
.expect(200);
|
||||
|
||||
const events = streamToSSEEvents(res.text);
|
||||
assertTaskCreationAndWorkingStatus(events);
|
||||
|
||||
// Status update: working
|
||||
const workingEvent2 = events[2].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent2.kind).toBe('status-update');
|
||||
expect(workingEvent2.status.state).toBe('working');
|
||||
|
||||
// Status update: tool-call-update (validating)
|
||||
const validatingEvent = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(validatingEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(validatingEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'validating',
|
||||
request: { callId: 'test-call-id-no-approval' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: tool-call-update (scheduled)
|
||||
const scheduledEvent = events[4].result as TaskStatusUpdateEvent;
|
||||
expect(scheduledEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(scheduledEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'scheduled',
|
||||
request: { callId: 'test-call-id-no-approval' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: tool-call-update (executing)
|
||||
const executingEvent = events[5].result as TaskStatusUpdateEvent;
|
||||
expect(executingEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(executingEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'executing',
|
||||
request: { callId: 'test-call-id-no-approval' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: tool-call-update (success)
|
||||
const successEvent = events[6].result as TaskStatusUpdateEvent;
|
||||
expect(successEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(successEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'success',
|
||||
request: { callId: 'test-call-id-no-approval' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: working (before sending tool result to LLM)
|
||||
const workingEvent3 = events[7].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent3.kind).toBe('status-update');
|
||||
expect(workingEvent3.status.state).toBe('working');
|
||||
|
||||
// Status update: text-content (final LLM response)
|
||||
const textContentEvent = events[8].result as TaskStatusUpdateEvent;
|
||||
expect(textContentEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'text-content',
|
||||
});
|
||||
expect(textContentEvent.status.message?.parts).toMatchObject([
|
||||
{ text: 'Tool executed successfully.' },
|
||||
]);
|
||||
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(10);
|
||||
});
|
||||
|
||||
it('should bypass tool approval in YOLO mode', async () => {
|
||||
// First call yields the tool request
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [
|
||||
{
|
||||
type: GeminiEventType.ToolCallRequest,
|
||||
value: {
|
||||
callId: 'test-call-id-yolo',
|
||||
name: 'test-tool-yolo',
|
||||
args: {},
|
||||
},
|
||||
},
|
||||
];
|
||||
});
|
||||
// Second call, after the tool runs, yields the final text
|
||||
sendMessageStreamSpy.mockImplementationOnce(async function* () {
|
||||
yield* [{ type: 'content', value: 'Tool executed successfully.' }];
|
||||
});
|
||||
|
||||
// Set approval mode to yolo
|
||||
getApprovalModeSpy.mockReturnValue(ApprovalMode.YOLO);
|
||||
|
||||
const mockTool = new MockTool(
|
||||
'test-tool-yolo',
|
||||
'Test Tool YOLO',
|
||||
false,
|
||||
false,
|
||||
);
|
||||
mockTool.execute.mockResolvedValue({
|
||||
llmContent: 'Tool executed successfully.',
|
||||
returnDisplay: 'Tool executed successfully.',
|
||||
});
|
||||
|
||||
getToolRegistrySpy.mockReturnValue({
|
||||
getAllTools: vi.fn().mockReturnValue([mockTool]),
|
||||
getToolsByServer: vi.fn().mockReturnValue([]),
|
||||
getTool: vi.fn().mockReturnValue(mockTool),
|
||||
});
|
||||
|
||||
const agent = request.agent(app);
|
||||
const res = await agent
|
||||
.post('/')
|
||||
.send(
|
||||
createStreamMessageRequest(
|
||||
'run a tool in yolo mode',
|
||||
'a2a-yolo-mode-test-message',
|
||||
),
|
||||
)
|
||||
.set('Content-Type', 'application/json')
|
||||
.expect(200);
|
||||
|
||||
const events = streamToSSEEvents(res.text);
|
||||
assertTaskCreationAndWorkingStatus(events);
|
||||
|
||||
// Status update: working
|
||||
const workingEvent2 = events[2].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent2.kind).toBe('status-update');
|
||||
expect(workingEvent2.status.state).toBe('working');
|
||||
|
||||
// Status update: tool-call-update (validating)
|
||||
const validatingEvent = events[3].result as TaskStatusUpdateEvent;
|
||||
expect(validatingEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(validatingEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'validating',
|
||||
request: { callId: 'test-call-id-yolo' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: tool-call-update (scheduled)
|
||||
const awaitingEvent = events[4].result as TaskStatusUpdateEvent;
|
||||
expect(awaitingEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(awaitingEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'scheduled',
|
||||
request: { callId: 'test-call-id-yolo' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: tool-call-update (executing)
|
||||
const executingEvent = events[5].result as TaskStatusUpdateEvent;
|
||||
expect(executingEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(executingEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'executing',
|
||||
request: { callId: 'test-call-id-yolo' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: tool-call-update (success)
|
||||
const successEvent = events[6].result as TaskStatusUpdateEvent;
|
||||
expect(successEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'tool-call-update',
|
||||
});
|
||||
expect(successEvent.status.message?.parts).toMatchObject([
|
||||
{
|
||||
data: {
|
||||
status: 'success',
|
||||
request: { callId: 'test-call-id-yolo' },
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// Status update: working (before sending tool result to LLM)
|
||||
const workingEvent3 = events[7].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent3.kind).toBe('status-update');
|
||||
expect(workingEvent3.status.state).toBe('working');
|
||||
|
||||
// Status update: text-content (final LLM response)
|
||||
const textContentEvent = events[8].result as TaskStatusUpdateEvent;
|
||||
expect(textContentEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'text-content',
|
||||
});
|
||||
expect(textContentEvent.status.message?.parts).toMatchObject([
|
||||
{ text: 'Tool executed successfully.' },
|
||||
]);
|
||||
|
||||
assertUniqueFinalEventIsLast(events);
|
||||
expect(events.length).toBe(10);
|
||||
});
|
||||
});
|
||||
785
packages/a2a-server/src/agent.ts
Normal file
785
packages/a2a-server/src/agent.ts
Normal file
@@ -0,0 +1,785 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import express from 'express';
|
||||
import { AsyncLocalStorage } from 'node:async_hooks';
|
||||
|
||||
import type { Message, Task as SDKTask, AgentCard } from '@a2a-js/sdk';
|
||||
import type {
|
||||
TaskStore,
|
||||
AgentExecutor,
|
||||
AgentExecutionEvent,
|
||||
RequestContext,
|
||||
ExecutionEventBus,
|
||||
} from '@a2a-js/sdk/server';
|
||||
import { DefaultRequestHandler, InMemoryTaskStore } from '@a2a-js/sdk/server';
|
||||
import { A2AExpressApp } from '@a2a-js/sdk/server/express'; // Import server components
|
||||
import type {
|
||||
ToolCallRequestInfo,
|
||||
ServerGeminiToolCallRequestEvent,
|
||||
Config,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { GeminiEventType } from '@google/gemini-cli-core';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { logger } from './logger.js';
|
||||
import type { StateChange, AgentSettings } from './types.js';
|
||||
import { CoderAgentEvent } from './types.js';
|
||||
import { loadConfig, loadEnvironment, setTargetDir } from './config.js';
|
||||
import { loadSettings } from './settings.js';
|
||||
import { loadExtensions } from './extension.js';
|
||||
import { Task } from './task.js';
|
||||
import { GCSTaskStore, NoOpTaskStore } from './gcs.js';
|
||||
import type { PersistedStateMetadata } from './metadata_types.js';
|
||||
import { getPersistedState, setPersistedState } from './metadata_types.js';
|
||||
|
||||
const requestStorage = new AsyncLocalStorage<{ req: express.Request }>();
|
||||
|
||||
/**
|
||||
* Provides a wrapper for Task. Passes data from Task to SDKTask.
|
||||
* The idea is to use this class inside CoderAgentExecutor to replace Task.
|
||||
*/
|
||||
class TaskWrapper {
|
||||
task: Task;
|
||||
agentSettings: AgentSettings;
|
||||
|
||||
constructor(task: Task, agentSettings: AgentSettings) {
|
||||
this.task = task;
|
||||
this.agentSettings = agentSettings;
|
||||
}
|
||||
|
||||
get id() {
|
||||
return this.task.id;
|
||||
}
|
||||
|
||||
toSDKTask(): SDKTask {
|
||||
const persistedState: PersistedStateMetadata = {
|
||||
_agentSettings: this.agentSettings,
|
||||
_taskState: this.task.taskState,
|
||||
};
|
||||
|
||||
const sdkTask: SDKTask = {
|
||||
id: this.task.id,
|
||||
contextId: this.task.contextId,
|
||||
kind: 'task',
|
||||
status: {
|
||||
state: this.task.taskState,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
metadata: setPersistedState({}, persistedState),
|
||||
history: [],
|
||||
artifacts: [],
|
||||
};
|
||||
sdkTask.metadata!['_contextId'] = this.task.contextId;
|
||||
return sdkTask;
|
||||
}
|
||||
}
|
||||
|
||||
const coderAgentCard: AgentCard = {
|
||||
name: 'Gemini SDLC Agent',
|
||||
description:
|
||||
'An agent that generates code based on natural language instructions and streams file outputs.',
|
||||
url: 'http://localhost:41242/',
|
||||
provider: {
|
||||
organization: 'Google',
|
||||
url: 'https://google.com',
|
||||
},
|
||||
protocolVersion: '0.3.0',
|
||||
version: '0.0.2', // Incremented version
|
||||
capabilities: {
|
||||
streaming: true,
|
||||
pushNotifications: false,
|
||||
stateTransitionHistory: true,
|
||||
},
|
||||
securitySchemes: undefined,
|
||||
security: undefined,
|
||||
defaultInputModes: ['text'],
|
||||
defaultOutputModes: ['text'],
|
||||
skills: [
|
||||
{
|
||||
id: 'code_generation',
|
||||
name: 'Code Generation',
|
||||
description:
|
||||
'Generates code snippets or complete files based on user requests, streaming the results.',
|
||||
tags: ['code', 'development', 'programming'],
|
||||
examples: [
|
||||
'Write a python function to calculate fibonacci numbers.',
|
||||
'Create an HTML file with a basic button that alerts "Hello!" when clicked.',
|
||||
],
|
||||
inputModes: ['text'],
|
||||
outputModes: ['text'],
|
||||
},
|
||||
],
|
||||
supportsAuthenticatedExtendedCard: false,
|
||||
};
|
||||
|
||||
/**
|
||||
* CoderAgentExecutor implements the agent's core logic for code generation.
|
||||
*/
|
||||
class CoderAgentExecutor implements AgentExecutor {
|
||||
private tasks: Map<string, TaskWrapper> = new Map();
|
||||
// Track tasks with an active execution loop.
|
||||
private executingTasks = new Set<string>();
|
||||
|
||||
constructor(private taskStore?: TaskStore) {}
|
||||
|
||||
private async getConfig(
|
||||
agentSettings: AgentSettings,
|
||||
taskId: string,
|
||||
): Promise<Config> {
|
||||
const workspaceRoot = setTargetDir(agentSettings);
|
||||
loadEnvironment(); // Will override any global env with workspace envs
|
||||
const settings = loadSettings(workspaceRoot);
|
||||
const extensions = loadExtensions(workspaceRoot);
|
||||
return await loadConfig(settings, extensions, taskId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconstructs TaskWrapper from SDKTask.
|
||||
*/
|
||||
async reconstruct(
|
||||
sdkTask: SDKTask,
|
||||
eventBus?: ExecutionEventBus,
|
||||
): Promise<TaskWrapper> {
|
||||
const metadata = sdkTask.metadata || {};
|
||||
const persistedState = getPersistedState(metadata);
|
||||
|
||||
if (!persistedState) {
|
||||
throw new Error(
|
||||
`Cannot reconstruct task ${sdkTask.id}: missing persisted state in metadata.`,
|
||||
);
|
||||
}
|
||||
|
||||
const agentSettings = persistedState._agentSettings;
|
||||
const config = await this.getConfig(agentSettings, sdkTask.id);
|
||||
const contextId =
|
||||
(metadata['_contextId'] as string) || (sdkTask.contextId as string);
|
||||
const runtimeTask = await Task.create(
|
||||
sdkTask.id,
|
||||
contextId,
|
||||
config,
|
||||
eventBus,
|
||||
);
|
||||
runtimeTask.taskState = persistedState._taskState;
|
||||
await runtimeTask.geminiClient.initialize(
|
||||
runtimeTask.config.getContentGeneratorConfig(),
|
||||
);
|
||||
|
||||
const wrapper = new TaskWrapper(runtimeTask, agentSettings);
|
||||
this.tasks.set(sdkTask.id, wrapper);
|
||||
logger.info(`Task ${sdkTask.id} reconstructed from store.`);
|
||||
return wrapper;
|
||||
}
|
||||
|
||||
async createTask(
|
||||
taskId: string,
|
||||
contextId: string,
|
||||
agentSettingsInput?: AgentSettings,
|
||||
eventBus?: ExecutionEventBus,
|
||||
): Promise<TaskWrapper> {
|
||||
const agentSettings = agentSettingsInput || ({} as AgentSettings);
|
||||
const config = await this.getConfig(agentSettings, taskId);
|
||||
const runtimeTask = await Task.create(taskId, contextId, config, eventBus);
|
||||
await runtimeTask.geminiClient.initialize(
|
||||
runtimeTask.config.getContentGeneratorConfig(),
|
||||
);
|
||||
|
||||
const wrapper = new TaskWrapper(runtimeTask, agentSettings);
|
||||
this.tasks.set(taskId, wrapper);
|
||||
logger.info(`New task ${taskId} created.`);
|
||||
return wrapper;
|
||||
}
|
||||
|
||||
getTask(taskId: string): TaskWrapper | undefined {
|
||||
return this.tasks.get(taskId);
|
||||
}
|
||||
|
||||
getAllTasks(): TaskWrapper[] {
|
||||
return Array.from(this.tasks.values());
|
||||
}
|
||||
|
||||
cancelTask = async (
|
||||
taskId: string,
|
||||
eventBus: ExecutionEventBus,
|
||||
): Promise<void> => {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Received cancel request for task ${taskId}`,
|
||||
);
|
||||
const wrapper = this.tasks.get(taskId);
|
||||
|
||||
if (!wrapper) {
|
||||
logger.warn(
|
||||
`[CoderAgentExecutor] Task ${taskId} not found for cancellation.`,
|
||||
);
|
||||
eventBus.publish({
|
||||
kind: 'status-update',
|
||||
taskId,
|
||||
contextId: uuidv4(),
|
||||
status: {
|
||||
state: 'failed',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [{ kind: 'text', text: `Task ${taskId} not found.` }],
|
||||
messageId: uuidv4(),
|
||||
taskId,
|
||||
},
|
||||
},
|
||||
final: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { task } = wrapper;
|
||||
|
||||
if (task.taskState === 'canceled' || task.taskState === 'failed') {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId} is already in a final state: ${task.taskState}. No action needed for cancellation.`,
|
||||
);
|
||||
eventBus.publish({
|
||||
kind: 'status-update',
|
||||
taskId,
|
||||
contextId: task.contextId,
|
||||
status: {
|
||||
state: task.taskState,
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{
|
||||
kind: 'text',
|
||||
text: `Task ${taskId} is already ${task.taskState}.`,
|
||||
},
|
||||
],
|
||||
messageId: uuidv4(),
|
||||
taskId,
|
||||
},
|
||||
},
|
||||
final: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Initiating cancellation for task ${taskId}.`,
|
||||
);
|
||||
task.cancelPendingTools('Task canceled by user request.');
|
||||
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
task.setTaskStateAndPublishUpdate(
|
||||
'canceled',
|
||||
stateChange,
|
||||
'Task canceled by user request.',
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId} cancellation processed. Saving state.`,
|
||||
);
|
||||
await this.taskStore?.save(wrapper.toSDKTask());
|
||||
logger.info(`[CoderAgentExecutor] Task ${taskId} state CANCELED saved.`);
|
||||
} catch (error) {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : 'Unknown error';
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Error during task cancellation for ${taskId}: ${errorMessage}`,
|
||||
error,
|
||||
);
|
||||
eventBus.publish({
|
||||
kind: 'status-update',
|
||||
taskId,
|
||||
contextId: task.contextId,
|
||||
status: {
|
||||
state: 'failed',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{
|
||||
kind: 'text',
|
||||
text: `Failed to process cancellation for task ${taskId}: ${errorMessage}`,
|
||||
},
|
||||
],
|
||||
messageId: uuidv4(),
|
||||
taskId,
|
||||
},
|
||||
},
|
||||
final: true,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
async execute(
|
||||
requestContext: RequestContext,
|
||||
eventBus: ExecutionEventBus,
|
||||
): Promise<void> {
|
||||
const userMessage = requestContext.userMessage as Message;
|
||||
const sdkTask = requestContext.task as SDKTask | undefined;
|
||||
|
||||
const taskId = sdkTask?.id || userMessage.taskId || uuidv4();
|
||||
const contextId =
|
||||
userMessage.contextId ||
|
||||
sdkTask?.contextId ||
|
||||
sdkTask?.metadata?.['_contextId'] ||
|
||||
uuidv4();
|
||||
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Executing for taskId: ${taskId}, contextId: ${contextId}`,
|
||||
);
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] userMessage: ${JSON.stringify(userMessage)}`,
|
||||
);
|
||||
eventBus.on('event', (event: AgentExecutionEvent) =>
|
||||
logger.info('[EventBus event]: ', event),
|
||||
);
|
||||
|
||||
const store = requestStorage.getStore();
|
||||
if (!store) {
|
||||
logger.error(
|
||||
'[CoderAgentExecutor] Could not get request from async local storage. Cancellation on socket close will not be handled for this request.',
|
||||
);
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const abortSignal = abortController.signal;
|
||||
|
||||
if (store) {
|
||||
// Grab the raw socket from the request object
|
||||
const socket = store.req.socket;
|
||||
const onClientEnd = () => {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Client socket closed for task ${taskId}. Cancelling execution.`,
|
||||
);
|
||||
if (!abortController.signal.aborted) {
|
||||
abortController.abort();
|
||||
}
|
||||
// Clean up the listener to prevent memory leaks
|
||||
socket.removeListener('close', onClientEnd);
|
||||
};
|
||||
|
||||
// Listen on the socket's 'end' event (remote closed the connection)
|
||||
socket.on('end', onClientEnd);
|
||||
|
||||
// It's also good practice to remove the listener if the task completes successfully
|
||||
abortSignal.addEventListener('abort', () => {
|
||||
socket.removeListener('end', onClientEnd);
|
||||
});
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Socket close handler set up for task ${taskId}.`,
|
||||
);
|
||||
}
|
||||
|
||||
let wrapper: TaskWrapper | undefined = this.tasks.get(taskId);
|
||||
|
||||
if (wrapper) {
|
||||
wrapper.task.eventBus = eventBus;
|
||||
logger.info(`[CoderAgentExecutor] Task ${taskId} found in memory cache.`);
|
||||
} else if (sdkTask) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId} found in TaskStore. Reconstructing...`,
|
||||
);
|
||||
try {
|
||||
wrapper = await this.reconstruct(sdkTask, eventBus);
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Failed to hydrate task ${taskId}:`,
|
||||
e,
|
||||
);
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
eventBus.publish({
|
||||
kind: 'status-update',
|
||||
taskId,
|
||||
contextId: sdkTask.contextId,
|
||||
status: {
|
||||
state: 'failed',
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{
|
||||
kind: 'text',
|
||||
text: 'Internal error: Task state lost or corrupted.',
|
||||
},
|
||||
],
|
||||
messageId: uuidv4(),
|
||||
taskId,
|
||||
contextId: sdkTask.contextId,
|
||||
} as Message,
|
||||
},
|
||||
final: true,
|
||||
metadata: { coderAgent: stateChange },
|
||||
});
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
logger.info(`[CoderAgentExecutor] Creating new task ${taskId}.`);
|
||||
const agentSettings = userMessage.metadata?.[
|
||||
'coderAgent'
|
||||
] as AgentSettings;
|
||||
wrapper = await this.createTask(
|
||||
taskId,
|
||||
contextId as string,
|
||||
agentSettings,
|
||||
eventBus,
|
||||
);
|
||||
const newTaskSDK = wrapper.toSDKTask();
|
||||
eventBus.publish({
|
||||
...newTaskSDK,
|
||||
kind: 'task',
|
||||
status: { state: 'submitted', timestamp: new Date().toISOString() },
|
||||
history: [userMessage],
|
||||
});
|
||||
try {
|
||||
await this.taskStore?.save(newTaskSDK);
|
||||
logger.info(`[CoderAgentExecutor] New task ${taskId} saved to store.`);
|
||||
} catch (saveError) {
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Failed to save new task ${taskId} to store:`,
|
||||
saveError,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if (!wrapper) {
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Task ${taskId} is unexpectedly undefined after load/create.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const currentTask = wrapper.task;
|
||||
|
||||
if (['canceled', 'failed', 'completed'].includes(currentTask.taskState)) {
|
||||
logger.warn(
|
||||
`[CoderAgentExecutor] Attempted to execute task ${taskId} which is already in state ${currentTask.taskState}. Ignoring.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (this.executingTasks.has(taskId)) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId} has a pending execution. Processing message and yielding.`,
|
||||
);
|
||||
currentTask.eventBus = eventBus;
|
||||
for await (const _ of currentTask.acceptUserMessage(
|
||||
requestContext,
|
||||
abortController.signal,
|
||||
)) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Processing user message ${userMessage.messageId} in secondary execution loop for task ${taskId}.`,
|
||||
);
|
||||
}
|
||||
// End this execution-- the original/source will be resumed.
|
||||
return;
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Starting main execution for message ${userMessage.messageId} for task ${taskId}.`,
|
||||
);
|
||||
this.executingTasks.add(taskId);
|
||||
|
||||
try {
|
||||
let agentTurnActive = true;
|
||||
logger.info(`[CoderAgentExecutor] Task ${taskId}: Processing user turn.`);
|
||||
let agentEvents = currentTask.acceptUserMessage(
|
||||
requestContext,
|
||||
abortSignal,
|
||||
);
|
||||
|
||||
while (agentTurnActive) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Processing agent turn (LLM stream).`,
|
||||
);
|
||||
const toolCallRequests: ToolCallRequestInfo[] = [];
|
||||
for await (const event of agentEvents) {
|
||||
if (abortSignal.aborted) {
|
||||
logger.warn(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Abort signal received during agent event processing.`,
|
||||
);
|
||||
throw new Error('Execution aborted');
|
||||
}
|
||||
if (event.type === GeminiEventType.ToolCallRequest) {
|
||||
toolCallRequests.push(
|
||||
(event as ServerGeminiToolCallRequestEvent).value,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
await currentTask.acceptAgentMessage(event);
|
||||
}
|
||||
|
||||
if (abortSignal.aborted) throw new Error('Execution aborted');
|
||||
|
||||
if (toolCallRequests.length > 0) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Found ${toolCallRequests.length} tool call requests. Scheduling as a batch.`,
|
||||
);
|
||||
await currentTask.scheduleToolCalls(toolCallRequests, abortSignal);
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Waiting for pending tools if any.`,
|
||||
);
|
||||
await currentTask.waitForPendingTools();
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: All pending tools completed or none were pending.`,
|
||||
);
|
||||
|
||||
if (abortSignal.aborted) throw new Error('Execution aborted');
|
||||
|
||||
const completedTools = currentTask.getAndClearCompletedTools();
|
||||
|
||||
if (completedTools.length > 0) {
|
||||
// If all completed tool calls were canceled, manually add them to history and set state to input-required, final:true
|
||||
if (completedTools.every((tool) => tool.status === 'cancelled')) {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: All tool calls were cancelled. Updating history and ending agent turn.`,
|
||||
);
|
||||
currentTask.addToolResponsesToHistory(completedTools);
|
||||
agentTurnActive = false;
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
currentTask.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
stateChange,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
} else {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Found ${completedTools.length} completed tool calls. Sending results back to LLM.`,
|
||||
);
|
||||
|
||||
agentEvents = currentTask.sendCompletedToolsToLlm(
|
||||
completedTools,
|
||||
abortSignal,
|
||||
);
|
||||
// Continue the loop to process the LLM response to the tool results.
|
||||
}
|
||||
} else {
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: No more tool calls to process. Ending agent turn.`,
|
||||
);
|
||||
agentTurnActive = false;
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Task ${taskId}: Agent turn finished, setting to input-required.`,
|
||||
);
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
currentTask.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
stateChange,
|
||||
undefined,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
} catch (error) {
|
||||
if (abortSignal.aborted) {
|
||||
logger.warn(`[CoderAgentExecutor] Task ${taskId} execution aborted.`);
|
||||
currentTask.cancelPendingTools('Execution aborted');
|
||||
if (
|
||||
currentTask.taskState !== 'canceled' &&
|
||||
currentTask.taskState !== 'failed'
|
||||
) {
|
||||
currentTask.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
{ kind: CoderAgentEvent.StateChangeEvent },
|
||||
'Execution aborted by client.',
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : 'Agent execution error';
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Error executing agent for task ${taskId}:`,
|
||||
error,
|
||||
);
|
||||
currentTask.cancelPendingTools(errorMessage);
|
||||
if (currentTask.taskState !== 'failed') {
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
currentTask.setTaskStateAndPublishUpdate(
|
||||
'failed',
|
||||
stateChange,
|
||||
errorMessage,
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
this.executingTasks.delete(taskId);
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Saving final state for task ${taskId}.`,
|
||||
);
|
||||
try {
|
||||
await this.taskStore?.save(wrapper.toSDKTask());
|
||||
logger.info(`[CoderAgentExecutor] Task ${taskId} state saved.`);
|
||||
} catch (saveError) {
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Failed to save task ${taskId} state in finally block:`,
|
||||
saveError,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function updateCoderAgentCardUrl(port: number) {
|
||||
coderAgentCard.url = `http://localhost:${port}/`;
|
||||
}
|
||||
|
||||
export async function main() {
|
||||
try {
|
||||
const expressApp = await createApp();
|
||||
const port = process.env['CODER_AGENT_PORT'] || 0;
|
||||
|
||||
const server = expressApp.listen(port, () => {
|
||||
const address = server.address();
|
||||
let actualPort;
|
||||
if (process.env['CODER_AGENT_PORT']) {
|
||||
actualPort = process.env['CODER_AGENT_PORT'];
|
||||
} else if (address && typeof address !== 'string') {
|
||||
actualPort = address.port;
|
||||
} else {
|
||||
throw new Error('[Core Agent] Could not find port number.');
|
||||
}
|
||||
updateCoderAgentCardUrl(Number(actualPort));
|
||||
logger.info(
|
||||
`[CoreAgent] Agent Server started on http://localhost:${actualPort}`,
|
||||
);
|
||||
logger.info(
|
||||
`[CoreAgent] Agent Card: http://localhost:${actualPort}/.well-known/agent-card.json`,
|
||||
);
|
||||
logger.info('[CoreAgent] Press Ctrl+C to stop the server');
|
||||
});
|
||||
} catch (error) {
|
||||
logger.error('[CoreAgent] Error during startup:', error);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
export async function createApp() {
|
||||
try {
|
||||
// loadEnvironment() is called within getConfig now
|
||||
const bucketName = process.env['GCS_BUCKET_NAME'];
|
||||
let taskStoreForExecutor: TaskStore;
|
||||
let taskStoreForHandler: TaskStore;
|
||||
|
||||
if (bucketName) {
|
||||
logger.info(`Using GCSTaskStore with bucket: ${bucketName}`);
|
||||
const gcsTaskStore = new GCSTaskStore(bucketName);
|
||||
taskStoreForExecutor = gcsTaskStore;
|
||||
taskStoreForHandler = new NoOpTaskStore(gcsTaskStore);
|
||||
} else {
|
||||
logger.info('Using InMemoryTaskStore');
|
||||
const inMemoryTaskStore = new InMemoryTaskStore();
|
||||
taskStoreForExecutor = inMemoryTaskStore;
|
||||
taskStoreForHandler = inMemoryTaskStore;
|
||||
}
|
||||
|
||||
const agentExecutor = new CoderAgentExecutor(taskStoreForExecutor);
|
||||
|
||||
const requestHandler = new DefaultRequestHandler(
|
||||
coderAgentCard,
|
||||
taskStoreForHandler,
|
||||
agentExecutor,
|
||||
);
|
||||
|
||||
let expressApp = express();
|
||||
expressApp.use((req, res, next) => {
|
||||
requestStorage.run({ req }, next);
|
||||
});
|
||||
|
||||
const appBuilder = new A2AExpressApp(requestHandler);
|
||||
expressApp = appBuilder.setupRoutes(expressApp, '');
|
||||
expressApp.use(express.json());
|
||||
|
||||
expressApp.post('/tasks', async (req, res) => {
|
||||
try {
|
||||
const taskId = uuidv4();
|
||||
const agentSettings = req.body.agentSettings as
|
||||
| AgentSettings
|
||||
| undefined;
|
||||
const contextId = req.body.contextId || uuidv4();
|
||||
const wrapper = await agentExecutor.createTask(
|
||||
taskId,
|
||||
contextId,
|
||||
agentSettings,
|
||||
);
|
||||
await taskStoreForExecutor.save(wrapper.toSDKTask());
|
||||
res.status(201).json(wrapper.id);
|
||||
} catch (error) {
|
||||
logger.error('[CoreAgent] Error creating task:', error);
|
||||
const errorMessage =
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'Unknown error creating task';
|
||||
res.status(500).send({ error: errorMessage });
|
||||
}
|
||||
});
|
||||
|
||||
expressApp.get('/tasks/metadata', async (req, res) => {
|
||||
// This endpoint is only meaningful if the task store is in-memory.
|
||||
if (!(taskStoreForExecutor instanceof InMemoryTaskStore)) {
|
||||
res.status(501).send({
|
||||
error:
|
||||
'Listing all task metadata is only supported when using InMemoryTaskStore.',
|
||||
});
|
||||
}
|
||||
try {
|
||||
const wrappers = agentExecutor.getAllTasks();
|
||||
if (wrappers && wrappers.length > 0) {
|
||||
const tasksMetadata = await Promise.all(
|
||||
wrappers.map((wrapper) => wrapper.task.getMetadata()),
|
||||
);
|
||||
res.status(200).json(tasksMetadata);
|
||||
} else {
|
||||
res.status(204).send();
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('[CoreAgent] Error getting all task metadata:', error);
|
||||
const errorMessage =
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: 'Unknown error getting task metadata';
|
||||
res.status(500).send({ error: errorMessage });
|
||||
}
|
||||
});
|
||||
|
||||
expressApp.get('/tasks/:taskId/metadata', async (req, res) => {
|
||||
const taskId = req.params.taskId;
|
||||
let wrapper = agentExecutor.getTask(taskId);
|
||||
if (!wrapper) {
|
||||
const sdkTask = await taskStoreForExecutor.load(taskId);
|
||||
if (sdkTask) {
|
||||
wrapper = await agentExecutor.reconstruct(sdkTask);
|
||||
}
|
||||
}
|
||||
if (!wrapper) {
|
||||
res.status(404).send({ error: 'Task not found' });
|
||||
return;
|
||||
}
|
||||
res.json({ metadata: await wrapper.task.getMetadata() });
|
||||
});
|
||||
return expressApp;
|
||||
} catch (error) {
|
||||
logger.error('[CoreAgent] Error during startup:', error);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
203
packages/a2a-server/src/config.ts
Normal file
203
packages/a2a-server/src/config.ts
Normal file
@@ -0,0 +1,203 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { homedir } from 'node:os';
|
||||
import * as dotenv from 'dotenv';
|
||||
|
||||
import type { TelemetryTarget } from '@google/gemini-cli-core';
|
||||
import {
|
||||
AuthType,
|
||||
Config,
|
||||
type ConfigParameters,
|
||||
FileDiscoveryService,
|
||||
ApprovalMode,
|
||||
loadServerHierarchicalMemory,
|
||||
GEMINI_CONFIG_DIR,
|
||||
DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
DEFAULT_GEMINI_MODEL,
|
||||
} from '@google/gemini-cli-core';
|
||||
|
||||
import { logger } from './logger.js';
|
||||
import type { Settings } from './settings.js';
|
||||
import type { Extension } from './extension.js';
|
||||
import { type AgentSettings, CoderAgentEvent } from './types.js';
|
||||
|
||||
export async function loadConfig(
|
||||
settings: Settings,
|
||||
extensions: Extension[],
|
||||
taskId: string,
|
||||
): Promise<Config> {
|
||||
const mcpServers = mergeMcpServers(settings, extensions);
|
||||
const workspaceDir = process.cwd();
|
||||
const adcFilePath = process.env['GOOGLE_APPLICATION_CREDENTIALS'];
|
||||
|
||||
const configParams: ConfigParameters = {
|
||||
sessionId: taskId,
|
||||
model: DEFAULT_GEMINI_MODEL,
|
||||
embeddingModel: DEFAULT_GEMINI_EMBEDDING_MODEL,
|
||||
sandbox: undefined, // Sandbox might not be relevant for a server-side agent
|
||||
targetDir: workspaceDir, // Or a specific directory the agent operates on
|
||||
debugMode: process.env['DEBUG'] === 'true' || false,
|
||||
question: '', // Not used in server mode directly like CLI
|
||||
fullContext: false, // Server might have different context needs
|
||||
coreTools: settings.coreTools || undefined,
|
||||
excludeTools: settings.excludeTools || undefined,
|
||||
showMemoryUsage: settings.showMemoryUsage || false,
|
||||
approvalMode:
|
||||
process.env['GEMINI_YOLO_MODE'] === 'true'
|
||||
? ApprovalMode.YOLO
|
||||
: ApprovalMode.DEFAULT,
|
||||
mcpServers,
|
||||
cwd: workspaceDir,
|
||||
telemetry: {
|
||||
enabled: settings.telemetry?.enabled,
|
||||
target: settings.telemetry?.target as TelemetryTarget,
|
||||
otlpEndpoint:
|
||||
process.env['OTEL_EXPORTER_OTLP_ENDPOINT'] ??
|
||||
settings.telemetry?.otlpEndpoint,
|
||||
logPrompts: settings.telemetry?.logPrompts,
|
||||
},
|
||||
// Git-aware file filtering settings
|
||||
fileFiltering: {
|
||||
respectGitIgnore: settings.fileFiltering?.respectGitIgnore,
|
||||
enableRecursiveFileSearch:
|
||||
settings.fileFiltering?.enableRecursiveFileSearch,
|
||||
},
|
||||
ideMode: false,
|
||||
};
|
||||
|
||||
const fileService = new FileDiscoveryService(workspaceDir);
|
||||
const extensionContextFilePaths = extensions.flatMap((e) => e.contextFiles);
|
||||
const { memoryContent, fileCount } = await loadServerHierarchicalMemory(
|
||||
workspaceDir,
|
||||
[workspaceDir],
|
||||
false,
|
||||
fileService,
|
||||
extensionContextFilePaths,
|
||||
);
|
||||
configParams.userMemory = memoryContent;
|
||||
configParams.geminiMdFileCount = fileCount;
|
||||
|
||||
const config = new Config({
|
||||
...configParams,
|
||||
});
|
||||
// Needed to initialize ToolRegistry, and git checkpointing if enabled
|
||||
await config.initialize();
|
||||
|
||||
if (process.env['USE_CCPA']) {
|
||||
logger.info('[Config] Using CCPA Auth:');
|
||||
try {
|
||||
if (adcFilePath) {
|
||||
path.resolve(adcFilePath);
|
||||
}
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
`[Config] USE_CCPA env var is true but unable to resolve GOOGLE_APPLICATION_CREDENTIALS file path ${adcFilePath}. Error ${e}`,
|
||||
);
|
||||
}
|
||||
await config.refreshAuth(AuthType.LOGIN_WITH_GOOGLE);
|
||||
logger.info(
|
||||
`[Config] GOOGLE_CLOUD_PROJECT: ${process.env['GOOGLE_CLOUD_PROJECT']}`,
|
||||
);
|
||||
} else if (process.env['GEMINI_API_KEY']) {
|
||||
logger.info('[Config] Using Gemini API Key');
|
||||
await config.refreshAuth(AuthType.USE_GEMINI);
|
||||
} else {
|
||||
logger.error(
|
||||
`[Config] Unable to set GeneratorConfig. Please provide a GEMINI_API_KEY or set USE_CCPA.`,
|
||||
);
|
||||
}
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
export function mergeMcpServers(settings: Settings, extensions: Extension[]) {
|
||||
const mcpServers = { ...(settings.mcpServers || {}) };
|
||||
for (const extension of extensions) {
|
||||
Object.entries(extension.config.mcpServers || {}).forEach(
|
||||
([key, server]) => {
|
||||
if (mcpServers[key]) {
|
||||
console.warn(
|
||||
`Skipping extension MCP config for server with key "${key}" as it already exists.`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
mcpServers[key] = server;
|
||||
},
|
||||
);
|
||||
}
|
||||
return mcpServers;
|
||||
}
|
||||
|
||||
export function setTargetDir(agentSettings: AgentSettings | undefined): string {
|
||||
const originalCWD = process.cwd();
|
||||
const targetDir =
|
||||
process.env['CODER_AGENT_WORKSPACE_PATH'] ??
|
||||
(agentSettings?.kind === CoderAgentEvent.StateAgentSettingsEvent
|
||||
? agentSettings.workspacePath
|
||||
: undefined);
|
||||
|
||||
if (!targetDir) {
|
||||
return originalCWD;
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[CoderAgentExecutor] Overriding workspace path to: ${targetDir}`,
|
||||
);
|
||||
|
||||
try {
|
||||
const resolvedPath = path.resolve(targetDir);
|
||||
process.chdir(resolvedPath);
|
||||
return resolvedPath;
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
`[CoderAgentExecutor] Error resolving workspace path: ${e}, returning original os.cwd()`,
|
||||
);
|
||||
return originalCWD;
|
||||
}
|
||||
}
|
||||
|
||||
export function loadEnvironment(): void {
|
||||
const envFilePath = findEnvFile(process.cwd());
|
||||
if (envFilePath) {
|
||||
dotenv.config({ path: envFilePath, override: true });
|
||||
}
|
||||
}
|
||||
|
||||
function findEnvFile(startDir: string): string | null {
|
||||
let currentDir = path.resolve(startDir);
|
||||
while (true) {
|
||||
// prefer gemini-specific .env under GEMINI_DIR
|
||||
const geminiEnvPath = path.join(currentDir, GEMINI_CONFIG_DIR, '.env');
|
||||
if (fs.existsSync(geminiEnvPath)) {
|
||||
return geminiEnvPath;
|
||||
}
|
||||
const envPath = path.join(currentDir, '.env');
|
||||
if (fs.existsSync(envPath)) {
|
||||
return envPath;
|
||||
}
|
||||
const parentDir = path.dirname(currentDir);
|
||||
if (parentDir === currentDir || !parentDir) {
|
||||
// check .env under home as fallback, again preferring gemini-specific .env
|
||||
const homeGeminiEnvPath = path.join(
|
||||
process.cwd(),
|
||||
GEMINI_CONFIG_DIR,
|
||||
'.env',
|
||||
);
|
||||
if (fs.existsSync(homeGeminiEnvPath)) {
|
||||
return homeGeminiEnvPath;
|
||||
}
|
||||
const homeEnvPath = path.join(homedir(), '.env');
|
||||
if (fs.existsSync(homeEnvPath)) {
|
||||
return homeEnvPath;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
currentDir = parentDir;
|
||||
}
|
||||
}
|
||||
146
packages/a2a-server/src/endpoints.test.ts
Normal file
146
packages/a2a-server/src/endpoints.test.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeAll, afterAll, vi } from 'vitest';
|
||||
import request from 'supertest';
|
||||
import type express from 'express';
|
||||
import { createApp, updateCoderAgentCardUrl } from './agent.js';
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import type { Server } from 'node:http';
|
||||
import type { TaskMetadata } from './types.js';
|
||||
import type { AddressInfo } from 'node:net';
|
||||
|
||||
// Mock the logger to avoid polluting test output
|
||||
// Comment out to help debug
|
||||
vi.mock('./logger.js', () => ({
|
||||
logger: { info: vi.fn(), warn: vi.fn(), error: vi.fn() },
|
||||
}));
|
||||
|
||||
// Mock Task.create to avoid its complex setup
|
||||
vi.mock('./task.js', () => {
|
||||
class MockTask {
|
||||
id: string;
|
||||
contextId: string;
|
||||
taskState = 'submitted';
|
||||
config = {
|
||||
getContentGeneratorConfig: vi
|
||||
.fn()
|
||||
.mockReturnValue({ model: 'gemini-pro' }),
|
||||
};
|
||||
geminiClient = {
|
||||
initialize: vi.fn().mockResolvedValue(undefined),
|
||||
};
|
||||
constructor(id: string, contextId: string) {
|
||||
this.id = id;
|
||||
this.contextId = contextId;
|
||||
}
|
||||
static create = vi
|
||||
.fn()
|
||||
.mockImplementation((id, contextId) =>
|
||||
Promise.resolve(new MockTask(id, contextId)),
|
||||
);
|
||||
getMetadata = vi.fn().mockImplementation(async () => ({
|
||||
id: this.id,
|
||||
contextId: this.contextId,
|
||||
taskState: this.taskState,
|
||||
model: 'gemini-pro',
|
||||
mcpServers: [],
|
||||
availableTools: [],
|
||||
}));
|
||||
}
|
||||
return { Task: MockTask };
|
||||
});
|
||||
|
||||
describe('Agent Server Endpoints', () => {
|
||||
let app: express.Express;
|
||||
let server: Server;
|
||||
let testWorkspace: string;
|
||||
|
||||
const createTask = (contextId: string) =>
|
||||
request(app)
|
||||
.post('/tasks')
|
||||
.send({
|
||||
contextId,
|
||||
agentSettings: {
|
||||
kind: 'agent-settings',
|
||||
workspacePath: testWorkspace,
|
||||
},
|
||||
})
|
||||
.set('Content-Type', 'application/json');
|
||||
|
||||
beforeAll(async () => {
|
||||
// Create a unique temporary directory for the workspace to avoid conflicts
|
||||
testWorkspace = fs.mkdtempSync(
|
||||
path.join(os.tmpdir(), 'gemini-agent-test-'),
|
||||
);
|
||||
app = await createApp();
|
||||
await new Promise<void>((resolve) => {
|
||||
server = app.listen(0, () => {
|
||||
const port = (server.address() as AddressInfo).port;
|
||||
updateCoderAgentCardUrl(port);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
afterAll(
|
||||
() =>
|
||||
new Promise<void>((resolve, reject) => {
|
||||
server.close((err) => {
|
||||
if (err) return reject(err);
|
||||
|
||||
try {
|
||||
fs.rmSync(testWorkspace, { recursive: true, force: true });
|
||||
} catch (e) {
|
||||
console.warn(`Could not remove temp dir '${testWorkspace}':`, e);
|
||||
}
|
||||
resolve();
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
it('should create a new task via POST /tasks', async () => {
|
||||
const response = await createTask('test-context');
|
||||
expect(response.status).toBe(201);
|
||||
expect(response.body).toBeTypeOf('string'); // Should return the task ID
|
||||
}, 7000);
|
||||
|
||||
it('should get metadata for a specific task via GET /tasks/:taskId/metadata', async () => {
|
||||
const createResponse = await createTask('test-context-2');
|
||||
const taskId = createResponse.body;
|
||||
const response = await request(app).get(`/tasks/${taskId}/metadata`);
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.metadata.id).toBe(taskId);
|
||||
}, 6000);
|
||||
|
||||
it('should get metadata for all tasks via GET /tasks/metadata', async () => {
|
||||
const createResponse = await createTask('test-context-3');
|
||||
const taskId = createResponse.body;
|
||||
const response = await request(app).get('/tasks/metadata');
|
||||
expect(response.status).toBe(200);
|
||||
expect(Array.isArray(response.body)).toBe(true);
|
||||
expect(response.body.length).toBeGreaterThan(0);
|
||||
const taskMetadata = response.body.find(
|
||||
(m: TaskMetadata) => m.id === taskId,
|
||||
);
|
||||
expect(taskMetadata).toBeDefined();
|
||||
});
|
||||
|
||||
it('should return 404 for a non-existent task', async () => {
|
||||
const response = await request(app).get('/tasks/fake-task/metadata');
|
||||
expect(response.status).toBe(404);
|
||||
});
|
||||
|
||||
it('should return agent metadata via GET /.well-known/agent-card.json', async () => {
|
||||
const response = await request(app).get('/.well-known/agent-card.json');
|
||||
const port = (server.address() as AddressInfo).port;
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.body.name).toBe('Gemini SDLC Agent');
|
||||
expect(response.body.url).toBe(`http://localhost:${port}/`);
|
||||
});
|
||||
});
|
||||
118
packages/a2a-server/src/extension.ts
Normal file
118
packages/a2a-server/src/extension.ts
Normal file
@@ -0,0 +1,118 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
// Copied exactly from packages/cli/src/config/extension.ts, last PR #1026
|
||||
|
||||
import type { MCPServerConfig } from '@google/gemini-cli-core';
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import * as os from 'node:os';
|
||||
import { logger } from './logger.js';
|
||||
|
||||
export const EXTENSIONS_DIRECTORY_NAME = path.join('.gemini', 'extensions');
|
||||
export const EXTENSIONS_CONFIG_FILENAME = 'gemini-extension.json';
|
||||
|
||||
export interface Extension {
|
||||
config: ExtensionConfig;
|
||||
contextFiles: string[];
|
||||
}
|
||||
|
||||
export interface ExtensionConfig {
|
||||
name: string;
|
||||
version: string;
|
||||
mcpServers?: Record<string, MCPServerConfig>;
|
||||
contextFileName?: string | string[];
|
||||
}
|
||||
|
||||
export function loadExtensions(workspaceDir: string): Extension[] {
|
||||
const allExtensions = [
|
||||
...loadExtensionsFromDir(workspaceDir),
|
||||
...loadExtensionsFromDir(os.homedir()),
|
||||
];
|
||||
|
||||
const uniqueExtensions: Extension[] = [];
|
||||
const seenNames = new Set<string>();
|
||||
for (const extension of allExtensions) {
|
||||
if (!seenNames.has(extension.config.name)) {
|
||||
logger.info(
|
||||
`Loading extension: ${extension.config.name} (version: ${extension.config.version})`,
|
||||
);
|
||||
uniqueExtensions.push(extension);
|
||||
seenNames.add(extension.config.name);
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueExtensions;
|
||||
}
|
||||
|
||||
function loadExtensionsFromDir(dir: string): Extension[] {
|
||||
const extensionsDir = path.join(dir, EXTENSIONS_DIRECTORY_NAME);
|
||||
if (!fs.existsSync(extensionsDir)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const extensions: Extension[] = [];
|
||||
for (const subdir of fs.readdirSync(extensionsDir)) {
|
||||
const extensionDir = path.join(extensionsDir, subdir);
|
||||
|
||||
const extension = loadExtension(extensionDir);
|
||||
if (extension != null) {
|
||||
extensions.push(extension);
|
||||
}
|
||||
}
|
||||
return extensions;
|
||||
}
|
||||
|
||||
function loadExtension(extensionDir: string): Extension | null {
|
||||
if (!fs.statSync(extensionDir).isDirectory()) {
|
||||
logger.error(
|
||||
`Warning: unexpected file ${extensionDir} in extensions directory.`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const configFilePath = path.join(extensionDir, EXTENSIONS_CONFIG_FILENAME);
|
||||
if (!fs.existsSync(configFilePath)) {
|
||||
logger.error(
|
||||
`Warning: extension directory ${extensionDir} does not contain a config file ${configFilePath}.`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
try {
|
||||
const configContent = fs.readFileSync(configFilePath, 'utf-8');
|
||||
const config = JSON.parse(configContent) as ExtensionConfig;
|
||||
if (!config.name || !config.version) {
|
||||
logger.error(
|
||||
`Invalid extension config in ${configFilePath}: missing name or version.`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
||||
const contextFiles = getContextFileNames(config)
|
||||
.map((contextFileName) => path.join(extensionDir, contextFileName))
|
||||
.filter((contextFilePath) => fs.existsSync(contextFilePath));
|
||||
|
||||
return {
|
||||
config,
|
||||
contextFiles,
|
||||
};
|
||||
} catch (e) {
|
||||
logger.error(
|
||||
`Warning: error parsing extension config in ${configFilePath}: ${e}`,
|
||||
);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function getContextFileNames(config: ExtensionConfig): string[] {
|
||||
if (!config.contextFileName) {
|
||||
return ['GEMINI.md'];
|
||||
} else if (!Array.isArray(config.contextFileName)) {
|
||||
return [config.contextFileName];
|
||||
}
|
||||
return config.contextFileName;
|
||||
}
|
||||
340
packages/a2a-server/src/gcs.test.ts
Normal file
340
packages/a2a-server/src/gcs.test.ts
Normal file
@@ -0,0 +1,340 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { Storage } from '@google-cloud/storage';
|
||||
import * as fse from 'fs-extra';
|
||||
import { promises as fsPromises, createReadStream } from 'node:fs';
|
||||
import * as tar from 'tar';
|
||||
import { gzipSync, gunzipSync } from 'node:zlib';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import type { Task as SDKTask } from '@a2a-js/sdk';
|
||||
import type { TaskStore } from '@a2a-js/sdk/server';
|
||||
import type { Mocked, MockedClass, Mock } from 'vitest';
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest';
|
||||
|
||||
import { GCSTaskStore, NoOpTaskStore } from './gcs.js';
|
||||
import { logger } from './logger.js';
|
||||
import * as configModule from './config.js';
|
||||
import * as metadataModule from './metadata_types.js';
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@google-cloud/storage');
|
||||
vi.mock('fs-extra', () => ({
|
||||
pathExists: vi.fn(),
|
||||
readdir: vi.fn(),
|
||||
remove: vi.fn(),
|
||||
ensureDir: vi.fn(),
|
||||
}));
|
||||
vi.mock('node:fs', async () => {
|
||||
const actual = await vi.importActual<typeof import('node:fs')>('node:fs');
|
||||
return {
|
||||
...actual,
|
||||
promises: {
|
||||
...actual.promises,
|
||||
readdir: vi.fn(),
|
||||
},
|
||||
createReadStream: vi.fn(),
|
||||
};
|
||||
});
|
||||
vi.mock('tar');
|
||||
vi.mock('zlib');
|
||||
vi.mock('uuid');
|
||||
vi.mock('./logger', () => ({
|
||||
logger: {
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
debug: vi.fn(),
|
||||
},
|
||||
}));
|
||||
vi.mock('./config');
|
||||
vi.mock('./metadata_types');
|
||||
vi.mock('node:stream/promises', () => ({
|
||||
pipeline: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockStorage = Storage as MockedClass<typeof Storage>;
|
||||
const mockFse = fse as Mocked<typeof fse>;
|
||||
const mockCreateReadStream = createReadStream as Mock;
|
||||
const mockTar = tar as Mocked<typeof tar>;
|
||||
const mockGzipSync = gzipSync as Mock;
|
||||
const mockGunzipSync = gunzipSync as Mock;
|
||||
const mockUuidv4 = uuidv4 as Mock;
|
||||
const mockSetTargetDir = configModule.setTargetDir as Mock;
|
||||
const mockGetPersistedState = metadataModule.getPersistedState as Mock;
|
||||
const METADATA_KEY = metadataModule.METADATA_KEY || '__persistedState';
|
||||
|
||||
type MockWriteStream = {
|
||||
on: Mock<
|
||||
(event: string, cb: (error?: Error | null) => void) => MockWriteStream
|
||||
>;
|
||||
destroy: Mock<() => void>;
|
||||
destroyed: boolean;
|
||||
};
|
||||
|
||||
type MockFile = {
|
||||
save: Mock<(data: Buffer | string) => Promise<void>>;
|
||||
download: Mock<() => Promise<[Buffer]>>;
|
||||
exists: Mock<() => Promise<[boolean]>>;
|
||||
createWriteStream: Mock<() => MockWriteStream>;
|
||||
};
|
||||
|
||||
type MockBucket = {
|
||||
exists: Mock<() => Promise<[boolean]>>;
|
||||
file: Mock<(path: string) => MockFile>;
|
||||
name: string;
|
||||
};
|
||||
|
||||
type MockStorageInstance = {
|
||||
bucket: Mock<(name: string) => MockBucket>;
|
||||
getBuckets: Mock<() => Promise<[Array<{ name: string }>]>>;
|
||||
createBucket: Mock<(name: string) => Promise<[MockBucket]>>;
|
||||
};
|
||||
|
||||
describe('GCSTaskStore', () => {
|
||||
let bucketName: string;
|
||||
let mockBucket: MockBucket;
|
||||
let mockFile: MockFile;
|
||||
let mockWriteStream: MockWriteStream;
|
||||
let mockStorageInstance: MockStorageInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
bucketName = 'test-bucket';
|
||||
|
||||
mockWriteStream = {
|
||||
on: vi.fn((event, cb) => {
|
||||
if (event === 'finish') setTimeout(cb, 0); // Simulate async finish
|
||||
return mockWriteStream;
|
||||
}),
|
||||
destroy: vi.fn(),
|
||||
destroyed: false,
|
||||
};
|
||||
|
||||
mockFile = {
|
||||
save: vi.fn().mockResolvedValue(undefined),
|
||||
download: vi.fn().mockResolvedValue([Buffer.from('')]),
|
||||
exists: vi.fn().mockResolvedValue([true]),
|
||||
createWriteStream: vi.fn().mockReturnValue(mockWriteStream),
|
||||
};
|
||||
|
||||
mockBucket = {
|
||||
exists: vi.fn().mockResolvedValue([true]),
|
||||
file: vi.fn().mockReturnValue(mockFile),
|
||||
name: bucketName,
|
||||
};
|
||||
|
||||
mockStorageInstance = {
|
||||
bucket: vi.fn().mockReturnValue(mockBucket),
|
||||
getBuckets: vi.fn().mockResolvedValue([[{ name: bucketName }]]),
|
||||
createBucket: vi.fn().mockResolvedValue([mockBucket]),
|
||||
};
|
||||
mockStorage.mockReturnValue(mockStorageInstance as unknown as Storage);
|
||||
|
||||
mockUuidv4.mockReturnValue('test-uuid');
|
||||
mockSetTargetDir.mockReturnValue('/tmp/workdir');
|
||||
mockGetPersistedState.mockReturnValue({
|
||||
_agentSettings: {},
|
||||
_taskState: 'submitted',
|
||||
});
|
||||
(fse.pathExists as Mock).mockResolvedValue(true);
|
||||
(fsPromises.readdir as Mock).mockResolvedValue(['file1.txt']);
|
||||
mockTar.c.mockResolvedValue(undefined);
|
||||
mockTar.x.mockResolvedValue(undefined);
|
||||
mockFse.remove.mockResolvedValue(undefined);
|
||||
mockFse.ensureDir.mockResolvedValue(undefined);
|
||||
mockGzipSync.mockReturnValue(Buffer.from('compressed'));
|
||||
mockGunzipSync.mockReturnValue(Buffer.from('{}'));
|
||||
mockCreateReadStream.mockReturnValue({ on: vi.fn(), pipe: vi.fn() });
|
||||
});
|
||||
|
||||
describe('Constructor & Initialization', () => {
|
||||
it('should initialize and check bucket existence', async () => {
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
await store['ensureBucketInitialized']();
|
||||
expect(mockStorage).toHaveBeenCalledTimes(1);
|
||||
expect(mockStorageInstance.getBuckets).toHaveBeenCalled();
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Bucket test-bucket exists'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create bucket if it does not exist', async () => {
|
||||
mockStorageInstance.getBuckets.mockResolvedValue([[]]);
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
await store['ensureBucketInitialized']();
|
||||
expect(mockStorageInstance.createBucket).toHaveBeenCalledWith(bucketName);
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('Bucket test-bucket created successfully'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should throw if bucket creation fails', async () => {
|
||||
mockStorageInstance.getBuckets.mockResolvedValue([[]]);
|
||||
mockStorageInstance.createBucket.mockRejectedValue(
|
||||
new Error('Create failed'),
|
||||
);
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
await expect(store['ensureBucketInitialized']()).rejects.toThrow(
|
||||
'Failed to create GCS bucket test-bucket: Error: Create failed',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('save', () => {
|
||||
const mockTask: SDKTask = {
|
||||
id: 'task1',
|
||||
contextId: 'ctx1',
|
||||
kind: 'task',
|
||||
status: { state: 'working' },
|
||||
metadata: {},
|
||||
};
|
||||
|
||||
it('should save metadata and workspace', async () => {
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
await store.save(mockTask);
|
||||
|
||||
expect(mockFile.save).toHaveBeenCalledTimes(1);
|
||||
expect(mockTar.c).toHaveBeenCalledTimes(1);
|
||||
expect(mockCreateReadStream).toHaveBeenCalledTimes(1);
|
||||
expect(mockFse.remove).toHaveBeenCalledTimes(1);
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('metadata saved to GCS'),
|
||||
);
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('workspace saved to GCS'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle tar creation failure', async () => {
|
||||
mockFse.pathExists.mockImplementation(
|
||||
async (path) =>
|
||||
!path.toString().includes('task-task1-workspace-test-uuid.tar.gz'),
|
||||
);
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
await expect(store.save(mockTask)).rejects.toThrow(
|
||||
'tar.c command failed to create',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('load', () => {
|
||||
it('should load task metadata and workspace', async () => {
|
||||
mockGunzipSync.mockReturnValue(
|
||||
Buffer.from(
|
||||
JSON.stringify({
|
||||
[METADATA_KEY]: { _agentSettings: {}, _taskState: 'submitted' },
|
||||
_contextId: 'ctx1',
|
||||
}),
|
||||
),
|
||||
);
|
||||
mockFile.download.mockResolvedValue([Buffer.from('compressed metadata')]);
|
||||
mockFile.download.mockResolvedValueOnce([
|
||||
Buffer.from('compressed metadata'),
|
||||
]);
|
||||
mockBucket.file = vi.fn((path) => {
|
||||
const newMockFile = { ...mockFile };
|
||||
if (path.includes('metadata')) {
|
||||
newMockFile.download = vi
|
||||
.fn()
|
||||
.mockResolvedValue([Buffer.from('compressed metadata')]);
|
||||
newMockFile.exists = vi.fn().mockResolvedValue([true]);
|
||||
} else {
|
||||
newMockFile.download = vi
|
||||
.fn()
|
||||
.mockResolvedValue([Buffer.from('compressed workspace')]);
|
||||
newMockFile.exists = vi.fn().mockResolvedValue([true]);
|
||||
}
|
||||
return newMockFile;
|
||||
});
|
||||
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
const task = await store.load('task1');
|
||||
|
||||
expect(task).toBeDefined();
|
||||
expect(task?.id).toBe('task1');
|
||||
expect(mockBucket.file).toHaveBeenCalledWith(
|
||||
'tasks/task1/metadata.tar.gz',
|
||||
);
|
||||
expect(mockBucket.file).toHaveBeenCalledWith(
|
||||
'tasks/task1/workspace.tar.gz',
|
||||
);
|
||||
expect(mockTar.x).toHaveBeenCalledTimes(1);
|
||||
expect(mockFse.remove).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it('should return undefined if metadata not found', async () => {
|
||||
mockFile.exists.mockResolvedValue([false]);
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
const task = await store.load('task1');
|
||||
expect(task).toBeUndefined();
|
||||
expect(mockBucket.file).toHaveBeenCalledWith(
|
||||
'tasks/task1/metadata.tar.gz',
|
||||
);
|
||||
});
|
||||
|
||||
it('should load metadata even if workspace not found', async () => {
|
||||
mockGunzipSync.mockReturnValue(
|
||||
Buffer.from(
|
||||
JSON.stringify({
|
||||
[METADATA_KEY]: { _agentSettings: {}, _taskState: 'submitted' },
|
||||
_contextId: 'ctx1',
|
||||
}),
|
||||
),
|
||||
);
|
||||
|
||||
mockBucket.file = vi.fn((path) => {
|
||||
const newMockFile = { ...mockFile };
|
||||
if (path.includes('workspace.tar.gz')) {
|
||||
newMockFile.exists = vi.fn().mockResolvedValue([false]);
|
||||
} else {
|
||||
newMockFile.exists = vi.fn().mockResolvedValue([true]);
|
||||
newMockFile.download = vi
|
||||
.fn()
|
||||
.mockResolvedValue([Buffer.from('compressed metadata')]);
|
||||
}
|
||||
return newMockFile;
|
||||
});
|
||||
|
||||
const store = new GCSTaskStore(bucketName);
|
||||
const task = await store.load('task1');
|
||||
|
||||
expect(task).toBeDefined();
|
||||
expect(mockTar.x).not.toHaveBeenCalled();
|
||||
expect(logger.info).toHaveBeenCalledWith(
|
||||
expect.stringContaining('workspace archive not found'),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('NoOpTaskStore', () => {
|
||||
let realStore: TaskStore;
|
||||
let noOpStore: NoOpTaskStore;
|
||||
|
||||
beforeEach(() => {
|
||||
// Create a mock of the real store to delegate to
|
||||
realStore = {
|
||||
save: vi.fn(),
|
||||
load: vi.fn().mockResolvedValue({ id: 'task-123' } as SDKTask),
|
||||
};
|
||||
noOpStore = new NoOpTaskStore(realStore);
|
||||
});
|
||||
|
||||
it("should not call the real store's save method", async () => {
|
||||
const mockTask: SDKTask = { id: 'test-task' } as SDKTask;
|
||||
await noOpStore.save(mockTask);
|
||||
expect(realStore.save).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should delegate the load method to the real store', async () => {
|
||||
const taskId = 'task-123';
|
||||
const result = await noOpStore.load(taskId);
|
||||
expect(realStore.load).toHaveBeenCalledWith(taskId);
|
||||
expect(result).toBeDefined();
|
||||
expect(result?.id).toBe(taskId);
|
||||
});
|
||||
});
|
||||
308
packages/a2a-server/src/gcs.ts
Normal file
308
packages/a2a-server/src/gcs.ts
Normal file
@@ -0,0 +1,308 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { Storage } from '@google-cloud/storage';
|
||||
import { gzipSync, gunzipSync } from 'node:zlib';
|
||||
import * as tar from 'tar';
|
||||
import * as fse from 'fs-extra';
|
||||
import { promises as fsPromises, createReadStream } from 'node:fs';
|
||||
import { tmpdir } from 'node:os';
|
||||
import { join } from 'node:path';
|
||||
import type { Task as SDKTask } from '@a2a-js/sdk';
|
||||
import type { TaskStore } from '@a2a-js/sdk/server';
|
||||
import { logger } from './logger.js';
|
||||
import { setTargetDir } from './config.js';
|
||||
import {
|
||||
getPersistedState,
|
||||
type PersistedTaskMetadata,
|
||||
} from './metadata_types.js';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
type ObjectType = 'metadata' | 'workspace';
|
||||
|
||||
const getTmpArchiveFilename = (taskId: string): string =>
|
||||
`task-${taskId}-workspace-${uuidv4()}.tar.gz`;
|
||||
|
||||
export class GCSTaskStore implements TaskStore {
|
||||
private storage: Storage;
|
||||
private bucketName: string;
|
||||
private bucketInitialized: Promise<void>;
|
||||
|
||||
constructor(bucketName: string) {
|
||||
if (!bucketName) {
|
||||
throw new Error('GCS bucket name is required.');
|
||||
}
|
||||
this.storage = new Storage();
|
||||
this.bucketName = bucketName;
|
||||
logger.info(`GCSTaskStore initializing with bucket: ${this.bucketName}`);
|
||||
// Prerequisites: user account or service account must have storage admin IAM role
|
||||
// and the bucket name must be unique.
|
||||
this.bucketInitialized = this.initializeBucket();
|
||||
}
|
||||
|
||||
private async initializeBucket(): Promise<void> {
|
||||
try {
|
||||
const [buckets] = await this.storage.getBuckets();
|
||||
const exists = buckets.some((bucket) => bucket.name === this.bucketName);
|
||||
|
||||
if (!exists) {
|
||||
logger.info(
|
||||
`Bucket ${this.bucketName} does not exist in the list. Attempting to create...`,
|
||||
);
|
||||
try {
|
||||
await this.storage.createBucket(this.bucketName);
|
||||
logger.info(`Bucket ${this.bucketName} created successfully.`);
|
||||
} catch (createError) {
|
||||
logger.info(
|
||||
`Failed to create bucket ${this.bucketName}: ${createError}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to create GCS bucket ${this.bucketName}: ${createError}`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
logger.info(`Bucket ${this.bucketName} exists.`);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.info(
|
||||
`Error during bucket initialization for ${this.bucketName}: ${error}`,
|
||||
);
|
||||
throw new Error(
|
||||
`Failed to initialize GCS bucket ${this.bucketName}: ${error}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private async ensureBucketInitialized(): Promise<void> {
|
||||
await this.bucketInitialized;
|
||||
}
|
||||
|
||||
private getObjectPath(taskId: string, type: ObjectType): string {
|
||||
return `tasks/${taskId}/${type}.tar.gz`;
|
||||
}
|
||||
|
||||
async save(task: SDKTask): Promise<void> {
|
||||
await this.ensureBucketInitialized();
|
||||
const taskId = task.id;
|
||||
const persistedState = getPersistedState(
|
||||
task.metadata as PersistedTaskMetadata,
|
||||
);
|
||||
|
||||
if (!persistedState) {
|
||||
throw new Error(`Task ${taskId} is missing persisted state in metadata.`);
|
||||
}
|
||||
const workDir = process.cwd();
|
||||
|
||||
const metadataObjectPath = this.getObjectPath(taskId, 'metadata');
|
||||
const workspaceObjectPath = this.getObjectPath(taskId, 'workspace');
|
||||
|
||||
const dataToStore = task.metadata;
|
||||
|
||||
try {
|
||||
const jsonString = JSON.stringify(dataToStore);
|
||||
const compressedMetadata = gzipSync(Buffer.from(jsonString));
|
||||
const metadataFile = this.storage
|
||||
.bucket(this.bucketName)
|
||||
.file(metadataObjectPath);
|
||||
await metadataFile.save(compressedMetadata, {
|
||||
contentType: 'application/gzip',
|
||||
});
|
||||
logger.info(
|
||||
`Task ${taskId} metadata saved to GCS: gs://${this.bucketName}/${metadataObjectPath}`,
|
||||
);
|
||||
|
||||
if (await fse.pathExists(workDir)) {
|
||||
const entries = await fsPromises.readdir(workDir);
|
||||
if (entries.length > 0) {
|
||||
const tmpArchiveFile = join(tmpdir(), getTmpArchiveFilename(taskId));
|
||||
try {
|
||||
await tar.c(
|
||||
{
|
||||
gzip: true,
|
||||
file: tmpArchiveFile,
|
||||
cwd: workDir,
|
||||
portable: true,
|
||||
},
|
||||
entries,
|
||||
);
|
||||
|
||||
if (!(await fse.pathExists(tmpArchiveFile))) {
|
||||
throw new Error(
|
||||
`tar.c command failed to create ${tmpArchiveFile}`,
|
||||
);
|
||||
}
|
||||
|
||||
const workspaceFile = this.storage
|
||||
.bucket(this.bucketName)
|
||||
.file(workspaceObjectPath);
|
||||
const sourceStream = createReadStream(tmpArchiveFile);
|
||||
const destStream = workspaceFile.createWriteStream({
|
||||
contentType: 'application/gzip',
|
||||
resumable: true,
|
||||
});
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
sourceStream.on('error', (err) => {
|
||||
logger.error(
|
||||
`Error in source stream for ${tmpArchiveFile}:`,
|
||||
err,
|
||||
);
|
||||
// Attempt to close destStream if source fails
|
||||
if (!destStream.destroyed) {
|
||||
destStream.destroy(err);
|
||||
}
|
||||
reject(err);
|
||||
});
|
||||
|
||||
destStream.on('error', (err) => {
|
||||
logger.error(
|
||||
`Error in GCS dest stream for ${workspaceObjectPath}:`,
|
||||
err,
|
||||
);
|
||||
reject(err);
|
||||
});
|
||||
|
||||
destStream.on('finish', () => {
|
||||
logger.info(
|
||||
`GCS destStream finished for ${workspaceObjectPath}`,
|
||||
);
|
||||
resolve();
|
||||
});
|
||||
|
||||
logger.info(
|
||||
`Piping ${tmpArchiveFile} to GCS object ${workspaceObjectPath}`,
|
||||
);
|
||||
sourceStream.pipe(destStream);
|
||||
});
|
||||
logger.info(
|
||||
`Task ${taskId} workspace saved to GCS: gs://${this.bucketName}/${workspaceObjectPath}`,
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`Error during workspace save process for ${taskId}:`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
} finally {
|
||||
logger.info(`Cleaning up temporary file: ${tmpArchiveFile}`);
|
||||
try {
|
||||
if (await fse.pathExists(tmpArchiveFile)) {
|
||||
await fse.remove(tmpArchiveFile);
|
||||
logger.info(
|
||||
`Successfully removed temporary file: ${tmpArchiveFile}`,
|
||||
);
|
||||
} else {
|
||||
logger.warn(
|
||||
`Temporary file not found for cleanup: ${tmpArchiveFile}`,
|
||||
);
|
||||
}
|
||||
} catch (removeError) {
|
||||
logger.error(
|
||||
`Error removing temporary file ${tmpArchiveFile}:`,
|
||||
removeError,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.info(
|
||||
`Workspace directory ${workDir} is empty, skipping workspace save for task ${taskId}.`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
logger.info(
|
||||
`Workspace directory ${workDir} not found, skipping workspace save for task ${taskId}.`,
|
||||
);
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to save task ${taskId} to GCS:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async load(taskId: string): Promise<SDKTask | undefined> {
|
||||
await this.ensureBucketInitialized();
|
||||
const metadataObjectPath = this.getObjectPath(taskId, 'metadata');
|
||||
const workspaceObjectPath = this.getObjectPath(taskId, 'workspace');
|
||||
|
||||
try {
|
||||
const metadataFile = this.storage
|
||||
.bucket(this.bucketName)
|
||||
.file(metadataObjectPath);
|
||||
const [metadataExists] = await metadataFile.exists();
|
||||
if (!metadataExists) {
|
||||
logger.info(`Task ${taskId} metadata not found in GCS.`);
|
||||
return undefined;
|
||||
}
|
||||
const [compressedMetadata] = await metadataFile.download();
|
||||
const jsonData = gunzipSync(compressedMetadata).toString();
|
||||
const loadedMetadata = JSON.parse(jsonData);
|
||||
logger.info(`Task ${taskId} metadata loaded from GCS.`);
|
||||
|
||||
const persistedState = getPersistedState(loadedMetadata);
|
||||
if (!persistedState) {
|
||||
throw new Error(
|
||||
`Loaded metadata for task ${taskId} is missing internal persisted state.`,
|
||||
);
|
||||
}
|
||||
const agentSettings = persistedState._agentSettings;
|
||||
|
||||
const workDir = setTargetDir(agentSettings);
|
||||
await fse.ensureDir(workDir);
|
||||
const workspaceFile = this.storage
|
||||
.bucket(this.bucketName)
|
||||
.file(workspaceObjectPath);
|
||||
const [workspaceExists] = await workspaceFile.exists();
|
||||
if (workspaceExists) {
|
||||
const tmpArchiveFile = join(tmpdir(), getTmpArchiveFilename(taskId));
|
||||
try {
|
||||
await workspaceFile.download({ destination: tmpArchiveFile });
|
||||
await tar.x({ file: tmpArchiveFile, cwd: workDir });
|
||||
logger.info(
|
||||
`Task ${taskId} workspace restored from GCS to ${workDir}`,
|
||||
);
|
||||
} finally {
|
||||
if (await fse.pathExists(tmpArchiveFile)) {
|
||||
await fse.remove(tmpArchiveFile);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logger.info(`Task ${taskId} workspace archive not found in GCS.`);
|
||||
}
|
||||
|
||||
return {
|
||||
id: taskId,
|
||||
contextId: loadedMetadata._contextId || uuidv4(),
|
||||
kind: 'task',
|
||||
status: {
|
||||
state: persistedState._taskState,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
metadata: loadedMetadata,
|
||||
history: [],
|
||||
artifacts: [],
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error(`Failed to load task ${taskId} from GCS:`, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class NoOpTaskStore implements TaskStore {
|
||||
constructor(private realStore: TaskStore) {}
|
||||
|
||||
async save(task: SDKTask): Promise<void> {
|
||||
logger.info(`[NoOpTaskStore] save called for task ${task.id} - IGNORED`);
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
async load(taskId: string): Promise<SDKTask | undefined> {
|
||||
logger.info(
|
||||
`[NoOpTaskStore] load called for task ${taskId}, delegating to real store.`,
|
||||
);
|
||||
return this.realStore.load(taskId);
|
||||
}
|
||||
}
|
||||
8
packages/a2a-server/src/index.ts
Normal file
8
packages/a2a-server/src/index.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
export * from './agent.js';
|
||||
export * from './types.js';
|
||||
28
packages/a2a-server/src/logger.ts
Normal file
28
packages/a2a-server/src/logger.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import winston from 'winston';
|
||||
|
||||
const logger = winston.createLogger({
|
||||
level: 'info',
|
||||
format: winston.format.combine(
|
||||
// First, add a timestamp to the log info object
|
||||
winston.format.timestamp({
|
||||
format: 'YYYY-MM-DD HH:mm:ss.SSS A', // Custom timestamp format
|
||||
}),
|
||||
// Here we define the custom output format
|
||||
winston.format.printf((info) => {
|
||||
const { level, timestamp, message, ...rest } = info;
|
||||
return (
|
||||
`[${level.toUpperCase()}] ${timestamp} -- ${message}` +
|
||||
`${Object.keys(rest).length > 0 ? `\n${JSON.stringify(rest, null, 2)}` : ''}`
|
||||
); // Only print ...rest if present
|
||||
}),
|
||||
),
|
||||
transports: [new winston.transports.Console()],
|
||||
});
|
||||
|
||||
export { logger };
|
||||
33
packages/a2a-server/src/metadata_types.ts
Normal file
33
packages/a2a-server/src/metadata_types.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type { AgentSettings } from './types.js';
|
||||
import type { TaskState } from '@a2a-js/sdk';
|
||||
|
||||
export interface PersistedStateMetadata {
|
||||
_agentSettings: AgentSettings;
|
||||
_taskState: TaskState;
|
||||
}
|
||||
|
||||
export type PersistedTaskMetadata = { [k: string]: unknown };
|
||||
|
||||
export const METADATA_KEY = '__persistedState';
|
||||
|
||||
export function getPersistedState(
|
||||
metadata: PersistedTaskMetadata,
|
||||
): PersistedStateMetadata | undefined {
|
||||
return metadata?.[METADATA_KEY] as PersistedStateMetadata | undefined;
|
||||
}
|
||||
|
||||
export function setPersistedState(
|
||||
metadata: PersistedTaskMetadata,
|
||||
state: PersistedStateMetadata,
|
||||
): PersistedTaskMetadata {
|
||||
return {
|
||||
...metadata,
|
||||
[METADATA_KEY]: state,
|
||||
};
|
||||
}
|
||||
33
packages/a2a-server/src/server.ts
Normal file
33
packages/a2a-server/src/server.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as url from 'node:url';
|
||||
import * as path from 'node:path';
|
||||
|
||||
import { logger } from './logger.js';
|
||||
import { main } from './agent.js';
|
||||
|
||||
// Check if the module is the main script being run. path.resolve() creates a
|
||||
// canonical, absolute path, which avoids cross-platform issues.
|
||||
const isMainModule =
|
||||
path.resolve(process.argv[1]) ===
|
||||
path.resolve(url.fileURLToPath(import.meta.url));
|
||||
|
||||
process.on('uncaughtException', (error) => {
|
||||
logger.error('Unhandled exception:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
|
||||
if (
|
||||
import.meta.url.startsWith('file:') &&
|
||||
isMainModule &&
|
||||
process.env['NODE_ENV'] !== 'test'
|
||||
) {
|
||||
main().catch((error) => {
|
||||
logger.error('[CoreAgent] Unhandled error in main:', error);
|
||||
process.exit(1);
|
||||
});
|
||||
}
|
||||
154
packages/a2a-server/src/settings.ts
Normal file
154
packages/a2a-server/src/settings.ts
Normal file
@@ -0,0 +1,154 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import * as fs from 'node:fs';
|
||||
import * as path from 'node:path';
|
||||
import { homedir } from 'node:os';
|
||||
|
||||
import type { MCPServerConfig } from '@google/gemini-cli-core';
|
||||
import {
|
||||
getErrorMessage,
|
||||
type TelemetrySettings,
|
||||
} from '@google/gemini-cli-core';
|
||||
import stripJsonComments from 'strip-json-comments';
|
||||
|
||||
export const SETTINGS_DIRECTORY_NAME = '.gemini';
|
||||
export const USER_SETTINGS_DIR = path.join(homedir(), SETTINGS_DIRECTORY_NAME);
|
||||
export const USER_SETTINGS_PATH = path.join(USER_SETTINGS_DIR, 'settings.json');
|
||||
|
||||
// Reconcile with https://github.com/google-gemini/gemini-cli/blob/b09bc6656080d4d12e1d06734aae2ec33af5c1ed/packages/cli/src/config/settings.ts#L53
|
||||
export interface Settings {
|
||||
mcpServers?: Record<string, MCPServerConfig>;
|
||||
coreTools?: string[];
|
||||
excludeTools?: string[];
|
||||
telemetry?: TelemetrySettings;
|
||||
showMemoryUsage?: boolean;
|
||||
checkpointing?: CheckpointingSettings;
|
||||
|
||||
// Git-aware file filtering settings
|
||||
fileFiltering?: {
|
||||
respectGitIgnore?: boolean;
|
||||
enableRecursiveFileSearch?: boolean;
|
||||
};
|
||||
}
|
||||
|
||||
export interface SettingsError {
|
||||
message: string;
|
||||
path: string;
|
||||
}
|
||||
|
||||
export interface CheckpointingSettings {
|
||||
enabled?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads settings from user and workspace directories.
|
||||
* Project settings override user settings.
|
||||
*
|
||||
* How is it different to gemini-cli/cli: Returns already merged settings rather
|
||||
* than `LoadedSettings` (unnecessary since we are not modifying users
|
||||
* settings.json).
|
||||
*/
|
||||
export function loadSettings(workspaceDir: string): Settings {
|
||||
let userSettings: Settings = {};
|
||||
let workspaceSettings: Settings = {};
|
||||
const settingsErrors: SettingsError[] = [];
|
||||
|
||||
// Load user settings
|
||||
try {
|
||||
if (fs.existsSync(USER_SETTINGS_PATH)) {
|
||||
const userContent = fs.readFileSync(USER_SETTINGS_PATH, 'utf-8');
|
||||
const parsedUserSettings = JSON.parse(
|
||||
stripJsonComments(userContent),
|
||||
) as Settings;
|
||||
userSettings = resolveEnvVarsInObject(parsedUserSettings);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
settingsErrors.push({
|
||||
message: getErrorMessage(error),
|
||||
path: USER_SETTINGS_PATH,
|
||||
});
|
||||
}
|
||||
|
||||
const workspaceSettingsPath = path.join(
|
||||
workspaceDir,
|
||||
SETTINGS_DIRECTORY_NAME,
|
||||
'settings.json',
|
||||
);
|
||||
|
||||
// Load workspace settings
|
||||
try {
|
||||
if (fs.existsSync(workspaceSettingsPath)) {
|
||||
const projectContent = fs.readFileSync(workspaceSettingsPath, 'utf-8');
|
||||
const parsedWorkspaceSettings = JSON.parse(
|
||||
stripJsonComments(projectContent),
|
||||
) as Settings;
|
||||
workspaceSettings = resolveEnvVarsInObject(parsedWorkspaceSettings);
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
settingsErrors.push({
|
||||
message: getErrorMessage(error),
|
||||
path: workspaceSettingsPath,
|
||||
});
|
||||
}
|
||||
|
||||
if (settingsErrors.length > 0) {
|
||||
console.error('Errors loading settings:');
|
||||
for (const error of settingsErrors) {
|
||||
console.error(` Path: ${error.path}`);
|
||||
console.error(` Message: ${error.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
// If there are overlapping keys, the values of workspaceSettings will
|
||||
// override values from userSettings
|
||||
return {
|
||||
...userSettings,
|
||||
...workspaceSettings,
|
||||
};
|
||||
}
|
||||
|
||||
function resolveEnvVarsInString(value: string): string {
|
||||
const envVarRegex = /\$(?:(\w+)|{([^}]+)})/g; // Find $VAR_NAME or ${VAR_NAME}
|
||||
return value.replace(envVarRegex, (match, varName1, varName2) => {
|
||||
const varName = varName1 || varName2;
|
||||
if (process && process.env && typeof process.env[varName] === 'string') {
|
||||
return process.env[varName]!;
|
||||
}
|
||||
return match;
|
||||
});
|
||||
}
|
||||
|
||||
function resolveEnvVarsInObject<T>(obj: T): T {
|
||||
if (
|
||||
obj === null ||
|
||||
obj === undefined ||
|
||||
typeof obj === 'boolean' ||
|
||||
typeof obj === 'number'
|
||||
) {
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (typeof obj === 'string') {
|
||||
return resolveEnvVarsInString(obj) as unknown as T;
|
||||
}
|
||||
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map((item) => resolveEnvVarsInObject(item)) as unknown as T;
|
||||
}
|
||||
|
||||
if (typeof obj === 'object') {
|
||||
const newObj = { ...obj } as T;
|
||||
for (const key in newObj) {
|
||||
if (Object.prototype.hasOwnProperty.call(newObj, key)) {
|
||||
newObj[key] = resolveEnvVarsInObject(newObj[key]);
|
||||
}
|
||||
}
|
||||
return newObj;
|
||||
}
|
||||
|
||||
return obj;
|
||||
}
|
||||
930
packages/a2a-server/src/task.ts
Normal file
930
packages/a2a-server/src/task.ts
Normal file
@@ -0,0 +1,930 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import {
|
||||
CoreToolScheduler,
|
||||
GeminiClient,
|
||||
GeminiEventType,
|
||||
ToolConfirmationOutcome,
|
||||
ApprovalMode,
|
||||
getAllMCPServerStatuses,
|
||||
MCPServerStatus,
|
||||
isNodeError,
|
||||
parseAndFormatApiError,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type {
|
||||
ToolConfirmationPayload,
|
||||
CompletedToolCall,
|
||||
ToolCall,
|
||||
ToolCallRequestInfo,
|
||||
ServerGeminiErrorEvent,
|
||||
ServerGeminiStreamEvent,
|
||||
ToolCallConfirmationDetails,
|
||||
Config,
|
||||
UserTierId,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { RequestContext } from '@a2a-js/sdk/server';
|
||||
import { type ExecutionEventBus } from '@a2a-js/sdk/server';
|
||||
import type {
|
||||
TaskStatusUpdateEvent,
|
||||
TaskArtifactUpdateEvent,
|
||||
TaskState,
|
||||
Message,
|
||||
Part,
|
||||
Artifact,
|
||||
} from '@a2a-js/sdk';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import { logger } from './logger.js';
|
||||
import * as fs from 'node:fs';
|
||||
|
||||
import { CoderAgentEvent } from './types.js';
|
||||
import type {
|
||||
CoderAgentMessage,
|
||||
StateChange,
|
||||
ToolCallUpdate,
|
||||
TextContent,
|
||||
TaskMetadata,
|
||||
Thought,
|
||||
ThoughtSummary,
|
||||
} from './types.js';
|
||||
import type { PartUnion, Part as genAiPart } from '@google/genai';
|
||||
|
||||
export class Task {
|
||||
id: string;
|
||||
contextId: string;
|
||||
scheduler: CoreToolScheduler;
|
||||
config: Config;
|
||||
geminiClient: GeminiClient;
|
||||
pendingToolConfirmationDetails: Map<string, ToolCallConfirmationDetails>;
|
||||
taskState: TaskState;
|
||||
eventBus?: ExecutionEventBus;
|
||||
completedToolCalls: CompletedToolCall[];
|
||||
skipFinalTrueAfterInlineEdit = false;
|
||||
|
||||
// For tool waiting logic
|
||||
private pendingToolCalls: Map<string, string> = new Map(); //toolCallId --> status
|
||||
private toolCompletionPromise?: Promise<void>;
|
||||
private toolCompletionNotifier?: {
|
||||
resolve: () => void;
|
||||
reject: (reason?: Error) => void;
|
||||
};
|
||||
|
||||
private constructor(
|
||||
id: string,
|
||||
contextId: string,
|
||||
config: Config,
|
||||
eventBus?: ExecutionEventBus,
|
||||
) {
|
||||
this.id = id;
|
||||
this.contextId = contextId;
|
||||
this.config = config;
|
||||
this.scheduler = this.createScheduler();
|
||||
this.geminiClient = new GeminiClient(this.config);
|
||||
this.pendingToolConfirmationDetails = new Map();
|
||||
this.taskState = 'submitted';
|
||||
this.eventBus = eventBus;
|
||||
this.completedToolCalls = [];
|
||||
this._resetToolCompletionPromise();
|
||||
this.config.setFlashFallbackHandler(
|
||||
async (currentModel: string, fallbackModel: string): Promise<boolean> => {
|
||||
config.setModel(fallbackModel); // gemini-cli-core sets to DEFAULT_GEMINI_FLASH_MODEL
|
||||
// Switch model for future use but return false to stop current retry
|
||||
return false;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
static async create(
|
||||
id: string,
|
||||
contextId: string,
|
||||
config: Config,
|
||||
eventBus?: ExecutionEventBus,
|
||||
): Promise<Task> {
|
||||
return new Task(id, contextId, config, eventBus);
|
||||
}
|
||||
|
||||
// Note: `getAllMCPServerStatuses` retrieves the status of all MCP servers for the entire
|
||||
// process. This is not scoped to the individual task but reflects the global connection
|
||||
// state managed within the @gemini-cli/core module.
|
||||
async getMetadata(): Promise<TaskMetadata> {
|
||||
const toolRegistry = await this.config.getToolRegistry();
|
||||
const mcpServers = this.config.getMcpServers() || {};
|
||||
const serverStatuses = getAllMCPServerStatuses();
|
||||
const servers = Object.keys(mcpServers).map((serverName) => ({
|
||||
name: serverName,
|
||||
status: serverStatuses.get(serverName) || MCPServerStatus.DISCONNECTED,
|
||||
tools: toolRegistry.getToolsByServer(serverName).map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameterSchema: tool.schema.parameters,
|
||||
})),
|
||||
}));
|
||||
|
||||
const availableTools = toolRegistry.getAllTools().map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameterSchema: tool.schema.parameters,
|
||||
}));
|
||||
|
||||
const metadata: TaskMetadata = {
|
||||
id: this.id,
|
||||
contextId: this.contextId,
|
||||
taskState: this.taskState,
|
||||
model: this.config.getContentGeneratorConfig().model,
|
||||
mcpServers: servers,
|
||||
availableTools,
|
||||
};
|
||||
return metadata;
|
||||
}
|
||||
|
||||
private _resetToolCompletionPromise(): void {
|
||||
this.toolCompletionPromise = new Promise((resolve, reject) => {
|
||||
this.toolCompletionNotifier = { resolve, reject };
|
||||
});
|
||||
// If there are no pending calls when reset, resolve immediately.
|
||||
if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.resolve();
|
||||
}
|
||||
}
|
||||
|
||||
private _registerToolCall(toolCallId: string, status: string): void {
|
||||
const wasEmpty = this.pendingToolCalls.size === 0;
|
||||
this.pendingToolCalls.set(toolCallId, status);
|
||||
if (wasEmpty) {
|
||||
this._resetToolCompletionPromise();
|
||||
}
|
||||
logger.info(
|
||||
`[Task] Registered tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
||||
);
|
||||
}
|
||||
|
||||
private _resolveToolCall(toolCallId: string): void {
|
||||
if (this.pendingToolCalls.has(toolCallId)) {
|
||||
this.pendingToolCalls.delete(toolCallId);
|
||||
logger.info(
|
||||
`[Task] Resolved tool call: ${toolCallId}. Pending: ${this.pendingToolCalls.size}`,
|
||||
);
|
||||
if (this.pendingToolCalls.size === 0 && this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.resolve();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async waitForPendingTools(): Promise<void> {
|
||||
if (this.pendingToolCalls.size === 0) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
logger.info(
|
||||
`[Task] Waiting for ${this.pendingToolCalls.size} pending tool(s)...`,
|
||||
);
|
||||
return this.toolCompletionPromise;
|
||||
}
|
||||
|
||||
cancelPendingTools(reason: string): void {
|
||||
if (this.pendingToolCalls.size > 0) {
|
||||
logger.info(
|
||||
`[Task] Cancelling all ${this.pendingToolCalls.size} pending tool calls. Reason: ${reason}`,
|
||||
);
|
||||
}
|
||||
if (this.toolCompletionNotifier) {
|
||||
this.toolCompletionNotifier.reject(new Error(reason));
|
||||
}
|
||||
this.pendingToolCalls.clear();
|
||||
// Reset the promise for any future operations, ensuring it's in a clean state.
|
||||
this._resetToolCompletionPromise();
|
||||
}
|
||||
|
||||
private _createTextMessage(
|
||||
text: string,
|
||||
role: 'agent' | 'user' = 'agent',
|
||||
): Message {
|
||||
return {
|
||||
kind: 'message',
|
||||
role,
|
||||
parts: [{ kind: 'text', text }],
|
||||
messageId: uuidv4(),
|
||||
taskId: this.id,
|
||||
contextId: this.contextId,
|
||||
};
|
||||
}
|
||||
|
||||
private _createStatusUpdateEvent(
|
||||
stateToReport: TaskState,
|
||||
coderAgentMessage: CoderAgentMessage,
|
||||
message?: Message,
|
||||
final = false,
|
||||
timestamp?: string,
|
||||
metadataError?: string,
|
||||
): TaskStatusUpdateEvent {
|
||||
const metadata: {
|
||||
coderAgent: CoderAgentMessage;
|
||||
model: string;
|
||||
userTier?: UserTierId;
|
||||
error?: string;
|
||||
} = {
|
||||
coderAgent: coderAgentMessage,
|
||||
model: this.config.getModel(),
|
||||
userTier: this.geminiClient.getUserTier(),
|
||||
};
|
||||
|
||||
if (metadataError) {
|
||||
metadata.error = metadataError;
|
||||
}
|
||||
|
||||
return {
|
||||
kind: 'status-update',
|
||||
taskId: this.id,
|
||||
contextId: this.contextId,
|
||||
status: {
|
||||
state: stateToReport,
|
||||
message, // Shorthand property
|
||||
timestamp: timestamp || new Date().toISOString(),
|
||||
},
|
||||
final,
|
||||
metadata,
|
||||
};
|
||||
}
|
||||
|
||||
setTaskStateAndPublishUpdate(
|
||||
newState: TaskState,
|
||||
coderAgentMessage: CoderAgentMessage,
|
||||
messageText?: string,
|
||||
messageParts?: Part[], // For more complex messages
|
||||
final = false,
|
||||
metadataError?: string,
|
||||
): void {
|
||||
this.taskState = newState;
|
||||
let message: Message | undefined;
|
||||
|
||||
if (messageText) {
|
||||
message = this._createTextMessage(messageText);
|
||||
} else if (messageParts) {
|
||||
message = {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: messageParts,
|
||||
messageId: uuidv4(),
|
||||
taskId: this.id,
|
||||
contextId: this.contextId,
|
||||
};
|
||||
}
|
||||
|
||||
const event = this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
coderAgentMessage,
|
||||
message,
|
||||
final,
|
||||
undefined,
|
||||
metadataError,
|
||||
);
|
||||
this.eventBus?.publish(event);
|
||||
}
|
||||
|
||||
private _schedulerOutputUpdate(
|
||||
toolCallId: string,
|
||||
outputChunk: string,
|
||||
): void {
|
||||
logger.info(
|
||||
'[Task] Scheduler output update for tool call ' +
|
||||
toolCallId +
|
||||
': ' +
|
||||
outputChunk,
|
||||
);
|
||||
const artifact: Artifact = {
|
||||
artifactId: `tool-${toolCallId}-output`,
|
||||
parts: [
|
||||
{
|
||||
kind: 'text',
|
||||
text: outputChunk,
|
||||
} as Part,
|
||||
],
|
||||
};
|
||||
const artifactEvent: TaskArtifactUpdateEvent = {
|
||||
kind: 'artifact-update',
|
||||
taskId: this.id,
|
||||
contextId: this.contextId,
|
||||
artifact,
|
||||
append: true,
|
||||
lastChunk: false,
|
||||
};
|
||||
this.eventBus?.publish(artifactEvent);
|
||||
}
|
||||
|
||||
private async _schedulerAllToolCallsComplete(
|
||||
completedToolCalls: CompletedToolCall[],
|
||||
): Promise<void> {
|
||||
logger.info(
|
||||
'[Task] All tool calls completed by scheduler (batch):',
|
||||
completedToolCalls.map((tc) => tc.request.callId),
|
||||
);
|
||||
this.completedToolCalls.push(...completedToolCalls);
|
||||
completedToolCalls.forEach((tc) => {
|
||||
this._resolveToolCall(tc.request.callId);
|
||||
});
|
||||
}
|
||||
|
||||
private _schedulerToolCallsUpdate(toolCalls: ToolCall[]): void {
|
||||
logger.info(
|
||||
'[Task] Scheduler tool calls updated:',
|
||||
toolCalls.map((tc) => `${tc.request.callId} (${tc.status})`),
|
||||
);
|
||||
|
||||
// Update state and send continuous, non-final updates
|
||||
toolCalls.forEach((tc) => {
|
||||
const previousStatus = this.pendingToolCalls.get(tc.request.callId);
|
||||
const hasChanged = previousStatus !== tc.status;
|
||||
|
||||
// Resolve tool call if it has reached a terminal state
|
||||
if (['success', 'error', 'cancelled'].includes(tc.status)) {
|
||||
this._resolveToolCall(tc.request.callId);
|
||||
} else {
|
||||
// This will update the map
|
||||
this._registerToolCall(tc.request.callId, tc.status);
|
||||
}
|
||||
|
||||
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
|
||||
this.pendingToolConfirmationDetails.set(
|
||||
tc.request.callId,
|
||||
tc.confirmationDetails,
|
||||
);
|
||||
}
|
||||
|
||||
// Only send an update if the status has actually changed.
|
||||
if (hasChanged) {
|
||||
const message = this.toolStatusMessage(tc, this.id, this.contextId);
|
||||
const coderAgentMessage: CoderAgentMessage =
|
||||
tc.status === 'awaiting_approval'
|
||||
? { kind: CoderAgentEvent.ToolCallConfirmationEvent }
|
||||
: { kind: CoderAgentEvent.ToolCallUpdateEvent };
|
||||
|
||||
const event = this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
coderAgentMessage,
|
||||
message,
|
||||
false, // Always false for these continuous updates
|
||||
);
|
||||
this.eventBus?.publish(event);
|
||||
}
|
||||
});
|
||||
|
||||
if (this.config.getApprovalMode() === ApprovalMode.YOLO) {
|
||||
logger.info('[Task] YOLO mode enabled. Auto-approving all tool calls.');
|
||||
toolCalls.forEach((tc: ToolCall) => {
|
||||
if (tc.status === 'awaiting_approval' && tc.confirmationDetails) {
|
||||
tc.confirmationDetails.onConfirm(ToolConfirmationOutcome.ProceedOnce);
|
||||
this.pendingToolConfirmationDetails.delete(tc.request.callId);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const allPendingStatuses = Array.from(this.pendingToolCalls.values());
|
||||
const isAwaitingApproval = allPendingStatuses.some(
|
||||
(status) => status === 'awaiting_approval',
|
||||
);
|
||||
const allPendingAreStable = allPendingStatuses.every(
|
||||
(status) =>
|
||||
status === 'awaiting_approval' ||
|
||||
status === 'success' ||
|
||||
status === 'error' ||
|
||||
status === 'cancelled',
|
||||
);
|
||||
|
||||
// 1. Are any pending tool calls awaiting_approval
|
||||
// 2. Are all pending tool calls in a stable state (i.e. not in validing or executing)
|
||||
// 3. After an inline edit, the edited tool call will send awaiting_approval THEN scheduled. We wait for the next update in this case.
|
||||
if (
|
||||
isAwaitingApproval &&
|
||||
allPendingAreStable &&
|
||||
!this.skipFinalTrueAfterInlineEdit
|
||||
) {
|
||||
this.skipFinalTrueAfterInlineEdit = false;
|
||||
|
||||
// We don't need to send another message, just a final status update.
|
||||
this.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
{ kind: CoderAgentEvent.StateChangeEvent },
|
||||
undefined,
|
||||
undefined,
|
||||
/*final*/ true,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private createScheduler(): CoreToolScheduler {
|
||||
const scheduler = new CoreToolScheduler({
|
||||
outputUpdateHandler: this._schedulerOutputUpdate.bind(this),
|
||||
onAllToolCallsComplete: this._schedulerAllToolCallsComplete.bind(this),
|
||||
onToolCallsUpdate: this._schedulerToolCallsUpdate.bind(this),
|
||||
getPreferredEditor: () => 'vscode',
|
||||
config: this.config,
|
||||
onEditorClose: () => {},
|
||||
});
|
||||
return scheduler;
|
||||
}
|
||||
|
||||
private toolStatusMessage(
|
||||
tc: ToolCall,
|
||||
taskId: string,
|
||||
contextId: string,
|
||||
): Message {
|
||||
const messageParts: Part[] = [];
|
||||
|
||||
// Create a serializable version of the ToolCall (pick necesssary
|
||||
// properties/avoic methods causing circular reference errors)
|
||||
const serializableToolCall: { [key: string]: unknown } = {
|
||||
request: tc.request,
|
||||
status: tc.status,
|
||||
};
|
||||
|
||||
// For WaitingToolCall type
|
||||
if ('confirmationDetails' in tc) {
|
||||
serializableToolCall['confirmationDetails'] = tc.confirmationDetails;
|
||||
}
|
||||
|
||||
if (tc.tool) {
|
||||
serializableToolCall['tool'] = {
|
||||
name: tc.tool.name,
|
||||
displayName: tc.tool.displayName,
|
||||
description: tc.tool.description,
|
||||
kind: tc.tool.kind,
|
||||
isOutputMarkdown: tc.tool.isOutputMarkdown,
|
||||
canUpdateOutput: tc.tool.canUpdateOutput,
|
||||
schema: tc.tool.schema,
|
||||
parameterSchema: tc.tool.parameterSchema,
|
||||
};
|
||||
}
|
||||
|
||||
messageParts.push({
|
||||
kind: 'data',
|
||||
data: serializableToolCall as ToolCall,
|
||||
} as Part);
|
||||
|
||||
return {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: messageParts,
|
||||
messageId: uuidv4(),
|
||||
taskId,
|
||||
contextId,
|
||||
};
|
||||
}
|
||||
|
||||
private async getProposedContent(
|
||||
file_path: string,
|
||||
old_string: string,
|
||||
new_string: string,
|
||||
): Promise<string> {
|
||||
try {
|
||||
const currentContent = fs.readFileSync(file_path, 'utf8');
|
||||
return this._applyReplacement(
|
||||
currentContent,
|
||||
old_string,
|
||||
new_string,
|
||||
old_string === '' && currentContent === '',
|
||||
);
|
||||
} catch (err) {
|
||||
if (!isNodeError(err) || err.code !== 'ENOENT') throw err;
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
private _applyReplacement(
|
||||
currentContent: string | null,
|
||||
oldString: string,
|
||||
newString: string,
|
||||
isNewFile: boolean,
|
||||
): string {
|
||||
if (isNewFile) {
|
||||
return newString;
|
||||
}
|
||||
if (currentContent === null) {
|
||||
// Should not happen if not a new file, but defensively return empty or newString if oldString is also empty
|
||||
return oldString === '' ? newString : '';
|
||||
}
|
||||
// If oldString is empty and it's not a new file, do not modify the content.
|
||||
if (oldString === '' && !isNewFile) {
|
||||
return currentContent;
|
||||
}
|
||||
return currentContent.replaceAll(oldString, newString);
|
||||
}
|
||||
|
||||
async scheduleToolCalls(
|
||||
requests: ToolCallRequestInfo[],
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<void> {
|
||||
if (requests.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (const request of requests) {
|
||||
if (
|
||||
!request.args['newContent'] &&
|
||||
request.name === 'replace' &&
|
||||
request.args &&
|
||||
request.args['file_path'] &&
|
||||
request.args['old_string'] &&
|
||||
request.args['new_string']
|
||||
) {
|
||||
request.args['newContent'] = await this.getProposedContent(
|
||||
request.args['file_path'] as string,
|
||||
request.args['old_string'] as string,
|
||||
request.args['new_string'] as string,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`[Task] Scheduling batch of ${requests.length} tool calls.`);
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||
|
||||
await this.scheduler.schedule(requests, abortSignal);
|
||||
}
|
||||
|
||||
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
switch (event.type) {
|
||||
case GeminiEventType.Content:
|
||||
logger.info('[Task] Sending agent message content...');
|
||||
this._sendTextContent(event.value);
|
||||
break;
|
||||
case GeminiEventType.ToolCallRequest:
|
||||
// This is now handled by the agent loop, which collects all requests
|
||||
// and calls scheduleToolCalls once.
|
||||
logger.warn(
|
||||
'[Task] A single tool call request was passed to acceptAgentMessage. This should be handled in a batch by the agent. Ignoring.',
|
||||
);
|
||||
break;
|
||||
case GeminiEventType.ToolCallResponse:
|
||||
// This event type from ServerGeminiStreamEvent might be for when LLM *generates* a tool response part.
|
||||
// The actual execution result comes via user message.
|
||||
logger.info(
|
||||
'[Task] Received tool call response from LLM (part of generation):',
|
||||
event.value,
|
||||
);
|
||||
break;
|
||||
case GeminiEventType.ToolCallConfirmation:
|
||||
// This is when LLM requests confirmation, not when user provides it.
|
||||
logger.info(
|
||||
'[Task] Received tool call confirmation request from LLM:',
|
||||
event.value.request.callId,
|
||||
);
|
||||
this.pendingToolConfirmationDetails.set(
|
||||
event.value.request.callId,
|
||||
event.value.details,
|
||||
);
|
||||
// This will be handled by the scheduler and _schedulerToolCallsUpdate will set InputRequired if needed.
|
||||
// No direct state change here, scheduler drives it.
|
||||
break;
|
||||
case GeminiEventType.UserCancelled:
|
||||
logger.info('[Task] Received user cancelled event from LLM stream.');
|
||||
this.cancelPendingTools('User cancelled via LLM stream event');
|
||||
this.setTaskStateAndPublishUpdate(
|
||||
'input-required',
|
||||
stateChange,
|
||||
'Task cancelled by user',
|
||||
undefined,
|
||||
true,
|
||||
);
|
||||
break;
|
||||
case GeminiEventType.Thought:
|
||||
logger.info('[Task] Sending agent thought...');
|
||||
this._sendThought(event.value);
|
||||
break;
|
||||
case GeminiEventType.ChatCompressed:
|
||||
break;
|
||||
case GeminiEventType.Finished:
|
||||
logger.info(`[Task ${this.id}] Agent finished its turn.`);
|
||||
break;
|
||||
case GeminiEventType.Error:
|
||||
default: {
|
||||
// Block scope for lexical declaration
|
||||
const errorEvent = event as ServerGeminiErrorEvent; // Type assertion
|
||||
const errorMessage =
|
||||
errorEvent.value?.error.message ?? 'Unknown error from LLM stream';
|
||||
logger.error(
|
||||
'[Task] Received error event from LLM stream:',
|
||||
errorMessage,
|
||||
);
|
||||
|
||||
let errMessage = 'Unknown error from LLM stream';
|
||||
if (errorEvent.value) {
|
||||
errMessage = parseAndFormatApiError(errorEvent.value);
|
||||
}
|
||||
this.cancelPendingTools(`LLM stream error: ${errorMessage}`);
|
||||
this.setTaskStateAndPublishUpdate(
|
||||
this.taskState,
|
||||
stateChange,
|
||||
`Agent Error, unknown agent message: ${errorMessage}`,
|
||||
undefined,
|
||||
false,
|
||||
errMessage,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private async _handleToolConfirmationPart(part: Part): Promise<boolean> {
|
||||
if (
|
||||
part.kind !== 'data' ||
|
||||
!part.data ||
|
||||
typeof part.data['callId'] !== 'string' ||
|
||||
typeof part.data['outcome'] !== 'string'
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const callId = part.data['callId'] as string;
|
||||
const outcomeString = part.data['outcome'] as string;
|
||||
let confirmationOutcome: ToolConfirmationOutcome | undefined;
|
||||
|
||||
if (outcomeString === 'proceed_once') {
|
||||
confirmationOutcome = ToolConfirmationOutcome.ProceedOnce;
|
||||
} else if (outcomeString === 'cancel') {
|
||||
confirmationOutcome = ToolConfirmationOutcome.Cancel;
|
||||
} else if (outcomeString === 'proceed_always') {
|
||||
confirmationOutcome = ToolConfirmationOutcome.ProceedAlways;
|
||||
} else if (outcomeString === 'proceed_always_server') {
|
||||
confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysServer;
|
||||
} else if (outcomeString === 'proceed_always_tool') {
|
||||
confirmationOutcome = ToolConfirmationOutcome.ProceedAlwaysTool;
|
||||
} else if (outcomeString === 'modify_with_editor') {
|
||||
confirmationOutcome = ToolConfirmationOutcome.ModifyWithEditor;
|
||||
} else {
|
||||
logger.warn(
|
||||
`[Task] Unknown tool confirmation outcome: "${outcomeString}" for callId: ${callId}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
const confirmationDetails = this.pendingToolConfirmationDetails.get(callId);
|
||||
|
||||
if (!confirmationDetails) {
|
||||
logger.warn(
|
||||
`[Task] Received tool confirmation for unknown or already processed callId: ${callId}`,
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
logger.info(
|
||||
`[Task] Handling tool confirmation for callId: ${callId} with outcome: ${outcomeString}`,
|
||||
);
|
||||
try {
|
||||
// Temporarily unset GCP environment variables so they do not leak into
|
||||
// tool calls.
|
||||
const gcpProject = process.env['GOOGLE_CLOUD_PROJECT'];
|
||||
const gcpCreds = process.env['GOOGLE_APPLICATION_CREDENTIALS'];
|
||||
try {
|
||||
delete process.env['GOOGLE_CLOUD_PROJECT'];
|
||||
delete process.env['GOOGLE_APPLICATION_CREDENTIALS'];
|
||||
|
||||
// This will trigger the scheduler to continue or cancel the specific tool.
|
||||
// The scheduler's onToolCallsUpdate will then reflect the new state (e.g., executing or cancelled).
|
||||
|
||||
// If `edit` tool call, pass updated payload if presesent
|
||||
if (confirmationDetails.type === 'edit') {
|
||||
const payload = part.data['newContent']
|
||||
? ({
|
||||
newContent: part.data['newContent'] as string,
|
||||
} as ToolConfirmationPayload)
|
||||
: undefined;
|
||||
this.skipFinalTrueAfterInlineEdit = !!payload;
|
||||
await confirmationDetails.onConfirm(confirmationOutcome, payload);
|
||||
} else {
|
||||
await confirmationDetails.onConfirm(confirmationOutcome);
|
||||
}
|
||||
} finally {
|
||||
if (gcpProject) {
|
||||
process.env['GOOGLE_CLOUD_PROJECT'] = gcpProject;
|
||||
}
|
||||
if (gcpCreds) {
|
||||
process.env['GOOGLE_APPLICATION_CREDENTIALS'] = gcpCreds;
|
||||
}
|
||||
}
|
||||
|
||||
// Do not delete if modifying, a subsequent tool confirmation for the same
|
||||
// callId will be passed with ProceedOnce/Cancel/etc
|
||||
// Note !== ToolConfirmationOutcome.ModifyWithEditor does not work!
|
||||
if (confirmationOutcome !== 'modify_with_editor') {
|
||||
this.pendingToolConfirmationDetails.delete(callId);
|
||||
}
|
||||
|
||||
// If outcome is Cancel, scheduler should update status to 'cancelled', which then resolves the tool.
|
||||
// If ProceedOnce, scheduler updates to 'executing', then eventually 'success'/'error', which resolves.
|
||||
return true;
|
||||
} catch (error) {
|
||||
logger.error(
|
||||
`[Task] Error during tool confirmation for callId ${callId}:`,
|
||||
error,
|
||||
);
|
||||
// If confirming fails, we should probably mark this tool as failed
|
||||
this._resolveToolCall(callId); // Resolve it as it won't proceed.
|
||||
const errorMessageText =
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: `Error processing tool confirmation for ${callId}`;
|
||||
const message = this._createTextMessage(errorMessageText);
|
||||
const toolCallUpdate: ToolCallUpdate = {
|
||||
kind: CoderAgentEvent.ToolCallUpdateEvent,
|
||||
};
|
||||
const event = this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
toolCallUpdate,
|
||||
message,
|
||||
false,
|
||||
);
|
||||
this.eventBus?.publish(event);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
getAndClearCompletedTools(): CompletedToolCall[] {
|
||||
const tools = [...this.completedToolCalls];
|
||||
this.completedToolCalls = [];
|
||||
return tools;
|
||||
}
|
||||
|
||||
addToolResponsesToHistory(completedTools: CompletedToolCall[]): void {
|
||||
logger.info(
|
||||
`[Task] Adding ${completedTools.length} tool responses to history without generating a new response.`,
|
||||
);
|
||||
const responsesToAdd = completedTools.flatMap(
|
||||
(toolCall) => toolCall.response.responseParts,
|
||||
);
|
||||
|
||||
for (const response of responsesToAdd) {
|
||||
let parts: genAiPart[];
|
||||
if (Array.isArray(response)) {
|
||||
parts = response;
|
||||
} else if (typeof response === 'string') {
|
||||
parts = [{ text: response }];
|
||||
} else {
|
||||
parts = [response];
|
||||
}
|
||||
this.geminiClient.addHistory({
|
||||
role: 'user',
|
||||
parts,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async *sendCompletedToolsToLlm(
|
||||
completedToolCalls: CompletedToolCall[],
|
||||
aborted: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
if (completedToolCalls.length === 0) {
|
||||
yield* (async function* () {})(); // Yield nothing
|
||||
return;
|
||||
}
|
||||
|
||||
const llmParts: PartUnion[] = [];
|
||||
logger.info(
|
||||
`[Task] Feeding ${completedToolCalls.length} tool responses to LLM.`,
|
||||
);
|
||||
for (const completedToolCall of completedToolCalls) {
|
||||
logger.info(
|
||||
`[Task] Adding tool response for "${completedToolCall.request.name}" (callId: ${completedToolCall.request.callId}) to LLM input.`,
|
||||
);
|
||||
const responseParts = completedToolCall.response.responseParts;
|
||||
if (Array.isArray(responseParts)) {
|
||||
llmParts.push(...responseParts);
|
||||
} else {
|
||||
llmParts.push(responseParts);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info('[Task] Sending new parts to agent.');
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
// Set task state to working as we are about to call LLM
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||
// TODO: Determine what it mean to have, then add a prompt ID.
|
||||
yield* this.geminiClient.sendMessageStream(
|
||||
llmParts,
|
||||
aborted,
|
||||
/*prompt_id*/ '',
|
||||
);
|
||||
}
|
||||
|
||||
async *acceptUserMessage(
|
||||
requestContext: RequestContext,
|
||||
aborted: AbortSignal,
|
||||
): AsyncGenerator<ServerGeminiStreamEvent> {
|
||||
const userMessage = requestContext.userMessage;
|
||||
const llmParts: PartUnion[] = [];
|
||||
let anyConfirmationHandled = false;
|
||||
let hasContentForLlm = false;
|
||||
|
||||
for (const part of userMessage.parts) {
|
||||
const confirmationHandled = await this._handleToolConfirmationPart(part);
|
||||
if (confirmationHandled) {
|
||||
anyConfirmationHandled = true;
|
||||
// If a confirmation was handled, the scheduler will now run the tool (or cancel it).
|
||||
// We don't send anything to the LLM for this part.
|
||||
// The subsequent tool execution will eventually lead to resolveToolCall.
|
||||
continue;
|
||||
}
|
||||
|
||||
if (part.kind === 'text') {
|
||||
llmParts.push({ text: part.text });
|
||||
hasContentForLlm = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasContentForLlm) {
|
||||
logger.info('[Task] Sending new parts to LLM.');
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
// Set task state to working as we are about to call LLM
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange);
|
||||
// TODO: Determine what it mean to have, then add a prompt ID.
|
||||
yield* this.geminiClient.sendMessageStream(
|
||||
llmParts,
|
||||
aborted,
|
||||
/*prompt_id*/ '',
|
||||
);
|
||||
} else if (anyConfirmationHandled) {
|
||||
logger.info(
|
||||
'[Task] User message only contained tool confirmations. Scheduler is active. No new input for LLM this turn.',
|
||||
);
|
||||
// Ensure task state reflects that scheduler might be working due to confirmation.
|
||||
// If scheduler is active, it will emit its own status updates.
|
||||
// If all pending tools were just confirmed, waitForPendingTools will handle the wait.
|
||||
// If some tools are still pending approval, scheduler would have set InputRequired.
|
||||
// If not, and no new text, we are just waiting.
|
||||
if (
|
||||
this.pendingToolCalls.size > 0 &&
|
||||
this.taskState !== 'input-required'
|
||||
) {
|
||||
const stateChange: StateChange = {
|
||||
kind: CoderAgentEvent.StateChangeEvent,
|
||||
};
|
||||
this.setTaskStateAndPublishUpdate('working', stateChange); // Reflect potential background activity
|
||||
}
|
||||
yield* (async function* () {})(); // Yield nothing
|
||||
} else {
|
||||
logger.info(
|
||||
'[Task] No relevant parts in user message for LLM interaction or tool confirmation.',
|
||||
);
|
||||
// If there's no new text and no confirmations, and no pending tools,
|
||||
// it implies we might need to signal input required if nothing else is happening.
|
||||
// However, the agent.ts will make this determination after waitForPendingTools.
|
||||
yield* (async function* () {})(); // Yield nothing
|
||||
}
|
||||
}
|
||||
|
||||
_sendTextContent(content: string): void {
|
||||
if (content === '') {
|
||||
return;
|
||||
}
|
||||
logger.info('[Task] Sending text content to event bus.');
|
||||
const message = this._createTextMessage(content);
|
||||
const textContent: TextContent = {
|
||||
kind: CoderAgentEvent.TextContentEvent,
|
||||
};
|
||||
this.eventBus?.publish(
|
||||
this._createStatusUpdateEvent(
|
||||
this.taskState,
|
||||
textContent,
|
||||
message,
|
||||
false,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
_sendThought(content: ThoughtSummary): void {
|
||||
if (!content.subject && !content.description) {
|
||||
return;
|
||||
}
|
||||
logger.info('[Task] Sending thought to event bus.');
|
||||
const message: Message = {
|
||||
kind: 'message',
|
||||
role: 'agent',
|
||||
parts: [
|
||||
{
|
||||
kind: 'data',
|
||||
data: content,
|
||||
} as Part,
|
||||
],
|
||||
messageId: uuidv4(),
|
||||
taskId: this.id,
|
||||
contextId: this.contextId,
|
||||
};
|
||||
const thought: Thought = {
|
||||
kind: CoderAgentEvent.ThoughtEvent,
|
||||
};
|
||||
this.eventBus?.publish(
|
||||
this._createStatusUpdateEvent(this.taskState, thought, message, false),
|
||||
);
|
||||
}
|
||||
}
|
||||
180
packages/a2a-server/src/testing_utils.ts
Normal file
180
packages/a2a-server/src/testing_utils.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
Task as SDKTask,
|
||||
TaskStatusUpdateEvent,
|
||||
SendStreamingMessageSuccessResponse,
|
||||
} from '@a2a-js/sdk';
|
||||
import {
|
||||
BaseDeclarativeTool,
|
||||
BaseToolInvocation,
|
||||
Kind,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type {
|
||||
ToolCallConfirmationDetails,
|
||||
ToolResult,
|
||||
ToolInvocation,
|
||||
} from '@google/gemini-cli-core';
|
||||
import { expect, vi } from 'vitest';
|
||||
|
||||
export const mockOnUserConfirmForToolConfirmation = vi.fn();
|
||||
|
||||
export class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {
|
||||
constructor(
|
||||
private readonly tool: MockTool,
|
||||
params: object,
|
||||
) {
|
||||
super(params);
|
||||
}
|
||||
|
||||
getDescription(): string {
|
||||
return JSON.stringify(this.params);
|
||||
}
|
||||
|
||||
override shouldConfirmExecute(
|
||||
abortSignal: AbortSignal,
|
||||
): Promise<ToolCallConfirmationDetails | false> {
|
||||
return this.tool.shouldConfirmExecute(this.params, abortSignal);
|
||||
}
|
||||
|
||||
execute(
|
||||
signal: AbortSignal,
|
||||
updateOutput?: (output: string) => void,
|
||||
terminalColumns?: number,
|
||||
terminalRows?: number,
|
||||
): Promise<ToolResult> {
|
||||
return this.tool.execute(
|
||||
this.params,
|
||||
signal,
|
||||
updateOutput,
|
||||
terminalColumns,
|
||||
terminalRows,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: dedup with gemini-cli, add shouldConfirmExecute() support in core
|
||||
export class MockTool extends BaseDeclarativeTool<object, ToolResult> {
|
||||
constructor(
|
||||
name: string,
|
||||
displayName: string,
|
||||
canUpdateOutput = false,
|
||||
isOutputMarkdown = false,
|
||||
shouldConfirmExecute?: () => Promise<ToolCallConfirmationDetails | false>,
|
||||
) {
|
||||
super(
|
||||
name,
|
||||
displayName,
|
||||
'A mock tool for testing',
|
||||
Kind.Other,
|
||||
{},
|
||||
isOutputMarkdown,
|
||||
canUpdateOutput,
|
||||
);
|
||||
|
||||
if (shouldConfirmExecute) {
|
||||
this.shouldConfirmExecute.mockImplementation(shouldConfirmExecute);
|
||||
} else {
|
||||
// Default to no confirmation needed
|
||||
this.shouldConfirmExecute.mockResolvedValue(false);
|
||||
}
|
||||
}
|
||||
|
||||
execute = vi.fn();
|
||||
shouldConfirmExecute = vi.fn();
|
||||
|
||||
protected createInvocation(
|
||||
params: object,
|
||||
): ToolInvocation<object, ToolResult> {
|
||||
return new MockToolInvocation(this, params);
|
||||
}
|
||||
}
|
||||
|
||||
export function createStreamMessageRequest(
|
||||
text: string,
|
||||
messageId: string,
|
||||
taskId?: string,
|
||||
) {
|
||||
const request: {
|
||||
jsonrpc: string;
|
||||
id: string;
|
||||
method: string;
|
||||
params: {
|
||||
message: {
|
||||
kind: string;
|
||||
role: string;
|
||||
parts: [{ kind: string; text: string }];
|
||||
messageId: string;
|
||||
};
|
||||
metadata: {
|
||||
coderAgent: {
|
||||
kind: string;
|
||||
workspacePath: string;
|
||||
};
|
||||
};
|
||||
taskId?: string;
|
||||
};
|
||||
} = {
|
||||
jsonrpc: '2.0',
|
||||
id: '1',
|
||||
method: 'message/stream',
|
||||
params: {
|
||||
message: {
|
||||
kind: 'message',
|
||||
role: 'user',
|
||||
parts: [{ kind: 'text', text }],
|
||||
messageId,
|
||||
},
|
||||
metadata: {
|
||||
coderAgent: {
|
||||
kind: 'agent-settings',
|
||||
workspacePath: '/tmp',
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
if (taskId) {
|
||||
request.params.taskId = taskId;
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
export function assertUniqueFinalEventIsLast(
|
||||
events: SendStreamingMessageSuccessResponse[],
|
||||
) {
|
||||
// Final event is input-required & final
|
||||
const finalEvent = events[events.length - 1].result as TaskStatusUpdateEvent;
|
||||
expect(finalEvent.metadata?.['coderAgent']).toMatchObject({
|
||||
kind: 'state-change',
|
||||
});
|
||||
expect(finalEvent.status?.state).toBe('input-required');
|
||||
expect(finalEvent.final).toBe(true);
|
||||
|
||||
// There is only one event with final and its the last
|
||||
expect(
|
||||
events.filter((e) => (e.result as TaskStatusUpdateEvent).final).length,
|
||||
).toBe(1);
|
||||
expect(
|
||||
events.findIndex((e) => (e.result as TaskStatusUpdateEvent).final),
|
||||
).toBe(events.length - 1);
|
||||
}
|
||||
|
||||
export function assertTaskCreationAndWorkingStatus(
|
||||
events: SendStreamingMessageSuccessResponse[],
|
||||
) {
|
||||
// Initial task creation event
|
||||
const taskEvent = events[0].result as SDKTask;
|
||||
expect(taskEvent.kind).toBe('task');
|
||||
expect(taskEvent.status.state).toBe('submitted');
|
||||
|
||||
// Status update: working
|
||||
const workingEvent = events[1].result as TaskStatusUpdateEvent;
|
||||
expect(workingEvent.kind).toBe('status-update');
|
||||
expect(workingEvent.status.state).toBe('working');
|
||||
}
|
||||
104
packages/a2a-server/src/types.ts
Normal file
104
packages/a2a-server/src/types.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import type {
|
||||
MCPServerStatus,
|
||||
ToolConfirmationOutcome,
|
||||
} from '@google/gemini-cli-core';
|
||||
import type { TaskState } from '@a2a-js/sdk';
|
||||
|
||||
// Interfaces and enums for the CoderAgent protocol.
|
||||
|
||||
export enum CoderAgentEvent {
|
||||
/**
|
||||
* An event requesting one or more tool call confirmations.
|
||||
*/
|
||||
ToolCallConfirmationEvent = 'tool-call-confirmation',
|
||||
/**
|
||||
* An event updating on the status of one or more tool calls.
|
||||
*/
|
||||
ToolCallUpdateEvent = 'tool-call-update',
|
||||
/**
|
||||
* An event providing text updates on the task.
|
||||
*/
|
||||
TextContentEvent = 'text-content',
|
||||
/**
|
||||
* An event that indicates a change in the task's execution state.
|
||||
*/
|
||||
StateChangeEvent = 'state-change',
|
||||
/**
|
||||
* An user-sent event to initiate the agent.
|
||||
*/
|
||||
StateAgentSettingsEvent = 'agent-settings',
|
||||
/**
|
||||
* An event that contains a thought from the agent.
|
||||
*/
|
||||
ThoughtEvent = 'thought',
|
||||
}
|
||||
|
||||
export interface AgentSettings {
|
||||
kind: CoderAgentEvent.StateAgentSettingsEvent;
|
||||
workspacePath: string;
|
||||
}
|
||||
|
||||
export interface ToolCallConfirmation {
|
||||
kind: CoderAgentEvent.ToolCallConfirmationEvent;
|
||||
}
|
||||
|
||||
export interface ToolCallUpdate {
|
||||
kind: CoderAgentEvent.ToolCallUpdateEvent;
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
kind: CoderAgentEvent.TextContentEvent;
|
||||
}
|
||||
|
||||
export interface StateChange {
|
||||
kind: CoderAgentEvent.StateChangeEvent;
|
||||
}
|
||||
|
||||
export interface Thought {
|
||||
kind: CoderAgentEvent.ThoughtEvent;
|
||||
}
|
||||
|
||||
export type ThoughtSummary = {
|
||||
subject: string;
|
||||
description: string;
|
||||
};
|
||||
|
||||
export interface ToolConfirmationResponse {
|
||||
outcome: ToolConfirmationOutcome;
|
||||
callId: string;
|
||||
}
|
||||
|
||||
export type CoderAgentMessage =
|
||||
| AgentSettings
|
||||
| ToolCallConfirmation
|
||||
| ToolCallUpdate
|
||||
| TextContent
|
||||
| StateChange
|
||||
| Thought;
|
||||
|
||||
export interface TaskMetadata {
|
||||
id: string;
|
||||
contextId: string;
|
||||
taskState: TaskState;
|
||||
model: string;
|
||||
mcpServers: Array<{
|
||||
name: string;
|
||||
status: MCPServerStatus;
|
||||
tools: Array<{
|
||||
name: string;
|
||||
description: string;
|
||||
parameterSchema: unknown;
|
||||
}>;
|
||||
}>;
|
||||
availableTools: Array<{
|
||||
name: string;
|
||||
description: string;
|
||||
parameterSchema: unknown;
|
||||
}>;
|
||||
}
|
||||
11
packages/a2a-server/tsconfig.json
Normal file
11
packages/a2a-server/tsconfig.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"extends": "../../tsconfig.json",
|
||||
"compilerOptions": {
|
||||
"outDir": "dist",
|
||||
"lib": ["DOM", "DOM.Iterable", "ES2021"],
|
||||
"composite": true,
|
||||
"types": ["node", "vitest/globals"]
|
||||
},
|
||||
"include": ["index.ts", "src/**/*.ts", "src/**/*.json"],
|
||||
"exclude": ["node_modules", "dist"]
|
||||
}
|
||||
26
packages/a2a-server/vitest.config.ts
Normal file
26
packages/a2a-server/vitest.config.ts
Normal file
@@ -0,0 +1,26 @@
|
||||
/**
|
||||
* @license
|
||||
* Copyright 2025 Google LLC
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*/
|
||||
|
||||
import { defineConfig } from 'vitest/config';
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
reporters: [['default'], ['junit', { outputFile: 'junit.xml' }]],
|
||||
passWithNoTests: true,
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reportsDirectory: './coverage',
|
||||
reporter: [
|
||||
['text', { file: 'full-text-summary.txt' }],
|
||||
'html',
|
||||
'json',
|
||||
'lcov',
|
||||
'cobertura',
|
||||
['json-summary', { outputFile: 'coverage-summary.json' }],
|
||||
],
|
||||
},
|
||||
},
|
||||
});
|
||||
@@ -57,6 +57,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@babel/runtime": "^7.27.6",
|
||||
"@google/gemini-cli-test-utils": "file:../test-utils",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@types/command-exists": "^1.2.3",
|
||||
"@types/diff": "^7.0.2",
|
||||
@@ -72,8 +73,7 @@
|
||||
"pretty-format": "^30.0.2",
|
||||
"react-dom": "^19.1.0",
|
||||
"typescript": "^5.3.3",
|
||||
"vitest": "^3.1.1",
|
||||
"@google/gemini-cli-test-utils": "file:../test-utils"
|
||||
"vitest": "^3.1.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20"
|
||||
|
||||
@@ -84,6 +84,7 @@ import { FileCommandLoader } from '../../services/FileCommandLoader.js';
|
||||
import { McpPromptLoader } from '../../services/McpPromptLoader.js';
|
||||
import {
|
||||
SlashCommandStatus,
|
||||
ToolConfirmationOutcome,
|
||||
makeFakeConfig,
|
||||
ToolConfirmationOutcome,
|
||||
type IdeClient,
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
"@types/glob": "^8.1.0",
|
||||
"@types/html-to-text": "^9.0.4",
|
||||
"ajv": "^8.17.1",
|
||||
"fast-uri": "^3.0.6",
|
||||
"ajv-formats": "^3.0.0",
|
||||
"chardet": "^2.1.0",
|
||||
"diff": "^7.0.0",
|
||||
|
||||
@@ -5,6 +5,13 @@
|
||||
"target": "ES2022",
|
||||
"lib": ["ES2022", "dom"],
|
||||
"sourceMap": true,
|
||||
/*
|
||||
* skipLibCheck is necessary because the a2a-server package depends on
|
||||
* @google-cloud/storage which pulls in @types/request which depends on
|
||||
* tough-cookie@4.x while jsdom requires tough-cookie@5.x. This causes a
|
||||
* type checking error in ../../node_modules/@types/request/index.d.ts.
|
||||
*/
|
||||
"skipLibCheck": true,
|
||||
"rootDir": "src",
|
||||
"strict": true /* enable all strict type-checking options */
|
||||
/* Additional Checks */
|
||||
|
||||
Reference in New Issue
Block a user