This is page 4 of 6. Use http://codebase.md/bsmi021/mcp-gemini-server?page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .eslintignore
├── .eslintrc.json
├── .gitignore
├── .prettierrc.json
├── Dockerfile
├── LICENSE
├── package-lock.json
├── package.json
├── README.md
├── review-prompt.txt
├── scripts
│ ├── gemini-review.sh
│ └── run-with-health-check.sh
├── smithery.yaml
├── src
│ ├── config
│ │ └── ConfigurationManager.ts
│ ├── createServer.ts
│ ├── index.ts
│ ├── resources
│ │ └── system-prompt.md
│ ├── server.ts
│ ├── services
│ │ ├── ExampleService.ts
│ │ ├── gemini
│ │ │ ├── GeminiCacheService.ts
│ │ │ ├── GeminiChatService.ts
│ │ │ ├── GeminiContentService.ts
│ │ │ ├── GeminiGitDiffService.ts
│ │ │ ├── GeminiPromptTemplates.ts
│ │ │ ├── GeminiTypes.ts
│ │ │ ├── GeminiUrlContextService.ts
│ │ │ ├── GeminiValidationSchemas.ts
│ │ │ ├── GitHubApiService.ts
│ │ │ ├── GitHubUrlParser.ts
│ │ │ └── ModelMigrationService.ts
│ │ ├── GeminiService.ts
│ │ ├── index.ts
│ │ ├── mcp
│ │ │ ├── index.ts
│ │ │ └── McpClientService.ts
│ │ ├── ModelSelectionService.ts
│ │ ├── session
│ │ │ ├── index.ts
│ │ │ ├── InMemorySessionStore.ts
│ │ │ ├── SessionStore.ts
│ │ │ └── SQLiteSessionStore.ts
│ │ └── SessionService.ts
│ ├── tools
│ │ ├── exampleToolParams.ts
│ │ ├── geminiCacheParams.ts
│ │ ├── geminiCacheTool.ts
│ │ ├── geminiChatParams.ts
│ │ ├── geminiChatTool.ts
│ │ ├── geminiCodeReviewParams.ts
│ │ ├── geminiCodeReviewTool.ts
│ │ ├── geminiGenerateContentConsolidatedParams.ts
│ │ ├── geminiGenerateContentConsolidatedTool.ts
│ │ ├── geminiGenerateImageParams.ts
│ │ ├── geminiGenerateImageTool.ts
│ │ ├── geminiGenericParamSchemas.ts
│ │ ├── geminiRouteMessageParams.ts
│ │ ├── geminiRouteMessageTool.ts
│ │ ├── geminiUrlAnalysisTool.ts
│ │ ├── index.ts
│ │ ├── mcpClientParams.ts
│ │ ├── mcpClientTool.ts
│ │ ├── registration
│ │ │ ├── index.ts
│ │ │ ├── registerAllTools.ts
│ │ │ ├── ToolAdapter.ts
│ │ │ └── ToolRegistry.ts
│ │ ├── schemas
│ │ │ ├── BaseToolSchema.ts
│ │ │ ├── CommonSchemas.ts
│ │ │ ├── index.ts
│ │ │ ├── ToolSchemas.ts
│ │ │ └── writeToFileParams.ts
│ │ └── writeToFileTool.ts
│ ├── types
│ │ ├── exampleServiceTypes.ts
│ │ ├── geminiServiceTypes.ts
│ │ ├── gitdiff-parser.d.ts
│ │ ├── googleGenAI.d.ts
│ │ ├── googleGenAITypes.ts
│ │ ├── index.ts
│ │ ├── micromatch.d.ts
│ │ ├── modelcontextprotocol-sdk.d.ts
│ │ ├── node-fetch.d.ts
│ │ └── serverTypes.ts
│ └── utils
│ ├── errors.ts
│ ├── filePathSecurity.ts
│ ├── FileSecurityService.ts
│ ├── geminiErrors.ts
│ ├── healthCheck.ts
│ ├── index.ts
│ ├── logger.ts
│ ├── RetryService.ts
│ ├── ToolError.ts
│ └── UrlSecurityService.ts
├── tests
│ ├── .env.test.example
│ ├── basic-router.test.vitest.ts
│ ├── e2e
│ │ ├── clients
│ │ │ └── mcp-test-client.ts
│ │ ├── README.md
│ │ └── streamableHttpTransport.test.vitest.ts
│ ├── integration
│ │ ├── dummyMcpServerSse.ts
│ │ ├── dummyMcpServerStdio.ts
│ │ ├── geminiRouterIntegration.test.vitest.ts
│ │ ├── mcpClientIntegration.test.vitest.ts
│ │ ├── multiModelIntegration.test.vitest.ts
│ │ └── urlContextIntegration.test.vitest.ts
│ ├── tsconfig.test.json
│ ├── unit
│ │ ├── config
│ │ │ └── ConfigurationManager.multimodel.test.vitest.ts
│ │ ├── server
│ │ │ └── transportLogic.test.vitest.ts
│ │ ├── services
│ │ │ ├── gemini
│ │ │ │ ├── GeminiChatService.test.vitest.ts
│ │ │ │ ├── GeminiGitDiffService.test.vitest.ts
│ │ │ │ ├── geminiImageGeneration.test.vitest.ts
│ │ │ │ ├── GeminiPromptTemplates.test.vitest.ts
│ │ │ │ ├── GeminiUrlContextService.test.vitest.ts
│ │ │ │ ├── GeminiValidationSchemas.test.vitest.ts
│ │ │ │ ├── GitHubApiService.test.vitest.ts
│ │ │ │ ├── GitHubUrlParser.test.vitest.ts
│ │ │ │ └── ThinkingBudget.test.vitest.ts
│ │ │ ├── mcp
│ │ │ │ └── McpClientService.test.vitest.ts
│ │ │ ├── ModelSelectionService.test.vitest.ts
│ │ │ └── session
│ │ │ └── SQLiteSessionStore.test.vitest.ts
│ │ ├── tools
│ │ │ ├── geminiCacheTool.test.vitest.ts
│ │ │ ├── geminiChatTool.test.vitest.ts
│ │ │ ├── geminiCodeReviewTool.test.vitest.ts
│ │ │ ├── geminiGenerateContentConsolidatedTool.test.vitest.ts
│ │ │ ├── geminiGenerateImageTool.test.vitest.ts
│ │ │ ├── geminiRouteMessageTool.test.vitest.ts
│ │ │ ├── mcpClientTool.test.vitest.ts
│ │ │ ├── mcpToolsTests.test.vitest.ts
│ │ │ └── schemas
│ │ │ ├── BaseToolSchema.test.vitest.ts
│ │ │ ├── ToolParamSchemas.test.vitest.ts
│ │ │ └── ToolSchemas.test.vitest.ts
│ │ └── utils
│ │ ├── errors.test.vitest.ts
│ │ ├── FileSecurityService.test.vitest.ts
│ │ ├── FileSecurityService.vitest.ts
│ │ ├── FileSecurityServiceBasics.test.vitest.ts
│ │ ├── healthCheck.test.vitest.ts
│ │ ├── RetryService.test.vitest.ts
│ │ └── UrlSecurityService.test.vitest.ts
│ └── utils
│ ├── assertions.ts
│ ├── debug-error.ts
│ ├── env-check.ts
│ ├── environment.ts
│ ├── error-helpers.ts
│ ├── express-mocks.ts
│ ├── integration-types.ts
│ ├── mock-types.ts
│ ├── test-fixtures.ts
│ ├── test-generators.ts
│ ├── test-setup.ts
│ └── vitest.d.ts
├── tsconfig.json
├── tsconfig.test.json
├── vitest-globals.d.ts
├── vitest.config.ts
└── vitest.setup.ts
```
# Files
--------------------------------------------------------------------------------
/src/tools/geminiUrlAnalysisTool.ts:
--------------------------------------------------------------------------------
```typescript
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { z } from "zod";
import { GeminiService } from "../services/index.js";
import { logger } from "../utils/index.js";
import { mapAnyErrorToMcpError } from "../utils/errors.js";
// Tool Name and Description
export const GEMINI_URL_ANALYSIS_TOOL_NAME = "gemini_url_analysis";
export const GEMINI_URL_ANALYSIS_TOOL_DESCRIPTION = `
Advanced URL analysis tool that fetches content from web pages and performs specialized analysis tasks.
Supports various analysis types including summarization, comparison, information extraction, and Q&A.
Automatically handles URL fetching, content processing, and intelligent model selection for optimal results.
`;
// Analysis types enum
const analysisTypeSchema = z
.enum([
"summary",
"comparison",
"extraction",
"qa",
"sentiment",
"fact-check",
"content-classification",
"readability",
"seo-analysis",
])
.describe("Type of analysis to perform on the URL content");
// Extraction schema for structured data extraction
const extractionSchemaSchema = z
.record(z.unknown())
.optional()
.describe(
"JSON schema or structure definition for extracting specific information from content"
);
// Parameters for the URL analysis tool
export const GEMINI_URL_ANALYSIS_PARAMS = {
urls: z
.array(z.string().url())
.min(1)
.max(20)
.describe("URLs to analyze (1-20 URLs supported)"),
analysisType: analysisTypeSchema,
query: z
.string()
.min(1)
.optional()
.describe("Specific query or instruction for the analysis"),
extractionSchema: extractionSchemaSchema,
questions: z
.array(z.string())
.optional()
.describe("List of specific questions to answer (for Q&A analysis)"),
compareBy: z
.array(z.string())
.optional()
.describe("Specific aspects to compare when using comparison analysis"),
outputFormat: z
.enum(["text", "json", "markdown", "structured"])
.default("text")
.optional()
.describe("Desired output format for the analysis results"),
includeMetadata: z
.boolean()
.default(true)
.optional()
.describe(
"Include URL metadata (title, description, etc.) in the analysis"
),
fetchOptions: z
.object({
maxContentKb: z
.number()
.min(1)
.max(1000)
.default(100)
.optional()
.describe("Maximum content size per URL in KB"),
timeoutMs: z
.number()
.min(1000)
.max(30000)
.default(10000)
.optional()
.describe("Fetch timeout per URL in milliseconds"),
allowedDomains: z
.array(z.string())
.optional()
.describe("Specific domains to allow for this request"),
userAgent: z
.string()
.optional()
.describe("Custom User-Agent header for URL requests"),
})
.optional()
.describe("Advanced options for URL fetching"),
modelName: z
.string()
.optional()
.describe("Specific Gemini model to use (auto-selected if not specified)"),
};
/**
* Registers the gemini_url_analysis tool with the MCP server.
* Provides specialized URL analysis capabilities with intelligent content processing.
*/
export const geminiUrlAnalysisTool = (
server: McpServer,
serviceInstance: GeminiService
): void => {
const processRequest = async (args: unknown) => {
// Parse and validate the arguments
const parsedArgs = z.object(GEMINI_URL_ANALYSIS_PARAMS).parse(args);
logger.debug(`Received ${GEMINI_URL_ANALYSIS_TOOL_NAME} request:`, {
urls: parsedArgs.urls,
analysisType: parsedArgs.analysisType,
urlCount: parsedArgs.urls.length,
});
try {
const {
urls,
analysisType,
query,
extractionSchema,
questions,
compareBy,
outputFormat,
includeMetadata,
fetchOptions,
modelName,
} = parsedArgs;
// Build the analysis prompt based on the analysis type
const prompt = buildAnalysisPrompt({
analysisType,
query,
extractionSchema,
questions,
compareBy,
outputFormat,
urlCount: urls.length,
});
// Prepare URL context for content generation
const urlContext = {
urls,
fetchOptions: {
...fetchOptions,
includeMetadata: includeMetadata ?? true,
convertToMarkdown: true, // Always convert to markdown for better analysis
},
};
// Calculate URL context metrics for optimal model selection
const urlCount = urls.length;
const maxContentKb = fetchOptions?.maxContentKb || 100;
const estimatedUrlContentSize = urlCount * maxContentKb * 1024;
// Select task type based on analysis type
const taskType = getTaskTypeForAnalysis(analysisType);
// Generate analysis using the service
const analysisResult = await serviceInstance.generateContent({
prompt,
modelName,
urlContext,
taskType: taskType as
| "text-generation"
| "image-generation"
| "video-generation"
| "code-review"
| "multimodal"
| "reasoning",
preferQuality: true, // Prefer quality for analysis tasks
complexityHint: urlCount > 5 ? "complex" : "medium",
urlCount,
estimatedUrlContentSize,
systemInstruction: getSystemInstructionForAnalysis(
analysisType,
outputFormat
),
});
// Format the result based on output format
const formattedResult = formatAnalysisResult(
analysisResult,
outputFormat
);
return {
content: [
{
type: "text" as const,
text: formattedResult,
},
],
};
} catch (error: unknown) {
logger.error(`Error processing ${GEMINI_URL_ANALYSIS_TOOL_NAME}:`, error);
throw mapAnyErrorToMcpError(error, GEMINI_URL_ANALYSIS_TOOL_NAME);
}
};
// Register the tool with the server
server.tool(
GEMINI_URL_ANALYSIS_TOOL_NAME,
GEMINI_URL_ANALYSIS_TOOL_DESCRIPTION,
GEMINI_URL_ANALYSIS_PARAMS,
processRequest
);
logger.info(`Tool registered: ${GEMINI_URL_ANALYSIS_TOOL_NAME}`);
};
/**
* Builds the analysis prompt based on the requested analysis type and parameters
*/
function buildAnalysisPrompt(params: {
analysisType: string;
query?: string;
extractionSchema?: Record<string, unknown>;
questions?: string[];
compareBy?: string[];
outputFormat?: string;
urlCount: number;
}): string {
const {
analysisType,
query,
extractionSchema,
questions,
compareBy,
outputFormat,
urlCount,
} = params;
let prompt = `Perform a ${analysisType} analysis on the provided URL content${urlCount > 1 ? "s" : ""}.\n\n`;
switch (analysisType) {
case "summary":
prompt += `Provide a comprehensive summary of the main points, key information, and important insights from the content. `;
if (query) {
prompt += `Focus particularly on: ${query}. `;
}
break;
case "comparison":
if (urlCount < 2) {
prompt += `Since only one URL is provided, analyze the different aspects or sections within the content. `;
} else {
prompt += `Compare and contrast the content from the different URLs, highlighting similarities, differences, and unique aspects. `;
}
if (compareBy && compareBy.length > 0) {
prompt += `Focus your comparison on these specific aspects: ${compareBy.join(", ")}. `;
}
break;
case "extraction":
prompt += `Extract specific information from the content. `;
if (extractionSchema) {
prompt += `Structure the extracted information according to this schema: ${JSON.stringify(extractionSchema, null, 2)}. `;
}
if (query) {
prompt += `Focus on extracting: ${query}. `;
}
break;
case "qa":
prompt += `Answer the following questions based on the content:\n`;
if (questions && questions.length > 0) {
questions.forEach((question, index) => {
prompt += `${index + 1}. ${question}\n`;
});
} else if (query) {
prompt += `Question: ${query}\n`;
} else {
prompt += `Provide comprehensive answers to common questions that would arise from this content.\n`;
}
break;
case "sentiment":
prompt += `Analyze the sentiment and emotional tone of the content. Identify the overall sentiment (positive, negative, neutral) and specific emotional indicators. `;
if (query) {
prompt += `Pay special attention to sentiment regarding: ${query}. `;
}
break;
case "fact-check":
prompt += `Evaluate the factual accuracy and credibility of claims made in the content. Identify verifiable facts, questionable claims, and potential misinformation. `;
if (query) {
prompt += `Focus particularly on claims about: ${query}. `;
}
break;
case "content-classification":
prompt += `Classify and categorize the content by topic, type, audience, and other relevant dimensions. `;
if (query) {
prompt += `Use this classification framework: ${query}. `;
}
break;
case "readability":
prompt += `Analyze the readability, writing quality, and accessibility of the content. Evaluate complexity, clarity, structure, and target audience. `;
break;
case "seo-analysis":
prompt += `Perform an SEO analysis of the content, evaluating keyword usage, content structure, meta information, and optimization opportunities. `;
break;
default:
if (query) {
prompt += `Based on the following instruction: ${query}. `;
}
}
// Add output format instructions
if (outputFormat && outputFormat !== "text") {
switch (outputFormat) {
case "json":
prompt += `\n\nFormat your response as valid JSON with appropriate structure and fields.`;
break;
case "markdown":
prompt += `\n\nFormat your response in well-structured Markdown with appropriate headers, lists, and formatting.`;
break;
case "structured":
prompt += `\n\nOrganize your response in a clear, structured format with distinct sections and subsections.`;
break;
}
}
prompt += `\n\nBe thorough, accurate, and insightful in your analysis.`;
return prompt;
}
/**
* Maps analysis types to task types for model selection
*/
function getTaskTypeForAnalysis(analysisType: string): string {
switch (analysisType) {
case "comparison":
case "fact-check":
case "seo-analysis":
return "reasoning";
case "extraction":
case "content-classification":
return "text-generation";
default:
return "text-generation";
}
}
/**
* Generates system instructions based on analysis type and output format
*/
function getSystemInstructionForAnalysis(
analysisType: string,
outputFormat?: string
): string {
let instruction = `You are an expert content analyst specializing in ${analysisType} analysis. `;
switch (analysisType) {
case "summary":
instruction += `Provide concise yet comprehensive summaries that capture the essence and key insights of the content.`;
break;
case "comparison":
instruction += `Excel at identifying similarities, differences, and patterns across different content sources.`;
break;
case "extraction":
instruction += `Focus on accurately identifying and extracting specific information while maintaining context and relevance.`;
break;
case "qa":
instruction += `Provide clear, accurate, and well-supported answers based on the available content.`;
break;
case "sentiment":
instruction += `Accurately identify emotional tone, sentiment indicators, and subjective language patterns.`;
break;
case "fact-check":
instruction += `Evaluate claims critically, distinguish between facts and opinions, and identify potential misinformation.`;
break;
case "content-classification":
instruction += `Categorize content accurately using relevant taxonomies and classification frameworks.`;
break;
case "readability":
instruction += `Assess content accessibility, complexity, and effectiveness for target audiences.`;
break;
case "seo-analysis":
instruction += `Evaluate content from an SEO perspective, focusing on optimization opportunities and best practices.`;
break;
}
if (outputFormat === "json") {
instruction += ` Always respond with valid, well-structured JSON.`;
} else if (outputFormat === "markdown") {
instruction += ` Use proper Markdown formatting with clear headers and structure.`;
}
instruction += ` Base your analysis strictly on the provided content and clearly distinguish between what is explicitly stated and what is inferred.`;
return instruction;
}
/**
* Formats the analysis result based on the requested output format
*/
function formatAnalysisResult(result: string, outputFormat?: string): string {
if (!outputFormat || outputFormat === "text") {
return result;
}
// For other formats, the formatting should have been handled by the model
// based on the prompt instructions, so we return the result as-is
return result;
}
```
--------------------------------------------------------------------------------
/tests/unit/utils/UrlSecurityService.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import { UrlSecurityService } from "../../../src/utils/UrlSecurityService.js";
import { ConfigurationManager } from "../../../src/config/ConfigurationManager.js";
import { GeminiUrlValidationError } from "../../../src/utils/geminiErrors.js";
// Mock dependencies
vi.mock("../../../src/config/ConfigurationManager.js");
vi.mock("../../../src/utils/logger.js");
interface MockConfigManager {
getUrlContextConfig: ReturnType<typeof vi.fn>;
}
describe("UrlSecurityService", () => {
let service: UrlSecurityService;
let mockConfig: MockConfigManager;
beforeEach(() => {
vi.clearAllMocks();
mockConfig = {
getUrlContextConfig: vi.fn().mockReturnValue({
allowedDomains: ["*"],
blocklistedDomains: [],
}),
};
service = new UrlSecurityService(mockConfig as ConfigurationManager);
});
describe("URL format validation", () => {
it("should accept valid HTTP URLs", async () => {
await expect(
service.validateUrl("http://example.com")
).resolves.not.toThrow();
});
it("should accept valid HTTPS URLs", async () => {
await expect(
service.validateUrl("https://example.com")
).resolves.not.toThrow();
});
it("should reject invalid URL formats", async () => {
await expect(service.validateUrl("not-a-url")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl(" ")).rejects.toThrow(
GeminiUrlValidationError
);
});
it("should reject non-HTTP protocols", async () => {
await expect(service.validateUrl("ftp://example.com")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("file:///etc/passwd")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("javascript:alert(1)")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(
service.validateUrl("data:text/plain;base64,SGVsbG8=")
).rejects.toThrow(GeminiUrlValidationError);
});
});
describe("Domain validation", () => {
it("should allow domains in allowlist", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["example.com", "test.org"],
blocklistedDomains: [],
});
await expect(
service.validateUrl("https://example.com")
).resolves.not.toThrow();
await expect(
service.validateUrl("https://test.org")
).resolves.not.toThrow();
});
it("should reject domains not in allowlist", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["example.com"],
blocklistedDomains: [],
});
await expect(
service.validateUrl("https://malicious.com")
).rejects.toThrow(GeminiUrlValidationError);
});
it("should handle wildcard allowlist", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["*"],
blocklistedDomains: [],
});
await expect(
service.validateUrl("https://any-domain.com")
).resolves.not.toThrow();
});
it("should handle subdomain patterns", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["*.example.com"],
blocklistedDomains: [],
});
await expect(
service.validateUrl("https://sub.example.com")
).resolves.not.toThrow();
await expect(
service.validateUrl("https://deep.sub.example.com")
).resolves.not.toThrow();
await expect(
service.validateUrl("https://example.com")
).resolves.not.toThrow();
await expect(service.validateUrl("https://other.com")).rejects.toThrow(
GeminiUrlValidationError
);
});
it("should block domains in blocklist", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["*"],
blocklistedDomains: ["malicious.com", "spam.net"],
});
await expect(
service.validateUrl("https://malicious.com")
).rejects.toThrow(GeminiUrlValidationError);
await expect(service.validateUrl("https://spam.net")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(
service.validateUrl("https://safe.com")
).resolves.not.toThrow();
});
it("should block subdomains of blocklisted domains", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["*"],
blocklistedDomains: ["malicious.com"],
});
await expect(
service.validateUrl("https://sub.malicious.com")
).rejects.toThrow(GeminiUrlValidationError);
});
});
describe("Private network protection", () => {
it("should block localhost addresses", async () => {
await expect(service.validateUrl("http://localhost")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("http://127.0.0.1")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("http://0.0.0.0")).rejects.toThrow(
GeminiUrlValidationError
);
});
it("should block private IP ranges", async () => {
await expect(service.validateUrl("http://192.168.1.1")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("http://10.0.0.1")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("http://172.16.0.1")).rejects.toThrow(
GeminiUrlValidationError
);
});
it("should block internal domain extensions", async () => {
await expect(service.validateUrl("http://server.local")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("http://api.internal")).rejects.toThrow(
GeminiUrlValidationError
);
await expect(service.validateUrl("http://db.corp")).rejects.toThrow(
GeminiUrlValidationError
);
});
it("should allow public IP addresses", async () => {
await expect(
service.validateUrl("http://8.8.8.8")
).resolves.not.toThrow();
await expect(
service.validateUrl("http://1.1.1.1")
).resolves.not.toThrow();
});
});
describe("Suspicious pattern detection", () => {
it("should detect path traversal attempts", async () => {
await expect(
service.validateUrl("http://example.com/../../../etc/passwd")
).rejects.toThrow(GeminiUrlValidationError);
await expect(
service.validateUrl("http://example.com/path/with/../dots")
).rejects.toThrow(GeminiUrlValidationError);
});
it("should detect dangerous characters", async () => {
await expect(
service.validateUrl("http://example.com/path<script>")
).rejects.toThrow(GeminiUrlValidationError);
await expect(
service.validateUrl("http://example.com/path{malicious}")
).rejects.toThrow(GeminiUrlValidationError);
});
it("should detect multiple @ symbols", async () => {
await expect(
service.validateUrl("http://user@[email protected]")
).rejects.toThrow(GeminiUrlValidationError);
});
it("should allow normal URLs with safe characters", async () => {
await expect(
service.validateUrl(
"https://example.com/path/to/resource?param=value&other=123"
)
).resolves.not.toThrow();
await expect(
service.validateUrl("https://api.example.com/v1/users/123")
).resolves.not.toThrow();
});
});
describe("URL shortener detection", () => {
it("should detect known URL shorteners", async () => {
const shorteners = [
"https://bit.ly/abc123",
"https://tinyurl.com/abc123",
"https://t.co/abc123",
"https://goo.gl/abc123",
];
// Note: These should not throw errors, but should be logged as warnings
for (const url of shorteners) {
await expect(service.validateUrl(url)).resolves.not.toThrow();
}
});
});
describe("IDN homograph attack detection", () => {
it("should detect potentially confusing Unicode domains", async () => {
// Cyrillic characters that look like Latin
await expect(service.validateUrl("https://gоogle.com")).rejects.toThrow(
GeminiUrlValidationError
); // 'о' is Cyrillic
await expect(service.validateUrl("https://аpple.com")).rejects.toThrow(
GeminiUrlValidationError
); // 'а' is Cyrillic
});
it("should allow legitimate Unicode domains", async () => {
await expect(
service.validateUrl("https://example.com")
).resolves.not.toThrow();
await expect(
service.validateUrl("https://测试.example.com")
).resolves.not.toThrow();
});
});
describe("Port validation", () => {
it("should allow standard HTTP/HTTPS ports", async () => {
await expect(
service.validateUrl("http://example.com:80")
).resolves.not.toThrow();
await expect(
service.validateUrl("https://example.com:443")
).resolves.not.toThrow();
await expect(
service.validateUrl("http://example.com:8080")
).resolves.not.toThrow();
await expect(
service.validateUrl("https://example.com:8443")
).resolves.not.toThrow();
});
it("should reject non-standard ports", async () => {
await expect(
service.validateUrl("http://example.com:22")
).rejects.toThrow(GeminiUrlValidationError);
await expect(
service.validateUrl("http://example.com:3389")
).rejects.toThrow(GeminiUrlValidationError);
await expect(
service.validateUrl("http://example.com:1337")
).rejects.toThrow(GeminiUrlValidationError);
});
});
describe("URL length validation", () => {
it("should reject extremely long URLs", async () => {
const longPath = "a".repeat(3000);
const longUrl = `https://example.com/${longPath}`;
await expect(service.validateUrl(longUrl)).rejects.toThrow(
GeminiUrlValidationError
);
});
it("should accept reasonable length URLs", async () => {
const normalPath = "a".repeat(100);
const normalUrl = `https://example.com/${normalPath}`;
await expect(service.validateUrl(normalUrl)).resolves.not.toThrow();
});
});
describe("Random domain detection", () => {
it("should flag potentially randomly generated domains", async () => {
// These should log warnings but not necessarily throw errors
const suspiciousDomains = [
"https://xkcd123456789.com",
"https://aaaaaaaaaaaa.com",
"https://1234567890abcd.com",
];
for (const url of suspiciousDomains) {
// Should not throw, but may log warnings
await expect(service.validateUrl(url)).resolves.not.toThrow();
}
});
});
describe("Security metrics", () => {
it("should track validation attempts and failures", async () => {
const initialMetrics = service.getSecurityMetrics();
expect(initialMetrics.validationAttempts).toBe(0);
expect(initialMetrics.validationFailures).toBe(0);
// Valid URL
await service.validateUrl("https://example.com").catch(() => {});
// Invalid URL
await service.validateUrl("invalid-url").catch(() => {});
const updatedMetrics = service.getSecurityMetrics();
expect(updatedMetrics.validationAttempts).toBe(2);
expect(updatedMetrics.validationFailures).toBe(1);
});
it("should track blocked domains", async () => {
mockConfig.getUrlContextConfig.mockReturnValue({
allowedDomains: ["*"],
blocklistedDomains: ["malicious.com"],
});
await service.validateUrl("https://malicious.com").catch(() => {});
const metrics = service.getSecurityMetrics();
expect(metrics.blockedDomains.has("malicious.com")).toBe(true);
});
it("should allow resetting metrics", () => {
service.resetSecurityMetrics();
const metrics = service.getSecurityMetrics();
expect(metrics.validationAttempts).toBe(0);
expect(metrics.validationFailures).toBe(0);
expect(metrics.blockedDomains.size).toBe(0);
expect(metrics.suspiciousPatterns).toHaveLength(0);
});
});
describe("Custom domain management", () => {
it("should allow adding custom malicious domains", () => {
service.addMaliciousDomain("custom-malicious.com");
// This should not throw immediately since domain checking happens in validateUrl
expect(() => service.addMaliciousDomain("another-bad.com")).not.toThrow();
});
});
describe("URL accessibility checking", () => {
it("should check URL accessibility", async () => {
// Mock fetch for accessibility check
const mockFetch = vi.fn().mockResolvedValue({
ok: true,
status: 200,
});
global.fetch = mockFetch;
const isAccessible = await service.checkUrlAccessibility(
"https://example.com"
);
expect(isAccessible).toBe(true);
expect(mockFetch).toHaveBeenCalledWith(
"https://example.com",
expect.objectContaining({
method: "HEAD",
})
);
});
it("should handle inaccessible URLs", async () => {
const mockFetch = vi.fn().mockRejectedValue(new Error("Network error"));
global.fetch = mockFetch;
const isAccessible = await service.checkUrlAccessibility(
"https://unreachable.com"
);
expect(isAccessible).toBe(false);
});
});
});
```
--------------------------------------------------------------------------------
/src/services/gemini/GeminiContentService.ts:
--------------------------------------------------------------------------------
```typescript
import { GoogleGenAI } from "@google/genai";
import {
GeminiApiError,
GeminiValidationError,
mapGeminiError,
} from "../../utils/geminiErrors.js";
import { logger } from "../../utils/logger.js";
import {
Content,
GenerationConfig,
SafetySetting,
Part,
ThinkingConfig,
} from "./GeminiTypes.js";
import { ZodError } from "zod";
import { validateGenerateContentParams } from "./GeminiValidationSchemas.js";
import { RetryService } from "../../utils/RetryService.js";
import { GeminiUrlContextService } from "./GeminiUrlContextService.js";
import { ConfigurationManager } from "../../config/ConfigurationManager.js";
// Request configuration type definition for reuse
interface RequestConfig {
model: string;
contents: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
systemInstruction?: Content;
cachedContent?: string;
thinkingConfig?: ThinkingConfig;
}
/**
* Interface for URL context parameters
*/
interface UrlContextParams {
urls: string[];
fetchOptions?: {
maxContentKb?: number;
timeoutMs?: number;
includeMetadata?: boolean;
convertToMarkdown?: boolean;
allowedDomains?: string[];
userAgent?: string;
};
}
/**
* Interface for the parameters of the generateContent method
* This interface is used internally, while the parent GeminiService exports a compatible version
*/
interface GenerateContentParams {
prompt: string;
modelName?: string;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
systemInstruction?: Content | string;
cachedContentName?: string;
urlContext?: UrlContextParams;
}
/**
* Default retry options for Gemini API calls
*/
const DEFAULT_RETRY_OPTIONS = {
maxAttempts: 3,
initialDelayMs: 500,
maxDelayMs: 10000,
backoffFactor: 2,
jitter: true,
onRetry: (error: unknown, attempt: number, delayMs: number) => {
logger.warn(
`Retrying Gemini API call after error (attempt ${attempt}, delay: ${delayMs}ms): ${error instanceof Error ? error.message : String(error)}`
);
},
};
/**
* Service for handling content generation related operations for the Gemini service.
* Manages content generation in both streaming and non-streaming modes.
*/
export class GeminiContentService {
private genAI: GoogleGenAI;
private defaultModelName?: string;
private defaultThinkingBudget?: number;
private retryService: RetryService;
private configManager: ConfigurationManager;
private urlContextService: GeminiUrlContextService;
/**
* Creates a new instance of the GeminiContentService.
* @param genAI The GoogleGenAI instance to use for API calls
* @param defaultModelName Optional default model name to use if not specified in method calls
* @param defaultThinkingBudget Optional default budget for reasoning (thinking) tokens
*/
constructor(
genAI: GoogleGenAI,
defaultModelName?: string,
defaultThinkingBudget?: number
) {
this.genAI = genAI;
this.defaultModelName = defaultModelName;
this.defaultThinkingBudget = defaultThinkingBudget;
this.retryService = new RetryService(DEFAULT_RETRY_OPTIONS);
this.configManager = ConfigurationManager.getInstance();
this.urlContextService = new GeminiUrlContextService(this.configManager);
}
/**
* Streams content generation using the Gemini model.
* Returns an async generator that yields text chunks as they are generated.
*
* @param params An object containing all necessary parameters for content generation
* @returns An async generator yielding text chunks as they become available
*/
public async *generateContentStream(
params: GenerateContentParams
): AsyncGenerator<string> {
// Log with truncated prompt for privacy/security
logger.debug(
`generateContentStream called with prompt: ${params.prompt.substring(0, 30)}...`
);
try {
// Validate parameters using Zod schema
try {
// Create a proper object for validation
const validationParams: Record<string, unknown> = {
prompt: params.prompt,
modelName: params.modelName,
generationConfig: params.generationConfig,
safetySettings: params.safetySettings,
systemInstruction: params.systemInstruction,
cachedContentName: params.cachedContentName,
};
validateGenerateContentParams(validationParams);
} catch (validationError: unknown) {
if (validationError instanceof ZodError) {
const fieldErrors = validationError.errors
.map((err) => `${err.path.join(".")}: ${err.message}`)
.join(", ");
throw new GeminiValidationError(
`Invalid parameters for content generation: ${fieldErrors}`,
validationError.errors[0]?.path.join(".")
);
}
throw validationError;
}
// Create the request configuration using the helper method
const requestConfig = await this.createRequestConfig(params);
// Call generateContentStream with retry
// Note: We can't use the retry service directly here because we need to handle streaming
// Instead, we'll add retry logic to the initial API call, but not the streaming part
let streamResult;
try {
streamResult = await this.retryService.execute(async () => {
return this.genAI.models.generateContentStream(requestConfig);
});
} catch (error: unknown) {
throw mapGeminiError(error, "generateContentStream");
}
// Stream the results (no retry for individual chunks)
try {
for await (const chunk of streamResult) {
// Extract text from the chunk if available - text is a getter, not a method
const chunkText = chunk.text;
if (chunkText) {
yield chunkText;
}
}
} catch (error: unknown) {
throw mapGeminiError(error, "generateContentStream");
}
} catch (error: unknown) {
// Map to appropriate error type for any other errors
throw mapGeminiError(error, "generateContentStream");
}
}
/**
* Creates the request configuration object for both content generation methods.
* This helper method reduces code duplication between generateContent and generateContentStream.
*
* @param params The content generation parameters
* @returns A properly formatted request configuration object
* @throws GeminiApiError if parameters are invalid or model name is missing
*/
private async createRequestConfig(
params: GenerateContentParams
): Promise<RequestConfig> {
const {
prompt,
modelName,
generationConfig,
safetySettings,
systemInstruction,
cachedContentName,
urlContext,
} = params;
const effectiveModelName = modelName ?? this.defaultModelName;
if (!effectiveModelName) {
throw new GeminiValidationError(
"Model name must be provided either as a parameter or via the GOOGLE_GEMINI_MODEL environment variable.",
"modelName"
);
}
logger.debug(`Creating request config for model: ${effectiveModelName}`);
// Construct base content parts array
const contentParts: Part[] = [];
// Process URL context first if provided
if (urlContext?.urls && urlContext.urls.length > 0) {
const urlConfig = this.configManager.getUrlContextConfig();
if (!urlConfig.enabled) {
throw new GeminiValidationError(
"URL context feature is not enabled. Set GOOGLE_GEMINI_ENABLE_URL_CONTEXT=true to enable.",
"urlContext"
);
}
try {
logger.debug(`Processing ${urlContext.urls.length} URLs for context`);
const urlFetchOptions = {
maxContentLength:
(urlContext.fetchOptions?.maxContentKb ||
urlConfig.defaultMaxContentKb) * 1024,
timeout:
urlContext.fetchOptions?.timeoutMs || urlConfig.defaultTimeoutMs,
includeMetadata:
urlContext.fetchOptions?.includeMetadata ??
urlConfig.includeMetadata,
convertToMarkdown:
urlContext.fetchOptions?.convertToMarkdown ??
urlConfig.convertToMarkdown,
allowedDomains:
urlContext.fetchOptions?.allowedDomains || urlConfig.allowedDomains,
userAgent: urlContext.fetchOptions?.userAgent || urlConfig.userAgent,
};
const { contents: urlContents, batchResult } =
await this.urlContextService.processUrlsForContext(
urlContext.urls,
urlFetchOptions
);
// Log the batch result for monitoring
logger.info("URL context processing completed", {
totalUrls: batchResult.summary.totalUrls,
successful: batchResult.summary.successCount,
failed: batchResult.summary.failureCount,
totalContentSize: batchResult.summary.totalContentSize,
avgResponseTime: batchResult.summary.averageResponseTime,
});
// Add URL content parts to the beginning (before the user's prompt)
for (const urlContent of urlContents) {
if (urlContent.parts) {
contentParts.push(...urlContent.parts);
}
}
// Log any failed URLs as warnings
if (batchResult.failed.length > 0) {
for (const failure of batchResult.failed) {
logger.warn("Failed to fetch URL for context", {
url: failure.url,
error: failure.error.message,
errorCode: failure.errorCode,
});
}
}
} catch (error) {
logger.error("URL context processing failed", { error });
// Depending on configuration, we could either fail the request or continue without URL context
// For now, we'll throw the error to fail fast
throw mapGeminiError(error, "URL context processing");
}
}
// Add the user's prompt after URL context
contentParts.push({ text: prompt });
// Process systemInstruction if it's a string
let formattedSystemInstruction: Content | undefined;
if (systemInstruction) {
if (typeof systemInstruction === "string") {
formattedSystemInstruction = {
parts: [{ text: systemInstruction }],
};
} else {
formattedSystemInstruction = systemInstruction;
}
}
// Create the request configuration for v0.10.0
const requestConfig: RequestConfig = {
model: effectiveModelName,
contents: [{ role: "user", parts: contentParts }],
};
// Add optional parameters if provided
if (generationConfig) {
requestConfig.generationConfig = generationConfig;
// Extract thinking config if it exists within generation config
if (generationConfig.thinkingConfig) {
requestConfig.thinkingConfig = generationConfig.thinkingConfig;
}
}
// Map reasoningEffort to thinkingBudget if provided
if (requestConfig.thinkingConfig?.reasoningEffort) {
const effortMap: Record<string, number> = {
none: 0,
low: 1024, // 1K tokens
medium: 8192, // 8K tokens
high: 24576, // 24K tokens
};
requestConfig.thinkingConfig.thinkingBudget =
effortMap[requestConfig.thinkingConfig.reasoningEffort];
logger.debug(
`Mapped reasoning effort '${requestConfig.thinkingConfig.reasoningEffort}' to thinking budget: ${requestConfig.thinkingConfig.thinkingBudget} tokens`
);
}
// Apply default thinking budget if available and not specified in request
if (
this.defaultThinkingBudget !== undefined &&
!requestConfig.thinkingConfig
) {
requestConfig.thinkingConfig = {
thinkingBudget: this.defaultThinkingBudget,
};
logger.debug(
`Applied default thinking budget: ${this.defaultThinkingBudget} tokens`
);
}
if (safetySettings) {
requestConfig.safetySettings = safetySettings;
}
if (formattedSystemInstruction) {
requestConfig.systemInstruction = formattedSystemInstruction;
}
if (cachedContentName) {
requestConfig.cachedContent = cachedContentName;
}
return requestConfig;
}
/**
* Generates content using the Gemini model with automatic retries for transient errors.
* Uses exponential backoff to avoid overwhelming the API during temporary issues.
*
* @param params An object containing all necessary parameters for content generation
* @returns A promise resolving to the generated text content
*/
public async generateContent(params: GenerateContentParams): Promise<string> {
// Log with truncated prompt for privacy/security
logger.debug(
`generateContent called with prompt: ${params.prompt.substring(0, 30)}...`
);
try {
// Validate parameters using Zod schema
try {
// Create a proper object for validation
const validationParams: Record<string, unknown> = {
prompt: params.prompt,
modelName: params.modelName,
generationConfig: params.generationConfig,
safetySettings: params.safetySettings,
systemInstruction: params.systemInstruction,
cachedContentName: params.cachedContentName,
};
validateGenerateContentParams(validationParams);
} catch (validationError: unknown) {
if (validationError instanceof ZodError) {
const fieldErrors = validationError.errors
.map((err) => `${err.path.join(".")}: ${err.message}`)
.join(", ");
throw new GeminiValidationError(
`Invalid parameters for content generation: ${fieldErrors}`,
validationError.errors[0]?.path.join(".")
);
}
throw validationError;
}
// Create the request configuration using the helper method
const requestConfig = await this.createRequestConfig(params);
// Call generateContent with retry logic
return await this.retryService.execute(async () => {
const result = await this.genAI.models.generateContent(requestConfig);
// Handle potentially undefined text property
if (!result.text) {
throw new GeminiApiError("No text was generated in the response");
}
return result.text;
});
} catch (error: unknown) {
// Map to appropriate error type
throw mapGeminiError(error, "generateContent");
}
}
}
```
--------------------------------------------------------------------------------
/tests/unit/tools/geminiCacheTool.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import { geminiCacheTool } from "../../../src/tools/geminiCacheTool.js";
import { GeminiApiError } from "../../../src/utils/errors.js";
import { McpError } from "@modelcontextprotocol/sdk/types.js";
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import { GeminiService } from "../../../src/services/index.js";
describe("geminiCacheTool", () => {
// Mock server and service instances
const mockTool = vi.fn();
const mockServer = {
tool: mockTool,
} as unknown as McpServer;
// Create mock functions for the service methods
const mockCreateCache = vi.fn();
const mockListCaches = vi.fn();
const mockGetCache = vi.fn();
const mockUpdateCache = vi.fn();
const mockDeleteCache = vi.fn();
// Create a minimal mock service with just the necessary methods for testing
const mockService = {
createCache: mockCreateCache,
listCaches: mockListCaches,
getCache: mockGetCache,
updateCache: mockUpdateCache,
deleteCache: mockDeleteCache,
// Add empty implementations for required GeminiService methods
generateContent: () => Promise.resolve("mock"),
} as unknown as GeminiService;
// Reset mocks before each test
beforeEach(() => {
vi.resetAllMocks();
});
it("should register the tool with the server", () => {
// Call the tool registration function
geminiCacheTool(mockServer, mockService);
// Verify tool was registered
expect(mockTool).toHaveBeenCalledTimes(1);
const [name, description, params, handler] = mockTool.mock.calls[0];
// Check tool registration parameters
expect(name).toBe("gemini_cache");
expect(description).toContain("Manages cached content resources");
expect(params).toBeDefined();
expect(typeof handler).toBe("function");
});
describe("create operation", () => {
it("should create a cache successfully", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
const mockCacheMetadata = {
name: "cachedContents/abc123xyz",
displayName: "Test Cache",
model: "gemini-1.5-flash",
createTime: "2024-01-01T00:00:00Z",
updateTime: "2024-01-01T00:00:00Z",
expirationTime: "2024-01-02T00:00:00Z",
state: "ACTIVE",
usageMetadata: {
totalTokenCount: 1000,
},
};
mockCreateCache.mockResolvedValueOnce(mockCacheMetadata);
// Prepare test request
const testRequest = {
operation: "create",
model: "gemini-1.5-flash",
contents: [
{
role: "user",
parts: [{ text: "This is cached content" }],
},
],
displayName: "Test Cache",
ttl: "3600s",
};
// Call the handler
const result = await handler(testRequest);
// Verify the service method was called with correct parameters
expect(mockCreateCache).toHaveBeenCalledWith(
"gemini-1.5-flash",
testRequest.contents,
{
displayName: "Test Cache",
ttl: "3600s",
}
);
// Verify the result
expect(result).toEqual({
content: [
{
type: "text",
text: JSON.stringify(mockCacheMetadata, null, 2),
},
],
});
});
it("should create cache with system instruction and tools", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
const mockCacheMetadata = {
name: "cachedContents/def456xyz",
model: "gemini-1.5-pro",
createTime: "2024-01-01T00:00:00Z",
updateTime: "2024-01-01T00:00:00Z",
};
mockCreateCache.mockResolvedValueOnce(mockCacheMetadata);
// Prepare test request with optional parameters
const testRequest = {
operation: "create",
model: "gemini-1.5-pro",
contents: [
{
role: "user",
parts: [{ text: "Cached content" }],
},
],
systemInstruction: {
role: "system",
parts: [{ text: "You are a helpful assistant" }],
},
tools: [
{
functionDeclarations: [
{
name: "get_weather",
description: "Get weather information",
parameters: {
type: "OBJECT",
properties: {
location: {
type: "STRING",
description: "The location",
},
},
},
},
],
},
],
toolConfig: {
functionCallingConfig: {
mode: "AUTO",
},
},
};
// Call the handler
const result = await handler(testRequest);
expect(result).toBeDefined();
// Verify all parameters were passed
expect(mockCreateCache).toHaveBeenCalledWith(
"gemini-1.5-pro",
testRequest.contents,
expect.objectContaining({
systemInstruction: testRequest.systemInstruction,
tools: testRequest.tools,
toolConfig: testRequest.toolConfig,
})
);
});
it("should throw error if contents is missing", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Prepare test request without contents
const testRequest = {
operation: "create",
model: "gemini-1.5-flash",
};
// Call the handler and expect error
await expect(handler(testRequest)).rejects.toThrow(
"contents is required for operation 'create'"
);
});
});
describe("list operation", () => {
it("should list caches successfully", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
const mockListResult = {
cachedContents: [
{
name: "cachedContents/cache1",
displayName: "Cache 1",
model: "gemini-1.5-flash",
state: "ACTIVE",
},
{
name: "cachedContents/cache2",
displayName: "Cache 2",
model: "gemini-1.5-pro",
state: "ACTIVE",
},
],
nextPageToken: "token123",
};
mockListCaches.mockResolvedValueOnce(mockListResult);
// Prepare test request
const testRequest = {
operation: "list",
pageSize: 50,
pageToken: "previousToken",
};
// Call the handler
const result = await handler(testRequest);
// Verify the service method was called
expect(mockListCaches).toHaveBeenCalledWith(50, "previousToken");
// Verify the result
expect(result).toEqual({
content: [
{
type: "text",
text: JSON.stringify(mockListResult, null, 2),
},
],
});
});
});
describe("get operation", () => {
it("should get cache metadata successfully", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
const mockCacheMetadata = {
name: "cachedContents/abc123xyz",
displayName: "Test Cache",
model: "gemini-1.5-flash",
createTime: "2024-01-01T00:00:00Z",
updateTime: "2024-01-01T00:00:00Z",
expirationTime: "2024-01-02T00:00:00Z",
state: "ACTIVE",
};
mockGetCache.mockResolvedValueOnce(mockCacheMetadata);
// Prepare test request
const testRequest = {
operation: "get",
cacheName: "cachedContents/abc123xyz",
};
// Call the handler
const result = await handler(testRequest);
// Verify the service method was called
expect(mockGetCache).toHaveBeenCalledWith("cachedContents/abc123xyz");
// Verify the result
expect(result).toEqual({
content: [
{
type: "text",
text: JSON.stringify(mockCacheMetadata, null, 2),
},
],
});
});
it("should throw error if cacheName is missing", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Prepare test request without cacheName
const testRequest = {
operation: "get",
};
// Call the handler and expect error
await expect(handler(testRequest)).rejects.toThrow(
"cacheName is required for operation 'get'"
);
});
it("should throw error if cacheName format is invalid", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Prepare test request with invalid cacheName
const testRequest = {
operation: "get",
cacheName: "invalid-format",
};
// Call the handler and expect error
await expect(handler(testRequest)).rejects.toThrow(
"cacheName must start with 'cachedContents/'"
);
});
});
describe("update operation", () => {
it("should update cache with TTL successfully", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
const mockUpdatedMetadata = {
name: "cachedContents/abc123xyz",
displayName: "Test Cache",
model: "gemini-1.5-flash",
updateTime: "2024-01-01T01:00:00Z",
expirationTime: "2024-01-03T00:00:00Z",
};
mockUpdateCache.mockResolvedValueOnce(mockUpdatedMetadata);
// Prepare test request
const testRequest = {
operation: "update",
cacheName: "cachedContents/abc123xyz",
ttl: "7200s",
};
// Call the handler
const result = await handler(testRequest);
// Verify the service method was called
expect(mockUpdateCache).toHaveBeenCalledWith("cachedContents/abc123xyz", {
ttl: "7200s",
});
// Verify the result
expect(result).toEqual({
content: [
{
type: "text",
text: JSON.stringify(mockUpdatedMetadata, null, 2),
},
],
});
});
it("should update cache with displayName successfully", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
const mockUpdatedMetadata = {
name: "cachedContents/abc123xyz",
displayName: "Updated Cache Name",
model: "gemini-1.5-flash",
updateTime: "2024-01-01T01:00:00Z",
};
mockUpdateCache.mockResolvedValueOnce(mockUpdatedMetadata);
// Prepare test request
const testRequest = {
operation: "update",
cacheName: "cachedContents/abc123xyz",
displayName: "Updated Cache Name",
};
// Call the handler
const result = await handler(testRequest);
expect(result).toBeDefined();
// Verify the service method was called
expect(mockUpdateCache).toHaveBeenCalledWith("cachedContents/abc123xyz", {
displayName: "Updated Cache Name",
});
});
it("should throw error if neither ttl nor displayName is provided", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Prepare test request without update fields
const testRequest = {
operation: "update",
cacheName: "cachedContents/abc123xyz",
};
// Call the handler and expect error
await expect(handler(testRequest)).rejects.toThrow(
"At least one of 'ttl' or 'displayName' must be provided for update operation"
);
});
});
describe("delete operation", () => {
it("should delete cache successfully", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock successful response
mockDeleteCache.mockResolvedValueOnce({ success: true });
// Prepare test request
const testRequest = {
operation: "delete",
cacheName: "cachedContents/abc123xyz",
};
// Call the handler
const result = await handler(testRequest);
// Verify the service method was called
expect(mockDeleteCache).toHaveBeenCalledWith("cachedContents/abc123xyz");
// Verify the result
expect(result).toEqual({
content: [
{
type: "text",
text: JSON.stringify({
success: true,
message: "Cache cachedContents/abc123xyz deleted successfully",
}),
},
],
});
});
});
describe("error handling", () => {
it("should map GeminiApiError to McpError", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Mock service to throw GeminiApiError
const geminiError = new GeminiApiError("API error occurred");
mockListCaches.mockRejectedValueOnce(geminiError);
// Prepare test request
const testRequest = {
operation: "list",
};
// Call the handler and expect McpError
try {
await handler(testRequest);
expect.fail("Should have thrown an error");
} catch (error) {
expect(error).toBeInstanceOf(McpError);
expect((error as McpError).message).toContain("API error occurred");
}
});
it("should handle invalid operation", async () => {
// Register tool to get the request handler
geminiCacheTool(mockServer, mockService);
const [, , , handler] = mockTool.mock.calls[0];
// Prepare test request with invalid operation
const testRequest = {
operation: "invalid_operation",
};
// Call the handler and expect error
await expect(handler(testRequest)).rejects.toThrow(
"Invalid operation: invalid_operation"
);
});
});
});
```
--------------------------------------------------------------------------------
/tests/integration/urlContextIntegration.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import { GeminiService } from "../../src/services/GeminiService.js";
import { ConfigurationManager } from "../../src/config/ConfigurationManager.js";
// Mock external dependencies
vi.mock("../../src/config/ConfigurationManager.js");
vi.mock("@google/genai");
// Mock fetch globally for URL fetching tests
const mockFetch = vi.fn();
global.fetch = mockFetch;
interface MockConfigInstance {
getGeminiServiceConfig: ReturnType<typeof vi.fn>;
getUrlContextConfig: ReturnType<typeof vi.fn>;
}
interface MockConfig {
getInstance: ReturnType<typeof vi.fn<[], MockConfigInstance>>;
}
describe("URL Context Integration Tests", () => {
let geminiService: GeminiService;
let mockConfig: MockConfig;
beforeEach(async () => {
vi.clearAllMocks();
// Mock configuration with URL context enabled
mockConfig = {
getInstance: vi.fn().mockReturnValue({
getGeminiServiceConfig: vi.fn().mockReturnValue({
apiKey: "test-api-key",
defaultModel: "gemini-2.5-flash-preview-05-20",
}),
getUrlContextConfig: vi.fn().mockReturnValue({
enabled: true,
maxUrlsPerRequest: 20,
defaultMaxContentKb: 100,
defaultTimeoutMs: 10000,
allowedDomains: ["*"],
blocklistedDomains: [],
convertToMarkdown: true,
includeMetadata: true,
enableCaching: true,
cacheExpiryMinutes: 15,
maxCacheSize: 1000,
rateLimitPerDomainPerMinute: 10,
userAgent: "MCP-Gemini-Server/1.0",
}),
}),
};
ConfigurationManager.getInstance = mockConfig.getInstance;
// Mock Gemini API
const mockGenAI = {
models: {
generateContent: vi.fn().mockResolvedValue({
text: "Generated response based on URL content",
}),
generateContentStream: vi.fn().mockImplementation(async function* () {
yield "Generated ";
yield "response ";
yield "based on ";
yield "URL content";
}),
},
};
const { GoogleGenAI } = await import("@google/genai");
vi.mocked(GoogleGenAI).mockImplementation(() => mockGenAI);
geminiService = new GeminiService();
});
afterEach(() => {
vi.resetAllMocks();
});
describe("URL Context with Content Generation", () => {
it("should successfully generate content with single URL context", async () => {
const mockHtmlContent = `
<!DOCTYPE html>
<html>
<head>
<title>Test Article</title>
<meta name="description" content="A comprehensive guide to testing">
</head>
<body>
<h1>Introduction to Testing</h1>
<p>Testing is essential for software quality assurance.</p>
<h2>Types of Testing</h2>
<ul>
<li>Unit Testing</li>
<li>Integration Testing</li>
<li>End-to-End Testing</li>
</ul>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: "https://example.com/testing-guide",
headers: new Map([
["content-type", "text/html; charset=utf-8"],
["content-length", mockHtmlContent.length.toString()],
]),
text: () => Promise.resolve(mockHtmlContent),
});
const result = await geminiService.generateContent({
prompt: "Summarize the main points from the provided article",
urlContext: {
urls: ["https://example.com/testing-guide"],
fetchOptions: {
maxContentKb: 50,
includeMetadata: true,
},
},
});
expect(result).toBeDefined();
expect(result).toBe("Generated response based on URL content");
expect(mockFetch).toHaveBeenCalledTimes(1);
});
it("should handle multiple URLs in context", async () => {
const mockContent1 = `
<html>
<head><title>Article 1</title></head>
<body><p>Content from first article about React development.</p></body>
</html>
`;
const mockContent2 = `
<html>
<head><title>Article 2</title></head>
<body><p>Content from second article about Vue.js development.</p></body>
</html>
`;
mockFetch
.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example1.com/react",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockContent1),
})
.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example2.com/vue",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockContent2),
});
const result = await geminiService.generateContent({
prompt:
"Compare the development approaches mentioned in these articles",
urlContext: {
urls: ["https://example1.com/react", "https://example2.com/vue"],
},
});
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledTimes(2);
});
it("should work with streaming content generation", async () => {
const mockJsonContent = JSON.stringify({
title: "API Documentation",
endpoints: [
{ path: "/users", method: "GET" },
{ path: "/users", method: "POST" },
],
});
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://api.example.com/docs",
headers: new Map([["content-type", "application/json"]]),
text: () => Promise.resolve(mockJsonContent),
});
const chunks: string[] = [];
for await (const chunk of geminiService.generateContentStream({
prompt: "Explain the API endpoints described in the documentation",
urlContext: {
urls: ["https://api.example.com/docs"],
fetchOptions: {
convertToMarkdown: false, // Keep JSON as-is
},
},
})) {
chunks.push(chunk);
}
const fullResponse = chunks.join("");
expect(fullResponse).toBe("Generated response based on URL content");
expect(mockFetch).toHaveBeenCalledTimes(1);
});
});
describe("URL Context Error Handling", () => {
it("should handle URL fetch failures gracefully", async () => {
mockFetch.mockRejectedValueOnce(new Error("Network error"));
await expect(
geminiService.generateContent({
prompt: "Analyze the content from this URL",
urlContext: {
urls: ["https://unreachable.com"],
},
})
).rejects.toThrow();
});
it("should handle mixed success/failure scenarios", async () => {
const mockSuccessContent =
"<html><body><p>Successful content</p></body></html>";
mockFetch
.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://success.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockSuccessContent),
})
.mockRejectedValueOnce(new Error("Failed to fetch"))
.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://success2.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockSuccessContent),
});
// This should continue processing successful URLs despite some failures
const result = await geminiService.generateContent({
prompt: "Summarize the available content",
urlContext: {
urls: [
"https://success.com",
"https://failed.com",
"https://success2.com",
],
},
});
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledTimes(3);
});
it("should respect URL context disabled configuration", async () => {
mockConfig.getInstance().getUrlContextConfig.mockReturnValue({
enabled: false,
maxUrlsPerRequest: 20,
defaultMaxContentKb: 100,
defaultTimeoutMs: 10000,
allowedDomains: ["*"],
blocklistedDomains: [],
});
await expect(
geminiService.generateContent({
prompt: "Analyze this content",
urlContext: {
urls: ["https://example.com"],
},
})
).rejects.toThrow("URL context feature is not enabled");
expect(mockFetch).not.toHaveBeenCalled();
});
});
describe("URL Security Integration", () => {
it("should block access to private networks", async () => {
await expect(
geminiService.generateContent({
prompt: "Analyze the content",
urlContext: {
urls: ["http://192.168.1.1/admin"],
},
})
).rejects.toThrow();
expect(mockFetch).not.toHaveBeenCalled();
});
it("should respect domain restrictions", async () => {
mockConfig.getInstance().getUrlContextConfig.mockReturnValue({
enabled: true,
maxUrlsPerRequest: 20,
defaultMaxContentKb: 100,
defaultTimeoutMs: 10000,
allowedDomains: ["example.com"],
blocklistedDomains: [],
});
// Allowed domain should work
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve("<html><body>Content</body></html>"),
});
await geminiService.generateContent({
prompt: "Analyze this content",
urlContext: {
urls: ["https://example.com"],
},
});
expect(mockFetch).toHaveBeenCalledTimes(1);
// Disallowed domain should fail
await expect(
geminiService.generateContent({
prompt: "Analyze this content",
urlContext: {
urls: ["https://other.com"],
},
})
).rejects.toThrow();
});
it("should enforce URL count limits", async () => {
const manyUrls = Array.from(
{ length: 25 },
(_, i) => `https://example${i}.com`
);
await expect(
geminiService.generateContent({
prompt: "Analyze all these URLs",
urlContext: {
urls: manyUrls,
},
})
).rejects.toThrow("Too many URLs");
expect(mockFetch).not.toHaveBeenCalled();
});
});
describe("Content Processing Integration", () => {
it("should correctly convert HTML to Markdown", async () => {
const complexHtml = `
<html>
<head><title>Complex Document</title></head>
<body>
<h1>Main Title</h1>
<p>Paragraph with <strong>bold</strong> and <em>italic</em> text.</p>
<ul>
<li>List item 1</li>
<li>List item 2 with <a href="https://example.com">link</a></li>
</ul>
<blockquote>This is a quote</blockquote>
<code>inline code</code>
<pre>code block</pre>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com/complex",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(complexHtml),
});
await geminiService.generateContent({
prompt: "Process this complex document",
urlContext: {
urls: ["https://example.com/complex"],
fetchOptions: {
convertToMarkdown: true,
includeMetadata: true,
},
},
});
expect(mockFetch).toHaveBeenCalledTimes(1);
// The actual content processing is tested in unit tests
});
it("should handle large content with truncation", async () => {
const largeContent = "x".repeat(500 * 1024); // 500KB content
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com/large",
headers: new Map([
["content-type", "text/html"],
["content-length", largeContent.length.toString()],
]),
text: () => Promise.resolve(largeContent),
});
await geminiService.generateContent({
prompt: "Summarize this large document",
urlContext: {
urls: ["https://example.com/large"],
fetchOptions: {
maxContentKb: 100, // Limit to 100KB
},
},
});
expect(mockFetch).toHaveBeenCalledTimes(1);
});
});
describe("Model Selection Integration", () => {
it("should prefer models with larger context windows for URL-heavy requests", async () => {
const urls = Array.from(
{ length: 15 },
(_, i) => `https://example${i}.com`
);
// Mock multiple successful fetches
for (let i = 0; i < 15; i++) {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: urls[i],
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(`<html><body>Content ${i}</body></html>`),
});
}
const result = await geminiService.generateContent({
prompt: "Analyze and compare all these sources",
urlContext: {
urls,
},
// Don't specify a model - let the service choose based on URL count
taskType: "reasoning",
complexityHint: "complex",
});
expect(result).toBeDefined();
expect(mockFetch).toHaveBeenCalledTimes(15);
});
});
describe("Caching Integration", () => {
it("should cache URL content between requests", async () => {
const mockContent = "<html><body><p>Cached content</p></body></html>";
mockFetch.mockResolvedValue({
ok: true,
status: 200,
url: "https://example.com/cached",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockContent),
});
// First request
await geminiService.generateContent({
prompt: "Analyze this content",
urlContext: {
urls: ["https://example.com/cached"],
},
});
// Second request with same URL - should use cache
await geminiService.generateContent({
prompt: "Different analysis of the same content",
urlContext: {
urls: ["https://example.com/cached"],
},
});
// Should only fetch once due to caching
expect(mockFetch).toHaveBeenCalledTimes(1);
});
});
describe("Rate Limiting Integration", () => {
it("should enforce rate limits per domain", async () => {
const baseUrl = "https://example.com/page";
// Mock successful responses for rate limit testing
for (let i = 0; i < 12; i++) {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: `${baseUrl}${i}`,
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve("<html><body>Content</body></html>"),
});
}
// First 10 requests should succeed
for (let i = 0; i < 10; i++) {
await geminiService.generateContent({
prompt: `Analyze page ${i}`,
urlContext: {
urls: [`${baseUrl}${i}`],
},
});
}
// 11th request should fail due to rate limiting
await expect(
geminiService.generateContent({
prompt: "Analyze page 11",
urlContext: {
urls: [`${baseUrl}11`],
},
})
).rejects.toThrow();
});
});
});
```
--------------------------------------------------------------------------------
/src/utils/errors.ts:
--------------------------------------------------------------------------------
```typescript
/**
* Base custom error class for application-specific errors.
*/
export class BaseError extends Error {
public code: string;
public readonly status: number; // HTTP status code equivalent
public readonly details?: unknown; // Additional details
constructor(
message: string,
code: string,
status: number,
details?: unknown
) {
super(message);
this.name = this.constructor.name; // Set the error name to the class name
this.code = code;
this.status = status;
this.details = details;
// Capture stack trace (excluding constructor)
Error.captureStackTrace(this, this.constructor);
}
}
/**
* Error for validation failures (e.g., invalid input).
* Maps typically to a 400 Bad Request or MCP InvalidParams.
*/
export class ValidationError extends BaseError {
constructor(message: string, details?: unknown) {
super(message, "VALIDATION_ERROR", 400, details);
}
}
/**
* Error when an expected entity or resource is not found.
* Maps typically to a 404 Not Found.
*/
export class NotFoundError extends BaseError {
constructor(message: string = "Resource not found") {
super(message, "NOT_FOUND", 404);
}
}
/**
* Error for configuration problems.
*/
export class ConfigurationError extends BaseError {
constructor(message: string) {
super(message, "CONFIG_ERROR", 500);
}
}
/**
* Error for issues during service processing unrelated to input validation.
* Maps typically to a 500 Internal Server Error or MCP InternalError.
*/
export class ServiceError extends BaseError {
constructor(message: string, details?: unknown) {
super(message, "SERVICE_ERROR", 500, details);
}
}
/**
* Error specifically for issues encountered when interacting with the Google Gemini API.
* Extends ServiceError as it relates to an external service failure.
*/
export class GeminiApiError extends ServiceError {
constructor(message: string, details?: unknown) {
// Call ServiceError constructor with only message and details
super(`Gemini API Error: ${message}`, details);
// Optionally add a specific code property if needed for finer-grained handling
// this.code = 'GEMINI_API_ERROR'; // Overrides the 'SERVICE_ERROR' code from BaseError via ServiceError
}
}
/**
* Error specifically for when a file or resource is not found in the Gemini API.
* Extends GeminiApiError to maintain the error hierarchy.
*/
export class GeminiResourceNotFoundError extends GeminiApiError {
constructor(resourceType: string, resourceId: string, details?: unknown) {
super(`${resourceType} not found: ${resourceId}`, details);
this.code = "GEMINI_RESOURCE_NOT_FOUND";
}
}
/**
* Error for invalid parameters when calling the Gemini API.
* Extends GeminiApiError to maintain the error hierarchy.
*/
export class GeminiInvalidParameterError extends GeminiApiError {
constructor(message: string, details?: unknown) {
super(`Invalid parameter: ${message}`, details);
this.code = "GEMINI_INVALID_PARAMETER";
}
}
/**
* Error for authentication failures with the Gemini API.
* Extends GeminiApiError to maintain the error hierarchy.
*/
export class GeminiAuthenticationError extends GeminiApiError {
constructor(message: string, details?: unknown) {
super(`Authentication error: ${message}`, details);
this.code = "GEMINI_AUTHENTICATION_ERROR";
}
}
/**
* Error for when Gemini API quota is exceeded or rate limits are hit.
* Extends GeminiApiError to maintain the error hierarchy.
*/
export class GeminiQuotaExceededError extends GeminiApiError {
constructor(message: string, details?: unknown) {
super(`Quota exceeded: ${message}`, details);
this.code = "GEMINI_QUOTA_EXCEEDED";
}
}
/**
* Error for when content is blocked by Gemini's safety settings.
* Extends GeminiApiError to maintain the error hierarchy.
*/
export class GeminiSafetyError extends GeminiApiError {
constructor(message: string, details?: unknown) {
super(`Content blocked by safety settings: ${message}`, details);
this.code = "GEMINI_SAFETY_ERROR";
}
}
// Import the McpError and ErrorCode from the MCP SDK for use in the mapping function
import { McpError, ErrorCode } from "@modelcontextprotocol/sdk/types.js";
import { ToolError } from "./ToolError.js";
// Re-export ToolError for use by tools
export { ToolError };
/**
* Maps internal application errors to standardized MCP errors.
* This function ensures consistent error mapping across all tool handlers.
*
* @param error - The error to be mapped to an MCP error
* @param toolName - The name of the tool where the error occurred (for better error messages)
* @returns McpError - A properly mapped MCP error
*/
export function mapToMcpError(error: unknown, toolName: string): McpError {
// If error is already an McpError, return it directly
if (error instanceof McpError) {
return error;
}
// Default error message if error is not an Error instance
let errorMessage = "An unknown error occurred";
let errorDetails: unknown = undefined;
// Extract error message and details if error is an Error instance
if (error instanceof Error) {
errorMessage = error.message;
// Extract details from BaseError instances
if (error instanceof BaseError && error.details) {
errorDetails = error.details;
}
} else if (typeof error === "string") {
errorMessage = error;
} else if (error !== null && typeof error === "object") {
// Try to extract information from unknown object errors
try {
errorMessage = JSON.stringify(error);
} catch {
// If JSON stringification fails, use default message
}
}
// ValidationError mapping
if (error instanceof ValidationError) {
return new McpError(
ErrorCode.InvalidParams,
`Validation error: ${errorMessage}`,
errorDetails
);
}
// NotFoundError mapping
if (error instanceof NotFoundError) {
return new McpError(
ErrorCode.InvalidRequest,
`Resource not found: ${errorMessage}`,
errorDetails
);
}
// ConfigurationError mapping
if (error instanceof ConfigurationError) {
return new McpError(
ErrorCode.InternalError, // Changed from FailedPrecondition which is not in MCP SDK
`Configuration error: ${errorMessage}`,
errorDetails
);
}
// Handle more specific Gemini API error subtypes first
if (error instanceof GeminiResourceNotFoundError) {
return new McpError(
ErrorCode.InvalidRequest, // MCP SDK lacks NotFound, mapping to InvalidRequest
`Resource not found: ${errorMessage}`,
errorDetails
);
}
if (error instanceof GeminiInvalidParameterError) {
return new McpError(
ErrorCode.InvalidParams,
`Invalid parameters: ${errorMessage}`,
errorDetails
);
}
if (error instanceof GeminiAuthenticationError) {
return new McpError(
ErrorCode.InvalidRequest, // Changed from PermissionDenied which is not in MCP SDK
`Authentication failed: ${errorMessage}`,
errorDetails
);
}
if (error instanceof GeminiQuotaExceededError) {
return new McpError(
ErrorCode.InternalError, // Changed from ResourceExhausted which is not in MCP SDK
`Quota exceeded or rate limit hit: ${errorMessage}`,
errorDetails
);
}
if (error instanceof GeminiSafetyError) {
return new McpError(
ErrorCode.InvalidRequest,
`Content blocked by safety settings: ${errorMessage}`,
errorDetails
);
}
// Generic GeminiApiError mapping with enhanced pattern detection
if (error instanceof GeminiApiError) {
// Convert message to lowercase for case-insensitive pattern matching
const lowerCaseMessage = errorMessage.toLowerCase();
// Handle rate limiting and quota errors
if (
lowerCaseMessage.includes("quota") ||
lowerCaseMessage.includes("rate limit") ||
lowerCaseMessage.includes("resource has been exhausted") ||
lowerCaseMessage.includes("resource exhausted") ||
lowerCaseMessage.includes("429") ||
lowerCaseMessage.includes("too many requests")
) {
return new McpError(
ErrorCode.InternalError, // Changed from ResourceExhausted which is not in MCP SDK
`Quota exceeded or rate limit hit: ${errorMessage}`,
errorDetails
);
}
// Handle permission and authorization errors
if (
lowerCaseMessage.includes("permission") ||
lowerCaseMessage.includes("not authorized") ||
lowerCaseMessage.includes("unauthorized") ||
lowerCaseMessage.includes("forbidden") ||
lowerCaseMessage.includes("403") ||
lowerCaseMessage.includes("access denied")
) {
return new McpError(
ErrorCode.InvalidRequest, // Changed from PermissionDenied which is not in MCP SDK
`Permission denied: ${errorMessage}`,
errorDetails
);
}
// Handle not found errors
if (
lowerCaseMessage.includes("not found") ||
lowerCaseMessage.includes("does not exist") ||
lowerCaseMessage.includes("404") ||
lowerCaseMessage.includes("could not find") ||
lowerCaseMessage.includes("no such file")
) {
return new McpError(
ErrorCode.InvalidRequest, // MCP SDK lacks NotFound, mapping to InvalidRequest
`Resource not found: ${errorMessage}`,
errorDetails
);
}
// Handle invalid argument/parameter errors
if (
lowerCaseMessage.includes("invalid argument") ||
lowerCaseMessage.includes("invalid parameter") ||
lowerCaseMessage.includes("invalid request") ||
lowerCaseMessage.includes("failed precondition") ||
lowerCaseMessage.includes("400") ||
lowerCaseMessage.includes("bad request") ||
lowerCaseMessage.includes("malformed")
) {
return new McpError(
ErrorCode.InvalidParams,
`Invalid parameters: ${errorMessage}`,
errorDetails
);
}
// Handle safety-related errors
if (
lowerCaseMessage.includes("safety") ||
lowerCaseMessage.includes("blocked") ||
lowerCaseMessage.includes("content policy") ||
lowerCaseMessage.includes("harmful") ||
lowerCaseMessage.includes("inappropriate") ||
lowerCaseMessage.includes("offensive")
) {
return new McpError(
ErrorCode.InvalidRequest,
`Content blocked by safety settings: ${errorMessage}`,
errorDetails
);
}
// Handle File API and other unsupported feature errors
if (
lowerCaseMessage.includes("file api is not supported") ||
lowerCaseMessage.includes("not supported") ||
lowerCaseMessage.includes("unsupported") ||
lowerCaseMessage.includes("not implemented")
) {
return new McpError(
ErrorCode.InvalidRequest, // Changed from FailedPrecondition which is not in MCP SDK
`Operation not supported: ${errorMessage}`,
errorDetails
);
}
// Default case for GeminiApiError - map to internal error
return new McpError(
ErrorCode.InternalError,
`Gemini API Error: ${errorMessage}`,
errorDetails
);
}
// Generic ServiceError mapping
if (error instanceof ServiceError) {
return new McpError(
ErrorCode.InternalError,
`Service error: ${errorMessage}`,
errorDetails
);
}
// Default case for all other errors
return new McpError(
ErrorCode.InternalError,
`[${toolName}] Failed: ${errorMessage}`
);
}
/**
* Combined error mapping function that handles both standard errors and ToolError instances.
* This function accommodates the different error types used across different tool implementations.
*
* @param error - Any error type, including McpError, BaseError, ToolError, or standard Error
* @param toolName - The name of the tool where the error occurred
* @returns McpError - A consistently mapped MCP error
*/
export function mapAnyErrorToMcpError(
error: unknown,
toolName: string
): McpError {
// Check if error is a ToolError from image feature tools
if (
error !== null &&
typeof error === "object" &&
"code" in error &&
typeof (error as ToolErrorLike).code === "string"
) {
// For objects that match the ToolError interface
return mapToolErrorToMcpError(error as ToolErrorLike, toolName);
}
// For standard errors and BaseError types
return mapToMcpError(error, toolName);
}
/**
* Interface for objects that conform to the ToolError structure
* This provides type safety for objects that have a similar structure to ToolError
* but may not be actual instances of the ToolError class.
*/
export interface ToolErrorLike {
code?: string;
message?: string;
details?: unknown;
[key: string]: unknown; // Allow additional properties for flexibility
}
// These tools use a different error structure than the rest of the application
// but need to maintain consistent error mapping to McpError
/**
* Maps ToolError instances used in some image feature tools to McpError.
* This is a compatibility layer for tools that use a different error structure.
*
* @param toolError - The ToolError instance or object with code/details properties
* @param toolName - The name of the tool for better error messages
* @returns McpError - A consistent MCP error
*/
export function mapToolErrorToMcpError(
toolError: ToolErrorLike | unknown,
toolName: string
): McpError {
// Default message if more specific extraction fails
let errorMessage = `Error in ${toolName}`;
let errorDetails: unknown = undefined;
// Extract error message and details if possible
if (toolError && typeof toolError === "object") {
const errorObj = toolError as ToolErrorLike;
// Extract message
if ("message" in errorObj && typeof errorObj.message === "string") {
errorMessage = errorObj.message;
}
// Extract details
if ("details" in errorObj) {
errorDetails = errorObj.details;
}
// Extract code for mapping
if ("code" in errorObj && typeof errorObj.code === "string") {
const code = errorObj.code.toUpperCase();
// Map common ToolError codes to appropriate ErrorCode values
if (code.includes("SAFETY") || code.includes("BLOCKED")) {
return new McpError(
ErrorCode.InvalidRequest,
`Content blocked by safety settings: ${errorMessage}`,
errorDetails
);
}
if (code.includes("QUOTA") || code.includes("RATE_LIMIT")) {
return new McpError(
ErrorCode.InternalError, // Changed from ResourceExhausted which is not in MCP SDK
`API quota or rate limit exceeded: ${errorMessage}`,
errorDetails
);
}
if (code.includes("PERMISSION") || code.includes("AUTH")) {
return new McpError(
ErrorCode.InvalidRequest, // Changed from PermissionDenied which is not in MCP SDK
`Permission denied: ${errorMessage}`,
errorDetails
);
}
if (code.includes("NOT_FOUND")) {
return new McpError(
ErrorCode.InvalidRequest,
`Resource not found: ${errorMessage}`,
errorDetails
);
}
if (code.includes("INVALID") || code.includes("ARGUMENT")) {
return new McpError(
ErrorCode.InvalidParams,
`Invalid parameters: ${errorMessage}`,
errorDetails
);
}
if (code.includes("UNSUPPORTED") || code.includes("NOT_SUPPORTED")) {
return new McpError(
ErrorCode.InvalidRequest, // Changed from FailedPrecondition which is not in MCP SDK
`Operation not supported: ${errorMessage}`,
errorDetails
);
}
}
}
// Default to internal error for any other case
return new McpError(
ErrorCode.InternalError,
`[${toolName}] Error: ${errorMessage}`,
errorDetails
);
}
```
--------------------------------------------------------------------------------
/tests/unit/tools/schemas/ToolSchemas.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import {
ToolSchema,
ToolResponseSchema,
FunctionParameterSchema,
FunctionDeclarationSchema,
} from "../../../../src/tools/schemas/ToolSchemas.js";
import {
HarmCategorySchema,
SafetySettingSchema,
ThinkingConfigSchema,
GenerationConfigSchema,
FilePathSchema,
FileOverwriteSchema,
EncodingSchema,
ModelNameSchema,
PromptSchema,
} from "../../../../src/tools/schemas/CommonSchemas.js";
describe("Tool Schemas Validation", () => {
describe("ToolSchema", () => {
it("should validate a valid tool definition with function declarations", () => {
const validTool = {
functionDeclarations: [
{
name: "testFunction",
description: "A test function",
parameters: {
type: "OBJECT",
properties: {
name: {
type: "STRING",
description: "The name parameter",
},
},
required: ["name"],
},
},
],
};
const result = ToolSchema.safeParse(validTool);
expect(result.success).toBe(true);
});
it("should validate a tool with no function declarations", () => {
const emptyTool = {};
const result = ToolSchema.safeParse(emptyTool);
expect(result.success).toBe(true);
});
it("should reject invalid function declarations", () => {
const invalidTool = {
functionDeclarations: [
{
// Missing required name field
description: "A test function",
parameters: {
type: "OBJECT",
properties: {},
},
},
],
};
const result = ToolSchema.safeParse(invalidTool);
expect(result.success).toBe(false);
});
});
describe("ToolResponseSchema", () => {
it("should validate a valid tool response", () => {
const validResponse = {
name: "testTool",
response: { result: "success" },
};
const result = ToolResponseSchema.safeParse(validResponse);
expect(result.success).toBe(true);
});
it("should reject response with missing name", () => {
const invalidResponse = {
response: { result: "success" },
};
const result = ToolResponseSchema.safeParse(invalidResponse);
expect(result.success).toBe(false);
});
});
describe("FunctionParameterSchema", () => {
it("should validate primitive parameter types", () => {
const stringParam = {
type: "STRING",
description: "A string parameter",
};
const numberParam = {
type: "NUMBER",
description: "A number parameter",
};
const booleanParam = {
type: "BOOLEAN",
};
expect(FunctionParameterSchema.safeParse(stringParam).success).toBe(true);
expect(FunctionParameterSchema.safeParse(numberParam).success).toBe(true);
expect(FunctionParameterSchema.safeParse(booleanParam).success).toBe(
true
);
});
it("should validate object parameter with nested properties", () => {
const objectParam = {
type: "OBJECT",
description: "An object parameter",
properties: {
name: {
type: "STRING",
},
age: {
type: "INTEGER",
},
details: {
type: "OBJECT",
properties: {
address: {
type: "STRING",
},
},
},
},
required: ["name"],
};
const result = FunctionParameterSchema.safeParse(objectParam);
expect(result.success).toBe(true);
});
it("should validate array parameter with items", () => {
const arrayParam = {
type: "ARRAY",
description: "An array parameter",
items: {
type: "STRING",
},
};
const result = FunctionParameterSchema.safeParse(arrayParam);
expect(result.success).toBe(true);
});
it("should reject parameter with invalid type", () => {
const invalidParam = {
type: "INVALID_TYPE", // Not a valid type
description: "An invalid parameter",
};
const result = FunctionParameterSchema.safeParse(invalidParam);
expect(result.success).toBe(false);
});
});
describe("FunctionDeclarationSchema", () => {
it("should validate a valid function declaration", () => {
const validFunction = {
name: "testFunction",
description: "A test function",
parameters: {
type: "OBJECT",
properties: {
name: {
type: "STRING",
description: "The name parameter",
},
age: {
type: "INTEGER",
},
},
required: ["name"],
},
};
const result = FunctionDeclarationSchema.safeParse(validFunction);
expect(result.success).toBe(true);
});
it("should reject function declaration with missing required fields", () => {
const invalidFunction = {
// Missing name
description: "A test function",
parameters: {
type: "OBJECT",
properties: {},
},
};
const result = FunctionDeclarationSchema.safeParse(invalidFunction);
expect(result.success).toBe(false);
});
it("should reject function declaration with invalid parameters type", () => {
const invalidFunction = {
name: "testFunction",
description: "A test function",
parameters: {
type: "STRING", // Should be "OBJECT"
properties: {},
},
};
const result = FunctionDeclarationSchema.safeParse(invalidFunction);
expect(result.success).toBe(false);
});
});
describe("CommonSchemas", () => {
describe("HarmCategorySchema", () => {
it("should validate valid harm categories", () => {
const validCategories = [
"HARM_CATEGORY_UNSPECIFIED",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
];
validCategories.forEach((category) => {
expect(HarmCategorySchema.safeParse(category).success).toBe(true);
});
});
it("should reject invalid harm categories", () => {
expect(HarmCategorySchema.safeParse("INVALID_CATEGORY").success).toBe(
false
);
});
});
describe("SafetySettingSchema", () => {
it("should validate a valid safety setting", () => {
const validSetting = {
category: "HARM_CATEGORY_HATE_SPEECH",
threshold: "BLOCK_MEDIUM_AND_ABOVE",
};
const result = SafetySettingSchema.safeParse(validSetting);
expect(result.success).toBe(true);
});
it("should validate all valid combinations of categories and thresholds", () => {
const validCategories = [
"HARM_CATEGORY_UNSPECIFIED",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
];
const validThresholds = [
"HARM_BLOCK_THRESHOLD_UNSPECIFIED",
"BLOCK_LOW_AND_ABOVE",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_ONLY_HIGH",
"BLOCK_NONE",
];
// Test a sampling of combinations
for (const category of validCategories) {
for (const threshold of validThresholds) {
const setting = { category, threshold };
expect(SafetySettingSchema.safeParse(setting).success).toBe(true);
}
}
});
it("should reject setting with valid structure but invalid category", () => {
const invalidSetting = {
category: "INVALID_CATEGORY",
threshold: "BLOCK_MEDIUM_AND_ABOVE",
};
const result = SafetySettingSchema.safeParse(invalidSetting);
expect(result.success).toBe(false);
});
it("should reject setting with valid structure but invalid threshold", () => {
const invalidSetting = {
category: "HARM_CATEGORY_HATE_SPEECH",
threshold: "INVALID_THRESHOLD",
};
const result = SafetySettingSchema.safeParse(invalidSetting);
expect(result.success).toBe(false);
});
it("should reject setting with missing required fields", () => {
const missingCategory = {
threshold: "BLOCK_MEDIUM_AND_ABOVE",
};
const missingThreshold = {
category: "HARM_CATEGORY_HATE_SPEECH",
};
expect(SafetySettingSchema.safeParse(missingCategory).success).toBe(
false
);
expect(SafetySettingSchema.safeParse(missingThreshold).success).toBe(
false
);
});
});
describe("GenerationConfigSchema", () => {
it("should validate a valid generation config", () => {
const validConfig = {
temperature: 0.7,
topP: 0.9,
topK: 40,
maxOutputTokens: 1024,
stopSequences: ["STOP", "END"],
thinkingConfig: {
thinkingBudget: 1000,
reasoningEffort: "medium",
},
};
const result = GenerationConfigSchema.safeParse(validConfig);
expect(result.success).toBe(true);
});
it("should validate minimal generation config", () => {
const minimalConfig = {};
const result = GenerationConfigSchema.safeParse(minimalConfig);
expect(result.success).toBe(true);
});
describe("temperature parameter boundary values", () => {
it("should validate minimum valid temperature (0)", () => {
const config = { temperature: 0 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
it("should validate maximum valid temperature (1)", () => {
const config = { temperature: 1 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
it("should reject temperature below minimum (-0.1)", () => {
const config = { temperature: -0.1 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(false);
});
it("should reject temperature above maximum (1.01)", () => {
const config = { temperature: 1.01 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(false);
});
});
describe("topP parameter boundary values", () => {
it("should validate minimum valid topP (0)", () => {
const config = { topP: 0 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
it("should validate maximum valid topP (1)", () => {
const config = { topP: 1 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
it("should reject topP below minimum (-0.1)", () => {
const config = { topP: -0.1 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(false);
});
it("should reject topP above maximum (1.01)", () => {
const config = { topP: 1.01 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(false);
});
});
describe("topK parameter boundary values", () => {
it("should validate minimum valid topK (1)", () => {
const config = { topK: 1 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
it("should reject topK below minimum (0)", () => {
const config = { topK: 0 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(false);
});
it("should validate large topK values", () => {
const config = { topK: 1000 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
});
describe("maxOutputTokens parameter boundary values", () => {
it("should validate minimum valid maxOutputTokens (1)", () => {
const config = { maxOutputTokens: 1 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
it("should reject maxOutputTokens below minimum (0)", () => {
const config = { maxOutputTokens: 0 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(false);
});
it("should validate large maxOutputTokens values", () => {
const config = { maxOutputTokens: 10000 };
expect(GenerationConfigSchema.safeParse(config).success).toBe(true);
});
});
});
describe("ThinkingConfigSchema", () => {
it("should validate valid thinking configs", () => {
const validConfigs = [
{ thinkingBudget: 1000 },
{ reasoningEffort: "medium" },
{ thinkingBudget: 5000, reasoningEffort: "high" },
{}, // Empty config is valid
];
validConfigs.forEach((config) => {
expect(ThinkingConfigSchema.safeParse(config).success).toBe(true);
});
});
describe("thinkingBudget parameter boundary values", () => {
it("should validate minimum valid thinkingBudget (0)", () => {
const config = { thinkingBudget: 0 };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(true);
});
it("should validate maximum valid thinkingBudget (24576)", () => {
const config = { thinkingBudget: 24576 };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(true);
});
it("should reject thinkingBudget below minimum (-1)", () => {
const config = { thinkingBudget: -1 };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(false);
});
it("should reject thinkingBudget above maximum (24577)", () => {
const config = { thinkingBudget: 24577 };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(false);
});
it("should reject non-integer thinkingBudget (1000.5)", () => {
const config = { thinkingBudget: 1000.5 };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(false);
});
});
describe("reasoningEffort parameter values", () => {
it("should validate all valid reasoningEffort options", () => {
const validOptions = ["none", "low", "medium", "high"];
validOptions.forEach((option) => {
const config = { reasoningEffort: option };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(true);
});
});
it("should reject invalid reasoningEffort options", () => {
const invalidOptions = ["maximum", "minimal", "very-high", ""];
invalidOptions.forEach((option) => {
const config = { reasoningEffort: option };
expect(ThinkingConfigSchema.safeParse(config).success).toBe(false);
});
});
});
});
describe("File Operation Schemas", () => {
it("should validate valid file paths", () => {
const validPaths = [
"/path/to/file.txt",
"C:\\Windows\\System32\\file.exe",
];
validPaths.forEach((path) => {
expect(FilePathSchema.safeParse(path).success).toBe(true);
});
});
it("should reject empty file paths", () => {
expect(FilePathSchema.safeParse("").success).toBe(false);
});
it("should validate file overwrite options", () => {
expect(FileOverwriteSchema.safeParse(true).success).toBe(true);
expect(FileOverwriteSchema.safeParse(false).success).toBe(true);
expect(FileOverwriteSchema.safeParse(undefined).success).toBe(true);
});
it("should validate encoding options", () => {
expect(EncodingSchema.safeParse("utf8").success).toBe(true);
expect(EncodingSchema.safeParse("base64").success).toBe(true);
expect(EncodingSchema.safeParse(undefined).success).toBe(true);
expect(EncodingSchema.safeParse("binary").success).toBe(false);
});
});
describe("Other Common Schemas", () => {
it("should validate model names", () => {
expect(ModelNameSchema.safeParse("gemini-pro").success).toBe(true);
expect(ModelNameSchema.safeParse("").success).toBe(false);
});
it("should validate prompts", () => {
expect(PromptSchema.safeParse("Tell me a story").success).toBe(true);
expect(PromptSchema.safeParse("").success).toBe(false);
});
});
});
});
```
--------------------------------------------------------------------------------
/tests/unit/utils/FileSecurityService.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import * as path from "path";
import * as fs from "fs/promises";
import * as fsSync from "fs";
// Import the code to test
import { FileSecurityService } from "../../../src/utils/FileSecurityService.js";
import { ValidationError } from "../../../src/utils/errors.js";
import { logger } from "../../../src/utils/logger.js";
describe("FileSecurityService", () => {
// Mock logger
const loggerMock = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
};
// Define test constants for all tests
const TEST_CONTENT = "Test file content";
const TEST_DIR = path.resolve("./test-security-dir");
const OUTSIDE_DIR = path.resolve("./outside-security-dir");
// Setup before each test
beforeEach(() => {
// Reset mocks and create test directories
vi.clearAllMocks();
// Replace logger with mock
vi.spyOn(logger, "info").mockImplementation(loggerMock.info);
vi.spyOn(logger, "warn").mockImplementation(loggerMock.warn);
vi.spyOn(logger, "error").mockImplementation(loggerMock.error);
vi.spyOn(logger, "debug").mockImplementation(loggerMock.debug);
// Create test directories
fsSync.mkdirSync(TEST_DIR, { recursive: true });
fsSync.mkdirSync(OUTSIDE_DIR, { recursive: true });
});
// Cleanup after each test
afterEach(() => {
// Restore original logger
vi.restoreAllMocks();
// Clean up test directories
try {
fsSync.rmSync(TEST_DIR, { recursive: true, force: true });
fsSync.rmSync(OUTSIDE_DIR, { recursive: true, force: true });
} catch (err) {
// Ignore cleanup errors
}
});
describe("Constructor and Configuration", () => {
it("should initialize with default allowed directories", () => {
const service = new FileSecurityService();
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs.length).toBeGreaterThan(0);
expect(allowedDirs).toContain(path.resolve(process.cwd()));
});
it("should initialize with custom allowed directories", () => {
const customDirs = [TEST_DIR, OUTSIDE_DIR];
const service = new FileSecurityService(customDirs);
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs.length).toBe(2);
expect(allowedDirs).toContain(path.resolve(TEST_DIR));
expect(allowedDirs).toContain(path.resolve(OUTSIDE_DIR));
});
it("should initialize with a secure base path", () => {
const service = new FileSecurityService([], TEST_DIR);
const basePath = service.getSecureBasePath();
expect(basePath).toBe(path.normalize(TEST_DIR));
// Verify allowed directories includes the base path
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs).toContain(path.normalize(TEST_DIR));
});
it("should set allowed directories", () => {
const service = new FileSecurityService();
const newDirs = [TEST_DIR, OUTSIDE_DIR];
service.setAllowedDirectories(newDirs);
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs.length).toBe(2);
expect(allowedDirs).toContain(path.normalize(TEST_DIR));
expect(allowedDirs).toContain(path.normalize(OUTSIDE_DIR));
});
it("should throw error when setting empty allowed directories", () => {
const service = new FileSecurityService();
expect(() => service.setAllowedDirectories([])).toThrow(ValidationError);
expect(() => service.setAllowedDirectories([])).toThrow(
/At least one allowed directory/
);
});
it("should throw error when setting non-absolute allowed directories", () => {
const service = new FileSecurityService();
expect(() => service.setAllowedDirectories(["./relative/path"])).toThrow(
ValidationError
);
expect(() => service.setAllowedDirectories(["./relative/path"])).toThrow(
/Directory path must be absolute/
);
});
it("should set and get secure base path", () => {
const service = new FileSecurityService();
service.setSecureBasePath(TEST_DIR);
const basePath = service.getSecureBasePath();
expect(basePath).toBe(path.normalize(TEST_DIR));
});
it("should throw error when setting non-absolute secure base path", () => {
const service = new FileSecurityService();
expect(() => service.setSecureBasePath("./relative/path")).toThrow(
ValidationError
);
expect(() => service.setSecureBasePath("./relative/path")).toThrow(
/Base path must be absolute/
);
});
it("should configure from environment", () => {
// Save original env var
const originalEnvVar = process.env.GEMINI_SAFE_FILE_BASE_DIR;
// Set env var for test
process.env.GEMINI_SAFE_FILE_BASE_DIR = TEST_DIR;
const service = FileSecurityService.configureFromEnvironment();
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs).toContain(path.normalize(TEST_DIR));
// Restore original env var
if (originalEnvVar) {
process.env.GEMINI_SAFE_FILE_BASE_DIR = originalEnvVar;
} else {
delete process.env.GEMINI_SAFE_FILE_BASE_DIR;
}
});
});
describe("Path Validation", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR]);
});
it("should validate path within allowed directory", () => {
const testFilePath = path.join(TEST_DIR, "test-file.txt");
const validatedPath = service.validateAndResolvePath(testFilePath);
expect(validatedPath).toBe(path.normalize(testFilePath));
});
it("should validate paths with relative components", () => {
const complexPath = path.join(
TEST_DIR,
".",
"subdir",
"..",
"test-file.txt"
);
const validatedPath = service.validateAndResolvePath(complexPath);
// Should normalize to TEST_DIR/test-file.txt
const expectedPath = path.normalize(path.join(TEST_DIR, "test-file.txt"));
expect(validatedPath).toBe(expectedPath);
});
it("should reject paths outside allowed directories", () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
expect(() => service.validateAndResolvePath(outsidePath)).toThrow(
ValidationError
);
expect(() => service.validateAndResolvePath(outsidePath)).toThrow(
/Access denied/
);
});
it("should reject paths with directory traversal", () => {
const traversalPath = path.join(
TEST_DIR,
"..",
"outside",
"test-file.txt"
);
expect(() => service.validateAndResolvePath(traversalPath)).toThrow(
ValidationError
);
expect(() => service.validateAndResolvePath(traversalPath)).toThrow(
/Access denied/
);
});
it("should check file existence with mustExist option", () => {
const nonExistentPath = path.join(TEST_DIR, "non-existent.txt");
expect(() =>
service.validateAndResolvePath(nonExistentPath, { mustExist: true })
).toThrow(ValidationError);
expect(() =>
service.validateAndResolvePath(nonExistentPath, { mustExist: true })
).toThrow(/File not found/);
});
it("should use custom allowed directories when provided", () => {
// Path is outside the service's configured directory but inside custom allowed dir
const customAllowedPath = path.join(OUTSIDE_DIR, "custom-allowed.txt");
const validatedPath = service.validateAndResolvePath(customAllowedPath, {
allowedDirs: [OUTSIDE_DIR],
});
expect(validatedPath).toBe(path.normalize(customAllowedPath));
});
});
describe("isPathWithinAllowedDirs", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR]);
});
it("should return true for paths within allowed directories", () => {
const insidePath = path.join(TEST_DIR, "test-file.txt");
const result = service.isPathWithinAllowedDirs(insidePath);
expect(result).toBe(true);
});
it("should return true for exact match with allowed directory", () => {
const result = service.isPathWithinAllowedDirs(TEST_DIR);
expect(result).toBe(true);
});
it("should return false for paths outside allowed directories", () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
const result = service.isPathWithinAllowedDirs(outsidePath);
expect(result).toBe(false);
});
it("should return false for paths with directory traversal", () => {
const traversalPath = path.join(
TEST_DIR,
"..",
"outside",
"test-file.txt"
);
const result = service.isPathWithinAllowedDirs(traversalPath);
expect(result).toBe(false);
});
it("should use custom allowed directories when provided", () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
// Should be false with default allowed dirs
expect(service.isPathWithinAllowedDirs(outsidePath)).toBe(false);
// Should be true with custom allowed dirs
expect(service.isPathWithinAllowedDirs(outsidePath, [OUTSIDE_DIR])).toBe(
true
);
});
it("should return false when no allowed directories exist", () => {
const result = service.isPathWithinAllowedDirs(TEST_DIR, []);
expect(result).toBe(false);
});
});
describe("fullyResolvePath", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR, OUTSIDE_DIR]);
});
it("should resolve a normal file path", async () => {
const testPath = path.join(TEST_DIR, "test-file.txt");
const resolvedPath = await service.fullyResolvePath(testPath);
expect(resolvedPath).toBe(path.normalize(testPath));
});
it("should handle non-existent paths", async () => {
const nonExistentPath = path.join(
TEST_DIR,
"non-existent",
"test-file.txt"
);
const resolvedPath = await service.fullyResolvePath(nonExistentPath);
expect(resolvedPath).toBe(path.normalize(nonExistentPath));
});
it("should resolve and validate a symlink to a file", async () => {
// Create target file
const targetPath = path.join(TEST_DIR, "target.txt");
await fs.writeFile(targetPath, TEST_CONTENT, "utf8");
// Create symlink
const symlinkPath = path.join(TEST_DIR, "symlink.txt");
await fs.symlink(targetPath, symlinkPath);
// Resolve the symlink
const resolvedPath = await service.fullyResolvePath(symlinkPath);
// Should resolve to the target path
expect(resolvedPath).toBe(path.normalize(targetPath));
});
it("should reject symlinks pointing outside allowed directories", async () => {
// Create target file in outside (non-allowed) directory
const targetPath = path.join(OUTSIDE_DIR, "target.txt");
await fs.writeFile(targetPath, TEST_CONTENT, "utf8");
// Create symlink in test (allowed) directory pointing to outside
const symlinkPath = path.join(TEST_DIR, "bad-symlink.txt");
// Setup service with only TEST_DIR allowed (not OUTSIDE_DIR)
const restrictedService = new FileSecurityService([TEST_DIR]);
await fs.symlink(targetPath, symlinkPath);
// Try to resolve the symlink
await expect(
restrictedService.fullyResolvePath(symlinkPath)
).rejects.toThrow(ValidationError);
await expect(
restrictedService.fullyResolvePath(symlinkPath)
).rejects.toThrow(/Security error/);
await expect(
restrictedService.fullyResolvePath(symlinkPath)
).rejects.toThrow(/outside allowed directories/);
});
it("should detect and validate symlinked parent directories", async () => {
// Create target directory in allowed location
const targetDir = path.join(TEST_DIR, "target-dir");
await fs.mkdir(targetDir, { recursive: true });
// Create symlink to directory
const symlinkDir = path.join(TEST_DIR, "symlink-dir");
await fs.symlink(targetDir, symlinkDir);
// Create a file path inside the symlinked directory
const filePath = path.join(symlinkDir, "test-file.txt");
// Resolve the path
const resolvedPath = await service.fullyResolvePath(filePath);
// Should resolve to actual path in target directory
const expectedPath = path.join(targetDir, "test-file.txt");
expect(resolvedPath).toBe(path.normalize(expectedPath));
});
it("should reject symlinked parent directories pointing outside allowed directories", async () => {
// Create target directory in outside (not allowed) directory
const targetDir = path.join(OUTSIDE_DIR, "target-dir");
await fs.mkdir(targetDir, { recursive: true });
// Create symlink in test directory pointing to outside directory
const symlinkDir = path.join(TEST_DIR, "bad-symlink-dir");
await fs.symlink(targetDir, symlinkDir);
// Create a file path inside the symlinked directory
const filePath = path.join(symlinkDir, "test-file.txt");
// Setup service with only TEST_DIR allowed
const restrictedService = new FileSecurityService([TEST_DIR]);
// Try to resolve the path
await expect(
restrictedService.fullyResolvePath(filePath)
).rejects.toThrow(ValidationError);
await expect(
restrictedService.fullyResolvePath(filePath)
).rejects.toThrow(/Security error/);
});
});
describe("secureWriteFile", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR]);
});
it("should write file to an allowed directory", async () => {
const filePath = path.join(TEST_DIR, "test-file.txt");
await service.secureWriteFile(filePath, TEST_CONTENT);
// Verify file was written
const content = await fs.readFile(filePath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
it("should create directories if they don't exist", async () => {
const nestedFilePath = path.join(
TEST_DIR,
"nested",
"deep",
"test-file.txt"
);
await service.secureWriteFile(nestedFilePath, TEST_CONTENT);
// Verify directories were created and file exists
const content = await fs.readFile(nestedFilePath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
it("should reject writing outside allowed directories", async () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
await expect(
service.secureWriteFile(outsidePath, TEST_CONTENT)
).rejects.toThrow(ValidationError);
await expect(
service.secureWriteFile(outsidePath, TEST_CONTENT)
).rejects.toThrow(/Access denied/);
// Verify file was not created
await expect(fs.access(outsidePath)).rejects.toThrow();
});
it("should reject overwriting existing files by default", async () => {
const filePath = path.join(TEST_DIR, "existing-file.txt");
// Create the file first
await fs.writeFile(filePath, "Original content", "utf8");
// Try to overwrite without setting overwrite flag
await expect(
service.secureWriteFile(filePath, TEST_CONTENT)
).rejects.toThrow(ValidationError);
await expect(
service.secureWriteFile(filePath, TEST_CONTENT)
).rejects.toThrow(/File already exists/);
// Verify file wasn't changed
const content = await fs.readFile(filePath, "utf8");
expect(content).toBe("Original content");
});
it("should allow overwriting existing files with overwrite flag", async () => {
const filePath = path.join(TEST_DIR, "existing-file.txt");
// Create the file first
await fs.writeFile(filePath, "Original content", "utf8");
// Overwrite with overwrite flag
await service.secureWriteFile(filePath, TEST_CONTENT, {
overwrite: true,
});
// Verify file was overwritten
const content = await fs.readFile(filePath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
it("should support custom allowed directories", async () => {
// Path is outside the service's configured directories
const customAllowedPath = path.join(OUTSIDE_DIR, "custom-allowed.txt");
// Use explicit allowedDirs
await service.secureWriteFile(customAllowedPath, TEST_CONTENT, {
allowedDirs: [OUTSIDE_DIR],
});
// Verify file was written
const content = await fs.readFile(customAllowedPath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
});
});
```
--------------------------------------------------------------------------------
/tests/unit/utils/FileSecurityService.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import * as path from "path";
import * as fs from "fs/promises";
import * as fsSync from "fs";
// Import the code to test
import { FileSecurityService } from "../../../src/utils/FileSecurityService.js";
import { ValidationError } from "../../../src/utils/errors.js";
import { logger } from "../../../src/utils/logger.js";
describe("FileSecurityService", () => {
// Mock logger
const loggerMock = {
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
debug: vi.fn(),
};
// Define test constants for all tests
const TEST_CONTENT = "Test file content";
const TEST_DIR = path.resolve("./test-security-dir");
const OUTSIDE_DIR = path.resolve("./outside-security-dir");
// Setup before each test
beforeEach(() => {
// Reset mocks and create test directories
vi.clearAllMocks();
// Replace logger with mock
vi.spyOn(logger, "info").mockImplementation(loggerMock.info);
vi.spyOn(logger, "warn").mockImplementation(loggerMock.warn);
vi.spyOn(logger, "error").mockImplementation(loggerMock.error);
vi.spyOn(logger, "debug").mockImplementation(loggerMock.debug);
// Create test directories
fsSync.mkdirSync(TEST_DIR, { recursive: true });
fsSync.mkdirSync(OUTSIDE_DIR, { recursive: true });
});
// Cleanup after each test
afterEach(() => {
// Restore original logger
vi.restoreAllMocks();
// Clean up test directories
try {
fsSync.rmSync(TEST_DIR, { recursive: true, force: true });
fsSync.rmSync(OUTSIDE_DIR, { recursive: true, force: true });
} catch (err) {
// Ignore cleanup errors
}
});
describe("Constructor and Configuration", () => {
it("should initialize with default allowed directories", () => {
const service = new FileSecurityService();
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs.length).toBeGreaterThan(0);
expect(allowedDirs).toContain(path.resolve(process.cwd()));
});
it("should initialize with custom allowed directories", () => {
const customDirs = [TEST_DIR, OUTSIDE_DIR];
const service = new FileSecurityService(customDirs);
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs.length).toBe(2);
expect(allowedDirs).toContain(path.resolve(TEST_DIR));
expect(allowedDirs).toContain(path.resolve(OUTSIDE_DIR));
});
it("should initialize with a secure base path", () => {
const service = new FileSecurityService([], TEST_DIR);
const basePath = service.getSecureBasePath();
expect(basePath).toBe(path.normalize(TEST_DIR));
// Verify allowed directories includes the base path
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs).toContain(path.normalize(TEST_DIR));
});
it("should set allowed directories", () => {
const service = new FileSecurityService();
const newDirs = [TEST_DIR, OUTSIDE_DIR];
service.setAllowedDirectories(newDirs);
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs.length).toBe(2);
expect(allowedDirs).toContain(path.normalize(TEST_DIR));
expect(allowedDirs).toContain(path.normalize(OUTSIDE_DIR));
});
it("should throw error when setting empty allowed directories", () => {
const service = new FileSecurityService();
expect(() => service.setAllowedDirectories([])).toThrow(ValidationError);
expect(() => service.setAllowedDirectories([])).toThrow(
/At least one allowed directory/
);
});
it("should throw error when setting non-absolute allowed directories", () => {
const service = new FileSecurityService();
expect(() => service.setAllowedDirectories(["./relative/path"])).toThrow(
ValidationError
);
expect(() => service.setAllowedDirectories(["./relative/path"])).toThrow(
/Directory path must be absolute/
);
});
it("should set and get secure base path", () => {
const service = new FileSecurityService();
service.setSecureBasePath(TEST_DIR);
const basePath = service.getSecureBasePath();
expect(basePath).toBe(path.normalize(TEST_DIR));
});
it("should throw error when setting non-absolute secure base path", () => {
const service = new FileSecurityService();
expect(() => service.setSecureBasePath("./relative/path")).toThrow(
ValidationError
);
expect(() => service.setSecureBasePath("./relative/path")).toThrow(
/Base path must be absolute/
);
});
it("should configure from environment", () => {
// Save original env var
const originalEnvVar = process.env.GEMINI_SAFE_FILE_BASE_DIR;
// Set env var for test
process.env.GEMINI_SAFE_FILE_BASE_DIR = TEST_DIR;
const service = FileSecurityService.configureFromEnvironment();
const allowedDirs = service.getAllowedDirectories();
expect(allowedDirs).toContain(path.normalize(TEST_DIR));
// Restore original env var
if (originalEnvVar) {
process.env.GEMINI_SAFE_FILE_BASE_DIR = originalEnvVar;
} else {
delete process.env.GEMINI_SAFE_FILE_BASE_DIR;
}
});
});
describe("Path Validation", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR]);
});
it("should validate path within allowed directory", () => {
const testFilePath = path.join(TEST_DIR, "test-file.txt");
const validatedPath = service.validateAndResolvePath(testFilePath);
expect(validatedPath).toBe(path.normalize(testFilePath));
});
it("should validate paths with relative components", () => {
const complexPath = path.join(
TEST_DIR,
".",
"subdir",
"..",
"test-file.txt"
);
const validatedPath = service.validateAndResolvePath(complexPath);
// Should normalize to TEST_DIR/test-file.txt
const expectedPath = path.normalize(path.join(TEST_DIR, "test-file.txt"));
expect(validatedPath).toBe(expectedPath);
});
it("should reject paths outside allowed directories", () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
expect(() => service.validateAndResolvePath(outsidePath)).toThrow(
ValidationError
);
expect(() => service.validateAndResolvePath(outsidePath)).toThrow(
/Access denied/
);
});
it("should reject paths with directory traversal", () => {
const traversalPath = path.join(
TEST_DIR,
"..",
"outside",
"test-file.txt"
);
expect(() => service.validateAndResolvePath(traversalPath)).toThrow(
ValidationError
);
expect(() => service.validateAndResolvePath(traversalPath)).toThrow(
/Access denied/
);
});
it("should check file existence with mustExist option", () => {
const nonExistentPath = path.join(TEST_DIR, "non-existent.txt");
expect(() =>
service.validateAndResolvePath(nonExistentPath, { mustExist: true })
).toThrow(ValidationError);
expect(() =>
service.validateAndResolvePath(nonExistentPath, { mustExist: true })
).toThrow(/File not found/);
});
it("should use custom allowed directories when provided", () => {
// Path is outside the service's configured directory but inside custom allowed dir
const customAllowedPath = path.join(OUTSIDE_DIR, "custom-allowed.txt");
const validatedPath = service.validateAndResolvePath(customAllowedPath, {
allowedDirs: [OUTSIDE_DIR],
});
expect(validatedPath).toBe(path.normalize(customAllowedPath));
});
});
describe("isPathWithinAllowedDirs", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR]);
});
it("should return true for paths within allowed directories", () => {
const insidePath = path.join(TEST_DIR, "test-file.txt");
const result = service.isPathWithinAllowedDirs(insidePath);
expect(result).toBe(true);
});
it("should return true for exact match with allowed directory", () => {
const result = service.isPathWithinAllowedDirs(TEST_DIR);
expect(result).toBe(true);
});
it("should return false for paths outside allowed directories", () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
const result = service.isPathWithinAllowedDirs(outsidePath);
expect(result).toBe(false);
});
it("should return false for paths with directory traversal", () => {
const traversalPath = path.join(
TEST_DIR,
"..",
"outside",
"test-file.txt"
);
const result = service.isPathWithinAllowedDirs(traversalPath);
expect(result).toBe(false);
});
it("should use custom allowed directories when provided", () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
// Should be false with default allowed dirs
expect(service.isPathWithinAllowedDirs(outsidePath)).toBe(false);
// Should be true with custom allowed dirs
expect(service.isPathWithinAllowedDirs(outsidePath, [OUTSIDE_DIR])).toBe(
true
);
});
it("should return false when no allowed directories exist", () => {
const result = service.isPathWithinAllowedDirs(TEST_DIR, []);
expect(result).toBe(false);
});
});
describe("fullyResolvePath", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR, OUTSIDE_DIR]);
});
it("should resolve a normal file path", async () => {
const testPath = path.join(TEST_DIR, "test-file.txt");
const resolvedPath = await service.fullyResolvePath(testPath);
expect(resolvedPath).toBe(path.normalize(testPath));
});
it("should handle non-existent paths", async () => {
const nonExistentPath = path.join(
TEST_DIR,
"non-existent",
"test-file.txt"
);
const resolvedPath = await service.fullyResolvePath(nonExistentPath);
expect(resolvedPath).toBe(path.normalize(nonExistentPath));
});
it("should resolve and validate a symlink to a file", async () => {
// Create target file
const targetPath = path.join(TEST_DIR, "target.txt");
await fs.writeFile(targetPath, TEST_CONTENT, "utf8");
// Create symlink
const symlinkPath = path.join(TEST_DIR, "symlink.txt");
await fs.symlink(targetPath, symlinkPath);
// Resolve the symlink
const resolvedPath = await service.fullyResolvePath(symlinkPath);
// Should resolve to the target path
expect(resolvedPath).toBe(path.normalize(targetPath));
});
it("should reject symlinks pointing outside allowed directories", async () => {
// Create target file in outside (non-allowed) directory
const targetPath = path.join(OUTSIDE_DIR, "target.txt");
await fs.writeFile(targetPath, TEST_CONTENT, "utf8");
// Create symlink in test (allowed) directory pointing to outside
const symlinkPath = path.join(TEST_DIR, "bad-symlink.txt");
// Setup service with only TEST_DIR allowed (not OUTSIDE_DIR)
const restrictedService = new FileSecurityService([TEST_DIR]);
await fs.symlink(targetPath, symlinkPath);
// Try to resolve the symlink
await expect(
restrictedService.fullyResolvePath(symlinkPath)
).rejects.toThrow(ValidationError);
await expect(
restrictedService.fullyResolvePath(symlinkPath)
).rejects.toThrow(/Security error/);
await expect(
restrictedService.fullyResolvePath(symlinkPath)
).rejects.toThrow(/outside allowed directories/);
});
it("should detect and validate symlinked parent directories", async () => {
// Create target directory in allowed location
const targetDir = path.join(TEST_DIR, "target-dir");
await fs.mkdir(targetDir, { recursive: true });
// Create symlink to directory
const symlinkDir = path.join(TEST_DIR, "symlink-dir");
await fs.symlink(targetDir, symlinkDir);
// Create a file path inside the symlinked directory
const filePath = path.join(symlinkDir, "test-file.txt");
// Resolve the path
const resolvedPath = await service.fullyResolvePath(filePath);
// Should resolve to actual path in target directory
const expectedPath = path.join(targetDir, "test-file.txt");
expect(resolvedPath).toBe(path.normalize(expectedPath));
});
it("should reject symlinked parent directories pointing outside allowed directories", async () => {
// Create target directory in outside (not allowed) directory
const targetDir = path.join(OUTSIDE_DIR, "target-dir");
await fs.mkdir(targetDir, { recursive: true });
// Create symlink in test directory pointing to outside directory
const symlinkDir = path.join(TEST_DIR, "bad-symlink-dir");
await fs.symlink(targetDir, symlinkDir);
// Create a file path inside the symlinked directory
const filePath = path.join(symlinkDir, "test-file.txt");
// Setup service with only TEST_DIR allowed
const restrictedService = new FileSecurityService([TEST_DIR]);
// Try to resolve the path
await expect(
restrictedService.fullyResolvePath(filePath)
).rejects.toThrow(ValidationError);
await expect(
restrictedService.fullyResolvePath(filePath)
).rejects.toThrow(/Security error/);
});
});
describe("secureWriteFile", () => {
let service: FileSecurityService;
beforeEach(() => {
service = new FileSecurityService([TEST_DIR]);
});
it("should write file to an allowed directory", async () => {
const filePath = path.join(TEST_DIR, "test-file.txt");
await service.secureWriteFile(filePath, TEST_CONTENT);
// Verify file was written
const content = await fs.readFile(filePath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
it("should create directories if they don't exist", async () => {
const nestedFilePath = path.join(
TEST_DIR,
"nested",
"deep",
"test-file.txt"
);
await service.secureWriteFile(nestedFilePath, TEST_CONTENT);
// Verify directories were created and file exists
const content = await fs.readFile(nestedFilePath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
it("should reject writing outside allowed directories", async () => {
const outsidePath = path.join(OUTSIDE_DIR, "test-file.txt");
await expect(
service.secureWriteFile(outsidePath, TEST_CONTENT)
).rejects.toThrow(ValidationError);
await expect(
service.secureWriteFile(outsidePath, TEST_CONTENT)
).rejects.toThrow(/Access denied/);
// Verify file was not created
await expect(fs.access(outsidePath)).rejects.toThrow();
});
it("should reject overwriting existing files by default", async () => {
const filePath = path.join(TEST_DIR, "existing-file.txt");
// Create the file first
await fs.writeFile(filePath, "Original content", "utf8");
// Try to overwrite without setting overwrite flag
await expect(
service.secureWriteFile(filePath, TEST_CONTENT)
).rejects.toThrow(ValidationError);
await expect(
service.secureWriteFile(filePath, TEST_CONTENT)
).rejects.toThrow(/File already exists/);
// Verify file wasn't changed
const content = await fs.readFile(filePath, "utf8");
expect(content).toBe("Original content");
});
it("should allow overwriting existing files with overwrite flag", async () => {
const filePath = path.join(TEST_DIR, "existing-file.txt");
// Create the file first
await fs.writeFile(filePath, "Original content", "utf8");
// Overwrite with overwrite flag
await service.secureWriteFile(filePath, TEST_CONTENT, {
overwrite: true,
});
// Verify file was overwritten
const content = await fs.readFile(filePath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
it("should support custom allowed directories", async () => {
// Path is outside the service's configured directories
const customAllowedPath = path.join(OUTSIDE_DIR, "custom-allowed.txt");
// Use explicit allowedDirs
await service.secureWriteFile(customAllowedPath, TEST_CONTENT, {
allowedDirs: [OUTSIDE_DIR],
});
// Verify file was written
const content = await fs.readFile(customAllowedPath, "utf8");
expect(content).toBe(TEST_CONTENT);
});
});
});
```
--------------------------------------------------------------------------------
/src/utils/FileSecurityService.ts:
--------------------------------------------------------------------------------
```typescript
import * as fs from "fs/promises";
import * as fsSync from "fs";
import * as path from "path";
import { logger } from "./logger.js";
import { ValidationError } from "./errors.js";
/**
* Type guard to check if an error is an ENOENT (file not found) error
* @param err - The error to check
* @returns True if the error is an ENOENT error
*/
function isENOENTError(err: unknown): boolean {
return (
err !== null &&
typeof err === "object" &&
"code" in err &&
err.code === "ENOENT"
);
}
/**
* Type guard to check if an error has a message property
* @param err - The error to check
* @returns True if the error has a message property
*/
function hasErrorMessage(err: unknown): err is { message: string } {
return (
err !== null &&
typeof err === "object" &&
"message" in err &&
typeof err.message === "string"
);
}
/**
* Centralized service for handling file-related security operations
* Provides comprehensive validation, resolution, and secure file operations
*/
export class FileSecurityService {
private allowedDirectories: string[] = [];
private secureBasePath?: string;
// Default safe base directory - using the project root as the default
private readonly DEFAULT_SAFE_BASE_DIR: string =
process.env.GEMINI_SAFE_FILE_BASE_DIR || path.resolve(process.cwd());
/**
* Creates a new instance of the FileSecurityService
* @param allowedDirectories Optional array of allowed directories for file operations
* @param secureBasePath Optional single secure base path (takes precedence over env vars)
*/
constructor(allowedDirectories?: string[], secureBasePath?: string) {
// Initialize with environment variable if set
this.secureBasePath =
process.env.GEMINI_SAFE_FILE_BASE_DIR ||
(secureBasePath ? path.normalize(secureBasePath) : undefined);
// Initialize allowed directories
if (allowedDirectories && allowedDirectories.length > 0) {
this.setAllowedDirectories(allowedDirectories);
} else if (this.secureBasePath) {
this.allowedDirectories = [this.secureBasePath];
} else {
this.allowedDirectories = [path.resolve(process.cwd())];
}
logger.info(
`File operations restricted to: ${this.allowedDirectories.join(", ")}`
);
}
/**
* Sets the secure base directory for file operations.
* @param basePath The absolute path to restrict file operations to
*/
public setSecureBasePath(basePath: string): void {
if (!path.isAbsolute(basePath)) {
throw new ValidationError("Base path must be absolute");
}
// Store the base path in a private field
this.secureBasePath = path.normalize(basePath);
// Update allowed directories to include this path
if (!this.allowedDirectories.includes(this.secureBasePath)) {
this.allowedDirectories.push(this.secureBasePath);
}
logger.debug(`Secure base path set to: ${this.secureBasePath}`);
}
/**
* Gets the current secure base directory if set
*/
public getSecureBasePath(): string | undefined {
return this.secureBasePath;
}
/**
* Sets the allowed directories for file operations
* @param directories Array of absolute paths allowed for file operations
*/
public setAllowedDirectories(directories: string[]): void {
if (!directories || directories.length === 0) {
throw new ValidationError(
"At least one allowed directory must be provided"
);
}
// Validate all directories are absolute paths
for (const dir of directories) {
if (!path.isAbsolute(dir)) {
throw new ValidationError(`Directory path must be absolute: ${dir}`);
}
}
// Store normalized paths
this.allowedDirectories = directories.map((dir) => path.normalize(dir));
logger.debug(
`Allowed directories set to: ${this.allowedDirectories.join(", ")}`
);
}
/**
* Gets the current allowed directories
*/
public getAllowedDirectories(): string[] {
return [...this.allowedDirectories];
}
/**
* Validates that a file path is secure and resolves it to an absolute path
* Can work with either a base directory or multiple allowed directories
*
* @param filePath The file path to validate
* @param options Optional configuration
* @returns The validated absolute file path
* @throws ValidationError if the path is invalid or insecure
*/
public validateAndResolvePath(
filePath: string,
options: {
mustExist?: boolean;
allowedDirs?: string[];
basePath?: string;
} = {}
): string {
const { mustExist = false, allowedDirs, basePath } = options;
// Determine which allowed directories to use
const effectiveAllowedDirs =
allowedDirs ||
(basePath ? [path.normalize(basePath)] : this.allowedDirectories);
logger.debug(`Validating file path: ${filePath}`);
logger.debug(
`Using allowed directories: ${effectiveAllowedDirs.join(", ")}`
);
// Resolve the absolute path
const absolutePath = path.isAbsolute(filePath)
? filePath
: path.resolve(this.secureBasePath || process.cwd(), filePath);
// Normalize path to handle . and .. segments
const normalizedPath = path.normalize(absolutePath);
// Check if the path is within any allowed directory
if (!this.isPathWithinAllowedDirs(normalizedPath, effectiveAllowedDirs)) {
logger.warn(
`Access denied: Path not in allowed directories: ${filePath}`
);
throw new ValidationError(
`Access denied: The file path must be within the allowed directories`
);
}
// Check if the file exists (if required)
if (mustExist) {
try {
fsSync.accessSync(normalizedPath, fsSync.constants.F_OK);
} catch (error) {
logger.warn(`File not found: ${normalizedPath}`);
throw new ValidationError(`File not found: ${normalizedPath}`);
}
}
logger.debug(`Validated path: ${normalizedPath}`);
return normalizedPath;
}
/**
* Checks if a given file path is within any of the allowed directories.
*
* @param filePath The relative or absolute path to check.
* @param allowedDirs Optional array of allowed directory paths (defaults to instance's allowed directories)
* @returns True if the file path is within any of the allowed directories, false otherwise.
*/
public isPathWithinAllowedDirs(
filePath: string,
allowedDirs?: string[]
): boolean {
// Use instance's allowed directories if none provided
const effectiveAllowedDirs = allowedDirs || this.allowedDirectories;
// Return false if effectiveAllowedDirs is empty
if (!effectiveAllowedDirs || effectiveAllowedDirs.length === 0) {
return false;
}
// Canonicalize the file path to an absolute path
const resolvedFilePath = path.resolve(filePath);
// Normalize the path to handle sequences like '..'
const normalizedFilePath = path.normalize(resolvedFilePath);
// Check if the file path is within any of the allowed directories
for (const allowedDir of effectiveAllowedDirs) {
// Normalize the allowed directory path
const normalizedAllowedDir = path.normalize(path.resolve(allowedDir));
// Check if it's an allowed directory containing the file, or an exact match
if (
normalizedFilePath.startsWith(normalizedAllowedDir + path.sep) ||
normalizedFilePath === normalizedAllowedDir
) {
// Additional check: ensure no upward traversal after matching the prefix
const relativePath = path.relative(
normalizedAllowedDir,
normalizedFilePath
);
if (!relativePath.startsWith("..") && !path.isAbsolute(relativePath)) {
return true;
}
}
}
return false;
}
/**
* Fully resolves a file path, handling symlinks and security checks
*
* @param filePath The file path to resolve
* @returns The fully resolved file path
* @throws ValidationError if the path contains insecure symlinks
*/
public async fullyResolvePath(filePath: string): Promise<string> {
const normalizedPath = path.normalize(path.resolve(filePath));
try {
// Check if the target file exists and is a symlink
try {
const stats = await fs.lstat(normalizedPath);
if (stats.isSymbolicLink()) {
logger.warn(`Path is a symlink: ${normalizedPath}`);
const target = await fs.readlink(normalizedPath);
const resolvedPath = path.resolve(
path.dirname(normalizedPath),
target
);
// Ensure the symlink target is within allowed directories
if (!this.isPathWithinAllowedDirs(resolvedPath)) {
throw new ValidationError(
`Security error: Symlink target is outside allowed directories: ${resolvedPath}`
);
}
return resolvedPath;
}
} catch (err) {
// If file doesn't exist (ENOENT), that's fine in many cases
if (!isENOENTError(err)) {
throw err;
}
}
// Also check parent directories to ensure we're not inside a symlinked directory
let currentPath = path.dirname(normalizedPath);
const root = path.parse(currentPath).root;
// Track resolved parent paths
const resolvedPaths = new Map<string, string>();
while (currentPath !== root) {
try {
const dirStats = await fs.lstat(currentPath);
if (dirStats.isSymbolicLink()) {
// Resolve the symlink
const linkTarget = await fs.readlink(currentPath);
const resolvedPath = path.resolve(
path.dirname(currentPath),
linkTarget
);
logger.warn(
`Parent directory is a symlink: ${currentPath} -> ${resolvedPath}`
);
resolvedPaths.set(currentPath, resolvedPath);
// If this is the immediate parent, update the final path
if (currentPath === path.dirname(normalizedPath)) {
const updatedPath = path.join(
resolvedPath,
path.basename(normalizedPath)
);
// Ensure resolved path is still secure
if (!this.isPathWithinAllowedDirs(updatedPath)) {
throw new ValidationError(
`Security error: Resolved symlink path is outside allowed directories: ${updatedPath}`
);
}
return updatedPath;
}
}
} catch (err) {
if (!isENOENTError(err)) {
throw err;
}
}
currentPath = path.dirname(currentPath);
}
// If we found symlinks in parent directories, perform a final security check
if (resolvedPaths.size > 0) {
try {
// Get fully resolved path including all symlinks
const finalResolvedPath = await fs.realpath(normalizedPath);
// Final security check with the fully resolved path
if (!this.isPathWithinAllowedDirs(finalResolvedPath)) {
throw new ValidationError(
`Security error: Resolved path is outside allowed directories: ${finalResolvedPath}`
);
}
return finalResolvedPath;
} catch (err) {
// Handle case where path doesn't exist yet
if (isENOENTError(err)) {
// Try to resolve just the directory part
const resolvedDir = await fs
.realpath(path.dirname(normalizedPath))
.catch((dirErr) => {
if (isENOENTError(dirErr)) {
return path.dirname(normalizedPath);
}
throw dirErr;
});
const finalPath = path.join(
resolvedDir,
path.basename(normalizedPath)
);
// Final security check
if (!this.isPathWithinAllowedDirs(finalPath)) {
throw new ValidationError(
`Security error: Resolved path is outside allowed directories: ${finalPath}`
);
}
return finalPath;
}
throw err;
}
}
// No symlinks found, return the normalized path
return normalizedPath;
} catch (err) {
if (hasErrorMessage(err) && err.message.includes("Security error:")) {
// Re-throw security errors
throw err;
}
// For other errors, provide a clearer message
const errorMsg = hasErrorMessage(err) ? err.message : String(err);
logger.error(`Error resolving path: ${errorMsg}`, err);
throw new ValidationError(`Error validating path security: ${errorMsg}`);
}
}
/**
* Securely writes content to a file, ensuring the path is within allowed directories.
*
* @param filePath The relative or absolute path to the file.
* @param content The string content to write to the file.
* @param options Optional configuration
* @returns A promise that resolves when the file is written
* @throws ValidationError if the path is invalid, outside allowed directories,
* if the file exists and overwrite is false, or for any other security/file system error
*/
public async secureWriteFile(
filePath: string,
content: string,
options: {
overwrite?: boolean;
allowedDirs?: string[];
} = {}
): Promise<void> {
const { overwrite = false, allowedDirs } = options;
// Use instance's allowed directories if none provided
const effectiveAllowedDirs = allowedDirs || this.allowedDirectories;
// 1. Initial validation against allowed directories
const validatedPath = this.validateAndResolvePath(filePath, {
allowedDirs: effectiveAllowedDirs,
});
// 2. Fully resolve the path handling symlinks and do final security check
const finalFilePath = await this.fullyResolvePath(validatedPath);
// 3. Check if file exists and overwrite flag is false
if (!overwrite) {
try {
await fs.access(finalFilePath);
// If we get here, the file exists
logger.error(
`File already exists and overwrite is false: ${finalFilePath}`
);
throw new ValidationError(
`File already exists: ${filePath}. Set overwrite flag to true to replace it.`
);
} catch (err) {
// File doesn't exist or other access error - this is expected for new files
if (!isENOENTError(err)) {
// If error is not "file doesn't exist", it's another access error
logger.error(`Error checking file existence: ${finalFilePath}`, err);
const errorMsg = hasErrorMessage(err) ? err.message : String(err);
throw new ValidationError(`Error checking file access: ${errorMsg}`);
}
// If err.code === 'ENOENT', the file doesn't exist, which is fine for new files
}
}
// 4. Create parent directories if they don't exist
const dirPath = path.dirname(finalFilePath);
try {
await fs.mkdir(dirPath, { recursive: true });
} catch (err) {
logger.error(`Error creating directory ${dirPath}:`, err);
const errorMsg = hasErrorMessage(err) ? err.message : String(err);
throw new ValidationError(
`Failed to create directory structure: ${errorMsg}`
);
}
// 5. Write the file
try {
await fs.writeFile(finalFilePath, content, "utf8");
logger.info(`Successfully wrote file to ${finalFilePath}`);
} catch (err) {
logger.error(`Error writing file ${finalFilePath}:`, err);
const errorMsg = hasErrorMessage(err) ? err.message : String(err);
throw new ValidationError(`Failed to write file: ${errorMsg}`);
}
}
/**
* Initializes file path security from environment variables
* Call this during application startup
*/
public static configureFromEnvironment(): FileSecurityService {
const customBaseDir = process.env.GEMINI_SAFE_FILE_BASE_DIR;
const service = new FileSecurityService();
if (customBaseDir) {
// Validate that the custom base directory exists
try {
fsSync.accessSync(customBaseDir, fsSync.constants.F_OK);
logger.info(`File operations restricted to: ${customBaseDir}`);
service.setAllowedDirectories([customBaseDir]);
} catch (error) {
logger.warn(
`Configured GEMINI_SAFE_FILE_BASE_DIR does not exist: ${customBaseDir}`
);
logger.warn(`Falling back to default directory: ${process.cwd()}`);
service.setAllowedDirectories([process.cwd()]);
}
} else {
logger.info(
`File operations restricted to current working directory: ${process.cwd()}`
);
service.setAllowedDirectories([process.cwd()]);
}
return service;
}
}
```
--------------------------------------------------------------------------------
/tests/unit/services/gemini/GeminiUrlContextService.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import { GeminiUrlContextService } from "../../../../src/services/gemini/GeminiUrlContextService.js";
import { ConfigurationManager } from "../../../../src/config/ConfigurationManager.js";
import { GeminiUrlFetchError } from "../../../../src/utils/geminiErrors.js";
// Mock dependencies
vi.mock("../../../../src/config/ConfigurationManager.js");
vi.mock("../../../../src/utils/logger.js");
vi.mock("../../../../src/utils/UrlSecurityService.js");
// Mock fetch globally
const mockFetch = vi.fn();
global.fetch = mockFetch;
interface MockConfigManager {
getUrlContextConfig: ReturnType<typeof vi.fn>;
}
describe("GeminiUrlContextService", () => {
let service: GeminiUrlContextService;
let mockConfig: MockConfigManager;
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks();
// Mock configuration
mockConfig = {
getUrlContextConfig: vi.fn().mockReturnValue({
enabled: true,
maxUrlsPerRequest: 20,
defaultMaxContentKb: 100,
defaultTimeoutMs: 10000,
allowedDomains: ["*"],
blocklistedDomains: [],
convertToMarkdown: true,
includeMetadata: true,
enableCaching: true,
cacheExpiryMinutes: 15,
maxCacheSize: 1000,
rateLimitPerDomainPerMinute: 10,
userAgent: "MCP-Gemini-Server/1.0",
}),
};
// Create service instance
service = new GeminiUrlContextService(
mockConfig as unknown as ConfigurationManager
);
});
afterEach(() => {
vi.resetAllMocks();
});
describe("fetchUrlContent", () => {
it("should successfully fetch and process HTML content", async () => {
const mockHtmlContent = `
<!DOCTYPE html>
<html>
<head>
<title>Test Page</title>
<meta name="description" content="A test page">
</head>
<body>
<h1>Main Heading</h1>
<p>This is a test paragraph with <strong>bold text</strong>.</p>
<ul>
<li>Item 1</li>
<li>Item 2</li>
</ul>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: "https://example.com/test",
headers: new Map([
["content-type", "text/html; charset=utf-8"],
["content-length", mockHtmlContent.length.toString()],
]),
text: () => Promise.resolve(mockHtmlContent),
});
const result = await service.fetchUrlContent("https://example.com/test");
expect(result).toBeDefined();
expect(result.metadata.url).toBe("https://example.com/test");
expect(result.metadata.statusCode).toBe(200);
expect(result.metadata.title).toBe("Test Page");
expect(result.metadata.description).toBe("A test page");
expect(result.content).toContain("# Main Heading");
expect(result.content).toContain("**bold text**");
expect(result.content).toContain("- Item 1");
});
it("should handle fetch errors gracefully", async () => {
mockFetch.mockRejectedValueOnce(new Error("Network error"));
await expect(
service.fetchUrlContent("https://example.com/error")
).rejects.toThrow(GeminiUrlFetchError);
});
it("should handle HTTP error responses", async () => {
mockFetch.mockResolvedValueOnce({
ok: false,
status: 404,
statusText: "Not Found",
url: "https://example.com/notfound",
headers: new Map(),
text: () => Promise.resolve("Page not found"),
});
await expect(
service.fetchUrlContent("https://example.com/notfound")
).rejects.toThrow(GeminiUrlFetchError);
});
it("should respect content size limits", async () => {
const largeContent = "x".repeat(200 * 1024); // 200KB content
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: "https://example.com/large",
headers: new Map([
["content-type", "text/html"],
["content-length", largeContent.length.toString()],
]),
text: () => Promise.resolve(largeContent),
});
const result = await service.fetchUrlContent(
"https://example.com/large",
{
maxContentLength: 100 * 1024, // 100KB limit
}
);
expect(result.metadata.truncated).toBe(true);
expect(result.content.length).toBeLessThanOrEqual(100 * 1024);
});
it("should handle JSON content without conversion", async () => {
const jsonContent = JSON.stringify({
message: "Hello World",
data: [1, 2, 3],
});
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: "https://api.example.com/data",
headers: new Map([
["content-type", "application/json"],
["content-length", jsonContent.length.toString()],
]),
text: () => Promise.resolve(jsonContent),
});
const result = await service.fetchUrlContent(
"https://api.example.com/data",
{
convertToMarkdown: false,
}
);
expect(result.content).toBe(jsonContent);
expect(result.metadata.contentType).toBe("application/json");
});
});
describe("processUrlsForContext", () => {
it("should process multiple URLs successfully", async () => {
const urls = ["https://example1.com", "https://example2.com"];
const mockContent1 =
"<html><head><title>Page 1</title></head><body><p>Content 1</p></body></html>";
const mockContent2 =
"<html><head><title>Page 2</title></head><body><p>Content 2</p></body></html>";
mockFetch
.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: urls[0],
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockContent1),
})
.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: urls[1],
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(mockContent2),
});
const result = await service.processUrlsForContext(urls);
expect(result.contents).toHaveLength(2);
expect(result.batchResult.summary.totalUrls).toBe(2);
expect(result.batchResult.summary.successCount).toBe(2);
expect(result.batchResult.summary.failureCount).toBe(0);
expect(result.contents[0]).toBeDefined();
expect(result.contents[0]!.parts).toBeDefined();
expect(result.contents[0]!.parts![0]).toBeDefined();
expect(result.contents[0]!.parts![0]!.text).toContain(
"Content from https://example1.com"
);
expect(result.contents[1]).toBeDefined();
expect(result.contents[1]!.parts).toBeDefined();
expect(result.contents[1]!.parts![0]).toBeDefined();
expect(result.contents[1]!.parts![0]!.text).toContain(
"Content from https://example2.com"
);
});
it("should handle mixed success and failure scenarios", async () => {
const urls = [
"https://example1.com",
"https://failed.com",
"https://example3.com",
];
mockFetch
.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: urls[0],
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve("<html><body>Content 1</body></html>"),
})
.mockRejectedValueOnce(new Error("Network error"))
.mockResolvedValueOnce({
ok: true,
status: 200,
statusText: "OK",
url: urls[2],
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve("<html><body>Content 3</body></html>"),
});
const result = await service.processUrlsForContext(urls);
expect(result.batchResult.summary.totalUrls).toBe(3);
expect(result.batchResult.summary.successCount).toBe(2);
expect(result.batchResult.summary.failureCount).toBe(1);
expect(result.batchResult.failed).toHaveLength(1);
expect(result.batchResult.failed[0].url).toBe("https://failed.com");
});
it("should reject if too many URLs provided", async () => {
const urls = Array.from(
{ length: 25 },
(_, i) => `https://example${i}.com`
);
await expect(service.processUrlsForContext(urls)).rejects.toThrow(
"Too many URLs: 25. Maximum allowed: 20"
);
});
it("should reject if no URLs provided", async () => {
await expect(service.processUrlsForContext([])).rejects.toThrow(
"No URLs provided for processing"
);
});
});
describe("HTML to Markdown conversion", () => {
it("should convert headings correctly", async () => {
const htmlContent = `
<html>
<body>
<h1>Heading 1</h1>
<h2>Heading 2</h2>
<h3>Heading 3</h3>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
const result = await service.fetchUrlContent("https://example.com");
expect(result.content).toContain("# Heading 1");
expect(result.content).toContain("## Heading 2");
expect(result.content).toContain("### Heading 3");
});
it("should convert lists correctly", async () => {
const htmlContent = `
<html>
<body>
<ul>
<li>Unordered item 1</li>
<li>Unordered item 2</li>
</ul>
<ol>
<li>Ordered item 1</li>
<li>Ordered item 2</li>
</ol>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
const result = await service.fetchUrlContent("https://example.com");
expect(result.content).toContain("- Unordered item 1");
expect(result.content).toContain("- Unordered item 2");
expect(result.content).toContain("1. Ordered item 1");
expect(result.content).toContain("2. Ordered item 2");
});
it("should convert links correctly", async () => {
const htmlContent = `
<html>
<body>
<a href="https://example.com">Example Link</a>
<a href="/relative/path">Relative Link</a>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
const result = await service.fetchUrlContent("https://example.com");
expect(result.content).toContain("[Example Link](https://example.com)");
expect(result.content).toContain("[Relative Link](/relative/path)");
});
it("should remove script and style tags", async () => {
const htmlContent = `
<html>
<head>
<style>body { color: red; }</style>
</head>
<body>
<p>Visible content</p>
<script>console.log('hidden');</script>
<p>More visible content</p>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
const result = await service.fetchUrlContent("https://example.com");
expect(result.content).toContain("Visible content");
expect(result.content).toContain("More visible content");
expect(result.content).not.toContain("color: red");
expect(result.content).not.toContain("console.log");
});
});
describe("Content metadata extraction", () => {
it("should extract title and description from meta tags", async () => {
const htmlContent = `
<html>
<head>
<title>Test Page Title</title>
<meta name="description" content="Test page description">
<meta property="og:image" content="https://example.com/image.jpg">
<link rel="canonical" href="https://example.com/canonical">
</head>
<body>
<p>Content</p>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
const result = await service.fetchUrlContent("https://example.com");
expect(result.metadata.title).toBe("Test Page Title");
expect(result.metadata.description).toBe("Test page description");
expect(result.metadata.ogImage).toBe("https://example.com/image.jpg");
expect(result.metadata.canonicalUrl).toBe(
"https://example.com/canonical"
);
});
it("should handle HTML entities in metadata", async () => {
const htmlContent = `
<html>
<head>
<title>Title with & ampersand <tags></title>
<meta name="description" content="Description with "quotes" and spaces">
</head>
<body>
<p>Content</p>
</body>
</html>
`;
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
const result = await service.fetchUrlContent("https://example.com");
expect(result.metadata.title).toBe("Title with & ampersand <tags>");
expect(result.metadata.description).toBe(
'Description with "quotes" and spaces'
);
});
});
describe("Caching functionality", () => {
it("should cache successful results", async () => {
const htmlContent = "<html><body><p>Cached content</p></body></html>";
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com",
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve(htmlContent),
});
// First call - should fetch from network
const result1 = await service.fetchUrlContent("https://example.com");
expect(mockFetch).toHaveBeenCalledTimes(1);
// Second call - should return from cache
const result2 = await service.fetchUrlContent("https://example.com");
expect(mockFetch).toHaveBeenCalledTimes(1); // No additional fetch
expect(result1.content).toBe(result2.content);
expect(result1.metadata.url).toBe(result2.metadata.url);
});
});
describe("Rate limiting", () => {
it("should enforce rate limits per domain", async () => {
const url = "https://example.com/page";
// Mock multiple successful responses
for (let i = 0; i < 15; i++) {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url,
headers: new Map([["content-type", "text/html"]]),
text: () => Promise.resolve("<html><body>Content</body></html>"),
});
}
// First 10 requests should succeed
for (let i = 0; i < 10; i++) {
await service.fetchUrlContent(`${url}?page=${i}`);
}
// 11th request should fail due to rate limiting
await expect(service.fetchUrlContent(`${url}?page=11`)).rejects.toThrow(
GeminiUrlFetchError
);
});
});
describe("Error handling", () => {
it("should handle timeout errors", async () => {
mockFetch.mockRejectedValueOnce(new Error("Request timeout"));
await expect(
service.fetchUrlContent("https://example.com/timeout")
).rejects.toThrow(GeminiUrlFetchError);
});
it("should handle unsupported content types", async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
status: 200,
url: "https://example.com/binary",
headers: new Map([["content-type", "application/octet-stream"]]),
text: () => Promise.resolve("binary data"),
});
await expect(
service.fetchUrlContent("https://example.com/binary")
).rejects.toThrow(GeminiUrlFetchError);
});
});
});
```
--------------------------------------------------------------------------------
/src/utils/UrlSecurityService.ts:
--------------------------------------------------------------------------------
```typescript
import { ConfigurationManager } from "../config/ConfigurationManager.js";
import { GeminiUrlValidationError } from "./geminiErrors.js";
import { logger } from "./logger.js";
export interface UrlValidationResult {
valid: boolean;
reason?: string;
warnings?: string[];
}
export interface SecurityMetrics {
validationAttempts: number;
validationFailures: number;
blockedDomains: Set<string>;
suspiciousPatterns: string[];
rateLimitViolations: number;
}
/**
* Comprehensive URL Security Service for validating and securing URL access
* Prevents access to malicious, private, or restricted URLs
*/
export class UrlSecurityService {
private readonly logger: typeof logger;
private readonly securityMetrics: SecurityMetrics;
// Known dangerous TLDs and patterns
private readonly dangerousTlds = new Set([
"tk",
"ml",
"ga",
"cf", // Free domains often used for malicious purposes
"bit",
"link",
"click", // URL shorteners that can hide destinations
"download",
"zip",
"exe", // File-like TLDs
]);
// Suspicious URL patterns
private readonly suspiciousPatterns = [
/\.\./, // Path traversal
/@.*@/, // Multiple @ symbols
/javascript:/i, // JavaScript protocol
/data:/i, // Data URLs
/file:/i, // File protocol
/ftp:/i, // FTP protocol
/localhost|127\.0\.0\.1|0\.0\.0\.0/i, // Localhost
/\.(local|internal|private|corp|lan)$/i, // Internal domains
/%[0-9a-f]{2}/i, // URL encoding (suspicious in domain names)
/[<>{}\\^`|"]/i, // Dangerous characters
];
// Known malicious domains and patterns (expandable list)
private readonly knownMaliciousDomains = new Set([
"malware.com",
"phishing.com",
"spam.com",
"virus.com",
"trojan.com",
]);
// Private/internal network ranges
private readonly privateNetworkRanges = [
/^10\./,
/^172\.(1[6-9]|2[0-9]|3[01])\./,
/^192\.168\./,
/^169\.254\./, // Link-local
/^224\./, // Multicast
/^fc00:/, // IPv6 unique local
/^fe80:/, // IPv6 link-local
/^ff00:/, // IPv6 multicast
];
constructor(private readonly config: ConfigurationManager) {
this.logger = logger;
this.securityMetrics = {
validationAttempts: 0,
validationFailures: 0,
blockedDomains: new Set(),
suspiciousPatterns: [],
rateLimitViolations: 0,
};
}
/**
* Comprehensive URL validation with security checks
*/
async validateUrl(url: string, allowedDomains?: string[]): Promise<void> {
this.securityMetrics.validationAttempts++;
try {
// Basic URL format validation
let parsedUrl: URL;
try {
parsedUrl = new URL(url);
} catch (error) {
this.logSecurityEvent("Invalid URL format", { url, error });
throw new GeminiUrlValidationError(
`Invalid URL format: ${url}`,
url,
"invalid_format"
);
}
// Protocol validation
if (!this.isAllowedProtocol(parsedUrl.protocol)) {
this.logSecurityEvent("Blocked protocol", {
url,
protocol: parsedUrl.protocol,
});
throw new GeminiUrlValidationError(
`Protocol not allowed: ${parsedUrl.protocol}`,
url,
"blocked_domain"
);
}
// Check for suspicious patterns
const suspiciousCheck = this.checkSuspiciousPatterns(url, parsedUrl);
if (!suspiciousCheck.valid) {
this.logSecurityEvent("Suspicious pattern detected", {
url,
reason: suspiciousCheck.reason,
});
throw new GeminiUrlValidationError(
suspiciousCheck.reason || "Suspicious URL pattern detected",
url,
"suspicious_pattern"
);
}
// Domain validation
await this.validateDomain(parsedUrl, allowedDomains);
// Check for known malicious domains
if (this.isKnownMaliciousDomain(parsedUrl.hostname)) {
this.logSecurityEvent("Known malicious domain", {
url,
domain: parsedUrl.hostname,
});
this.securityMetrics.blockedDomains.add(parsedUrl.hostname);
throw new GeminiUrlValidationError(
`Access to known malicious domain blocked: ${parsedUrl.hostname}`,
url,
"blocked_domain"
);
}
// Check URL configuration limits
this.validateUrlConfiguration(parsedUrl);
// Additional security checks
await this.performAdvancedSecurityChecks(parsedUrl);
this.logger.debug("URL validation passed", {
url,
domain: parsedUrl.hostname,
});
} catch (error) {
this.securityMetrics.validationFailures++;
if (error instanceof GeminiUrlValidationError) {
throw error;
}
throw new GeminiUrlValidationError(
`URL validation failed: ${error instanceof Error ? error.message : String(error)}`,
url,
"invalid_format"
);
}
}
/**
* Check if URL is accessible without actually fetching it
*/
async checkUrlAccessibility(url: string): Promise<boolean> {
try {
const response = await fetch(url, {
method: "HEAD",
headers: {
"User-Agent": "MCP-Gemini-Server-HealthCheck/1.0",
},
});
return response.ok;
} catch (error) {
this.logger.debug("URL accessibility check failed", { url, error });
return false;
}
}
/**
* Get security metrics for monitoring
*/
getSecurityMetrics(): SecurityMetrics {
return {
...this.securityMetrics,
blockedDomains: new Set(this.securityMetrics.blockedDomains),
suspiciousPatterns: [...this.securityMetrics.suspiciousPatterns],
};
}
/**
* Reset security metrics (useful for testing)
*/
resetSecurityMetrics(): void {
this.securityMetrics.validationAttempts = 0;
this.securityMetrics.validationFailures = 0;
this.securityMetrics.blockedDomains.clear();
this.securityMetrics.suspiciousPatterns.length = 0;
this.securityMetrics.rateLimitViolations = 0;
}
/**
* Add custom malicious domain to blocklist
*/
addMaliciousDomain(domain: string): void {
this.knownMaliciousDomains.add(domain.toLowerCase());
this.logger.info("Added domain to malicious blocklist", { domain });
}
/**
* Check if protocol is allowed
*/
private isAllowedProtocol(protocol: string): boolean {
const allowedProtocols = ["http:", "https:"];
return allowedProtocols.includes(protocol.toLowerCase());
}
/**
* Check for suspicious URL patterns
*/
private checkSuspiciousPatterns(
url: string,
parsedUrl: URL
): UrlValidationResult {
const warnings: string[] = [];
// Check for control characters
if (this.hasControlCharacters(url)) {
const reason = "Control characters detected in URL";
this.securityMetrics.suspiciousPatterns.push(reason);
return { valid: false, reason };
}
// Check each suspicious pattern
for (const pattern of this.suspiciousPatterns) {
if (pattern.test(url)) {
const reason = `Suspicious pattern detected: ${pattern.source}`;
this.securityMetrics.suspiciousPatterns.push(reason);
return { valid: false, reason };
}
}
// Check for dangerous TLDs
const tld = parsedUrl.hostname.split(".").pop()?.toLowerCase();
if (tld && this.dangerousTlds.has(tld)) {
warnings.push(`Potentially dangerous TLD: .${tld}`);
}
// Check for IDN homograph attacks
if (this.detectIdnHomograph(parsedUrl.hostname)) {
this.logger.warn("IDN homograph attack detected", {
hostname: parsedUrl.hostname,
});
return {
valid: false,
reason: "Potential IDN homograph attack detected in domain name",
};
}
// Check for URL shorteners (could hide destination)
if (this.isUrlShortener(parsedUrl.hostname)) {
warnings.push("URL shortener detected - destination cannot be verified");
}
return { valid: true, warnings };
}
/**
* Validate domain against whitelist/blacklist
*/
private async validateDomain(
parsedUrl: URL,
allowedDomains?: string[]
): Promise<void> {
const hostname = parsedUrl.hostname.toLowerCase();
const urlConfig = this.config.getUrlContextConfig();
// Check blocklist first
if (urlConfig.blocklistedDomains.length > 0) {
for (const blockedPattern of urlConfig.blocklistedDomains) {
if (this.matchesDomainPattern(hostname, blockedPattern)) {
this.securityMetrics.blockedDomains.add(hostname);
throw new GeminiUrlValidationError(
`Domain is blocked: ${hostname}`,
parsedUrl.href,
"blocked_domain"
);
}
}
}
// Check allowlist if specified
const domainsToCheck = allowedDomains || urlConfig.allowedDomains;
if (domainsToCheck.length > 0 && !domainsToCheck.includes("*")) {
let allowed = false;
for (const allowedPattern of domainsToCheck) {
if (this.matchesDomainPattern(hostname, allowedPattern)) {
allowed = true;
break;
}
}
if (!allowed) {
throw new GeminiUrlValidationError(
`Domain not in allowlist: ${hostname}`,
parsedUrl.href,
"blocked_domain"
);
}
}
// Check for private/internal networks
if (this.isPrivateOrInternalAddress(hostname)) {
throw new GeminiUrlValidationError(
`Access to private/internal addresses blocked: ${hostname}`,
parsedUrl.href,
"blocked_domain"
);
}
}
/**
* Check if domain matches a pattern (supports wildcards)
*/
private matchesDomainPattern(domain: string, pattern: string): boolean {
if (pattern === "*") {
return true;
}
if (pattern.startsWith("*.")) {
const suffix = pattern.slice(2);
return domain === suffix || domain.endsWith("." + suffix);
}
// For blocklist, also block subdomains
// e.g., "malicious.com" should block "sub.malicious.com"
if (domain === pattern || domain.endsWith("." + pattern)) {
return true;
}
return false;
}
/**
* Check if address is private/internal
*/
private isPrivateOrInternalAddress(hostname: string): boolean {
// Check if it's an IP address
const ipv4Regex = /^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$/;
const ipv6Regex = /^([0-9a-f]{0,4}:){2,7}[0-9a-f]{0,4}$/i;
if (ipv4Regex.test(hostname) || ipv6Regex.test(hostname)) {
return this.privateNetworkRanges.some((range) => range.test(hostname));
}
// Check for internal domain patterns
return (
/\.(local|internal|private|corp|lan|test|dev|localhost)$/i.test(
hostname
) || hostname === "localhost"
);
}
/**
* Check if domain is a known URL shortener
*/
private isUrlShortener(hostname: string): boolean {
const shorteners = [
"bit.ly",
"tinyurl.com",
"short.link",
"ow.ly",
"t.co",
"goo.gl",
"tiny.cc",
"is.gd",
"buff.ly",
"bitly.com",
];
return shorteners.includes(hostname);
}
/**
* Detect potential IDN homograph attacks
*/
private detectIdnHomograph(hostname: string): boolean {
// Check if hostname contains Punycode (IDN encoded) parts
const parts = hostname.split(".");
const punycodePattern = /^xn--/;
for (const part of parts) {
if (punycodePattern.test(part)) {
// This is a Punycode domain
// Check for suspicious patterns that indicate homograph attacks
// Common homograph attacks target well-known domains
// They usually have short encoded names that look like popular sites
const encodedPart = part.substring(4); // Remove "xn--" prefix
// Check for patterns that look like common targets
// e.g., "gogle", "mircosoft", "amaz0n" etc.
// These tend to encode to relatively short Punycode strings
if (encodedPart.length <= 10 && parts.length === 2) {
// Short encoded domain + TLD (like google.com) - suspicious
const tld = parts[parts.length - 1];
if (["com", "org", "net", "io", "co"].includes(tld)) {
this.logger.warn("Suspicious Punycode domain detected", {
hostname,
part,
});
return true;
}
}
// Also flag if it's a subdomain of a legitimate domain
// e.g., xn--pple-43d.com (apple with Cyrillic 'a')
if (
parts.length === 2 &&
encodedPart.match(/^[a-z0-9]{3,8}-[a-z0-9]{2,4}$/)
) {
// Pattern matches common homograph encoding patterns
return true;
}
}
}
// Check for mixed scripts that could be confusing
const hasLatin = /[a-zA-Z]/.test(hostname);
const hasCyrillic = /[\u0400-\u04FF]/.test(hostname);
const hasGreek = /[\u0370-\u03FF]/.test(hostname);
// Mixed scripts could indicate homograph attack
const scriptCount = [hasLatin, hasCyrillic, hasGreek].filter(
Boolean
).length;
if (scriptCount > 1) {
return true;
}
// Check for any Cyrillic characters that could be confused with Latin
// This includes common lookalike characters
if (hasCyrillic && hostname.match(/[a-zA-Z]/)) {
// Has both Cyrillic and Latin - likely homograph attack
return true;
}
return false;
}
/**
* Check if domain is known to be malicious
*/
private isKnownMaliciousDomain(hostname: string): boolean {
const lowerHostname = hostname.toLowerCase();
// Check exact matches
if (this.knownMaliciousDomains.has(lowerHostname)) {
return true;
}
// Check subdomains of known malicious domains
for (const maliciousDomain of this.knownMaliciousDomains) {
if (lowerHostname.endsWith("." + maliciousDomain)) {
return true;
}
}
return false;
}
/**
* Validate URL against configuration limits
*/
private validateUrlConfiguration(parsedUrl: URL): void {
// Check URL length
if (parsedUrl.href.length > 2048) {
throw new GeminiUrlValidationError(
"URL too long (max 2048 characters)",
parsedUrl.href,
"invalid_format"
);
}
// Check for suspicious ports
const port = parsedUrl.port;
if (port) {
const portNum = parseInt(port);
const allowedPorts = [80, 443, 8080, 8443];
if (!allowedPorts.includes(portNum)) {
throw new GeminiUrlValidationError(
`Port not allowed: ${port}`,
parsedUrl.href,
"blocked_domain"
);
}
}
}
/**
* Perform advanced security checks
*/
private async performAdvancedSecurityChecks(parsedUrl: URL): Promise<void> {
// Check for recently registered domains (simplified check)
const hostname = parsedUrl.hostname;
const parts = hostname.split(".");
// Private IP check is already done in validateDomain method
// Very new domains might be suspicious
if (parts.length === 2 && parts[0].length < 3) {
this.logger.warn("Potentially suspicious short domain", { hostname });
}
// Check for excessive subdomains (possible DGA)
if (parts.length > 5) {
this.logger.warn("Excessive subdomain levels detected", {
hostname,
levels: parts.length,
});
}
// Check for random-looking domains
if (this.looksRandomlyGenerated(hostname)) {
this.logger.warn("Potentially randomly generated domain", { hostname });
}
}
/**
* Check if domain name looks randomly generated
*/
private looksRandomlyGenerated(hostname: string): boolean {
const mainDomain = hostname.split(".")[0];
// Check for patterns indicating random generation
const hasRepeatingChars = /(.)\1{3,}/.test(mainDomain);
const hasAlternatingPattern = /([a-z])([0-9])\1\2/.test(mainDomain);
const hasExcessiveNumbers =
(mainDomain.match(/[0-9]/g) || []).length > mainDomain.length * 0.5;
const hasNoVowels = !/[aeiou]/i.test(mainDomain);
const isVeryShort = mainDomain.length < 4;
const isVeryLong = mainDomain.length > 20;
return (
hasRepeatingChars ||
hasAlternatingPattern ||
hasExcessiveNumbers ||
(hasNoVowels && !isVeryShort) ||
isVeryLong
);
}
/**
* Log security events for monitoring
*/
private logSecurityEvent(
event: string,
details: Record<string, unknown>
): void {
this.logger.warn(`Security event: ${event}`, details);
}
// Helper method to check for control characters
private hasControlCharacters(text: string): boolean {
for (let i = 0; i < text.length; i++) {
const charCode = text.charCodeAt(i);
if (
(charCode >= 0 && charCode <= 31) ||
(charCode >= 127 && charCode <= 159)
) {
return true;
}
}
return false;
}
}
```
--------------------------------------------------------------------------------
/tests/unit/services/gemini/GeminiValidationSchemas.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import { ZodError } from "zod";
import {
validateImageGenerationParams,
validateGenerateContentParams,
validateRouteMessageParams,
ImageGenerationParamsSchema,
GenerateContentParamsSchema,
RouteMessageParamsSchema,
ThinkingConfigSchema,
GenerationConfigSchema,
} from "../../../../src/services/gemini/GeminiValidationSchemas.js";
describe("GeminiValidationSchemas", () => {
describe("Image Generation Validation", () => {
it("should validate valid image generation parameters", () => {
const validParams = {
prompt: "A beautiful sunset over the ocean",
modelName: "imagen-3.1-generate-003",
resolution: "1024x1024",
numberOfImages: 2,
safetySettings: [
{
category: "HARM_CATEGORY_HARASSMENT",
threshold: "BLOCK_MEDIUM_AND_ABOVE",
},
],
negativePrompt: "clouds, rain",
stylePreset: "photographic",
seed: 12345,
styleStrength: 0.75,
};
// Should not throw
const result = ImageGenerationParamsSchema.parse(validParams);
expect(result.prompt).toBe(validParams.prompt);
expect(result.modelName).toBe(validParams.modelName);
expect(result.resolution).toBe(validParams.resolution);
});
it("should validate using the validateImageGenerationParams helper", () => {
const result = validateImageGenerationParams(
"A beautiful sunset",
"imagen-3.1-generate-003",
"1024x1024",
2
);
expect(result.prompt).toBe("A beautiful sunset");
expect(result.modelName).toBe("imagen-3.1-generate-003");
expect(result.resolution).toBe("1024x1024");
expect(result.numberOfImages).toBe(2);
});
it("should throw on invalid prompt", () => {
expect(() => ImageGenerationParamsSchema.parse({ prompt: "" })).toThrow(
ZodError
);
try {
ImageGenerationParamsSchema.parse({ prompt: "" });
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("prompt");
}
});
it("should throw on invalid resolution", () => {
expect(() =>
ImageGenerationParamsSchema.parse({
prompt: "valid prompt",
resolution: "invalid-resolution",
})
).toThrow(ZodError);
try {
ImageGenerationParamsSchema.parse({
prompt: "valid prompt",
resolution: "invalid-resolution",
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("resolution");
}
});
it("should throw on invalid numberOfImages", () => {
expect(() =>
ImageGenerationParamsSchema.parse({
prompt: "valid prompt",
numberOfImages: 20, // Max is 8
})
).toThrow(ZodError);
try {
ImageGenerationParamsSchema.parse({
prompt: "valid prompt",
numberOfImages: 20, // Max is 8
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("numberOfImages");
}
});
it("should throw on invalid styleStrength", () => {
expect(() =>
ImageGenerationParamsSchema.parse({
prompt: "valid prompt",
styleStrength: 2.5, // Max is 1.0
})
).toThrow(ZodError);
try {
ImageGenerationParamsSchema.parse({
prompt: "valid prompt",
styleStrength: 2.5, // Max is 1.0
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("styleStrength");
}
});
});
describe("Thinking Budget Validation", () => {
it("should validate valid thinking budget", () => {
const validThinkingConfig = {
thinkingBudget: 5000,
};
// Should not throw
const result = ThinkingConfigSchema.parse(validThinkingConfig);
expect(result?.thinkingBudget).toBe(5000);
});
it("should validate empty thinking budget object", () => {
const emptyThinkingConfig = {};
// Should not throw
const result = ThinkingConfigSchema.parse(emptyThinkingConfig);
expect(result?.thinkingBudget).toBeUndefined();
});
it("should validate valid reasoningEffort values", () => {
const validValues = ["none", "low", "medium", "high"];
for (const value of validValues) {
// Should not throw
const result = ThinkingConfigSchema.parse({ reasoningEffort: value });
expect(result?.reasoningEffort).toBe(value);
}
});
it("should throw on invalid reasoningEffort values", () => {
expect(() =>
ThinkingConfigSchema.parse({ reasoningEffort: "invalid" })
).toThrow(ZodError);
try {
ThinkingConfigSchema.parse({ reasoningEffort: "invalid" });
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("reasoningEffort");
}
});
it("should validate both thinkingBudget and reasoningEffort in same object", () => {
const config = {
thinkingBudget: 5000,
reasoningEffort: "medium",
};
// Should not throw
const result = ThinkingConfigSchema.parse(config);
expect(result?.thinkingBudget).toBe(5000);
expect(result?.reasoningEffort).toBe("medium");
});
it("should validate thinking budget at boundaries", () => {
// Min value (0)
expect(() =>
ThinkingConfigSchema.parse({ thinkingBudget: 0 })
).not.toThrow();
// Max value (24576)
expect(() =>
ThinkingConfigSchema.parse({ thinkingBudget: 24576 })
).not.toThrow();
});
it("should throw on invalid thinking budget values", () => {
// Below min value
expect(() => ThinkingConfigSchema.parse({ thinkingBudget: -1 })).toThrow(
ZodError
);
try {
ThinkingConfigSchema.parse({ thinkingBudget: -1 });
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("thinkingBudget");
}
// Above max value
expect(() =>
ThinkingConfigSchema.parse({ thinkingBudget: 30000 })
).toThrow(ZodError);
try {
ThinkingConfigSchema.parse({ thinkingBudget: 30000 });
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("thinkingBudget");
}
// Non-integer value
expect(() =>
ThinkingConfigSchema.parse({ thinkingBudget: 100.5 })
).toThrow(ZodError);
try {
ThinkingConfigSchema.parse({ thinkingBudget: 100.5 });
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("thinkingBudget");
}
});
it("should validate thinking config within generation config", () => {
const validGenerationConfig = {
temperature: 0.7,
thinkingConfig: {
thinkingBudget: 5000,
},
};
// Should not throw
const result = GenerationConfigSchema.parse(validGenerationConfig);
expect(result?.temperature).toBe(0.7);
expect(result?.thinkingConfig?.thinkingBudget).toBe(5000);
});
it("should validate reasoningEffort within generation config", () => {
const validGenerationConfig = {
temperature: 0.7,
thinkingConfig: {
reasoningEffort: "high",
},
};
// Should not throw
const result = GenerationConfigSchema.parse(validGenerationConfig);
expect(result?.temperature).toBe(0.7);
expect(result?.thinkingConfig?.reasoningEffort).toBe("high");
});
it("should throw on invalid thinking budget in generation config", () => {
expect(() =>
GenerationConfigSchema.parse({
temperature: 0.7,
thinkingConfig: {
thinkingBudget: 30000, // Above max
},
})
).toThrow(ZodError);
try {
GenerationConfigSchema.parse({
temperature: 0.7,
thinkingConfig: {
thinkingBudget: 30000, // Above max
},
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("thinkingConfig");
expect(zodError.errors[0].path[1]).toBe("thinkingBudget");
}
});
});
describe("Content Generation Validation", () => {
it("should validate valid content generation parameters", () => {
const validParams = {
prompt: "Tell me about AI",
modelName: "gemini-1.5-flash",
generationConfig: {
temperature: 0.7,
topP: 0.9,
maxOutputTokens: 1000,
thinkingConfig: {
thinkingBudget: 4096,
},
},
safetySettings: [
{
category: "HARM_CATEGORY_HARASSMENT",
threshold: "BLOCK_MEDIUM_AND_ABOVE",
},
],
systemInstruction: "You are a helpful assistant",
};
// Should not throw
const result = GenerateContentParamsSchema.parse(validParams);
expect(result.prompt).toBe(validParams.prompt);
expect(result.modelName).toBe(validParams.modelName);
expect(result.generationConfig).toEqual(validParams.generationConfig);
});
it("should validate using the validateGenerateContentParams helper", () => {
const result = validateGenerateContentParams({
prompt: "Tell me about AI",
modelName: "gemini-1.5-flash",
});
expect(result.prompt).toBe("Tell me about AI");
expect(result.modelName).toBe("gemini-1.5-flash");
});
it("should throw on invalid prompt", () => {
expect(() => GenerateContentParamsSchema.parse({ prompt: "" })).toThrow(
ZodError
);
try {
GenerateContentParamsSchema.parse({ prompt: "" });
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("prompt");
}
});
it("should throw on invalid temperature", () => {
expect(() =>
GenerateContentParamsSchema.parse({
prompt: "valid prompt",
generationConfig: {
temperature: 2.5, // Max is 1.0
},
})
).toThrow(ZodError);
try {
GenerateContentParamsSchema.parse({
prompt: "valid prompt",
generationConfig: {
temperature: 2.5, // Max is 1.0
},
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("generationConfig");
expect(zodError.errors[0].path[1]).toBe("temperature");
}
});
it("should accept string or ContentSchema for systemInstruction", () => {
// String form
expect(() =>
GenerateContentParamsSchema.parse({
prompt: "valid prompt",
systemInstruction: "You are a helpful assistant",
})
).not.toThrow();
// Object form
expect(() =>
GenerateContentParamsSchema.parse({
prompt: "valid prompt",
systemInstruction: {
role: "system",
parts: [{ text: "You are a helpful assistant" }],
},
})
).not.toThrow();
});
});
describe("Router Validation", () => {
it("should validate valid router parameters", () => {
const validParams = {
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
routingPrompt: "Choose the best model for this question",
defaultModel: "gemini-1.5-pro",
generationConfig: {
temperature: 0.7,
maxOutputTokens: 1000,
},
safetySettings: [
{
category: "HARM_CATEGORY_HARASSMENT",
threshold: "BLOCK_MEDIUM_AND_ABOVE",
},
],
systemInstruction: "You are a helpful assistant",
};
// Should not throw
const result = RouteMessageParamsSchema.parse(validParams);
expect(result.message).toBe(validParams.message);
expect(result.models).toEqual(validParams.models);
expect(result.routingPrompt).toBe(validParams.routingPrompt);
});
it("should validate using the validateRouteMessageParams helper", () => {
const result = validateRouteMessageParams({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
});
expect(result.message).toBe("What is the capital of France?");
expect(result.models).toEqual(["gemini-1.5-pro", "gemini-1.5-flash"]);
});
it("should throw on empty message", () => {
expect(() =>
RouteMessageParamsSchema.parse({
message: "",
models: ["gemini-1.5-pro"],
})
).toThrow(ZodError);
try {
RouteMessageParamsSchema.parse({
message: "",
models: ["gemini-1.5-pro"],
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("message");
}
});
it("should throw on empty models array", () => {
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
models: [],
})
).toThrow(ZodError);
try {
RouteMessageParamsSchema.parse({
message: "valid message",
models: [],
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("models");
}
});
it("should throw on missing required fields", () => {
expect(() =>
RouteMessageParamsSchema.parse({
// Missing required message field
models: ["gemini-1.5-pro"],
})
).toThrow(ZodError);
try {
RouteMessageParamsSchema.parse({
// Missing required message field
models: ["gemini-1.5-pro"],
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("message");
}
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
// Missing required models field
})
).toThrow(ZodError);
try {
RouteMessageParamsSchema.parse({
message: "valid message",
// Missing required models field
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("models");
}
});
it("should validate optional fields when provided", () => {
// Testing with just the required fields
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
models: ["gemini-1.5-pro"],
})
).not.toThrow();
// Testing with optional fields
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
models: ["gemini-1.5-pro"],
routingPrompt: "custom prompt",
defaultModel: "gemini-1.5-flash",
})
).not.toThrow();
// Testing with invalid optional field
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
models: ["gemini-1.5-pro"],
defaultModel: "", // Empty string
})
).toThrow(ZodError);
try {
RouteMessageParamsSchema.parse({
message: "valid message",
models: ["gemini-1.5-pro"],
defaultModel: "", // Empty string
});
} catch (err) {
expect(err).toBeInstanceOf(ZodError);
const zodError = err as ZodError;
expect(zodError.errors[0].path[0]).toBe("defaultModel");
}
});
it("should accept string or ContentSchema for systemInstruction", () => {
// String form
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
models: ["gemini-1.5-pro"],
systemInstruction: "You are a helpful assistant",
})
).not.toThrow();
// Object form
expect(() =>
RouteMessageParamsSchema.parse({
message: "valid message",
models: ["gemini-1.5-pro"],
systemInstruction: {
parts: [{ text: "You are a helpful assistant" }],
},
})
).not.toThrow();
});
});
});
```