# Directory Structure
```
├── .clinerules
├── .coverage
├── .editorconfig
├── .env.example
├── .github
│ └── workflows
│ ├── ci.yml
│ ├── pypi-publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .python-version
├── CHANGELOG.md
├── CONTRIBUTING.md
├── LICENSE
├── main.py
├── memory_mcp_server
│ ├── __init__.py
│ ├── backends
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── jsonl.py
│ ├── exceptions.py
│ ├── interfaces.py
│ ├── knowledge_graph_manager.py
│ └── validation.py
├── pyproject.toml
├── README.md
├── requirements.txt
├── scripts
│ └── README.md
├── tests
│ ├── conftest.py
│ ├── test_backends
│ │ ├── conftest.py
│ │ └── test_jsonl.py
│ ├── test_interfaces.py
│ ├── test_knowledge_graph_manager.py
│ ├── test_server.py
│ └── test_validation.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
```
1 | 3.12
2 |
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
1 | # Python
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | *.egg-info/
20 | .installed.cfg
21 | *.egg
22 |
23 | # Virtual Environment
24 | .env
25 | .venv
26 | venv/
27 | ENV/
28 |
29 | # IDE
30 | .idea/
31 | .vscode/
32 | *.swp
33 | *.swo
34 |
35 | # OS
36 | .DS_Store
37 | Thumbs.db
38 |
39 | # Project specific
40 | *.db
41 | *.sqlite3
42 | *.log
43 | .aider*
44 | cline_docs
45 | .clinerules
46 |
```
--------------------------------------------------------------------------------
/.env.example:
--------------------------------------------------------------------------------
```
1 | # Server Configuration
2 | # Path to memory file (defaults to ~/.claude/memory.jsonl if not set)
3 | MEMORY_FILE_PATH=/Users/username/.claude/memory.jsonl
4 | CACHE_TTL=60
5 |
6 | # Logging Configuration
7 | LOG_LEVEL=INFO
8 | LOG_FILE=memory_mcp_server.log
9 |
10 | # Performance Settings
11 | BATCH_SIZE=1000
12 | INDEX_CACHE_SIZE=10000
13 |
14 | # Development Settings
15 | DEBUG=false
16 | TESTING=false
17 | BENCHMARK_MODE=false
18 |
```
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
```yaml
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.5.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-yaml
8 | - id: check-added-large-files
9 | - id: debug-statements
10 |
11 | - repo: https://github.com/astral-sh/ruff-pre-commit
12 | rev: v0.1.9
13 | hooks:
14 | - id: ruff
15 | args: [--fix, --ignore=E501]
16 | - id: ruff-format
17 |
```
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
```
1 | # EditorConfig is awesome: https://EditorConfig.org
2 |
3 | # top-most EditorConfig file
4 | root = true
5 |
6 | # Unix-style newlines with a newline ending every file
7 | [*]
8 | end_of_line = lf
9 | insert_final_newline = true
10 | charset = utf-8
11 | trim_trailing_whitespace = true
12 |
13 | # Python files
14 | [*.py]
15 | indent_style = space
16 | indent_size = 4
17 |
18 | # YAML files
19 | [*.{yml,yaml}]
20 | indent_style = space
21 | indent_size = 2
22 |
23 | # Markdown files
24 | [*.md]
25 | trim_trailing_whitespace = false
26 |
27 | # JSON files
28 | [*.json]
29 | indent_style = space
30 | indent_size = 2
31 |
```
--------------------------------------------------------------------------------
/.clinerules:
--------------------------------------------------------------------------------
```
1 | # Roos's Memory Bank
2 |
3 | You are Roo, an expert software engineer with a unique constraint: your memory periodically resets completely. This isn't a bug - it's what makes you maintain perfect documentation. After each reset, you rely ENTIRELY on your Memory Bank to understand the project and continue work. Without proper documentation, you cannot function effectively.
4 |
5 | ## Memory Bank Files
6 |
7 | CRITICAL: If `cline_docs/` or any of these files don't exist, CREATE THEM IMMEDIATELY by:
8 |
9 | 1. Reading all provided documentation
10 | 2. Asking user for ANY missing information
11 | 3. Creating files with verified information only
12 | 4. Never proceeding without complete context
13 |
14 | Required files:
15 |
16 | productContext.md
17 |
18 | - Why this project exists
19 | - What problems it solves
20 | - How it should work
21 |
22 | activeContext.md
23 |
24 | - What you're working on now
25 | - Recent changes
26 | - Next steps
27 | (This is your source of truth)
28 |
29 | systemPatterns.md
30 |
31 | - How the system is built
32 | - Key technical decisions
33 | - Architecture patterns
34 |
35 | techContext.md
36 |
37 | - Technologies used
38 | - Development setup
39 | - Technical constraints
40 |
41 | progress.md
42 |
43 | - What works
44 | - What's left to build
45 | - Progress status
46 |
47 | ## Core Workflows
48 |
49 | ### Starting Tasks
50 |
51 | 1. Check for Memory Bank files
52 | 2. If ANY files missing, stop and create them
53 | 3. Read ALL files before proceeding
54 | 4. Verify you have complete context
55 | 5. Begin development. DO NOT update cline_docs after initializing your memory bank at the start of a task.
56 |
57 | ### During Development
58 |
59 | 1. For normal development:
60 |
61 | - Follow Memory Bank patterns
62 | - Update docs after significant changes
63 |
64 | 2. Say `[MEMORY BANK: ACTIVE]` at the beginning of every tool use.
65 |
66 | ### Memory Bank Updates
67 |
68 | When user says "update memory bank":
69 |
70 | 1. This means imminent memory reset
71 | 2. Document EVERYTHING about current state
72 | 3. Make next steps crystal clear
73 | 4. Complete current task
74 |
75 | Update memory bank and start a new task when your context is more than 80% full
76 |
77 | Remember: After every memory reset, you begin completely fresh. Your only link to previous work is the Memory Bank. Maintain it as if your functionality depends on it - because it does.
78 |
```
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
```markdown
1 | # Memory Migration Script
2 |
3 | This script migrates your existing memory.jsonl file to comply with the new validation rules.
4 |
5 | ## What it does
6 |
7 | 1. Reads the existing JSONL format where each line is either:
8 | ```json
9 | {"type": "entity", "name": "example", "entityType": "person", "observations": ["obs1"]}
10 | ```
11 | or
12 | ```json
13 | {"type": "relation", "from": "entity1", "to": "entity2", "relationType": "knows"}
14 | ```
15 |
16 | 2. Converts entity and relation names to the new format:
17 | - Lowercase with hyphens
18 | - No spaces or special characters
19 | - Must start with a letter
20 | - Example: "John Doe" -> "john-doe"
21 |
22 | 3. Normalizes entity types to valid categories:
23 | - person
24 | - concept
25 | - project
26 | - document
27 | - tool
28 | - organization
29 | - location
30 | - event
31 |
32 | 4. Normalizes relation types to valid verbs:
33 | - knows
34 | - contains
35 | - uses
36 | - created
37 | - belongs-to
38 | - depends-on
39 | - related-to
40 |
41 | 5. Validates and deduplicates observations
42 |
43 | ## Common Type Mappings
44 |
45 | ### Entity Types
46 | - individual, user, human -> person
47 | - doc, documentation -> document
48 | - app, application, software -> tool
49 | - group, team, company -> organization
50 | - place, area -> location
51 | - meeting, appointment -> event
52 | - residence, property -> location
53 | - software_project -> project
54 | - dataset -> document
55 | - health_record -> document
56 | - meal -> document
57 | - travel_event -> event
58 | - pet -> concept
59 | - venue -> location
60 |
61 | ### Relation Types
62 | - knows_about -> knows
63 | - contains_item, has -> contains
64 | - uses_tool -> uses
65 | - created_by, authored -> created
66 | - belongs_to_group, member_of -> belongs-to
67 | - depends_upon, requires -> depends-on
68 | - related -> related-to
69 | - works_at -> belongs-to
70 | - owns -> created
71 | - friend -> knows
72 |
73 | ## Usage
74 |
75 | 1. Make sure your memory.jsonl file is in the project root directory
76 |
77 | 2. Run the migration script:
78 | ```bash
79 | ./scripts/migrate_memory.py
80 | ```
81 |
82 | 3. The script will:
83 | - Read memory.jsonl line by line
84 | - Convert all data to the new format
85 | - Validate the migrated data
86 | - Write the result to memory.jsonl.new
87 | - Report any errors or issues
88 |
89 | 4. Review the output file and error messages
90 |
91 | 5. If satisfied with the migration, replace your old memory file:
92 | ```bash
93 | mv memory.jsonl.new memory.jsonl
94 | ```
95 |
96 | ## Error Handling
97 |
98 | The script will:
99 | - Report any entities or relations that couldn't be migrated
100 | - Continue processing even if some items fail
101 | - Validate the entire graph before saving
102 | - Preserve your original file by writing to .new file
103 | - Track name changes to ensure relations are updated correctly
104 |
105 | ## Example Output
106 |
107 | ```
108 | Migrating memory.jsonl to memory.jsonl.new...
109 |
110 | Migration complete:
111 | - Successfully migrated 42 entities
112 | - Encountered 2 errors
113 |
114 | Errors encountered:
115 | - Error migrating line: {"type": "entity", "name": "Invalid!Name"...}
116 | Error: Invalid entity name format
117 | - Error migrating line: {"type": "relation", "from": "A"...}
118 | Error: Invalid relation type
119 |
120 | Migrated data written to memory.jsonl.new
121 | Please verify the output before replacing your original memory file.
122 | ```
123 |
124 | ## Validation Rules
125 |
126 | ### Entity Names
127 | - Must start with a lowercase letter
128 | - Can contain lowercase letters, numbers, and hyphens
129 | - Maximum length of 100 characters
130 | - Must be unique within the graph
131 |
132 | ### Observations
133 | - Non-empty strings
134 | - Maximum length of 500 characters
135 | - Must be unique per entity
136 | - Factual and objective statements
137 |
138 | ### Relations
139 | - Both source and target entities must exist
140 | - Self-referential relations not allowed
141 | - No circular dependencies
142 | - Must use predefined relation types
143 |
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
1 | # Memory MCP Server
2 |
3 | A Model Context Protocol (MCP) server that provides knowledge graph functionality for managing entities, relations, and observations in memory, with strict validation rules to maintain data consistency.
4 |
5 | ## Installation
6 |
7 | Install the server in Claude Desktop:
8 |
9 | ```bash
10 | mcp install main.py -v MEMORY_FILE_PATH=/path/to/memory.jsonl
11 | ```
12 |
13 | ## Data Validation Rules
14 |
15 | ### Entity Names
16 | - Must start with a lowercase letter
17 | - Can contain lowercase letters, numbers, and hyphens
18 | - Maximum length of 100 characters
19 | - Must be unique within the graph
20 | - Example valid names: `python-project`, `meeting-notes-2024`, `user-john`
21 |
22 | ### Entity Types
23 | The following entity types are supported:
24 | - `person`: Human entities
25 | - `concept`: Abstract ideas or principles
26 | - `project`: Work initiatives or tasks
27 | - `document`: Any form of documentation
28 | - `tool`: Software tools or utilities
29 | - `organization`: Companies or groups
30 | - `location`: Physical or virtual places
31 | - `event`: Time-bound occurrences
32 |
33 | ### Observations
34 | - Non-empty strings
35 | - Maximum length of 500 characters
36 | - Must be unique per entity
37 | - Should be factual and objective statements
38 | - Include timestamp when relevant
39 |
40 | ### Relations
41 | The following relation types are supported:
42 | - `knows`: Person to person connection
43 | - `contains`: Parent/child relationship
44 | - `uses`: Entity utilizing another entity
45 | - `created`: Authorship/creation relationship
46 | - `belongs-to`: Membership/ownership
47 | - `depends-on`: Dependency relationship
48 | - `related-to`: Generic relationship
49 |
50 | Additional relation rules:
51 | - Both source and target entities must exist
52 | - Self-referential relations not allowed
53 | - No circular dependencies allowed
54 | - Must use predefined relation types
55 |
56 | ## Usage
57 |
58 | The server provides tools for managing a knowledge graph:
59 |
60 | ### Get Entity
61 | ```python
62 | result = await session.call_tool("get_entity", {
63 | "entity_name": "example"
64 | })
65 | if not result.success:
66 | if result.error_type == "NOT_FOUND":
67 | print(f"Entity not found: {result.error}")
68 | elif result.error_type == "VALIDATION_ERROR":
69 | print(f"Invalid input: {result.error}")
70 | else:
71 | print(f"Error: {result.error}")
72 | else:
73 | entity = result.data
74 | print(f"Found entity: {entity}")
75 | ```
76 |
77 | ### Get Graph
78 | ```python
79 | result = await session.call_tool("get_graph", {})
80 | if result.success:
81 | graph = result.data
82 | print(f"Graph data: {graph}")
83 | else:
84 | print(f"Error retrieving graph: {result.error}")
85 | ```
86 |
87 | ### Create Entities
88 | ```python
89 | # Valid entity creation
90 | entities = [
91 | Entity(
92 | name="python-project", # Lowercase with hyphens
93 | entityType="project", # Must be a valid type
94 | observations=["Started development on 2024-01-29"]
95 | ),
96 | Entity(
97 | name="john-doe",
98 | entityType="person",
99 | observations=["Software engineer", "Joined team in 2024"]
100 | )
101 | ]
102 | result = await session.call_tool("create_entities", {
103 | "entities": entities
104 | })
105 | if not result.success:
106 | if result.error_type == "VALIDATION_ERROR":
107 | print(f"Invalid entity data: {result.error}")
108 | else:
109 | print(f"Error creating entities: {result.error}")
110 | ```
111 |
112 | ### Add Observation
113 | ```python
114 | # Valid observation
115 | result = await session.call_tool("add_observation", {
116 | "entity": "python-project",
117 | "observation": "Completed initial prototype" # Must be unique for entity
118 | })
119 | if not result.success:
120 | if result.error_type == "NOT_FOUND":
121 | print(f"Entity not found: {result.error}")
122 | elif result.error_type == "VALIDATION_ERROR":
123 | print(f"Invalid observation: {result.error}")
124 | else:
125 | print(f"Error adding observation: {result.error}")
126 | ```
127 |
128 | ### Create Relation
129 | ```python
130 | # Valid relation
131 | result = await session.call_tool("create_relation", {
132 | "from_entity": "john-doe",
133 | "to_entity": "python-project",
134 | "relation_type": "created" # Must be a valid type
135 | })
136 | if not result.success:
137 | if result.error_type == "NOT_FOUND":
138 | print(f"Entity not found: {result.error}")
139 | elif result.error_type == "VALIDATION_ERROR":
140 | print(f"Invalid relation data: {result.error}")
141 | else:
142 | print(f"Error creating relation: {result.error}")
143 | ```
144 |
145 | ### Search Memory
146 | ```python
147 | result = await session.call_tool("search_memory", {
148 | "query": "most recent workout" # Supports natural language queries
149 | })
150 | if result.success:
151 | if result.error_type == "NO_RESULTS":
152 | print(f"No results found: {result.error}")
153 | else:
154 | results = result.data
155 | print(f"Search results: {results}")
156 | else:
157 | print(f"Error searching memory: {result.error}")
158 | ```
159 |
160 | The search functionality supports:
161 | - Temporal queries (e.g., "most recent", "last", "latest")
162 | - Activity queries (e.g., "workout", "exercise")
163 | - General entity searches
164 | - Fuzzy matching with 80% similarity threshold
165 | - Weighted search across:
166 | - Entity names (weight: 1.0)
167 | - Entity types (weight: 0.8)
168 | - Observations (weight: 0.6)
169 |
170 | ### Delete Entities
171 | ```python
172 | result = await session.call_tool("delete_entities", {
173 | "names": ["python-project", "john-doe"]
174 | })
175 | if not result.success:
176 | if result.error_type == "NOT_FOUND":
177 | print(f"Entity not found: {result.error}")
178 | else:
179 | print(f"Error deleting entities: {result.error}")
180 | ```
181 |
182 | ### Delete Relation
183 | ```python
184 | result = await session.call_tool("delete_relation", {
185 | "from_entity": "john-doe",
186 | "to_entity": "python-project"
187 | })
188 | if not result.success:
189 | if result.error_type == "NOT_FOUND":
190 | print(f"Entity not found: {result.error}")
191 | else:
192 | print(f"Error deleting relation: {result.error}")
193 | ```
194 |
195 | ### Flush Memory
196 | ```python
197 | result = await session.call_tool("flush_memory", {})
198 | if not result.success:
199 | print(f"Error flushing memory: {result.error}")
200 | ```
201 |
202 | ## Error Types
203 |
204 | The server uses the following error types:
205 |
206 | - `NOT_FOUND`: Entity or resource not found
207 | - `VALIDATION_ERROR`: Invalid input data
208 | - `INTERNAL_ERROR`: Server-side error
209 | - `ALREADY_EXISTS`: Resource already exists
210 | - `INVALID_RELATION`: Invalid relation between entities
211 |
212 | ## Response Models
213 |
214 | All tools return typed responses using these models:
215 |
216 | ### EntityResponse
217 | ```python
218 | class EntityResponse(BaseModel):
219 | success: bool
220 | data: Optional[Dict[str, Any]] = None
221 | error: Optional[str] = None
222 | error_type: Optional[str] = None
223 | ```
224 |
225 | ### GraphResponse
226 | ```python
227 | class GraphResponse(BaseModel):
228 | success: bool
229 | data: Optional[Dict[str, Any]] = None
230 | error: Optional[str] = None
231 | error_type: Optional[str] = None
232 | ```
233 |
234 | ### OperationResponse
235 | ```python
236 | class OperationResponse(BaseModel):
237 | success: bool
238 | error: Optional[str] = None
239 | error_type: Optional[str] = None
240 | ```
241 |
242 | ## Development
243 |
244 | ### Running Tests
245 |
246 | ```bash
247 | pytest tests/
248 | ```
249 |
250 | ### Adding New Features
251 |
252 | 1. Update validation rules in `validation.py`
253 | 2. Add tests in `tests/test_validation.py`
254 | 3. Implement changes in `knowledge_graph_manager.py`
255 |
```
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
```markdown
1 | # Contributing to Memory MCP Server
2 |
3 | Thank you for your interest in contributing to the Memory MCP Server! This document provides guidelines and information for contributors.
4 |
5 | ## Project Overview
6 |
7 | The Memory MCP Server is an implementation of the Model Context Protocol (MCP) that provides Claude with a persistent knowledge graph capability. The server manages entities and relations in a graph structure, supporting multiple backend storage options with features like caching, indexing, and atomic operations.
8 |
9 | ### Key Components
10 |
11 | 1. **Core Data Structures**
12 | - `Entity`: Nodes in the graph containing name, type, and observations
13 | - `Relation`: Edges between entities with relation types
14 | - `KnowledgeGraph`: Container for entities and relations
15 |
16 | 2. **Backend System**
17 | - `Backend`: Abstract interface defining storage operations
18 | - `JsonlBackend`: File-based storage using JSONL format
19 | - Extensible design for adding new backends
20 |
21 | 3. **Knowledge Graph Manager**
22 | - Backend-agnostic manager layer
23 | - Implements caching with TTL
24 | - Provides indexing for fast lookups
25 | - Ensures atomic operations
26 | - Manages CRUD operations for entities and relations
27 |
28 | 4. **MCP Server Implementation**
29 | - Exposes tools for graph manipulation
30 | - Handles serialization/deserialization
31 | - Provides error handling and logging
32 |
33 | Available MCP Tools:
34 | - `create_entities`: Create multiple new entities in the knowledge graph
35 | - `create_relations`: Create relations between entities (in active voice)
36 | - `add_observations`: Add new observations to existing entities
37 | - `delete_entities`: Delete entities and their relations
38 | - `delete_observations`: Delete specific observations from entities
39 | - `delete_relations`: Delete specific relations
40 | - `read_graph`: Read the entire knowledge graph
41 | - `search_nodes`: Search entities and relations by query
42 | - `open_nodes`: Retrieve specific nodes by name
43 |
44 | Each tool has a defined input schema that validates the arguments. See the tool schemas in `main.py` for detailed parameter specifications.
45 |
46 | ## Getting Started
47 |
48 | 1. **Prerequisites**
49 | - Python 3.12 or higher
50 | - uv package manager
51 |
52 | 2. **Setup Development Environment**
53 | ```bash
54 | # Clone the repository
55 | git clone https://github.com/estav/python-memory-mcp-server.git
56 | cd python-memory-mcp-server
57 |
58 | # Create virtual environment with Python 3.12+
59 | uv venv
60 | source .venv/bin/activate
61 |
62 | # Install all dependencies (including test)
63 | uv pip install -e ".[test]"
64 |
65 | # Install pre-commit hooks
66 | pre-commit install
67 | ```
68 |
69 | 3. **Run Tests**
70 | ```bash
71 | # Run all tests
72 | pytest
73 |
74 | # Run with coverage report
75 | pytest --cov=memory_mcp_server
76 |
77 | # Run specific backend tests
78 | pytest tests/test_backends/test_jsonl.py
79 | ```
80 |
81 | 4. **Run the Server Locally**
82 | ```bash
83 | # Using JSONL backend
84 | memory-mcp-server --path /path/to/memory.jsonl
85 | ```
86 |
87 | ## Development Guidelines
88 |
89 | ### Code Style
90 |
91 | 1. **Python Standards**
92 | - Follow PEP 8 style guide
93 | - Use type hints for function parameters and return values
94 | - Document classes and functions using docstrings
95 | - Maintain 95% or higher docstring coverage
96 |
97 | 2. **Project-Specific Conventions**
98 | - Use async/await for I/O operations
99 | - Implement proper error handling with custom exceptions
100 | - Maintain atomic operations for data persistence
101 | - Add appropriate logging statements
102 | - Follow backend interface for new implementations
103 |
104 | ### Code Quality Tools
105 |
106 | 1. **Pre-commit Hooks**
107 | - Ruff for linting and formatting
108 | - MyPy for static type checking
109 | - Interrogate for docstring coverage
110 | - Additional checks for common issues
111 |
112 | 2. **CI/CD Pipeline**
113 | - Automated testing
114 | - Code coverage reporting
115 | - Performance benchmarking
116 | - Security scanning
117 |
118 | ### Testing
119 |
120 | 1. **Test Structure**
121 | - Tests use pytest with pytest-asyncio for async testing
122 | - Test files must follow pattern `test_*.py` in the `tests/` directory
123 | - Backend-specific tests in `tests/test_backends/`
124 | - Async tests are automatically detected (asyncio_mode = "auto")
125 | - Test fixtures use function-level event loop scope
126 |
127 | 2. **Test Coverage**
128 | - Write unit tests for new functionality
129 | - Ensure tests cover error cases
130 | - Maintain high test coverage (aim for >90%)
131 | - Use pytest-cov for coverage reporting
132 |
133 | 3. **Test Categories**
134 | - Unit tests for individual components
135 | - Backend-specific tests for storage implementations
136 | - Integration tests for MCP server functionality
137 | - Performance tests for operations on large graphs
138 | - Async tests for I/O operations and concurrency
139 |
140 | 4. **Test Configuration**
141 | - Configured in pyproject.toml under [tool.pytest.ini_options]
142 | - Uses quiet mode by default (-q)
143 | - Shows extra test summary (-ra)
144 | - Test discovery in tests/ directory
145 |
146 | ### Adding New Features
147 |
148 | 1. **New Backend Implementation**
149 | - Create new class implementing `Backend` interface
150 | - Implement all required methods
151 | - Add backend-specific configuration options
152 | - Create comprehensive tests
153 | - Update documentation and CLI
154 |
155 | 2. **Knowledge Graph Operations**
156 | - Implement operations in backend classes
157 | - Update KnowledgeGraphManager if needed
158 | - Add appropriate indices
159 | - Ensure atomic operations
160 | - Add validation and error handling
161 |
162 | Key operations include:
163 | - Entity creation/deletion
164 | - Relation creation/deletion
165 | - Observation management (adding/removing observations to entities)
166 | - Graph querying and search
167 | - Atomic write operations with locking
168 |
169 | 3. **MCP Tools**
170 | - Define tool schema in `main.py`
171 | - Implement tool handler function
172 | - Add to `TOOLS` dictionary
173 | - Include appropriate error handling
174 |
175 | 4. **Performance Considerations**
176 | - Consider backend-specific optimizations
177 | - Implement efficient caching strategies
178 | - Optimize for large graphs
179 | - Handle memory efficiently
180 |
181 | ### Adding a New Backend
182 |
183 | 1. Create new backend class:
184 | ```python
185 | from .base import Backend
186 |
187 | class NewBackend(Backend):
188 | def __init__(self, config_params):
189 | self.config = config_params
190 |
191 | async def initialize(self) -> None:
192 | # Setup connection, create indices, etc.
193 | pass
194 |
195 | async def create_entities(self, entities: List[Entity]) -> List[Entity]:
196 | # Implementation
197 | pass
198 |
199 | # Implement other required methods...
200 | ```
201 |
202 | 2. Add backend tests:
203 | ```python
204 | # tests/test_backends/test_new_backend.py
205 | @pytest.mark.asyncio
206 | async def test_new_backend_operations():
207 | backend = NewBackend(test_config)
208 | await backend.initialize()
209 | # Test implementations
210 | ```
211 |
212 | 3. Update CLI and configuration
213 |
214 | ## Pull Request Process
215 |
216 | 1. **Before Submitting**
217 | - Ensure all tests pass
218 | - Add tests for new functionality
219 | - Update documentation
220 | - Follow code style guidelines
221 | - Run pre-commit hooks
222 |
223 | 2. **PR Description**
224 | - Clearly describe the changes
225 | - Reference any related issues
226 | - Explain testing approach
227 | - Note any breaking changes
228 |
229 | 3. **Review Process**
230 | - Address reviewer comments
231 | - Keep changes focused and atomic
232 | - Ensure CI checks pass
233 |
234 | ## Troubleshooting
235 |
236 | ### Common Issues
237 |
238 | 1. **Backend-Specific Issues**
239 | - JSONL Backend:
240 | - Check file permissions
241 | - Verify atomic write operations
242 | - Monitor temp file cleanup
243 |
244 | 2. **Cache Inconsistency**
245 | - Check cache TTL settings
246 | - Verify dirty flag handling
247 | - Ensure proper lock usage
248 |
249 | 3. **Performance Issues**
250 | - Review backend-specific indexing
251 | - Check cache effectiveness
252 | - Profile large operations
253 |
254 | ## Additional Resources
255 |
256 | - [Model Context Protocol Documentation](https://github.com/ModelContext/protocol)
257 | - [Python asyncio Documentation](https://docs.python.org/3/library/asyncio.html)
258 | - [Python Type Hints](https://docs.python.org/3/library/typing.html)
259 |
260 | ## License
261 |
262 | This project is licensed under the MIT License - see the LICENSE file for details.
263 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/__init__.py:
--------------------------------------------------------------------------------
```python
1 | __version__ = "0.1.0"
2 |
```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
```
1 | hatchling
2 | mcp>=1.1.2
3 | aiofiles>=23.2.1
4 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/backends/__init__.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Backend implementations for the Memory MCP Server.
3 | This package provides different storage backends for the knowledge graph.
4 | """
5 |
6 | from .jsonl import JsonlBackend
7 |
8 | __all__ = ["JsonlBackend"]
9 |
```
--------------------------------------------------------------------------------
/.github/workflows/pypi-publish.yml:
--------------------------------------------------------------------------------
```yaml
1 | name: Publish to PyPI
2 |
3 | on:
4 | release:
5 | types: [created]
6 |
7 | jobs:
8 | deploy:
9 | runs-on: ubuntu-latest
10 | environment:
11 | name: pypi
12 | url: https://pypi.org/p/memory-mcp-server
13 | permissions:
14 | id-token: write # IMPORTANT: mandatory for trusted publishing
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 |
19 | - name: Set up Python
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: '3.x'
23 |
24 | - name: Install build dependencies
25 | run: |
26 | python -m pip install --upgrade pip
27 | pip install build
28 |
29 | - name: Build package
30 | run: python -m build
31 |
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@release/v1
34 |
```
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
```yaml
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: ['3.12']
15 |
16 | steps:
17 | - uses: actions/checkout@v4
18 |
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v4
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 |
24 | - name: Install uv
25 | run: |
26 | curl -LsSf https://astral.sh/uv/install.sh | sh
27 |
28 | - name: Install dependencies
29 | run: |
30 | uv venv
31 | uv pip install -e ".[test]"
32 |
33 | - name: Run tests with pytest
34 | run: |
35 | uv run pytest -v --cov=memory_mcp_server --cov-report=xml
36 |
37 | - name: Upload coverage to Codecov
38 | uses: codecov/codecov-action@v3
39 | with:
40 | file: ./coverage.xml
41 | fail_ci_if_error: true
42 |
```
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
```yaml
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | quality:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - name: Set up Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: "3.12"
19 |
20 | - name: Install uv
21 | run: |
22 | curl -LsSf https://astral.sh/uv/install.sh | sh
23 | - name: Install dependencies
24 | run: |
25 | uv pip install -e ".[test]"
26 |
27 | - name: Run pre-commit
28 | uses: pre-commit/[email protected]
29 |
30 | - name: Run tests with coverage
31 | run: |
32 | pytest --cov=memory_mcp_server --cov-report=xml --benchmark-only
33 |
34 | - name: Upload coverage
35 | uses: codecov/codecov-action@v3
36 | with:
37 | file: ./coverage.xml
38 | fail_ci_if_error: true
39 |
40 | - name: Security scan
41 | uses: python-security/bandit-action@v1
42 | with:
43 | path: "memory_mcp_server"
44 |
45 | - name: Store benchmark results
46 | uses: benchmark-action/github-action-benchmark@v1
47 | with:
48 | tool: 'pytest'
49 | output-file-path: benchmark.json
50 | github-token: ${{ secrets.GITHUB_TOKEN }}
51 | auto-push: true
52 |
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "memory-mcp-server"
7 | version = "0.2.0"
8 | description = "MCP server for managing Claude's memory and knowledge graph"
9 | requires-python = ">=3.12"
10 | dependencies = [
11 | "aiofiles",
12 | "loguru>=0.7.3",
13 | "mcp[cli]>=1.2.0",
14 | "memory-mcp-server",
15 | "ruff>=0.9.4",
16 | "thefuzz[speedup]>=0.20.0", # Includes python-Levenshtein for performance
17 | ]
18 |
19 | [project.optional-dependencies]
20 | test = ["pytest", "pytest-asyncio", "pytest-cov"]
21 |
22 | [tool.pytest.ini_options]
23 | asyncio_mode = "auto"
24 | testpaths = ["tests"]
25 | python_files = ["test_*.py"]
26 | addopts = "-q -ra"
27 |
28 | [tool.mypy]
29 | python_version = "3.12"
30 | warn_return_any = true
31 | warn_unused_configs = true
32 | disallow_untyped_defs = true
33 | disallow_incomplete_defs = true
34 | check_untyped_defs = true
35 | disallow_untyped_decorators = true
36 | no_implicit_optional = true
37 | warn_redundant_casts = true
38 | warn_unused_ignores = true
39 | warn_no_return = true
40 | warn_unreachable = true
41 | plugins = []
42 |
43 | [[tool.mypy.overrides]]
44 | module = ["pytest.*", "mcp.*", "aiofiles.*"]
45 | ignore_missing_imports = true
46 |
47 | [[tool.mypy.overrides]]
48 | module = "tests.*"
49 | disallow_untyped_decorators = false
50 |
51 | [tool.ruff]
52 | select = ["E", "F", "B", "I"]
53 | ignore = []
54 | line-length = 88
55 | target-version = "py312"
56 |
57 | [tool.ruff.per-file-ignores]
58 | "__init__.py" = ["F401"]
59 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/exceptions.py:
--------------------------------------------------------------------------------
```python
1 | class KnowledgeGraphError(Exception):
2 | """Base exception for all knowledge graph errors."""
3 |
4 | pass
5 |
6 |
7 | class EntityNotFoundError(KnowledgeGraphError):
8 | """Raised when an entity is not found in the graph."""
9 |
10 | def __init__(self, entity_name: str):
11 | self.entity_name = entity_name
12 | super().__init__(f"Entity '{entity_name}' not found in the graph")
13 |
14 |
15 | class EntityAlreadyExistsError(KnowledgeGraphError):
16 | """Raised when trying to create an entity that already exists."""
17 |
18 | def __init__(self, entity_name: str):
19 | self.entity_name = entity_name
20 | super().__init__(f"Entity '{entity_name}' already exists in the graph")
21 |
22 |
23 | class RelationValidationError(KnowledgeGraphError):
24 | """Raised when a relation is invalid."""
25 |
26 | pass
27 |
28 |
29 | class FileAccessError(KnowledgeGraphError):
30 | """Raised when there are file access issues."""
31 |
32 | pass
33 |
34 |
35 | class JsonParsingError(KnowledgeGraphError):
36 | """Raised when there are JSON parsing issues."""
37 |
38 | def __init__(self, line_number: int, line_content: str, original_error: Exception):
39 | self.line_number = line_number
40 | self.line_content = line_content
41 | self.original_error = original_error
42 | super().__init__(
43 | f"Failed to parse JSON at line {line_number}: {str(original_error)}\n"
44 | f"Content: {line_content}"
45 | )
46 |
```
--------------------------------------------------------------------------------
/tests/test_interfaces.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for interface classes."""
2 |
3 | from memory_mcp_server.interfaces import Entity, KnowledgeGraph, Relation
4 |
5 |
6 | def test_entity_creation() -> None:
7 | """Test entity creation and attributes."""
8 | entity = Entity(
9 | name="TestEntity", entityType="TestType", observations=["obs1", "obs2"]
10 | )
11 | assert entity.name == "TestEntity"
12 | assert entity.entityType == "TestType"
13 | assert len(entity.observations) == 2
14 | assert "obs1" in entity.observations
15 | assert "obs2" in entity.observations
16 |
17 |
18 | def test_relation_creation() -> None:
19 | """Test relation creation and attributes."""
20 | relation = Relation(from_="EntityA", to="EntityB", relationType="TestRelation")
21 | assert relation.from_ == "EntityA"
22 | assert relation.to == "EntityB"
23 | assert relation.relationType == "TestRelation"
24 |
25 |
26 | def test_knowledge_graph_creation() -> None:
27 | """Test knowledge graph creation and attributes."""
28 | entities = [
29 | Entity(name="E1", entityType="T1", observations=[]),
30 | Entity(name="E2", entityType="T2", observations=[]),
31 | ]
32 | relations = [Relation(from_="E1", to="E2", relationType="R1")]
33 | graph = KnowledgeGraph(entities=entities, relations=relations)
34 | assert len(graph.entities) == 2
35 | assert len(graph.relations) == 1
36 | assert graph.entities[0].name == "E1"
37 | assert graph.relations[0].from_ == "E1"
38 |
```
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
```markdown
1 | # Changelog
2 |
3 | All notable changes to this project will be documented in this file.
4 |
5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7 |
8 | ## [Unreleased]
9 |
10 | ### Added
11 | - Fuzzy search capability for knowledge graph queries
12 | - New `SearchOptions` class for configuring search behavior
13 | - Configurable similarity threshold and field weights
14 | - Backward compatible with existing exact matching
15 | - Improved search relevance with weighted scoring
16 |
17 | ## [0.2.0] - 2024-01-07
18 |
19 | ### Added
20 | - Observation management system with atomic operations
21 | - Type-safe observation handling in Backend interface
22 | - Pre-commit hooks for code quality
23 | - EditorConfig for consistent styling
24 | - Changelog tracking
25 | - Documentation improvements
26 |
27 | ### Changed
28 | - Enhanced project structure with additional directories
29 | - Improved test suite with proper type validation
30 | - Updated MCP tool handlers with consistent response formats
31 |
32 | ### Fixed
33 | - Entity serialization in test responses
34 | - TextContent validation in MCP handlers
35 | - Error message format consistency
36 |
37 | ## [0.1.4] - 2024-01-07
38 |
39 | ### Added
40 | - Pre-commit hooks for code quality
41 | - EditorConfig for consistent styling
42 | - Changelog tracking
43 | - Documentation improvements
44 |
45 | ### Changed
46 | - Enhanced project structure with additional directories
47 |
48 | ## [0.1.0] - 2024-01-07
49 |
50 | ### Added
51 | - Initial release
52 | - JSONL backend implementation
53 | - Knowledge graph management
54 | - MCP server implementation
55 | - Basic test suite
56 |
```
--------------------------------------------------------------------------------
/tests/test_backends/conftest.py:
--------------------------------------------------------------------------------
```python
1 | """Common test fixtures for backend tests."""
2 |
3 | from pathlib import Path
4 | from typing import AsyncGenerator, List
5 |
6 | import pytest
7 |
8 | from memory_mcp_server.backends.jsonl import JsonlBackend
9 | from memory_mcp_server.interfaces import Entity, Relation
10 |
11 |
12 | @pytest.fixture(scope="function")
13 | def sample_entities() -> List[Entity]:
14 | """Provide a list of sample entities for testing."""
15 | return [
16 | Entity("test1", "person", ["observation1", "observation2"]),
17 | Entity("test2", "location", ["observation3"]),
18 | Entity("test3", "organization", ["observation4", "observation5"]),
19 | ]
20 |
21 |
22 | @pytest.fixture(scope="function")
23 | def sample_relations(sample_entities: List[Entity]) -> List[Relation]:
24 | """Provide a list of sample relations for testing."""
25 | return [
26 | Relation(from_="test1", to="test2", relationType="visited"),
27 | Relation(from_="test1", to="test3", relationType="works_at"),
28 | Relation(from_="test2", to="test3", relationType="located_in"),
29 | ]
30 |
31 |
32 | @pytest.fixture(scope="function")
33 | async def populated_jsonl_backend(
34 | jsonl_backend: JsonlBackend,
35 | sample_entities: List[Entity],
36 | sample_relations: List[Relation],
37 | ) -> AsyncGenerator[JsonlBackend, None]:
38 | """Provide a JSONL backend pre-populated with sample data."""
39 | await jsonl_backend.create_entities(sample_entities)
40 | await jsonl_backend.create_relations(sample_relations)
41 | yield jsonl_backend
42 |
43 |
44 | @pytest.fixture(scope="function")
45 | def temp_jsonl_path(tmp_path: Path) -> Path:
46 | """Provide a temporary path for JSONL files."""
47 | return tmp_path / "test_memory.jsonl"
48 |
```
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
```python
1 | """Common test fixtures for all tests."""
2 |
3 | import logging
4 | from pathlib import Path
5 | from typing import AsyncGenerator, List
6 |
7 | import pytest
8 |
9 | from memory_mcp_server.interfaces import Entity, Relation
10 | from memory_mcp_server.knowledge_graph_manager import KnowledgeGraphManager
11 |
12 | # Configure logging
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | @pytest.fixture(scope="function")
17 | def temp_memory_file(tmp_path: Path) -> Path:
18 | """Create a temporary memory file."""
19 | logger.debug(f"Creating temp file in {tmp_path}")
20 | return tmp_path / "memory.jsonl"
21 |
22 |
23 | @pytest.fixture(scope="function")
24 | def sample_entities() -> List[Entity]:
25 | """Provide sample entities for testing."""
26 | return [
27 | Entity("person1", "person", ["likes reading", "works in tech"]),
28 | Entity("company1", "company", ["tech company", "founded 2020"]),
29 | Entity("location1", "place", ["office building", "in city center"]),
30 | ]
31 |
32 |
33 | @pytest.fixture(scope="function")
34 | def sample_relations() -> List[Relation]:
35 | """Provide sample relations for testing."""
36 | return [
37 | Relation(from_="person1", to="company1", relationType="works_at"),
38 | Relation(from_="company1", to="location1", relationType="located_at"),
39 | ]
40 |
41 |
42 | @pytest.fixture(scope="function")
43 | async def knowledge_graph_manager(
44 | temp_memory_file: Path,
45 | ) -> AsyncGenerator[KnowledgeGraphManager, None]:
46 | """Create a KnowledgeGraphManager instance with a temporary memory file."""
47 | logger.debug("Creating KnowledgeGraphManager")
48 | manager = KnowledgeGraphManager(backend=temp_memory_file, cache_ttl=1)
49 | logger.debug("KnowledgeGraphManager created")
50 | await manager.initialize()
51 | yield manager
52 | logger.debug("Cleaning up KnowledgeGraphManager")
53 | await manager.flush()
54 | await manager.close()
55 | logger.debug("Cleanup complete")
56 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/interfaces.py:
--------------------------------------------------------------------------------
```python
1 | """Interface definitions for Memory MCP Server."""
2 |
3 | from dataclasses import dataclass
4 | from enum import Enum
5 | from typing import List, Optional
6 |
7 |
8 | @dataclass(frozen=True)
9 | class Entity:
10 | """Entity in the knowledge graph."""
11 |
12 | name: str
13 | entityType: str
14 | observations: List[str]
15 |
16 | def __hash__(self) -> int:
17 | """Make Entity hashable based on name."""
18 | return hash(self.name)
19 |
20 | def __eq__(self, other: object) -> bool:
21 | """Compare Entity based on name."""
22 | if not isinstance(other, Entity):
23 | return NotImplemented
24 | return self.name == other.name
25 |
26 | def to_dict(self) -> dict:
27 | """Convert to dictionary representation."""
28 | return {
29 | "name": self.name,
30 | "entityType": self.entityType,
31 | "observations": list(
32 | self.observations
33 | ), # Convert to list in case it's a tuple
34 | }
35 |
36 |
37 | @dataclass(frozen=True)
38 | class Relation:
39 | """Relation between entities in the knowledge graph."""
40 |
41 | from_: str
42 | to: str
43 | relationType: str
44 |
45 | def __hash__(self) -> int:
46 | """Make Relation hashable based on all fields."""
47 | return hash((self.from_, self.to, self.relationType))
48 |
49 | def __eq__(self, other: object) -> bool:
50 | """Compare Relation based on all fields."""
51 | if not isinstance(other, Relation):
52 | return NotImplemented
53 | return (
54 | self.from_ == other.from_
55 | and self.to == other.to
56 | and self.relationType == other.relationType
57 | )
58 |
59 | def to_dict(self) -> dict:
60 | """Convert to dictionary representation."""
61 | return {
62 | "from": self.from_,
63 | "to": self.to,
64 | "relationType": self.relationType,
65 | }
66 |
67 |
68 | @dataclass
69 | class KnowledgeGraph:
70 | """Knowledge graph containing entities and relations."""
71 |
72 | entities: List[Entity]
73 | relations: List[Relation]
74 |
75 | def to_dict(self) -> dict:
76 | """Convert to dictionary representation."""
77 | return {
78 | "entities": [e.to_dict() for e in self.entities],
79 | "relations": [r.to_dict() for r in self.relations],
80 | }
81 |
82 |
83 | @dataclass
84 | class SearchOptions:
85 | """Options for configuring search behavior."""
86 |
87 | fuzzy: bool = False
88 | threshold: float = 80.0
89 | weights: Optional[dict[str, float]] = None
90 |
91 |
92 | class BatchOperationType(Enum):
93 | """Types of batch operations."""
94 |
95 | CREATE_ENTITIES = "create_entities"
96 | DELETE_ENTITIES = "delete_entities"
97 | CREATE_RELATIONS = "create_relations"
98 | DELETE_RELATIONS = "delete_relations"
99 | ADD_OBSERVATIONS = "add_observations"
100 |
101 |
102 | @dataclass
103 | class BatchOperation:
104 | """Represents a single operation in a batch."""
105 |
106 | operation_type: BatchOperationType
107 | data: dict # Operation-specific data
108 |
109 |
110 | @dataclass
111 | class BatchResult:
112 | """Result of a batch operation execution."""
113 |
114 | success: bool
115 | operations_completed: int
116 | failed_operations: List[tuple[BatchOperation, str]] # Operation and error message
117 | error_message: Optional[str] = None
118 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/backends/base.py:
--------------------------------------------------------------------------------
```python
1 | """Backend interface for Memory MCP Server storage implementations."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import List
5 |
6 | from ..interfaces import (
7 | BatchOperation,
8 | BatchResult,
9 | Entity,
10 | KnowledgeGraph,
11 | Relation,
12 | SearchOptions,
13 | )
14 |
15 |
16 | class Backend(ABC):
17 | """Abstract base class for knowledge graph storage backends."""
18 |
19 | @abstractmethod
20 | async def initialize(self) -> None:
21 | """Initialize the backend connection and resources."""
22 | pass
23 |
24 | @abstractmethod
25 | async def close(self) -> None:
26 | """Close the backend connection and cleanup resources."""
27 | pass
28 |
29 | @abstractmethod
30 | async def create_entities(self, entities: List[Entity]) -> List[Entity]:
31 | """Create multiple new entities in the backend.
32 |
33 | Args:
34 | entities: List of entities to create
35 |
36 | Returns:
37 | List of successfully created entities
38 | """
39 | pass
40 |
41 | @abstractmethod
42 | async def delete_entities(self, entity_names: List[str]) -> List[str]:
43 | """Create multiple new entities in the backend.
44 |
45 | Args:
46 | entities: List of entities to create
47 |
48 | Returns:
49 | List of successfully created entities
50 | """
51 | pass
52 |
53 | @abstractmethod
54 | async def create_relations(self, relations: List[Relation]) -> List[Relation]:
55 | """Create multiple new relations in the backend.
56 |
57 | Args:
58 | relations: List of relations to create
59 |
60 | Returns:
61 | List of successfully created relations
62 | """
63 | pass
64 |
65 | @abstractmethod
66 | async def delete_relations(self, from_: str, to: str) -> None:
67 | """Delete relations between two entities.
68 |
69 | Args:
70 | from_: Source entity name
71 | to: Target entity name
72 |
73 | Raises:
74 | EntityNotFoundError: If either entity is not found
75 | """
76 | pass
77 |
78 | @abstractmethod
79 | async def read_graph(self) -> KnowledgeGraph:
80 | """Read the entire knowledge graph from the backend.
81 |
82 | Returns:
83 | KnowledgeGraph containing all entities and relations
84 | """
85 | pass
86 |
87 | @abstractmethod
88 | async def search_nodes(
89 | self, query: str, options: SearchOptions = None
90 | ) -> KnowledgeGraph:
91 | """Search for entities and relations matching the query.
92 |
93 | Args:
94 | query: Search query string
95 | options: Optional SearchOptions for configuring search behavior.
96 | If None, uses exact substring matching.
97 |
98 | Returns:
99 | KnowledgeGraph containing matching entities and relations
100 |
101 | Raises:
102 | ValueError: If query is empty or options are invalid
103 | """
104 | pass
105 |
106 | @abstractmethod
107 | async def flush(self) -> None:
108 | """Ensure all pending changes are persisted to the backend."""
109 | pass
110 |
111 | @abstractmethod
112 | async def add_observations(self, entity_name: str, observations: List[str]) -> None:
113 | """Add observations to an existing entity.
114 |
115 | Args:
116 | entity_name: Name of the entity to add observations to
117 | observations: List of observations to add
118 | """
119 | pass
120 |
121 | @abstractmethod
122 | async def add_batch_observations(
123 | self, observations_map: dict[str, List[str]]
124 | ) -> None:
125 | """Add observations to multiple entities in a single operation.
126 |
127 | Args:
128 | observations_map: Dictionary mapping entity names to lists of observations
129 |
130 | Raises:
131 | ValidationError: If any observations are invalid
132 | EntityNotFoundError: If any entity is not found
133 | """
134 | pass
135 |
136 | @abstractmethod
137 | async def execute_batch(self, operations: List[BatchOperation]) -> BatchResult:
138 | """Execute multiple operations in a single atomic batch.
139 |
140 | Args:
141 | operations: List of operations to execute
142 |
143 | Returns:
144 | BatchResult containing success/failure information
145 |
146 | Raises:
147 | ValidationError: If validation fails for any operation
148 | """
149 | pass
150 |
151 | @abstractmethod
152 | async def begin_transaction(self) -> None:
153 | """Begin a transaction for batch operations.
154 |
155 | This creates a savepoint that can be rolled back to if needed.
156 | """
157 | pass
158 |
159 | @abstractmethod
160 | async def rollback_transaction(self) -> None:
161 | """Rollback to the last transaction savepoint."""
162 | pass
163 |
164 | @abstractmethod
165 | async def commit_transaction(self) -> None:
166 | """Commit the current transaction."""
167 | pass
168 |
```
--------------------------------------------------------------------------------
/tests/test_validation.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for validation functionality."""
2 |
3 | import pytest
4 |
5 | from memory_mcp_server.interfaces import Entity, Relation
6 | from memory_mcp_server.validation import (
7 | EntityValidationError,
8 | KnowledgeGraphValidator,
9 | RelationValidationError,
10 | )
11 |
12 |
13 | def test_validate_batch_entities() -> None:
14 | """Test batch entity validation."""
15 | # Valid batch
16 | entities = [
17 | Entity("test1", "person", ["obs1"]),
18 | Entity("test2", "person", ["obs2"]),
19 | ]
20 | existing_names = {"existing1", "existing2"}
21 | KnowledgeGraphValidator.validate_batch_entities(entities, existing_names)
22 |
23 | # Empty batch
24 | with pytest.raises(EntityValidationError, match="Entity list cannot be empty"):
25 | KnowledgeGraphValidator.validate_batch_entities([], existing_names)
26 |
27 | # Duplicate names within batch
28 | entities = [
29 | Entity("test1", "person", ["obs1"]),
30 | Entity("test1", "person", ["obs2"]),
31 | ]
32 | with pytest.raises(EntityValidationError, match="Duplicate entity name in batch"):
33 | KnowledgeGraphValidator.validate_batch_entities(entities, existing_names)
34 |
35 | # Conflict with existing names
36 | entities = [
37 | Entity("test1", "person", ["obs1"]),
38 | Entity("existing1", "person", ["obs2"]),
39 | ]
40 | with pytest.raises(EntityValidationError, match="Entities already exist"):
41 | KnowledgeGraphValidator.validate_batch_entities(entities, existing_names)
42 |
43 | # Invalid entity type
44 | entities = [
45 | Entity("test1", "invalid-type", ["obs1"]),
46 | ]
47 | with pytest.raises(EntityValidationError, match="Invalid entity type"):
48 | KnowledgeGraphValidator.validate_batch_entities(entities, existing_names)
49 |
50 |
51 | def test_validate_batch_relations() -> None:
52 | """Test batch relation validation."""
53 | # Valid batch
54 | relations = [
55 | Relation(from_="entity1", to="entity2", relationType="knows"),
56 | Relation(from_="entity2", to="entity3", relationType="knows"),
57 | ]
58 | existing_relations = []
59 | entity_names = {"entity1", "entity2", "entity3"}
60 | KnowledgeGraphValidator.validate_batch_relations(
61 | relations, existing_relations, entity_names
62 | )
63 |
64 | # Empty batch
65 | with pytest.raises(RelationValidationError, match="Relations list cannot be empty"):
66 | KnowledgeGraphValidator.validate_batch_relations(
67 | [], existing_relations, entity_names
68 | )
69 |
70 | # Duplicate relations
71 | relations = [
72 | Relation(from_="entity1", to="entity2", relationType="knows"),
73 | Relation(from_="entity1", to="entity2", relationType="knows"), # Same relation
74 | ]
75 | with pytest.raises(RelationValidationError, match="Duplicate relation"):
76 | KnowledgeGraphValidator.validate_batch_relations(
77 | relations, existing_relations, entity_names
78 | )
79 |
80 | # Missing entities
81 | relations = [
82 | Relation("entity1", "nonexistent", "knows"),
83 | ]
84 | with pytest.raises(RelationValidationError, match="Entities not found"):
85 | KnowledgeGraphValidator.validate_batch_relations(
86 | relations, existing_relations, entity_names
87 | )
88 |
89 | # Invalid relation type
90 | relations = [
91 | Relation("entity1", "entity2", "invalid-type"),
92 | ]
93 | with pytest.raises(RelationValidationError, match="Invalid relation type"):
94 | KnowledgeGraphValidator.validate_batch_relations(
95 | relations, existing_relations, entity_names
96 | )
97 |
98 | # Self-referential relation
99 | relations = [
100 | Relation("entity1", "entity1", "knows"),
101 | ]
102 | with pytest.raises(
103 | RelationValidationError, match="Self-referential relations not allowed"
104 | ):
105 | KnowledgeGraphValidator.validate_batch_relations(
106 | relations, existing_relations, entity_names
107 | )
108 |
109 | # Cycle detection
110 | relations = [
111 | Relation("entity1", "entity2", "knows"),
112 | Relation("entity2", "entity3", "knows"),
113 | Relation("entity3", "entity1", "knows"),
114 | ]
115 | with pytest.raises(RelationValidationError, match="Circular dependency detected"):
116 | KnowledgeGraphValidator.validate_batch_relations(
117 | relations, existing_relations, entity_names
118 | )
119 |
120 |
121 | def test_validate_batch_observations() -> None:
122 | """Test batch observation validation."""
123 | # Valid batch
124 | existing_entities = {
125 | "entity1": Entity("entity1", "person", ["existing1"]),
126 | "entity2": Entity("entity2", "person", ["existing2"]),
127 | }
128 | observations_map = {
129 | "entity1": ["new1", "new2"],
130 | "entity2": ["new3"],
131 | }
132 | KnowledgeGraphValidator.validate_batch_observations(
133 | observations_map, existing_entities
134 | )
135 |
136 | # Empty batch
137 | with pytest.raises(EntityValidationError, match="Observations map cannot be empty"):
138 | KnowledgeGraphValidator.validate_batch_observations({}, existing_entities)
139 |
140 | # Missing entities
141 | observations_map = {
142 | "entity1": ["new1"],
143 | "nonexistent": ["new2"],
144 | }
145 | with pytest.raises(EntityValidationError, match="Entities not found"):
146 | KnowledgeGraphValidator.validate_batch_observations(
147 | observations_map, existing_entities
148 | )
149 |
150 | # Empty observations list is allowed (skipped)
151 | observations_map = {
152 | "entity1": [],
153 | }
154 | KnowledgeGraphValidator.validate_batch_observations(
155 | observations_map, existing_entities
156 | )
157 |
158 | # Invalid observation format
159 | observations_map = {
160 | "entity1": ["", "new2"], # Empty observation
161 | }
162 | with pytest.raises(EntityValidationError, match="Empty observation"):
163 | KnowledgeGraphValidator.validate_batch_observations(
164 | observations_map, existing_entities
165 | )
166 |
167 | # Duplicate observations
168 | observations_map = {
169 | "entity1": ["existing1", "new2"], # Duplicate with existing observation
170 | }
171 | with pytest.raises(EntityValidationError, match="Duplicate observations"):
172 | KnowledgeGraphValidator.validate_batch_observations(
173 | observations_map, existing_entities
174 | )
175 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/knowledge_graph_manager.py:
--------------------------------------------------------------------------------
```python
1 | """Knowledge graph manager that delegates to a configured backend."""
2 |
3 | import asyncio
4 | from pathlib import Path
5 | from typing import Dict, List, Optional, Union
6 |
7 | from .backends.base import Backend
8 | from .backends.jsonl import JsonlBackend
9 | from .interfaces import Entity, KnowledgeGraph, Relation, SearchOptions
10 | from .validation import KnowledgeGraphValidator, ValidationError
11 |
12 |
13 | class KnowledgeGraphManager:
14 | """Manages knowledge graph operations through a configured backend."""
15 |
16 | backend: Backend
17 | _write_lock: asyncio.Lock
18 |
19 | def __init__(
20 | self,
21 | backend: Union[Backend, Path],
22 | cache_ttl: int = 60,
23 | ):
24 | """Initialize the KnowledgeGraphManager.
25 |
26 | Args:
27 | backend: Either a Backend instance or Path to use default JSONL backend
28 | cache_ttl: Cache TTL in seconds (only used for JSONL backend)
29 | """
30 | if isinstance(backend, Path):
31 | self.backend = JsonlBackend(backend, cache_ttl)
32 | else:
33 | self.backend = backend
34 | self._write_lock = asyncio.Lock()
35 |
36 | async def initialize(self) -> None:
37 | """Initialize the backend connection."""
38 | await self.backend.initialize()
39 |
40 | async def close(self) -> None:
41 | """Close the backend connection."""
42 | await self.backend.close()
43 |
44 | async def create_entities(self, entities: List[Entity]) -> List[Entity]:
45 | """Create multiple new entities.
46 |
47 | Args:
48 | entities: List of entities to create
49 |
50 | Returns:
51 | List of successfully created entities
52 |
53 | Raises:
54 | ValidationError: If any entity fails validation
55 | """
56 | # Get existing entities for validation
57 | graph = await self.read_graph()
58 | existing_names = {entity.name for entity in graph.entities}
59 |
60 | # Validate all entities in one pass
61 | KnowledgeGraphValidator.validate_batch_entities(entities, existing_names)
62 |
63 | async with self._write_lock:
64 | return await self.backend.create_entities(entities)
65 |
66 | async def delete_entities(self, entity_names: List[str]) -> List[str]:
67 | """Delete multiple existing entities by name.
68 |
69 | Args:
70 | entity_names: List of entity names to delete
71 |
72 | Returns:
73 | List of successfully deleted entity names
74 |
75 | Raises:
76 | ValueError: If entity_names list is empty
77 | EntityNotFoundError: If any entity is not found in the graph
78 | FileAccessError: If there are file system issues (backend specific)
79 | """
80 | if not entity_names:
81 | raise ValueError("Entity names list cannot be empty")
82 |
83 | async with self._write_lock:
84 | return await self.backend.delete_entities(entity_names)
85 |
86 | async def delete_relations(self, from_: str, to: str) -> None:
87 | """Delete relations between two entities.
88 |
89 | Args:
90 | from_: Source entity name
91 | to: Target entity name
92 |
93 | Raises:
94 | EntityNotFoundError: If either entity is not found
95 | """
96 | async with self._write_lock:
97 | return await self.backend.delete_relations(from_, to)
98 |
99 | async def create_relations(self, relations: List[Relation]) -> List[Relation]:
100 | """Create multiple new relations.
101 |
102 | Args:
103 | relations: List of relations to create
104 |
105 | Returns:
106 | List of successfully created relations
107 |
108 | Raises:
109 | ValidationError: If any relation fails validation
110 | EntityNotFoundError: If referenced entities don't exist
111 | """
112 | # Get existing graph for validation
113 | graph = await self.read_graph()
114 | existing_names = {entity.name for entity in graph.entities}
115 |
116 | # Validate all relations in one pass
117 | KnowledgeGraphValidator.validate_batch_relations(
118 | relations, graph.relations, existing_names
119 | )
120 |
121 | async with self._write_lock:
122 | return await self.backend.create_relations(relations)
123 |
124 | async def read_graph(self) -> KnowledgeGraph:
125 | """Read the entire knowledge graph.
126 |
127 | Returns:
128 | Current state of the knowledge graph
129 | """
130 | return await self.backend.read_graph()
131 |
132 | async def search_nodes(
133 | self, query: str, options: Optional[SearchOptions] = None
134 | ) -> KnowledgeGraph:
135 | """Search for entities and relations matching query.
136 |
137 | Args:
138 | query: Search query string
139 | options: Optional SearchOptions for configuring search behavior.
140 | If None, uses exact substring matching.
141 |
142 | Returns:
143 | KnowledgeGraph containing matches
144 |
145 | Raises:
146 | ValueError: If query is empty or options are invalid
147 | """
148 | return await self.backend.search_nodes(query, options)
149 |
150 | async def flush(self) -> None:
151 | """Ensure any pending changes are persisted."""
152 | await self.backend.flush()
153 |
154 | async def add_observations(self, entity_name: str, observations: List[str]) -> None:
155 | """Add observations to an existing entity.
156 |
157 | Args:
158 | entity_name: Name of the entity to add observations to
159 | observations: List of observations to add
160 |
161 | Raises:
162 | EntityNotFoundError: If the entity is not found
163 | ValidationError: If observations are invalid
164 | ValueError: If observations list is empty
165 | """
166 | if not observations:
167 | raise ValueError("Observations list cannot be empty")
168 |
169 | # Validate new observations
170 | KnowledgeGraphValidator.validate_observations(observations)
171 |
172 | # Get existing entity to check for duplicate observations
173 | graph = await self.read_graph()
174 | entity = next((e for e in graph.entities if e.name == entity_name), None)
175 | if not entity:
176 | raise ValidationError(f"Entity not found: {entity_name}")
177 |
178 | # Check for duplicates against existing observations
179 | existing_observations = set(entity.observations)
180 | duplicates = [obs for obs in observations if obs in existing_observations]
181 | if duplicates:
182 | raise ValidationError(f"Duplicate observations: {', '.join(duplicates)}")
183 |
184 | async with self._write_lock:
185 | await self.backend.add_observations(entity_name, observations)
186 |
187 | async def add_batch_observations(
188 | self, observations_map: Dict[str, List[str]]
189 | ) -> None:
190 | """Add observations to multiple entities in a single operation.
191 |
192 | Args:
193 | observations_map: Dictionary mapping entity names to lists of observations
194 |
195 | Raises:
196 | ValidationError: If any observations are invalid
197 | EntityNotFoundError: If any entity is not found
198 | ValueError: If observations_map is empty
199 | """
200 | # Get existing graph for validation
201 | graph = await self.read_graph()
202 | entities_map = {entity.name: entity for entity in graph.entities}
203 |
204 | # Validate all observations in one pass
205 | KnowledgeGraphValidator.validate_batch_observations(
206 | observations_map, entities_map
207 | )
208 |
209 | # All validation passed, perform the batch update
210 | async with self._write_lock:
211 | await self.backend.add_batch_observations(observations_map)
212 |
```
--------------------------------------------------------------------------------
/tests/test_knowledge_graph_manager.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for KnowledgeGraphManager."""
2 |
3 | import asyncio
4 | from typing import List
5 |
6 | import pytest
7 |
8 | from memory_mcp_server.interfaces import Entity, Relation
9 | from memory_mcp_server.knowledge_graph_manager import KnowledgeGraphManager
10 | from memory_mcp_server.validation import EntityValidationError, ValidationError
11 |
12 |
13 | @pytest.mark.asyncio(scope="function")
14 | async def test_create_entities(
15 | knowledge_graph_manager: KnowledgeGraphManager,
16 | ) -> None:
17 | """Test the creation of new entities in the knowledge graph.
18 |
19 | This test verifies that:
20 | 1. Entities can be created successfully
21 | 2. The created entities are stored in the graph
22 | 3. Entity attributes are preserved correctly
23 | """
24 | print("\nStarting test_create_entities")
25 | entities = [
26 | Entity(
27 | name="john-doe",
28 | entityType="person",
29 | observations=["loves pizza"],
30 | )
31 | ]
32 |
33 | created_entities = await knowledge_graph_manager.create_entities(entities)
34 | print("Created entities")
35 | assert len(created_entities) == 1
36 |
37 | graph = await knowledge_graph_manager.read_graph()
38 | print("Read graph")
39 | assert len(graph.entities) == 1
40 | assert graph.entities[0].name == "john-doe"
41 |
42 | print("test_create_entities: Complete")
43 |
44 |
45 | @pytest.mark.asyncio(scope="function")
46 | async def test_create_relations(
47 | knowledge_graph_manager: KnowledgeGraphManager,
48 | ) -> None:
49 | """Test the creation of relations between entities.
50 |
51 | This test verifies that:
52 | 1. Relations can be created between existing entities
53 | 2. Relations are stored properly in the graph
54 | 3. Relation properties (from, to, type) are preserved
55 | """
56 | print("\nStarting test_create_relations")
57 |
58 | entities = [
59 | Entity(name="alice-smith", entityType="person", observations=["test"]),
60 | Entity(name="bob-jones", entityType="person", observations=["test"]),
61 | ]
62 | await knowledge_graph_manager.create_entities(entities)
63 | print("Created entities")
64 |
65 | relations = [Relation(from_="alice-smith", to="bob-jones", relationType="knows")]
66 | created_relations = await knowledge_graph_manager.create_relations(relations)
67 | print("Created relations")
68 |
69 | assert len(created_relations) == 1
70 | assert created_relations[0].from_ == "alice-smith"
71 | assert created_relations[0].to == "bob-jones"
72 |
73 | print("test_create_relations: Complete")
74 |
75 |
76 | @pytest.mark.asyncio(scope="function")
77 | async def test_search_functionality(
78 | knowledge_graph_manager: KnowledgeGraphManager,
79 | ) -> None:
80 | """Test the search functionality across different criteria.
81 |
82 | This test verifies searching by:
83 | 1. Entity name
84 | 2. Entity type
85 | 3. Observation content
86 | 4. Case insensitivity
87 | """
88 | # Create test entities with varied data
89 | entities = [
90 | Entity(
91 | name="search-test-1",
92 | entityType="project",
93 | observations=["keyword1", "unique1"],
94 | ),
95 | Entity(name="search-test-2", entityType="project", observations=["keyword2"]),
96 | Entity(name="different-type", entityType="document", observations=["keyword1"]),
97 | ]
98 | await knowledge_graph_manager.create_entities(entities)
99 |
100 | # Test search by name
101 | name_result = await knowledge_graph_manager.search_nodes("search-test")
102 | assert len(name_result.entities) == 2
103 | assert all("search-test" in e.name for e in name_result.entities)
104 |
105 | # Test search by type
106 | type_result = await knowledge_graph_manager.search_nodes("document")
107 | assert len(type_result.entities) == 1
108 | assert type_result.entities[0].name == "different-type"
109 |
110 | # Test search by observation
111 | obs_result = await knowledge_graph_manager.search_nodes("keyword1")
112 | assert len(obs_result.entities) == 2
113 | assert any(e.name == "search-test-1" for e in obs_result.entities)
114 | assert any(e.name == "different-type" for e in obs_result.entities)
115 |
116 |
117 | @pytest.mark.asyncio(scope="function")
118 | async def test_error_handling(
119 | knowledge_graph_manager: KnowledgeGraphManager,
120 | ) -> None:
121 | """Test error handling in various scenarios.
122 |
123 | This test verifies proper error handling for:
124 | 1. Invalid entity names
125 | 2. Non-existent entities in relations
126 | 3. Empty delete requests
127 | 4. Deleting non-existent entities
128 | """
129 | # Test invalid entity name
130 | with pytest.raises(EntityValidationError, match="Invalid entity name"):
131 | await knowledge_graph_manager.create_entities(
132 | [Entity(name="Invalid Name", entityType="person", observations=[])]
133 | )
134 |
135 | # Test relation with non-existent entities
136 | with pytest.raises(ValidationError, match="Entities not found"):
137 | await knowledge_graph_manager.create_relations(
138 | [
139 | Relation(
140 | from_="non-existent", to="also-non-existent", relationType="knows"
141 | )
142 | ]
143 | )
144 |
145 | # Test deleting empty list
146 | with pytest.raises(ValueError, match="cannot be empty"):
147 | await knowledge_graph_manager.delete_entities([])
148 |
149 | # Test deleting non-existent entities
150 | result = await knowledge_graph_manager.delete_entities(["non-existent"])
151 | assert result == []
152 |
153 |
154 | @pytest.mark.asyncio(scope="function")
155 | async def test_graph_persistence(
156 | knowledge_graph_manager: KnowledgeGraphManager,
157 | ) -> None:
158 | """Test that graph changes persist after reloading.
159 |
160 | This test verifies that:
161 | 1. Created entities persist after a graph reload
162 | 2. Added relations persist after a graph reload
163 | 3. New observations persist after a graph reload
164 | """
165 | # Create initial data
166 | entity = Entity(
167 | name="persistence-test", entityType="project", observations=["initial"]
168 | )
169 | await knowledge_graph_manager.create_entities([entity])
170 |
171 | # Force a reload of the graph by clearing the cache
172 | knowledge_graph_manager._cache = None # type: ignore
173 |
174 | # Verify data persists
175 | graph = await knowledge_graph_manager.read_graph()
176 | assert len(graph.entities) == 1
177 | assert graph.entities[0].name == "persistence-test"
178 | assert "initial" in graph.entities[0].observations
179 |
180 |
181 | @pytest.mark.asyncio(scope="function")
182 | async def test_concurrent_operations(
183 | knowledge_graph_manager: KnowledgeGraphManager,
184 | ) -> None:
185 | """Test handling of concurrent operations.
186 |
187 | This test verifies that:
188 | 1. Multiple concurrent entity creations/deletions are handled properly
189 | 2. Cache remains consistent under concurrent operations
190 | 3. No data is lost during concurrent writes
191 | """
192 |
193 | # Create multiple entities concurrently
194 | async def create_entity(index: int) -> List[Entity]:
195 | entity = Entity(
196 | name=f"concurrent-{index}",
197 | entityType="project",
198 | observations=[f"obs{index}"],
199 | )
200 | return await knowledge_graph_manager.create_entities([entity])
201 |
202 | # Delete entities concurrently
203 | async def delete_entity(index: int) -> List[str]:
204 | return await knowledge_graph_manager.delete_entities([f"concurrent-{index}"])
205 |
206 | # First create 5 entities
207 | create_tasks = [create_entity(i) for i in range(5)]
208 | create_results = await asyncio.gather(*create_tasks)
209 | assert all(len(r) == 1 for r in create_results)
210 |
211 | # Then concurrently delete 3 of them while creating 2 more
212 | delete_tasks = [delete_entity(i) for i in range(3)]
213 | create_tasks = [create_entity(i) for i in range(5, 7)]
214 | delete_results, create_results = await asyncio.gather(
215 | asyncio.gather(*delete_tasks), asyncio.gather(*create_tasks)
216 | )
217 |
218 | # Verify deletions
219 | assert all(len(r) == 1 for r in delete_results)
220 |
221 | # Verify creations
222 | assert all(len(r) == 1 for r in create_results)
223 |
224 | # Verify final state
225 | graph = await knowledge_graph_manager.read_graph()
226 | expected_names = {"concurrent-5", "concurrent-6", "concurrent-3", "concurrent-4"}
227 | assert len(graph.entities) == 4
228 | assert all(e.name in expected_names for e in graph.entities)
229 |
```
--------------------------------------------------------------------------------
/tests/test_server.py:
--------------------------------------------------------------------------------
```python
1 | """Tests for the MCP server implementation."""
2 |
3 | import json
4 | from typing import Any, Dict, List, Protocol, cast
5 |
6 | import pytest
7 | from mcp.types import TextContent
8 |
9 | from memory_mcp_server.exceptions import EntityNotFoundError
10 | from memory_mcp_server.interfaces import Entity, KnowledgeGraph, Relation
11 |
12 |
13 | # Mock tools and handlers
14 | def handle_error(error: Exception) -> str:
15 | """Mock error handler."""
16 | if isinstance(error, EntityNotFoundError):
17 | return str(error)
18 | return f"Error: {str(error)}"
19 |
20 |
21 | async def create_entities_handler(
22 | manager: Any, arguments: Dict[str, Any]
23 | ) -> List[TextContent]:
24 | """Mock create entities handler."""
25 | entities = [
26 | Entity(
27 | name=e["name"],
28 | entityType=e["entityType"],
29 | observations=e.get("observations", []),
30 | )
31 | for e in arguments["entities"]
32 | ]
33 | result = await manager.create_entities(entities)
34 | return [TextContent(type="text", text=json.dumps([e.to_dict() for e in result]))]
35 |
36 |
37 | async def create_relations_handler(
38 | manager: Any, arguments: Dict[str, Any]
39 | ) -> List[TextContent]:
40 | """Mock create relations handler."""
41 | relations = [
42 | Relation(from_=r["from"], to=r["to"], relationType=r["relationType"])
43 | for r in arguments["relations"]
44 | ]
45 | result = await manager.create_relations(relations)
46 | return [TextContent(type="text", text=json.dumps([r.to_dict() for r in result]))]
47 |
48 |
49 | async def add_observations_handler(
50 | manager: Any, arguments: Dict[str, Any]
51 | ) -> List[TextContent]:
52 | """Mock add observations handler."""
53 | await manager.add_observations(arguments["entity"], arguments["observations"])
54 | return [TextContent(type="text", text=json.dumps({"success": True}))]
55 |
56 |
57 | async def delete_entities_handler(
58 | manager: Any, arguments: Dict[str, Any]
59 | ) -> List[TextContent]:
60 | """Mock delete entities handler."""
61 | await manager.delete_entities(arguments["names"])
62 | return [TextContent(type="text", text=json.dumps({"success": True}))]
63 |
64 |
65 | async def delete_observations_handler(
66 | manager: Any, arguments: Dict[str, Any]
67 | ) -> List[TextContent]:
68 | """Mock delete observations handler."""
69 | await manager.delete_observations(arguments["entity"], arguments["observations"])
70 | return [TextContent(type="text", text=json.dumps({"success": True}))]
71 |
72 |
73 | async def delete_relations_handler(
74 | manager: Any, arguments: Dict[str, Any]
75 | ) -> List[TextContent]:
76 | """Mock delete relations handler."""
77 | await manager.delete_relations(arguments["from"], arguments["to"])
78 | return [TextContent(type="text", text=json.dumps({"success": True}))]
79 |
80 |
81 | async def read_graph_handler(
82 | manager: Any, arguments: Dict[str, Any]
83 | ) -> List[TextContent]:
84 | """Mock read graph handler."""
85 | graph = await manager.read_graph()
86 | return [TextContent(type="text", text=json.dumps(graph.to_dict()))]
87 |
88 |
89 | async def search_nodes_handler(
90 | manager: Any, arguments: Dict[str, Any]
91 | ) -> List[TextContent]:
92 | """Mock search nodes handler."""
93 | result = await manager.search_nodes(arguments["query"])
94 | return [TextContent(type="text", text=json.dumps(result.to_dict()))]
95 |
96 |
97 | TOOLS: Dict[str, Any] = {
98 | "create_entities": create_entities_handler,
99 | "create_relations": create_relations_handler,
100 | "add_observations": add_observations_handler,
101 | "delete_entities": delete_entities_handler,
102 | "delete_relations": delete_relations_handler,
103 | "read_graph": read_graph_handler,
104 | "search_nodes": search_nodes_handler,
105 | }
106 |
107 |
108 | class MockManagerProtocol(Protocol):
109 | """Protocol defining the interface for MockManager."""
110 |
111 | async def create_entities(self, entities: List[Entity]) -> List[Entity]:
112 | ...
113 |
114 | async def create_relations(self, relations: List[Relation]) -> List[Relation]:
115 | ...
116 |
117 | async def add_observations(self, entity: str, observations: List[str]) -> None:
118 | ...
119 |
120 | async def delete_entities(self, names: List[str]) -> None:
121 | ...
122 |
123 | async def delete_relations(self, from_: str, to: str) -> None:
124 | ...
125 |
126 | async def read_graph(self) -> KnowledgeGraph:
127 | ...
128 |
129 | async def search_nodes(self, query: str) -> KnowledgeGraph:
130 | ...
131 |
132 | async def flush(self) -> None:
133 | ...
134 |
135 |
136 | @pytest.fixture(scope="function")
137 | def mock_manager() -> MockManagerProtocol:
138 | """Create a mock manager for testing."""
139 |
140 | class MockManager:
141 | def __init__(self) -> None:
142 | self.entities: List[Entity] = []
143 | self.relations: List[Relation] = []
144 |
145 | async def create_entities(self, entities: List[Entity]) -> List[Entity]:
146 | self.entities.extend(entities)
147 | return entities
148 |
149 | async def create_relations(self, relations: List[Relation]) -> List[Relation]:
150 | return relations
151 |
152 | async def add_observations(self, entity: str, observations: List[str]) -> None:
153 | if entity == "MissingEntity":
154 | raise EntityNotFoundError(entity)
155 |
156 | async def delete_entities(self, names: List[str]) -> None:
157 | for name in names:
158 | if name == "MissingEntity":
159 | raise EntityNotFoundError(name)
160 |
161 | async def delete_relations(self, from_: str, to: str) -> None:
162 | if from_ == "MissingEntity" or to == "MissingEntity":
163 | raise EntityNotFoundError("MissingEntity")
164 |
165 | async def read_graph(self) -> KnowledgeGraph:
166 | # Return current state including any created entities
167 | return KnowledgeGraph(
168 | entities=self.entities,
169 | relations=self.relations,
170 | )
171 |
172 | async def search_nodes(self, query: str) -> KnowledgeGraph:
173 | # If query matches "TestEntity", return graph; otherwise empty
174 | if "TestEntity".lower() in query.lower():
175 | return await self.read_graph()
176 | return KnowledgeGraph(entities=[], relations=[])
177 |
178 | async def open_nodes(self, names: List[str]) -> KnowledgeGraph:
179 | # If "TestEntity" is requested, return it
180 | if "TestEntity" in names:
181 | return await self.read_graph()
182 | return KnowledgeGraph(entities=[], relations=[])
183 |
184 | async def flush(self) -> None:
185 | """Mock flush method to comply with interface."""
186 | pass
187 |
188 | return MockManager()
189 |
190 |
191 | @pytest.mark.asyncio
192 | async def test_create_entities(mock_manager: MockManagerProtocol) -> None:
193 | """Test creating entities through the MCP server."""
194 | handler = cast(Any, TOOLS["create_entities"])
195 | arguments = {
196 | "entities": [
197 | {"name": "E1", "entityType": "TypeX", "observations": ["obsA"]},
198 | {"name": "E2", "entityType": "TypeY", "observations": ["obsB"]},
199 | ]
200 | }
201 | result = await handler(mock_manager, arguments)
202 | data = json.loads(result[0].text)
203 | assert len(data) == 2
204 | assert data[0]["name"] == "E1"
205 | assert data[1]["observations"] == ["obsB"]
206 |
207 |
208 | @pytest.mark.asyncio
209 | async def test_create_relations(mock_manager: MockManagerProtocol) -> None:
210 | """Test creating relations through the MCP server."""
211 | handler = cast(Any, TOOLS["create_relations"])
212 | arguments = {"relations": [{"from": "E1", "to": "E2", "relationType": "likes"}]}
213 | result = await handler(mock_manager, arguments)
214 | data = json.loads(result[0].text)
215 | assert len(data) == 1
216 | assert data[0]["from"] == "E1"
217 | assert data[0]["to"] == "E2"
218 |
219 |
220 | @pytest.mark.asyncio
221 | async def test_add_observations(mock_manager: MockManagerProtocol) -> None:
222 | """Test adding observations through the MCP server."""
223 | handler = cast(Any, TOOLS["add_observations"])
224 | arguments = {"entity": "E1", "observations": ["newObs"]}
225 | result = await handler(mock_manager, arguments)
226 | data = json.loads(result[0].text)
227 | assert data["success"] is True
228 |
229 |
230 | @pytest.mark.asyncio
231 | async def test_delete_entities(mock_manager: MockManagerProtocol) -> None:
232 | """Test deleting entities through the MCP server."""
233 | handler = cast(Any, TOOLS["delete_entities"])
234 | arguments = {"names": ["E1"]}
235 | result = await handler(mock_manager, arguments)
236 | data = json.loads(result[0].text)
237 | assert data["success"] is True
238 |
239 |
240 | @pytest.mark.asyncio
241 | async def test_delete_relations(mock_manager: MockManagerProtocol) -> None:
242 | """Test deleting relations through the MCP server."""
243 | handler = cast(Any, TOOLS["delete_relations"])
244 | arguments = {"from": "E1", "to": "E2"}
245 | result = await handler(mock_manager, arguments)
246 | data = json.loads(result[0].text)
247 | assert data["success"] is True
248 |
249 |
250 | @pytest.mark.asyncio
251 | async def test_read_graph(mock_manager: MockManagerProtocol) -> None:
252 | """Test reading the graph through the MCP server."""
253 | # Create test entity first
254 | await mock_manager.create_entities(
255 | [
256 | Entity(
257 | name="TestEntity",
258 | entityType="TestType",
259 | observations=["test observation"],
260 | )
261 | ]
262 | )
263 |
264 | handler = cast(Any, TOOLS["read_graph"])
265 | arguments: Dict[str, Any] = {}
266 | result = await handler(mock_manager, arguments)
267 | data = json.loads(result[0].text)
268 |
269 | assert len(data["entities"]) == 1
270 | assert data["entities"][0]["name"] == "TestEntity"
271 | assert isinstance(
272 | data["entities"][0]["observations"], (list, tuple)
273 | ) # Allow both list and tuple
274 |
275 |
276 | @pytest.mark.asyncio
277 | async def test_save_graph(mock_manager: MockManagerProtocol) -> None:
278 | """Test saving the graph through the MCP server."""
279 | # First create some test data
280 | await mock_manager.create_entities(
281 | [Entity(name="TestSave", entityType="TestType", observations=[])]
282 | )
283 |
284 | # Explicitly save the graph
285 | await mock_manager.flush()
286 |
287 | # Read back the graph
288 | graph = await mock_manager.read_graph()
289 |
290 | # Verify our test entity exists
291 | assert any(e.name == "TestSave" for e in graph.entities)
292 | # Verify the save preserved the structure
293 | assert isinstance(
294 | graph.entities[0].observations, (list, tuple)
295 | ) # Allow both list and tuple for immutability
296 |
297 |
298 | @pytest.mark.asyncio
299 | async def test_search_nodes(mock_manager: MockManagerProtocol) -> None:
300 | """Test searching nodes through the MCP server."""
301 | # Create test entity first
302 | await mock_manager.create_entities(
303 | [
304 | Entity(
305 | name="TestEntity",
306 | entityType="TestType",
307 | observations=["test observation"],
308 | )
309 | ]
310 | )
311 |
312 | handler = cast(Any, TOOLS["search_nodes"])
313 | arguments = {"query": "TestEntity"}
314 | result = await handler(mock_manager, arguments)
315 | data = json.loads(result[0].text)
316 |
317 | assert len(data["entities"]) == 1
318 | assert data["entities"][0]["name"] == "TestEntity"
319 | assert isinstance(
320 | data["entities"][0]["observations"], (list, tuple)
321 | ) # Allow both list and tuple
322 |
323 |
324 | def test_error_handling() -> None:
325 | """Test error handling functionality."""
326 | msg = handle_error(EntityNotFoundError("MissingEntity"))
327 | assert "Entity 'MissingEntity' not found in the graph" in msg
328 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/validation.py:
--------------------------------------------------------------------------------
```python
1 | """Validation module for knowledge graph consistency."""
2 |
3 | import re
4 | from typing import List, Optional, Set
5 |
6 | from .interfaces import Entity, KnowledgeGraph, Relation
7 |
8 |
9 | class ValidationError(Exception):
10 | """Base class for validation errors."""
11 |
12 | pass
13 |
14 |
15 | class EntityValidationError(ValidationError):
16 | """Raised when entity validation fails."""
17 |
18 | pass
19 |
20 |
21 | class RelationValidationError(ValidationError):
22 | """Raised when relation validation fails."""
23 |
24 | pass
25 |
26 |
27 | class KnowledgeGraphValidator:
28 | """Validator for ensuring knowledge graph consistency."""
29 |
30 | # Constants for validation rules
31 | ENTITY_NAME_PATTERN = r"^[a-z][a-z0-9-]{0,99}$"
32 | MAX_OBSERVATION_LENGTH = 500
33 | VALID_ENTITY_TYPES = {
34 | "person",
35 | "concept",
36 | "project",
37 | "document",
38 | "tool",
39 | "organization",
40 | "location",
41 | "event",
42 | }
43 | VALID_RELATION_TYPES = {
44 | "knows",
45 | "contains",
46 | "uses",
47 | "created",
48 | "belongs-to",
49 | "depends-on",
50 | "related-to",
51 | }
52 |
53 | @classmethod
54 | def validate_entity_name(cls, name: str) -> None:
55 | """Validate entity name follows naming convention.
56 |
57 | Args:
58 | name: Entity name to validate
59 |
60 | Raises:
61 | EntityValidationError: If name is invalid
62 | """
63 | if not re.match(cls.ENTITY_NAME_PATTERN, name):
64 | raise EntityValidationError(
65 | f"Invalid entity name '{name}'. Must start with lowercase letter, "
66 | "contain only lowercase letters, numbers and hyphens, "
67 | "and be 1-100 characters long."
68 | )
69 |
70 | @classmethod
71 | def validate_entity_type(cls, entity_type: str) -> None:
72 | """Validate entity type is from allowed set.
73 |
74 | Args:
75 | entity_type: Entity type to validate
76 |
77 | Raises:
78 | EntityValidationError: If type is invalid
79 | """
80 | if entity_type not in cls.VALID_ENTITY_TYPES:
81 | raise EntityValidationError(
82 | f"Invalid entity type '{entity_type}'. Must be one of: "
83 | f"{', '.join(sorted(cls.VALID_ENTITY_TYPES))}"
84 | )
85 |
86 | @classmethod
87 | def validate_observations(cls, observations: List[str]) -> None:
88 | """Validate entity observations.
89 |
90 | Args:
91 | observations: List of observations to validate
92 |
93 | Raises:
94 | EntityValidationError: If any observation is invalid
95 | """
96 | seen = set()
97 | for obs in observations:
98 | if not obs:
99 | raise EntityValidationError("Empty observation")
100 | if len(obs) > cls.MAX_OBSERVATION_LENGTH:
101 | raise EntityValidationError(
102 | f"Observation exceeds length of {cls.MAX_OBSERVATION_LENGTH} chars"
103 | )
104 | if obs in seen:
105 | raise EntityValidationError(f"Duplicate observation: {obs}")
106 | seen.add(obs)
107 |
108 | @classmethod
109 | def validate_entity(cls, entity: Entity) -> None:
110 | """Validate an entity.
111 |
112 | Args:
113 | entity: Entity to validate
114 |
115 | Raises:
116 | EntityValidationError: If entity is invalid
117 | """
118 | cls.validate_entity_name(entity.name)
119 | cls.validate_entity_type(entity.entityType)
120 | cls.validate_observations(list(entity.observations))
121 |
122 | @classmethod
123 | def validate_relation_type(cls, relation_type: str) -> None:
124 | """Validate relation type is from allowed set.
125 |
126 | Args:
127 | relation_type: Relation type to validate
128 |
129 | Raises:
130 | RelationValidationError: If type is invalid
131 | """
132 | if relation_type not in cls.VALID_RELATION_TYPES:
133 | valid_types = ", ".join(sorted(cls.VALID_RELATION_TYPES))
134 | raise RelationValidationError(
135 | f"Invalid relation type '{relation_type}'. Valid types: {valid_types}"
136 | )
137 |
138 | @classmethod
139 | def validate_relation(cls, relation: Relation) -> None:
140 | """Validate a relation.
141 |
142 | Args:
143 | relation: Relation to validate
144 |
145 | Raises:
146 | RelationValidationError: If relation is invalid
147 | """
148 | if relation.from_ == relation.to:
149 | raise RelationValidationError("Self-referential relations not allowed")
150 | cls.validate_relation_type(relation.relationType)
151 |
152 | @classmethod
153 | def validate_no_cycles(
154 | cls,
155 | relations: List[Relation],
156 | existing_relations: Optional[List[Relation]] = None,
157 | ) -> None:
158 | """Validate that relations don't create cycles.
159 |
160 | Args:
161 | relations: New relations to validate
162 | existing_relations: Optional list of existing relations to check against
163 |
164 | Raises:
165 | RelationValidationError: If cycles are detected
166 | """
167 | # Build adjacency list
168 | graph: dict[str, Set[str]] = {}
169 | all_relations = list(relations)
170 | if existing_relations:
171 | all_relations.extend(existing_relations)
172 |
173 | for rel in all_relations:
174 | if rel.from_ not in graph:
175 | graph[rel.from_] = set()
176 | graph[rel.from_].add(rel.to)
177 |
178 | # Check for cycles using DFS
179 | def has_cycle(node: str, visited: Set[str], path: Set[str]) -> bool:
180 | visited.add(node)
181 | path.add(node)
182 |
183 | for neighbor in graph.get(node, set()):
184 | if neighbor not in visited:
185 | if has_cycle(neighbor, visited, path):
186 | return True
187 | elif neighbor in path:
188 | return True
189 |
190 | path.remove(node)
191 | return False
192 |
193 | visited: Set[str] = set()
194 | path: Set[str] = set()
195 |
196 | for node in graph:
197 | if node not in visited:
198 | if has_cycle(node, visited, path):
199 | raise RelationValidationError(
200 | "Circular dependency detected in relations"
201 | )
202 |
203 | @classmethod
204 | def validate_graph(cls, graph: KnowledgeGraph) -> None:
205 | """Validate entire knowledge graph.
206 |
207 | Args:
208 | graph: Knowledge graph to validate
209 |
210 | Raises:
211 | ValidationError: If any validation fails
212 | """
213 | # Validate all entities
214 | entity_names = set()
215 | for entity in graph.entities:
216 | cls.validate_entity(entity)
217 | if entity.name in entity_names:
218 | raise EntityValidationError(f"Duplicate entity name: {entity.name}")
219 | entity_names.add(entity.name)
220 |
221 | # Validate all relations
222 | for relation in graph.relations:
223 | cls.validate_relation(relation)
224 | if relation.from_ not in entity_names:
225 | raise RelationValidationError(
226 | f"Source entity '{relation.from_}' not found in graph"
227 | )
228 | if relation.to not in entity_names:
229 | raise RelationValidationError(
230 | f"Target entity '{relation.to}' not found in graph"
231 | )
232 |
233 | # Check for cycles
234 | cls.validate_no_cycles(graph.relations)
235 |
236 | @classmethod
237 | def validate_batch_entities(
238 | cls, entities: List[Entity], existing_names: Set[str]
239 | ) -> None:
240 | """Validate a batch of entities efficiently.
241 |
242 | Args:
243 | entities: List of entities to validate
244 | existing_names: Set of existing entity names
245 |
246 | Raises:
247 | EntityValidationError: If validation fails
248 | """
249 | if not entities:
250 | raise EntityValidationError("Entity list cannot be empty")
251 |
252 | # Check for duplicates within the batch
253 | new_names = set()
254 | for entity in entities:
255 | if entity.name in new_names:
256 | raise EntityValidationError(
257 | f"Duplicate entity name in batch: {entity.name}"
258 | )
259 | new_names.add(entity.name)
260 |
261 | # Check for conflicts with existing entities
262 | conflicts = new_names.intersection(existing_names)
263 | if conflicts:
264 | raise EntityValidationError(
265 | f"Entities already exist: {', '.join(conflicts)}"
266 | )
267 |
268 | # Validate all entities in one pass
269 | for entity in entities:
270 | cls.validate_entity(entity)
271 |
272 | @classmethod
273 | def validate_batch_relations(
274 | cls,
275 | relations: List[Relation],
276 | existing_relations: List[Relation],
277 | entity_names: Set[str],
278 | ) -> None:
279 | """Validate a batch of relations efficiently.
280 |
281 | Args:
282 | relations: List of relations to validate
283 | existing_relations: List of existing relations
284 | entity_names: Set of valid entity names
285 |
286 | Raises:
287 | RelationValidationError: If validation fails
288 | """
289 | if not relations:
290 | raise RelationValidationError("Relations list cannot be empty")
291 |
292 | # Track relation keys to prevent duplicates
293 | seen_relations: Set[tuple[str, str, str]] = set()
294 |
295 | # Validate all relations in one pass
296 | missing_entities = set()
297 | for relation in relations:
298 | # Basic validation
299 | cls.validate_relation(relation)
300 |
301 | # Check for duplicate relations
302 | key = (relation.from_, relation.to, relation.relationType)
303 | if key in seen_relations:
304 | raise RelationValidationError(
305 | f"Duplicate relation: {relation.from_} -> {relation.to}"
306 | )
307 | seen_relations.add(key)
308 |
309 | # Collect missing entities
310 | if relation.from_ not in entity_names:
311 | missing_entities.add(relation.from_)
312 | if relation.to not in entity_names:
313 | missing_entities.add(relation.to)
314 |
315 | # Report all missing entities at once
316 | if missing_entities:
317 | raise RelationValidationError(
318 | f"Entities not found: {', '.join(missing_entities)}"
319 | )
320 |
321 | # Check for cycles including existing relations
322 | cls.validate_no_cycles(relations, existing_relations)
323 |
324 | @classmethod
325 | def validate_batch_observations(
326 | cls,
327 | observations_map: dict[str, List[str]],
328 | existing_entities: dict[str, Entity],
329 | ) -> None:
330 | """Validate a batch of observations efficiently.
331 |
332 | Args:
333 | observations_map: Dictionary mapping entity names to lists of observations
334 | existing_entities: Dictionary of existing entities
335 |
336 | Raises:
337 | EntityValidationError: If validation fails
338 | """
339 | if not observations_map:
340 | raise EntityValidationError("Observations map cannot be empty")
341 |
342 | # Check for missing entities first
343 | missing_entities = [
344 | name for name in observations_map if name not in existing_entities
345 | ]
346 | if missing_entities:
347 | raise EntityValidationError(
348 | f"Entities not found: {', '.join(missing_entities)}"
349 | )
350 |
351 | # Validate all observations in one pass
352 | for entity_name, observations in observations_map.items():
353 | if not observations:
354 | continue
355 |
356 | # Validate observation format
357 | cls.validate_observations(observations)
358 |
359 | # Check for duplicates against existing observations
360 | entity = existing_entities[entity_name]
361 | existing_observations = set(entity.observations)
362 | duplicates = [obs for obs in observations if obs in existing_observations]
363 | if duplicates:
364 | raise EntityValidationError(
365 | f"Duplicate observations for {entity_name}: {', '.join(duplicates)}"
366 | )
367 |
```
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """Memory MCP server using FastMCP."""
3 |
4 | import os
5 | from pathlib import Path
6 | from typing import Any, Dict, List, Optional
7 |
8 | from loguru import logger as logging
9 | from mcp.server.fastmcp import Context, FastMCP
10 | from mcp.server.fastmcp.prompts.base import Message, UserMessage
11 | from pydantic import BaseModel
12 |
13 | from memory_mcp_server.interfaces import Entity, Relation
14 | from memory_mcp_server.knowledge_graph_manager import KnowledgeGraphManager
15 |
16 | # Error type constants
17 | ERROR_TYPES = {
18 | "NOT_FOUND": "NOT_FOUND",
19 | "VALIDATION_ERROR": "VALIDATION_ERROR",
20 | "INTERNAL_ERROR": "INTERNAL_ERROR",
21 | "ALREADY_EXISTS": "ALREADY_EXISTS",
22 | "INVALID_RELATION": "INVALID_RELATION",
23 | "NO_RESULTS": "NO_RESULTS", # Used when search returns no matches
24 | }
25 |
26 |
27 | # Response models
28 | class EntityResponse(BaseModel):
29 | success: bool
30 | data: Optional[Dict[str, Any]] = None
31 | error: Optional[str] = None
32 | error_type: Optional[str] = None
33 |
34 |
35 | class GraphResponse(BaseModel):
36 | success: bool
37 | data: Optional[Dict[str, Any]] = None
38 | error: Optional[str] = None
39 | error_type: Optional[str] = None
40 |
41 |
42 | class OperationResponse(BaseModel):
43 | success: bool
44 | error: Optional[str] = None
45 | error_type: Optional[str] = None
46 |
47 |
48 | # Create FastMCP server with dependencies and instructions
49 | mcp = FastMCP(
50 | "Memory",
51 | dependencies=["pydantic", "jsonl"],
52 | version="0.1.0",
53 | instructions="""
54 | Memory MCP server providing knowledge graph functionality.
55 | Available tools:
56 | - get_entity: Retrieve entity by name
57 | - get_graph: Get entire knowledge graph
58 | - create_entities: Create multiple entities
59 | - add_observation: Add observation to entity
60 | - create_relation: Create relation between entities
61 | - search_memory: Search entities by query
62 | - delete_entities: Delete multiple entities
63 | - delete_relation: Delete relation between entities
64 | - flush_memory: Persist changes to storage
65 | """,
66 | )
67 |
68 | # Initialize knowledge graph manager using environment variable
69 | # Default to ~/.claude/memory.jsonl if MEMORY_FILE_PATH not set
70 | default_memory_path = Path.home() / ".claude" / "memory.jsonl"
71 | memory_file = Path(os.getenv("MEMORY_FILE_PATH", str(default_memory_path)))
72 |
73 | logging.info(f"Memory server using file: {memory_file}")
74 |
75 | # Create KnowledgeGraphManager instance
76 | kg = KnowledgeGraphManager(memory_file, 60)
77 |
78 |
79 | def serialize_to_dict(obj: Any) -> Dict:
80 | """Helper to serialize objects to dictionaries."""
81 | if hasattr(obj, "to_dict"):
82 | return obj.to_dict()
83 | elif hasattr(obj, "__dict__"):
84 | return obj.__dict__
85 | else:
86 | return str(obj)
87 |
88 |
89 | @mcp.tool()
90 | async def get_entity(entity_name: str) -> EntityResponse:
91 | """Get entity by name from memory."""
92 | try:
93 | result = await kg.search_nodes(entity_name)
94 | if result:
95 | return EntityResponse(success=True, data=serialize_to_dict(result))
96 | return EntityResponse(
97 | success=False,
98 | error=f"Entity '{entity_name}' not found",
99 | error_type=ERROR_TYPES["NOT_FOUND"],
100 | )
101 | except ValueError as e:
102 | return EntityResponse(
103 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
104 | )
105 | except Exception as e:
106 | return EntityResponse(
107 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
108 | )
109 |
110 |
111 | @mcp.tool()
112 | async def get_graph() -> GraphResponse:
113 | """Get the entire knowledge graph."""
114 | try:
115 | graph = await kg.read_graph()
116 | return GraphResponse(success=True, data=serialize_to_dict(graph))
117 | except Exception as e:
118 | return GraphResponse(
119 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
120 | )
121 |
122 |
123 | @mcp.tool()
124 | async def create_entities(entities: List[Entity]) -> OperationResponse:
125 | """Create multiple new entities."""
126 | try:
127 | await kg.create_entities(entities)
128 | return OperationResponse(success=True)
129 | except ValueError as e:
130 | return OperationResponse(
131 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
132 | )
133 | except Exception as e:
134 | return OperationResponse(
135 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
136 | )
137 |
138 |
139 | @mcp.tool()
140 | async def add_observation(
141 | entity: str, observation: str, ctx: Context = None
142 | ) -> OperationResponse:
143 | """Add an observation to an existing entity."""
144 | try:
145 | if ctx:
146 | ctx.info(f"Adding observation to {entity}")
147 |
148 | # Check if entity exists
149 | exists = await kg.search_nodes(entity)
150 | if not exists:
151 | return OperationResponse(
152 | success=False,
153 | error=f"Entity '{entity}' not found",
154 | error_type=ERROR_TYPES["NOT_FOUND"],
155 | )
156 |
157 | await kg.add_observations(entity, [observation])
158 | return OperationResponse(success=True)
159 | except ValueError as e:
160 | return OperationResponse(
161 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
162 | )
163 | except Exception as e:
164 | return OperationResponse(
165 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
166 | )
167 |
168 |
169 | @mcp.tool()
170 | async def create_relation(
171 | from_entity: str, to_entity: str, relation_type: str, ctx: Context = None
172 | ) -> OperationResponse:
173 | """Create a relation between entities."""
174 | try:
175 | if ctx:
176 | ctx.info(f"Creating relation: {from_entity} -{relation_type}-> {to_entity}")
177 |
178 | # Check if entities exist
179 | from_exists = await kg.search_nodes(from_entity)
180 | to_exists = await kg.search_nodes(to_entity)
181 |
182 | if not from_exists:
183 | return OperationResponse(
184 | success=False,
185 | error=f"Source entity '{from_entity}' not found",
186 | error_type=ERROR_TYPES["NOT_FOUND"],
187 | )
188 |
189 | if not to_exists:
190 | return OperationResponse(
191 | success=False,
192 | error=f"Target entity '{to_entity}' not found",
193 | error_type=ERROR_TYPES["NOT_FOUND"],
194 | )
195 |
196 | await kg.create_relations(
197 | [Relation(from_=from_entity, to=to_entity, relationType=relation_type)]
198 | )
199 | return OperationResponse(success=True)
200 | except ValueError as e:
201 | return OperationResponse(
202 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
203 | )
204 | except Exception as e:
205 | return OperationResponse(
206 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
207 | )
208 |
209 |
210 | @mcp.tool()
211 | async def search_memory(query: str, ctx: Context = None) -> EntityResponse:
212 | """Search memory using natural language queries.
213 |
214 | Handles:
215 | - Temporal queries (e.g., "most recent", "last", "latest")
216 | - Activity queries (e.g., "workout", "exercise")
217 | - General entity searches
218 | """
219 | try:
220 | if ctx:
221 | ctx.info(f"Searching for: {query}")
222 |
223 | # Handle temporal queries
224 | temporal_keywords = ["recent", "last", "latest"]
225 | is_temporal = any(keyword in query.lower() for keyword in temporal_keywords)
226 |
227 | # Extract activity type from query
228 | activity_type = None
229 | if "workout" in query.lower():
230 | activity_type = "workout"
231 | elif "exercise" in query.lower():
232 | activity_type = "exercise"
233 | elif "physical activity" in query.lower():
234 | activity_type = "physical_activity"
235 |
236 | # Search for entities
237 | results = await kg.search_nodes(activity_type if activity_type else query)
238 |
239 | if not results:
240 | return EntityResponse(
241 | success=True,
242 | data={"entities": [], "relations": []},
243 | error="No matching activities found in memory",
244 | error_type="NO_RESULTS",
245 | )
246 |
247 | # For temporal queries, sort by timestamp if available
248 | if is_temporal and isinstance(results, list):
249 | results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
250 | if results:
251 | results = results[0] # Get most recent
252 |
253 | return EntityResponse(success=True, data=serialize_to_dict(results))
254 | except ValueError as e:
255 | return EntityResponse(
256 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
257 | )
258 | except Exception as e:
259 | return EntityResponse(
260 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
261 | )
262 |
263 |
264 | @mcp.tool()
265 | async def delete_entities(names: List[str], ctx: Context = None) -> OperationResponse:
266 | """Delete multiple entities and their relations."""
267 | try:
268 | if ctx:
269 | ctx.info(f"Deleting entities: {', '.join(names)}")
270 |
271 | await kg.delete_entities(names)
272 | return OperationResponse(success=True)
273 | except ValueError as e:
274 | return OperationResponse(
275 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
276 | )
277 | except Exception as e:
278 | return OperationResponse(
279 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
280 | )
281 |
282 |
283 | @mcp.tool()
284 | async def delete_relation(
285 | from_entity: str, to_entity: str, ctx: Context = None
286 | ) -> OperationResponse:
287 | """Delete relations between two entities."""
288 | try:
289 | if ctx:
290 | ctx.info(f"Deleting relations between {from_entity} and {to_entity}")
291 |
292 | # Check if entities exist
293 | from_exists = await kg.search_nodes(from_entity)
294 | to_exists = await kg.search_nodes(to_entity)
295 |
296 | if not from_exists:
297 | return OperationResponse(
298 | success=False,
299 | error=f"Source entity '{from_entity}' not found",
300 | error_type=ERROR_TYPES["NOT_FOUND"],
301 | )
302 |
303 | if not to_exists:
304 | return OperationResponse(
305 | success=False,
306 | error=f"Target entity '{to_entity}' not found",
307 | error_type=ERROR_TYPES["NOT_FOUND"],
308 | )
309 |
310 | await kg.delete_relations(from_entity, to_entity)
311 | return OperationResponse(success=True)
312 | except ValueError as e:
313 | return OperationResponse(
314 | success=False, error=str(e), error_type=ERROR_TYPES["VALIDATION_ERROR"]
315 | )
316 | except Exception as e:
317 | return OperationResponse(
318 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
319 | )
320 |
321 |
322 | @mcp.tool()
323 | async def flush_memory(ctx: Context = None) -> OperationResponse:
324 | """Ensure all changes are persisted to storage."""
325 | try:
326 | if ctx:
327 | ctx.info("Flushing memory to storage")
328 |
329 | await kg.flush()
330 | return OperationResponse(success=True)
331 | except Exception as e:
332 | return OperationResponse(
333 | success=False, error=str(e), error_type=ERROR_TYPES["INTERNAL_ERROR"]
334 | )
335 |
336 |
337 | @mcp.prompt()
338 | def create_entity_prompt(name: str, entity_type: str) -> list[Message]:
339 | """Generate prompt for entity creation."""
340 | return [
341 | UserMessage(
342 | f"I want to create a new entity in memory:\n"
343 | f"Name: {name}\n"
344 | f"Type: {entity_type}\n\n"
345 | f"What observations should I record about this entity?"
346 | )
347 | ]
348 |
349 |
350 | @mcp.prompt()
351 | def search_prompt(query: str) -> list[Message]:
352 | """Generate prompt for memory search."""
353 | return [
354 | UserMessage(
355 | f"I want to search my memory for information about: {query}\n\n"
356 | f"What specific aspects of these results would you like me to explain?"
357 | )
358 | ]
359 |
360 |
361 | @mcp.prompt()
362 | def relation_prompt(from_entity: str, to_entity: str) -> list[Message]:
363 | """Generate prompt for creating a relation."""
364 | return [
365 | UserMessage(
366 | f"I want to establish a relationship between:\n"
367 | f"Source: {from_entity}\n"
368 | f"Target: {to_entity}\n\n"
369 | f"What type of relationship exists between these entities?"
370 | )
371 | ]
372 |
```
--------------------------------------------------------------------------------
/tests/test_backends/test_jsonl.py:
--------------------------------------------------------------------------------
```python
1 | import json
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 | from memory_mcp_server.backends.jsonl import JsonlBackend
7 | from memory_mcp_server.exceptions import EntityNotFoundError, FileAccessError
8 | from memory_mcp_server.interfaces import (
9 | BatchOperation,
10 | BatchOperationType,
11 | BatchResult,
12 | Entity,
13 | Relation,
14 | SearchOptions,
15 | )
16 |
17 | # --- Fixtures ---
18 |
19 |
20 | @pytest.fixture
21 | async def backend(tmp_path: Path) -> JsonlBackend:
22 | b = JsonlBackend(tmp_path / "test.jsonl")
23 | await b.initialize()
24 | yield b
25 | await b.close()
26 |
27 |
28 | # --- Entity Creation / Duplication ---
29 |
30 |
31 | @pytest.mark.asyncio
32 | async def test_create_entities(backend: JsonlBackend):
33 | entities = [
34 | Entity(name="Alice", entityType="person", observations=["likes apples"]),
35 | Entity(name="Bob", entityType="person", observations=["enjoys biking"]),
36 | ]
37 | created = await backend.create_entities(entities)
38 | assert len(created) == 2, "Should create two new entities"
39 |
40 | graph = await backend.read_graph()
41 | assert len(graph.entities) == 2, "Graph should contain two entities"
42 |
43 |
44 | @pytest.mark.asyncio
45 | async def test_duplicate_entities(backend: JsonlBackend):
46 | entity = Entity(name="Alice", entityType="person", observations=["likes apples"])
47 | created1 = await backend.create_entities([entity])
48 | created2 = await backend.create_entities([entity])
49 | assert len(created1) == 1
50 | assert len(created2) == 0, "Duplicate entity creation should return empty list"
51 |
52 |
53 | # --- Relation Creation / Deletion ---
54 |
55 |
56 | @pytest.mark.asyncio
57 | async def test_create_relations(backend: JsonlBackend):
58 | entities = [
59 | Entity(name="Alice", entityType="person", observations=[""]),
60 | Entity(name="Wonderland", entityType="place", observations=["fantasy land"]),
61 | ]
62 | await backend.create_entities(entities)
63 | relation = Relation(from_="Alice", to="Wonderland", relationType="visits")
64 | created_relations = await backend.create_relations([relation])
65 | assert len(created_relations) == 1
66 |
67 | graph = await backend.read_graph()
68 | assert len(graph.relations) == 1
69 |
70 |
71 | @pytest.mark.asyncio
72 | async def test_create_relation_missing_entity(backend: JsonlBackend):
73 | # No entities have been created.
74 | relation = Relation(from_="Alice", to="Nowhere", relationType="visits")
75 | with pytest.raises(EntityNotFoundError):
76 | await backend.create_relations([relation])
77 |
78 |
79 | @pytest.mark.asyncio
80 | async def test_delete_relations(backend: JsonlBackend):
81 | entities = [
82 | Entity(name="Alice", entityType="person", observations=[]),
83 | Entity(name="Bob", entityType="person", observations=[]),
84 | ]
85 | await backend.create_entities(entities)
86 | # Create two distinct relations.
87 | relation1 = Relation(from_="Alice", to="Bob", relationType="likes")
88 | relation2 = Relation(from_="Alice", to="Bob", relationType="follows")
89 | await backend.create_relations([relation1, relation2])
90 | await backend.delete_relations("Alice", "Bob")
91 | graph = await backend.read_graph()
92 | assert (
93 | len(graph.relations) == 0
94 | ), "All relations between Alice and Bob should be removed"
95 |
96 |
97 | @pytest.mark.asyncio
98 | async def test_delete_entities(backend: JsonlBackend):
99 | entities = [
100 | Entity(name="Alice", entityType="person", observations=["obs1"]),
101 | Entity(name="Bob", entityType="person", observations=["obs2"]),
102 | ]
103 | await backend.create_entities(entities)
104 | # Create a relation so that deletion cascades.
105 | relation = Relation(from_="Alice", to="Bob", relationType="knows")
106 | await backend.create_relations([relation])
107 | deleted = await backend.delete_entities(["Alice"])
108 | assert "Alice" in deleted
109 |
110 | graph = await backend.read_graph()
111 | # Only Bob should remain and the relation should have been removed.
112 | assert len(graph.entities) == 1
113 | assert graph.entities[0].name == "Bob"
114 | assert len(graph.relations) == 0
115 |
116 |
117 | # --- Searching ---
118 |
119 |
120 | @pytest.mark.asyncio
121 | async def test_search_nodes_exact(backend: JsonlBackend):
122 | entities = [
123 | Entity(
124 | name="Alice Wonderland", entityType="person", observations=["loves tea"]
125 | ),
126 | Entity(name="Wonderland", entityType="place", observations=["magical"]),
127 | ]
128 | await backend.create_entities(entities)
129 | result = await backend.search_nodes("Wonderland")
130 | # Both entities should match the substring.
131 | assert len(result.entities) == 2
132 | # No relations were created.
133 | assert len(result.relations) == 0
134 |
135 |
136 | @pytest.mark.asyncio
137 | async def test_search_nodes_fuzzy(backend: JsonlBackend):
138 | entities = [
139 | Entity(
140 | name="John Smith", entityType="person", observations=["software engineer"]
141 | ),
142 | Entity(
143 | name="Jane Smith", entityType="person", observations=["product manager"]
144 | ),
145 | ]
146 | await backend.create_entities(entities)
147 | options = SearchOptions(
148 | fuzzy=True,
149 | threshold=90,
150 | weights={"name": 0.7, "type": 0.5, "observations": 0.3},
151 | )
152 | result = await backend.search_nodes("Jon Smith", options)
153 | assert len(result.entities) == 1, "Fuzzy search should match John Smith"
154 | assert result.entities[0].name == "John Smith"
155 |
156 |
157 | @pytest.mark.asyncio
158 | async def test_search_nodes_fuzzy_weights(backend: JsonlBackend):
159 | # Clear any existing entities.
160 | current = await backend.read_graph()
161 | if current.entities:
162 | await backend.delete_entities([e.name for e in current.entities])
163 | entities = [
164 | Entity(
165 | name="Programming Guide",
166 | entityType="document",
167 | observations=["A guide about programming development"],
168 | ),
169 | Entity(
170 | name="Software Manual",
171 | entityType="document",
172 | observations=["Programming tutorial and guide"],
173 | ),
174 | ]
175 | await backend.create_entities(entities)
176 | # With name-weight high, only one should match.
177 | options_name = SearchOptions(
178 | fuzzy=True,
179 | threshold=60,
180 | weights={"name": 1.0, "type": 0.1, "observations": 0.1},
181 | )
182 | result = await backend.search_nodes("programming", options_name)
183 | assert len(result.entities) == 1
184 | assert result.entities[0].name == "Programming Guide"
185 |
186 | # With observation weight high, both should match.
187 | options_obs = SearchOptions(
188 | fuzzy=True,
189 | threshold=60,
190 | weights={"name": 0.1, "type": 0.1, "observations": 1.0},
191 | )
192 | result = await backend.search_nodes("programming", options_obs)
193 | assert len(result.entities) == 2
194 |
195 |
196 | # --- Observations ---
197 |
198 |
199 | @pytest.mark.asyncio
200 | async def test_add_observations(backend: JsonlBackend):
201 | entity = Entity(name="Alice", entityType="person", observations=["initial"])
202 | await backend.create_entities([entity])
203 | await backend.add_observations("Alice", ["update"])
204 | graph = await backend.read_graph()
205 | alice = next(e for e in graph.entities if e.name == "Alice")
206 | assert "update" in alice.observations
207 |
208 |
209 | @pytest.mark.asyncio
210 | async def test_add_batch_observations(backend: JsonlBackend):
211 | entities = [
212 | Entity(name="Alice", entityType="person", observations=["obs1"]),
213 | Entity(name="Bob", entityType="person", observations=["obs2"]),
214 | ]
215 | await backend.create_entities(entities)
216 | observations_map = {"Alice": ["new1", "new2"], "Bob": ["new3"]}
217 | await backend.add_batch_observations(observations_map)
218 | graph = await backend.read_graph()
219 | alice = next(e for e in graph.entities if e.name == "Alice")
220 | bob = next(e for e in graph.entities if e.name == "Bob")
221 | assert set(alice.observations) == {"obs1", "new1", "new2"}
222 | assert set(bob.observations) == {"obs2", "new3"}
223 |
224 |
225 | @pytest.mark.asyncio
226 | async def test_add_batch_observations_empty_map(backend: JsonlBackend):
227 | with pytest.raises(ValueError, match="Observations map cannot be empty"):
228 | await backend.add_batch_observations({})
229 |
230 |
231 | @pytest.mark.asyncio
232 | async def test_add_batch_observations_missing_entity(backend: JsonlBackend):
233 | entity = Entity(name="Alice", entityType="person", observations=["obs1"])
234 | await backend.create_entities([entity])
235 | observations_map = {"Alice": ["new"], "Bob": ["obs"]}
236 | with pytest.raises(EntityNotFoundError):
237 | await backend.add_batch_observations(observations_map)
238 |
239 |
240 | # --- Transaction Management ---
241 |
242 |
243 | @pytest.mark.asyncio
244 | async def test_transaction_management(backend: JsonlBackend):
245 | entities = [
246 | Entity(name="Alice", entityType="person", observations=["obs1"]),
247 | Entity(name="Bob", entityType="person", observations=["obs2"]),
248 | ]
249 | await backend.create_entities(entities)
250 | # Begin a transaction.
251 | await backend.begin_transaction()
252 | await backend.create_entities(
253 | [Entity(name="Charlie", entityType="person", observations=["obs3"])]
254 | )
255 | await backend.delete_entities(["Alice"])
256 | # Within transaction, changes are visible.
257 | graph = await backend.read_graph()
258 | names = {e.name for e in graph.entities}
259 | assert "Charlie" in names
260 | assert "Alice" not in names
261 | # Roll back.
262 | await backend.rollback_transaction()
263 | graph = await backend.read_graph()
264 | names = {e.name for e in graph.entities}
265 | assert "Alice" in names
266 | assert "Charlie" not in names
267 |
268 | # Test commit.
269 | await backend.begin_transaction()
270 | await backend.create_entities(
271 | [Entity(name="Dave", entityType="person", observations=["obs4"])]
272 | )
273 | await backend.commit_transaction()
274 | graph = await backend.read_graph()
275 | names = {e.name for e in graph.entities}
276 | assert "Dave" in names
277 |
278 |
279 | # --- Persistence and File Format ---
280 |
281 |
282 | @pytest.mark.asyncio
283 | async def test_persistence(tmp_path: Path):
284 | file_path = tmp_path / "persist.jsonl"
285 | backend1 = JsonlBackend(file_path)
286 | await backend1.initialize()
287 | entity = Entity(name="Alice", entityType="person", observations=["obs"])
288 | await backend1.create_entities([entity])
289 | await backend1.close()
290 |
291 | backend2 = JsonlBackend(file_path)
292 | await backend2.initialize()
293 | graph = await backend2.read_graph()
294 | assert any(e.name == "Alice" for e in graph.entities)
295 | await backend2.close()
296 |
297 |
298 | @pytest.mark.asyncio
299 | async def test_atomic_writes(tmp_path: Path):
300 | file_path = tmp_path / "atomic.jsonl"
301 | backend = JsonlBackend(file_path)
302 | await backend.initialize()
303 | entity = Entity(name="Alice", entityType="person", observations=["obs"])
304 | await backend.create_entities([entity])
305 | await backend.close()
306 | temp_file = file_path.with_suffix(".tmp")
307 | assert not temp_file.exists(), "Temporary file should be removed after writing"
308 | assert file_path.exists()
309 |
310 |
311 | @pytest.mark.asyncio
312 | async def test_file_format(tmp_path: Path):
313 | file_path = tmp_path / "format.jsonl"
314 | backend = JsonlBackend(file_path)
315 | await backend.initialize()
316 | entity = Entity(name="Alice", entityType="person", observations=["obs"])
317 | relation = Relation(from_="Alice", to="Alice", relationType="self")
318 | await backend.create_entities([entity])
319 | await backend.create_relations([relation])
320 | await backend.close()
321 | with open(file_path, "r", encoding="utf-8") as f:
322 | lines = f.read().splitlines()
323 | assert len(lines) == 2, "File should contain exactly two JSON lines"
324 | data1 = json.loads(lines[0])
325 | data2 = json.loads(lines[1])
326 | types = {data1.get("type"), data2.get("type")}
327 | assert "entity" in types and "relation" in types
328 |
329 |
330 | # --- Error / Corruption Handling ---
331 |
332 |
333 | @pytest.mark.asyncio
334 | async def test_corrupted_file_handling(tmp_path: Path):
335 | file_path = tmp_path / "corrupted.jsonl"
336 | # Write one valid and one corrupted JSON line.
337 | with open(file_path, "w", encoding="utf-8") as f:
338 | f.write(
339 | '{"type": "entity", "name": "Alice", "entityType": "person", "observations": []}\n'
340 | )
341 | f.write(
342 | '{"type": "relation", "from": "Alice", "to": "Bob"'
343 | ) # missing closing brace
344 | backend = JsonlBackend(file_path)
345 | await backend.initialize()
346 | with pytest.raises(FileAccessError, match="Error loading graph"):
347 | await backend.read_graph()
348 | await backend.close()
349 |
350 |
351 | @pytest.mark.asyncio
352 | async def test_file_access_error_propagation(tmp_path: Path):
353 | file_path = tmp_path / "error.jsonl"
354 | # Create a directory with the same name as the file.
355 | file_path.mkdir()
356 | backend = JsonlBackend(file_path)
357 | with pytest.raises(FileAccessError, match="is a directory"):
358 | await backend.initialize()
359 | await backend.close()
360 |
361 |
362 | # --- Caching ---
363 |
364 |
365 | @pytest.mark.asyncio
366 | async def test_caching(backend: JsonlBackend):
367 | entity = Entity(name="Alice", entityType="person", observations=["obs"])
368 | await backend.create_entities([entity])
369 | graph1 = await backend.read_graph()
370 | graph2 = await backend.read_graph()
371 | assert graph1 is graph2, "Repeated reads should return the cached graph"
372 |
373 |
374 | # --- Batch Operations ---
375 |
376 |
377 | @pytest.mark.asyncio
378 | async def test_execute_batch(backend: JsonlBackend):
379 | # Create an initial entity.
380 | await backend.create_entities(
381 | [Entity(name="Alice", entityType="person", observations=["obs"])]
382 | )
383 | operations = [
384 | BatchOperation(
385 | operation_type=BatchOperationType.CREATE_ENTITIES,
386 | data={
387 | "entities": [
388 | Entity(name="Bob", entityType="person", observations=["obs2"])
389 | ]
390 | },
391 | ),
392 | BatchOperation(
393 | operation_type=BatchOperationType.CREATE_RELATIONS,
394 | data={
395 | "relations": [Relation(from_="Alice", to="Bob", relationType="knows")]
396 | },
397 | ),
398 | BatchOperation(
399 | operation_type=BatchOperationType.ADD_OBSERVATIONS,
400 | data={"observations_map": {"Alice": ["new_obs"]}},
401 | ),
402 | ]
403 | result: BatchResult = await backend.execute_batch(operations)
404 | print(result)
405 | assert result.success, "Batch operations should succeed"
406 | graph = await backend.read_graph()
407 | assert any(e.name == "Bob" for e in graph.entities)
408 | assert len(graph.relations) == 1
409 | alice = next(e for e in graph.entities if e.name == "Alice")
410 | assert "new_obs" in alice.observations
411 |
412 |
413 | @pytest.mark.asyncio
414 | async def test_execute_batch_failure(backend: JsonlBackend):
415 | # Create an initial entity.
416 | await backend.create_entities(
417 | [Entity(name="Alice", entityType="person", observations=["obs"])]
418 | )
419 | operations = [
420 | BatchOperation(
421 | operation_type=BatchOperationType.CREATE_RELATIONS,
422 | data={
423 | "relations": [
424 | Relation(from_="Alice", to="NonExistent", relationType="knows")
425 | ]
426 | },
427 | ),
428 | ]
429 | result: BatchResult = await backend.execute_batch(operations)
430 | assert (
431 | not result.success
432 | ), "Batch operation should fail if a relation refers to a non-existent entity"
433 | # Verify that rollback occurred (no partial changes).
434 | graph = await backend.read_graph()
435 | assert len(graph.entities) == 1
436 | assert len(graph.relations) == 0
437 |
```
--------------------------------------------------------------------------------
/memory_mcp_server/backends/jsonl.py:
--------------------------------------------------------------------------------
```python
1 | import asyncio
2 | import json
3 | import time
4 | from collections import defaultdict
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 | from typing import Any, Dict, List, Optional, Set, Tuple, cast
8 |
9 | import aiofiles
10 | from thefuzz import fuzz
11 |
12 | from ..exceptions import EntityNotFoundError, FileAccessError
13 | from ..interfaces import (
14 | BatchOperation,
15 | BatchOperationType,
16 | BatchResult,
17 | Entity,
18 | KnowledgeGraph,
19 | Relation,
20 | SearchOptions,
21 | )
22 | from .base import Backend
23 |
24 |
25 | @dataclass
26 | class SearchResult:
27 | entity: Entity
28 | score: float
29 |
30 |
31 | class ReentrantLock:
32 | def __init__(self):
33 | self._lock = asyncio.Lock()
34 | self._owner = None
35 | self._count = 0
36 |
37 | async def acquire(self):
38 | current = asyncio.current_task()
39 | if self._owner == current:
40 | self._count += 1
41 | return
42 | await self._lock.acquire()
43 | self._owner = current
44 | self._count = 1
45 |
46 | def release(self):
47 | current = asyncio.current_task()
48 | if self._owner != current:
49 | raise RuntimeError("Lock not owned by current task")
50 | self._count -= 1
51 | if self._count == 0:
52 | self._owner = None
53 | self._lock.release()
54 |
55 | async def __aenter__(self):
56 | await self.acquire()
57 | return self
58 |
59 | async def __aexit__(self, exc_type, exc_val, tb):
60 | self.release()
61 |
62 |
63 | class JsonlBackend(Backend):
64 | def __init__(self, memory_path: Path, cache_ttl: int = 60):
65 | self.memory_path = memory_path
66 | self.cache_ttl = cache_ttl
67 | self._cache: Optional[KnowledgeGraph] = None
68 | self._cache_timestamp: float = 0.0
69 | self._cache_file_mtime: float = 0.0
70 | self._dirty = False
71 | self._write_lock = ReentrantLock()
72 | self._lock = asyncio.Lock()
73 |
74 | # Transaction support: when a transaction is active, we work on separate copies.
75 | self._transaction_cache: Optional[KnowledgeGraph] = None
76 | self._transaction_indices: Optional[Dict[str, Any]] = None
77 | self._in_transaction = False
78 |
79 | self._indices: Dict[str, Any] = {
80 | "entity_names": {},
81 | "entity_types": defaultdict(list),
82 | "relations_from": defaultdict(list),
83 | "relations_to": defaultdict(list),
84 | "relation_keys": set(),
85 | "observation_index": defaultdict(set),
86 | }
87 |
88 | async def initialize(self) -> None:
89 | self.memory_path.parent.mkdir(parents=True, exist_ok=True)
90 | if self.memory_path.exists() and self.memory_path.is_dir():
91 | raise FileAccessError(f"Path {self.memory_path} is a directory")
92 |
93 | async def close(self) -> None:
94 | await self.flush()
95 |
96 | def _build_indices(self, graph: KnowledgeGraph) -> None:
97 | # Build indices for faster lookups.
98 | entity_names: Dict[str, Entity] = {}
99 | entity_types: Dict[str, List[Entity]] = defaultdict(list)
100 | relations_from: Dict[str, List[Relation]] = defaultdict(list)
101 | relations_to: Dict[str, List[Relation]] = defaultdict(list)
102 | relation_keys: Set[Tuple[str, str, str]] = set()
103 |
104 | for entity in graph.entities:
105 | entity_names[entity.name] = entity
106 | entity_types[entity.entityType].append(entity)
107 |
108 | for relation in graph.relations:
109 | relations_from[relation.from_].append(relation)
110 | relations_to[relation.to].append(relation)
111 | relation_keys.add((relation.from_, relation.to, relation.relationType))
112 |
113 | self._indices["entity_names"] = entity_names
114 | self._indices["entity_types"] = entity_types
115 | self._indices["relations_from"] = relations_from
116 | self._indices["relations_to"] = relations_to
117 | self._indices["relation_keys"] = relation_keys
118 |
119 | # Build the observation index.
120 | observation_index = cast(
121 | Dict[str, Set[str]], self._indices["observation_index"]
122 | )
123 | observation_index.clear()
124 | for entity in graph.entities:
125 | for obs in entity.observations:
126 | for word in obs.lower().split():
127 | observation_index[word].add(entity.name)
128 |
129 | async def _check_cache(self) -> KnowledgeGraph:
130 | # During a transaction, always use the transaction snapshot.
131 | if self._in_transaction:
132 | return self._transaction_cache # type: ignore
133 |
134 | current_time = time.monotonic()
135 | file_mtime = (
136 | self.memory_path.stat().st_mtime if self.memory_path.exists() else 0
137 | )
138 | needs_refresh = (
139 | self._cache is None
140 | or (current_time - self._cache_timestamp > self.cache_ttl)
141 | or self._dirty
142 | or (file_mtime > self._cache_file_mtime)
143 | )
144 |
145 | if needs_refresh:
146 | async with self._lock:
147 | current_time = time.monotonic()
148 | file_mtime = (
149 | self.memory_path.stat().st_mtime if self.memory_path.exists() else 0
150 | )
151 | needs_refresh = (
152 | self._cache is None
153 | or (current_time - self._cache_timestamp > self.cache_ttl)
154 | or self._dirty
155 | or (file_mtime > self._cache_file_mtime)
156 | )
157 | if needs_refresh:
158 | try:
159 | graph = await self._load_graph_from_file()
160 | self._cache = graph
161 | self._cache_timestamp = current_time
162 | self._cache_file_mtime = file_mtime
163 | self._build_indices(graph)
164 | self._dirty = False
165 | except FileAccessError:
166 | raise
167 | except Exception as e:
168 | raise FileAccessError(f"Error loading graph: {str(e)}") from e
169 |
170 | return cast(KnowledgeGraph, self._cache)
171 |
172 | async def _load_graph_from_file(self) -> KnowledgeGraph:
173 | if not self.memory_path.exists():
174 | return KnowledgeGraph(entities=[], relations=[])
175 |
176 | graph = KnowledgeGraph(entities=[], relations=[])
177 | try:
178 | async with aiofiles.open(self.memory_path, mode="r", encoding="utf-8") as f:
179 | async for line in f:
180 | line = line.strip()
181 | if not line:
182 | continue
183 | try:
184 | item = json.loads(line)
185 | if item["type"] == "entity":
186 | graph.entities.append(
187 | Entity(
188 | name=item["name"],
189 | entityType=item["entityType"],
190 | observations=item["observations"],
191 | )
192 | )
193 | elif item["type"] == "relation":
194 | graph.relations.append(
195 | Relation(
196 | from_=item["from"],
197 | to=item["to"],
198 | relationType=item["relationType"],
199 | )
200 | )
201 | except json.JSONDecodeError as e:
202 | raise FileAccessError(f"Error loading graph: {str(e)}") from e
203 | except KeyError as e:
204 | raise FileAccessError(
205 | f"Error loading graph: Missing required key {str(e)}"
206 | ) from e
207 | return graph
208 | except Exception as err:
209 | raise FileAccessError(f"Error reading file: {str(err)}") from err
210 |
211 | async def _save_graph(self, graph: KnowledgeGraph) -> None:
212 | # This function writes to disk. Note that during a transaction, it is only called on commit.
213 | temp_path = self.memory_path.with_suffix(".tmp")
214 | buffer_size = 1000 # Buffer size (number of lines)
215 | try:
216 | async with aiofiles.open(temp_path, mode="w", encoding="utf-8") as f:
217 | buffer = []
218 | # Write entities.
219 | for entity in graph.entities:
220 | line = json.dumps(
221 | {
222 | "type": "entity",
223 | "name": entity.name,
224 | "entityType": entity.entityType,
225 | "observations": entity.observations,
226 | }
227 | )
228 | buffer.append(line)
229 | if len(buffer) >= buffer_size:
230 | await f.write("\n".join(buffer) + "\n")
231 | buffer = []
232 | if buffer:
233 | await f.write("\n".join(buffer) + "\n")
234 | buffer = []
235 |
236 | # Write relations.
237 | for relation in graph.relations:
238 | line = json.dumps(
239 | {
240 | "type": "relation",
241 | "from": relation.from_,
242 | "to": relation.to,
243 | "relationType": relation.relationType,
244 | }
245 | )
246 | buffer.append(line)
247 | if len(buffer) >= buffer_size:
248 | await f.write("\n".join(buffer) + "\n")
249 | buffer = []
250 | if buffer:
251 | await f.write("\n".join(buffer) + "\n")
252 | temp_path.replace(self.memory_path)
253 | except Exception as err:
254 | raise FileAccessError(f"Error saving file: {str(err)}") from err
255 | finally:
256 | if temp_path.exists():
257 | try:
258 | temp_path.unlink()
259 | except Exception:
260 | pass
261 |
262 | async def _get_current_state(self) -> Tuple[KnowledgeGraph, Dict[str, Any]]:
263 | # Returns the active graph and indices. If a transaction is in progress,
264 | # return the transaction copies; otherwise, return the persistent ones.
265 | if self._in_transaction:
266 | return self._transaction_cache, self._transaction_indices # type: ignore
267 | else:
268 | graph = await self._check_cache()
269 | return graph, self._indices
270 |
271 | async def create_entities(self, entities: List[Entity]) -> List[Entity]:
272 | async with self._write_lock:
273 | graph, indices = await self._get_current_state()
274 | existing_entities = cast(Dict[str, Entity], indices["entity_names"])
275 | new_entities = []
276 |
277 | for entity in entities:
278 | if not entity.name or not entity.entityType:
279 | raise ValueError(f"Invalid entity: {entity}")
280 | if entity.name not in existing_entities:
281 | new_entities.append(entity)
282 | existing_entities[entity.name] = entity
283 | cast(Dict[str, List[Entity]], indices["entity_types"]).setdefault(
284 | entity.entityType, []
285 | ).append(entity)
286 |
287 | if new_entities:
288 | graph.entities.extend(new_entities)
289 | # If not in a transaction, immediately persist the change.
290 | if not self._in_transaction:
291 | self._dirty = True
292 | await self._save_graph(graph)
293 | self._dirty = False
294 | self._cache_timestamp = time.monotonic()
295 |
296 | return new_entities
297 |
298 | async def delete_entities(self, entity_names: List[str]) -> List[str]:
299 | if not entity_names:
300 | return []
301 |
302 | async with self._write_lock:
303 | graph, indices = await self._get_current_state()
304 | existing_entities = cast(Dict[str, Entity], indices["entity_names"])
305 | deleted_names = []
306 | relation_keys = cast(Set[Tuple[str, str, str]], indices["relation_keys"])
307 |
308 | for name in entity_names:
309 | if name in existing_entities:
310 | entity = existing_entities.pop(name)
311 | entity_type_list = cast(
312 | Dict[str, List[Entity]], indices["entity_types"]
313 | ).get(entity.entityType, [])
314 | if entity in entity_type_list:
315 | entity_type_list.remove(entity)
316 |
317 | # Remove associated relations.
318 | relations_from = cast(
319 | Dict[str, List[Relation]], indices["relations_from"]
320 | ).get(name, [])
321 | relations_to = cast(
322 | Dict[str, List[Relation]], indices["relations_to"]
323 | ).get(name, [])
324 | relations_to_remove = relations_from + relations_to
325 |
326 | for relation in relations_to_remove:
327 | if relation in graph.relations:
328 | graph.relations.remove(relation)
329 | relation_keys.discard(
330 | (relation.from_, relation.to, relation.relationType)
331 | )
332 | if relation in cast(
333 | Dict[str, List[Relation]], indices["relations_from"]
334 | ).get(relation.from_, []):
335 | cast(Dict[str, List[Relation]], indices["relations_from"])[
336 | relation.from_
337 | ].remove(relation)
338 | if relation in cast(
339 | Dict[str, List[Relation]], indices["relations_to"]
340 | ).get(relation.to, []):
341 | cast(Dict[str, List[Relation]], indices["relations_to"])[
342 | relation.to
343 | ].remove(relation)
344 |
345 | deleted_names.append(name)
346 |
347 | if deleted_names:
348 | graph.entities = [
349 | e for e in graph.entities if e.name not in deleted_names
350 | ]
351 | if not self._in_transaction:
352 | self._dirty = True
353 | await self._save_graph(graph)
354 | self._dirty = False
355 | self._cache_timestamp = time.monotonic()
356 |
357 | return deleted_names
358 |
359 | async def create_relations(self, relations: List[Relation]) -> List[Relation]:
360 | async with self._write_lock:
361 | graph, indices = await self._get_current_state()
362 | existing_entities = cast(Dict[str, Entity], indices["entity_names"])
363 | relation_keys = cast(Set[Tuple[str, str, str]], indices["relation_keys"])
364 | new_relations = []
365 |
366 | for relation in relations:
367 | if not relation.from_ or not relation.to or not relation.relationType:
368 | raise ValueError(f"Invalid relation: {relation}")
369 |
370 | if relation.from_ not in existing_entities:
371 | raise EntityNotFoundError(f"Entity not found: {relation.from_}")
372 | if relation.to not in existing_entities:
373 | raise EntityNotFoundError(f"Entity not found: {relation.to}")
374 |
375 | key = (relation.from_, relation.to, relation.relationType)
376 | if key not in relation_keys:
377 | new_relations.append(relation)
378 | relation_keys.add(key)
379 | cast(
380 | Dict[str, List[Relation]], indices["relations_from"]
381 | ).setdefault(relation.from_, []).append(relation)
382 | cast(Dict[str, List[Relation]], indices["relations_to"]).setdefault(
383 | relation.to, []
384 | ).append(relation)
385 |
386 | if new_relations:
387 | graph.relations.extend(new_relations)
388 | if not self._in_transaction:
389 | self._dirty = True
390 | await self._save_graph(graph)
391 | self._dirty = False
392 | self._cache_timestamp = time.monotonic()
393 |
394 | return new_relations
395 |
396 | async def delete_relations(self, from_: str, to: str) -> None:
397 | async with self._write_lock:
398 | graph, indices = await self._get_current_state()
399 | existing_entities = cast(Dict[str, Entity], indices["entity_names"])
400 |
401 | if from_ not in existing_entities:
402 | raise EntityNotFoundError(f"Entity not found: {from_}")
403 | if to not in existing_entities:
404 | raise EntityNotFoundError(f"Entity not found: {to}")
405 |
406 | relations_from = cast(
407 | Dict[str, List[Relation]], indices["relations_from"]
408 | ).get(from_, [])
409 | relations_to_remove = [rel for rel in relations_from if rel.to == to]
410 |
411 | if relations_to_remove:
412 | graph.relations = [
413 | rel for rel in graph.relations if rel not in relations_to_remove
414 | ]
415 | relation_keys = cast(
416 | Set[Tuple[str, str, str]], indices["relation_keys"]
417 | )
418 | for rel in relations_to_remove:
419 | relation_keys.discard((rel.from_, rel.to, rel.relationType))
420 | if rel in cast(
421 | Dict[str, List[Relation]], indices["relations_from"]
422 | ).get(from_, []):
423 | cast(Dict[str, List[Relation]], indices["relations_from"])[
424 | from_
425 | ].remove(rel)
426 | if rel in cast(
427 | Dict[str, List[Relation]], indices["relations_to"]
428 | ).get(to, []):
429 | cast(Dict[str, List[Relation]], indices["relations_to"])[
430 | to
431 | ].remove(rel)
432 | if not self._in_transaction:
433 | self._dirty = True
434 | await self._save_graph(graph)
435 | self._dirty = False
436 | self._cache_timestamp = time.monotonic()
437 |
438 | async def read_graph(self) -> KnowledgeGraph:
439 | return await self._check_cache()
440 |
441 | async def flush(self) -> None:
442 | async with self._write_lock:
443 | # During a transaction, disk is not touched until commit.
444 | if self._dirty and not self._in_transaction:
445 | graph = await self._check_cache()
446 | await self._save_graph(graph)
447 | self._dirty = False
448 | self._cache_timestamp = time.monotonic()
449 |
450 | async def search_nodes(
451 | self, query: str, options: Optional[SearchOptions] = None
452 | ) -> KnowledgeGraph:
453 | """
454 | Search for entities and relations matching the query.
455 | If options is provided and options.fuzzy is True, fuzzy matching is used with weights and threshold.
456 | Otherwise, a simple case‐insensitive substring search is performed.
457 | Relations are returned only if both endpoints are in the set of matched entities.
458 | """
459 | graph = await self._check_cache()
460 | matched_entities = []
461 | if options is not None and options.fuzzy:
462 | # Use provided weights or default to 1.0 if not provided.
463 | weights = (
464 | options.weights
465 | if options.weights is not None
466 | else {"name": 1.0, "type": 1.0, "observations": 1.0}
467 | )
468 | q = query.strip()
469 | for entity in graph.entities:
470 | # Compute robust scores for each field.
471 | name_score = fuzz.WRatio(q, entity.name)
472 | type_score = fuzz.WRatio(q, entity.entityType)
473 | obs_score = 0
474 | if entity.observations:
475 | # For each observation, take the best between WRatio and partial_ratio.
476 | scores = [
477 | max(fuzz.WRatio(q, obs), fuzz.partial_ratio(q, obs))
478 | for obs in entity.observations
479 | ]
480 | obs_score = max(scores) if scores else 0
481 |
482 | total_score = (
483 | name_score * weights.get("name", 1.0)
484 | + type_score * weights.get("type", 1.0)
485 | + obs_score * weights.get("observations", 1.0)
486 | )
487 | if total_score >= options.threshold:
488 | matched_entities.append(entity)
489 | else:
490 | q = query.lower()
491 | for entity in graph.entities:
492 | if (
493 | q in entity.name.lower()
494 | or q in entity.entityType.lower()
495 | or any(q in obs.lower() for obs in entity.observations)
496 | ):
497 | matched_entities.append(entity)
498 |
499 | matched_names = {entity.name for entity in matched_entities}
500 | matched_relations = [
501 | rel
502 | for rel in graph.relations
503 | if rel.from_ in matched_names and rel.to in matched_names
504 | ]
505 | return KnowledgeGraph(entities=matched_entities, relations=matched_relations)
506 |
507 | async def add_observations(self, entity_name: str, observations: List[str]) -> None:
508 | if not observations:
509 | raise ValueError("Observations list cannot be empty")
510 |
511 | async with self._write_lock:
512 | graph, indices = await self._get_current_state()
513 | existing_entities = cast(Dict[str, Entity], indices["entity_names"])
514 |
515 | if entity_name not in existing_entities:
516 | raise EntityNotFoundError(f"Entity not found: {entity_name}")
517 |
518 | entity = existing_entities[entity_name]
519 | updated_entity = Entity(
520 | name=entity.name,
521 | entityType=entity.entityType,
522 | observations=list(entity.observations) + observations,
523 | )
524 |
525 | graph.entities = [
526 | updated_entity if e.name == entity_name else e for e in graph.entities
527 | ]
528 | existing_entities[entity_name] = updated_entity
529 |
530 | entity_types = cast(Dict[str, List[Entity]], indices["entity_types"])
531 | if entity_name in [
532 | e.name for e in entity_types.get(updated_entity.entityType, [])
533 | ]:
534 | entity_types[updated_entity.entityType] = [
535 | updated_entity if e.name == entity_name else e
536 | for e in entity_types[updated_entity.entityType]
537 | ]
538 |
539 | if not self._in_transaction:
540 | self._dirty = True
541 | await self._save_graph(graph)
542 | self._dirty = False
543 | self._cache_timestamp = time.monotonic()
544 |
545 | async def add_batch_observations(
546 | self, observations_map: Dict[str, List[str]]
547 | ) -> None:
548 | if not observations_map:
549 | raise ValueError("Observations map cannot be empty")
550 |
551 | async with self._write_lock:
552 | graph, indices = await self._get_current_state()
553 | existing_entities = cast(Dict[str, Entity], indices["entity_names"])
554 | entity_types = cast(Dict[str, List[Entity]], indices["entity_types"])
555 |
556 | missing_entities = [
557 | name for name in observations_map if name not in existing_entities
558 | ]
559 | if missing_entities:
560 | raise EntityNotFoundError(
561 | f"Entities not found: {', '.join(missing_entities)}"
562 | )
563 |
564 | updated_entities = {}
565 | for entity_name, observations in observations_map.items():
566 | if not observations:
567 | continue
568 | entity = existing_entities[entity_name]
569 | updated_entity = Entity(
570 | name=entity.name,
571 | entityType=entity.entityType,
572 | observations=list(entity.observations) + observations,
573 | )
574 | updated_entities[entity_name] = updated_entity
575 |
576 | if updated_entities:
577 | graph.entities = [
578 | updated_entities.get(e.name, e) for e in graph.entities
579 | ]
580 | for updated_entity in updated_entities.values():
581 | existing_entities[updated_entity.name] = updated_entity
582 | et_list = entity_types.get(updated_entity.entityType, [])
583 | for i, e in enumerate(et_list):
584 | if e.name == updated_entity.name:
585 | et_list[i] = updated_entity
586 | break
587 | if not self._in_transaction:
588 | self._dirty = True
589 | await self._save_graph(graph)
590 | self._dirty = False
591 | self._cache_timestamp = time.monotonic()
592 |
593 | #
594 | # Transaction Methods
595 | #
596 | async def begin_transaction(self) -> None:
597 | async with self._write_lock:
598 | if self._in_transaction:
599 | raise ValueError("Transaction already in progress")
600 | graph = await self._check_cache()
601 | # Make deep (shallow for immutable entities) copies of state.
602 | self._transaction_cache = KnowledgeGraph(
603 | entities=list(graph.entities), relations=list(graph.relations)
604 | )
605 | self._transaction_indices = {
606 | "entity_names": dict(self._indices["entity_names"]),
607 | "entity_types": defaultdict(
608 | list, {k: list(v) for k, v in self._indices["entity_types"].items()}
609 | ),
610 | "relations_from": defaultdict(
611 | list,
612 | {k: list(v) for k, v in self._indices["relations_from"].items()},
613 | ),
614 | "relations_to": defaultdict(
615 | list, {k: list(v) for k, v in self._indices["relations_to"].items()}
616 | ),
617 | "relation_keys": set(self._indices["relation_keys"]),
618 | "observation_index": defaultdict(
619 | set,
620 | {k: set(v) for k, v in self._indices["observation_index"].items()},
621 | ),
622 | }
623 | self._in_transaction = True
624 |
625 | async def rollback_transaction(self) -> None:
626 | async with self._write_lock:
627 | if not self._in_transaction:
628 | raise ValueError("No transaction in progress")
629 | # Discard the transaction state; since disk writes were deferred, the file remains unchanged.
630 | self._transaction_cache = None
631 | self._transaction_indices = None
632 | self._in_transaction = False
633 |
634 | async def commit_transaction(self) -> None:
635 | async with self._write_lock:
636 | if not self._in_transaction:
637 | raise ValueError("No transaction in progress")
638 | # Persist the transaction state to disk.
639 | await self._save_graph(cast(KnowledgeGraph, self._transaction_cache))
640 | # Update the persistent state with the transaction snapshot.
641 | self._cache = self._transaction_cache
642 | self._indices = self._transaction_indices # type: ignore
643 | self._transaction_cache = None
644 | self._transaction_indices = None
645 | self._in_transaction = False
646 | self._dirty = False
647 | self._cache_timestamp = time.monotonic()
648 |
649 | async def execute_batch(self, operations: List[BatchOperation]) -> BatchResult:
650 | if not operations:
651 | return BatchResult(
652 | success=True,
653 | operations_completed=0,
654 | failed_operations=[],
655 | )
656 |
657 | async with self._write_lock:
658 | try:
659 | # Start a transaction so that no disk writes occur until commit.
660 | await self.begin_transaction()
661 |
662 | completed = 0
663 | failed_ops: List[Tuple[BatchOperation, str]] = []
664 |
665 | # Execute each operation.
666 | for operation in operations:
667 | try:
668 | if (
669 | operation.operation_type
670 | == BatchOperationType.CREATE_ENTITIES
671 | ):
672 | await self.create_entities(operation.data["entities"])
673 | elif (
674 | operation.operation_type
675 | == BatchOperationType.DELETE_ENTITIES
676 | ):
677 | await self.delete_entities(operation.data["entity_names"])
678 | elif (
679 | operation.operation_type
680 | == BatchOperationType.CREATE_RELATIONS
681 | ):
682 | await self.create_relations(operation.data["relations"])
683 | elif (
684 | operation.operation_type
685 | == BatchOperationType.DELETE_RELATIONS
686 | ):
687 | await self.delete_relations(
688 | operation.data["from_"], operation.data["to"]
689 | )
690 | elif (
691 | operation.operation_type
692 | == BatchOperationType.ADD_OBSERVATIONS
693 | ):
694 | await self.add_batch_observations(
695 | operation.data["observations_map"]
696 | )
697 | else:
698 | raise ValueError(
699 | f"Unknown operation type: {operation.operation_type}"
700 | )
701 | completed += 1
702 | except Exception as e:
703 | failed_ops.append((operation, str(e)))
704 | if not operation.data.get("allow_partial", False):
705 | # On failure, rollback and return.
706 | await self.rollback_transaction()
707 | return BatchResult(
708 | success=False,
709 | operations_completed=completed,
710 | failed_operations=failed_ops,
711 | error_message=f"Operation failed: {str(e)}",
712 | )
713 |
714 | # Commit the transaction (persisting all changes) or report partial success.
715 | await self.commit_transaction()
716 | if failed_ops:
717 | return BatchResult(
718 | success=True,
719 | operations_completed=completed,
720 | failed_operations=failed_ops,
721 | error_message="Some operations failed",
722 | )
723 | else:
724 | return BatchResult(
725 | success=True,
726 | operations_completed=completed,
727 | failed_operations=[],
728 | )
729 |
730 | except Exception as e:
731 | if self._in_transaction:
732 | await self.rollback_transaction()
733 | return BatchResult(
734 | success=False,
735 | operations_completed=0,
736 | failed_operations=[],
737 | error_message=f"Batch execution failed: {str(e)}",
738 | )
739 |
```