This is page 6 of 9. Use http://codebase.md/getzep/graphiti?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/mcp_server/tests/test_stress_load.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Stress and load testing for Graphiti MCP Server.
Tests system behavior under high load, resource constraints, and edge conditions.
"""
import asyncio
import gc
import random
import time
from dataclasses import dataclass
import psutil
import pytest
from test_fixtures import TestDataGenerator, graphiti_test_client
@dataclass
class LoadTestConfig:
"""Configuration for load testing scenarios."""
num_clients: int = 10
operations_per_client: int = 100
ramp_up_time: float = 5.0 # seconds
test_duration: float = 60.0 # seconds
target_throughput: float | None = None # ops/sec
think_time: float = 0.1 # seconds between ops
@dataclass
class LoadTestResult:
"""Results from a load test run."""
total_operations: int
successful_operations: int
failed_operations: int
duration: float
throughput: float
average_latency: float
p50_latency: float
p95_latency: float
p99_latency: float
max_latency: float
errors: dict[str, int]
resource_usage: dict[str, float]
class LoadTester:
"""Orchestrate load testing scenarios."""
def __init__(self, config: LoadTestConfig):
self.config = config
self.metrics: list[tuple[float, float, bool]] = [] # (start, duration, success)
self.errors: dict[str, int] = {}
self.start_time: float | None = None
async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
"""Run workload for a single simulated client."""
stats = {'success': 0, 'failure': 0}
data_gen = TestDataGenerator()
# Ramp-up delay
ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
await asyncio.sleep(ramp_delay)
for op_num in range(self.config.operations_per_client):
operation_start = time.time()
try:
# Randomly select operation type
operation = random.choice(
[
'add_memory',
'search_memory_nodes',
'get_episodes',
]
)
if operation == 'add_memory':
args = {
'name': f'Load Test {client_id}-{op_num}',
'episode_body': data_gen.generate_technical_document(),
'source': 'text',
'source_description': 'load test',
'group_id': group_id,
}
elif operation == 'search_memory_nodes':
args = {
'query': random.choice(['performance', 'architecture', 'test', 'data']),
'group_id': group_id,
'limit': 10,
}
else: # get_episodes
args = {
'group_id': group_id,
'last_n': 10,
}
# Execute operation with timeout
await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)
duration = time.time() - operation_start
self.metrics.append((operation_start, duration, True))
stats['success'] += 1
except asyncio.TimeoutError:
duration = time.time() - operation_start
self.metrics.append((operation_start, duration, False))
self.errors['timeout'] = self.errors.get('timeout', 0) + 1
stats['failure'] += 1
except Exception as e:
duration = time.time() - operation_start
self.metrics.append((operation_start, duration, False))
error_type = type(e).__name__
self.errors[error_type] = self.errors.get(error_type, 0) + 1
stats['failure'] += 1
# Think time between operations
await asyncio.sleep(self.config.think_time)
# Stop if we've exceeded test duration
if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
break
return stats
def calculate_results(self) -> LoadTestResult:
"""Calculate load test results from metrics."""
if not self.metrics:
return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})
successful = [m for m in self.metrics if m[2]]
failed = [m for m in self.metrics if not m[2]]
latencies = sorted([m[1] for m in self.metrics])
duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
# Calculate percentiles
def percentile(data: list[float], p: float) -> float:
if not data:
return 0.0
idx = int(len(data) * p / 100)
return data[min(idx, len(data) - 1)]
# Get resource usage
process = psutil.Process()
resource_usage = {
'cpu_percent': process.cpu_percent(),
'memory_mb': process.memory_info().rss / 1024 / 1024,
'num_threads': process.num_threads(),
}
return LoadTestResult(
total_operations=len(self.metrics),
successful_operations=len(successful),
failed_operations=len(failed),
duration=duration,
throughput=len(self.metrics) / duration if duration > 0 else 0,
average_latency=sum(latencies) / len(latencies) if latencies else 0,
p50_latency=percentile(latencies, 50),
p95_latency=percentile(latencies, 95),
p99_latency=percentile(latencies, 99),
max_latency=max(latencies) if latencies else 0,
errors=self.errors,
resource_usage=resource_usage,
)
class TestLoadScenarios:
"""Various load testing scenarios."""
@pytest.mark.asyncio
@pytest.mark.slow
async def test_sustained_load(self):
"""Test system under sustained moderate load."""
config = LoadTestConfig(
num_clients=5,
operations_per_client=20,
ramp_up_time=2.0,
test_duration=30.0,
think_time=0.5,
)
async with graphiti_test_client() as (session, group_id):
tester = LoadTester(config)
tester.start_time = time.time()
# Run client workloads
client_tasks = []
for client_id in range(config.num_clients):
task = tester.run_client_workload(client_id, session, group_id)
client_tasks.append(task)
# Execute all clients
await asyncio.gather(*client_tasks)
# Calculate results
results = tester.calculate_results()
# Assertions
assert results.successful_operations > results.failed_operations
assert results.average_latency < 5.0, (
f'Average latency too high: {results.average_latency:.2f}s'
)
assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'
# Report results
print('\nSustained Load Test Results:')
print(f' Total operations: {results.total_operations}')
print(
f' Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
)
print(f' Throughput: {results.throughput:.2f} ops/s')
print(f' Avg latency: {results.average_latency:.2f}s')
print(f' P95 latency: {results.p95_latency:.2f}s')
@pytest.mark.asyncio
@pytest.mark.slow
async def test_spike_load(self):
"""Test system response to sudden load spikes."""
async with graphiti_test_client() as (session, group_id):
# Normal load phase
normal_tasks = []
for i in range(3):
task = session.call_tool(
'add_memory',
{
'name': f'Normal Load {i}',
'episode_body': 'Normal operation',
'source': 'text',
'source_description': 'normal',
'group_id': group_id,
},
)
normal_tasks.append(task)
await asyncio.sleep(0.5)
await asyncio.gather(*normal_tasks)
# Spike phase - sudden burst of requests
spike_start = time.time()
spike_tasks = []
for i in range(50):
task = session.call_tool(
'add_memory',
{
'name': f'Spike Load {i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'spike',
'group_id': group_id,
},
)
spike_tasks.append(task)
# Execute spike
spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
spike_duration = time.time() - spike_start
# Analyze spike handling
spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
print('\nSpike Load Test Results:')
print(f' Spike size: {len(spike_tasks)} operations')
print(f' Duration: {spike_duration:.2f}s')
print(f' Success rate: {spike_success_rate * 100:.1f}%')
print(f' Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')
# System should handle at least 80% of spike
assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'
@pytest.mark.asyncio
@pytest.mark.slow
async def test_memory_leak_detection(self):
"""Test for memory leaks during extended operation."""
async with graphiti_test_client() as (session, group_id):
process = psutil.Process()
gc.collect() # Force garbage collection
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
# Perform many operations
for batch in range(10):
batch_tasks = []
for i in range(10):
task = session.call_tool(
'add_memory',
{
'name': f'Memory Test {batch}-{i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'memory test',
'group_id': group_id,
},
)
batch_tasks.append(task)
await asyncio.gather(*batch_tasks)
# Force garbage collection between batches
gc.collect()
await asyncio.sleep(1)
# Check memory after operations
gc.collect()
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_growth = final_memory - initial_memory
print('\nMemory Leak Test:')
print(f' Initial memory: {initial_memory:.1f} MB')
print(f' Final memory: {final_memory:.1f} MB')
print(f' Growth: {memory_growth:.1f} MB')
# Allow for some memory growth but flag potential leaks
# This is a soft check - actual threshold depends on system
if memory_growth > 100: # More than 100MB growth
print(f' ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth')
@pytest.mark.asyncio
@pytest.mark.slow
async def test_connection_pool_exhaustion(self):
"""Test behavior when connection pools are exhausted."""
async with graphiti_test_client() as (session, group_id):
# Create many concurrent long-running operations
long_tasks = []
for i in range(100): # Many more than typical pool size
task = session.call_tool(
'search_memory_nodes',
{
'query': f'complex query {i} '
+ ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
'group_id': group_id,
'limit': 100,
},
)
long_tasks.append(task)
# Execute with timeout
try:
results = await asyncio.wait_for(
asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
)
# Count connection-related errors
connection_errors = sum(
1
for r in results
if isinstance(r, Exception) and 'connection' in str(r).lower()
)
print('\nConnection Pool Test:')
print(f' Total requests: {len(long_tasks)}')
print(f' Connection errors: {connection_errors}')
except asyncio.TimeoutError:
print(' Test timed out - possible deadlock or exhaustion')
@pytest.mark.asyncio
@pytest.mark.slow
async def test_gradual_degradation(self):
"""Test system degradation under increasing load."""
async with graphiti_test_client() as (session, group_id):
load_levels = [5, 10, 20, 40, 80] # Increasing concurrent operations
results_by_level = {}
for level in load_levels:
level_start = time.time()
tasks = []
for i in range(level):
task = session.call_tool(
'add_memory',
{
'name': f'Load Level {level} Op {i}',
'episode_body': f'Testing at load level {level}',
'source': 'text',
'source_description': 'degradation test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute level
level_results = await asyncio.gather(*tasks, return_exceptions=True)
level_duration = time.time() - level_start
# Calculate metrics
failures = sum(1 for r in level_results if isinstance(r, Exception))
success_rate = (level - failures) / level * 100
throughput = level / level_duration
results_by_level[level] = {
'success_rate': success_rate,
'throughput': throughput,
'duration': level_duration,
}
print(f'\nLoad Level {level}:')
print(f' Success rate: {success_rate:.1f}%')
print(f' Throughput: {throughput:.2f} ops/s')
print(f' Duration: {level_duration:.2f}s')
# Brief pause between levels
await asyncio.sleep(2)
# Verify graceful degradation
# Success rate should not drop below 50% even at high load
for level, metrics in results_by_level.items():
assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'
class TestResourceLimits:
"""Test behavior at resource limits."""
@pytest.mark.asyncio
async def test_large_payload_handling(self):
"""Test handling of very large payloads."""
async with graphiti_test_client() as (session, group_id):
payload_sizes = [
(1_000, '1KB'),
(10_000, '10KB'),
(100_000, '100KB'),
(1_000_000, '1MB'),
]
for size, label in payload_sizes:
content = 'x' * size
start_time = time.time()
try:
await asyncio.wait_for(
session.call_tool(
'add_memory',
{
'name': f'Large Payload {label}',
'episode_body': content,
'source': 'text',
'source_description': 'payload test',
'group_id': group_id,
},
),
timeout=30.0,
)
duration = time.time() - start_time
status = '✅ Success'
except asyncio.TimeoutError:
duration = 30.0
status = '⏱️ Timeout'
except Exception as e:
duration = time.time() - start_time
status = f'❌ Error: {type(e).__name__}'
print(f'Payload {label}: {status} ({duration:.2f}s)')
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
"""Test handling of rate limits."""
async with graphiti_test_client() as (session, group_id):
# Rapid fire requests to trigger rate limits
rapid_tasks = []
for i in range(100):
task = session.call_tool(
'add_memory',
{
'name': f'Rate Limit Test {i}',
'episode_body': f'Testing rate limit {i}',
'source': 'text',
'source_description': 'rate test',
'group_id': group_id,
},
)
rapid_tasks.append(task)
# Execute without delays
results = await asyncio.gather(*rapid_tasks, return_exceptions=True)
# Count rate limit errors
rate_limit_errors = sum(
1
for r in results
if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
)
print('\nRate Limit Test:')
print(f' Total requests: {len(rapid_tasks)}')
print(f' Rate limit errors: {rate_limit_errors}')
print(
f' Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
)
def generate_load_test_report(results: list[LoadTestResult]) -> str:
"""Generate comprehensive load test report."""
report = []
report.append('\n' + '=' * 60)
report.append('LOAD TEST REPORT')
report.append('=' * 60)
for i, result in enumerate(results):
report.append(f'\nTest Run {i + 1}:')
report.append(f' Total Operations: {result.total_operations}')
report.append(
f' Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
)
report.append(f' Throughput: {result.throughput:.2f} ops/s')
report.append(
f' Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
)
if result.errors:
report.append(' Errors:')
for error_type, count in result.errors.items():
report.append(f' {error_type}: {count}')
report.append(' Resource Usage:')
for metric, value in result.resource_usage.items():
report.append(f' {metric}: {value:.2f}')
report.append('=' * 60)
return '\n'.join(report)
if __name__ == '__main__':
pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections.abc import Awaitable, Callable
from time import time
from typing import Any
from pydantic import BaseModel
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import (
EntityNode,
EpisodeType,
EpisodicNode,
create_entity_node_embeddings,
)
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
EntitySummary,
ExtractedEntities,
ExtractedEntity,
MissedEntities,
)
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupCandidateIndexes,
DedupResolutionState,
_build_candidate_indexes,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
filter_existing_duplicate_of_edges,
)
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
logger = logging.getLogger(__name__)
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
async def extract_nodes_reflexion(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
node_names: list[str],
group_id: str | None = None,
) -> list[str]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'extracted_entities': node_names,
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.reflexion(context),
MissedEntities,
group_id=group_id,
prompt_name='extract_nodes.reflexion',
)
missed_entities = llm_response.get('missed_entities', [])
return missed_entities
async def extract_nodes(
clients: GraphitiClients,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
) -> list[EntityNode]:
start = time()
llm_client = clients.llm_client
llm_response = {}
custom_prompt = ''
entities_missed = True
reflexion_iterations = 0
entity_types_context = [
{
'entity_type_id': 0,
'entity_type_name': 'Entity',
'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.',
}
]
entity_types_context += (
[
{
'entity_type_id': i + 1,
'entity_type_name': type_name,
'entity_type_description': type_model.__doc__,
}
for i, (type_name, type_model) in enumerate(entity_types.items())
]
if entity_types is not None
else []
)
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [ep.content for ep in previous_episodes],
'custom_prompt': custom_prompt,
'entity_types': entity_types_context,
'source_description': episode.source_description,
}
while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
if episode.source == EpisodeType.message:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_message(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
prompt_name='extract_nodes.extract_message',
)
elif episode.source == EpisodeType.text:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
prompt_name='extract_nodes.extract_text',
)
elif episode.source == EpisodeType.json:
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_json(context),
response_model=ExtractedEntities,
group_id=episode.group_id,
prompt_name='extract_nodes.extract_json',
)
response_object = ExtractedEntities(**llm_response)
extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
missing_entities = await extract_nodes_reflexion(
llm_client,
episode,
previous_episodes,
[entity.name for entity in extracted_entities],
episode.group_id,
)
entities_missed = len(missing_entities) != 0
custom_prompt = 'Make sure that the following entities are extracted: '
for entity in missing_entities:
custom_prompt += f'\n{entity},'
filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()]
end = time()
logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms')
# Convert the extracted data into EntityNode objects
extracted_nodes = []
for extracted_entity in filtered_extracted_entities:
type_id = extracted_entity.entity_type_id
if 0 <= type_id < len(entity_types_context):
entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
'entity_type_name'
)
else:
entity_type_name = 'Entity'
# Check if this entity type should be excluded
if excluded_entity_types and entity_type_name in excluded_entity_types:
logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
continue
labels: list[str] = list({'Entity', str(entity_type_name)})
new_node = EntityNode(
name=extracted_entity.name,
group_id=episode.group_id,
labels=labels,
summary='',
created_at=utc_now(),
)
extracted_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
return extracted_nodes
async def _collect_candidate_nodes(
clients: GraphitiClients,
extracted_nodes: list[EntityNode],
existing_nodes_override: list[EntityNode] | None,
) -> list[EntityNode]:
"""Search per extracted name and return unique candidates with overrides honored in order."""
search_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients=clients,
query=node.name,
group_ids=[node.group_id],
search_filter=SearchFilters(),
config=NODE_HYBRID_SEARCH_RRF,
)
for node in extracted_nodes
]
)
candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
if existing_nodes_override is not None:
candidate_nodes.extend(existing_nodes_override)
seen_candidate_uuids: set[str] = set()
ordered_candidates: list[EntityNode] = []
for candidate in candidate_nodes:
if candidate.uuid in seen_candidate_uuids:
continue
seen_candidate_uuids.add(candidate.uuid)
ordered_candidates.append(candidate)
return ordered_candidates
async def _resolve_with_llm(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
indexes: DedupCandidateIndexes,
state: DedupResolutionState,
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
entity_types: dict[str, type[BaseModel]] | None,
) -> None:
"""Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
The guardrails below defensively ignore malformed or duplicate LLM responses so the
ingestion workflow remains deterministic even when the model misbehaves.
"""
if not state.unresolved_indices:
return
entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
extracted_nodes_context = [
{
'id': i,
'name': node.name,
'entity_type': node.labels,
'entity_type_description': entity_types_dict.get(
next((item for item in node.labels if item != 'Entity'), '')
).__doc__
or 'Default Entity Type',
}
for i, node in enumerate(llm_extracted_nodes)
]
sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
logger.debug(
'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
len(llm_extracted_nodes),
len(llm_extracted_nodes) - 1,
sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
)
if llm_extracted_nodes:
sample_size = min(3, len(extracted_nodes_context))
logger.debug(
'First %d entities: %s',
sample_size,
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
)
if len(extracted_nodes_context) > 3:
logger.debug(
'Last %d entities: %s',
sample_size,
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
)
existing_nodes_context = [
{
**{
'idx': i,
'name': candidate.name,
'entity_types': candidate.labels,
},
**candidate.attributes,
}
for i, candidate in enumerate(indexes.existing_nodes)
]
context = {
'extracted_nodes': extracted_nodes_context,
'existing_nodes': existing_nodes_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': (
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
),
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.nodes(context),
response_model=NodeResolutions,
prompt_name='dedupe_nodes.nodes',
)
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
valid_relative_range = range(len(state.unresolved_indices))
processed_relative_ids: set[int] = set()
received_ids = {r.id for r in node_resolutions}
expected_ids = set(valid_relative_range)
missing_ids = expected_ids - received_ids
extra_ids = received_ids - expected_ids
logger.debug(
'Received %d resolutions for %d entities',
len(node_resolutions),
len(state.unresolved_indices),
)
if missing_ids:
logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
if extra_ids:
logger.warning(
'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
len(state.unresolved_indices) - 1,
sorted(extra_ids),
sorted(received_ids),
)
for resolution in node_resolutions:
relative_id: int = resolution.id
duplicate_idx: int = resolution.duplicate_idx
if relative_id not in valid_relative_range:
logger.warning(
'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
relative_id,
len(state.unresolved_indices) - 1,
len(node_resolutions),
)
continue
if relative_id in processed_relative_ids:
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
continue
processed_relative_ids.add(relative_id)
original_index = state.unresolved_indices[relative_id]
extracted_node = extracted_nodes[original_index]
resolved_node: EntityNode
if duplicate_idx == -1:
resolved_node = extracted_node
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
resolved_node = indexes.existing_nodes[duplicate_idx]
else:
logger.warning(
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
duplicate_idx,
extracted_node.uuid,
)
resolved_node = extracted_node
state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
if resolved_node.uuid != extracted_node.uuid:
state.duplicate_pairs.append((extracted_node, resolved_node))
async def resolve_extracted_nodes(
clients: GraphitiClients,
extracted_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
"""Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
llm_client = clients.llm_client
driver = clients.driver
existing_nodes = await _collect_candidate_nodes(
clients,
extracted_nodes,
existing_nodes_override,
)
indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
state = DedupResolutionState(
resolved_nodes=[None] * len(extracted_nodes),
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity(extracted_nodes, indexes, state)
await _resolve_with_llm(
llm_client,
extracted_nodes,
indexes,
state,
episode,
previous_episodes,
entity_types,
)
for idx, node in enumerate(extracted_nodes):
if state.resolved_nodes[idx] is None:
state.resolved_nodes[idx] = node
state.uuid_map[node.uuid] = node.uuid
logger.debug(
'Resolved nodes: %s',
[(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
)
new_node_duplicates: list[
tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
return (
[node for node in state.resolved_nodes if node is not None],
state.uuid_map,
new_node_duplicates,
)
async def extract_attributes_from_nodes(
clients: GraphitiClients,
nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
should_summarize_node: NodeSummaryFilter | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
updated_nodes: list[EntityNode] = await semaphore_gather(
*[
extract_attributes_from_node(
llm_client,
node,
episode,
previous_episodes,
(
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
if entity_types is not None
else None
),
should_summarize_node,
)
for node in nodes
]
)
await create_entity_node_embeddings(embedder, updated_nodes)
return updated_nodes
async def extract_attributes_from_node(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_type: type[BaseModel] | None = None,
should_summarize_node: NodeSummaryFilter | None = None,
) -> EntityNode:
# Extract attributes if entity type is defined and has attributes
llm_response = await _extract_entity_attributes(
llm_client, node, episode, previous_episodes, entity_type
)
# Extract summary if needed
await _extract_entity_summary(
llm_client, node, episode, previous_episodes, should_summarize_node
)
node.attributes.update(llm_response)
return node
async def _extract_entity_attributes(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
entity_type: type[BaseModel] | None,
) -> dict[str, Any]:
if entity_type is None or len(entity_type.model_fields) == 0:
return {}
attributes_context = _build_episode_context(
# should not include summary
node_data={
'name': node.name,
'entity_types': node.labels,
'attributes': node.attributes,
},
episode=episode,
previous_episodes=previous_episodes,
)
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(attributes_context),
response_model=entity_type,
model_size=ModelSize.small,
group_id=node.group_id,
prompt_name='extract_nodes.extract_attributes',
)
# validate response
entity_type(**llm_response)
return llm_response
async def _extract_entity_summary(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
should_summarize_node: NodeSummaryFilter | None,
) -> None:
if should_summarize_node is not None and not await should_summarize_node(node):
return
summary_context = _build_episode_context(
node_data={
'name': node.name,
'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
'entity_types': node.labels,
'attributes': node.attributes,
},
episode=episode,
previous_episodes=previous_episodes,
)
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
group_id=node.group_id,
prompt_name='extract_nodes.extract_summary',
)
node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
def _build_episode_context(
node_data: dict[str, Any],
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
) -> dict[str, Any]:
return {
'node': node_data,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': (
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
),
}
```
--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN,
EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE,
get_community_edge_save_query,
get_entity_edge_return_query,
get_entity_edge_save_query,
)
from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
group_id: str = Field(description='partition of the graph')
source_node_uuid: str
target_node_uuid: str
created_at: datetime
@abstractmethod
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_delete(self, driver)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_ {uuid: $uuid})
DETACH DELETE e
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Edge: {self.uuid}')
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_)
WHERE e.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
logger.debug(f'Deleted Edges: {uuids}')
def __hash__(self):
return hash(self.uuid)
def __eq__(self, other):
if isinstance(other, Node):
return self.uuid == other.uuid
return False
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
class EpisodicEdge(Edge):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
EPISODIC_EDGE_SAVE,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
"""
+ EPISODIC_EDGE_RETURN,
uuid=uuid,
routing_='r',
)
edges = [get_episodic_edge_from_record(record) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
"""
+ EPISODIC_EDGE_RETURN,
uuids=uuids,
routing_='r',
)
edges = [get_episodic_edge_from_record(record) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ EPISODIC_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
edges = [get_episodic_edge_from_record(record) for record in records]
if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids)
return edges
class EntityEdge(Edge):
name: str = Field(description='name of the edge, relation name')
fact: str = Field(description='fact representing the edge and nodes that it connects')
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
episodes: list[str] = Field(
default=[],
description='list of episode ids that reference these entity edges',
)
expired_at: datetime | None = Field(
default=None, description='datetime of when the node was invalidated'
)
valid_at: datetime | None = Field(
default=None, description='datetime of when the fact became true'
)
invalid_at: datetime | None = Field(
default=None, description='datetime of when the fact stopped being true'
)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the edge. Dependent on edge name'
)
async def generate_embedding(self, embedder: EmbedderClient):
start = time()
text = self.fact.replace('\n', ' ')
self.fact_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
if driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
if driver.provider == GraphProvider.KUZU:
query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise EdgeNotFoundError(self.uuid)
self.fact_embedding = records[0]['fact_embedding']
async def save(self, driver: GraphDriver):
edge_data: dict[str, Any] = {
'source_uuid': self.source_node_uuid,
'target_uuid': self.target_node_uuid,
'uuid': self.uuid,
'name': self.name,
'group_id': self.group_id,
'fact': self.fact,
'fact_embedding': self.fact_embedding,
'episodes': self.episodes,
'created_at': self.created_at,
'expired_at': self.expired_at,
'valid_at': self.valid_at,
'invalid_at': self.invalid_at,
}
if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
**edge_data,
)
else:
edge_data.update(self.attributes or {})
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
match_query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_between_nodes(
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
):
match_query = """
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $source_node_uuid})
-[:RELATES_TO]->(e:RelatesToNode_)
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if len(uuids) == 0:
return []
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.uuid IN $uuids
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
with_embeddings: bool = False,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
e.fact_embedding AS fact_embedding
"""
if with_embeddings
else ''
)
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids)
return edges
@classmethod
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
node_uuid=node_uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
class CommunityEdge(Edge):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
get_community_edge_save_query(driver.provider),
community_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
RETURN
"""
+ COMMUNITY_EDGE_RETURN,
uuid=uuid,
routing_='r',
)
edges = [get_community_edge_from_record(record) for record in records]
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
RETURN
"""
+ COMMUNITY_EDGE_RETURN,
uuids=uuids,
routing_='r',
)
edges = [get_community_edge_from_record(record) for record in records]
return edges
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ COMMUNITY_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
edges = [get_community_edge_from_record(record) for record in records]
return edges
# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']), # type: ignore
)
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
episodes = record['episodes']
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('source_node_uuid', None)
attributes.pop('target_node_uuid', None)
attributes.pop('fact', None)
attributes.pop('fact_embedding', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('episodes', None)
attributes.pop('created_at', None)
attributes.pop('expired_at', None)
attributes.pop('valid_at', None)
attributes.pop('invalid_at', None)
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
fact_embedding=record.get('fact_embedding'),
name=record['name'],
group_id=record['group_id'],
episodes=episodes,
created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
attributes=attributes,
)
return edge
def get_community_edge_from_record(record: Any):
return CommunityEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']), # type: ignore
)
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
# filter out falsey values from edges
filtered_edges = [edge for edge in edges if edge.fact]
if len(filtered_edges) == 0:
return
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
edge.fact_embedding = fact_embedding
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------
```python
import logging
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupCandidateIndexes,
DedupResolutionState,
_build_candidate_indexes,
_cached_shingles,
_has_high_entropy,
_hash_shingle,
_jaccard_similarity,
_lsh_bands,
_minhash_signature,
_name_entropy,
_normalize_name_for_fuzzy,
_normalize_string_exact,
_resolve_with_similarity,
_shingles,
)
from graphiti_core.utils.maintenance.node_operations import (
_collect_candidate_nodes,
_resolve_with_llm,
extract_attributes_from_node,
extract_attributes_from_nodes,
resolve_extracted_nodes,
)
def _make_clients():
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
llm_generate = AsyncMock()
llm_client.generate_response = llm_generate
clients = GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
)
return clients, llm_generate
def _make_episode(group_id: str = 'group'):
return EpisodicNode(
name='episode',
group_id=group_id,
source=EpisodeType.message,
source_description='test',
content='content',
valid_at=utc_now(),
)
@pytest.mark.asyncio
async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
clients, llm_generate = _make_clients()
candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
async def fake_search(*_, **__):
return SearchResults(nodes=[candidate])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
fake_search,
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
AsyncMock(return_value=[]),
)
resolved, uuid_map, _ = await resolve_extracted_nodes(
clients,
[extracted],
episode=_make_episode(),
previous_episodes=[],
)
assert resolved[0].uuid == candidate.uuid
assert uuid_map[extracted.uuid] == candidate.uuid
llm_generate.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
clients, llm_generate = _make_clients()
llm_generate.return_value = {
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': -1,
'name': 'Joe',
'duplicates': [],
}
]
}
extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])
async def fake_search(*_, **__):
return SearchResults(nodes=[])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
fake_search,
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
AsyncMock(return_value=[]),
)
resolved, uuid_map, _ = await resolve_extracted_nodes(
clients,
[extracted],
episode=_make_episode(),
previous_episodes=[],
)
assert resolved[0].uuid == extracted.uuid
assert uuid_map[extracted.uuid] == extracted.uuid
llm_generate.assert_awaited()
@pytest.mark.asyncio
async def test_resolve_nodes_fuzzy_match(monkeypatch):
clients, llm_generate = _make_clients()
candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
async def fake_search(*_, **__):
return SearchResults(nodes=[candidate])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
fake_search,
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
AsyncMock(return_value=[]),
)
resolved, uuid_map, _ = await resolve_extracted_nodes(
clients,
[extracted],
episode=_make_episode(),
previous_episodes=[],
)
assert resolved[0].uuid == candidate.uuid
assert uuid_map[extracted.uuid] == candidate.uuid
llm_generate.assert_not_awaited()
@pytest.mark.asyncio
async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
clients, _ = _make_clients()
candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
override_duplicate = EntityNode(
uuid=candidate.uuid,
name='Alice Alt',
group_id='group',
labels=['Entity'],
)
extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])
search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
search_mock,
)
result = await _collect_candidate_nodes(
clients,
[extracted],
existing_nodes_override=[override_duplicate],
)
assert len(result) == 1
assert result[0].uuid == candidate.uuid
search_mock.assert_awaited()
def test_build_candidate_indexes_populates_structures():
candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
normalized_key = candidate.name.lower()
assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
assert indexes.nodes_by_uuid[candidate.uuid] is candidate
assert candidate.uuid in indexes.shingles_by_candidate
assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())
def test_normalize_helpers():
assert _normalize_string_exact(' Alice Smith ') == 'alice smith'
assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'
def test_name_entropy_variants():
assert _name_entropy('alice') > _name_entropy('aaaaa')
assert _name_entropy('') == 0.0
def test_has_high_entropy_rules():
assert _has_high_entropy('meaningful name') is True
assert _has_high_entropy('aa') is False
def test_shingles_and_cache():
raw = 'alice'
shingle_set = _shingles(raw)
assert shingle_set == {'ali', 'lic', 'ice'}
assert _cached_shingles(raw) == shingle_set
assert _cached_shingles(raw) is _cached_shingles(raw)
def test_hash_minhash_and_lsh():
shingles = {'abc', 'bcd', 'cde'}
signature = _minhash_signature(shingles)
assert len(signature) == 32
bands = _lsh_bands(signature)
assert all(len(band) == 4 for band in bands)
hashed = {_hash_shingle(s, 0) for s in shingles}
assert len(hashed) == len(shingles)
def test_jaccard_similarity_edges():
a = {'a', 'b'}
b = {'a', 'c'}
assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
assert _jaccard_similarity(set(), set()) == 1.0
assert _jaccard_similarity(a, set()) == 0.0
def test_resolve_with_similarity_exact_match_updates_state():
candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
_resolve_with_similarity([extracted], indexes, state)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.unresolved_indices == []
assert state.duplicate_pairs == [(extracted, candidate)]
def test_resolve_with_similarity_low_entropy_defers_resolution():
extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
indexes = DedupCandidateIndexes(
existing_nodes=[],
nodes_by_uuid={},
normalized_existing=defaultdict(list),
shingles_by_candidate={},
lsh_buckets=defaultdict(list),
)
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
_resolve_with_similarity([extracted], indexes, state)
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate1, candidate2])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
_resolve_with_similarity([extracted], indexes, state)
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
@pytest.mark.asyncio
async def test_resolve_with_llm_updates_unresolved(monkeypatch):
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
captured_context = {}
def fake_prompt_nodes(context):
captured_context.update(context)
return ['prompt']
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
fake_prompt_nodes,
)
async def fake_generate_response(*_, **__):
return {
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 0,
'name': 'Dizzy Gillespie',
'duplicates': [0],
}
]
}
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]
@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 5,
'duplicate_idx': -1,
'name': 'Dexter',
'duplicates': [],
}
]
}
)
with caplog.at_level(logging.WARNING):
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0] is None
assert 'Skipping invalid LLM dedupe id 5' in caplog.text
@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 0,
'name': 'Dizzy Gillespie',
'duplicates': [0],
},
{
'id': 0,
'duplicate_idx': -1,
'name': 'Dizzy',
'duplicates': [],
},
]
}
)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.duplicate_pairs == [(extracted, candidate)]
@pytest.mark.asyncio
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 10,
'name': 'Dexter',
'duplicates': [],
}
]
}
)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0] == extracted
assert state.uuid_map[extracted.uuid] == extracted.uuid
assert state.duplicate_pairs == []
@pytest.mark.asyncio
async def test_extract_attributes_without_callback_generates_summary():
"""Test that summary is generated when no callback is provided (default behavior)."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=None, # No callback provided
)
# Summary should be generated
assert result.summary == 'Generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_skip_summary():
"""Test that summary is NOT regenerated when callback returns False."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'This should not be used', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns False (skip summary generation)
async def skip_summary_filter(node: EntityNode) -> bool:
return False
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=skip_summary_filter,
)
# Summary should remain unchanged
assert result.summary == 'Old summary'
# LLM should NOT have been called for summary
assert llm_client.generate_response.call_count == 0
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_generate_summary():
"""Test that summary is regenerated when callback returns True."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'New generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns True (generate summary)
async def generate_summary_filter(node: EntityNode) -> bool:
return True
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=generate_summary_filter,
)
# Summary should be updated
assert result.summary == 'New generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_selective_callback():
"""Test callback that selectively skips summaries based on node properties."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
topic_node = EntityNode(
name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
)
episode = _make_episode()
# Callback that skips User nodes but generates for others
async def selective_filter(node: EntityNode) -> bool:
return 'User' not in node.labels
result_user = await extract_attributes_from_node(
llm_client,
user_node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=selective_filter,
)
result_topic = await extract_attributes_from_node(
llm_client,
topic_node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=selective_filter,
)
# User summary should remain unchanged
assert result_user.summary == 'Old'
# Topic summary should be generated
assert result_topic.summary == 'Generated summary'
# LLM should have been called only once (for topic)
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_from_nodes_with_callback():
"""Test that callback is properly passed through extract_attributes_from_nodes."""
clients, _ = _make_clients()
clients.llm_client.generate_response = AsyncMock(
return_value={'summary': 'New summary', 'attributes': {}}
)
clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
episode = _make_episode()
call_tracker = []
# Callback that tracks which nodes it's called with
async def tracking_filter(node: EntityNode) -> bool:
call_tracker.append(node.name)
return 'User' not in node.labels
results = await extract_attributes_from_nodes(
clients,
[node1, node2],
episode=episode,
previous_episodes=[],
entity_types=None,
should_summarize_node=tracking_filter,
)
# Callback should have been called for both nodes
assert len(call_tracker) == 2
assert 'Node1' in call_tracker
assert 'Node2' in call_tracker
# Node1 (User) should keep old summary, Node2 (Topic) should get new summary
node1_result = next(n for n in results if n.name == 'Node1')
node2_result = next(n for n in results if n.name == 'Node2')
assert node1_result.summary == 'Old1'
assert node2_result.summary == 'New summary'
```
--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from graphiti_core.llm_client.config import LLMConfig, ModelSize
from graphiti_core.llm_client.errors import RateLimitError
from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
from graphiti_core.prompts.models import Message
# Test model for response testing
class ResponseModel(BaseModel):
"""Test model for response testing."""
test_field: str
optional_field: int = 0
@pytest.fixture
def mock_gemini_client():
"""Fixture to mock the Google Gemini client."""
with patch('google.genai.Client') as mock_client:
# Setup mock instance and its methods
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.generate_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_client(mock_gemini_client):
"""Fixture to create a GeminiClient with a mocked client."""
config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
client = GeminiClient(config=config, cache=False)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_gemini_client
return client
class TestGeminiClientInitialization:
"""Tests for GeminiClient initialization."""
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = GeminiClient(config=config, cache=False, max_tokens=1000)
assert client.config == config
assert client.model == 'test-model'
assert client.temperature == 0.5
assert client.max_tokens == 1000
@patch('google.genai.Client')
def test_init_with_default_model(self, mock_client):
"""Test initialization with default model when none is provided."""
config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
client = GeminiClient(config=config, cache=False)
assert client.model == DEFAULT_MODEL
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
client = GeminiClient(cache=False)
assert client.config is not None
# When no config.model is set, it will be None, not DEFAULT_MODEL
assert client.model is None
@patch('google.genai.Client')
def test_init_with_thinking_config(self, mock_client):
"""Test initialization with thinking config."""
with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
thinking_config = mock_thinking_config.return_value
client = GeminiClient(thinking_config=thinking_config)
assert client.thinking_config == thinking_config
class TestGeminiClientGenerateResponse:
"""Tests for GeminiClient generate_response method."""
@pytest.mark.asyncio
async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
"""Test successful response generation with simple text."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response text'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [Message(role='user', content='Test message')]
result = await gemini_client.generate_response(messages)
# Assertions
assert isinstance(result, dict)
assert result['content'] == 'Test response text'
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_structured_output(
self, gemini_client, mock_gemini_client
):
"""Test response generation with structured output."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
result = await gemini_client.generate_response(
messages=messages, response_model=ResponseModel
)
# Assertions
assert isinstance(result, dict)
assert result['test_field'] == 'test_value'
assert result['optional_field'] == 42
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
"""Test response generation with system message handling."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Response with system context'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
await gemini_client.generate_response(messages)
# Verify system message is processed correctly
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert 'System message' in config.system_instruction
@pytest.mark.asyncio
async def test_get_model_for_size(self, gemini_client):
"""Test model selection based on size."""
# Test small model
small_model = gemini_client._get_model_for_size(ModelSize.small)
assert small_model == DEFAULT_SMALL_MODEL
# Test medium/large model
medium_model = gemini_client._get_model_for_size(ModelSize.medium)
assert medium_model == gemini_client.model
@pytest.mark.asyncio
async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of rate limit errors."""
# Setup mock to raise rate limit error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Rate limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of quota errors."""
# Setup mock to raise quota error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Quota exceeded for requests'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of resource exhausted errors."""
# Setup mock to raise resource exhausted error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'resource_exhausted: Request limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
"""Test handling of safety blocks."""
# Setup mock response with safety block
mock_candidate = MagicMock()
mock_candidate.finish_reason = 'SAFETY'
mock_candidate.safety_ratings = [
MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
]
mock_response = MagicMock()
mock_response.candidates = [mock_candidate]
mock_response.prompt_feedback = None
mock_response.text = ''
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
"""Test handling of prompt blocks."""
# Setup mock response with prompt block
mock_prompt_feedback = MagicMock()
mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
mock_response = MagicMock()
mock_response.candidates = []
mock_response.prompt_feedback = mock_prompt_feedback
mock_response.text = ''
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
"""Test handling of structured output parsing errors."""
# Setup mock response with invalid JSON that will exhaust retries
mock_response = MagicMock()
mock_response.text = 'Invalid JSON that cannot be parsed'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception - should exhaust retries
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
"""Test that safety blocks are not retried."""
# Setup mock response with safety block
mock_candidate = MagicMock()
mock_candidate.finish_reason = 'SAFETY'
mock_candidate.safety_ratings = [
MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
]
mock_response = MagicMock()
mock_response.candidates = [mock_candidate]
mock_response.prompt_feedback = None
mock_response.text = ''
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check that it doesn't retry
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
# Should only be called once (no retries for safety blocks)
assert mock_gemini_client.aio.models.generate_content.call_count == 1
@pytest.mark.asyncio
async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
"""Test retry behavior on validation error."""
# First call returns invalid JSON, second call returns valid data
mock_response1 = MagicMock()
mock_response1.text = 'Invalid JSON that cannot be parsed'
mock_response1.candidates = []
mock_response1.prompt_feedback = None
mock_response2 = MagicMock()
mock_response2.text = '{"test_field": "correct_value"}'
mock_response2.candidates = []
mock_response2.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.side_effect = [
mock_response1,
mock_response2,
]
# Call method
messages = [Message(role='user', content='Test message')]
result = await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content twice due to retry
assert mock_gemini_client.aio.models.generate_content.call_count == 2
assert result['test_field'] == 'correct_value'
@pytest.mark.asyncio
async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
"""Test behavior when max retries are exceeded."""
# Setup mock to always return invalid JSON
mock_response = MagicMock()
mock_response.text = 'Invalid JSON that cannot be parsed'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
"""Test handling of empty responses."""
# Setup mock response with no text
mock_response = MagicMock()
mock_response.text = ''
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with structured output and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have exhausted retries due to empty response (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
"""Test that explicit max_tokens parameter takes precedence over all other values."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with custom max tokens (should take precedence)
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, max_tokens=500)
# Verify explicit max_tokens parameter takes precedence
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Explicit parameter should override everything else
assert config.max_output_tokens == 500
@pytest.mark.asyncio
async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
"""Test max_tokens precedence when no explicit parameter is provided."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Test case 1: No explicit max_tokens, has instance max_tokens
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = GeminiClient(
config=config, cache=False, max_tokens=2000, client=mock_gemini_client
)
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Instance max_tokens should be used
assert config.max_output_tokens == 2000
# Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Model mapping should be used
assert config.max_output_tokens == 65536
@pytest.mark.asyncio
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
"""Test that the correct model is selected based on model size."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with small model size
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, model_size=ModelSize.small)
# Verify correct model is used
call_args = mock_gemini_client.aio.models.generate_content.call_args
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
@pytest.mark.asyncio
async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
"""Test that different Gemini models use their correct max tokens."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Test data: (model_name, expected_max_tokens)
test_cases = [
('gemini-2.5-flash', 65536),
('gemini-2.5-pro', 65536),
('gemini-2.5-flash-lite', 64000),
('gemini-2.0-flash', 8192),
('gemini-1.5-pro', 8192),
('gemini-1.5-flash', 8192),
('unknown-model', 8192), # Fallback case
]
for model_name, expected_max_tokens in test_cases:
# Create client with specific model, no explicit max_tokens to test mapping
config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
# Call method without explicit max_tokens to test model mapping fallback
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
# Verify correct max tokens is used from model mapping
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert config.max_output_tokens == expected_max_tokens, (
f'Model {model_name} should use {expected_max_tokens} tokens'
)
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_client.py'])
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Comprehensive integration test suite for Graphiti MCP Server.
Covers all MCP tools with consideration for LLM inference latency.
"""
import asyncio
import json
import os
import time
from dataclasses import dataclass
from typing import Any
import pytest
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@dataclass
class TestMetrics:
"""Track test performance metrics."""
operation: str
start_time: float
end_time: float
success: bool
details: dict[str, Any]
@property
def duration(self) -> float:
"""Calculate operation duration in seconds."""
return self.end_time - self.start_time
class GraphitiTestClient:
"""Enhanced test client for comprehensive Graphiti MCP testing."""
def __init__(self, test_group_id: str | None = None):
self.test_group_id = test_group_id or f'test_{int(time.time())}'
self.session = None
self.metrics: list[TestMetrics] = []
self.default_timeout = 30 # seconds
async def __aenter__(self):
"""Initialize MCP client session."""
server_params = StdioServerParameters(
command='uv',
args=['run', '../main.py', '--transport', 'stdio'],
env={
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
},
)
self.client_context = stdio_client(server_params)
read, write = await self.client_context.__aenter__()
self.session = ClientSession(read, write)
await self.session.initialize()
# Wait for server to be fully ready
await asyncio.sleep(2)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Clean up client session."""
if self.session:
await self.session.close()
if hasattr(self, 'client_context'):
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
async def call_tool_with_metrics(
self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
) -> tuple[Any, TestMetrics]:
"""Call a tool and capture performance metrics."""
start_time = time.time()
timeout = timeout or self.default_timeout
try:
result = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments), timeout=timeout
)
content = result.content[0].text if result.content else None
success = True
details = {'result': content, 'tool': tool_name}
except asyncio.TimeoutError:
content = None
success = False
details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
except Exception as e:
content = None
success = False
details = {'error': str(e), 'tool': tool_name}
end_time = time.time()
metric = TestMetrics(
operation=f'call_{tool_name}',
start_time=start_time,
end_time=end_time,
success=success,
details=details,
)
self.metrics.append(metric)
return content, metric
async def wait_for_episode_processing(
self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
) -> bool:
"""
Wait for episodes to be processed with intelligent polling.
Args:
expected_count: Number of episodes expected to be processed
max_wait: Maximum seconds to wait
poll_interval: Seconds between status checks
Returns:
True if episodes were processed successfully
"""
start_time = time.time()
while (time.time() - start_time) < max_wait:
result, _ = await self.call_tool_with_metrics(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
)
if result:
try:
episodes = json.loads(result) if isinstance(result, str) else result
if len(episodes.get('episodes', [])) >= expected_count:
return True
except (json.JSONDecodeError, AttributeError):
pass
await asyncio.sleep(poll_interval)
return False
class TestCoreOperations:
"""Test core Graphiti operations."""
@pytest.mark.asyncio
async def test_server_initialization(self):
"""Verify server initializes with all required tools."""
async with GraphitiTestClient() as client:
tools_result = await client.session.list_tools()
tools = {tool.name for tool in tools_result.tools}
required_tools = {
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'delete_entity_edge',
'get_entity_edge',
'clear_graph',
'get_status',
}
missing_tools = required_tools - tools
assert not missing_tools, f'Missing required tools: {missing_tools}'
@pytest.mark.asyncio
async def test_add_text_memory(self):
"""Test adding text-based memories."""
async with GraphitiTestClient() as client:
# Add memory
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Tech Conference Notes',
'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
'source': 'text',
'source_description': 'conference notes',
'group_id': client.test_group_id,
},
)
assert metric.success, f'Failed to add memory: {metric.details}'
assert 'queued' in str(result).lower()
# Wait for processing
processed = await client.wait_for_episode_processing(expected_count=1)
assert processed, 'Episode was not processed within timeout'
@pytest.mark.asyncio
async def test_add_json_memory(self):
"""Test adding structured JSON memories."""
async with GraphitiTestClient() as client:
json_data = {
'project': {
'name': 'GraphitiDB',
'version': '2.0.0',
'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
},
'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
}
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Project Data',
'episode_body': json.dumps(json_data),
'source': 'json',
'source_description': 'project database',
'group_id': client.test_group_id,
},
)
assert metric.success
assert 'queued' in str(result).lower()
@pytest.mark.asyncio
async def test_add_message_memory(self):
"""Test adding conversation/message memories."""
async with GraphitiTestClient() as client:
conversation = """
user: What are the key features of Graphiti?
assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
user: How does it handle entity resolution?
assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
"""
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Feature Discussion',
'episode_body': conversation,
'source': 'message',
'source_description': 'support chat',
'group_id': client.test_group_id,
},
)
assert metric.success
assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
class TestSearchOperations:
"""Test search and retrieval operations."""
@pytest.mark.asyncio
async def test_search_nodes_semantic(self):
"""Test semantic search for nodes."""
async with GraphitiTestClient() as client:
# First add some test data
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Product Launch',
'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
'source': 'text',
'source_description': 'product roadmap',
'group_id': client.test_group_id,
},
)
# Wait for processing
await client.wait_for_episode_processing()
# Search for nodes
result, metric = await client.call_tool_with_metrics(
'search_memory_nodes',
{'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
)
assert metric.success
assert result is not None
@pytest.mark.asyncio
async def test_search_facts_with_filters(self):
"""Test fact search with various filters."""
async with GraphitiTestClient() as client:
# Add test data
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Company Facts',
'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
'source': 'text',
'source_description': 'company profile',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing()
# Search with date filter
result, metric = await client.call_tool_with_metrics(
'search_memory_facts',
{
'query': 'company information',
'group_id': client.test_group_id,
'created_after': '2020-01-01T00:00:00Z',
'limit': 20,
},
)
assert metric.success
@pytest.mark.asyncio
async def test_hybrid_search(self):
"""Test hybrid search combining semantic and keyword search."""
async with GraphitiTestClient() as client:
# Add diverse test data
test_memories = [
{
'name': 'Technical Doc',
'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
'source': 'text',
},
{
'name': 'Architecture',
'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
'source': 'text',
},
]
for memory in test_memories:
memory['group_id'] = client.test_group_id
memory['source_description'] = 'documentation'
await client.call_tool_with_metrics('add_memory', memory)
await client.wait_for_episode_processing(expected_count=2)
# Test semantic + keyword search
result, metric = await client.call_tool_with_metrics(
'search_memory_nodes',
{'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
)
assert metric.success
class TestEpisodeManagement:
"""Test episode lifecycle operations."""
@pytest.mark.asyncio
async def test_get_episodes_pagination(self):
"""Test retrieving episodes with pagination."""
async with GraphitiTestClient() as client:
# Add multiple episodes
for i in range(5):
await client.call_tool_with_metrics(
'add_memory',
{
'name': f'Episode {i}',
'episode_body': f'This is test episode number {i}',
'source': 'text',
'source_description': 'test',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing(expected_count=5)
# Test pagination
result, metric = await client.call_tool_with_metrics(
'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
)
assert metric.success
episodes = json.loads(result) if isinstance(result, str) else result
assert len(episodes.get('episodes', [])) <= 3
@pytest.mark.asyncio
async def test_delete_episode(self):
"""Test deleting specific episodes."""
async with GraphitiTestClient() as client:
# Add an episode
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'To Delete',
'episode_body': 'This episode will be deleted',
'source': 'text',
'source_description': 'test',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing()
# Get episode UUID
result, _ = await client.call_tool_with_metrics(
'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
)
episodes = json.loads(result) if isinstance(result, str) else result
episode_uuid = episodes['episodes'][0]['uuid']
# Delete the episode
result, metric = await client.call_tool_with_metrics(
'delete_episode', {'episode_uuid': episode_uuid}
)
assert metric.success
assert 'deleted' in str(result).lower()
class TestEntityAndEdgeOperations:
"""Test entity and edge management."""
@pytest.mark.asyncio
async def test_get_entity_edge(self):
"""Test retrieving entity edges."""
async with GraphitiTestClient() as client:
# Add data to create entities and edges
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Relationship Data',
'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
'source': 'text',
'source_description': 'org chart',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing()
# Search for nodes to get UUIDs
result, _ = await client.call_tool_with_metrics(
'search_memory_nodes',
{'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
)
# Note: This test assumes edges are created between entities
# Actual edge retrieval would require valid edge UUIDs
@pytest.mark.asyncio
async def test_delete_entity_edge(self):
"""Test deleting entity edges."""
# Similar structure to get_entity_edge but with deletion
pass # Implement based on actual edge creation patterns
class TestErrorHandling:
"""Test error conditions and edge cases."""
@pytest.mark.asyncio
async def test_invalid_tool_arguments(self):
"""Test handling of invalid tool arguments."""
async with GraphitiTestClient() as client:
# Missing required arguments
result, metric = await client.call_tool_with_metrics(
'add_memory',
{'name': 'Incomplete'}, # Missing required fields
)
assert not metric.success
assert 'error' in str(metric.details).lower()
@pytest.mark.asyncio
async def test_timeout_handling(self):
"""Test timeout handling for long operations."""
async with GraphitiTestClient() as client:
# Simulate a very large episode that might time out
large_text = 'Large document content. ' * 10000
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Large Document',
'episode_body': large_text,
'source': 'text',
'source_description': 'large file',
'group_id': client.test_group_id,
},
timeout=5, # Short timeout
)
# Check if timeout was handled gracefully
if not metric.success:
assert 'timeout' in str(metric.details).lower()
@pytest.mark.asyncio
async def test_concurrent_operations(self):
"""Test handling of concurrent operations."""
async with GraphitiTestClient() as client:
# Launch multiple operations concurrently
tasks = []
for i in range(5):
task = client.call_tool_with_metrics(
'add_memory',
{
'name': f'Concurrent {i}',
'episode_body': f'Concurrent operation {i}',
'source': 'text',
'source_description': 'concurrent test',
'group_id': client.test_group_id,
},
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Check that operations were queued successfully
successful = sum(1 for r, m in results if m.success)
assert successful >= 3 # At least 60% should succeed
class TestPerformance:
"""Test performance characteristics and optimization."""
@pytest.mark.asyncio
async def test_latency_metrics(self):
"""Measure and validate operation latencies."""
async with GraphitiTestClient() as client:
operations = [
(
'add_memory',
{
'name': 'Perf Test',
'episode_body': 'Simple text',
'source': 'text',
'source_description': 'test',
'group_id': client.test_group_id,
},
),
(
'search_memory_nodes',
{'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
),
('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
]
for tool_name, args in operations:
_, metric = await client.call_tool_with_metrics(tool_name, args)
# Log performance metrics
print(f'{tool_name}: {metric.duration:.2f}s')
# Basic latency assertions
if tool_name == 'get_episodes':
assert metric.duration < 2, f'{tool_name} too slow'
elif tool_name == 'search_memory_nodes':
assert metric.duration < 10, f'{tool_name} too slow'
@pytest.mark.asyncio
async def test_batch_processing_efficiency(self):
"""Test efficiency of batch operations."""
async with GraphitiTestClient() as client:
batch_size = 10
start_time = time.time()
# Batch add memories
for i in range(batch_size):
await client.call_tool_with_metrics(
'add_memory',
{
'name': f'Batch {i}',
'episode_body': f'Batch content {i}',
'source': 'text',
'source_description': 'batch test',
'group_id': client.test_group_id,
},
)
# Wait for all to process
processed = await client.wait_for_episode_processing(
expected_count=batch_size,
max_wait=120, # Allow more time for batch
)
total_time = time.time() - start_time
avg_time_per_item = total_time / batch_size
assert processed, f'Failed to process {batch_size} items'
assert avg_time_per_item < 15, (
f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
)
# Generate performance report
print('\nBatch Performance Report:')
print(f' Total items: {batch_size}')
print(f' Total time: {total_time:.2f}s')
print(f' Avg per item: {avg_time_per_item:.2f}s')
class TestDatabaseBackends:
"""Test different database backend configurations."""
@pytest.mark.asyncio
@pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
async def test_database_operations(self, database):
"""Test operations with different database backends."""
env_vars = {
'DATABASE_PROVIDER': database,
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
}
if database == 'neo4j':
env_vars.update(
{
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
}
)
elif database == 'falkordb':
env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
# This test would require setting up server with specific database
# Implementation depends on database availability
pass # Placeholder for database-specific tests
def generate_test_report(client: GraphitiTestClient) -> str:
"""Generate a comprehensive test report from metrics."""
if not client.metrics:
return 'No metrics collected'
report = []
report.append('\n' + '=' * 60)
report.append('GRAPHITI MCP TEST REPORT')
report.append('=' * 60)
# Summary statistics
total_ops = len(client.metrics)
successful_ops = sum(1 for m in client.metrics if m.success)
avg_duration = sum(m.duration for m in client.metrics) / total_ops
report.append(f'\nTotal Operations: {total_ops}')
report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
report.append(f'Average Duration: {avg_duration:.2f}s')
# Operation breakdown
report.append('\nOperation Breakdown:')
operation_stats = {}
for metric in client.metrics:
if metric.operation not in operation_stats:
operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
stats = operation_stats[metric.operation]
stats['count'] += 1
stats['success'] += 1 if metric.success else 0
stats['total_duration'] += metric.duration
for op, stats in sorted(operation_stats.items()):
avg_dur = stats['total_duration'] / stats['count']
success_rate = stats['success'] / stats['count'] * 100
report.append(
f' {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
)
# Slowest operations
slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
report.append('\nSlowest Operations:')
for metric in slowest:
report.append(f' {metric.operation}: {metric.duration:.2f}s')
report.append('=' * 60)
return '\n'.join(report)
if __name__ == '__main__':
# Run tests with pytest
pytest.main([__file__, '-v', '--asyncio-mode=auto'])
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/edge_operations.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from datetime import datetime
from time import time
from pydantic import BaseModel
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import (
CommunityEdge,
EntityEdge,
EpisodicEdge,
create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
DEFAULT_EDGE_NAME = 'RELATES_TO'
logger = logging.getLogger(__name__)
def build_episodic_edges(
entity_nodes: list[EntityNode],
episode_uuid: str,
created_at: datetime,
) -> list[EpisodicEdge]:
episodic_edges: list[EpisodicEdge] = [
EpisodicEdge(
source_node_uuid=episode_uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=node.group_id,
)
for node in entity_nodes
]
logger.debug(f'Built episodic edges: {episodic_edges}')
return episodic_edges
def build_community_edges(
entity_nodes: list[EntityNode],
community_node: CommunityNode,
created_at: datetime,
) -> list[CommunityEdge]:
edges: list[CommunityEdge] = [
CommunityEdge(
source_node_uuid=community_node.uuid,
target_node_uuid=node.uuid,
created_at=created_at,
group_id=community_node.group_id,
)
for node in entity_nodes
]
return edges
async def extract_edges(
clients: GraphitiClients,
episode: EpisodicNode,
nodes: list[EntityNode],
previous_episodes: list[EpisodicNode],
edge_type_map: dict[tuple[str, str], list[str]],
group_id: str = '',
edge_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityEdge]:
start = time()
extract_edges_max_tokens = 16384
llm_client = clients.llm_client
edge_type_signature_map: dict[str, tuple[str, str]] = {
edge_type: signature
for signature, edge_types in edge_type_map.items()
for edge_type in edge_types
}
edge_types_context = (
[
{
'fact_type_name': type_name,
'fact_type_signature': edge_type_signature_map.get(type_name, ('Entity', 'Entity')),
'fact_type_description': type_model.__doc__,
}
for type_name, type_model in edge_types.items()
]
if edge_types is not None
else []
)
# Prepare context for LLM
context = {
'episode_content': episode.content,
'nodes': [
{'id': idx, 'name': node.name, 'entity_types': node.labels}
for idx, node in enumerate(nodes)
],
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_time': episode.valid_at,
'edge_types': edge_types_context,
'custom_prompt': '',
}
facts_missed = True
reflexion_iterations = 0
while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
llm_response = await llm_client.generate_response(
prompt_library.extract_edges.edge(context),
response_model=ExtractedEdges,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
prompt_name='extract_edges.edge',
)
edges_data = ExtractedEdges(**llm_response).edges
context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
reflexion_iterations += 1
if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
reflexion_response = await llm_client.generate_response(
prompt_library.extract_edges.reflexion(context),
response_model=MissingFacts,
max_tokens=extract_edges_max_tokens,
group_id=group_id,
prompt_name='extract_edges.reflexion',
)
missing_facts = reflexion_response.get('missing_facts', [])
custom_prompt = 'The following facts were missed in a previous extraction: '
for fact in missing_facts:
custom_prompt += f'\n{fact},'
context['custom_prompt'] = custom_prompt
facts_missed = len(missing_facts) != 0
end = time()
logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
if len(edges_data) == 0:
return []
# Convert the extracted data into EntityEdge objects
edges = []
for edge_data in edges_data:
# Validate Edge Date information
valid_at = edge_data.valid_at
invalid_at = edge_data.invalid_at
valid_at_datetime = None
invalid_at_datetime = None
# Filter out empty edges
if not edge_data.fact.strip():
continue
source_node_idx = edge_data.source_entity_id
target_node_idx = edge_data.target_entity_id
if len(nodes) == 0:
logger.warning('No entities provided for edge extraction')
continue
if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
logger.warning(
f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
)
continue
source_node_uuid = nodes[source_node_idx].uuid
target_node_uuid = nodes[target_node_idx].uuid
if valid_at:
try:
valid_at_datetime = ensure_utc(
datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
if invalid_at:
try:
invalid_at_datetime = ensure_utc(
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
)
except ValueError as e:
logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
edge = EntityEdge(
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
name=edge_data.relation_type,
group_id=group_id,
fact=edge_data.fact,
episodes=[episode.uuid],
created_at=utc_now(),
valid_at=valid_at_datetime,
invalid_at=invalid_at_datetime,
)
edges.append(edge)
logger.debug(
f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
)
logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
return edges
async def resolve_extracted_edges(
clients: GraphitiClients,
extracted_edges: list[EntityEdge],
episode: EpisodicNode,
entities: list[EntityNode],
edge_types: dict[str, type[BaseModel]],
edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
# Fast path: deduplicate exact matches within the extracted edges before parallel processing
seen: dict[tuple[str, str, str], EntityEdge] = {}
deduplicated_edges: list[EntityEdge] = []
for edge in extracted_edges:
key = (
edge.source_node_uuid,
edge.target_node_uuid,
_normalize_string_exact(edge.fact),
)
if key not in seen:
seen[key] = edge
deduplicated_edges.append(edge)
extracted_edges = deduplicated_edges
driver = clients.driver
llm_client = clients.llm_client
embedder = clients.embedder
await create_entity_edge_embeddings(embedder, extracted_edges)
valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
*[
EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
for edge in extracted_edges
]
)
related_edges_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
)
for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
]
)
related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients,
extracted_edge.fact,
group_ids=[extracted_edge.group_id],
config=EDGE_HYBRID_SEARCH_RRF,
search_filter=SearchFilters(),
)
for extracted_edge in extracted_edges
]
)
edge_invalidation_candidates: list[list[EntityEdge]] = [
result.edges for result in edge_invalidation_candidate_results
]
logger.debug(
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
)
# Build entity hash table
uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
# Determine which edge types are relevant for each edge.
# `edge_types_lst` stores the subset of custom edge definitions whose
# node signature matches each extracted edge. Anything outside this subset
# should only stay on the edge if it is a non-custom (LLM generated) label.
edge_types_lst: list[dict[str, type[BaseModel]]] = []
custom_type_names = set(edge_types or {})
for extracted_edge in extracted_edges:
source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
source_node_labels = (
source_node.labels + ['Entity'] if source_node is not None else ['Entity']
)
target_node_labels = (
target_node.labels + ['Entity'] if target_node is not None else ['Entity']
)
label_tuples = [
(source_label, target_label)
for source_label in source_node_labels
for target_label in target_node_labels
]
extracted_edge_types = {}
for label_tuple in label_tuples:
type_names = edge_type_map.get(label_tuple, [])
for type_name in type_names:
type_model = edge_types.get(type_name)
if type_model is None:
continue
extracted_edge_types[type_name] = type_model
edge_types_lst.append(extracted_edge_types)
for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
allowed_type_names = set(extracted_edge_types)
is_custom_name = extracted_edge.name in custom_type_names
if not allowed_type_names:
# No custom types are valid for this node pairing. Keep LLM generated
# labels, but flip disallowed custom names back to the default.
if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
extracted_edge.name = DEFAULT_EDGE_NAME
continue
if is_custom_name and extracted_edge.name not in allowed_type_names:
# Custom name exists but it is not permitted for this source/target
# signature, so fall back to the default edge label.
extracted_edge.name = DEFAULT_EDGE_NAME
# resolve edges with related edges in the graph and find invalidation candidates
results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
await semaphore_gather(
*[
resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
extracted_edge_types,
custom_type_names,
)
for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
extracted_edges,
related_edges_lists,
edge_invalidation_candidates,
edge_types_lst,
strict=True,
)
]
)
)
resolved_edges: list[EntityEdge] = []
invalidated_edges: list[EntityEdge] = []
for result in results:
resolved_edge = result[0]
invalidated_edge_chunk = result[1]
resolved_edges.append(resolved_edge)
invalidated_edges.extend(invalidated_edge_chunk)
logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
await semaphore_gather(
create_entity_edge_embeddings(embedder, resolved_edges),
create_entity_edge_embeddings(embedder, invalidated_edges),
)
return resolved_edges, invalidated_edges
def resolve_edge_contradictions(
resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
) -> list[EntityEdge]:
if len(invalidation_candidates) == 0:
return []
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidation_candidates:
# (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
edge_invalid_at_utc = ensure_utc(edge.invalid_at)
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
edge_valid_at_utc = ensure_utc(edge.valid_at)
resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
if (
edge_invalid_at_utc is not None
and resolved_edge_valid_at_utc is not None
and edge_invalid_at_utc <= resolved_edge_valid_at_utc
) or (
edge_valid_at_utc is not None
and resolved_edge_invalid_at_utc is not None
and resolved_edge_invalid_at_utc <= edge_valid_at_utc
):
continue
# New edge invalidates edge
elif (
edge_valid_at_utc is not None
and resolved_edge_valid_at_utc is not None
and edge_valid_at_utc < resolved_edge_valid_at_utc
):
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
invalidated_edges.append(edge)
return invalidated_edges
async def resolve_extracted_edge(
llm_client: LLMClient,
extracted_edge: EntityEdge,
related_edges: list[EntityEdge],
existing_edges: list[EntityEdge],
episode: EpisodicNode,
edge_type_candidates: dict[str, type[BaseModel]] | None = None,
custom_edge_type_names: set[str] | None = None,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
"""Resolve an extracted edge against existing graph context.
Parameters
----------
llm_client : LLMClient
Client used to invoke the LLM for deduplication and attribute extraction.
extracted_edge : EntityEdge
Newly extracted edge whose canonical representation is being resolved.
related_edges : list[EntityEdge]
Candidate edges with identical endpoints used for duplicate detection.
existing_edges : list[EntityEdge]
Broader set of edges evaluated for contradiction / invalidation.
episode : EpisodicNode
Episode providing content context when extracting edge attributes.
edge_type_candidates : dict[str, type[BaseModel]] | None
Custom edge types permitted for the current source/target signature.
custom_edge_type_names : set[str] | None
Full catalog of registered custom edge names. Used to distinguish
between disallowed custom types (which fall back to the default label)
and ad-hoc labels emitted by the LLM.
Returns
-------
tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
The resolved edge, any duplicates, and edges to invalidate.
"""
if len(related_edges) == 0 and len(existing_edges) == 0:
return extracted_edge, [], []
# Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
normalized_fact = _normalize_string_exact(extracted_edge.fact)
for edge in related_edges:
if (
edge.source_node_uuid == extracted_edge.source_node_uuid
and edge.target_node_uuid == extracted_edge.target_node_uuid
and _normalize_string_exact(edge.fact) == normalized_fact
):
resolved = edge
if episode is not None and episode.uuid not in resolved.episodes:
resolved.episodes.append(episode.uuid)
return resolved, [], []
start = time()
# Prepare context for LLM
related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
invalidation_edge_candidates_context = [
{'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
]
edge_types_context = (
[
{
'fact_type_name': type_name,
'fact_type_description': type_model.__doc__,
}
for type_name, type_model in edge_type_candidates.items()
]
if edge_type_candidates is not None
else []
)
context = {
'existing_edges': related_edges_context,
'new_edge': extracted_edge.fact,
'edge_invalidation_candidates': invalidation_edge_candidates_context,
'edge_types': edge_types_context,
}
if related_edges or existing_edges:
logger.debug(
'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
len(related_edges),
f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
len(existing_edges),
f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
)
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.resolve_edge(context),
response_model=EdgeDuplicate,
model_size=ModelSize.small,
prompt_name='dedupe_edges.resolve_edge',
)
response_object = EdgeDuplicate(**llm_response)
duplicate_facts = response_object.duplicate_facts
# Validate duplicate_facts are in valid range for EXISTING FACTS
invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
if invalid_duplicates:
logger.warning(
'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
invalid_duplicates,
len(related_edges) - 1,
)
duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
resolved_edge = extracted_edge
for duplicate_fact_id in duplicate_fact_ids:
resolved_edge = related_edges[duplicate_fact_id]
break
if duplicate_fact_ids and episode is not None:
resolved_edge.episodes.append(episode.uuid)
contradicted_facts: list[int] = response_object.contradicted_facts
# Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
if invalid_contradictions:
logger.warning(
'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
invalid_contradictions,
len(existing_edges) - 1,
)
invalidation_candidates: list[EntityEdge] = [
existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
]
fact_type: str = response_object.fact_type
candidate_type_names = set(edge_type_candidates or {})
custom_type_names = custom_edge_type_names or set()
is_default_type = fact_type.upper() == 'DEFAULT'
is_custom_type = fact_type in custom_type_names
is_allowed_custom_type = fact_type in candidate_type_names
if is_allowed_custom_type:
# The LLM selected a custom type that is allowed for the node pair.
# Adopt the custom type and, if needed, extract its structured attributes.
resolved_edge.name = fact_type
edge_attributes_context = {
'episode_content': episode.content,
'reference_time': episode.valid_at,
'fact': resolved_edge.fact,
}
edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
if edge_model is not None and len(edge_model.model_fields) != 0:
edge_attributes_response = await llm_client.generate_response(
prompt_library.extract_edges.extract_attributes(edge_attributes_context),
response_model=edge_model, # type: ignore
model_size=ModelSize.small,
prompt_name='extract_edges.extract_attributes',
)
resolved_edge.attributes = edge_attributes_response
elif not is_default_type and is_custom_type:
# The LLM picked a custom type that is not allowed for this signature.
# Reset to the default label and drop any structured attributes.
resolved_edge.name = DEFAULT_EDGE_NAME
resolved_edge.attributes = {}
elif not is_default_type:
# Non-custom labels are allowed to pass through so long as the LLM does
# not return the sentinel DEFAULT value.
resolved_edge.name = fact_type
resolved_edge.attributes = {}
end = time()
logger.debug(
f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
)
now = utc_now()
if resolved_edge.invalid_at and not resolved_edge.expired_at:
resolved_edge.expired_at = now
# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
for candidate in invalidation_candidates:
candidate_valid_at_utc = ensure_utc(candidate.valid_at)
resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
if (
candidate_valid_at_utc is not None
and resolved_edge_valid_at_utc is not None
and candidate_valid_at_utc > resolved_edge_valid_at_utc
):
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break
# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
resolved_edge, invalidation_candidates
)
duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
return resolved_edge, invalidated_edges, duplicate_edges
async def filter_existing_duplicate_of_edges(
driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:
if not duplicates_node_tuples:
return []
duplicate_nodes_map = {
(source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
}
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_nodes = [
{'source': source.uuid, 'target': target.uuid}
for source, target in duplicates_node_tuples
]
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=duplicate_nodes,
routing_='r',
)
else:
if driver.provider == GraphProvider.KUZU:
query = """
UNWIND $duplicate_node_uuids AS duplicate
MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
else:
query: LiteralString = """
UNWIND $duplicate_node_uuids AS duplicate_tuple
MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
RETURN DISTINCT
n.uuid AS source_uuid,
m.uuid AS target_uuid
"""
duplicate_node_uuids = list(duplicate_nodes_map.keys())
records, _, _ = await driver.execute_query(
query,
duplicate_node_uuids=duplicate_node_uuids,
routing_='r',
)
# Remove duplicates that already have the IS_DUPLICATE_OF edge
for record in records:
duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
if duplicate_nodes_map.get(duplicate_tuple):
duplicate_nodes_map.pop(duplicate_tuple)
return list(duplicate_nodes_map.values())
```
--------------------------------------------------------------------------------
/graphiti_core/nodes.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import (
GraphDriver,
GraphProvider,
)
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE,
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query,
get_entity_node_return_query,
get_entity_node_save_query,
get_episode_node_save_query,
)
from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
class EpisodeType(Enum):
"""
Enumeration of different types of episodes that can be processed.
This enum defines the various sources or formats of episodes that the system
can handle. It's used to categorize and potentially handle different types
of input data differently.
Attributes:
-----------
message : str
Represents a standard message-type episode. The content for this type
should be formatted as "actor: content". For example, "user: Hello, how are you?"
or "assistant: I'm doing well, thank you for asking."
json : str
Represents an episode containing a JSON string object with structured data.
text : str
Represents a plain text episode.
"""
message = 'message'
json = 'json'
text = 'text'
@staticmethod
def from_str(episode_type: str):
if episode_type == 'message':
return EpisodeType.message
if episode_type == 'json':
return EpisodeType.json
if episode_type == 'text':
return EpisodeType.text
logger.error(f'Episode type: {episode_type} not implemented')
raise NotImplementedError
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(description='name of the node')
group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: utc_now())
@abstractmethod
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete(self, driver)
match driver.provider:
case GraphProvider.NEO4J:
records, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
WHERE n:Entity OR n:Episodic OR n:Community
OPTIONAL MATCH (n)-[r]-()
WITH collect(r.uuid) AS edge_uuids, n
DETACH DELETE n
RETURN edge_uuids
""",
uuid=self.uuid,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Node: {self.uuid}')
def __hash__(self):
return hash(self.uuid)
def __eq__(self, other):
if isinstance(other, Node):
return self.uuid == other.uuid
return False
@classmethod
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete_by_group_id(
cls, driver, group_id, batch_size
)
match driver.provider:
case GraphProvider.NEO4J:
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
CALL (n) {
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
group_id=group_id,
batch_size=batch_size,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
group_id=group_id,
)
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})
DETACH DELETE n
""",
group_id=group_id,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete_by_uuids(
cls, driver, uuids, group_id=None, batch_size=batch_size
)
match driver.provider:
case GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
WHERE n.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
)
await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
case _: # Neo4J, Neptune
async with driver.session() as session:
# Collect all edge UUIDs before deleting nodes
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
MATCH (n)-[r]-()
RETURN collect(r.uuid) AS edge_uuids
""",
uuids=uuids,
)
# Now delete the nodes in batches
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
CALL (n) {
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
uuids=uuids,
batch_size=batch_size,
)
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
class EpisodicNode(Node):
source: EpisodeType = Field(description='source type')
source_description: str = Field(description='description of the data source')
content: str = Field(description='raw episode data')
valid_at: datetime = Field(
description='datetime of when the original document was created',
)
entity_edges: list[str] = Field(
description='list of entity edges referenced in this episode',
default_factory=list,
)
async def save(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.episodic_node_save(self, driver)
episode_args = {
'uuid': self.uuid,
'name': self.name,
'group_id': self.group_id,
'source_description': self.source_description,
'content': self.content,
'entity_edges': self.entity_edges,
'created_at': self.created_at,
'valid_at': self.valid_at,
'source': self.source.value,
}
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
RETURN
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuid=uuid,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
if len(episodes) == 0:
raise NodeNotFoundError(uuid)
return episodes[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)
WHERE e.uuid IN $uuids
RETURN DISTINCT
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuids=uuids,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN DISTINCT
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
)
+ """
ORDER BY uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
@classmethod
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
RETURN DISTINCT
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
entity_node_uuid=entity_node_uuid,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the node. Dependent on node labels'
)
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
self.name_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
else:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
self.name_embedding = records[0]['name_embedding']
async def save(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_save(self, driver)
entity_data: dict[str, Any] = {
'uuid': self.uuid,
'name': self.name,
'name_embedding': self.name_embedding,
'group_id': self.group_id,
'summary': self.summary,
'created_at': self.created_at,
}
if driver.provider == GraphProvider.KUZU:
entity_data['attributes'] = json.dumps(self.attributes)
entity_data['labels'] = list(set(self.labels + ['Entity']))
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels=''),
**entity_data,
)
else:
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity'])
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
RETURN
"""
+ get_entity_node_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
RETURN
"""
+ get_entity_node_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
with_embeddings: bool = False,
):
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
n.name_embedding AS name_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY n.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
class CommunityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='region summary of member nodes', default_factory=str)
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)
result = await driver.execute_query(
get_community_node_save_query(driver.provider), # type: ignore
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
self.name_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
"""
else:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
self.name_embedding = records[0]['name_embedding']
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community {uuid: $uuid})
RETURN
"""
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuid=uuid,
routing_='r',
)
nodes = [get_community_node_from_record(record) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)
WHERE c.uuid IN $uuids
RETURN
"""
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuids=uuids,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)
WHERE c.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
)
+ """
ORDER BY c.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities
# Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
created_at = parse_db_date(record['created_at'])
valid_at = parse_db_date(record['valid_at'])
if created_at is None:
raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
if valid_at is None:
raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
return EpisodicNode(
content=record['content'],
created_at=created_at,
valid_at=valid_at,
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
entity_edges=record['entity_edges'],
)
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('name_embedding', None)
attributes.pop('summary', None)
attributes.pop('created_at', None)
attributes.pop('labels', None)
labels = record.get('labels', [])
group_id = record.get('group_id')
if 'Entity_' + group_id.replace('-', '') in labels:
labels.remove('Entity_' + group_id.replace('-', ''))
entity_node = EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record.get('name_embedding'),
group_id=group_id,
labels=labels,
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
attributes=attributes,
)
return entity_node
def get_community_node_from_record(record: Any) -> CommunityNode:
return CommunityNode(
uuid=record['uuid'],
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
)
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
# filter out falsey values from nodes
filtered_nodes = [node for node in nodes if node.name]
if not filtered_nodes:
return
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
node.name_embedding = name_embedding
```
--------------------------------------------------------------------------------
/mcp_server/src/graphiti_mcp_server.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
"""
import argparse
import asyncio
import logging
import os
import sys
from pathlib import Path
from typing import Any, Optional
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel
from starlette.responses import JSONResponse
from config.schema import GraphitiConfig, ServerConfig
from models.response_types import (
EpisodeSearchResponse,
ErrorResponse,
FactSearchResponse,
NodeResult,
NodeSearchResponse,
StatusResponse,
SuccessResponse,
)
from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
from services.queue_service import QueueService
from utils.formatting import format_fact_result
# Load .env file from mcp_server directory
mcp_server_dir = Path(__file__).parent.parent
env_file = mcp_server_dir / '.env'
if env_file.exists():
load_dotenv(env_file)
else:
# Try current working directory as fallback
load_dotenv()
# Semaphore limit for concurrent Graphiti operations.
#
# This controls how many episodes can be processed simultaneously. Each episode
# processing involves multiple LLM calls (entity extraction, deduplication, etc.),
# so the actual number of concurrent LLM requests will be higher.
#
# TUNING GUIDELINES:
#
# LLM Provider Rate Limits (requests per minute):
# - OpenAI Tier 1 (free): 3 RPM -> SEMAPHORE_LIMIT=1-2
# - OpenAI Tier 2: 60 RPM -> SEMAPHORE_LIMIT=5-8
# - OpenAI Tier 3: 500 RPM -> SEMAPHORE_LIMIT=10-15
# - OpenAI Tier 4: 5,000 RPM -> SEMAPHORE_LIMIT=20-50
# - Anthropic (default): 50 RPM -> SEMAPHORE_LIMIT=5-8
# - Anthropic (high tier): 1,000 RPM -> SEMAPHORE_LIMIT=15-30
# - Azure OpenAI (varies): Consult your quota -> adjust accordingly
#
# SYMPTOMS:
# - Too high: 429 rate limit errors, increased costs from parallel processing
# - Too low: Slow throughput, underutilized API quota
#
# MONITORING:
# - Watch logs for rate limit errors (429)
# - Monitor episode processing times
# - Check LLM provider dashboard for actual request rates
#
# DEFAULT: 10 (suitable for OpenAI Tier 3, mid-tier Anthropic)
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))
# Configure structured logging with timestamps
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(
level=logging.INFO,
format=LOG_FORMAT,
datefmt=DATE_FORMAT,
stream=sys.stderr,
)
# Configure specific loggers
logging.getLogger('uvicorn').setLevel(logging.INFO)
logging.getLogger('uvicorn.access').setLevel(logging.WARNING) # Reduce access log noise
logging.getLogger('mcp.server.streamable_http_manager').setLevel(
logging.WARNING
) # Reduce MCP noise
# Patch uvicorn's logging config to use our format
def configure_uvicorn_logging():
"""Configure uvicorn loggers to match our format after they're created."""
for logger_name in ['uvicorn', 'uvicorn.error', 'uvicorn.access']:
uvicorn_logger = logging.getLogger(logger_name)
# Remove existing handlers and add our own with proper formatting
uvicorn_logger.handlers.clear()
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT))
uvicorn_logger.addHandler(handler)
uvicorn_logger.propagate = False
logger = logging.getLogger(__name__)
# Create global config instance - will be properly initialized later
config: GraphitiConfig
# MCP server instructions
GRAPHITI_MCP_INSTRUCTIONS = """
Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well
with dynamic data such as user interactions, changing enterprise data, and external information.
Graphiti transforms information into a richly connected knowledge network, allowing you to
capture relationships between concepts, entities, and information. The system organizes data as episodes
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic,
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including
structured JSON data, enabling seamless integration with existing data pipelines and systems.
Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid
(superseded by new information).
Key capabilities:
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
3. Find relevant facts (relationships between entities) with search_facts
4. Retrieve specific entity edges or episodes by UUID
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph
The server connects to a database for persistent storage and uses language models for certain operations.
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.
When adding information, provide descriptive names and detailed content to improve search quality.
When searching, use specific queries and consider filtering by group_id for more relevant results.
For optimal performance, ensure the database is properly configured and accessible, and valid
API keys are provided for any language model operations.
"""
# MCP server instance
mcp = FastMCP(
'Graphiti Agent Memory',
instructions=GRAPHITI_MCP_INSTRUCTIONS,
)
# Global services
graphiti_service: Optional['GraphitiService'] = None
queue_service: QueueService | None = None
# Global client for backward compatibility
graphiti_client: Graphiti | None = None
semaphore: asyncio.Semaphore
class GraphitiService:
"""Graphiti service using the unified configuration system."""
def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
self.config = config
self.semaphore_limit = semaphore_limit
self.semaphore = asyncio.Semaphore(semaphore_limit)
self.client: Graphiti | None = None
self.entity_types = None
async def initialize(self) -> None:
"""Initialize the Graphiti client with factory-created components."""
try:
# Create clients using factories
llm_client = None
embedder_client = None
# Create LLM client based on configured provider
try:
llm_client = LLMClientFactory.create(self.config.llm)
except Exception as e:
logger.warning(f'Failed to create LLM client: {e}')
# Create embedder client based on configured provider
try:
embedder_client = EmbedderFactory.create(self.config.embedder)
except Exception as e:
logger.warning(f'Failed to create embedder client: {e}')
# Get database configuration
db_config = DatabaseDriverFactory.create_config(self.config.database)
# Build entity types from configuration
custom_types = None
if self.config.graphiti.entity_types:
custom_types = {}
for entity_type in self.config.graphiti.entity_types:
# Create a dynamic Pydantic model for each entity type
# Note: Don't use 'name' as it's a protected Pydantic attribute
entity_model = type(
entity_type.name,
(BaseModel,),
{
'__doc__': entity_type.description,
},
)
custom_types[entity_type.name] = entity_model
# Store entity types for later use
self.entity_types = custom_types
# Initialize Graphiti client with appropriate driver
try:
if self.config.database.provider.lower() == 'falkordb':
# For FalkorDB, create a FalkorDriver instance directly
from graphiti_core.driver.falkordb_driver import FalkorDriver
falkor_driver = FalkorDriver(
host=db_config['host'],
port=db_config['port'],
password=db_config['password'],
database=db_config['database'],
)
self.client = Graphiti(
graph_driver=falkor_driver,
llm_client=llm_client,
embedder=embedder_client,
max_coroutines=self.semaphore_limit,
)
else:
# For Neo4j (default), use the original approach
self.client = Graphiti(
uri=db_config['uri'],
user=db_config['user'],
password=db_config['password'],
llm_client=llm_client,
embedder=embedder_client,
max_coroutines=self.semaphore_limit,
)
except Exception as db_error:
# Check for connection errors
error_msg = str(db_error).lower()
if 'connection refused' in error_msg or 'could not connect' in error_msg:
db_provider = self.config.database.provider
if db_provider.lower() == 'falkordb':
raise RuntimeError(
f'\n{"=" * 70}\n'
f'Database Connection Error: FalkorDB is not running\n'
f'{"=" * 70}\n\n'
f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
f'To start FalkorDB:\n'
f' - Using Docker Compose: cd mcp_server && docker compose up\n'
f' - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
f'{"=" * 70}\n'
) from db_error
elif db_provider.lower() == 'neo4j':
raise RuntimeError(
f'\n{"=" * 70}\n'
f'Database Connection Error: Neo4j is not running\n'
f'{"=" * 70}\n\n'
f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
f'To start Neo4j:\n'
f' - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
f' - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
f' - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
f'{"=" * 70}\n'
) from db_error
else:
raise RuntimeError(
f'\n{"=" * 70}\n'
f'Database Connection Error: {db_provider} is not running\n'
f'{"=" * 70}\n\n'
f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
f'Please ensure {db_provider} is running and accessible.\n\n'
f'{"=" * 70}\n'
) from db_error
# Re-raise other errors
raise
# Build indices
await self.client.build_indices_and_constraints()
logger.info('Successfully initialized Graphiti client')
# Log configuration details
if llm_client:
logger.info(
f'Using LLM provider: {self.config.llm.provider} / {self.config.llm.model}'
)
else:
logger.info('No LLM client configured - entity extraction will be limited')
if embedder_client:
logger.info(f'Using Embedder provider: {self.config.embedder.provider}')
else:
logger.info('No Embedder client configured - search will be limited')
if self.entity_types:
entity_type_names = list(self.entity_types.keys())
logger.info(f'Using custom entity types: {", ".join(entity_type_names)}')
else:
logger.info('Using default entity types')
logger.info(f'Using database: {self.config.database.provider}')
logger.info(f'Using group_id: {self.config.graphiti.group_id}')
except Exception as e:
logger.error(f'Failed to initialize Graphiti client: {e}')
raise
async def get_client(self) -> Graphiti:
"""Get the Graphiti client, initializing if necessary."""
if self.client is None:
await self.initialize()
if self.client is None:
raise RuntimeError('Failed to initialize Graphiti client')
return self.client
@mcp.tool()
async def add_memory(
name: str,
episode_body: str,
group_id: str | None = None,
source: str = 'text',
source_description: str = '',
uuid: str | None = None,
) -> SuccessResponse | ErrorResponse:
"""Add an episode to memory. This is the primary way to add information to the graph.
This function returns immediately and processes the episode addition in the background.
Episodes for the same group_id are processed sequentially to avoid race conditions.
Args:
name (str): Name of the episode
episode_body (str): The content of the episode to persist to memory. When source='json', this must be a
properly escaped JSON string, not a raw Python dictionary. The JSON data will be
automatically processed to extract entities and relationships.
group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
or a generated one.
source (str, optional): Source type, must be one of:
- 'text': For plain text content (default)
- 'json': For structured data
- 'message': For conversation-style content
source_description (str, optional): Description of the source
uuid (str, optional): Optional UUID for the episode
Examples:
# Adding plain text content
add_memory(
name="Company News",
episode_body="Acme Corp announced a new product line today.",
source="text",
source_description="news article",
group_id="some_arbitrary_string"
)
# Adding structured JSON data
# NOTE: episode_body should be a JSON string (standard JSON escaping)
add_memory(
name="Customer Profile",
episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}',
source="json",
source_description="CRM data"
)
"""
global graphiti_service, queue_service
if graphiti_service is None or queue_service is None:
return ErrorResponse(error='Services not initialized')
try:
# Use the provided group_id or fall back to the default from config
effective_group_id = group_id or config.graphiti.group_id
# Try to parse the source as an EpisodeType enum, with fallback to text
episode_type = EpisodeType.text # Default
if source:
try:
episode_type = EpisodeType[source.lower()]
except (KeyError, AttributeError):
# If the source doesn't match any enum value, use text as default
logger.warning(f"Unknown source type '{source}', using 'text' as default")
episode_type = EpisodeType.text
# Submit to queue service for async processing
await queue_service.add_episode(
group_id=effective_group_id,
name=name,
content=episode_body,
source_description=source_description,
episode_type=episode_type,
entity_types=graphiti_service.entity_types,
uuid=uuid or None, # Ensure None is passed if uuid is None
)
return SuccessResponse(
message=f"Episode '{name}' queued for processing in group '{effective_group_id}'"
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error queuing episode: {error_msg}')
return ErrorResponse(error=f'Error queuing episode: {error_msg}')
@mcp.tool()
async def search_nodes(
query: str,
group_ids: list[str] | None = None,
max_nodes: int = 10,
entity_types: list[str] | None = None,
) -> NodeSearchResponse | ErrorResponse:
"""Search for nodes in the graph memory.
Args:
query: The search query
group_ids: Optional list of group IDs to filter results
max_nodes: Maximum number of nodes to return (default: 10)
entity_types: Optional list of entity type names to filter by
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
)
# Create search filters
search_filters = SearchFilters(
node_labels=entity_types,
)
# Use the search_ method with node search config
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
results = await client.search_(
query=query,
config=NODE_HYBRID_SEARCH_RRF,
group_ids=effective_group_ids,
search_filter=search_filters,
)
# Extract nodes from results
nodes = results.nodes[:max_nodes] if results.nodes else []
if not nodes:
return NodeSearchResponse(message='No relevant nodes found', nodes=[])
# Format the results
node_results = []
for node in nodes:
# Get attributes and ensure no embeddings are included
attrs = node.attributes if hasattr(node, 'attributes') else {}
# Remove any embedding keys that might be in attributes
attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}
node_results.append(
NodeResult(
uuid=node.uuid,
name=node.name,
labels=node.labels if node.labels else [],
created_at=node.created_at.isoformat() if node.created_at else None,
summary=node.summary,
group_id=node.group_id,
attributes=attrs,
)
)
return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results)
except Exception as e:
error_msg = str(e)
logger.error(f'Error searching nodes: {error_msg}')
return ErrorResponse(error=f'Error searching nodes: {error_msg}')
@mcp.tool()
async def search_memory_facts(
query: str,
group_ids: list[str] | None = None,
max_facts: int = 10,
center_node_uuid: str | None = None,
) -> FactSearchResponse | ErrorResponse:
"""Search the graph memory for relevant facts.
Args:
query: The search query
group_ids: Optional list of group IDs to filter results
max_facts: Maximum number of facts to return (default: 10)
center_node_uuid: Optional UUID of a node to center the search around
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
# Validate max_facts parameter
if max_facts <= 0:
return ErrorResponse(error='max_facts must be a positive integer')
client = await graphiti_service.get_client()
# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
)
relevant_edges = await client.search(
group_ids=effective_group_ids,
query=query,
num_results=max_facts,
center_node_uuid=center_node_uuid,
)
if not relevant_edges:
return FactSearchResponse(message='No relevant facts found', facts=[])
facts = [format_fact_result(edge) for edge in relevant_edges]
return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
except Exception as e:
error_msg = str(e)
logger.error(f'Error searching facts: {error_msg}')
return ErrorResponse(error=f'Error searching facts: {error_msg}')
@mcp.tool()
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an entity edge from the graph memory.
Args:
uuid: UUID of the entity edge to delete
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Get the entity edge by UUID
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
# Delete the edge using its delete method
await entity_edge.delete(client.driver)
return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
except Exception as e:
error_msg = str(e)
logger.error(f'Error deleting entity edge: {error_msg}')
return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')
@mcp.tool()
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
"""Delete an episode from the graph memory.
Args:
uuid: UUID of the episode to delete
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Get the episodic node by UUID
episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
# Delete the node using its delete method
await episodic_node.delete(client.driver)
return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
except Exception as e:
error_msg = str(e)
logger.error(f'Error deleting episode: {error_msg}')
return ErrorResponse(error=f'Error deleting episode: {error_msg}')
@mcp.tool()
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
"""Get an entity edge from the graph memory by its UUID.
Args:
uuid: UUID of the entity edge to retrieve
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Get the entity edge directly using the EntityEdge class method
entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
# Use the format_fact_result function to serialize the edge
# Return the Python dict directly - MCP will handle serialization
return format_fact_result(entity_edge)
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting entity edge: {error_msg}')
return ErrorResponse(error=f'Error getting entity edge: {error_msg}')
@mcp.tool()
async def get_episodes(
group_ids: list[str] | None = None,
max_episodes: int = 10,
) -> EpisodeSearchResponse | ErrorResponse:
"""Get episodes from the graph memory.
Args:
group_ids: Optional list of group IDs to filter results
max_episodes: Maximum number of episodes to return (default: 10)
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids
if group_ids is not None
else [config.graphiti.group_id]
if config.graphiti.group_id
else []
)
# Get episodes from the driver directly
from graphiti_core.nodes import EpisodicNode
if effective_group_ids:
episodes = await EpisodicNode.get_by_group_ids(
client.driver, effective_group_ids, limit=max_episodes
)
else:
# If no group IDs, we need to use a different approach
# For now, return empty list when no group IDs specified
episodes = []
if not episodes:
return EpisodeSearchResponse(message='No episodes found', episodes=[])
# Format the results
episode_results = []
for episode in episodes:
episode_dict = {
'uuid': episode.uuid,
'name': episode.name,
'content': episode.content,
'created_at': episode.created_at.isoformat() if episode.created_at else None,
'source': episode.source.value
if hasattr(episode.source, 'value')
else str(episode.source),
'source_description': episode.source_description,
'group_id': episode.group_id,
}
episode_results.append(episode_dict)
return EpisodeSearchResponse(
message='Episodes retrieved successfully', episodes=episode_results
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error getting episodes: {error_msg}')
return ErrorResponse(error=f'Error getting episodes: {error_msg}')
@mcp.tool()
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
"""Clear all data from the graph for specified group IDs.
Args:
group_ids: Optional list of group IDs to clear. If not provided, clears the default group.
"""
global graphiti_service
if graphiti_service is None:
return ErrorResponse(error='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Use the provided group_ids or fall back to the default from config if none provided
effective_group_ids = (
group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
)
if not effective_group_ids:
return ErrorResponse(error='No group IDs specified for clearing')
# Clear data for the specified group IDs
await clear_data(client.driver, group_ids=effective_group_ids)
return SuccessResponse(
message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}'
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error clearing graph: {error_msg}')
return ErrorResponse(error=f'Error clearing graph: {error_msg}')
@mcp.tool()
async def get_status() -> StatusResponse:
"""Get the status of the Graphiti MCP server and database connection."""
global graphiti_service
if graphiti_service is None:
return StatusResponse(status='error', message='Graphiti service not initialized')
try:
client = await graphiti_service.get_client()
# Test database connection with a simple query
async with client.driver.session() as session:
result = await session.run('MATCH (n) RETURN count(n) as count')
# Consume the result to verify query execution
if result:
_ = [record async for record in result]
# Use the provider from the service's config, not the global
provider_name = graphiti_service.config.database.provider
return StatusResponse(
status='ok',
message=f'Graphiti MCP server is running and connected to {provider_name} database',
)
except Exception as e:
error_msg = str(e)
logger.error(f'Error checking database connection: {error_msg}')
return StatusResponse(
status='error',
message=f'Graphiti MCP server is running but database connection failed: {error_msg}',
)
@mcp.custom_route('/health', methods=['GET'])
async def health_check(request) -> JSONResponse:
"""Health check endpoint for Docker and load balancers."""
return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'})
async def initialize_server() -> ServerConfig:
"""Parse CLI arguments and initialize the Graphiti server configuration."""
global config, graphiti_service, queue_service, graphiti_client, semaphore
parser = argparse.ArgumentParser(
description='Run the Graphiti MCP server with YAML configuration support'
)
# Configuration file argument
# Default to config/config.yaml relative to the mcp_server directory
default_config = Path(__file__).parent.parent / 'config' / 'config.yaml'
parser.add_argument(
'--config',
type=Path,
default=default_config,
help='Path to YAML configuration file (default: config/config.yaml)',
)
# Transport arguments
parser.add_argument(
'--transport',
choices=['sse', 'stdio', 'http'],
help='Transport to use: http (recommended, default), stdio (standard I/O), or sse (deprecated)',
)
parser.add_argument(
'--host',
help='Host to bind the MCP server to',
)
parser.add_argument(
'--port',
type=int,
help='Port to bind the MCP server to',
)
# Provider selection arguments
parser.add_argument(
'--llm-provider',
choices=['openai', 'azure_openai', 'anthropic', 'gemini', 'groq'],
help='LLM provider to use',
)
parser.add_argument(
'--embedder-provider',
choices=['openai', 'azure_openai', 'gemini', 'voyage'],
help='Embedder provider to use',
)
parser.add_argument(
'--database-provider',
choices=['neo4j', 'falkordb'],
help='Database provider to use',
)
# LLM configuration arguments
parser.add_argument('--model', help='Model name to use with the LLM client')
parser.add_argument('--small-model', help='Small model name to use with the LLM client')
parser.add_argument(
'--temperature', type=float, help='Temperature setting for the LLM (0.0-2.0)'
)
# Embedder configuration arguments
parser.add_argument('--embedder-model', help='Model name to use with the embedder')
# Graphiti-specific arguments
parser.add_argument(
'--group-id',
help='Namespace for the graph. If not provided, uses config file or generates random UUID.',
)
parser.add_argument(
'--user-id',
help='User ID for tracking operations',
)
parser.add_argument(
'--destroy-graph',
action='store_true',
help='Destroy all Graphiti graphs on startup',
)
args = parser.parse_args()
# Set config path in environment for the settings to pick up
if args.config:
os.environ['CONFIG_PATH'] = str(args.config)
# Load configuration with environment variables and YAML
config = GraphitiConfig()
# Apply CLI overrides
config.apply_cli_overrides(args)
# Also apply legacy CLI args for backward compatibility
if hasattr(args, 'destroy_graph'):
config.destroy_graph = args.destroy_graph
# Log configuration details
logger.info('Using configuration:')
logger.info(f' - LLM: {config.llm.provider} / {config.llm.model}')
logger.info(f' - Embedder: {config.embedder.provider} / {config.embedder.model}')
logger.info(f' - Database: {config.database.provider}')
logger.info(f' - Group ID: {config.graphiti.group_id}')
logger.info(f' - Transport: {config.server.transport}')
# Log graphiti-core version
try:
import graphiti_core
graphiti_version = getattr(graphiti_core, '__version__', 'unknown')
logger.info(f' - Graphiti Core: {graphiti_version}')
except Exception:
# Check for Docker-stored version file
version_file = Path('/app/.graphiti-core-version')
if version_file.exists():
graphiti_version = version_file.read_text().strip()
logger.info(f' - Graphiti Core: {graphiti_version}')
else:
logger.info(' - Graphiti Core: version unavailable')
# Handle graph destruction if requested
if hasattr(config, 'destroy_graph') and config.destroy_graph:
logger.warning('Destroying all Graphiti graphs as requested...')
temp_service = GraphitiService(config, SEMAPHORE_LIMIT)
await temp_service.initialize()
client = await temp_service.get_client()
await clear_data(client.driver)
logger.info('All graphs destroyed')
# Initialize services
graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
queue_service = QueueService()
await graphiti_service.initialize()
# Set global client for backward compatibility
graphiti_client = await graphiti_service.get_client()
semaphore = graphiti_service.semaphore
# Initialize queue service with the client
await queue_service.initialize(graphiti_client)
# Set MCP server settings
if config.server.host:
mcp.settings.host = config.server.host
if config.server.port:
mcp.settings.port = config.server.port
# Return MCP configuration for transport
return config.server
async def run_mcp_server():
"""Run the MCP server in the current event loop."""
# Initialize the server
mcp_config = await initialize_server()
# Run the server with configured transport
logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
if mcp_config.transport == 'stdio':
await mcp.run_stdio_async()
elif mcp_config.transport == 'sse':
logger.info(
f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
)
logger.info(f'Access the server at: http://{mcp.settings.host}:{mcp.settings.port}/sse')
await mcp.run_sse_async()
elif mcp_config.transport == 'http':
# Use localhost for display if binding to 0.0.0.0
display_host = 'localhost' if mcp.settings.host == '0.0.0.0' else mcp.settings.host
logger.info(
f'Running MCP server with streamable HTTP transport on {mcp.settings.host}:{mcp.settings.port}'
)
logger.info('=' * 60)
logger.info('MCP Server Access Information:')
logger.info(f' Base URL: http://{display_host}:{mcp.settings.port}/')
logger.info(f' MCP Endpoint: http://{display_host}:{mcp.settings.port}/mcp/')
logger.info(' Transport: HTTP (streamable)')
# Show FalkorDB Browser UI access if enabled
if os.environ.get('BROWSER', '1') == '1':
logger.info(f' FalkorDB Browser UI: http://{display_host}:3000/')
logger.info('=' * 60)
logger.info('For MCP clients, connect to the /mcp/ endpoint above')
# Configure uvicorn logging to match our format
configure_uvicorn_logging()
await mcp.run_streamable_http_async()
else:
raise ValueError(
f'Unsupported transport: {mcp_config.transport}. Use "sse", "stdio", or "http"'
)
def main():
"""Main function to run the Graphiti MCP server."""
try:
# Run everything in a single event loop
asyncio.run(run_mcp_server())
except KeyboardInterrupt:
logger.info('Server shutting down...')
except Exception as e:
logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
raise
if __name__ == '__main__':
main()
```