Pure refactor: Consolidate isWithinRoot() function calling. (#4163)

This commit is contained in:
Tommaso Sciortino
2025-07-14 22:55:49 -07:00
committed by GitHub
parent e584241141
commit fefa7ecbea
13 changed files with 96 additions and 179 deletions

View File

@@ -11,6 +11,7 @@ import { SchemaValidator } from '../utils/schemaValidator.js';
import { BaseTool, ToolResult } from './tools.js';
import { Type } from '@google/genai';
import { shortenPath, makeRelative } from '../utils/paths.js';
import { isWithinRoot } from '../utils/fileUtils.js';
import { Config } from '../config/config.js';
// Subset of 'Path' interface provided by 'glob' that we can implement for testing
@@ -79,14 +80,8 @@ export interface GlobToolParams {
*/
export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
static readonly Name = 'glob';
/**
* Creates a new instance of the GlobLogic
* @param rootDirectory Root directory to ground this tool in.
*/
constructor(
private rootDirectory: string,
private config: Config,
) {
constructor(private config: Config) {
super(
GlobTool.Name,
'FindFiles',
@@ -118,28 +113,6 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
type: Type.OBJECT,
},
);
this.rootDirectory = path.resolve(rootDirectory);
}
/**
* Checks if a given path is within the root directory bounds.
* This security check prevents accessing files outside the designated root directory.
*
* @param pathToCheck The absolute path to validate
* @returns True if the path is within the root directory, false otherwise
*/
private isWithinRoot(pathToCheck: string): boolean {
const absolutePathToCheck = path.resolve(pathToCheck);
const normalizedPath = path.normalize(absolutePathToCheck);
const normalizedRoot = path.normalize(this.rootDirectory);
const rootWithSep = normalizedRoot.endsWith(path.sep)
? normalizedRoot
: normalizedRoot + path.sep;
return (
normalizedPath === normalizedRoot ||
normalizedPath.startsWith(rootWithSep)
);
}
/**
@@ -152,15 +125,15 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
}
const searchDirAbsolute = path.resolve(
this.rootDirectory,
this.config.getTargetDir(),
params.path || '.',
);
if (!this.isWithinRoot(searchDirAbsolute)) {
return `Search path ("${searchDirAbsolute}") resolves outside the tool's root directory ("${this.rootDirectory}").`;
if (!isWithinRoot(searchDirAbsolute, this.config.getTargetDir())) {
return `Search path ("${searchDirAbsolute}") resolves outside the tool's root directory ("${this.config.getTargetDir()}").`;
}
const targetDir = searchDirAbsolute || this.rootDirectory;
const targetDir = searchDirAbsolute || this.config.getTargetDir();
try {
if (!fs.existsSync(targetDir)) {
return `Search path does not exist ${targetDir}`;
@@ -189,8 +162,11 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
getDescription(params: GlobToolParams): string {
let description = `'${params.pattern}'`;
if (params.path) {
const searchDir = path.resolve(this.rootDirectory, params.path || '.');
const relativePath = makeRelative(searchDir, this.rootDirectory);
const searchDir = path.resolve(
this.config.getTargetDir(),
params.path || '.',
);
const relativePath = makeRelative(searchDir, this.config.getTargetDir());
description += ` within ${shortenPath(relativePath)}`;
}
return description;
@@ -213,7 +189,7 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
try {
const searchDirAbsolute = path.resolve(
this.rootDirectory,
this.config.getTargetDir(),
params.path || '.',
);
@@ -241,13 +217,15 @@ export class GlobTool extends BaseTool<GlobToolParams, ToolResult> {
if (respectGitIgnore) {
const relativePaths = entries.map((p) =>
path.relative(this.rootDirectory, p.fullpath()),
path.relative(this.config.getTargetDir(), p.fullpath()),
);
const filteredRelativePaths = fileDiscovery.filterFiles(relativePaths, {
respectGitIgnore,
});
const filteredAbsolutePaths = new Set(
filteredRelativePaths.map((p) => path.resolve(this.rootDirectory, p)),
filteredRelativePaths.map((p) =>
path.resolve(this.config.getTargetDir(), p),
),
);
filteredEntries = entries.filter((entry) =>