Load and use MCP server prompts as slash commands in the CLI (#4828)

Co-authored-by: harold <haroldmciver@google.com>
Co-authored-by: N. Taylor Mullen <ntaylormullen@google.com>
This commit is contained in:
christine betts
2025-07-25 20:56:33 +00:00
committed by GitHub
parent de96887789
commit eb65034117
19 changed files with 761 additions and 100 deletions

View File

@@ -40,13 +40,19 @@ vi.mock('../ui/commands/extensionsCommand.js', () => ({
extensionsCommand: {},
}));
vi.mock('../ui/commands/helpCommand.js', () => ({ helpCommand: {} }));
vi.mock('../ui/commands/mcpCommand.js', () => ({ mcpCommand: {} }));
vi.mock('../ui/commands/memoryCommand.js', () => ({ memoryCommand: {} }));
vi.mock('../ui/commands/privacyCommand.js', () => ({ privacyCommand: {} }));
vi.mock('../ui/commands/quitCommand.js', () => ({ quitCommand: {} }));
vi.mock('../ui/commands/statsCommand.js', () => ({ statsCommand: {} }));
vi.mock('../ui/commands/themeCommand.js', () => ({ themeCommand: {} }));
vi.mock('../ui/commands/toolsCommand.js', () => ({ toolsCommand: {} }));
vi.mock('../ui/commands/mcpCommand.js', () => ({
mcpCommand: {
name: 'mcp',
description: 'MCP command',
kind: 'BUILT_IN',
},
}));
describe('BuiltinCommandLoader', () => {
let mockConfig: Config;
@@ -114,5 +120,8 @@ describe('BuiltinCommandLoader', () => {
const ideCmd = commands.find((c) => c.name === 'ide');
expect(ideCmd).toBeDefined();
const mcpCmd = commands.find((c) => c.name === 'mcp');
expect(mcpCmd).toBeDefined();
});
});

View File

@@ -58,9 +58,9 @@ export class BuiltinCommandLoader implements ICommandLoader {
extensionsCommand,
helpCommand,
ideCommand(this.config),
mcpCommand,
memoryCommand,
privacyCommand,
mcpCommand,
quitCommand,
restoreCommand(this.config),
statsCommand,

View File

@@ -0,0 +1,231 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/
import {
Config,
getErrorMessage,
getMCPServerPrompts,
} from '@google/gemini-cli-core';
import {
CommandContext,
CommandKind,
SlashCommand,
SlashCommandActionReturn,
} from '../ui/commands/types.js';
import { ICommandLoader } from './types.js';
import { PromptArgument } from '@modelcontextprotocol/sdk/types.js';
/**
* Discovers and loads executable slash commands from prompts exposed by
* Model-Context-Protocol (MCP) servers.
*/
export class McpPromptLoader implements ICommandLoader {
constructor(private readonly config: Config | null) {}
/**
* Loads all available prompts from all configured MCP servers and adapts
* them into executable SlashCommand objects.
*
* @param _signal An AbortSignal (unused for this synchronous loader).
* @returns A promise that resolves to an array of loaded SlashCommands.
*/
loadCommands(_signal: AbortSignal): Promise<SlashCommand[]> {
const promptCommands: SlashCommand[] = [];
if (!this.config) {
return Promise.resolve([]);
}
const mcpServers = this.config.getMcpServers() || {};
for (const serverName in mcpServers) {
const prompts = getMCPServerPrompts(this.config, serverName) || [];
for (const prompt of prompts) {
const commandName = `${prompt.name}`;
const newPromptCommand: SlashCommand = {
name: commandName,
description: prompt.description || `Invoke prompt ${prompt.name}`,
kind: CommandKind.MCP_PROMPT,
subCommands: [
{
name: 'help',
description: 'Show help for this prompt',
kind: CommandKind.MCP_PROMPT,
action: async (): Promise<SlashCommandActionReturn> => {
if (!prompt.arguments || prompt.arguments.length === 0) {
return {
type: 'message',
messageType: 'info',
content: `Prompt "${prompt.name}" has no arguments.`,
};
}
let helpMessage = `Arguments for "${prompt.name}":\n\n`;
if (prompt.arguments && prompt.arguments.length > 0) {
helpMessage += `You can provide arguments by name (e.g., --argName="value") or by position.\n\n`;
helpMessage += `e.g., ${prompt.name} ${prompt.arguments?.map((_) => `"foo"`)} is equivalent to ${prompt.name} ${prompt.arguments?.map((arg) => `--${arg.name}="foo"`)}\n\n`;
}
for (const arg of prompt.arguments) {
helpMessage += ` --${arg.name}\n`;
if (arg.description) {
helpMessage += ` ${arg.description}\n`;
}
helpMessage += ` (required: ${
arg.required ? 'yes' : 'no'
})\n\n`;
}
return {
type: 'message',
messageType: 'info',
content: helpMessage,
};
},
},
],
action: async (
context: CommandContext,
args: string,
): Promise<SlashCommandActionReturn> => {
if (!this.config) {
return {
type: 'message',
messageType: 'error',
content: 'Config not loaded.',
};
}
const promptInputs = this.parseArgs(args, prompt.arguments);
if (promptInputs instanceof Error) {
return {
type: 'message',
messageType: 'error',
content: promptInputs.message,
};
}
try {
const mcpServers = this.config.getMcpServers() || {};
const mcpServerConfig = mcpServers[serverName];
if (!mcpServerConfig) {
return {
type: 'message',
messageType: 'error',
content: `MCP server config not found for '${serverName}'.`,
};
}
const result = await prompt.invoke(promptInputs);
if (result.error) {
return {
type: 'message',
messageType: 'error',
content: `Error invoking prompt: ${result.error}`,
};
}
if (!result.messages?.[0]?.content?.text) {
return {
type: 'message',
messageType: 'error',
content:
'Received an empty or invalid prompt response from the server.',
};
}
return {
type: 'submit_prompt',
content: JSON.stringify(result.messages[0].content.text),
};
} catch (error) {
return {
type: 'message',
messageType: 'error',
content: `Error: ${getErrorMessage(error)}`,
};
}
},
completion: async (_: CommandContext, partialArg: string) => {
if (!prompt || !prompt.arguments) {
return [];
}
const suggestions: string[] = [];
const usedArgNames = new Set(
(partialArg.match(/--([^=]+)/g) || []).map((s) => s.substring(2)),
);
for (const arg of prompt.arguments) {
if (!usedArgNames.has(arg.name)) {
suggestions.push(`--${arg.name}=""`);
}
}
return suggestions;
},
};
promptCommands.push(newPromptCommand);
}
}
return Promise.resolve(promptCommands);
}
private parseArgs(
userArgs: string,
promptArgs: PromptArgument[] | undefined,
): Record<string, unknown> | Error {
const argValues: { [key: string]: string } = {};
const promptInputs: Record<string, unknown> = {};
// arg parsing: --key="value" or --key=value
const namedArgRegex = /--([^=]+)=(?:"((?:\\.|[^"\\])*)"|([^ ]*))/g;
let match;
const remainingArgs: string[] = [];
let lastIndex = 0;
while ((match = namedArgRegex.exec(userArgs)) !== null) {
const key = match[1];
const value = match[2] ?? match[3]; // Quoted or unquoted value
argValues[key] = value;
// Capture text between matches as potential positional args
if (match.index > lastIndex) {
remainingArgs.push(userArgs.substring(lastIndex, match.index).trim());
}
lastIndex = namedArgRegex.lastIndex;
}
// Capture any remaining text after the last named arg
if (lastIndex < userArgs.length) {
remainingArgs.push(userArgs.substring(lastIndex).trim());
}
const positionalArgs = remainingArgs.join(' ').split(/ +/);
if (!promptArgs) {
return promptInputs;
}
for (const arg of promptArgs) {
if (argValues[arg.name]) {
promptInputs[arg.name] = argValues[arg.name];
}
}
const unfilledArgs = promptArgs.filter(
(arg) => arg.required && !promptInputs[arg.name],
);
const missingArgs: string[] = [];
for (let i = 0; i < unfilledArgs.length; i++) {
if (positionalArgs.length > i && positionalArgs[i]) {
promptInputs[unfilledArgs[i].name] = positionalArgs[i];
} else {
missingArgs.push(unfilledArgs[i].name);
}
}
if (missingArgs.length > 0) {
const missingArgNames = missingArgs.map((name) => `--${name}`).join(', ');
return new Error(`Missing required argument(s): ${missingArgNames}`);
}
return promptInputs;
}
}

