This is page 5 of 6. Use http://codebase.md/bsmi021/mcp-gemini-server?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .eslintignore
├── .eslintrc.json
├── .gitignore
├── .prettierrc.json
├── Dockerfile
├── LICENSE
├── package-lock.json
├── package.json
├── README.md
├── review-prompt.txt
├── scripts
│ ├── gemini-review.sh
│ └── run-with-health-check.sh
├── smithery.yaml
├── src
│ ├── config
│ │ └── ConfigurationManager.ts
│ ├── createServer.ts
│ ├── index.ts
│ ├── resources
│ │ └── system-prompt.md
│ ├── server.ts
│ ├── services
│ │ ├── ExampleService.ts
│ │ ├── gemini
│ │ │ ├── GeminiCacheService.ts
│ │ │ ├── GeminiChatService.ts
│ │ │ ├── GeminiContentService.ts
│ │ │ ├── GeminiGitDiffService.ts
│ │ │ ├── GeminiPromptTemplates.ts
│ │ │ ├── GeminiTypes.ts
│ │ │ ├── GeminiUrlContextService.ts
│ │ │ ├── GeminiValidationSchemas.ts
│ │ │ ├── GitHubApiService.ts
│ │ │ ├── GitHubUrlParser.ts
│ │ │ └── ModelMigrationService.ts
│ │ ├── GeminiService.ts
│ │ ├── index.ts
│ │ ├── mcp
│ │ │ ├── index.ts
│ │ │ └── McpClientService.ts
│ │ ├── ModelSelectionService.ts
│ │ ├── session
│ │ │ ├── index.ts
│ │ │ ├── InMemorySessionStore.ts
│ │ │ ├── SessionStore.ts
│ │ │ └── SQLiteSessionStore.ts
│ │ └── SessionService.ts
│ ├── tools
│ │ ├── exampleToolParams.ts
│ │ ├── geminiCacheParams.ts
│ │ ├── geminiCacheTool.ts
│ │ ├── geminiChatParams.ts
│ │ ├── geminiChatTool.ts
│ │ ├── geminiCodeReviewParams.ts
│ │ ├── geminiCodeReviewTool.ts
│ │ ├── geminiGenerateContentConsolidatedParams.ts
│ │ ├── geminiGenerateContentConsolidatedTool.ts
│ │ ├── geminiGenerateImageParams.ts
│ │ ├── geminiGenerateImageTool.ts
│ │ ├── geminiGenericParamSchemas.ts
│ │ ├── geminiRouteMessageParams.ts
│ │ ├── geminiRouteMessageTool.ts
│ │ ├── geminiUrlAnalysisTool.ts
│ │ ├── index.ts
│ │ ├── mcpClientParams.ts
│ │ ├── mcpClientTool.ts
│ │ ├── registration
│ │ │ ├── index.ts
│ │ │ ├── registerAllTools.ts
│ │ │ ├── ToolAdapter.ts
│ │ │ └── ToolRegistry.ts
│ │ ├── schemas
│ │ │ ├── BaseToolSchema.ts
│ │ │ ├── CommonSchemas.ts
│ │ │ ├── index.ts
│ │ │ ├── ToolSchemas.ts
│ │ │ └── writeToFileParams.ts
│ │ └── writeToFileTool.ts
│ ├── types
│ │ ├── exampleServiceTypes.ts
│ │ ├── geminiServiceTypes.ts
│ │ ├── gitdiff-parser.d.ts
│ │ ├── googleGenAI.d.ts
│ │ ├── googleGenAITypes.ts
│ │ ├── index.ts
│ │ ├── micromatch.d.ts
│ │ ├── modelcontextprotocol-sdk.d.ts
│ │ ├── node-fetch.d.ts
│ │ └── serverTypes.ts
│ └── utils
│ ├── errors.ts
│ ├── filePathSecurity.ts
│ ├── FileSecurityService.ts
│ ├── geminiErrors.ts
│ ├── healthCheck.ts
│ ├── index.ts
│ ├── logger.ts
│ ├── RetryService.ts
│ ├── ToolError.ts
│ └── UrlSecurityService.ts
├── tests
│ ├── .env.test.example
│ ├── basic-router.test.vitest.ts
│ ├── e2e
│ │ ├── clients
│ │ │ └── mcp-test-client.ts
│ │ ├── README.md
│ │ └── streamableHttpTransport.test.vitest.ts
│ ├── integration
│ │ ├── dummyMcpServerSse.ts
│ │ ├── dummyMcpServerStdio.ts
│ │ ├── geminiRouterIntegration.test.vitest.ts
│ │ ├── mcpClientIntegration.test.vitest.ts
│ │ ├── multiModelIntegration.test.vitest.ts
│ │ └── urlContextIntegration.test.vitest.ts
│ ├── tsconfig.test.json
│ ├── unit
│ │ ├── config
│ │ │ └── ConfigurationManager.multimodel.test.vitest.ts
│ │ ├── server
│ │ │ └── transportLogic.test.vitest.ts
│ │ ├── services
│ │ │ ├── gemini
│ │ │ │ ├── GeminiChatService.test.vitest.ts
│ │ │ │ ├── GeminiGitDiffService.test.vitest.ts
│ │ │ │ ├── geminiImageGeneration.test.vitest.ts
│ │ │ │ ├── GeminiPromptTemplates.test.vitest.ts
│ │ │ │ ├── GeminiUrlContextService.test.vitest.ts
│ │ │ │ ├── GeminiValidationSchemas.test.vitest.ts
│ │ │ │ ├── GitHubApiService.test.vitest.ts
│ │ │ │ ├── GitHubUrlParser.test.vitest.ts
│ │ │ │ └── ThinkingBudget.test.vitest.ts
│ │ │ ├── mcp
│ │ │ │ └── McpClientService.test.vitest.ts
│ │ │ ├── ModelSelectionService.test.vitest.ts
│ │ │ └── session
│ │ │ └── SQLiteSessionStore.test.vitest.ts
│ │ ├── tools
│ │ │ ├── geminiCacheTool.test.vitest.ts
│ │ │ ├── geminiChatTool.test.vitest.ts
│ │ │ ├── geminiCodeReviewTool.test.vitest.ts
│ │ │ ├── geminiGenerateContentConsolidatedTool.test.vitest.ts
│ │ │ ├── geminiGenerateImageTool.test.vitest.ts
│ │ │ ├── geminiRouteMessageTool.test.vitest.ts
│ │ │ ├── mcpClientTool.test.vitest.ts
│ │ │ ├── mcpToolsTests.test.vitest.ts
│ │ │ └── schemas
│ │ │ ├── BaseToolSchema.test.vitest.ts
│ │ │ ├── ToolParamSchemas.test.vitest.ts
│ │ │ └── ToolSchemas.test.vitest.ts
│ │ └── utils
│ │ ├── errors.test.vitest.ts
│ │ ├── FileSecurityService.test.vitest.ts
│ │ ├── FileSecurityService.vitest.ts
│ │ ├── FileSecurityServiceBasics.test.vitest.ts
│ │ ├── healthCheck.test.vitest.ts
│ │ ├── RetryService.test.vitest.ts
│ │ └── UrlSecurityService.test.vitest.ts
│ └── utils
│ ├── assertions.ts
│ ├── debug-error.ts
│ ├── env-check.ts
│ ├── environment.ts
│ ├── error-helpers.ts
│ ├── express-mocks.ts
│ ├── integration-types.ts
│ ├── mock-types.ts
│ ├── test-fixtures.ts
│ ├── test-generators.ts
│ ├── test-setup.ts
│ └── vitest.d.ts
├── tsconfig.json
├── tsconfig.test.json
├── vitest-globals.d.ts
├── vitest.config.ts
└── vitest.setup.ts
```
# Files
--------------------------------------------------------------------------------
/tests/unit/tools/geminiCodeReviewTool.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import {
geminiCodeReviewTool,
geminiCodeReviewStreamTool,
} from "../../../src/tools/geminiCodeReviewTool.js";
import { GeminiService } from "../../../src/services/index.js";
// Mock dependencies
vi.mock("../../../src/services/index.js");
type MockGeminiService = {
reviewGitDiff: ReturnType<typeof vi.fn>;
reviewGitDiffStream: ReturnType<typeof vi.fn>;
reviewGitHubRepository: ReturnType<typeof vi.fn>;
reviewGitHubPullRequest: ReturnType<typeof vi.fn>;
};
describe("geminiCodeReviewTool", () => {
let mockGeminiService: MockGeminiService;
beforeEach(() => {
vi.clearAllMocks();
// Setup mock GeminiService
mockGeminiService = {
reviewGitDiff: vi.fn(),
reviewGitDiffStream: vi.fn(),
reviewGitHubRepository: vi.fn(),
reviewGitHubPullRequest: vi.fn(),
};
vi.mocked(GeminiService).mockImplementation(() => mockGeminiService as any);
});
describe("Tool Configuration", () => {
it("should have correct name and description", () => {
expect(geminiCodeReviewTool.name).toBe("gemini_code_review");
expect(geminiCodeReviewTool.description).toContain(
"Performs comprehensive code reviews"
);
});
it("should have valid input schema", () => {
expect(geminiCodeReviewTool.inputSchema).toBeDefined();
expect((geminiCodeReviewTool.inputSchema as any)._def.discriminator).toBe(
"source"
);
});
});
describe("Local Diff Review", () => {
it("should handle local diff review", async () => {
const mockReview =
"Code Review:\n- Good use of types\n- Consider error handling";
mockGeminiService.reviewGitDiff.mockResolvedValue(mockReview);
const args = {
source: "local_diff" as const,
diffContent: "diff --git a/file.ts b/file.ts\n+const x = 1;",
model: "gemini-2.5-pro-preview-05-06",
reviewFocus: "security" as const,
customPrompt: "Focus on TypeScript best practices",
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(mockGeminiService.reviewGitDiff).toHaveBeenCalledWith({
diffContent: args.diffContent,
modelName: args.model,
reviewFocus: "security", // Should take first value from array
customPrompt: args.customPrompt,
diffOptions: {
maxFilesToInclude: undefined,
excludePatterns: undefined,
prioritizeFiles: undefined,
},
reasoningEffort: undefined,
repositoryContext: undefined,
});
expect(result.content[0].type).toBe("text");
expect(result.content[0].text).toBe(mockReview);
});
it("should handle local diff with repository context", async () => {
const mockReview = "Review complete";
mockGeminiService.reviewGitDiff.mockResolvedValue(mockReview);
const args = {
source: "local_diff" as const,
diffContent: "diff content",
repositoryContext: {
name: "my-project",
description: "A TypeScript project",
languages: ["TypeScript", "JavaScript"],
frameworks: ["React", "Node.js"],
},
maxFilesToInclude: 50,
excludePatterns: ["*.test.ts", "dist/**"],
prioritizeFiles: ["src/**/*.ts"],
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(result).toBeDefined();
expect(result.content[0].text).toBe(mockReview);
expect(mockGeminiService.reviewGitDiff).toHaveBeenCalledWith(
expect.objectContaining({
repositoryContext: JSON.stringify(args.repositoryContext),
diffOptions: {
maxFilesToInclude: 50,
excludePatterns: ["*.test.ts", "dist/**"],
prioritizeFiles: ["src/**/*.ts"],
},
})
);
});
});
describe("GitHub Repository Review", () => {
it("should handle GitHub repository review", async () => {
const mockReview = "Repository Review:\n- Well-structured codebase";
mockGeminiService.reviewGitHubRepository.mockResolvedValue(mockReview);
const args = {
source: "github_repo" as const,
repoUrl: "https://github.com/owner/repo",
branch: "main",
maxFiles: 50,
reasoningEffort: "high" as const,
reviewFocus: "architecture" as const,
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(mockGeminiService.reviewGitHubRepository).toHaveBeenCalledWith({
owner: "owner",
repo: "repo",
branch: args.branch,
maxFilesToInclude: args.maxFiles,
modelName: undefined,
reasoningEffort: args.reasoningEffort,
reviewFocus: "architecture", // Should take first value from array
excludePatterns: undefined,
prioritizeFiles: undefined,
customPrompt: undefined,
});
expect(result.content[0].text).toBe(mockReview);
});
});
describe("GitHub Pull Request Review", () => {
it("should handle GitHub PR review", async () => {
const mockReview =
"PR Review:\n- Changes look good\n- Tests are comprehensive";
mockGeminiService.reviewGitHubPullRequest.mockResolvedValue(mockReview);
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/pull/123",
model: "gemini-2.5-flash-preview-05-20",
filesOnly: true,
excludePatterns: ["*.generated.ts"],
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(mockGeminiService.reviewGitHubPullRequest).toHaveBeenCalledWith({
owner: "owner",
repo: "repo",
prNumber: 123,
modelName: args.model,
reasoningEffort: undefined,
reviewFocus: undefined,
excludePatterns: args.excludePatterns,
customPrompt: undefined,
});
expect(result.content[0].text).toBe(mockReview);
});
it("should handle GitHub PR review with all optional parameters", async () => {
const mockReview =
"Comprehensive PR Review:\n- Code quality is excellent\n- Security considerations addressed";
mockGeminiService.reviewGitHubPullRequest.mockResolvedValue(mockReview);
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/pull/456",
model: "gemini-2.5-pro-preview-05-06",
reasoningEffort: "high" as const,
reviewFocus: "security" as const,
excludePatterns: ["*.test.ts", "*.spec.ts", "dist/**"],
customPrompt:
"Focus on security vulnerabilities and performance optimizations",
filesOnly: false,
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(mockGeminiService.reviewGitHubPullRequest).toHaveBeenCalledWith({
owner: "owner",
repo: "repo",
prNumber: 456,
modelName: args.model,
reasoningEffort: args.reasoningEffort,
reviewFocus: "security", // Should take first value from array
excludePatterns: args.excludePatterns,
customPrompt: args.customPrompt,
});
expect(result.content[0].text).toBe(mockReview);
});
it("should handle GitHub PR review with deprecated filesOnly parameter", async () => {
const mockReview = "Files-only PR Review";
mockGeminiService.reviewGitHubPullRequest.mockResolvedValue(mockReview);
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/pull/789",
filesOnly: true,
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(mockGeminiService.reviewGitHubPullRequest).toHaveBeenCalledWith({
owner: "owner",
repo: "repo",
prNumber: 789,
modelName: undefined,
reasoningEffort: undefined,
reviewFocus: undefined,
excludePatterns: undefined,
customPrompt: undefined,
});
expect(result.content[0].text).toBe(mockReview);
});
it("should handle GitHub PR review with minimal parameters", async () => {
const mockReview = "Basic PR Review";
mockGeminiService.reviewGitHubPullRequest.mockResolvedValue(mockReview);
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/pull/101",
};
const result = await geminiCodeReviewTool.execute(
args,
mockGeminiService as any
);
expect(mockGeminiService.reviewGitHubPullRequest).toHaveBeenCalledWith({
owner: "owner",
repo: "repo",
prNumber: 101,
modelName: undefined,
reasoningEffort: undefined,
reviewFocus: undefined,
excludePatterns: undefined,
customPrompt: undefined,
});
expect(result.content[0].text).toBe(mockReview);
});
});
describe("URL Parsing and Validation", () => {
it("should handle invalid GitHub repository URL", async () => {
const args = {
source: "github_repo" as const,
repoUrl: "https://invalid-url.com/not-github",
maxFiles: 100,
};
await expect(
geminiCodeReviewTool.execute(args, mockGeminiService as any)
).rejects.toThrow("Invalid GitHub repository URL format");
});
it("should handle invalid GitHub PR URL", async () => {
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/issues/123", // issues instead of pull
};
await expect(
geminiCodeReviewTool.execute(args, mockGeminiService as any)
).rejects.toThrow("Invalid GitHub pull request URL format");
});
it("should handle malformed GitHub PR URL", async () => {
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/pull/invalid-number",
};
await expect(
geminiCodeReviewTool.execute(args, mockGeminiService as any)
).rejects.toThrow("Invalid GitHub pull request URL format");
});
it("should correctly parse GitHub repository URL", async () => {
const mockReview = "Repository parsed correctly";
mockGeminiService.reviewGitHubRepository.mockResolvedValue(mockReview);
const args = {
source: "github_repo" as const,
repoUrl: "https://github.com/microsoft/typescript",
branch: "main",
maxFiles: 100,
};
await geminiCodeReviewTool.execute(args, mockGeminiService as any);
expect(mockGeminiService.reviewGitHubRepository).toHaveBeenCalledWith(
expect.objectContaining({
owner: "microsoft",
repo: "typescript",
branch: "main",
})
);
});
it("should correctly parse GitHub PR URL and extract PR number", async () => {
const mockReview = "PR parsed correctly";
mockGeminiService.reviewGitHubPullRequest.mockResolvedValue(mockReview);
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/facebook/react/pull/12345",
};
await geminiCodeReviewTool.execute(args, mockGeminiService as any);
expect(mockGeminiService.reviewGitHubPullRequest).toHaveBeenCalledWith(
expect.objectContaining({
owner: "facebook",
repo: "react",
prNumber: 12345,
})
);
});
});
describe("Review Focus Array Handling", () => {
it("should handle multiple review focus areas for local diff", async () => {
const mockReview = "Multi-focus review";
mockGeminiService.reviewGitDiff.mockResolvedValue(mockReview);
const args = {
source: "local_diff" as const,
diffContent: "diff content",
reviewFocus: "security" as const,
};
await geminiCodeReviewTool.execute(args, mockGeminiService as any);
expect(mockGeminiService.reviewGitDiff).toHaveBeenCalledWith(
expect.objectContaining({
reviewFocus: "security", // Should take first value
})
);
});
it("should handle empty review focus array", async () => {
const mockReview = "Default focus review";
mockGeminiService.reviewGitDiff.mockResolvedValue(mockReview);
const args = {
source: "local_diff" as const,
diffContent: "diff content",
// No reviewFocus to test undefined behavior
};
await geminiCodeReviewTool.execute(args, mockGeminiService as any);
expect(mockGeminiService.reviewGitDiff).toHaveBeenCalledWith(
expect.objectContaining({
reviewFocus: undefined, // Should be undefined for empty array
})
);
});
it("should handle single review focus area", async () => {
const mockReview = "Single focus review";
mockGeminiService.reviewGitDiff.mockResolvedValue(mockReview);
const args = {
source: "local_diff" as const,
diffContent: "diff content",
reviewFocus: "architecture" as const,
};
await geminiCodeReviewTool.execute(args, mockGeminiService as any);
expect(mockGeminiService.reviewGitDiff).toHaveBeenCalledWith(
expect.objectContaining({
reviewFocus: "architecture",
})
);
});
});
describe("Error Handling", () => {
it("should handle GitHub API service errors for PR review", async () => {
mockGeminiService.reviewGitHubPullRequest.mockRejectedValue(
new Error("GitHub API rate limit exceeded")
);
const args = {
source: "github_pr" as const,
prUrl: "https://github.com/owner/repo/pull/123",
};
await expect(
geminiCodeReviewTool.execute(args, mockGeminiService as any)
).rejects.toThrow();
});
it("should handle GitHub API service errors for repo review", async () => {
mockGeminiService.reviewGitHubRepository.mockRejectedValue(
new Error("Repository not found")
);
const args = {
source: "github_repo" as const,
repoUrl: "https://github.com/owner/nonexistent-repo",
maxFiles: 100,
};
await expect(
geminiCodeReviewTool.execute(args, mockGeminiService as any)
).rejects.toThrow();
});
});
describe("Error Handling", () => {
it("should handle service errors", async () => {
mockGeminiService.reviewGitDiff.mockRejectedValue(new Error("API error"));
const args = {
source: "local_diff" as const,
diffContent: "diff content",
};
await expect(
geminiCodeReviewTool.execute(args, mockGeminiService as any)
).rejects.toThrow();
});
it("should handle unknown source type", async () => {
const args = {
source: "unknown" as unknown as
| "local_diff"
| "github_pr"
| "github_repo",
diffContent: "diff",
};
await expect(
geminiCodeReviewTool.execute(args as any, mockGeminiService as any)
).rejects.toThrow("Unknown review source");
});
});
});
describe("geminiCodeReviewStreamTool", () => {
let mockGeminiService: Pick<MockGeminiService, "reviewGitDiffStream">;
beforeEach(() => {
vi.clearAllMocks();
mockGeminiService = {
reviewGitDiffStream: vi.fn(),
};
vi.mocked(GeminiService).mockImplementation(() => mockGeminiService as any);
});
it("should stream local diff review", async () => {
const mockChunks = ["Review chunk 1", "Review chunk 2", "Review chunk 3"];
// Create an async generator mock
mockGeminiService.reviewGitDiffStream.mockImplementation(
async function* () {
for (const chunk of mockChunks) {
yield chunk;
}
}
);
const args = {
source: "local_diff" as const,
diffContent: "diff content",
model: "gemini-2.5-pro-preview-05-06",
};
const results: Array<any> = [];
const generator = await geminiCodeReviewStreamTool.execute(
args,
mockGeminiService as any
);
for await (const chunk of generator) {
results.push(chunk);
}
expect(results).toHaveLength(3);
expect(results[0].content[0].text).toBe("Review chunk 1");
expect(results[1].content[0].text).toBe("Review chunk 2");
expect(results[2].content[0].text).toBe("Review chunk 3");
});
it("should reject non-local_diff sources", async () => {
const args = {
source: "github_repo" as const,
repoUrl: "https://github.com/owner/repo",
maxFiles: 100,
};
await expect(async () => {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const generator = await geminiCodeReviewStreamTool.execute(
args,
mockGeminiService as any
);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
for await (const _chunk of generator) {
// Should not reach here - this line should never execute
break;
}
}).rejects.toThrow("Streaming is only supported for local_diff source");
});
});
```
--------------------------------------------------------------------------------
/tests/unit/services/mcp/McpClientService.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
/// <reference types="../../../../vitest-globals.d.ts" />
// Using vitest globals - see vitest.config.ts globals: true
// Fixed UUID for testing
const TEST_UUID = "test-uuid-value";
// Import mock types first
import {
EVENT_SOURCE_STATES,
MockEvent,
MockEventSource,
} from "../../../../tests/utils/mock-types.js";
// Store the mock instance for access in tests
let mockEventSourceInstance: MockEventSource;
// Create mock objects for child_process
export const mockStdout = {
on: vi.fn(),
removeAllListeners: vi.fn(),
};
export const mockStderr = {
on: vi.fn(),
removeAllListeners: vi.fn(),
};
export const mockStdin = {
write: vi.fn(),
};
export const mockChildProcess = {
stdout: mockStdout,
stderr: mockStderr,
stdin: mockStdin,
on: vi.fn(),
kill: vi.fn(),
removeAllListeners: vi.fn(),
};
// Store the EventSource constructor for test expectations
let EventSourceConstructor: any;
// Setup mocks using doMock to avoid hoisting issues
vi.doMock("eventsource", () => {
EventSourceConstructor = vi.fn().mockImplementation(function (
url: string,
_options?: any
) {
// Create mock instance
const instance = {
onopen: null,
onmessage: null,
onerror: null,
readyState: 0,
url: url,
withCredentials: false,
close: vi.fn(),
addEventListener: vi.fn(),
removeEventListener: vi.fn(),
dispatchEvent: vi.fn().mockReturnValue(true),
};
// Store instance for test access
mockEventSourceInstance = instance as any;
return instance;
});
return {
default: EventSourceConstructor,
};
});
vi.doMock("uuid", () => ({
v4: vi.fn(() => TEST_UUID),
}));
const mockSpawn = vi.fn(() => mockChildProcess);
vi.doMock("child_process", () => ({
spawn: mockSpawn,
}));
vi.doMock("node-fetch", () => ({
default: vi.fn().mockResolvedValue({
ok: true,
status: 200,
statusText: "OK",
json: vi.fn().mockResolvedValue({ result: {} }),
}),
}));
// Type helper for accessing private properties in tests - will be redefined after import
type McpClientServicePrivate = any;
describe("McpClientService", () => {
let McpClientService: typeof import("../../../../src/services/mcp/McpClientService.js").McpClientService;
let SdkMcpError: any;
let logger: any;
let service: any;
let originalSetInterval: typeof global.setInterval;
let originalClearInterval: typeof global.clearInterval;
beforeAll(async () => {
// Dynamic imports after mocks are set up
const mcpService = await import(
"../../../../src/services/mcp/McpClientService.js"
);
McpClientService = mcpService.McpClientService;
const sdkTypes = await import("@modelcontextprotocol/sdk/types.js");
SdkMcpError = sdkTypes.McpError;
const loggerModule = await import("../../../../src/utils/logger.js");
logger = loggerModule.logger;
});
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks();
// Save original timing functions
originalSetInterval = global.setInterval;
originalClearInterval = global.clearInterval;
// Mock timers
vi.useFakeTimers();
// Mock logger
vi.spyOn(logger, "info").mockImplementation(vi.fn());
vi.spyOn(logger, "warn").mockImplementation(vi.fn());
vi.spyOn(logger, "error").mockImplementation(vi.fn());
vi.spyOn(logger, "debug").mockImplementation(vi.fn());
// Create a new instance of the service
service = new McpClientService();
});
afterEach(() => {
// Restore originals
global.setInterval = originalSetInterval;
global.clearInterval = originalClearInterval;
// Restore all mocks
vi.restoreAllMocks();
vi.useRealTimers();
});
describe("Constructor", () => {
it("should initialize with empty connection maps", () => {
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(0);
expect(
(service as McpClientServicePrivate).activeStdioConnections.size
).toBe(0);
expect(
(service as McpClientServicePrivate).pendingStdioRequests.size
).toBe(0);
});
it("should set up a cleanup interval", () => {
expect(vi.getTimerCount()).toBeGreaterThan(0);
});
});
describe("connect", () => {
it("should validate serverId properly", async () => {
await expect(
service.connect("", { type: "sse", sseUrl: "http://test-url.com" })
).rejects.toThrow(SdkMcpError);
await expect(
service.connect("", { type: "sse", sseUrl: "http://test-url.com" })
).rejects.toThrow(/Server ID must be a non-empty string/);
});
it("should validate connection details properly", async () => {
await expect(service.connect("server1", null as any)).rejects.toThrow(
SdkMcpError
);
await expect(service.connect("server1", null as any)).rejects.toThrow(
/Connection details must be an object/
);
});
it("should validate connection type properly", async () => {
// Using type assertion to test invalid inputs
await expect(
service.connect("server1", {
type: "invalid" as unknown as "sse" | "stdio",
})
).rejects.toThrow(SdkMcpError);
// Using type assertion to test invalid inputs
await expect(
service.connect("server1", {
type: "invalid" as unknown as "sse" | "stdio",
})
).rejects.toThrow(/Connection type must be 'sse' or 'stdio'/);
});
it("should validate SSE URL properly", async () => {
await expect(
service.connect("server1", { type: "sse", sseUrl: "" })
).rejects.toThrow(SdkMcpError);
await expect(
service.connect("server1", { type: "sse", sseUrl: "" })
).rejects.toThrow(/sseUrl must be a non-empty string/);
await expect(
service.connect("server1", { type: "sse", sseUrl: "invalid-url" })
).rejects.toThrow(SdkMcpError);
await expect(
service.connect("server1", { type: "sse", sseUrl: "invalid-url" })
).rejects.toThrow(/valid URL format/);
});
it("should validate stdio command properly", async () => {
await expect(
service.connect("server1", { type: "stdio", stdioCommand: "" })
).rejects.toThrow(SdkMcpError);
await expect(
service.connect("server1", { type: "stdio", stdioCommand: "" })
).rejects.toThrow(/stdioCommand must be a non-empty string/);
});
it("should establish an SSE connection successfully", async () => {
const connectPromise = service.connect("server1", {
type: "sse",
sseUrl: "http://test-server.com/sse",
});
// Wait for the EventSource to be created and callbacks to be assigned
await new Promise((resolve) => setTimeout(resolve, 10));
// Simulate successful connection by calling the onopen callback
if (mockEventSourceInstance && mockEventSourceInstance.onopen) {
mockEventSourceInstance.onopen({} as MockEvent);
}
const connectionId = await connectPromise;
expect(connectionId).toBe(TEST_UUID);
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
// Check correct parameters were used
expect(EventSourceConstructor).toHaveBeenCalledWith(
"http://test-server.com/sse"
);
});
it("should establish a stdio connection successfully", async () => {
const connectionId = await service.connect("server1", {
type: "stdio",
stdioCommand: "test-command",
stdioArgs: ["arg1", "arg2"],
});
expect(connectionId).toBe(TEST_UUID);
expect(
(service as McpClientServicePrivate).activeStdioConnections.size
).toBe(1);
// Check correct parameters were used
expect(mockSpawn).toHaveBeenCalledWith(
"test-command",
["arg1", "arg2"],
expect.anything()
);
});
});
describe("cleanupStaleConnections", () => {
it("should close stale SSE connections", async () => {
// Create a connection
const connectPromise = (service as McpClientServicePrivate).connectSse(
"http://test-server.com/sse"
);
mockEventSourceInstance.onopen &&
mockEventSourceInstance.onopen({} as MockEvent);
const connectionId = await connectPromise;
// Verify connection exists
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
// Set the last activity timestamp to be stale (10 minutes + 1 second ago)
const staleTimestamp = Date.now() - (600000 + 1000);
(service as McpClientServicePrivate).activeSseConnections.get(
connectionId
).lastActivityTimestamp = staleTimestamp;
// Call the cleanup method
(service as McpClientServicePrivate).cleanupStaleConnections();
// Verify connection was closed
expect(mockEventSourceInstance.close).toHaveBeenCalled();
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(0);
});
it("should close stale stdio connections", async () => {
// Create a connection
const connectionId = await (
service as McpClientServicePrivate
).connectStdio("test-command");
// Verify connection exists
expect(
(service as McpClientServicePrivate).activeStdioConnections.size
).toBe(1);
// Set the last activity timestamp to be stale (10 minutes + 1 second ago)
const staleTimestamp = Date.now() - (600000 + 1000);
(service as McpClientServicePrivate).activeStdioConnections.get(
connectionId
).lastActivityTimestamp = staleTimestamp;
// Call the cleanup method
(service as McpClientServicePrivate).cleanupStaleConnections();
// Verify connection was closed
expect(mockChildProcess.kill).toHaveBeenCalled();
expect(
(service as McpClientServicePrivate).activeStdioConnections.size
).toBe(0);
});
it("should not close active connections", async () => {
// Create a connection
const connectPromise = (service as McpClientServicePrivate).connectSse(
"http://test-server.com/sse"
);
mockEventSourceInstance.onopen &&
mockEventSourceInstance.onopen({} as MockEvent);
await connectPromise;
// Verify connection exists (with current timestamp)
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
// Call the cleanup method
(service as McpClientServicePrivate).cleanupStaleConnections();
// Verify connection was not closed
expect(mockEventSourceInstance.close).not.toHaveBeenCalled();
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
});
});
describe("SSE Connections", () => {
const testUrl = "http://test-server.com/sse";
it("should create an EventSource and return a connection ID when successful", async () => {
const connectPromise = (service as McpClientServicePrivate).connectSse(
testUrl
);
// Trigger the onopen event to simulate successful connection
mockEventSourceInstance.onopen &&
mockEventSourceInstance.onopen({} as MockEvent);
const connectionId = await connectPromise;
// Check EventSource was constructed with the correct URL
expect(EventSourceConstructor).toHaveBeenCalledWith(testUrl);
// Check the connection ID is returned
expect(connectionId).toBe(TEST_UUID);
// Check the connection was stored with last activity timestamp
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
expect(
(service as McpClientServicePrivate).activeSseConnections.has(TEST_UUID)
).toBe(true);
const connection = (
service as McpClientServicePrivate
).activeSseConnections.get(TEST_UUID);
expect(connection.lastActivityTimestamp).toBeGreaterThan(0);
});
it("should handle SSE messages and pass them to the messageHandler", async () => {
const messageHandler = vi.fn();
const testData = { foo: "bar" };
const connectPromise = (service as McpClientServicePrivate).connectSse(
testUrl,
messageHandler
);
// Manually trigger the onopen callback to resolve the connection promise
mockEventSourceInstance.onopen &&
mockEventSourceInstance.onopen({} as MockEvent);
await connectPromise;
// Get the initial activity timestamp
const initialTimestamp = (
service as McpClientServicePrivate
).activeSseConnections.get(TEST_UUID).lastActivityTimestamp;
// Store original timestamp so we can mock a newer one
const originalTimestamp = Date.now;
// Mock Date.now to return a later timestamp
Date.now = vi.fn().mockReturnValue(initialTimestamp + 1000);
// Trigger the onmessage event with test data
const messageEvent = { data: JSON.stringify(testData) };
mockEventSourceInstance.onmessage &&
mockEventSourceInstance.onmessage(messageEvent as MessageEvent);
// Verify message handler was called with parsed data
expect(messageHandler).toHaveBeenCalledWith(testData);
// Verify last activity timestamp was updated
const newTimestamp = (
service as McpClientServicePrivate
).activeSseConnections.get(TEST_UUID).lastActivityTimestamp;
expect(newTimestamp).toBeGreaterThan(initialTimestamp);
// Restore original Date.now
Date.now = originalTimestamp;
});
it("should handle SSE message parse errors and pass raw data to the messageHandler", async () => {
const messageHandler = vi.fn();
const invalidJson = "{ not valid json";
const connectPromise = (service as McpClientServicePrivate).connectSse(
testUrl,
messageHandler
);
// Manually trigger the onopen callback to resolve the connection promise
mockEventSourceInstance.onopen &&
mockEventSourceInstance.onopen({} as MockEvent);
await connectPromise;
// Store original timestamp and mock it
const originalTimestamp = Date.now;
Date.now = vi.fn().mockReturnValue(Date.now() + 1000);
// Trigger the onmessage event with invalid JSON
const messageEvent = { data: invalidJson };
mockEventSourceInstance.onmessage &&
mockEventSourceInstance.onmessage(messageEvent as MessageEvent);
// Verify message handler was called with raw data
expect(messageHandler).toHaveBeenCalledWith(invalidJson);
// Restore original Date.now
Date.now = originalTimestamp;
});
it("should reject the promise when an SSE error occurs before connection", async () => {
const connectPromise = (service as McpClientServicePrivate).connectSse(
testUrl
);
// Trigger the onerror event before onopen
const errorEvent: MockEvent = {
type: "error",
message: "Connection failed",
};
mockEventSourceInstance.onerror &&
mockEventSourceInstance.onerror(errorEvent);
// Expect the promise to reject
await expect(connectPromise).rejects.toThrow(SdkMcpError);
await expect(connectPromise).rejects.toThrow(
/Failed to establish SSE connection/
);
// Verify no connection was stored
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(0);
});
it("should close and remove the connection when an SSE error occurs after connection", async () => {
// Successfully connect first
const connectPromise = (service as McpClientServicePrivate).connectSse(
testUrl
);
mockEventSourceInstance.onopen &&
mockEventSourceInstance.onopen({} as MockEvent);
const connectionId = await connectPromise;
// Verify connection exists before error
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
expect(
(service as McpClientServicePrivate).activeSseConnections.has(
connectionId
)
).toBe(true);
// Update readyState to simulate a connected then closed state
mockEventSourceInstance.readyState = EVENT_SOURCE_STATES.CLOSED;
// Trigger an error after successful connection
const errorEvent: MockEvent = {
type: "error",
message: "Connection lost",
};
mockEventSourceInstance.onerror &&
mockEventSourceInstance.onerror(errorEvent);
// Verify connection was removed
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(0);
expect(
(service as McpClientServicePrivate).activeSseConnections.has(
connectionId
)
).toBe(false);
});
it("should close an SSE connection on disconnect", async () => {
// Reset mocks before this test to ensure clean state
vi.clearAllMocks();
service = new McpClientService();
// In this test we're going to directly set up the activeSseConnections map to match the test scenario
// This is necessary because the implementation uses the connectionId for storage and lookup
const connectionId = TEST_UUID;
// Manually set up the connection in the map
(service as McpClientServicePrivate).activeSseConnections.set(
connectionId,
{
eventSource: mockEventSourceInstance,
baseUrl: testUrl,
lastActivityTimestamp: Date.now(),
}
);
// Verify connection exists before disconnecting
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(1);
expect(
(service as McpClientServicePrivate).activeSseConnections.has(
connectionId
)
).toBe(true);
// Disconnect
const result = service.disconnect(connectionId);
// Verify connection was closed
expect(result).toBe(true);
expect(mockEventSourceInstance.close).toHaveBeenCalled();
expect(
(service as McpClientServicePrivate).activeSseConnections.size
).toBe(0);
});
it("should throw an error when disconnecting from a non-existent connection", () => {
expect(() => service.disconnect("non-existent-server")).toThrow(
SdkMcpError
);
expect(() => service.disconnect("non-existent-server")).toThrow(
/Connection not found/
);
});
});
// Additional tests for callTool, listTools, etc. would follow the same pattern
});
```
--------------------------------------------------------------------------------
/tests/unit/services/gemini/GeminiChatService.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
import { GeminiChatService } from "../../../../src/services/gemini/GeminiChatService.js";
import {
GeminiApiError,
ValidationError as GeminiValidationError,
} from "../../../../src/utils/errors.js";
// Import necessary types
import type {
GenerateContentResponse,
GenerationConfig,
Content,
SafetySetting,
GoogleGenAI,
} from "@google/genai";
// Import the ChatSession type from our service
import { ChatSession } from "../../../../src/services/gemini/GeminiTypes.js";
import { FinishReason } from "../../../../src/types/googleGenAITypes.js";
// Helper type for accessing private properties in tests
type GeminiChatServiceTestAccess = {
chatSessions: Map<string, ChatSession>;
};
// Define a partial version of GenerateContentResponse for mocking
interface PartialGenerateContentResponse
extends Partial<GenerateContentResponse> {
response?: {
candidates?: Array<{
content?: {
role?: string;
parts?: Array<{
text?: string;
functionCall?: Record<string, unknown>;
}>;
};
finishReason?: FinishReason;
}>;
promptFeedback?: {
blockReason?: string;
};
};
model?: string;
contents?: Array<Content>;
generationConfig?: GenerationConfig;
safetySettings?: Array<SafetySetting>;
candidates?: Array<{
content?: {
role?: string;
parts?: Array<{
text?: string;
functionCall?: Record<string, unknown>;
}>;
};
finishReason?: FinishReason;
}>;
text?: string;
}
// Mock uuid
vi.mock("uuid", () => ({
v4: () => "test-session-id",
}));
describe("GeminiChatService", () => {
let chatService: GeminiChatService;
const defaultModel = "gemini-1.5-pro";
// Mock the GoogleGenAI class
const mockGenerateContent = vi
.fn()
.mockResolvedValue({} as PartialGenerateContentResponse);
const mockGoogleGenAI = {
models: {
generateContent: mockGenerateContent,
getGenerativeModel: vi.fn().mockReturnValue({
generateContent: mockGenerateContent,
}),
// Mock the required internal methods
apiClient: {} as unknown,
generateContentInternal: vi.fn(),
generateContentStreamInternal: vi.fn(),
},
// Add other required properties for GoogleGenAI
apiClient: {} as unknown,
vertexai: {} as unknown,
live: {} as unknown,
chats: {} as unknown,
upload: {} as unknown,
caching: {} as unknown,
} as unknown as GoogleGenAI;
beforeEach(() => {
// Reset mocks before each test
vi.clearAllMocks();
// Initialize chat service with mocked dependencies
chatService = new GeminiChatService(mockGoogleGenAI, defaultModel);
});
describe("startChatSession", () => {
it("should create a new chat session with default model when no model is provided", () => {
const sessionId = chatService.startChatSession({});
expect(sessionId).toBe("test-session-id");
// Get the session from private map using proper type assertion
const sessions = (chatService as unknown as GeminiChatServiceTestAccess)
.chatSessions;
const session = sessions.get("test-session-id") as ChatSession;
expect(session.model).toBe(defaultModel);
expect(session.history).toEqual([]);
expect(session.config).toBeDefined();
});
it("should create a new chat session with provided model", () => {
const customModel = "gemini-1.5-flash";
const sessionId = chatService.startChatSession({
modelName: customModel,
});
expect(sessionId).toBe("test-session-id");
// Get the session from private map with proper type assertion
const sessions = (chatService as unknown as GeminiChatServiceTestAccess)
.chatSessions;
const session = sessions.get("test-session-id") as ChatSession;
expect(session.model).toBe(customModel);
});
it("should include history if provided", () => {
const history = [
{ role: "user", parts: [{ text: "Hello" }] },
{ role: "model", parts: [{ text: "Hi there" }] },
];
chatService.startChatSession({ history });
// Get the session from private map with proper type assertion
const sessions = (chatService as unknown as GeminiChatServiceTestAccess)
.chatSessions;
const session = sessions.get("test-session-id") as ChatSession;
expect(session.history).toEqual(history);
expect(session.config.history).toEqual(history);
});
it("should convert string systemInstruction to Content object", () => {
const systemInstruction = "You are a helpful assistant";
chatService.startChatSession({ systemInstruction });
// Get the session from private map with proper type assertion
const sessions = (chatService as unknown as GeminiChatServiceTestAccess)
.chatSessions;
const session = sessions.get("test-session-id") as ChatSession;
expect(session.config.systemInstruction).toEqual({
parts: [{ text: systemInstruction }],
});
});
it("should throw when no model name is available", () => {
// Create a service with no default model
const noDefaultService = new GeminiChatService(
mockGoogleGenAI as GoogleGenAI
);
expect(() => noDefaultService.startChatSession({})).toThrow(
GeminiApiError
);
expect(() => noDefaultService.startChatSession({})).toThrow(
"Model name must be provided"
);
});
});
describe("sendMessageToSession", () => {
beforeEach(() => {
// Create a test session first
chatService.startChatSession({});
});
it("should send a message to an existing session", async () => {
// Mock successful response with proper typing
const mockResponse: PartialGenerateContentResponse = {
candidates: [
{
content: {
role: "model",
parts: [{ text: "Hello, how can I help you?" }],
},
},
],
text: "Hello, how can I help you?",
};
mockGenerateContent.mockResolvedValueOnce(mockResponse);
const response = await chatService.sendMessageToSession({
sessionId: "test-session-id",
message: "Hi there",
});
// Verify generateContent was called with correct params
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
const requestConfig = (
mockGenerateContent.mock.calls[0] as unknown[]
)[0] as Record<string, unknown>;
expect(requestConfig.model).toBe(defaultModel);
expect(requestConfig.contents).toBeDefined();
// Just verify the message exists somewhere in the contents
const contents = requestConfig.contents as Array<Record<string, unknown>>;
const userContent = contents.find(
(content: Record<string, unknown>) =>
content.role === "user" &&
(content.parts as Array<{ text?: string }>)?.[0]?.text === "Hi there"
);
expect(userContent).toBeDefined();
// Verify response
expect(response).toEqual(mockResponse);
// Check that history was updated in the session
const sessions = (chatService as unknown as GeminiChatServiceTestAccess)
.chatSessions;
const session = sessions.get("test-session-id") as ChatSession;
expect(session.history.length).toBe(2); // User + model response
});
it("should throw when session doesn't exist", async () => {
await expect(
chatService.sendMessageToSession({
sessionId: "non-existent-session",
message: "Hi there",
})
).rejects.toThrow(GeminiApiError);
await expect(
chatService.sendMessageToSession({
sessionId: "non-existent-session",
message: "Hi there",
})
).rejects.toThrow("Chat session not found");
});
it("should apply per-message configuration options", async () => {
// Mock successful response with proper typing
const emptyResponse: PartialGenerateContentResponse = {};
mockGenerateContent.mockResolvedValueOnce(emptyResponse);
const generationConfig = { temperature: 0.7 };
const safetySettings = [
{
category: "HARM_CATEGORY_HARASSMENT",
threshold: "BLOCK_MEDIUM_AND_ABOVE",
},
];
await chatService.sendMessageToSession({
sessionId: "test-session-id",
message: "Hi there",
generationConfig,
safetySettings: safetySettings as SafetySetting[],
});
// Verify configuration was applied
const requestConfig = (
mockGenerateContent.mock.calls[0] as unknown[]
)[0] as Record<string, unknown>;
expect(requestConfig.generationConfig).toEqual(generationConfig);
expect(requestConfig.safetySettings).toEqual(safetySettings);
});
});
describe("sendFunctionResultToSession", () => {
beforeEach(() => {
// Create a test session first
chatService.startChatSession({});
});
it("should send a function result to an existing session", async () => {
// Mock successful response with proper typing
const mockResponse: PartialGenerateContentResponse = {
candidates: [
{
content: {
role: "model",
parts: [{ text: "I've processed that function result" }],
},
},
],
};
mockGenerateContent.mockResolvedValueOnce(mockResponse);
const response = await chatService.sendFunctionResultToSession({
sessionId: "test-session-id",
functionResponse: '{"result": "success"}',
functionCall: { name: "testFunction" },
});
// Verify generateContent was called with correct params
expect(mockGenerateContent).toHaveBeenCalledTimes(1);
const requestConfig = (
mockGenerateContent.mock.calls[0] as unknown[]
)[0] as Record<string, unknown>;
// Verify content contains function response
const contents = requestConfig.contents as Array<Record<string, unknown>>;
const functionResponseContent = contents.find(
(c: Record<string, unknown>) => c.role === "function"
);
expect(functionResponseContent).toBeDefined();
const parts = (functionResponseContent as Record<string, unknown>)
.parts as Array<Record<string, unknown>>;
const functionResponse = parts[0].functionResponse as Record<
string,
unknown
>;
expect(functionResponse.name).toBe("testFunction");
// Verify response
expect(response).toEqual(mockResponse);
// Check that history was updated in the session
const sessions = (chatService as unknown as GeminiChatServiceTestAccess)
.chatSessions;
const session = sessions.get("test-session-id") as ChatSession;
expect(session.history.length).toBe(2); // Function call + model response
});
it("should throw when session doesn't exist", async () => {
await expect(
chatService.sendFunctionResultToSession({
sessionId: "non-existent-session",
functionResponse: "{}",
})
).rejects.toThrow(GeminiApiError);
await expect(
chatService.sendFunctionResultToSession({
sessionId: "non-existent-session",
functionResponse: "{}",
})
).rejects.toThrow("Chat session not found");
});
});
describe("routeMessage", () => {
it("should validate input parameters", async () => {
// Invalid parameters to trigger validation error
await expect(
chatService.routeMessage({
message: "", // Empty message
models: [], // Empty models array
} as Parameters<typeof chatService.routeMessage>[0])
).rejects.toThrow(GeminiValidationError);
});
it("should use the first model to do routing and selected model for the message", async () => {
// Mock successful routing response
const routingResponse: PartialGenerateContentResponse = {
text: "gemini-1.5-flash",
};
mockGenerateContent.mockResolvedValueOnce(routingResponse);
// Mock successful content response
const contentResponse: PartialGenerateContentResponse = {
text: "Response from flash model",
candidates: [
{
content: {
parts: [{ text: "Response from flash model" }],
},
},
],
};
mockGenerateContent.mockResolvedValueOnce(contentResponse);
const result = await chatService.routeMessage({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
});
// Verify routing was done with the first model
expect(mockGenerateContent).toHaveBeenCalledTimes(2);
const routingConfig = (
mockGenerateContent.mock.calls[0] as unknown[]
)[0] as Record<string, unknown>;
expect(routingConfig.model).toBe("gemini-1.5-pro");
const routingContents = routingConfig.contents as Array<
Record<string, unknown>
>;
const parts = routingContents[0].parts as Array<Record<string, unknown>>;
expect(parts[0].text).toContain("router");
// Verify final request used the chosen model
const messageConfig = (
mockGenerateContent.mock.calls[1] as unknown[]
)[0] as Record<string, unknown>;
expect(messageConfig.model).toBe("gemini-1.5-flash");
// Verify result contains both response and chosen model
expect(result.response).toBeDefined();
expect(result.chosenModel).toBe("gemini-1.5-flash");
});
it("should use default model if routing fails to identify a model", async () => {
// Mock routing response that doesn't match any model
const unknownModelResponse: PartialGenerateContentResponse = {
text: "unknown-model",
};
mockGenerateContent.mockResolvedValueOnce(unknownModelResponse);
// Mock successful content response
const defaultModelResponse: PartialGenerateContentResponse = {
text: "Response from default model",
candidates: [
{
content: {
parts: [{ text: "Response from default model" }],
},
},
],
};
mockGenerateContent.mockResolvedValueOnce(defaultModelResponse);
const result = await chatService.routeMessage({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
defaultModel: "gemini-1.5-pro",
});
// Verify final request used the default model
const messageConfig = (
mockGenerateContent.mock.calls[1] as unknown[]
)[0] as Record<string, unknown>;
expect(messageConfig.model).toBe("gemini-1.5-pro");
expect(result.chosenModel).toBe("gemini-1.5-pro");
});
it("should throw if routing fails and no default model is provided", async () => {
// Mock routing response that doesn't match any model
const failedRoutingResponse: PartialGenerateContentResponse = {
text: "unknown-model",
};
mockGenerateContent.mockResolvedValueOnce(failedRoutingResponse);
await expect(
chatService.routeMessage({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
})
).rejects.toThrow(GeminiApiError);
await expect(
chatService.routeMessage({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
})
).rejects.toThrow(/Routing failed|Failed to route message/);
});
it("should use custom routing prompt if provided", async () => {
// Mock successful routing and content responses
const customPromptRoutingResponse: PartialGenerateContentResponse = {
text: "gemini-1.5-flash",
};
mockGenerateContent.mockResolvedValueOnce(customPromptRoutingResponse);
const customPromptContentResponse: PartialGenerateContentResponse = {
text: "Response",
};
mockGenerateContent.mockResolvedValueOnce(customPromptContentResponse);
const customPrompt = "Choose the most performant model for this request";
await chatService.routeMessage({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
routingPrompt: customPrompt,
});
// Verify routing was done with the custom prompt
const routingConfig = (
mockGenerateContent.mock.calls[0] as unknown[]
)[0] as Record<string, unknown>;
const routingContents = routingConfig.contents as Array<
Record<string, unknown>
>;
const parts = routingContents[0].parts as Array<Record<string, unknown>>;
const promptText = parts[0].text;
expect(promptText).toContain(customPrompt);
});
it("should pass system instruction to both routing and content requests", async () => {
// Mock successful routing and content responses
const customPromptRoutingResponse: PartialGenerateContentResponse = {
text: "gemini-1.5-flash",
};
mockGenerateContent.mockResolvedValueOnce(customPromptRoutingResponse);
const customPromptContentResponse: PartialGenerateContentResponse = {
text: "Response",
};
mockGenerateContent.mockResolvedValueOnce(customPromptContentResponse);
const systemInstruction = "You are a helpful assistant";
await chatService.routeMessage({
message: "What is the capital of France?",
models: ["gemini-1.5-pro", "gemini-1.5-flash"],
systemInstruction,
});
// Verify system instruction was added to routing request
const routingConfig = (
mockGenerateContent.mock.calls[0] as unknown[]
)[0] as Record<string, unknown>;
const routingContents = routingConfig.contents as Array<
Record<string, unknown>
>;
expect(routingContents[0].role).toBe("system");
const routingParts = routingContents[0].parts as Array<
Record<string, unknown>
>;
expect(routingParts[0].text).toBe(systemInstruction);
// Verify system instruction was added to content request
const messageConfig = (
mockGenerateContent.mock.calls[1] as unknown[]
)[0] as Record<string, unknown>;
const messageContents = messageConfig.contents as Array<
Record<string, unknown>
>;
expect(messageContents[0].role).toBe("system");
const messageParts = messageContents[0].parts as Array<
Record<string, unknown>
>;
expect(messageParts[0].text).toBe(systemInstruction);
});
});
});
```
--------------------------------------------------------------------------------
/src/services/gemini/GeminiUrlContextService.ts:
--------------------------------------------------------------------------------
```typescript
import { ConfigurationManager } from "../../config/ConfigurationManager.js";
import { logger } from "../../utils/logger.js";
import {
GeminiUrlFetchError,
GeminiUrlValidationError,
} from "../../utils/geminiErrors.js";
import { UrlSecurityService } from "../../utils/UrlSecurityService.js";
import { RetryService } from "../../utils/RetryService.js";
import type { Content } from "@google/genai";
export interface UrlFetchOptions {
maxContentLength?: number; // Max bytes to fetch
timeout?: number; // Fetch timeout in ms
headers?: Record<string, string>;
allowedDomains?: string[]; // Domain whitelist
includeMetadata?: boolean; // Include URL metadata in response
convertToMarkdown?: boolean; // Convert HTML to markdown
followRedirects?: number; // Max redirects to follow
userAgent?: string; // Custom user agent
}
export interface UrlContentMetadata {
url: string;
finalUrl?: string; // After redirects
title?: string;
description?: string;
contentType: string;
contentLength?: number;
fetchedAt: Date;
truncated: boolean;
responseTime: number; // ms
statusCode: number;
encoding?: string;
language?: string;
canonicalUrl?: string;
ogImage?: string;
favicon?: string;
}
export interface UrlContentResult {
content: string;
metadata: UrlContentMetadata;
}
export interface UrlBatchResult {
successful: UrlContentResult[];
failed: Array<{
url: string;
error: Error;
errorCode: string;
}>;
summary: {
totalUrls: number;
successCount: number;
failureCount: number;
totalContentSize: number;
averageResponseTime: number;
};
}
/**
* Advanced URL Context Service for Gemini API integration
* Handles URL fetching, content extraction, security validation, and metadata processing
*/
export class GeminiUrlContextService {
private readonly securityService: UrlSecurityService;
private readonly retryService: RetryService;
private readonly urlCache = new Map<
string,
{ result: UrlContentResult; expiry: number }
>();
private readonly rateLimiter = new Map<
string,
{ count: number; resetTime: number }
>();
constructor(private readonly config: ConfigurationManager) {
this.securityService = new UrlSecurityService(config);
this.retryService = new RetryService({
maxAttempts: 3,
initialDelayMs: 1000,
maxDelayMs: 5000,
backoffFactor: 2,
});
}
/**
* Fetch content from a single URL with comprehensive error handling and metadata extraction
*/
async fetchUrlContent(
url: string,
options: UrlFetchOptions = {}
): Promise<UrlContentResult> {
const startTime = Date.now();
try {
// Validate URL security and format
await this.securityService.validateUrl(url, options.allowedDomains);
// Check rate limiting
this.checkRateLimit(url);
// Check cache first
const cached = this.getCachedResult(url);
if (cached) {
logger.debug("Returning cached URL content", { url });
return cached;
}
// Fetch with retry logic
const result = await this.retryService.execute(() =>
this.performUrlFetch(url, options, startTime)
);
// Cache successful result
this.cacheResult(url, result);
// Update rate limiter
this.updateRateLimit(url);
return result;
} catch (error) {
const responseTime = Date.now() - startTime;
logger.error("Failed to fetch URL content", {
url,
error: error instanceof Error ? error.message : String(error),
responseTime,
});
if (
error instanceof GeminiUrlFetchError ||
error instanceof GeminiUrlValidationError
) {
throw error;
}
throw new GeminiUrlFetchError(
`Failed to fetch URL: ${error instanceof Error ? error.message : String(error)}`,
url,
undefined,
error instanceof Error ? error : undefined
);
}
}
/**
* Process multiple URLs in parallel with intelligent batching and error handling
*/
async processUrlsForContext(
urls: string[],
options: UrlFetchOptions = {}
): Promise<{ contents: Content[]; batchResult: UrlBatchResult }> {
if (urls.length === 0) {
throw new Error("No URLs provided for processing");
}
const urlConfig = this.config.getUrlContextConfig();
if (urls.length > urlConfig.maxUrlsPerRequest) {
throw new Error(
`Too many URLs: ${urls.length}. Maximum allowed: ${urlConfig.maxUrlsPerRequest}`
);
}
const startTime = Date.now();
const successful: UrlContentResult[] = [];
const failed: Array<{ url: string; error: Error; errorCode: string }> = [];
// Process URLs in controlled batches to prevent overwhelming target servers
const batchSize = Math.min(5, urls.length);
const batches = this.createBatches(urls, batchSize);
for (const batch of batches) {
const batchPromises = batch.map(async (url) => {
try {
const result = await this.fetchUrlContent(url, options);
successful.push(result);
return { success: true, url, result };
} catch (error) {
const errorInfo = {
url,
error: error instanceof Error ? error : new Error(String(error)),
errorCode: this.getErrorCode(error),
};
failed.push(errorInfo);
return { success: false, url, error: errorInfo };
}
});
// Wait for current batch before processing next
await Promise.allSettled(batchPromises);
// Small delay between batches to be respectful to servers
if (batches.indexOf(batch) < batches.length - 1) {
await this.delay(200);
}
}
const totalTime = Date.now() - startTime;
const totalContentSize = successful.reduce(
(sum, result) => sum + result.content.length,
0
);
const averageResponseTime =
successful.length > 0
? successful.reduce(
(sum, result) => sum + result.metadata.responseTime,
0
) / successful.length
: 0;
const batchResult: UrlBatchResult = {
successful,
failed,
summary: {
totalUrls: urls.length,
successCount: successful.length,
failureCount: failed.length,
totalContentSize,
averageResponseTime,
},
};
// Convert successful results to Gemini Content format
const contents = this.convertToGeminiContent(successful, options);
logger.info("URL batch processing completed", {
totalUrls: urls.length,
successful: successful.length,
failed: failed.length,
totalTime,
totalContentSize,
});
return { contents, batchResult };
}
/**
* Perform the actual URL fetch with comprehensive metadata extraction
*/
private async performUrlFetch(
url: string,
options: UrlFetchOptions,
startTime: number
): Promise<UrlContentResult> {
const urlConfig = this.config.getUrlContextConfig();
const fetchOptions = {
method: "GET",
timeout: options.timeout || urlConfig.defaultTimeoutMs,
headers: {
"User-Agent":
options.userAgent ||
"MCP-Gemini-Server/1.0 (+hhttps://github.com/bsmi021/mcp-gemini-server)",
Accept:
"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
"Accept-Language": "en-US,en;q=0.5",
"Accept-Encoding": "gzip, deflate, br",
"Cache-Control": "no-cache",
Pragma: "no-cache",
...options.headers,
},
redirect: "follow" as RequestRedirect,
follow: options.followRedirects || 3,
size: options.maxContentLength || urlConfig.defaultMaxContentKb * 1024,
};
const response = await fetch(url, fetchOptions);
const responseTime = Date.now() - startTime;
if (!response.ok) {
throw new GeminiUrlFetchError(
`HTTP ${response.status}: ${response.statusText}`,
url,
response.status
);
}
const contentType = response.headers.get("content-type") || "text/html";
const encoding = this.extractEncodingFromContentType(contentType);
// Check content type - only process text-based content
if (!this.isTextBasedContent(contentType)) {
throw new GeminiUrlFetchError(
`Unsupported content type: ${contentType}`,
url,
response.status
);
}
let rawContent = await response.text();
const actualSize = Buffer.byteLength(rawContent, "utf8");
const maxSize =
options.maxContentLength || urlConfig.defaultMaxContentKb * 1024;
let truncated = false;
// Truncate if content is too large
if (actualSize > maxSize) {
rawContent = rawContent.substring(0, maxSize);
truncated = true;
}
// Extract metadata from HTML
const metadata = await this.extractMetadata(
rawContent,
url,
response,
responseTime,
truncated,
encoding
);
// Process content based on type and options
let processedContent = rawContent;
if (
contentType.includes("text/html") &&
(options.convertToMarkdown ?? urlConfig.convertToMarkdown)
) {
processedContent = this.convertHtmlToMarkdown(rawContent);
}
// Clean and optimize content
processedContent = this.cleanContent(processedContent);
return {
content: processedContent,
metadata,
};
}
/**
* Extract comprehensive metadata from HTML content and HTTP response
*/
private async extractMetadata(
content: string,
originalUrl: string,
response: Response,
responseTime: number,
truncated: boolean,
encoding?: string
): Promise<UrlContentMetadata> {
const contentType = response.headers.get("content-type") || "";
const contentLength = parseInt(
response.headers.get("content-length") || "0"
);
const metadata: UrlContentMetadata = {
url: originalUrl,
finalUrl: response.url !== originalUrl ? response.url : undefined,
contentType,
contentLength: contentLength || content.length,
fetchedAt: new Date(),
truncated,
responseTime,
statusCode: response.status,
encoding,
};
// Extract HTML metadata if content is HTML
if (contentType.includes("text/html")) {
const htmlMetadata = this.extractHtmlMetadata(content);
Object.assign(metadata, htmlMetadata);
}
return metadata;
}
/**
* Extract structured metadata from HTML content
*/
private extractHtmlMetadata(html: string): Partial<UrlContentMetadata> {
const metadata: Partial<UrlContentMetadata> = {};
// Extract title
const titleMatch = html.match(/<title[^>]*>([^<]+)<\/title>/i);
if (titleMatch) {
metadata.title = this.cleanText(titleMatch[1]);
}
// Extract meta description
const descMatch = html.match(
/<meta[^>]+name=["']description["'][^>]+content=["']([^"']+)["']/i
);
if (descMatch) {
metadata.description = this.cleanText(descMatch[1]);
}
// Extract language
const langMatch =
html.match(/<html[^>]+lang=["']([^"']+)["']/i) ||
html.match(
/<meta[^>]+http-equiv=["']content-language["'][^>]+content=["']([^"']+)["']/i
);
if (langMatch) {
metadata.language = langMatch[1];
}
// Extract canonical URL
const canonicalMatch = html.match(
/<link[^>]+rel=["']canonical["'][^>]+href=["']([^"']+)["']/i
);
if (canonicalMatch) {
metadata.canonicalUrl = canonicalMatch[1];
}
// Extract Open Graph image
const ogImageMatch = html.match(
/<meta[^>]+property=["']og:image["'][^>]+content=["']([^"']+)["']/i
);
if (ogImageMatch) {
metadata.ogImage = ogImageMatch[1];
}
// Extract favicon
const faviconMatch = html.match(
/<link[^>]+rel=["'](?:icon|shortcut icon)["'][^>]+href=["']([^"']+)["']/i
);
if (faviconMatch) {
metadata.favicon = faviconMatch[1];
}
return metadata;
}
/**
* Convert HTML content to clean markdown
*/
private convertHtmlToMarkdown(html: string): string {
// Remove script and style tags entirely
html = html.replace(/<(script|style)[^>]*>[\s\S]*?<\/\1>/gi, "");
// Remove comments
html = html.replace(/<!--[\s\S]*?-->/g, "");
// Convert headings
html = html.replace(
/<h([1-6])[^>]*>(.*?)<\/h\1>/gi,
(_, level, content) => {
const hashes = "#".repeat(parseInt(level));
return `\n\n${hashes} ${this.cleanText(content)}\n\n`;
}
);
// Convert paragraphs
html = html.replace(/<p[^>]*>(.*?)<\/p>/gi, "\n\n$1\n\n");
// Convert line breaks
html = html.replace(/<br\s*\/?>/gi, "\n");
// Convert lists
html = html.replace(/<ul[^>]*>([\s\S]*?)<\/ul>/gi, (_, content) => {
return content.replace(/<li[^>]*>(.*?)<\/li>/gi, "- $1\n");
});
html = html.replace(/<ol[^>]*>([\s\S]*?)<\/ol>/gi, (_, content) => {
let counter = 1;
return content.replace(
/<li[^>]*>(.*?)<\/li>/gi,
(_: string, itemContent: string) => `${counter++}. ${itemContent}\n`
);
});
// Convert links
html = html.replace(
/<a[^>]+href=["']([^"']+)["'][^>]*>(.*?)<\/a>/gi,
"[$2]($1)"
);
// Convert emphasis
html = html.replace(/<(strong|b)[^>]*>(.*?)<\/\1>/gi, "**$2**");
html = html.replace(/<(em|i)[^>]*>(.*?)<\/\1>/gi, "*$2*");
// Convert code
html = html.replace(/<code[^>]*>(.*?)<\/code>/gi, "`$1`");
html = html.replace(/<pre[^>]*>(.*?)<\/pre>/gi, "\n```\n$1\n```\n");
// Convert blockquotes
html = html.replace(
/<blockquote[^>]*>(.*?)<\/blockquote>/gi,
(_, content) => {
return content
.split("\n")
.map((line: string) => `> ${line}`)
.join("\n");
}
);
// Remove remaining HTML tags
html = html.replace(/<[^>]+>/g, "");
// Clean up the text
return this.cleanContent(html);
}
/**
* Clean and normalize text content
*/
private cleanContent(content: string): string {
// Decode HTML entities
content = content
.replace(/&/g, "&")
.replace(/</g, "<")
.replace(/>/g, ">")
.replace(/"/g, '"')
.replace(/'/g, "'")
.replace(/ /g, " ")
.replace(/—/g, "—")
.replace(/–/g, "–")
.replace(/…/g, "…");
// Normalize whitespace
content = content
.replace(/\r\n/g, "\n")
.replace(/\r/g, "\n")
.replace(/\t/g, " ")
.replace(/[ ]+/g, " ")
.replace(/\n[ ]+/g, "\n")
.replace(/[ ]+\n/g, "\n")
.replace(/\n{3,}/g, "\n\n");
// Trim and return
return content.trim();
}
/**
* Clean text by removing extra whitespace and HTML entities
*/
private cleanText(text: string): string {
return text
.replace(/&/g, "&")
.replace(/</g, "<")
.replace(/>/g, ">")
.replace(/"/g, '"')
.replace(/'/g, "'")
.replace(/ /g, " ")
.replace(/\s+/g, " ")
.trim();
}
/**
* Convert URL content results to Gemini Content format
*/
private convertToGeminiContent(
results: UrlContentResult[],
options: UrlFetchOptions
): Content[] {
const includeMetadata = options.includeMetadata ?? true;
const contents: Content[] = [];
for (const result of results) {
// Create content with URL context header
let contentText = `## Content from ${result.metadata.url}\n\n`;
if (includeMetadata && result.metadata.title) {
contentText += `**Title:** ${result.metadata.title}\n\n`;
}
if (includeMetadata && result.metadata.description) {
contentText += `**Description:** ${result.metadata.description}\n\n`;
}
contentText += result.content;
contents.push({
role: "user",
parts: [
{
text: contentText,
},
],
});
}
return contents;
}
/**
* Utility methods for caching, rate limiting, and validation
*/
private getCachedResult(url: string): UrlContentResult | null {
const cached = this.urlCache.get(url);
if (cached && Date.now() < cached.expiry) {
return cached.result;
}
this.urlCache.delete(url);
return null;
}
private cacheResult(url: string, result: UrlContentResult): void {
const cacheExpiry = Date.now() + 15 * 60 * 1000; // 15 minutes
this.urlCache.set(url, { result, expiry: cacheExpiry });
// Clean up expired cache entries
if (this.urlCache.size > 1000) {
const now = Date.now();
for (const [key, value] of this.urlCache.entries()) {
if (now >= value.expiry) {
this.urlCache.delete(key);
}
}
}
}
private checkRateLimit(url: string): void {
const domain = new URL(url).hostname;
const now = Date.now();
const limit = this.rateLimiter.get(domain);
if (limit) {
if (now < limit.resetTime) {
if (limit.count >= 10) {
// Max 10 requests per minute per domain
throw new GeminiUrlFetchError(
`Rate limit exceeded for domain: ${domain}`,
url
);
}
} else {
// Reset counter
this.rateLimiter.set(domain, { count: 0, resetTime: now + 60000 });
}
} else {
this.rateLimiter.set(domain, { count: 0, resetTime: now + 60000 });
}
}
private updateRateLimit(url: string): void {
const domain = new URL(url).hostname;
const limit = this.rateLimiter.get(domain);
if (limit) {
limit.count++;
}
}
private shouldRetryFetch(error: unknown): boolean {
if (error instanceof GeminiUrlValidationError) {
return false; // Don't retry validation errors
}
if (error instanceof GeminiUrlFetchError) {
const status = error.statusCode;
// Retry on server errors and certain client errors
return !status || status >= 500 || status === 429 || status === 408;
}
return true; // Retry network errors
}
private createBatches<T>(items: T[], batchSize: number): T[][] {
const batches: T[][] = [];
for (let i = 0; i < items.length; i += batchSize) {
batches.push(items.slice(i, i + batchSize));
}
return batches;
}
private delay(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
private extractEncodingFromContentType(
contentType: string
): string | undefined {
const match = contentType.match(/charset=([^;]+)/i);
return match ? match[1].toLowerCase() : undefined;
}
private isTextBasedContent(contentType: string): boolean {
const textTypes = [
"text/html",
"text/plain",
"text/xml",
"text/markdown",
"application/xml",
"application/xhtml+xml",
"application/json",
"application/ld+json",
];
return textTypes.some((type) => contentType.toLowerCase().includes(type));
}
private getErrorCode(error: unknown): string {
if (error instanceof GeminiUrlValidationError) {
return "VALIDATION_ERROR";
}
if (error instanceof GeminiUrlFetchError) {
return error.statusCode ? `HTTP_${error.statusCode}` : "FETCH_ERROR";
}
return "UNKNOWN_ERROR";
}
}
```
--------------------------------------------------------------------------------
/src/services/gemini/GeminiChatService.ts:
--------------------------------------------------------------------------------
```typescript
import {
GoogleGenAI,
GenerateContentResponse,
HarmCategory,
HarmBlockThreshold,
} from "@google/genai";
import { v4 as uuidv4 } from "uuid";
import {
GeminiApiError,
ValidationError as GeminiValidationError,
} from "../../utils/errors.js";
import { logger } from "../../utils/logger.js";
import {
Content,
GenerationConfig,
SafetySetting,
Tool,
ToolConfig,
FunctionCall,
ChatSession,
ThinkingConfig,
} from "./GeminiTypes.js";
import { RouteMessageParams } from "../GeminiService.js";
import { validateRouteMessageParams } from "./GeminiValidationSchemas.js";
import { ZodError } from "zod";
/**
* Maps reasoningEffort string values to token budgets
*/
const REASONING_EFFORT_MAP: Record<string, number> = {
none: 0,
low: 1024, // 1K tokens
medium: 8192, // 8K tokens
high: 24576, // 24K tokens
};
/**
* Helper function to process thinkingConfig, mapping reasoningEffort to thinkingBudget if needed
* @param thinkingConfig The thinking configuration object to process
* @returns Processed thinking configuration
*/
function processThinkingConfig(
thinkingConfig?: ThinkingConfig
): ThinkingConfig | undefined {
if (!thinkingConfig) return undefined;
const processedConfig = { ...thinkingConfig };
// Map reasoningEffort to thinkingBudget if provided
if (
processedConfig.reasoningEffort &&
REASONING_EFFORT_MAP[processedConfig.reasoningEffort] !== undefined
) {
processedConfig.thinkingBudget =
REASONING_EFFORT_MAP[processedConfig.reasoningEffort];
logger.debug(
`Mapped reasoning effort '${processedConfig.reasoningEffort}' to thinking budget: ${processedConfig.thinkingBudget} tokens`
);
}
return processedConfig;
}
/**
* Helper function to transform validated safety settings to use actual enum values
* @param safetySettings The validated safety settings from Zod
* @returns Safety settings with actual enum values
*/
function transformSafetySettings(
safetySettings?: Array<{ category: string; threshold: string }>
): SafetySetting[] | undefined {
if (!safetySettings) return undefined;
return safetySettings.map((setting) => ({
category: HarmCategory[setting.category as keyof typeof HarmCategory],
threshold:
HarmBlockThreshold[setting.threshold as keyof typeof HarmBlockThreshold],
}));
}
/**
* Interface for the parameters of the startChatSession method
*/
export interface StartChatParams {
modelName?: string;
history?: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
systemInstruction?: Content | string;
cachedContentName?: string;
}
/**
* Interface for the parameters of the sendMessageToSession method
*/
export interface SendMessageParams {
sessionId: string;
message: string;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
toolConfig?: ToolConfig;
cachedContentName?: string;
}
/**
* Interface for the parameters of the sendFunctionResultToSession method
*/
export interface SendFunctionResultParams {
sessionId: string;
functionResponse: string;
functionCall?: FunctionCall;
}
/**
* Service for handling chat session related operations for the Gemini service.
* Manages chat sessions, sending messages, and handling function responses.
*/
export class GeminiChatService {
private genAI: GoogleGenAI;
private defaultModelName?: string;
private chatSessions: Map<string, ChatSession> = new Map();
/**
* Creates a new instance of the GeminiChatService.
* @param genAI The GoogleGenAI instance to use for API calls
* @param defaultModelName Optional default model name to use if not specified in method calls
*/
constructor(genAI: GoogleGenAI, defaultModelName?: string) {
this.genAI = genAI;
this.defaultModelName = defaultModelName;
}
/**
* Starts a new stateful chat session with the Gemini model.
*
* @param params Parameters for starting a chat session
* @returns A unique session ID to identify this chat session
*/
public startChatSession(params: StartChatParams = {}): string {
const {
modelName,
history,
generationConfig,
safetySettings,
tools,
systemInstruction,
cachedContentName,
} = params;
const effectiveModelName = modelName ?? this.defaultModelName;
if (!effectiveModelName) {
throw new GeminiApiError(
"Model name must be provided either as a parameter or via the GOOGLE_GEMINI_MODEL environment variable."
);
}
// Process systemInstruction if it's a string
let formattedSystemInstruction: Content | undefined;
if (systemInstruction) {
if (typeof systemInstruction === "string") {
formattedSystemInstruction = {
parts: [{ text: systemInstruction }],
};
} else {
formattedSystemInstruction = systemInstruction;
}
}
try {
// Create the chat session using the models API
logger.debug(`Creating chat session with model: ${effectiveModelName}`);
// Create chat configuration for v0.10.0
const chatConfig: {
history?: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
systemInstruction?: Content;
cachedContent?: string;
thinkingConfig?: ThinkingConfig;
} = {};
// Add optional parameters if provided
if (history && Array.isArray(history)) {
chatConfig.history = history;
}
if (generationConfig) {
chatConfig.generationConfig = generationConfig;
// Extract thinking config if it exists within generation config
if (generationConfig.thinkingConfig) {
chatConfig.thinkingConfig = processThinkingConfig(
generationConfig.thinkingConfig
);
}
}
if (safetySettings && Array.isArray(safetySettings)) {
chatConfig.safetySettings = safetySettings;
}
if (tools && Array.isArray(tools)) {
chatConfig.tools = tools;
}
if (formattedSystemInstruction) {
chatConfig.systemInstruction = formattedSystemInstruction;
}
if (cachedContentName) {
chatConfig.cachedContent = cachedContentName;
}
// Generate a unique session ID
const sessionId = uuidv4();
// Create a mock chat session for storing configuration
// In v0.10.0, we don't have direct chat session objects,
// but we'll store the configuration to use for future messages
this.chatSessions.set(sessionId, {
model: effectiveModelName,
config: chatConfig,
history: history || [],
});
logger.info(
`Chat session created: ${sessionId} using model ${effectiveModelName}`
);
return sessionId;
} catch (error: unknown) {
logger.error("Error creating chat session:", error);
throw new GeminiApiError(
`Failed to create chat session: ${(error as Error).message}`,
error
);
}
}
/**
* Sends a message to an existing chat session.
* Uses the generated content API directly since we're managing chat state ourselves.
*
* @param params Parameters for sending a message
* @returns Promise resolving to the chat response
*/
public async sendMessageToSession(
params: SendMessageParams
): Promise<GenerateContentResponse> {
const {
sessionId,
message,
generationConfig,
safetySettings,
tools,
toolConfig,
cachedContentName,
} = params;
// Get the chat session
const session = this.chatSessions.get(sessionId);
if (!session) {
throw new GeminiApiError(`Chat session not found: ${sessionId}`);
}
// Create user content from the message
const userContent: Content = {
role: "user",
parts: [{ text: message }],
};
// Add the user message to the session history
session.history.push(userContent);
try {
// Prepare the request configuration
const requestConfig: {
model: string;
contents: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
cachedContent?: string;
thinkingConfig?: ThinkingConfig;
} = {
model: session.model,
contents: session.history,
};
// Add configuration from the original session configuration
if (session.config.systemInstruction) {
requestConfig.systemInstruction = session.config.systemInstruction;
}
// Override with any per-message configuration options
if (generationConfig) {
requestConfig.generationConfig = generationConfig;
// Extract thinking config if it exists within generation config
if (generationConfig.thinkingConfig) {
requestConfig.thinkingConfig = processThinkingConfig(
generationConfig.thinkingConfig
);
}
} else if (session.config.generationConfig) {
requestConfig.generationConfig = session.config.generationConfig;
// Use thinking config from session if available
if (session.config.thinkingConfig) {
requestConfig.thinkingConfig = processThinkingConfig(
session.config.thinkingConfig
);
}
}
if (safetySettings) {
requestConfig.safetySettings = safetySettings;
} else if (session.config.safetySettings) {
requestConfig.safetySettings = session.config.safetySettings;
}
if (tools) {
requestConfig.tools = tools;
} else if (session.config.tools) {
requestConfig.tools = session.config.tools;
}
if (toolConfig) {
requestConfig.toolConfig = toolConfig;
}
if (cachedContentName) {
requestConfig.cachedContent = cachedContentName;
} else if (session.config.cachedContent) {
requestConfig.cachedContent = session.config.cachedContent;
}
logger.debug(
`Sending message to session ${sessionId} using model ${session.model}`
);
// Call the generateContent API
const response = await this.genAI.models.generateContent(requestConfig);
// Process the response
if (response.candidates && response.candidates.length > 0) {
const assistantMessage = response.candidates[0].content;
if (assistantMessage) {
// Add the assistant response to the session history
session.history.push(assistantMessage);
}
}
return response;
} catch (error: unknown) {
logger.error(`Error sending message to session ${sessionId}:`, error);
throw new GeminiApiError(
`Failed to send message to session ${sessionId}: ${(error as Error).message}`,
error
);
}
}
/**
* Sends the result of a function call back to the chat session.
*
* @param params Parameters for sending a function result
* @returns Promise resolving to the chat response
*/
public async sendFunctionResultToSession(
params: SendFunctionResultParams
): Promise<GenerateContentResponse> {
const { sessionId, functionResponse, functionCall } = params;
// Get the chat session
const session = this.chatSessions.get(sessionId);
if (!session) {
throw new GeminiApiError(`Chat session not found: ${sessionId}`);
}
// Create function response message
const responseContent: Content = {
role: "function",
parts: [
{
functionResponse: {
name: functionCall?.name || "function",
response: { content: functionResponse },
},
},
],
};
// Add the function response to the session history
session.history.push(responseContent);
try {
// Prepare the request configuration
const requestConfig: {
model: string;
contents: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
toolConfig?: ToolConfig;
systemInstruction?: Content;
cachedContent?: string;
thinkingConfig?: ThinkingConfig;
} = {
model: session.model,
contents: session.history,
};
// Add configuration from the session
if (session.config.systemInstruction) {
requestConfig.systemInstruction = session.config.systemInstruction;
}
if (session.config.generationConfig) {
requestConfig.generationConfig = session.config.generationConfig;
// Use thinking config from session if available
if (session.config.thinkingConfig) {
requestConfig.thinkingConfig = processThinkingConfig(
session.config.thinkingConfig
);
}
}
if (session.config.safetySettings) {
requestConfig.safetySettings = session.config.safetySettings;
}
if (session.config.tools) {
requestConfig.tools = session.config.tools;
}
if (session.config.cachedContent) {
requestConfig.cachedContent = session.config.cachedContent;
}
logger.debug(
`Sending function result to session ${sessionId} using model ${session.model}`
);
// Call the generateContent API directly
const response = await this.genAI.models.generateContent(requestConfig);
// Process the response
if (response.candidates && response.candidates.length > 0) {
const assistantMessage = response.candidates[0].content;
if (assistantMessage) {
// Add the assistant response to the session history
session.history.push(assistantMessage);
}
}
return response;
} catch (error: unknown) {
logger.error(
`Error sending function result to session ${sessionId}:`,
error
);
throw new GeminiApiError(
`Failed to send function result to session ${sessionId}: ${(error as Error).message}`,
error
);
}
}
/**
* Routes a message to the most appropriate model based on a routing prompt.
* Uses a two-step process:
* 1. First asks a routing model to determine which model to use
* 2. Then sends the original message to the chosen model
*
* @param params Parameters for routing a message
* @returns Promise resolving to the chat response from the chosen model
* @throws {GeminiApiError} If routing fails or all models are unavailable
*/
public async routeMessage(
params: RouteMessageParams
): Promise<{ response: GenerateContentResponse; chosenModel: string }> {
let validatedParams;
try {
// Validate all parameters using Zod schema
validatedParams = validateRouteMessageParams(params);
} catch (validationError) {
if (validationError instanceof ZodError) {
const fieldErrors = validationError.errors
.map((err) => `${err.path.join(".")}: ${err.message}`)
.join(", ");
throw new GeminiValidationError(
`Invalid parameters for message routing: ${fieldErrors}`,
validationError.errors[0]?.path.join(".")
);
}
throw validationError;
}
const {
message,
models,
routingPrompt,
defaultModel,
generationConfig,
safetySettings,
systemInstruction,
} = validatedParams;
try {
// Use either a specific routing prompt or a default one
const effectiveRoutingPrompt =
routingPrompt ||
`You are a sophisticated model router. Your task is to analyze the following message and determine which AI model would be best suited to handle it. Choose exactly one model from this list: ${models.join(", ")}. Respond with ONLY the name of the chosen model, nothing else.`;
// Step 1: Determine the appropriate model using routing prompt
// For routing decisions, we'll use a low temperature to ensure deterministic routing
const routingConfig = {
model: models[0], // Use the first model as the router by default
contents: [
{
role: "user",
parts: [
{
text: `${effectiveRoutingPrompt}\n\nUser message: "${message}"`,
},
],
},
],
generationConfig: {
temperature: 0.2,
maxOutputTokens: 20, // Keep it short, we just need the model name
...generationConfig,
},
safetySettings: transformSafetySettings(safetySettings),
};
// If system instruction is provided, add it to the routing request
if (systemInstruction) {
if (typeof systemInstruction === "string") {
routingConfig.contents.unshift({
role: "system" as const,
parts: [{ text: systemInstruction }],
});
} else {
const formattedInstruction = {
...systemInstruction,
role: systemInstruction.role || ("system" as const),
};
routingConfig.contents.unshift(
formattedInstruction as { role: string; parts: { text: string }[] }
);
}
}
logger.debug(`Routing message using model ${routingConfig.model}`);
// Send the routing request
const routingResponse =
await this.genAI.models.generateContent(routingConfig);
if (!routingResponse?.text) {
throw new GeminiApiError("Routing model did not return any text");
}
// Extract the chosen model from the routing response
// Normalize text by removing whitespace and checking for exact matches
const routingResponseText = routingResponse.text.trim();
const chosenModel =
models.find((model) => routingResponseText.includes(model)) ||
defaultModel;
if (!chosenModel) {
throw new GeminiApiError(
`Routing failed: couldn't identify a valid model from response "${routingResponseText}"`
);
}
logger.info(
`Routing complete. Selected model: ${chosenModel} for message`
);
// Step 2: Send the original message to the chosen model
const requestConfig: {
model: string;
contents: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
thinkingConfig?: ThinkingConfig;
} = {
model: chosenModel,
contents: [
{
role: "user",
parts: [{ text: message }],
},
],
generationConfig: generationConfig,
safetySettings: transformSafetySettings(safetySettings),
};
// Extract thinking config if it exists within generation config
if (generationConfig?.thinkingConfig) {
requestConfig.thinkingConfig = processThinkingConfig(
generationConfig.thinkingConfig
);
}
// If system instruction is provided, add it to the final request
if (systemInstruction) {
if (typeof systemInstruction === "string") {
requestConfig.contents.unshift({
role: "system" as const,
parts: [{ text: systemInstruction }],
});
} else {
const formattedInstruction = {
...systemInstruction,
role: systemInstruction.role || ("system" as const),
};
requestConfig.contents.unshift(
formattedInstruction as { role: string; parts: { text: string }[] }
);
}
}
logger.debug(`Sending routed message to model ${chosenModel}`);
// Call the generateContent API with the chosen model
const response = await this.genAI.models.generateContent(requestConfig);
return {
response,
chosenModel,
};
} catch (error: unknown) {
logger.error(`Error routing message: ${(error as Error).message}`, error);
throw new GeminiApiError(
`Failed to route message: ${(error as Error).message}`,
error
);
}
}
}
```
--------------------------------------------------------------------------------
/src/services/gemini/GeminiGitDiffService.ts:
--------------------------------------------------------------------------------
```typescript
import { GoogleGenAI } from "@google/genai";
import { logger } from "../../utils/logger.js";
import {
GeminiModelError,
GeminiValidationError,
mapGeminiError,
} from "../../utils/geminiErrors.js";
import {
Content,
GenerationConfig,
SafetySetting,
Tool,
} from "./GeminiTypes.js";
import gitdiffParser from "gitdiff-parser";
import micromatch from "micromatch";
import {
getReviewTemplate,
processTemplate,
getFocusInstructions,
} from "./GeminiPromptTemplates.js";
// Define interface for gitdiff-parser return type
interface GitDiffParserFile {
oldPath: string;
newPath: string;
oldRevision: string;
newRevision: string;
hunks: Array<{
content: string;
oldStart: number;
newStart: number;
oldLines: number;
newLines: number;
changes: Array<{
content: string;
type: "insert" | "delete" | "normal";
lineNumber?: number;
oldLineNumber?: number;
newLineNumber?: number;
}>;
}>;
isBinary?: boolean;
oldEndingNewLine?: boolean;
newEndingNewLine?: boolean;
oldMode?: string;
newMode?: string;
similarity?: number;
}
// Define our interface matching the original GoogleGenAI interface
interface GenerativeModel {
generateContent(options: { contents: Content[] }): Promise<{
response: {
text(): string;
};
}>;
generateContentStream(options: { contents: Content[] }): Promise<{
stream: AsyncGenerator<{
text(): string;
}>;
}>;
startChat(options?: {
history?: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
systemInstruction?: Content;
cachedContent?: string;
}): {
sendMessage(text: string): Promise<{ response: { text(): string } }>;
sendMessageStream(
text: string
): Promise<{ stream: AsyncGenerator<{ text(): string }> }>;
getHistory(): Content[];
};
generateImages(params: {
prompt: string;
safetySettings?: SafetySetting[];
[key: string]: unknown;
}): Promise<{
images?: Array<{ data?: string; mimeType?: string }>;
promptSafetyMetadata?: {
blocked?: boolean;
safetyRatings?: Array<{ category: string; probability: string }>;
};
}>;
}
// Define interface for GoogleGenAI with getGenerativeModel method
interface ExtendedGoogleGenAI extends GoogleGenAI {
getGenerativeModel(options: {
model: string;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
}): GenerativeModel;
}
/**
* Interface for parsed git diff files
*/
interface ParsedDiffFile {
oldPath: string;
newPath: string;
oldRevision: string;
newRevision: string;
hunks: Array<{
content: string;
oldStart: number;
newStart: number;
oldLines: number;
newLines: number;
changes: Array<{
content: string;
type: "insert" | "delete" | "normal";
lineNumber?: number;
oldLineNumber?: number;
newLineNumber?: number;
}>;
}>;
isBinary?: boolean;
type: "add" | "delete" | "modify" | "rename";
oldEndingNewLine?: boolean;
newEndingNewLine?: boolean;
oldMode?: string;
newMode?: string;
similarity?: number;
}
/**
* Options for processing git diffs
*/
interface DiffProcessingOptions {
maxFilesToInclude?: number;
excludePatterns?: string[];
prioritizeFiles?: string[];
includeContextLines?: number;
maxDiffSize?: number;
}
/**
* Parameters for reviewing git diffs
*/
export interface GitDiffReviewParams {
diffContent: string;
modelName?: string;
reviewFocus?:
| "security"
| "performance"
| "architecture"
| "bugs"
| "general";
repositoryContext?: string;
diffOptions?: DiffProcessingOptions;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
systemInstruction?: Content | string;
reasoningEffort?: "none" | "low" | "medium" | "high";
customPrompt?: string;
}
/**
* Service for processing and analyzing git diffs using Gemini models
*/
export class GeminiGitDiffService {
private genAI: ExtendedGoogleGenAI;
private defaultModelName?: string;
private maxDiffSizeBytes: number;
private defaultExcludePatterns: string[];
/**
* Creates a new instance of GeminiGitDiffService
*
* @param genAI The GoogleGenAI instance
* @param defaultModelName Optional default model name
* @param maxDiffSizeBytes Maximum allowed size for diff content in bytes
* @param defaultExcludePatterns Default patterns to exclude from diff analysis
* @param defaultThinkingBudget Optional default thinking budget in tokens (0-24576)
*/
constructor(
genAI: ExtendedGoogleGenAI,
defaultModelName?: string,
maxDiffSizeBytes: number = 1024 * 1024, // 1MB default
defaultExcludePatterns: string[] = [
"package-lock.json",
"yarn.lock",
"*.min.js",
"*.min.css",
"node_modules/**",
"dist/**",
"build/**",
"*.lock",
"**/*.map",
],
private defaultThinkingBudget?: number
) {
this.genAI = genAI;
this.defaultModelName = defaultModelName;
this.maxDiffSizeBytes = maxDiffSizeBytes;
this.defaultExcludePatterns = defaultExcludePatterns;
}
/**
* Parse raw git diff content into a structured format using gitdiff-parser
*
* @param diffContent Raw git diff content as string
* @returns Array of parsed diff files with additional type information
* @throws GeminiValidationError if diff parsing fails
*/
private async parseGitDiff(diffContent: string): Promise<ParsedDiffFile[]> {
try {
// Check diff size limits
if (diffContent.length > this.maxDiffSizeBytes) {
throw new GeminiValidationError(
`Diff content exceeds maximum size (${this.maxDiffSizeBytes} bytes)`,
"diffContent"
);
}
// Parse using gitdiff-parser
// The gitdiff-parser module doesn't export types properly, but we know its structure
const parsedFiles = (
gitdiffParser as { parse: (diffStr: string) => GitDiffParserFile[] }
).parse(diffContent);
// Extend with additional type information
return parsedFiles.map((file) => {
// Determine file type based on paths and changes
let type: "add" | "delete" | "modify" | "rename" = "modify";
if (file.oldPath === "/dev/null") {
type = "add";
} else if (file.newPath === "/dev/null") {
type = "delete";
} else if (file.oldPath !== file.newPath) {
type = "rename";
} else {
type = "modify";
}
return {
...file,
type,
};
});
} catch (error: unknown) {
if (error instanceof GeminiValidationError) {
throw error;
}
logger.error("Failed to parse git diff:", error);
throw new GeminiValidationError(
"Failed to parse git diff content. Ensure it's valid output from git diff.",
"diffContent"
);
}
}
/**
* Prioritize and filter diff content based on importance using micromatch
*
* @param parsedDiff Array of parsed diff files
* @param options Options for prioritization and filtering
* @returns Filtered and prioritized diff files
*/
private filterAndPrioritizeDiff(
parsedDiff: ParsedDiffFile[],
options: DiffProcessingOptions = {}
): ParsedDiffFile[] {
let result = [...parsedDiff];
// Apply exclude patterns
const excludePatterns = [...this.defaultExcludePatterns];
if (options.excludePatterns && options.excludePatterns.length > 0) {
excludePatterns.push(...options.excludePatterns);
}
if (excludePatterns.length > 0) {
// Use micromatch for glob pattern matching
result = result.filter((file) => {
// For each file path, check if it matches any exclude pattern
return !micromatch.isMatch(file.newPath, excludePatterns);
});
}
// Apply priority patterns if specified
if (options.prioritizeFiles && options.prioritizeFiles.length > 0) {
// Score files based on prioritization patterns
const scoredFiles = result.map((file) => {
// Calculate a priority score based on matching patterns
// Higher score = higher priority
const priorityScore = options.prioritizeFiles!.reduce(
(score, pattern) => {
// If file path matches the pattern, increase its score
if (micromatch.isMatch(file.newPath, pattern)) {
return score + 1;
}
return score;
},
0
);
return { file, priorityScore };
});
// Sort by priority score (descending)
scoredFiles.sort((a, b) => b.priorityScore - a.priorityScore);
// Extract the sorted files
result = scoredFiles.map((item) => item.file);
}
// Filter to max files if specified
if (
options.maxFilesToInclude &&
options.maxFilesToInclude > 0 &&
result.length > options.maxFilesToInclude
) {
// Take only the specified number of files (already sorted by priority if applicable)
result = result.slice(0, options.maxFilesToInclude);
}
return result;
}
/**
* Generate a review prompt for the Gemini model based on the processed diff
*
* @param parsedDiff Processed diff files
* @param repositoryContext Optional context about the repository
* @param reviewFocus Optional focus area for the review
* @returns Formatted prompt string
*/
private generateReviewPrompt(
parsedDiff: ParsedDiffFile[],
repositoryContext?: string,
reviewFocus:
| "security"
| "performance"
| "architecture"
| "bugs"
| "general" = "general"
): string {
// Create file summary
const fileSummary = parsedDiff
.map((file) => {
const hunksCount = file.hunks.length;
const addedLines = file.hunks.reduce((count, hunk) => {
return (
count +
hunk.changes.filter((change) => change.type === "insert").length
);
}, 0);
const removedLines = file.hunks.reduce((count, hunk) => {
return (
count +
hunk.changes.filter((change) => change.type === "delete").length
);
}, 0);
return `- ${file.newPath}: ${hunksCount} chunk(s), +${addedLines} -${removedLines} lines`;
})
.join("\n");
// Generate diff content with context
let diffContent = "";
for (const file of parsedDiff) {
diffContent += `\n\nFile: ${file.newPath}\n`;
for (const hunk of file.hunks) {
diffContent += `@@ -${hunk.oldStart},${hunk.oldLines} +${hunk.newStart},${hunk.newLines} @@\n`;
diffContent += hunk.changes
.map((change) => {
if (change.type === "insert") {
return `+${change.content}`;
} else if (change.type === "delete") {
return `-${change.content}`;
} else {
return ` ${change.content}`;
}
})
.join("\n");
}
}
// Format repository context if provided
const formattedContext = repositoryContext
? `Repository context:\n${repositoryContext}`
: "";
// Include file summary in repository context
const fullContext = formattedContext
? `${formattedContext}\n\nSummary of changes:\n${fileSummary}`
: `Summary of changes:\n${fileSummary}`;
// Get the appropriate template based on review focus
const template = getReviewTemplate(reviewFocus);
// Process the template with the context and diff content
return processTemplate(template, {
repositoryContext: fullContext,
diffContent,
focusInstructions: getFocusInstructions(reviewFocus),
});
}
/**
* Review a git diff and generate analysis using Gemini models
*
* @param params Parameters for the review operation
* @returns Promise resolving to review text
*/
public async reviewDiff(params: GitDiffReviewParams): Promise<string> {
try {
const {
diffContent,
modelName,
reviewFocus = "general",
repositoryContext,
diffOptions = {},
generationConfig = {},
safetySettings,
systemInstruction,
reasoningEffort = "medium",
customPrompt,
} = params;
// Validate input
if (!diffContent || diffContent.trim().length === 0) {
throw new GeminiValidationError(
"Diff content is required",
"diffContent"
);
}
// Parse the diff
const parsedDiff = await this.parseGitDiff(diffContent);
// Filter and prioritize diff content
const processedDiff = this.filterAndPrioritizeDiff(
parsedDiff,
diffOptions
);
if (processedDiff.length === 0) {
return "No files to review after applying filters.";
}
// Generate the review prompt
let prompt: string;
if (customPrompt) {
// Use custom prompt if provided
prompt = customPrompt;
// Add the diff content to the custom prompt
prompt += `\n\nAnalyze the following git diff:\n\`\`\`diff\n`;
// Format diff content for the prompt
for (const file of processedDiff) {
prompt += `\n\nFile: ${file.newPath}\n`;
for (const hunk of file.hunks) {
prompt += `@@ -${hunk.oldStart},${hunk.oldLines} +${hunk.newStart},${hunk.newLines} @@\n`;
prompt += hunk.changes
.map((change) => {
if (change.type === "insert") {
return `+${change.content}`;
} else if (change.type === "delete") {
return `-${change.content}`;
} else {
return ` ${change.content}`;
}
})
.join("\n");
}
}
prompt += `\n\`\`\``;
} else {
// Use the standard prompt generator
prompt = this.generateReviewPrompt(
processedDiff,
repositoryContext,
reviewFocus
);
}
// Select the model to use
const effectiveModelName =
modelName || this.defaultModelName || "gemini-flash-2.0"; // Using cheaper Gemini Flash 2.0 as default
// Map reasoning effort to thinking budget
let thinkingBudget: number | undefined;
switch (reasoningEffort) {
case "none":
thinkingBudget = 0;
break;
case "low":
thinkingBudget = 2048;
break;
case "medium":
thinkingBudget = 4096;
break;
case "high":
thinkingBudget = 8192;
break;
default:
thinkingBudget = this.defaultThinkingBudget;
}
// Update generation config with thinking budget if specified
const updatedGenerationConfig = {
...generationConfig,
};
if (thinkingBudget !== undefined) {
updatedGenerationConfig.thinkingBudget = thinkingBudget;
}
// Get model instance
const model = this.genAI.getGenerativeModel({
model: effectiveModelName,
generationConfig: updatedGenerationConfig,
safetySettings,
});
// Create the content parts with system instructions if provided
const contentParts: Content[] = [];
if (systemInstruction) {
if (typeof systemInstruction === "string") {
contentParts.push({
role: "system",
parts: [{ text: systemInstruction }],
});
} else {
contentParts.push(systemInstruction);
}
}
contentParts.push({
role: "user",
parts: [{ text: prompt }],
});
// Generate content
const result = await model.generateContent({
contents: contentParts,
});
// Extract text from response
if (!result.response.text()) {
throw new GeminiModelError(
"Model returned empty response",
effectiveModelName
);
}
return result.response.text();
} catch (error: unknown) {
logger.error("Error reviewing git diff:", error);
throw mapGeminiError(error, "reviewGitDiff");
}
}
/**
* Stream review content for a git diff
*
* @param params Parameters for the review operation
* @returns AsyncGenerator yielding review content chunks
*/
public async *reviewDiffStream(
params: GitDiffReviewParams
): AsyncGenerator<string> {
try {
const {
diffContent,
modelName,
reviewFocus = "general",
repositoryContext,
diffOptions = {},
generationConfig = {},
safetySettings,
systemInstruction,
reasoningEffort = "medium",
customPrompt,
} = params;
// Validate input
if (!diffContent || diffContent.trim().length === 0) {
throw new GeminiValidationError(
"Diff content is required",
"diffContent"
);
}
// Parse the diff
const parsedDiff = await this.parseGitDiff(diffContent);
// Filter and prioritize diff content
const processedDiff = this.filterAndPrioritizeDiff(
parsedDiff,
diffOptions
);
if (processedDiff.length === 0) {
yield "No files to review after applying filters.";
return;
}
// Generate the review prompt
let prompt: string;
if (customPrompt) {
// Use custom prompt if provided
prompt = customPrompt;
// Add the diff content to the custom prompt
prompt += `\n\nAnalyze the following git diff:\n\`\`\`diff\n`;
// Format diff content for the prompt
for (const file of processedDiff) {
prompt += `\n\nFile: ${file.newPath}\n`;
for (const hunk of file.hunks) {
prompt += `@@ -${hunk.oldStart},${hunk.oldLines} +${hunk.newStart},${hunk.newLines} @@\n`;
prompt += hunk.changes
.map((change) => {
if (change.type === "insert") {
return `+${change.content}`;
} else if (change.type === "delete") {
return `-${change.content}`;
} else {
return ` ${change.content}`;
}
})
.join("\n");
}
}
prompt += `\n\`\`\``;
} else {
// Use the standard prompt generator
prompt = this.generateReviewPrompt(
processedDiff,
repositoryContext,
reviewFocus
);
}
// Select the model to use
const effectiveModelName =
modelName || this.defaultModelName || "gemini-flash-2.0"; // Using cheaper Gemini Flash 2.0 as default
// Map reasoning effort to thinking budget
let thinkingBudget: number | undefined;
switch (reasoningEffort) {
case "none":
thinkingBudget = 0;
break;
case "low":
thinkingBudget = 2048;
break;
case "medium":
thinkingBudget = 4096;
break;
case "high":
thinkingBudget = 8192;
break;
default:
thinkingBudget = this.defaultThinkingBudget;
}
// Update generation config with thinking budget if specified
const updatedGenerationConfig = {
...generationConfig,
};
if (thinkingBudget !== undefined) {
updatedGenerationConfig.thinkingBudget = thinkingBudget;
}
// Get model instance
const model = this.genAI.getGenerativeModel({
model: effectiveModelName,
generationConfig: updatedGenerationConfig,
safetySettings,
});
// Create the content parts with system instructions if provided
const contentParts: Content[] = [];
if (systemInstruction) {
if (typeof systemInstruction === "string") {
contentParts.push({
role: "system",
parts: [{ text: systemInstruction }],
});
} else {
contentParts.push(systemInstruction);
}
}
contentParts.push({
role: "user",
parts: [{ text: prompt }],
});
// Generate content with streaming
const result = await model.generateContentStream({
contents: contentParts,
});
// Stream chunks
for await (const chunk of result.stream) {
const chunkText = chunk.text();
if (chunkText) {
yield chunkText;
}
}
} catch (error: unknown) {
logger.error("Error streaming git diff review:", error);
throw mapGeminiError(error, "reviewGitDiffStream");
}
}
}
```
--------------------------------------------------------------------------------
/src/services/gemini/GitHubApiService.ts:
--------------------------------------------------------------------------------
```typescript
import { Octokit } from "@octokit/rest";
import { graphql } from "@octokit/graphql";
import { RequestError } from "@octokit/request-error";
import { logger } from "../../utils/logger.js";
import { ConfigurationManager } from "../../config/ConfigurationManager.js";
import { GeminiValidationError } from "../../utils/geminiErrors.js";
import { GitHubUrlParser } from "./GitHubUrlParser.js";
import KeyV from "keyv";
/**
* Interface for repository content
*/
export interface RepoContent {
name: string;
path: string;
content: string;
type: "file" | "dir" | "symlink";
size: number;
sha: string;
url: string;
html_url: string;
}
/**
* Interface for a pull request
*/
export interface PullRequest {
number: number;
title: string;
body: string;
state: string;
head: {
ref: string;
sha: string;
repo: {
full_name: string;
};
};
base: {
ref: string;
sha: string;
repo: {
full_name: string;
};
};
user: {
login: string;
};
html_url: string;
created_at: string;
updated_at: string;
merged_at: string | null;
mergeable: boolean | null;
mergeable_state: string;
changed_files: number;
additions: number;
deletions: number;
}
/**
* Interface for a PR file
*/
export interface PrFile {
filename: string;
status: string;
additions: number;
deletions: number;
changes: number;
patch?: string;
contents_url: string;
}
/**
* Interface for cache configuration
*/
interface CacheConfig {
enabled: boolean;
ttl: number; // Time-to-live in seconds
}
/**
* Service for interacting with the GitHub API
* Provides methods for fetching repository content, PR information, and diffs
*/
export class GitHubApiService {
private octokit: Octokit;
private graphqlWithAuth: typeof graphql;
private cache: KeyV;
private cacheConfig: CacheConfig;
private rateLimitRemaining: number = 5000; // Default for authenticated users
private rateLimitResetTime: Date = new Date();
private requestCount: number = 0;
/**
* Creates a new instance of GitHubApiService
* @param apiToken Optional GitHub API token, will use token from ConfigurationManager if not provided
* @param cacheEnabled Whether to enable caching (default: true)
* @param cacheTtl Time-to-live for cache entries in seconds (default: 3600 = 1 hour)
*/
constructor(
apiToken?: string,
cacheEnabled: boolean = true,
cacheTtl: number = 3600
) {
// Get token from ConfigurationManager if not provided
if (!apiToken) {
const configManager = ConfigurationManager.getInstance();
apiToken = configManager.getGitHubApiToken();
if (!apiToken) {
logger.warn(
"GitHub API token not provided. Some operations may be rate-limited or fail for private repositories."
);
}
}
// Initialize Octokit
this.octokit = new Octokit({
auth: apiToken,
});
// Initialize GraphQL with auth
this.graphqlWithAuth = graphql.defaults({
headers: {
authorization: `token ${apiToken}`,
},
});
// Configure caching
this.cacheConfig = {
enabled: cacheEnabled,
ttl: cacheTtl,
};
// Initialize cache
this.cache = new KeyV({
namespace: "github-api-cache",
ttl: cacheTtl * 1000, // Convert to milliseconds
});
// Check the rate limit initially
this.checkRateLimit().catch((error) => {
logger.warn("Failed to check initial rate limit", { error });
});
}
/**
* Check the current rate limit status
* @returns Promise resolving to the rate limit info
*/
public async checkRateLimit(): Promise<{
limit: number;
remaining: number;
resetDate: Date;
}> {
try {
const response = await this.octokit.rateLimit.get();
const { limit, remaining, reset } = response.data.resources.core;
this.rateLimitRemaining = remaining;
this.rateLimitResetTime = new Date(reset * 1000);
// Log warning if rate limit is getting low
if (remaining < limit * 0.2) {
logger.warn(
`GitHub API rate limit is getting low: ${remaining}/${limit} remaining, resets at ${this.rateLimitResetTime.toISOString()}`
);
}
return {
limit,
remaining,
resetDate: this.rateLimitResetTime,
};
} catch (error: unknown) {
logger.error("Failed to check rate limit", { error });
throw new Error("Failed to check GitHub API rate limit");
}
}
/**
* Check if we can make a request, considering rate limits
* @throws Error if rate limit is exceeded
*/
private async checkBeforeRequest(): Promise<void> {
this.requestCount++;
// Periodically check the rate limit (every 20 requests)
if (this.requestCount % 20 === 0) {
await this.checkRateLimit();
}
// Check if we're close to the rate limit
if (this.rateLimitRemaining < 10) {
const now = new Date();
const minutesUntilReset = Math.ceil(
(this.rateLimitResetTime.getTime() - now.getTime()) / (60 * 1000)
);
throw new Error(
`GitHub API rate limit nearly exceeded. ${this.rateLimitRemaining} requests remaining. Resets in ${minutesUntilReset} minutes.`
);
}
}
/**
* Get the cached value or fetch it if not in cache
* @param cacheKey The cache key
* @param fetchFn Function to fetch the value if not in cache
* @returns The cached or freshly fetched value
*/
private async getCachedOrFetch<T>(
cacheKey: string,
fetchFn: () => Promise<T>
): Promise<T> {
if (this.cacheConfig.enabled) {
// Try to get from cache
const cachedValue = await this.cache.get(cacheKey);
if (cachedValue !== undefined) {
logger.debug(`Cache hit for ${cacheKey}`);
return cachedValue as T;
}
}
// Not in cache or caching disabled, fetch fresh data
logger.debug(`Cache miss for ${cacheKey}, fetching fresh data`);
const freshValue = await fetchFn();
// Store in cache if enabled
if (this.cacheConfig.enabled) {
await this.cache.set(cacheKey, freshValue);
}
return freshValue;
}
/**
* Get the contents of a file in a repository
* @param owner Repository owner
* @param repo Repository name
* @param path Path to the file
* @param ref Optional reference (branch, tag, or commit SHA)
* @returns Promise resolving to the file content
*/
public async getFileContent(
owner: string,
repo: string,
path: string,
ref?: string
): Promise<string> {
const cacheKey = `file:${owner}/${repo}/${path}${ref ? `@${ref}` : ""}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
const response = await this.octokit.repos.getContent({
owner,
repo,
path,
ref,
});
// Handle directory case
if (Array.isArray(response.data)) {
throw new Error(`Path ${path} is a directory, not a file`);
}
// Handle file case
const fileData = response.data as {
type: string;
content?: string;
encoding?: string;
};
if (fileData.type !== "file" || !fileData.content) {
throw new Error(`Path ${path} is not a file or has no content`);
}
// Decode content (usually base64)
if (fileData.encoding === "base64") {
return Buffer.from(fileData.content, "base64").toString("utf-8");
}
return fileData.content;
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`File not found: ${path} in ${owner}/${repo}`,
"path"
);
}
logger.error("Error fetching file content", { error });
throw new Error(
`Failed to fetch file content for ${path} in ${owner}/${repo}`
);
}
});
}
/**
* List files in a repository directory
* @param owner Repository owner
* @param repo Repository name
* @param path Path to the directory
* @param ref Optional reference (branch, tag, or commit SHA)
* @returns Promise resolving to an array of repository content items
*/
public async listDirectory(
owner: string,
repo: string,
path: string = "",
ref?: string
): Promise<RepoContent[]> {
const cacheKey = `dir:${owner}/${repo}/${path}${ref ? `@${ref}` : ""}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
const response = await this.octokit.repos.getContent({
owner,
repo,
path,
ref,
});
// Handle file case (should be a directory)
if (!Array.isArray(response.data)) {
throw new Error(`Path ${path} is a file, not a directory`);
}
// Map to standardized structure and ensure html_url is never null
return response.data.map((item) => ({
name: item.name,
path: item.path,
content: "",
type: item.type as "file" | "dir" | "symlink",
size: item.size,
sha: item.sha,
url: item.url,
html_url: item.html_url || "", // Convert null to empty string
}));
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`Directory not found: ${path} in ${owner}/${repo}`,
"path"
);
}
logger.error("Error listing directory", { error });
throw new Error(
`Failed to list directory for ${path} in ${owner}/${repo}`
);
}
});
}
/**
* Get Pull Request details
* @param owner Repository owner
* @param repo Repository name
* @param prNumber Pull request number
* @returns Promise resolving to pull request details
*/
public async getPullRequest(
owner: string,
repo: string,
prNumber: number
): Promise<PullRequest> {
const cacheKey = `pr:${owner}/${repo}/${prNumber}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
const response = await this.octokit.pulls.get({
owner,
repo,
pull_number: prNumber,
});
return response.data as PullRequest;
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`Pull request not found: #${prNumber} in ${owner}/${repo}`,
"prNumber"
);
}
logger.error("Error fetching pull request", { error });
throw new Error(
`Failed to fetch pull request #${prNumber} from ${owner}/${repo}`
);
}
});
}
/**
* Get files changed in a Pull Request
* @param owner Repository owner
* @param repo Repository name
* @param prNumber Pull request number
* @returns Promise resolving to an array of changed files
*/
public async getPullRequestFiles(
owner: string,
repo: string,
prNumber: number
): Promise<PrFile[]> {
const cacheKey = `pr-files:${owner}/${repo}/${prNumber}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
const response = await this.octokit.pulls.listFiles({
owner,
repo,
pull_number: prNumber,
per_page: 100, // Get up to 100 files per page
});
return response.data as PrFile[];
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`Pull request not found: #${prNumber} in ${owner}/${repo}`,
"prNumber"
);
}
logger.error("Error fetching pull request files", { error });
throw new Error(
`Failed to fetch files for PR #${prNumber} from ${owner}/${repo}`
);
}
});
}
/**
* Get the git diff for a Pull Request
* @param owner Repository owner
* @param repo Repository name
* @param prNumber Pull request number
* @returns Promise resolving to the PR diff as a string
*/
public async getPullRequestDiff(
owner: string,
repo: string,
prNumber: number
): Promise<string> {
const cacheKey = `pr-diff:${owner}/${repo}/${prNumber}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
// Get the diff directly using the GitHub API's raw format
const response = await this.octokit.request(
`GET /repos/{owner}/{repo}/pulls/{pull_number}`,
{
owner,
repo,
pull_number: prNumber,
headers: {
accept: "application/vnd.github.v3.diff",
},
}
);
// The API returns a diff as text when using the diff content type
return String(response.data);
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`Pull request not found: #${prNumber} in ${owner}/${repo}`,
"prNumber"
);
}
logger.error("Error fetching pull request diff", { error });
throw new Error(
`Failed to fetch diff for PR #${prNumber} from ${owner}/${repo}`
);
}
});
}
/**
* Get information about the default branch
* @param owner Repository owner
* @param repo Repository name
* @returns Promise resolving to the default branch name
*/
public async getDefaultBranch(owner: string, repo: string): Promise<string> {
const cacheKey = `default-branch:${owner}/${repo}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
const response = await this.octokit.repos.get({
owner,
repo,
});
return response.data.default_branch;
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`Repository not found: ${owner}/${repo}`,
"repo"
);
}
logger.error("Error fetching repository info", { error });
throw new Error(
`Failed to fetch repository information for ${owner}/${repo}`
);
}
});
}
/**
* Get repository contents using a GitHub URL
* @param githubUrl GitHub URL (repo, branch, PR, etc.)
* @returns Promise resolving to repository information and contents
*/
public async getRepositoryInfoFromUrl(githubUrl: string): Promise<{
owner: string;
repo: string;
type: string;
branch?: string;
prNumber?: number;
issueNumber?: number;
}> {
// Parse the GitHub URL
const parsedUrl = GitHubUrlParser.parse(githubUrl);
if (!parsedUrl) {
throw new GeminiValidationError(
`Invalid GitHub URL: ${githubUrl}`,
"githubUrl"
);
}
const { owner, repo, type } = parsedUrl;
const result: {
owner: string;
repo: string;
type: string;
branch?: string;
prNumber?: number;
issueNumber?: number;
} = { owner, repo, type };
// Add type-specific information
if (parsedUrl.branch) {
result.branch = parsedUrl.branch;
} else if (parsedUrl.prNumber) {
result.prNumber = parseInt(parsedUrl.prNumber, 10);
} else if (parsedUrl.issueNumber) {
result.issueNumber = parseInt(parsedUrl.issueNumber, 10);
}
return result;
}
/**
* Processing repository data using GraphQL for more efficient querying
* @param owner Repository owner
* @param repo Repository name
* @returns Promise resolving to repository information
*/
public async getRepositoryOverview(
owner: string,
repo: string
): Promise<{
name: string;
description: string;
defaultBranch: string;
language: string;
languages: Array<{ name: string; percentage: number }>;
stars: number;
forks: number;
openIssues: number;
openPRs: number;
lastUpdated: string;
}> {
const cacheKey = `repo-overview:${owner}/${repo}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
// Define the expected type of the GraphQL result
interface GraphQLRepoResult {
repository: {
name: string;
description: string | null;
defaultBranchRef: {
name: string;
};
primaryLanguage: {
name: string;
} | null;
languages: {
edges: Array<{
node: {
name: string;
};
size: number;
}>;
totalSize: number;
};
stargazerCount: number;
forkCount: number;
issues: {
totalCount: number;
};
pullRequests: {
totalCount: number;
};
updatedAt: string;
};
}
const result = await this.graphqlWithAuth<GraphQLRepoResult>(
`
query getRepoOverview($owner: String!, $repo: String!) {
repository(owner: $owner, name: $repo) {
name
description
defaultBranchRef {
name
}
primaryLanguage {
name
}
languages(first: 10, orderBy: {field: SIZE, direction: DESC}) {
edges {
node {
name
}
size
}
totalSize
}
stargazerCount
forkCount
issues(states: OPEN) {
totalCount
}
pullRequests(states: OPEN) {
totalCount
}
updatedAt
}
}
`,
{
owner,
repo,
}
);
// Process languages data
const totalSize = result.repository.languages.totalSize;
const languages = result.repository.languages.edges.map((edge) => ({
name: edge.node.name,
percentage: Math.round((edge.size / totalSize) * 100),
}));
return {
name: result.repository.name,
description: result.repository.description || "",
defaultBranch: result.repository.defaultBranchRef.name,
language: result.repository.primaryLanguage?.name || "Unknown",
languages,
stars: result.repository.stargazerCount,
forks: result.repository.forkCount,
openIssues: result.repository.issues.totalCount,
openPRs: result.repository.pullRequests.totalCount,
lastUpdated: result.repository.updatedAt,
};
} catch (error: unknown) {
if (error instanceof RequestError && error.status === 404) {
throw new GeminiValidationError(
`Repository not found: ${owner}/${repo}`,
"repo"
);
}
logger.error("Error fetching repository overview", { error });
throw new Error(
`Failed to fetch repository overview for ${owner}/${repo}`
);
}
});
}
/**
* Get a combined diff from comparing two branches
* @param owner Repository owner
* @param repo Repository name
* @param baseBranch Base branch name
* @param headBranch Head branch name
* @returns Promise resolving to the diff as a string
*/
public async getComparisonDiff(
owner: string,
repo: string,
baseBranch: string,
headBranch: string
): Promise<string> {
const cacheKey = `comparison-diff:${owner}/${repo}/${baseBranch}...${headBranch}`;
return this.getCachedOrFetch(cacheKey, async () => {
await this.checkBeforeRequest();
try {
// Get the diff using the comparison API with diff format
const response = await this.octokit.request(
`GET /repos/{owner}/{repo}/compare/{basehead}`,
{
owner,
repo,
basehead: `${baseBranch}...${headBranch}`,
headers: {
accept: "application/vnd.github.v3.diff",
},
}
);
// The API returns a diff as text when using the diff content type
return String(response.data);
} catch (error: unknown) {
if (error instanceof RequestError) {
if (error.status === 404) {
throw new GeminiValidationError(
`Repository or branches not found: ${owner}/${repo} ${baseBranch}...${headBranch}`,
"branches"
);
}
// Handle 422 error for when the branches don't have common history
if (error.status === 422) {
throw new GeminiValidationError(
`Cannot compare branches: ${baseBranch} and ${headBranch} don't have common history`,
"branches"
);
}
}
logger.error("Error fetching comparison diff", { error });
throw new Error(
`Failed to fetch comparison diff for ${baseBranch}...${headBranch} in ${owner}/${repo}`
);
}
});
}
/**
* Invalidate a cache entry manually
* @param cacheKey The key to invalidate
*/
public async invalidateCache(cacheKey: string): Promise<void> {
if (this.cacheConfig.enabled) {
await this.cache.delete(cacheKey);
logger.debug(`Invalidated cache for ${cacheKey}`);
}
}
/**
* Clear the entire cache
*/
public async clearCache(): Promise<void> {
if (this.cacheConfig.enabled) {
await this.cache.clear();
logger.info("Cleared GitHub API cache");
}
}
}
```
--------------------------------------------------------------------------------
/tests/integration/mcpClientIntegration.test.vitest.ts:
--------------------------------------------------------------------------------
```typescript
// Using vitest globals - see vitest.config.ts globals: true
// Skip these flaky integration tests for now
const itSkipIntegration = it.skip;
import { spawn, ChildProcess } from "child_process";
import path from "path";
import fs from "fs/promises";
import os from "os";
// Import MCP client service
import { McpClientService } from "../../src/services/mcp/McpClientService.js";
// Import tool processors for direct invocation
import { mcpClientTool } from "../../src/tools/mcpClientTool.js";
import { writeToFile } from "../../src/tools/writeToFileTool.js";
// Import Configuration manager
import { ConfigurationManager } from "../../src/config/ConfigurationManager.js";
// Import integration test types
import { ToolProcessor, ToolProcessors } from "../utils/integration-types.js";
// Response types are defined inline where needed to avoid unused variable warnings
// Helper functions to set up integration environment
function createTempOutputDir(): Promise<string> {
// Create a temporary directory for test file outputs
const tempDir = path.join(os.tmpdir(), `mcp-client-test-${Date.now()}`);
return fs.mkdir(tempDir, { recursive: true }).then(() => tempDir);
}
async function cleanupTempDir(tempDir: string): Promise<void> {
try {
// Recursively delete the temporary directory
await fs.rm(tempDir, { recursive: true, force: true });
} catch (error) {
console.error(`Error cleaning up temp directory: ${error}`);
}
}
// Helper to mock the ConfigurationManager
function mockConfigurationManager(tempDir: string): void {
// Backup the original getInstance
const originalGetInstance = ConfigurationManager.getInstance;
// Mock getInstance
ConfigurationManager.getInstance = function getInstance() {
const instance = originalGetInstance.call(ConfigurationManager);
// Mock the getAllowedOutputPaths method
instance.getAllowedOutputPaths = () => [tempDir];
// The getAllowedOutputPaths method is already mocked above
// Mock the getMcpConfig method
instance.getMcpConfig = () => ({
host: "localhost",
port: 3456,
connectionToken: "test-token",
clientId: "test-client",
logLevel: "info",
transport: "stdio",
enableStreaming: false,
sessionTimeoutSeconds: 60,
});
return instance;
};
}
// Helper to restore the original ConfigurationManager
function restoreConfigurationManager(): void {
// Restore the original getInstance method
delete (
ConfigurationManager as unknown as {
getInstance?: () => ConfigurationManager;
}
).getInstance;
}
// Generic function to create a tool processor from current tool objects
function createToolProcessor(
tool: {
execute: (
args: any,
service: McpClientService
) => Promise<{ content: { type: string; text: string }[] }>;
},
mcpClientService: McpClientService
): ToolProcessor;
function createToolProcessor(
tool: {
execute: (
args: any
) => Promise<{ content: { type: string; text: string }[] }>;
},
mcpClientService: McpClientService
): ToolProcessor;
function createToolProcessor(
tool: any,
mcpClientService: McpClientService
): ToolProcessor {
return async (args: any) => {
if (tool.execute.length === 1) {
// Tool doesn't need service parameter (like writeToFile)
return await tool.execute(args);
} else {
// Tool needs service parameter (like mcpClientTool)
return await tool.execute(args, mcpClientService);
}
};
}
// Start dummy MCP server (stdio)
async function startDummyMcpServerStdio(): Promise<ChildProcess> {
const currentDir = path.dirname(new URL(import.meta.url).pathname);
const serverPath = path.resolve(currentDir, "./dummyMcpServerStdio.ts");
console.debug(`Starting STDIO server at path: ${serverPath}`);
// Verify the file exists
try {
await fs.access(serverPath);
} catch (error) {
throw new Error(`Dummy server file not found at: ${serverPath}`);
}
// Start the child process with ts-node for TypeScript execution
const nodeProcess = spawn("node", ["--loader", "ts-node/esm", serverPath], {
stdio: ["pipe", "pipe", "pipe"],
env: {
...process.env,
NODE_OPTIONS: "--no-warnings --experimental-specifier-resolution=node",
},
});
// Create a Promise that resolves when the server is ready
return new Promise((resolve, reject) => {
let errorOutput = "";
// Listen for data on stderr to detect when server is ready
nodeProcess.stderr.on("data", (data) => {
const message = data.toString();
errorOutput += message;
console.debug(`[STDIO Server stderr]: ${message}`);
// When we see the server ready message, resolve
if (message.includes("Dummy MCP Server (stdio) started")) {
resolve(nodeProcess);
}
});
// Also listen on stdout for any output
nodeProcess.stdout.on("data", (data) => {
console.debug(`[STDIO Server stdout]: ${data.toString()}`);
});
// Handle startup failure
nodeProcess.on("error", (err) => {
reject(new Error(`Failed to start dummy server: ${err.message}`));
});
// Set a timeout in case the server doesn't start
const timeout = setTimeout(() => {
nodeProcess.kill();
reject(
new Error(
`Timeout waiting for dummy server to start. Last output: ${errorOutput}`
)
);
}, 15000); // Increased timeout to 15 seconds
// Clear the timeout if we resolve or reject
nodeProcess.on("exit", () => {
clearTimeout(timeout);
});
});
}
// Start dummy MCP server (SSE)
async function startDummyMcpServerSse(port = 3456): Promise<ChildProcess> {
const currentDir = path.dirname(new URL(import.meta.url).pathname);
const serverPath = path.resolve(currentDir, "./dummyMcpServerSse.ts");
console.debug(`Starting SSE server at path: ${serverPath}`);
// Start the child process with ts-node for TypeScript execution
const nodeProcess = spawn(
"node",
["--loader", "ts-node/esm", serverPath, port.toString()],
{
stdio: ["pipe", "pipe", "pipe"],
env: {
...process.env,
NODE_OPTIONS: "--no-warnings --experimental-specifier-resolution=node",
},
}
);
// Create a Promise that resolves when the server is ready
return new Promise((resolve, reject) => {
let errorOutput = "";
// Listen for data on stderr to detect when server is ready
nodeProcess.stderr.on("data", (data) => {
const message = data.toString();
errorOutput += message;
console.debug(`[SSE Server stderr]: ${message}`);
// When we see the server ready message, resolve
if (message.includes(`Dummy MCP Server (SSE) started on port ${port}`)) {
resolve(nodeProcess);
}
});
// Also listen on stdout for any output
nodeProcess.stdout.on("data", (data) => {
console.debug(`[SSE Server stdout]: ${data.toString()}`);
});
// Handle startup failure
nodeProcess.on("error", (err) => {
reject(new Error(`Failed to start dummy server: ${err.message}`));
});
// Set a timeout in case the server doesn't start
const timeout = setTimeout(() => {
nodeProcess.kill();
reject(
new Error(
`Timeout waiting for dummy server to start. Last output: ${errorOutput}`
)
);
}, 15000); // Increased timeout to 15 seconds
// Clear the timeout if we resolve or reject
nodeProcess.on("exit", () => {
clearTimeout(timeout);
});
});
}
describe("MCP Client Integration Tests", () => {
let mcpClientService: McpClientService;
let processors: ToolProcessors;
let tempDir: string;
let stdioServer: ChildProcess | null = null;
let sseServer: ChildProcess | null = null;
// Set up test environment before all tests
beforeAll(async function () {
// Create a temporary directory for test outputs
tempDir = await createTempOutputDir();
// Set the environment variable for file security
process.env.GEMINI_SAFE_FILE_BASE_DIR = tempDir;
// Mock ConfigurationManager to use our test settings
mockConfigurationManager(tempDir);
// Initialize the MCP client service
mcpClientService = new McpClientService();
// Create tool processors for testing
processors = {
connect: createToolProcessor(mcpClientTool, mcpClientService),
listTools: createToolProcessor(mcpClientTool, mcpClientService),
callServerTool: createToolProcessor(mcpClientTool, mcpClientService),
disconnect: createToolProcessor(mcpClientTool, mcpClientService),
writeToFile: createToolProcessor(writeToFile, mcpClientService),
};
});
// Clean up after all tests
afterAll(async function () {
// Close any open MCP connections
mcpClientService.closeAllConnections();
// Kill the server processes if they're still running
if (stdioServer) {
stdioServer.kill();
}
if (sseServer) {
sseServer.kill();
}
// Restore the original ConfigurationManager
restoreConfigurationManager();
// Clean up environment variable
delete process.env.GEMINI_SAFE_FILE_BASE_DIR;
// Clean up temporary directory
await cleanupTempDir(tempDir);
});
describe("STDIO Transport Tests", () => {
// Set up stdio server before each test in this group
beforeEach(async function () {
// Start the dummy stdio server
stdioServer = await startDummyMcpServerStdio();
}, 20000); // Increase timeout to 20 seconds
// Clean up stdio server after each test
afterEach(function () {
// Kill the stdio server
if (stdioServer) {
stdioServer.kill();
stdioServer = null;
}
// Close any connections
mcpClientService.closeAllConnections();
});
itSkipIntegration(
"should connect to a stdio server, list tools, call a tool, and disconnect",
async () => {
// Step 1: Call the connect processor to connect to the dummy stdio server
const connectArgs = {
transport: "stdio",
connectionDetails: {
transport: "stdio",
command: "node",
args: [
"--loader",
"ts-node/esm",
"./tests/integration/dummyMcpServerStdio.ts",
],
},
};
// Connect to the server
const connectResult = await processors.connect(connectArgs);
// Extract the connection ID from the result
const resultJson = JSON.parse(connectResult.content[0].text);
const connectionId = resultJson.connectionId;
// Verify connection ID was returned and is a string
expect(connectionId).toBeTruthy();
expect(typeof connectionId).toBe("string");
// Step 2: List tools on the connected server
const listToolsArgs = {
connectionId,
};
const listToolsResult = await processors.listTools(listToolsArgs);
// Parse the tools list
const toolsList = JSON.parse(listToolsResult.content[0].text);
// Verify tools list
expect(Array.isArray(toolsList)).toBeTruthy();
expect(toolsList.length).toBeGreaterThanOrEqual(3);
// Verify expected tools are in the list
const toolNames = toolsList.map((tool: { name: string }) => tool.name);
expect(toolNames.includes("echoTool")).toBeTruthy();
expect(toolNames.includes("addTool")).toBeTruthy();
expect(toolNames.includes("complexDataTool")).toBeTruthy();
// Step 3: Call the echo tool
const echoMessage = "Hello from integration test";
const callToolArgs = {
connectionId,
toolName: "echoTool",
toolParameters: {
message: echoMessage,
},
};
const callToolResult = await processors.callServerTool(callToolArgs);
// Parse the result
const echoResult = JSON.parse(callToolResult.content[0].text);
// Verify echo result
expect(echoResult.message).toBe(echoMessage);
expect(echoResult.timestamp).toBeTruthy();
// Step 4: Call the add tool
const addArgs = {
connectionId,
toolName: "addTool",
toolParameters: {
a: 5,
b: 7,
},
};
const addResult = await processors.callServerTool(addArgs);
// Parse the result
const addOutput = JSON.parse(addResult.content[0].text);
// Verify add result
expect(addOutput.sum).toBe(12);
expect(addOutput.inputs).toEqual({ a: 5, b: 7 });
// Step 5: Disconnect from the server
const disconnectArgs = {
connectionId,
};
const disconnectResult = await processors.disconnect(disconnectArgs);
// Parse the disconnect result
const disconnectOutput = JSON.parse(disconnectResult.content[0].text);
// Verify disconnect result
expect(disconnectOutput.connectionId).toBe(connectionId);
expect(
disconnectOutput.message.includes("Connection closed")
).toBeTruthy();
// Verify the connection is no longer in the active connections list
expect(mcpClientService.getActiveStdioConnectionIds().length).toBe(0);
}
);
itSkipIntegration(
"should call a tool and write output to a file",
async () => {
// Step 1: Connect to the dummy stdio server
const connectArgs = {
transport: "stdio",
connectionDetails: {
transport: "stdio",
command: "node",
args: [
"--loader",
"ts-node/esm",
"./tests/integration/dummyMcpServerStdio.ts",
],
},
};
const connectResult = await processors.connect(connectArgs);
const resultJson = JSON.parse(connectResult.content[0].text);
const connectionId = resultJson.connectionId;
// Step 2: Call the complexDataTool and write output to a file
const outputPath = path.join(tempDir, "complex-data-output.json");
const callToolArgs = {
connectionId,
toolName: "complexDataTool",
toolParameters: {
depth: 2,
itemCount: 3,
},
outputFilePath: outputPath,
};
const callToolResult = await processors.callServerTool(callToolArgs);
// Parse the result
const callToolOutput = JSON.parse(callToolResult.content[0].text);
// Verify the result contains the expected information
expect(callToolOutput.message).toBe("Output written to file");
expect(callToolOutput.filePath).toBe(outputPath);
// Verify the file exists and contains the expected data
const fileExists = await fs
.access(outputPath)
.then(() => true)
.catch(() => false);
expect(fileExists).toBeTruthy();
// Read the file contents
const fileContent = await fs.readFile(outputPath, "utf8");
const fileData = JSON.parse(fileContent);
// Verify file content structure
expect(fileData.level).toBe(1);
expect(fileData.items.length).toBe(3);
expect(fileData.items[0].level).toBe(2);
// Clean up - disconnect from the server
await processors.disconnect({ connectionId });
}
);
});
describe("SSE Transport Tests", () => {
// Set up SSE server before each test in this group
beforeEach(async function () {
// Start the dummy SSE server
sseServer = await startDummyMcpServerSse();
}, 20000); // Increase timeout to 20 seconds
// Clean up SSE server after each test
afterEach(function () {
// Kill the SSE server
if (sseServer) {
sseServer.kill();
sseServer = null;
}
// Close any connections
mcpClientService.closeAllConnections();
});
itSkipIntegration(
"should connect to an SSE server, list tools, call a tool, and disconnect",
async () => {
// Step 1: Call the connect processor to connect to the dummy SSE server
const ssePort = 3456;
const connectArgs = {
transport: "sse",
connectionDetails: {
transport: "sse",
url: `http://localhost:${ssePort}/mcp`,
},
};
// Connect to the server
const connectResult = await processors.connect(connectArgs);
// Extract the connection ID from the result
const resultJson = JSON.parse(connectResult.content[0].text);
const connectionId = resultJson.connectionId;
// Verify connection ID was returned and is a string
expect(connectionId).toBeTruthy();
expect(typeof connectionId).toBe("string");
// Step 2: List tools on the connected server
const listToolsArgs = {
connectionId,
};
const listToolsResult = await processors.listTools(listToolsArgs);
// Parse the tools list
const toolsList = JSON.parse(listToolsResult.content[0].text);
// Verify tools list
expect(Array.isArray(toolsList)).toBeTruthy();
expect(toolsList.length).toBeGreaterThanOrEqual(3);
// Verify expected tools are in the list
const toolNames = toolsList.map((tool: { name: string }) => tool.name);
expect(toolNames.includes("echoTool")).toBeTruthy();
expect(toolNames.includes("addTool")).toBeTruthy();
expect(toolNames.includes("complexDataTool")).toBeTruthy();
// Step 3: Call the echo tool
const echoMessage = "Hello from SSE integration test";
const callToolArgs = {
connectionId,
toolName: "echoTool",
toolParameters: {
message: echoMessage,
},
};
const callToolResult = await processors.callServerTool(callToolArgs);
// Parse the result
const echoResult = JSON.parse(callToolResult.content[0].text);
// Verify echo result
expect(echoResult.message).toBe(echoMessage);
expect(echoResult.timestamp).toBeTruthy();
// Step 4: Disconnect from the server
const disconnectArgs = {
connectionId,
};
const disconnectResult = await processors.disconnect(disconnectArgs);
// Parse the disconnect result
const disconnectOutput = JSON.parse(disconnectResult.content[0].text);
// Verify disconnect result
expect(disconnectOutput.connectionId).toBe(connectionId);
expect(
disconnectOutput.message.includes("Connection closed")
).toBeTruthy();
// Verify the connection is no longer in the active connections list
expect(mcpClientService.getActiveSseConnectionIds().length).toBe(0);
}
);
});
describe("Write to File Tool Tests", () => {
let writeToFileProcessor: ToolProcessor;
beforeEach(function () {
// Create the writeToFile processor
writeToFileProcessor = createToolProcessor(writeToFile, mcpClientService);
});
itSkipIntegration("should write a string to a file", async () => {
// Create the file path for the test
const testFilePath = path.join(tempDir, "test-utf8-output.txt");
const testContent =
"This is a test string to write to a file\nWith multiple lines\nAnd special chars: €£¥©®™";
// Call the writeToFile processor
const args = {
filePath: testFilePath,
content: testContent,
encoding: "utf8",
};
const result = await writeToFileProcessor(args);
// Parse the result
const resultJson = JSON.parse(result.content[0].text);
// Verify the result contains the expected information
expect(resultJson.message).toBe("Content written to file successfully.");
expect(resultJson.filePath).toBe(testFilePath);
// Verify the file exists and contains the correct data
const fileExists = await fs
.access(testFilePath)
.then(() => true)
.catch(() => false);
expect(fileExists).toBeTruthy();
// Read the file and compare the content
const fileContent = await fs.readFile(testFilePath, "utf8");
expect(fileContent).toBe(testContent);
});
itSkipIntegration(
"should write a base64 encoded string to a file",
async () => {
// Create the file path for the test
const testFilePath = path.join(tempDir, "test-base64-output.txt");
// Create a test string and encode it to base64
const originalString =
"This is a test string that will be base64 encoded\nWith multiple lines\nAnd special chars: €£¥©®™";
const base64Content = Buffer.from(originalString).toString("base64");
// Call the writeToFile processor
const args = {
filePath: testFilePath,
content: base64Content,
encoding: "base64",
};
const result = await writeToFileProcessor(args);
// Parse the result
const resultJson = JSON.parse(result.content[0].text);
// Verify the result contains the expected information
expect(resultJson.message).toBe(
"Content written to file successfully."
);
expect(resultJson.filePath).toBe(testFilePath);
// Verify the file exists and contains the correct data
const fileExists = await fs
.access(testFilePath)
.then(() => true)
.catch(() => false);
expect(fileExists).toBeTruthy();
// Read the file and compare the content
const fileContent = await fs.readFile(testFilePath, "utf8");
expect(fileContent).toBe(originalString);
}
);
itSkipIntegration(
"should fail when writing to a path outside allowed directories",
async () => {
// Try to write to an absolute path outside the allowed directory
const nonAllowedPath = path.join(
os.tmpdir(),
"..",
"non-allowed-dir",
"test.txt"
);
const args = {
filePath: nonAllowedPath,
content: "This should not be written",
encoding: "utf8",
};
// The call should reject because the path is not allowed
await expect(writeToFileProcessor(args)).rejects.toThrow(
/Security error|not within the allowed output|InvalidParams/
);
// Verify the file does not exist
const fileExists = await fs
.access(nonAllowedPath)
.then(() => true)
.catch(() => false);
expect(fileExists).toBe(false);
}
);
});
});
```
--------------------------------------------------------------------------------
/src/services/GeminiService.ts:
--------------------------------------------------------------------------------
```typescript
import { GoogleGenAI, GenerateContentResponse } from "@google/genai";
import { ConfigurationManager } from "../config/ConfigurationManager.js";
import { ModelSelectionService } from "./ModelSelectionService.js";
import { logger } from "../utils/logger.js";
import {
CachedContentMetadata,
ModelSelectionCriteria,
ImageGenerationResult,
} from "../types/index.js";
import {
GeminiGitDiffService,
GitDiffReviewParams,
} from "./gemini/GeminiGitDiffService.js";
import { GitHubApiService } from "./gemini/GitHubApiService.js";
// Import specialized services
import { GeminiChatService } from "./gemini/GeminiChatService.js";
import { GeminiContentService } from "./gemini/GeminiContentService.js";
import { GeminiCacheService } from "./gemini/GeminiCacheService.js";
import {
Content,
Tool,
ToolConfig,
GenerationConfig,
SafetySetting,
CacheId,
FunctionCall,
} from "./gemini/GeminiTypes.js";
/**
* Service for interacting with the Google Gemini API.
* This is a facade that delegates to specialized services for different functionality.
*/
export class GeminiService {
private genAI: GoogleGenAI;
private defaultModelName?: string;
private modelSelector: ModelSelectionService;
private configManager: ConfigurationManager;
private chatService: GeminiChatService;
private contentService: GeminiContentService;
private cacheService: GeminiCacheService;
private gitDiffService: GeminiGitDiffService;
private gitHubApiService: GitHubApiService;
constructor() {
this.configManager = ConfigurationManager.getInstance();
const config = this.configManager.getGeminiServiceConfig();
this.modelSelector = new ModelSelectionService(
this.configManager.getModelConfiguration()
);
if (!config.apiKey) {
throw new Error("Gemini API key is required");
}
// Initialize with the apiKey property in an object as required in v0.10.0
this.genAI = new GoogleGenAI({ apiKey: config.apiKey });
this.defaultModelName = config.defaultModel;
// File security service is no longer needed since file operations were removed
// Initialize specialized services
this.contentService = new GeminiContentService(
this.genAI,
this.defaultModelName,
config.defaultThinkingBudget
);
this.chatService = new GeminiChatService(this.genAI, this.defaultModelName);
this.cacheService = new GeminiCacheService(this.genAI);
this.gitDiffService = new GeminiGitDiffService(
this.genAI,
this.defaultModelName,
1024 * 1024, // 1MB default
[
"package-lock.json",
"yarn.lock",
"*.min.js",
"*.min.css",
"node_modules/**",
"dist/**",
"build/**",
"*.lock",
"**/*.map",
],
config.defaultThinkingBudget
);
const githubApiToken = this.configManager.getGitHubApiToken();
this.gitHubApiService = new GitHubApiService(githubApiToken);
}
public async *generateContentStream(
params: GenerateContentParams & {
preferQuality?: boolean;
preferSpeed?: boolean;
preferCost?: boolean;
complexityHint?: "simple" | "medium" | "complex";
taskType?: ModelSelectionCriteria["taskType"];
}
): AsyncGenerator<string> {
const selectedModel = await this.selectModelForGeneration(params);
yield* this.contentService.generateContentStream({
...params,
modelName: selectedModel,
});
}
public async generateContent(
params: GenerateContentParams & {
preferQuality?: boolean;
preferSpeed?: boolean;
preferCost?: boolean;
complexityHint?: "simple" | "medium" | "complex";
taskType?: ModelSelectionCriteria["taskType"];
}
): Promise<string> {
const selectedModel = await this.selectModelForGeneration(params);
const result = await this.contentService.generateContent({
...params,
modelName: selectedModel,
});
return result;
}
/**
* Starts a new stateful chat session with the Gemini model.
*
* @param params Parameters for starting a chat session
* @returns A unique session ID to identify this chat session
*/
public startChatSession(params: StartChatParams = {}): string {
return this.chatService.startChatSession(params);
}
/**
* Sends a message to an existing chat session.
* Uses the generated content API directly since we're managing chat state ourselves.
*
* @param params Parameters for sending a message
* @returns Promise resolving to the chat response
*/
public async sendMessageToSession(
params: SendMessageParams
): Promise<GenerateContentResponse> {
return this.chatService.sendMessageToSession(params);
}
/**
* Sends the result of a function call back to the chat session.
*
* @param params Parameters for sending a function result
* @returns Promise resolving to the chat response
*/
public async sendFunctionResultToSession(
params: SendFunctionResultParams
): Promise<GenerateContentResponse> {
return this.chatService.sendFunctionResultToSession(params);
}
/**
* Creates a cached content entry in the Gemini API.
*
* @param modelName The model to use for this cached content
* @param contents The conversation contents to cache
* @param options Additional options for the cache (displayName, systemInstruction, ttl, tools, toolConfig)
* @returns Promise resolving to the cached content metadata
*/
public async createCache(
modelName: string,
contents: Content[],
options?: {
displayName?: string;
systemInstruction?: Content | string;
ttl?: string;
tools?: Tool[];
toolConfig?: ToolConfig;
}
): Promise<CachedContentMetadata> {
return this.cacheService.createCache(modelName, contents, options);
}
/**
* Lists cached content entries in the Gemini API.
*
* @param pageSize Optional maximum number of entries to return
* @param pageToken Optional token for pagination
* @returns Promise resolving to an object with caches array and optional nextPageToken
*/
public async listCaches(
pageSize?: number,
pageToken?: string
): Promise<{ caches: CachedContentMetadata[]; nextPageToken?: string }> {
return this.cacheService.listCaches(pageSize, pageToken);
}
/**
* Gets a specific cached content entry's metadata from the Gemini API.
*
* @param cacheId The ID of the cached content to retrieve (format: "cachedContents/{id}")
* @returns Promise resolving to the cached content metadata
*/
public async getCache(cacheId: CacheId): Promise<CachedContentMetadata> {
return this.cacheService.getCache(cacheId);
}
/**
* Updates a cached content entry in the Gemini API.
*
* @param cacheId The ID of the cached content to update (format: "cachedContents/{id}")
* @param updates The updates to apply to the cached content (ttl, displayName)
* @returns Promise resolving to the updated cached content metadata
*/
public async updateCache(
cacheId: CacheId,
updates: { ttl?: string; displayName?: string }
): Promise<CachedContentMetadata> {
return this.cacheService.updateCache(cacheId, updates);
}
/**
* Deletes a cached content entry from the Gemini API.
*
* @param cacheId The ID of the cached content to delete (format: "cachedContents/{id}")
* @returns Promise resolving to an object with success flag
*/
public async deleteCache(cacheId: CacheId): Promise<{ success: boolean }> {
return this.cacheService.deleteCache(cacheId);
}
/**
* Routes a message to the most appropriate model based on a routing prompt.
* This is useful when you have multiple specialized models and want to automatically
* select the best one for the specific query type.
*
* @param params Parameters for routing a message across models
* @returns Promise resolving to an object with the chat response and the chosen model
* @throws {GeminiApiError} If routing fails or all models are unavailable
*/
public async routeMessage(
params: RouteMessageParams
): Promise<{ response: GenerateContentResponse; chosenModel: string }> {
return this.chatService.routeMessage(params);
}
/**
* Reviews a git diff and generates analysis using Gemini models
*
* @param params Parameters for the git diff review
* @returns Promise resolving to the review text
*/
public async reviewGitDiff(params: GitDiffReviewParams): Promise<string> {
return this.gitDiffService.reviewDiff(params);
}
/**
* Streams a git diff review content using Gemini models
*
* @param params Parameters for the git diff review
* @returns AsyncGenerator yielding review content chunks as they become available
*/
public async *reviewGitDiffStream(
params: GitDiffReviewParams
): AsyncGenerator<string> {
yield* this.gitDiffService.reviewDiffStream(params);
}
/**
* Reviews a GitHub repository and generates analysis using Gemini models.
*
* IMPORTANT: This method uses a special approach to analyze repository contents by
* creating a diff against an empty tree. While effective for getting an overview of
* the repository, be aware of these limitations:
*
* 1. Token Usage: This approach consumes a significant number of tokens, especially
* for large repositories, as it treats the entire repository as one large diff.
*
* 2. Performance Impact: For very large repositories, this may result in slow
* response times and potential timeout errors.
*
* 3. Cost Considerations: The token consumption directly impacts API costs.
* Consider using the maxFilesToInclude and excludePatterns options to limit scope.
*
* 4. Scale Issues: Repositories with many files or large files may exceed context
* limits of the model, resulting in incomplete analysis.
*
* For large repositories, consider reviewing specific directories or files instead,
* or focusing on a particular branch or PR.
*
* @param params Parameters for the GitHub repository review
* @returns Promise resolving to the review text
*/
public async reviewGitHubRepository(params: {
owner: string;
repo: string;
branch?: string;
modelName?: string;
reasoningEffort?: "none" | "low" | "medium" | "high";
reviewFocus?:
| "security"
| "performance"
| "architecture"
| "bugs"
| "general";
maxFilesToInclude?: number;
excludePatterns?: string[];
prioritizeFiles?: string[];
customPrompt?: string;
}): Promise<string> {
try {
const {
owner,
repo,
branch,
modelName,
reasoningEffort = "medium",
reviewFocus = "general",
maxFilesToInclude = 50,
excludePatterns = [],
prioritizeFiles,
customPrompt,
} = params;
// Get repository overview using GitHub API
const repoOverview = await this.gitHubApiService.getRepositoryOverview(
owner,
repo
);
// Get default branch if not specified
const targetBranch = branch || repoOverview.defaultBranch;
// Create repository context for Gemini prompt
const repositoryContext = `Repository: ${owner}/${repo}
Primary Language: ${repoOverview.language}
Languages: ${repoOverview.languages.map((l) => `${l.name} (${l.percentage}%)`).join(", ")}
Description: ${repoOverview.description || "No description"}
Default Branch: ${repoOverview.defaultBranch}
Target Branch: ${targetBranch}
Stars: ${repoOverview.stars}
Forks: ${repoOverview.forks}`;
// Get content from repository files that match our criteria
// For now, we'll use a git diff approach by getting a comparison diff
// between the empty state and the target branch
const diff = await this.gitHubApiService.getComparisonDiff(
owner,
repo,
// Use a known empty reference as the base
"4b825dc642cb6eb9a060e54bf8d69288fbee4904",
targetBranch
);
// Use the git diff service to analyze the repository content
return this.gitDiffService.reviewDiff({
diffContent: diff,
modelName,
reviewFocus,
repositoryContext,
diffOptions: {
maxFilesToInclude,
excludePatterns,
prioritizeFiles,
},
reasoningEffort,
customPrompt,
});
} catch (error: unknown) {
logger.error("Error reviewing GitHub repository:", error);
throw error;
}
}
/**
* Reviews a GitHub Pull Request and generates analysis using Gemini models
*
* @param params Parameters for the GitHub PR review
* @returns Promise resolving to the review text
*/
public async reviewGitHubPullRequest(params: {
owner: string;
repo: string;
prNumber: number;
modelName?: string;
reasoningEffort?: "none" | "low" | "medium" | "high";
reviewFocus?:
| "security"
| "performance"
| "architecture"
| "bugs"
| "general";
excludePatterns?: string[];
customPrompt?: string;
}): Promise<string> {
try {
const {
owner,
repo,
prNumber,
modelName,
reasoningEffort = "medium",
reviewFocus = "general",
excludePatterns = [],
customPrompt,
} = params;
// Get PR details using GitHub API
const pullRequest = await this.gitHubApiService.getPullRequest(
owner,
repo,
prNumber
);
// Create repository context for Gemini prompt
const repositoryContext = `Repository: ${owner}/${repo}
Pull Request: #${prNumber} - ${pullRequest.title}
Author: ${pullRequest.user.login}
Base Branch: ${pullRequest.base.ref}
Head Branch: ${pullRequest.head.ref}
Files Changed: ${pullRequest.changed_files}
Additions: ${pullRequest.additions}
Deletions: ${pullRequest.deletions}
Description: ${pullRequest.body || "No description"}`;
// Get PR diff using GitHub API
const diff = await this.gitHubApiService.getPullRequestDiff(
owner,
repo,
prNumber
);
// Use the git diff service to analyze the PR
return this.gitDiffService.reviewDiff({
diffContent: diff,
modelName,
reviewFocus,
repositoryContext,
diffOptions: {
excludePatterns,
},
reasoningEffort,
customPrompt,
});
} catch (error: unknown) {
logger.error("Error reviewing GitHub Pull Request:", error);
throw error;
}
}
/**
* Generates images from text prompts using Google's image generation models
* Supports both Gemini and Imagen models for image generation
*
* @param prompt - The text prompt describing the desired image
* @param modelName - Optional model name (defaults to optimal model selection)
* @param resolution - Optional image resolution (512x512, 1024x1024, 1536x1536)
* @param numberOfImages - Optional number of images to generate (1-8)
* @param safetySettings - Optional safety settings for content filtering
* @param negativePrompt - Optional text describing what to avoid in the image
* @param stylePreset - Optional visual style to apply
* @param seed - Optional seed for reproducible generation
* @param styleStrength - Optional strength of style preset (0.0-1.0)
* @param preferQuality - Optional preference for quality over speed
* @param preferSpeed - Optional preference for speed over quality
* @returns Promise resolving to image generation result with base64 data
* @throws {GeminiValidationError} If parameters are invalid
* @throws {GeminiContentFilterError} If content is blocked by safety filters
* @throws {GeminiQuotaError} If API quota is exceeded
* @throws {GeminiModelError} If the model encounters issues generating the image
* @throws {GeminiNetworkError} If a network error occurs
* @throws {GeminiApiError} For any other errors
*/
public async generateImage(
prompt: string,
modelName?: string,
resolution?: "512x512" | "1024x1024" | "1536x1536",
numberOfImages?: number,
safetySettings?: SafetySetting[],
negativePrompt?: string,
stylePreset?: string,
seed?: number,
styleStrength?: number,
preferQuality?: boolean,
preferSpeed?: boolean
): Promise<ImageGenerationResult> {
// Log with truncated prompt for privacy/security
logger.debug(`Generating image with prompt: ${prompt.substring(0, 30)}...`);
try {
// Import validation schemas and error handling
const { validateImageGenerationParams, DEFAULT_SAFETY_SETTINGS } =
await import("./gemini/GeminiValidationSchemas.js");
const {
GeminiContentFilterError,
GeminiModelError,
GeminiErrorMessages,
} = await import("../utils/geminiErrors.js");
// Validate parameters using Zod schemas
const validatedParams = validateImageGenerationParams(
prompt,
modelName,
resolution,
numberOfImages,
safetySettings,
negativePrompt,
stylePreset,
seed,
styleStrength
);
const effectiveModel =
validatedParams.modelName ||
(await this.modelSelector.selectOptimalModel({
taskType: "image-generation",
preferQuality,
preferSpeed,
fallbackModel: "imagen-3.0-generate-002",
}));
// Get the model from the SDK
const model = this.genAI.getGenerativeModel({
model: effectiveModel,
});
// Build generation config based on validated parameters
const generationConfig: {
numberOfImages: number;
width?: number;
height?: number;
negativePrompt?: string;
stylePreset?: string;
seed?: number;
styleStrength?: number;
} = {
numberOfImages: validatedParams.numberOfImages,
};
if (validatedParams.resolution) {
const [width, height] = validatedParams.resolution
.split("x")
.map((dim) => parseInt(dim, 10));
generationConfig.width = width;
generationConfig.height = height;
}
if (validatedParams.negativePrompt) {
generationConfig.negativePrompt = validatedParams.negativePrompt;
}
if (validatedParams.stylePreset) {
generationConfig.stylePreset = validatedParams.stylePreset;
}
if (validatedParams.seed !== undefined) {
generationConfig.seed = validatedParams.seed;
}
if (validatedParams.styleStrength !== undefined) {
generationConfig.styleStrength = validatedParams.styleStrength;
}
// Apply default safety settings if none provided
const effectiveSafetySettings =
validatedParams.safetySettings ||
(DEFAULT_SAFETY_SETTINGS as SafetySetting[]);
// Generate the images using the correct generateImages API
const result = await model.generateImages({
prompt: validatedParams.prompt,
safetySettings: effectiveSafetySettings as SafetySetting[],
...generationConfig,
});
// Check for safety blocks first (higher priority than empty results)
if (result.promptSafetyMetadata?.blocked) {
const safetyRatings = result.promptSafetyMetadata.safetyRatings || [];
throw new GeminiContentFilterError(
GeminiErrorMessages.CONTENT_FILTERED,
safetyRatings.map((rating) => rating.category)
);
}
// Check if images were generated successfully
if (!result.images || result.images.length === 0) {
throw new GeminiModelError(
"No images were generated by the model",
"image_generation"
);
}
// Parse resolution for width/height
const [width, height] = (validatedParams.resolution || "1024x1024")
.split("x")
.map((dim) => parseInt(dim, 10));
// Format the images according to our expected structure
const formattedImages = result.images.map((image) => ({
base64Data: image.data || "",
mimeType: image.mimeType || "image/png",
width,
height,
}));
const formattedResult: ImageGenerationResult = {
images: formattedImages,
};
// Validate the generated images
await this.validateGeneratedImages(formattedResult);
return formattedResult;
} catch (error: unknown) {
const { mapGeminiError } = await import("../utils/geminiErrors.js");
// Map to appropriate error type
throw mapGeminiError(error, "generateImage");
}
}
/**
* Validates generated images to ensure they meet quality and safety standards
* @param result - The image generation result to validate
* @throws {GeminiValidationError} If validation fails
*/
private async validateGeneratedImages(
result: ImageGenerationResult
): Promise<void> {
// Import validation error from utils
const { GeminiValidationError } = await import("../utils/geminiErrors.js");
// Check that each image has proper data
for (const [index, image] of result.images.entries()) {
// Verify base64 data is present and valid
if (!image.base64Data || image.base64Data.length < 100) {
throw new GeminiValidationError(
`Image ${index} has invalid or missing data`,
"base64Data"
);
}
// Verify MIME type is supported
const supportedMimeTypes = ["image/png", "image/jpeg", "image/webp"];
if (!supportedMimeTypes.includes(image.mimeType)) {
throw new GeminiValidationError(
`Image ${index} has unsupported MIME type: ${image.mimeType}`,
"mimeType"
);
}
// Verify dimensions are positive numbers
if (image.width <= 0 || image.height <= 0) {
throw new GeminiValidationError(
`Image ${index} has invalid dimensions: ${image.width}x${image.height}`,
"dimensions"
);
}
}
}
private async selectModelForGeneration(params: {
modelName?: string;
preferQuality?: boolean;
preferSpeed?: boolean;
preferCost?: boolean;
complexityHint?: "simple" | "medium" | "complex";
taskType?: ModelSelectionCriteria["taskType"];
prompt?: string;
}): Promise<string> {
if (params.modelName) {
return params.modelName;
}
const complexity =
params.complexityHint ||
this.analyzePromptComplexity(params.prompt || "");
return this.modelSelector.selectOptimalModel({
taskType: params.taskType || "text-generation",
complexityLevel: complexity,
preferQuality: params.preferQuality,
preferSpeed: params.preferSpeed,
preferCost: params.preferCost,
fallbackModel:
this.defaultModelName ||
this.configManager.getModelConfiguration().default,
});
}
private analyzePromptComplexity(
prompt: string
): "simple" | "medium" | "complex" {
const complexKeywords = [
"analyze",
"compare",
"evaluate",
"synthesize",
"reasoning",
"complex",
"detailed analysis",
"comprehensive",
"explain why",
"what are the implications",
"trade-offs",
"pros and cons",
"algorithm",
"architecture",
"design pattern",
];
const codeKeywords = [
"function",
"class",
"import",
"export",
"const",
"let",
"var",
"if",
"else",
"for",
"while",
"return",
"async",
"await",
];
const wordCount = prompt.split(/\s+/).length;
const hasComplexKeywords = complexKeywords.some((keyword) =>
prompt.toLowerCase().includes(keyword.toLowerCase())
);
const hasCodeKeywords = codeKeywords.some((keyword) =>
prompt.toLowerCase().includes(keyword.toLowerCase())
);
if (hasComplexKeywords || hasCodeKeywords || wordCount > 100) {
return "complex";
} else if (wordCount > 20) {
return "medium";
} else {
return "simple";
}
}
public getModelSelector(): ModelSelectionService {
return this.modelSelector;
}
public async getOptimalModelForTask(
criteria: ModelSelectionCriteria
): Promise<string> {
return this.modelSelector.selectOptimalModel(criteria);
}
public isModelAvailable(modelName: string): boolean {
return this.modelSelector.isModelAvailable(modelName);
}
public getAvailableModels(): string[] {
return this.modelSelector.getAvailableModels();
}
public validateModelForTask(
modelName: string,
taskType: ModelSelectionCriteria["taskType"]
): boolean {
return this.modelSelector.validateModelForTask(modelName, taskType);
}
// Model selection history and performance metrics methods removed
// These were not implemented in ModelSelectionService
}
// Define interfaces directly to avoid circular dependencies
export interface GenerateContentParams {
prompt: string;
modelName?: string;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
systemInstruction?: Content | string;
cachedContentName?: string;
urlContext?: {
urls: string[];
fetchOptions?: {
maxContentKb?: number;
timeoutMs?: number;
includeMetadata?: boolean;
convertToMarkdown?: boolean;
allowedDomains?: string[];
userAgent?: string;
};
};
preferQuality?: boolean;
preferSpeed?: boolean;
preferCost?: boolean;
complexityHint?: "simple" | "medium" | "complex";
taskType?: ModelSelectionCriteria["taskType"];
urlCount?: number;
estimatedUrlContentSize?: number;
}
export interface StartChatParams {
modelName?: string;
history?: Content[];
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
systemInstruction?: Content | string;
cachedContentName?: string;
}
export interface SendMessageParams {
sessionId: string;
message: string;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
tools?: Tool[];
toolConfig?: ToolConfig;
cachedContentName?: string;
}
export interface SendFunctionResultParams {
sessionId: string;
functionResponse: string;
functionCall?: FunctionCall;
}
/**
* Interface for the routing parameters when sending messages to multiple models
*/
export interface RouteMessageParams {
message: string;
models: string[];
routingPrompt?: string;
defaultModel?: string;
generationConfig?: GenerationConfig;
safetySettings?: SafetySetting[];
systemInstruction?: Content | string;
}
// Re-export other types for backwards compatibility
export type {
Content,
Tool,
ToolConfig,
GenerationConfig,
SafetySetting,
Part,
FunctionCall,
CacheId,
} from "./gemini/GeminiTypes.js";
```