View File

@@ -125,6 +125,7 @@ vi.mock('@google/gemini-cli-core', async (importOriginal) => {
getToolCallCommand: vi.fn(() => opts.toolCallCommand),
getMcpServerCommand: vi.fn(() => opts.mcpServerCommand),
getMcpServers: vi.fn(() => opts.mcpServers),
getPromptRegistry: vi.fn(),
getExtensions: vi.fn(() => []),
getBlockedMcpServers: vi.fn(() => []),
getUserAgent: vi.fn(() => opts.userAgent || 'test-agent'),

View File

@@ -71,6 +71,7 @@ describe('mcpCommand', () => {
getToolRegistry: ReturnType<typeof vi.fn>;
getMcpServers: ReturnType<typeof vi.fn>;
getBlockedMcpServers: ReturnType<typeof vi.fn>;
getPromptRegistry: ReturnType<typeof vi.fn>;
};
beforeEach(() => {
@@ -92,6 +93,10 @@ describe('mcpCommand', () => {
}),
getMcpServers: vi.fn().mockReturnValue({}),
getBlockedMcpServers: vi.fn().mockReturnValue([]),
getPromptRegistry: vi.fn().mockResolvedValue({
getAllPrompts: vi.fn().mockReturnValue([]),
getPromptsByServer: vi.fn().mockReturnValue([]),
}),
};
mockContext = createMockCommandContext({
@@ -223,7 +228,7 @@ describe('mcpCommand', () => {
// Server 2 - Connected
expect(message).toContain(
'🟢 \u001b[1mserver2\u001b[0m - Ready (1 tools)',
'🟢 \u001b[1mserver2\u001b[0m - Ready (1 tool)',
);
expect(message).toContain('server2_tool1');
@@ -365,13 +370,13 @@ describe('mcpCommand', () => {
if (isMessageAction(result)) {
const message = result.content;
expect(message).toContain(
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)',
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)',
);
expect(message).toContain('\u001b[36mserver1_tool1\u001b[0m');
expect(message).toContain(
'🔴 \u001b[1mserver2\u001b[0m - Disconnected (0 tools cached)',
);
expect(message).toContain('No tools available');
expect(message).toContain('No tools or prompts available');
}
});
@@ -421,10 +426,10 @@ describe('mcpCommand', () => {
// Check server statuses
expect(message).toContain(
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tools)',
'🟢 \u001b[1mserver1\u001b[0m - Ready (1 tool)',
);
expect(message).toContain(
'🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools will appear when ready)',
'🔄 \u001b[1mserver2\u001b[0m - Starting... (first startup may take longer) (tools and prompts will appear when ready)',
);
}
});
@@ -994,6 +999,9 @@ describe('mcpCommand', () => {
getBlockedMcpServers: vi.fn().mockReturnValue([]),
getToolRegistry: vi.fn().mockResolvedValue(mockToolRegistry),
getGeminiClient: vi.fn().mockReturnValue(mockGeminiClient),
getPromptRegistry: vi.fn().mockResolvedValue({
getPromptsByServer: vi.fn().mockReturnValue([]),
}),
},
},
});

View File

@@ -12,6 +12,7 @@ import {
MessageActionReturn,
} from './types.js';
import {
DiscoveredMCPPrompt,
DiscoveredMCPTool,
getMCPDiscoveryState,
getMCPServerStatus,
@@ -101,6 +102,8 @@ const getMcpStatus = async (
(tool) =>
tool instanceof DiscoveredMCPTool && tool.serverName === serverName,
) as DiscoveredMCPTool[];
const promptRegistry = await config.getPromptRegistry();
const serverPrompts = promptRegistry.getPromptsByServer(serverName) || [];
const status = getMCPServerStatus(serverName);
@@ -160,9 +163,26 @@ const getMcpStatus = async (
// Add tool count with conditional messaging
if (status === MCPServerStatus.CONNECTED) {
message += ` (${serverTools.length} tools)`;
const parts = [];
if (serverTools.length > 0) {
parts.push(
`${serverTools.length} ${serverTools.length === 1 ? 'tool' : 'tools'}`,
);
}
if (serverPrompts.length > 0) {
parts.push(
`${serverPrompts.length} ${
serverPrompts.length === 1 ? 'prompt' : 'prompts'
}`,
);
}
if (parts.length > 0) {
message += ` (${parts.join(', ')})`;
} else {
message += ` (0 tools)`;
}
} else if (status === MCPServerStatus.CONNECTING) {
message += ` (tools will appear when ready)`;
message += ` (tools and prompts will appear when ready)`;
} else {
message += ` (${serverTools.length} tools cached)`;
}
@@ -186,6 +206,7 @@ const getMcpStatus = async (
message += RESET_COLOR;
if (serverTools.length > 0) {
message += ` ${COLOR_CYAN}Tools:${RESET_COLOR}\n`;
serverTools.forEach((tool) => {
if (showDescriptions && tool.description) {
// Format tool name in cyan using simple ANSI cyan color
@@ -222,12 +243,41 @@ const getMcpStatus = async (
}
}
});
} else {
}
if (serverPrompts.length > 0) {
if (serverTools.length > 0) {
message += '\n';
}
message += ` ${COLOR_CYAN}Prompts:${RESET_COLOR}\n`;
serverPrompts.forEach((prompt: DiscoveredMCPPrompt) => {
if (showDescriptions && prompt.description) {
message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}`;
const descLines = prompt.description.trim().split('\n');
if (descLines) {
message += ':\n';
for (const descLine of descLines) {
message += ` ${COLOR_GREEN}${descLine}${RESET_COLOR}\n`;
}
} else {
message += '\n';
}
} else {
message += ` - ${COLOR_CYAN}${prompt.name}${RESET_COLOR}\n`;
}
});
}
if (serverTools.length === 0 && serverPrompts.length === 0) {
message += ' No tools or prompts available\n';
} else if (serverTools.length === 0) {
message += ' No tools available';
if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}`;
}
message += '\n';
} else if (status === MCPServerStatus.DISCONNECTED && needsAuthHint) {
// This case is for when serverTools.length > 0
message += ` ${COLOR_GREY}(type: "/mcp auth ${serverName}" to authenticate this server)${RESET_COLOR}\n`;
}
message += '\n';
}
@@ -328,11 +378,10 @@ const authCommand: SlashCommand = {
// Import dynamically to avoid circular dependencies
const { MCPOAuthProvider } = await import('@google/gemini-cli-core');
// Create OAuth config for authentication (will be discovered automatically)
const oauthConfig = server.oauth || {
authorizationUrl: '', // Will be discovered automatically
tokenUrl: '', // Will be discovered automatically
};
let oauthConfig = server.oauth;
if (!oauthConfig) {
oauthConfig = { enabled: false };
}
// Pass the MCP server URL for OAuth discovery
const mcpServerUrl = server.httpUrl || server.url;

View File

@@ -128,6 +128,7 @@ export type SlashCommandActionReturn =
export enum CommandKind {
BUILT_IN = 'built-in',
FILE = 'file',
MCP_PROMPT = 'mcp-prompt',
}
// The standardized contract for any command in the system.

View File

@@ -28,6 +28,13 @@ vi.mock('../../services/FileCommandLoader.js', () => ({
})),
}));
const mockMcpLoadCommands = vi.fn();
vi.mock('../../services/McpPromptLoader.js', () => ({
McpPromptLoader: vi.fn().mockImplementation(() => ({
loadCommands: mockMcpLoadCommands,
})),
}));
vi.mock('../contexts/SessionContext.js', () => ({
useSessionStats: vi.fn(() => ({ stats: {} })),
}));
@@ -41,6 +48,7 @@ import { LoadedSettings } from '../../config/settings.js';
import { MessageType } from '../types.js';
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
import { McpPromptLoader } from '../../services/McpPromptLoader.js';
const createTestCommand = (
overrides: Partial<SlashCommand>,
@@ -75,14 +83,17 @@ describe('useSlashCommandProcessor', () => {
(vi.mocked(BuiltinCommandLoader) as Mock).mockClear();
mockBuiltinLoadCommands.mockResolvedValue([]);
mockFileLoadCommands.mockResolvedValue([]);
mockMcpLoadCommands.mockResolvedValue([]);
});
const setupProcessorHook = (
builtinCommands: SlashCommand[] = [],
fileCommands: SlashCommand[] = [],
mcpCommands: SlashCommand[] = [],
) => {
mockBuiltinLoadCommands.mockResolvedValue(Object.freeze(builtinCommands));
mockFileLoadCommands.mockResolvedValue(Object.freeze(fileCommands));
mockMcpLoadCommands.mockResolvedValue(Object.freeze(mcpCommands));
const { result } = renderHook(() =>
useSlashCommandProcessor(
@@ -111,6 +122,7 @@ describe('useSlashCommandProcessor', () => {
setupProcessorHook();
expect(BuiltinCommandLoader).toHaveBeenCalledWith(mockConfig);
expect(FileCommandLoader).toHaveBeenCalledWith(mockConfig);
expect(McpPromptLoader).toHaveBeenCalledWith(mockConfig);
});
it('should call loadCommands and populate state after mounting', async () => {
@@ -124,6 +136,7 @@ describe('useSlashCommandProcessor', () => {
expect(result.current.slashCommands[0]?.name).toBe('test');
expect(mockBuiltinLoadCommands).toHaveBeenCalledTimes(1);
expect(mockFileLoadCommands).toHaveBeenCalledTimes(1);
expect(mockMcpLoadCommands).toHaveBeenCalledTimes(1);
});
it('should provide an immutable array of commands to consumers', async () => {
@@ -369,6 +382,38 @@ describe('useSlashCommandProcessor', () => {
expect.any(Number),
);
});
it('should handle "submit_prompt" action returned from a mcp-based command', async () => {
const mcpCommand = createTestCommand(
{
name: 'mcpcmd',
description: 'A command from mcp',
action: async () => ({
type: 'submit_prompt',
content: 'The actual prompt from the mcp command.',
}),
},
CommandKind.MCP_PROMPT,
);
const result = setupProcessorHook([], [], [mcpCommand]);
await waitFor(() => expect(result.current.slashCommands).toHaveLength(1));
let actionResult;
await act(async () => {
actionResult = await result.current.handleSlashCommand('/mcpcmd');
});
expect(actionResult).toEqual({
type: 'submit_prompt',
content: 'The actual prompt from the mcp command.',
});
expect(mockAddItem).toHaveBeenCalledWith(
{ type: MessageType.USER, text: '/mcpcmd' },
expect.any(Number),
);
});
});
describe('Command Parsing and Matching', () => {
@@ -441,6 +486,39 @@ describe('useSlashCommandProcessor', () => {
});
describe('Command Precedence', () => {
it('should override mcp-based commands with file-based commands of the same name', async () => {
const mcpAction = vi.fn();
const fileAction = vi.fn();
const mcpCommand = createTestCommand(
{
name: 'override',
description: 'mcp',
action: mcpAction,
},
CommandKind.MCP_PROMPT,
);
const fileCommand = createTestCommand(
{ name: 'override', description: 'file', action: fileAction },
CommandKind.FILE,
);
const result = setupProcessorHook([], [fileCommand], [mcpCommand]);
await waitFor(() => {
// The service should only return one command with the name 'override'
expect(result.current.slashCommands).toHaveLength(1);
});
await act(async () => {
await result.current.handleSlashCommand('/override');
});
// Only the file-based command's action should be called.
expect(fileAction).toHaveBeenCalledTimes(1);
expect(mcpAction).not.toHaveBeenCalled();
});
it('should prioritize a command with a primary name over a command with a matching alias', async () => {
const quitAction = vi.fn();
const exitAction = vi.fn();

View File

@@ -23,6 +23,7 @@ import { type CommandContext, type SlashCommand } from '../commands/types.js';
import { CommandService } from '../../services/CommandService.js';
import { BuiltinCommandLoader } from '../../services/BuiltinCommandLoader.js';
import { FileCommandLoader } from '../../services/FileCommandLoader.js';
import { McpPromptLoader } from '../../services/McpPromptLoader.js';
/**
* Hook to define and process slash commands (e.g., /help, /clear).
@@ -164,6 +165,7 @@ export const useSlashCommandProcessor = (
const controller = new AbortController();
const load = async () => {
const loaders = [
new McpPromptLoader(config),
new BuiltinCommandLoader(config),
new FileCommandLoader(config),
];
@@ -246,82 +248,95 @@ export const useSlashCommandProcessor = (
args,
},
};
const result = await commandToExecute.action(
fullCommandContext,
args,
);
try {
const result = await commandToExecute.action(
fullCommandContext,
args,
);
if (result) {
switch (result.type) {
case 'tool':
return {
type: 'schedule_tool',
toolName: result.toolName,
toolArgs: result.toolArgs,
};
case 'message':
addItem(
{
type:
result.messageType === 'error'
? MessageType.ERROR
: MessageType.INFO,
text: result.content,
},
Date.now(),
);
return { type: 'handled' };
case 'dialog':
switch (result.dialog) {
case 'help':
setShowHelp(true);
return { type: 'handled' };
case 'auth':
openAuthDialog();
return { type: 'handled' };
case 'theme':
openThemeDialog();
return { type: 'handled' };
case 'editor':
openEditorDialog();
return { type: 'handled' };
case 'privacy':
openPrivacyNotice();
return { type: 'handled' };
default: {
const unhandled: never = result.dialog;
throw new Error(
`Unhandled slash command result: ${unhandled}`,
);
if (result) {
switch (result.type) {
case 'tool':
return {
type: 'schedule_tool',
toolName: result.toolName,
toolArgs: result.toolArgs,
};
case 'message':
addItem(
{
type:
result.messageType === 'error'
? MessageType.ERROR
: MessageType.INFO,
text: result.content,
},
Date.now(),
);
return { type: 'handled' };
case 'dialog':
switch (result.dialog) {
case 'help':
setShowHelp(true);
return { type: 'handled' };
case 'auth':
openAuthDialog();
return { type: 'handled' };
case 'theme':
openThemeDialog();
return { type: 'handled' };
case 'editor':
openEditorDialog();
return { type: 'handled' };
case 'privacy':
openPrivacyNotice();
return { type: 'handled' };
default: {
const unhandled: never = result.dialog;
throw new Error(
`Unhandled slash command result: ${unhandled}`,
);
}
}
case 'load_history': {
await config
?.getGeminiClient()
?.setHistory(result.clientHistory);
fullCommandContext.ui.clear();
result.history.forEach((item, index) => {
fullCommandContext.ui.addItem(item, index);
});
return { type: 'handled' };
}
case 'load_history': {
await config
?.getGeminiClient()
?.setHistory(result.clientHistory);
fullCommandContext.ui.clear();
result.history.forEach((item, index) => {
fullCommandContext.ui.addItem(item, index);
});
return { type: 'handled' };
}
case 'quit':
setQuittingMessages(result.messages);
setTimeout(() => {
process.exit(0);
}, 100);
return { type: 'handled' };
case 'quit':
setQuittingMessages(result.messages);
setTimeout(() => {
process.exit(0);
}, 100);
return { type: 'handled' };
case 'submit_prompt':
return {
type: 'submit_prompt',
content: result.content,
};
default: {
const unhandled: never = result;
throw new Error(`Unhandled slash command result: ${unhandled}`);
case 'submit_prompt':
return {
type: 'submit_prompt',
content: result.content,
};
default: {
const unhandled: never = result;
throw new Error(
`Unhandled slash command result: ${unhandled}`,
);
}
}
}
} catch (e) {
addItem(
{
type: MessageType.ERROR,
text: e instanceof Error ? e.message : String(e),
},
Date.now(),
);
return { type: 'handled' };
}
return { type: 'handled' };

View File

@@ -1100,7 +1100,7 @@ describe('useCompletion', () => {
result.current.handleAutocomplete(0);
});
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory');
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory ');
});
it('should append a sub-command when the parent is complete', () => {
@@ -1145,7 +1145,7 @@ describe('useCompletion', () => {
result.current.handleAutocomplete(1); // index 1 is 'add'
});
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add');
expect(mockBuffer.setText).toHaveBeenCalledWith('/memory add ');
});
it('should complete a command with an alternative name', () => {
@@ -1190,7 +1190,7 @@ describe('useCompletion', () => {
result.current.handleAutocomplete(0);
});
expect(mockBuffer.setText).toHaveBeenCalledWith('/help');
expect(mockBuffer.setText).toHaveBeenCalledWith('/help ');
});
it('should complete a file path', async () => {

View File

@@ -638,10 +638,17 @@ export function useCompletion(
// Determine the base path of the command.
// - If there's a trailing space, the whole command is the base.
// - If it's a known parent path, the whole command is the base.
// - If the last part is a complete argument, the whole command is the base.
// - Otherwise, the base is everything EXCEPT the last partial part.
const lastPart = parts.length > 0 ? parts[parts.length - 1] : '';
const isLastPartACompleteArg =
lastPart.startsWith('--') && lastPart.includes('=');
const basePath =
hasTrailingSpace || isParentPath ? parts : parts.slice(0, -1);
const newValue = `/${[...basePath, suggestion].join(' ')}`;
hasTrailingSpace || isParentPath || isLastPartACompleteArg
? parts
: parts.slice(0, -1);
const newValue = `/${[...basePath, suggestion].join(' ')} `;
buffer.setText(newValue);
} else {