#
tokens: 48409/50000 9/234 files (page 6/9)
lines: off (toggle) GitHub
raw markdown copy
This is page 6 of 9. Use http://codebase.md/getzep/graphiti?lines=false&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/mcp_server/tests/test_stress_load.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Stress and load testing for Graphiti MCP Server.
Tests system behavior under high load, resource constraints, and edge conditions.
"""

import asyncio
import gc
import random
import time
from dataclasses import dataclass

import psutil
import pytest
from test_fixtures import TestDataGenerator, graphiti_test_client


@dataclass
class LoadTestConfig:
    """Configuration for load testing scenarios."""

    num_clients: int = 10
    operations_per_client: int = 100
    ramp_up_time: float = 5.0  # seconds
    test_duration: float = 60.0  # seconds
    target_throughput: float | None = None  # ops/sec
    think_time: float = 0.1  # seconds between ops


@dataclass
class LoadTestResult:
    """Results from a load test run."""

    total_operations: int
    successful_operations: int
    failed_operations: int
    duration: float
    throughput: float
    average_latency: float
    p50_latency: float
    p95_latency: float
    p99_latency: float
    max_latency: float
    errors: dict[str, int]
    resource_usage: dict[str, float]


class LoadTester:
    """Orchestrate load testing scenarios."""

    def __init__(self, config: LoadTestConfig):
        self.config = config
        self.metrics: list[tuple[float, float, bool]] = []  # (start, duration, success)
        self.errors: dict[str, int] = {}
        self.start_time: float | None = None

    async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
        """Run workload for a single simulated client."""
        stats = {'success': 0, 'failure': 0}
        data_gen = TestDataGenerator()

        # Ramp-up delay
        ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
        await asyncio.sleep(ramp_delay)

        for op_num in range(self.config.operations_per_client):
            operation_start = time.time()

            try:
                # Randomly select operation type
                operation = random.choice(
                    [
                        'add_memory',
                        'search_memory_nodes',
                        'get_episodes',
                    ]
                )

                if operation == 'add_memory':
                    args = {
                        'name': f'Load Test {client_id}-{op_num}',
                        'episode_body': data_gen.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'load test',
                        'group_id': group_id,
                    }
                elif operation == 'search_memory_nodes':
                    args = {
                        'query': random.choice(['performance', 'architecture', 'test', 'data']),
                        'group_id': group_id,
                        'limit': 10,
                    }
                else:  # get_episodes
                    args = {
                        'group_id': group_id,
                        'last_n': 10,
                    }

                # Execute operation with timeout
                await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)

                duration = time.time() - operation_start
                self.metrics.append((operation_start, duration, True))
                stats['success'] += 1

            except asyncio.TimeoutError:
                duration = time.time() - operation_start
                self.metrics.append((operation_start, duration, False))
                self.errors['timeout'] = self.errors.get('timeout', 0) + 1
                stats['failure'] += 1

            except Exception as e:
                duration = time.time() - operation_start
                self.metrics.append((operation_start, duration, False))
                error_type = type(e).__name__
                self.errors[error_type] = self.errors.get(error_type, 0) + 1
                stats['failure'] += 1

            # Think time between operations
            await asyncio.sleep(self.config.think_time)

            # Stop if we've exceeded test duration
            if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
                break

        return stats

    def calculate_results(self) -> LoadTestResult:
        """Calculate load test results from metrics."""
        if not self.metrics:
            return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})

        successful = [m for m in self.metrics if m[2]]
        failed = [m for m in self.metrics if not m[2]]

        latencies = sorted([m[1] for m in self.metrics])
        duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])

        # Calculate percentiles
        def percentile(data: list[float], p: float) -> float:
            if not data:
                return 0.0
            idx = int(len(data) * p / 100)
            return data[min(idx, len(data) - 1)]

        # Get resource usage
        process = psutil.Process()
        resource_usage = {
            'cpu_percent': process.cpu_percent(),
            'memory_mb': process.memory_info().rss / 1024 / 1024,
            'num_threads': process.num_threads(),
        }

        return LoadTestResult(
            total_operations=len(self.metrics),
            successful_operations=len(successful),
            failed_operations=len(failed),
            duration=duration,
            throughput=len(self.metrics) / duration if duration > 0 else 0,
            average_latency=sum(latencies) / len(latencies) if latencies else 0,
            p50_latency=percentile(latencies, 50),
            p95_latency=percentile(latencies, 95),
            p99_latency=percentile(latencies, 99),
            max_latency=max(latencies) if latencies else 0,
            errors=self.errors,
            resource_usage=resource_usage,
        )


class TestLoadScenarios:
    """Various load testing scenarios."""

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_sustained_load(self):
        """Test system under sustained moderate load."""
        config = LoadTestConfig(
            num_clients=5,
            operations_per_client=20,
            ramp_up_time=2.0,
            test_duration=30.0,
            think_time=0.5,
        )

        async with graphiti_test_client() as (session, group_id):
            tester = LoadTester(config)
            tester.start_time = time.time()

            # Run client workloads
            client_tasks = []
            for client_id in range(config.num_clients):
                task = tester.run_client_workload(client_id, session, group_id)
                client_tasks.append(task)

            # Execute all clients
            await asyncio.gather(*client_tasks)

            # Calculate results
            results = tester.calculate_results()

            # Assertions
            assert results.successful_operations > results.failed_operations
            assert results.average_latency < 5.0, (
                f'Average latency too high: {results.average_latency:.2f}s'
            )
            assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'

            # Report results
            print('\nSustained Load Test Results:')
            print(f'  Total operations: {results.total_operations}')
            print(
                f'  Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
            )
            print(f'  Throughput: {results.throughput:.2f} ops/s')
            print(f'  Avg latency: {results.average_latency:.2f}s')
            print(f'  P95 latency: {results.p95_latency:.2f}s')

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_spike_load(self):
        """Test system response to sudden load spikes."""
        async with graphiti_test_client() as (session, group_id):
            # Normal load phase
            normal_tasks = []
            for i in range(3):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Normal Load {i}',
                        'episode_body': 'Normal operation',
                        'source': 'text',
                        'source_description': 'normal',
                        'group_id': group_id,
                    },
                )
                normal_tasks.append(task)
                await asyncio.sleep(0.5)

            await asyncio.gather(*normal_tasks)

            # Spike phase - sudden burst of requests
            spike_start = time.time()
            spike_tasks = []
            for i in range(50):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Spike Load {i}',
                        'episode_body': TestDataGenerator.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'spike',
                        'group_id': group_id,
                    },
                )
                spike_tasks.append(task)

            # Execute spike
            spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
            spike_duration = time.time() - spike_start

            # Analyze spike handling
            spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
            spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)

            print('\nSpike Load Test Results:')
            print(f'  Spike size: {len(spike_tasks)} operations')
            print(f'  Duration: {spike_duration:.2f}s')
            print(f'  Success rate: {spike_success_rate * 100:.1f}%')
            print(f'  Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')

            # System should handle at least 80% of spike
            assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_memory_leak_detection(self):
        """Test for memory leaks during extended operation."""
        async with graphiti_test_client() as (session, group_id):
            process = psutil.Process()
            gc.collect()  # Force garbage collection
            initial_memory = process.memory_info().rss / 1024 / 1024  # MB

            # Perform many operations
            for batch in range(10):
                batch_tasks = []
                for i in range(10):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Memory Test {batch}-{i}',
                            'episode_body': TestDataGenerator.generate_technical_document(),
                            'source': 'text',
                            'source_description': 'memory test',
                            'group_id': group_id,
                        },
                    )
                    batch_tasks.append(task)

                await asyncio.gather(*batch_tasks)

                # Force garbage collection between batches
                gc.collect()
                await asyncio.sleep(1)

            # Check memory after operations
            gc.collect()
            final_memory = process.memory_info().rss / 1024 / 1024  # MB
            memory_growth = final_memory - initial_memory

            print('\nMemory Leak Test:')
            print(f'  Initial memory: {initial_memory:.1f} MB')
            print(f'  Final memory: {final_memory:.1f} MB')
            print(f'  Growth: {memory_growth:.1f} MB')

            # Allow for some memory growth but flag potential leaks
            # This is a soft check - actual threshold depends on system
            if memory_growth > 100:  # More than 100MB growth
                print(f'  ⚠️  Potential memory leak detected: {memory_growth:.1f} MB growth')

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_connection_pool_exhaustion(self):
        """Test behavior when connection pools are exhausted."""
        async with graphiti_test_client() as (session, group_id):
            # Create many concurrent long-running operations
            long_tasks = []
            for i in range(100):  # Many more than typical pool size
                task = session.call_tool(
                    'search_memory_nodes',
                    {
                        'query': f'complex query {i} '
                        + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
                        'group_id': group_id,
                        'limit': 100,
                    },
                )
                long_tasks.append(task)

            # Execute with timeout
            try:
                results = await asyncio.wait_for(
                    asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
                )

                # Count connection-related errors
                connection_errors = sum(
                    1
                    for r in results
                    if isinstance(r, Exception) and 'connection' in str(r).lower()
                )

                print('\nConnection Pool Test:')
                print(f'  Total requests: {len(long_tasks)}')
                print(f'  Connection errors: {connection_errors}')

            except asyncio.TimeoutError:
                print('  Test timed out - possible deadlock or exhaustion')

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_gradual_degradation(self):
        """Test system degradation under increasing load."""
        async with graphiti_test_client() as (session, group_id):
            load_levels = [5, 10, 20, 40, 80]  # Increasing concurrent operations
            results_by_level = {}

            for level in load_levels:
                level_start = time.time()
                tasks = []

                for i in range(level):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Load Level {level} Op {i}',
                            'episode_body': f'Testing at load level {level}',
                            'source': 'text',
                            'source_description': 'degradation test',
                            'group_id': group_id,
                        },
                    )
                    tasks.append(task)

                # Execute level
                level_results = await asyncio.gather(*tasks, return_exceptions=True)
                level_duration = time.time() - level_start

                # Calculate metrics
                failures = sum(1 for r in level_results if isinstance(r, Exception))
                success_rate = (level - failures) / level * 100
                throughput = level / level_duration

                results_by_level[level] = {
                    'success_rate': success_rate,
                    'throughput': throughput,
                    'duration': level_duration,
                }

                print(f'\nLoad Level {level}:')
                print(f'  Success rate: {success_rate:.1f}%')
                print(f'  Throughput: {throughput:.2f} ops/s')
                print(f'  Duration: {level_duration:.2f}s')

                # Brief pause between levels
                await asyncio.sleep(2)

            # Verify graceful degradation
            # Success rate should not drop below 50% even at high load
            for level, metrics in results_by_level.items():
                assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'


class TestResourceLimits:
    """Test behavior at resource limits."""

    @pytest.mark.asyncio
    async def test_large_payload_handling(self):
        """Test handling of very large payloads."""
        async with graphiti_test_client() as (session, group_id):
            payload_sizes = [
                (1_000, '1KB'),
                (10_000, '10KB'),
                (100_000, '100KB'),
                (1_000_000, '1MB'),
            ]

            for size, label in payload_sizes:
                content = 'x' * size

                start_time = time.time()
                try:
                    await asyncio.wait_for(
                        session.call_tool(
                            'add_memory',
                            {
                                'name': f'Large Payload {label}',
                                'episode_body': content,
                                'source': 'text',
                                'source_description': 'payload test',
                                'group_id': group_id,
                            },
                        ),
                        timeout=30.0,
                    )
                    duration = time.time() - start_time
                    status = '✅ Success'

                except asyncio.TimeoutError:
                    duration = 30.0
                    status = '⏱️  Timeout'

                except Exception as e:
                    duration = time.time() - start_time
                    status = f'❌ Error: {type(e).__name__}'

                print(f'Payload {label}: {status} ({duration:.2f}s)')

    @pytest.mark.asyncio
    async def test_rate_limit_handling(self):
        """Test handling of rate limits."""
        async with graphiti_test_client() as (session, group_id):
            # Rapid fire requests to trigger rate limits
            rapid_tasks = []
            for i in range(100):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Rate Limit Test {i}',
                        'episode_body': f'Testing rate limit {i}',
                        'source': 'text',
                        'source_description': 'rate test',
                        'group_id': group_id,
                    },
                )
                rapid_tasks.append(task)

            # Execute without delays
            results = await asyncio.gather(*rapid_tasks, return_exceptions=True)

            # Count rate limit errors
            rate_limit_errors = sum(
                1
                for r in results
                if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
            )

            print('\nRate Limit Test:')
            print(f'  Total requests: {len(rapid_tasks)}')
            print(f'  Rate limit errors: {rate_limit_errors}')
            print(
                f'  Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
            )


def generate_load_test_report(results: list[LoadTestResult]) -> str:
    """Generate comprehensive load test report."""
    report = []
    report.append('\n' + '=' * 60)
    report.append('LOAD TEST REPORT')
    report.append('=' * 60)

    for i, result in enumerate(results):
        report.append(f'\nTest Run {i + 1}:')
        report.append(f'  Total Operations: {result.total_operations}')
        report.append(
            f'  Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
        )
        report.append(f'  Throughput: {result.throughput:.2f} ops/s')
        report.append(
            f'  Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
        )

        if result.errors:
            report.append('  Errors:')
            for error_type, count in result.errors.items():
                report.append(f'    {error_type}: {count}')

        report.append('  Resource Usage:')
        for metric, value in result.resource_usage.items():
            report.append(f'    {metric}: {value:.2f}')

    report.append('=' * 60)
    return '\n'.join(report)


if __name__ == '__main__':
    pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])

```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from collections.abc import Awaitable, Callable
from time import time
from typing import Any

from pydantic import BaseModel

from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import (
    EntityNode,
    EpisodeType,
    EpisodicNode,
    create_entity_node_embeddings,
)
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
    EntitySummary,
    ExtractedEntities,
    ExtractedEntity,
    MissedEntities,
)
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
    DedupCandidateIndexes,
    DedupResolutionState,
    _build_candidate_indexes,
    _resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
    filter_existing_duplicate_of_edges,
)
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence

logger = logging.getLogger(__name__)

NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]


async def extract_nodes_reflexion(
    llm_client: LLMClient,
    episode: EpisodicNode,
    previous_episodes: list[EpisodicNode],
    node_names: list[str],
    group_id: str | None = None,
) -> list[str]:
    # Prepare context for LLM
    context = {
        'episode_content': episode.content,
        'previous_episodes': [ep.content for ep in previous_episodes],
        'extracted_entities': node_names,
    }

    llm_response = await llm_client.generate_response(
        prompt_library.extract_nodes.reflexion(context),
        MissedEntities,
        group_id=group_id,
        prompt_name='extract_nodes.reflexion',
    )
    missed_entities = llm_response.get('missed_entities', [])

    return missed_entities


async def extract_nodes(
    clients: GraphitiClients,
    episode: EpisodicNode,
    previous_episodes: list[EpisodicNode],
    entity_types: dict[str, type[BaseModel]] | None = None,
    excluded_entity_types: list[str] | None = None,
) -> list[EntityNode]:
    start = time()
    llm_client = clients.llm_client
    llm_response = {}
    custom_prompt = ''
    entities_missed = True
    reflexion_iterations = 0

    entity_types_context = [
        {
            'entity_type_id': 0,
            'entity_type_name': 'Entity',
            'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.',
        }
    ]

    entity_types_context += (
        [
            {
                'entity_type_id': i + 1,
                'entity_type_name': type_name,
                'entity_type_description': type_model.__doc__,
            }
            for i, (type_name, type_model) in enumerate(entity_types.items())
        ]
        if entity_types is not None
        else []
    )

    context = {
        'episode_content': episode.content,
        'episode_timestamp': episode.valid_at.isoformat(),
        'previous_episodes': [ep.content for ep in previous_episodes],
        'custom_prompt': custom_prompt,
        'entity_types': entity_types_context,
        'source_description': episode.source_description,
    }

    while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
        if episode.source == EpisodeType.message:
            llm_response = await llm_client.generate_response(
                prompt_library.extract_nodes.extract_message(context),
                response_model=ExtractedEntities,
                group_id=episode.group_id,
                prompt_name='extract_nodes.extract_message',
            )
        elif episode.source == EpisodeType.text:
            llm_response = await llm_client.generate_response(
                prompt_library.extract_nodes.extract_text(context),
                response_model=ExtractedEntities,
                group_id=episode.group_id,
                prompt_name='extract_nodes.extract_text',
            )
        elif episode.source == EpisodeType.json:
            llm_response = await llm_client.generate_response(
                prompt_library.extract_nodes.extract_json(context),
                response_model=ExtractedEntities,
                group_id=episode.group_id,
                prompt_name='extract_nodes.extract_json',
            )

        response_object = ExtractedEntities(**llm_response)

        extracted_entities: list[ExtractedEntity] = response_object.extracted_entities

        reflexion_iterations += 1
        if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
            missing_entities = await extract_nodes_reflexion(
                llm_client,
                episode,
                previous_episodes,
                [entity.name for entity in extracted_entities],
                episode.group_id,
            )

            entities_missed = len(missing_entities) != 0

            custom_prompt = 'Make sure that the following entities are extracted: '
            for entity in missing_entities:
                custom_prompt += f'\n{entity},'

    filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()]
    end = time()
    logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms')
    # Convert the extracted data into EntityNode objects
    extracted_nodes = []
    for extracted_entity in filtered_extracted_entities:
        type_id = extracted_entity.entity_type_id
        if 0 <= type_id < len(entity_types_context):
            entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
                'entity_type_name'
            )
        else:
            entity_type_name = 'Entity'

        # Check if this entity type should be excluded
        if excluded_entity_types and entity_type_name in excluded_entity_types:
            logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
            continue

        labels: list[str] = list({'Entity', str(entity_type_name)})

        new_node = EntityNode(
            name=extracted_entity.name,
            group_id=episode.group_id,
            labels=labels,
            summary='',
            created_at=utc_now(),
        )
        extracted_nodes.append(new_node)
        logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')

    logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

    return extracted_nodes


async def _collect_candidate_nodes(
    clients: GraphitiClients,
    extracted_nodes: list[EntityNode],
    existing_nodes_override: list[EntityNode] | None,
) -> list[EntityNode]:
    """Search per extracted name and return unique candidates with overrides honored in order."""
    search_results: list[SearchResults] = await semaphore_gather(
        *[
            search(
                clients=clients,
                query=node.name,
                group_ids=[node.group_id],
                search_filter=SearchFilters(),
                config=NODE_HYBRID_SEARCH_RRF,
            )
            for node in extracted_nodes
        ]
    )

    candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]

    if existing_nodes_override is not None:
        candidate_nodes.extend(existing_nodes_override)

    seen_candidate_uuids: set[str] = set()
    ordered_candidates: list[EntityNode] = []
    for candidate in candidate_nodes:
        if candidate.uuid in seen_candidate_uuids:
            continue
        seen_candidate_uuids.add(candidate.uuid)
        ordered_candidates.append(candidate)

    return ordered_candidates


async def _resolve_with_llm(
    llm_client: LLMClient,
    extracted_nodes: list[EntityNode],
    indexes: DedupCandidateIndexes,
    state: DedupResolutionState,
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
    entity_types: dict[str, type[BaseModel]] | None,
) -> None:
    """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.

    The guardrails below defensively ignore malformed or duplicate LLM responses so the
    ingestion workflow remains deterministic even when the model misbehaves.
    """
    if not state.unresolved_indices:
        return

    entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}

    llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]

    extracted_nodes_context = [
        {
            'id': i,
            'name': node.name,
            'entity_type': node.labels,
            'entity_type_description': entity_types_dict.get(
                next((item for item in node.labels if item != 'Entity'), '')
            ).__doc__
            or 'Default Entity Type',
        }
        for i, node in enumerate(llm_extracted_nodes)
    ]

    sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
    logger.debug(
        'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
        len(llm_extracted_nodes),
        len(llm_extracted_nodes) - 1,
        sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
    )
    if llm_extracted_nodes:
        sample_size = min(3, len(extracted_nodes_context))
        logger.debug(
            'First %d entities: %s',
            sample_size,
            [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
        )
        if len(extracted_nodes_context) > 3:
            logger.debug(
                'Last %d entities: %s',
                sample_size,
                [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
            )

    existing_nodes_context = [
        {
            **{
                'idx': i,
                'name': candidate.name,
                'entity_types': candidate.labels,
            },
            **candidate.attributes,
        }
        for i, candidate in enumerate(indexes.existing_nodes)
    ]

    context = {
        'extracted_nodes': extracted_nodes_context,
        'existing_nodes': existing_nodes_context,
        'episode_content': episode.content if episode is not None else '',
        'previous_episodes': (
            [ep.content for ep in previous_episodes] if previous_episodes is not None else []
        ),
    }

    llm_response = await llm_client.generate_response(
        prompt_library.dedupe_nodes.nodes(context),
        response_model=NodeResolutions,
        prompt_name='dedupe_nodes.nodes',
    )

    node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions

    valid_relative_range = range(len(state.unresolved_indices))
    processed_relative_ids: set[int] = set()

    received_ids = {r.id for r in node_resolutions}
    expected_ids = set(valid_relative_range)
    missing_ids = expected_ids - received_ids
    extra_ids = received_ids - expected_ids

    logger.debug(
        'Received %d resolutions for %d entities',
        len(node_resolutions),
        len(state.unresolved_indices),
    )

    if missing_ids:
        logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))

    if extra_ids:
        logger.warning(
            'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
            len(state.unresolved_indices) - 1,
            sorted(extra_ids),
            sorted(received_ids),
        )

    for resolution in node_resolutions:
        relative_id: int = resolution.id
        duplicate_idx: int = resolution.duplicate_idx

        if relative_id not in valid_relative_range:
            logger.warning(
                'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
                relative_id,
                len(state.unresolved_indices) - 1,
                len(node_resolutions),
            )
            continue

        if relative_id in processed_relative_ids:
            logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
            continue
        processed_relative_ids.add(relative_id)

        original_index = state.unresolved_indices[relative_id]
        extracted_node = extracted_nodes[original_index]

        resolved_node: EntityNode
        if duplicate_idx == -1:
            resolved_node = extracted_node
        elif 0 <= duplicate_idx < len(indexes.existing_nodes):
            resolved_node = indexes.existing_nodes[duplicate_idx]
        else:
            logger.warning(
                'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
                duplicate_idx,
                extracted_node.uuid,
            )
            resolved_node = extracted_node

        state.resolved_nodes[original_index] = resolved_node
        state.uuid_map[extracted_node.uuid] = resolved_node.uuid
        if resolved_node.uuid != extracted_node.uuid:
            state.duplicate_pairs.append((extracted_node, resolved_node))


async def resolve_extracted_nodes(
    clients: GraphitiClients,
    extracted_nodes: list[EntityNode],
    episode: EpisodicNode | None = None,
    previous_episodes: list[EpisodicNode] | None = None,
    entity_types: dict[str, type[BaseModel]] | None = None,
    existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
    """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
    llm_client = clients.llm_client
    driver = clients.driver
    existing_nodes = await _collect_candidate_nodes(
        clients,
        extracted_nodes,
        existing_nodes_override,
    )

    indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)

    state = DedupResolutionState(
        resolved_nodes=[None] * len(extracted_nodes),
        uuid_map={},
        unresolved_indices=[],
    )

    _resolve_with_similarity(extracted_nodes, indexes, state)

    await _resolve_with_llm(
        llm_client,
        extracted_nodes,
        indexes,
        state,
        episode,
        previous_episodes,
        entity_types,
    )

    for idx, node in enumerate(extracted_nodes):
        if state.resolved_nodes[idx] is None:
            state.resolved_nodes[idx] = node
            state.uuid_map[node.uuid] = node.uuid

    logger.debug(
        'Resolved nodes: %s',
        [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
    )

    new_node_duplicates: list[
        tuple[EntityNode, EntityNode]
    ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)

    return (
        [node for node in state.resolved_nodes if node is not None],
        state.uuid_map,
        new_node_duplicates,
    )


async def extract_attributes_from_nodes(
    clients: GraphitiClients,
    nodes: list[EntityNode],
    episode: EpisodicNode | None = None,
    previous_episodes: list[EpisodicNode] | None = None,
    entity_types: dict[str, type[BaseModel]] | None = None,
    should_summarize_node: NodeSummaryFilter | None = None,
) -> list[EntityNode]:
    llm_client = clients.llm_client
    embedder = clients.embedder
    updated_nodes: list[EntityNode] = await semaphore_gather(
        *[
            extract_attributes_from_node(
                llm_client,
                node,
                episode,
                previous_episodes,
                (
                    entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
                    if entity_types is not None
                    else None
                ),
                should_summarize_node,
            )
            for node in nodes
        ]
    )

    await create_entity_node_embeddings(embedder, updated_nodes)

    return updated_nodes


async def extract_attributes_from_node(
    llm_client: LLMClient,
    node: EntityNode,
    episode: EpisodicNode | None = None,
    previous_episodes: list[EpisodicNode] | None = None,
    entity_type: type[BaseModel] | None = None,
    should_summarize_node: NodeSummaryFilter | None = None,
) -> EntityNode:
    # Extract attributes if entity type is defined and has attributes
    llm_response = await _extract_entity_attributes(
        llm_client, node, episode, previous_episodes, entity_type
    )

    # Extract summary if needed
    await _extract_entity_summary(
        llm_client, node, episode, previous_episodes, should_summarize_node
    )

    node.attributes.update(llm_response)

    return node


async def _extract_entity_attributes(
    llm_client: LLMClient,
    node: EntityNode,
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
    entity_type: type[BaseModel] | None,
) -> dict[str, Any]:
    if entity_type is None or len(entity_type.model_fields) == 0:
        return {}

    attributes_context = _build_episode_context(
        # should not include summary
        node_data={
            'name': node.name,
            'entity_types': node.labels,
            'attributes': node.attributes,
        },
        episode=episode,
        previous_episodes=previous_episodes,
    )

    llm_response = await llm_client.generate_response(
        prompt_library.extract_nodes.extract_attributes(attributes_context),
        response_model=entity_type,
        model_size=ModelSize.small,
        group_id=node.group_id,
        prompt_name='extract_nodes.extract_attributes',
    )

    # validate response
    entity_type(**llm_response)

    return llm_response


async def _extract_entity_summary(
    llm_client: LLMClient,
    node: EntityNode,
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
    should_summarize_node: NodeSummaryFilter | None,
) -> None:
    if should_summarize_node is not None and not await should_summarize_node(node):
        return

    summary_context = _build_episode_context(
        node_data={
            'name': node.name,
            'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
            'entity_types': node.labels,
            'attributes': node.attributes,
        },
        episode=episode,
        previous_episodes=previous_episodes,
    )

    summary_response = await llm_client.generate_response(
        prompt_library.extract_nodes.extract_summary(summary_context),
        response_model=EntitySummary,
        model_size=ModelSize.small,
        group_id=node.group_id,
        prompt_name='extract_nodes.extract_summary',
    )

    node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)


def _build_episode_context(
    node_data: dict[str, Any],
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
) -> dict[str, Any]:
    return {
        'node': node_data,
        'episode_content': episode.content if episode is not None else '',
        'previous_episodes': (
            [ep.content for ep in previous_episodes] if previous_episodes is not None else []
        ),
    }

```

--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field
from typing_extensions import LiteralString

from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
    COMMUNITY_EDGE_RETURN,
    EPISODIC_EDGE_RETURN,
    EPISODIC_EDGE_SAVE,
    get_community_edge_save_query,
    get_entity_edge_return_query,
    get_entity_edge_save_query,
)
from graphiti_core.nodes import Node

logger = logging.getLogger(__name__)


class Edge(BaseModel, ABC):
    uuid: str = Field(default_factory=lambda: str(uuid4()))
    group_id: str = Field(description='partition of the graph')
    source_node_uuid: str
    target_node_uuid: str
    created_at: datetime

    @abstractmethod
    async def save(self, driver: GraphDriver): ...

    async def delete(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.edge_delete(self, driver)

        if driver.provider == GraphProvider.KUZU:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
                DELETE e
                """,
                uuid=self.uuid,
            )
            await driver.execute_query(
                """
                MATCH (e:RelatesToNode_ {uuid: $uuid})
                DETACH DELETE e
                """,
                uuid=self.uuid,
            )
        else:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
                DELETE e
                """,
                uuid=self.uuid,
            )

        logger.debug(f'Deleted Edge: {self.uuid}')

    @classmethod
    async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)

        if driver.provider == GraphProvider.KUZU:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
                WHERE e.uuid IN $uuids
                DELETE e
                """,
                uuids=uuids,
            )
            await driver.execute_query(
                """
                MATCH (e:RelatesToNode_)
                WHERE e.uuid IN $uuids
                DETACH DELETE e
                """,
                uuids=uuids,
            )
        else:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
                WHERE e.uuid IN $uuids
                DELETE e
                """,
                uuids=uuids,
            )

        logger.debug(f'Deleted Edges: {uuids}')

    def __hash__(self):
        return hash(self.uuid)

    def __eq__(self, other):
        if isinstance(other, Node):
            return self.uuid == other.uuid
        return False

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...


class EpisodicEdge(Edge):
    async def save(self, driver: GraphDriver):
        result = await driver.execute_query(
            EPISODIC_EDGE_SAVE,
            episode_uuid=self.source_node_uuid,
            entity_uuid=self.target_node_uuid,
            uuid=self.uuid,
            group_id=self.group_id,
            created_at=self.created_at,
        )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
            RETURN
            """
            + EPISODIC_EDGE_RETURN,
            uuid=uuid,
            routing_='r',
        )

        edges = [get_episodic_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuid)
        return edges[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
            WHERE e.uuid IN $uuids
            RETURN
            """
            + EPISODIC_EDGE_RETURN,
            uuids=uuids,
            routing_='r',
        )

        edges = [get_episodic_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuids[0])
        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + EPISODIC_EDGE_RETURN
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_episodic_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise GroupsEdgesNotFoundError(group_ids)
        return edges


class EntityEdge(Edge):
    name: str = Field(description='name of the edge, relation name')
    fact: str = Field(description='fact representing the edge and nodes that it connects')
    fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
    episodes: list[str] = Field(
        default=[],
        description='list of episode ids that reference these entity edges',
    )
    expired_at: datetime | None = Field(
        default=None, description='datetime of when the node was invalidated'
    )
    valid_at: datetime | None = Field(
        default=None, description='datetime of when the fact became true'
    )
    invalid_at: datetime | None = Field(
        default=None, description='datetime of when the fact stopped being true'
    )
    attributes: dict[str, Any] = Field(
        default={}, description='Additional attributes of the edge. Dependent on edge name'
    )

    async def generate_embedding(self, embedder: EmbedderClient):
        start = time()

        text = self.fact.replace('\n', ' ')
        self.fact_embedding = await embedder.create(input_data=[text])

        end = time()
        logger.debug(f'embedded {text} in {end - start} ms')

        return self.fact_embedding

    async def load_fact_embedding(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.edge_load_embeddings(self, driver)

        query = """
            MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
            RETURN e.fact_embedding AS fact_embedding
        """

        if driver.provider == GraphProvider.NEPTUNE:
            query = """
                MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
                RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
            """

        if driver.provider == GraphProvider.KUZU:
            query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
                RETURN e.fact_embedding AS fact_embedding
            """

        records, _, _ = await driver.execute_query(
            query,
            uuid=self.uuid,
            routing_='r',
        )

        if len(records) == 0:
            raise EdgeNotFoundError(self.uuid)

        self.fact_embedding = records[0]['fact_embedding']

    async def save(self, driver: GraphDriver):
        edge_data: dict[str, Any] = {
            'source_uuid': self.source_node_uuid,
            'target_uuid': self.target_node_uuid,
            'uuid': self.uuid,
            'name': self.name,
            'group_id': self.group_id,
            'fact': self.fact,
            'fact_embedding': self.fact_embedding,
            'episodes': self.episodes,
            'created_at': self.created_at,
            'expired_at': self.expired_at,
            'valid_at': self.valid_at,
            'invalid_at': self.invalid_at,
        }

        if driver.provider == GraphProvider.KUZU:
            edge_data['attributes'] = json.dumps(self.attributes)
            result = await driver.execute_query(
                get_entity_edge_save_query(driver.provider),
                **edge_data,
            )
        else:
            edge_data.update(self.attributes or {})
            result = await driver.execute_query(
                get_entity_edge_save_query(driver.provider),
                edge_data=edge_data,
            )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            uuid=uuid,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuid)
        return edges[0]

    @classmethod
    async def get_between_nodes(
        cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
    ):
        match_query = """
            MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity {uuid: $source_node_uuid})
                      -[:RELATES_TO]->(e:RelatesToNode_)
                      -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            source_node_uuid=source_node_uuid,
            target_node_uuid=target_node_uuid,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        return edges

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        if len(uuids) == 0:
            return []

        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            WHERE e.uuid IN $uuids
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            uuids=uuids,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
        with_embeddings: bool = False,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
        with_embeddings_query: LiteralString = (
            """,
                e.fact_embedding AS fact_embedding
                """
            if with_embeddings
            else ''
        )

        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider)
            + with_embeddings_query
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        if len(edges) == 0:
            raise GroupsEdgesNotFoundError(group_ids)
        return edges

    @classmethod
    async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
        match_query = """
            MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            node_uuid=node_uuid,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        return edges


class CommunityEdge(Edge):
    async def save(self, driver: GraphDriver):
        result = await driver.execute_query(
            get_community_edge_save_query(driver.provider),
            community_uuid=self.source_node_uuid,
            entity_uuid=self.target_node_uuid,
            uuid=self.uuid,
            group_id=self.group_id,
            created_at=self.created_at,
        )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
            RETURN
            """
            + COMMUNITY_EDGE_RETURN,
            uuid=uuid,
            routing_='r',
        )

        edges = [get_community_edge_from_record(record) for record in records]

        return edges[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Community)-[e:HAS_MEMBER]->(m)
            WHERE e.uuid IN $uuids
            RETURN
            """
            + COMMUNITY_EDGE_RETURN,
            uuids=uuids,
            routing_='r',
        )

        edges = [get_community_edge_from_record(record) for record in records]

        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Community)-[e:HAS_MEMBER]->(m)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + COMMUNITY_EDGE_RETURN
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_community_edge_from_record(record) for record in records]

        return edges


# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
    return EpisodicEdge(
        uuid=record['uuid'],
        group_id=record['group_id'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
    )


def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
    episodes = record['episodes']
    if provider == GraphProvider.KUZU:
        attributes = json.loads(record['attributes']) if record['attributes'] else {}
    else:
        attributes = record['attributes']
        attributes.pop('uuid', None)
        attributes.pop('source_node_uuid', None)
        attributes.pop('target_node_uuid', None)
        attributes.pop('fact', None)
        attributes.pop('fact_embedding', None)
        attributes.pop('name', None)
        attributes.pop('group_id', None)
        attributes.pop('episodes', None)
        attributes.pop('created_at', None)
        attributes.pop('expired_at', None)
        attributes.pop('valid_at', None)
        attributes.pop('invalid_at', None)

    edge = EntityEdge(
        uuid=record['uuid'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        fact=record['fact'],
        fact_embedding=record.get('fact_embedding'),
        name=record['name'],
        group_id=record['group_id'],
        episodes=episodes,
        created_at=parse_db_date(record['created_at']),  # type: ignore
        expired_at=parse_db_date(record['expired_at']),
        valid_at=parse_db_date(record['valid_at']),
        invalid_at=parse_db_date(record['invalid_at']),
        attributes=attributes,
    )

    return edge


def get_community_edge_from_record(record: Any):
    return CommunityEdge(
        uuid=record['uuid'],
        group_id=record['group_id'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
    )


async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
    # filter out falsey values from edges
    filtered_edges = [edge for edge in edges if edge.fact]

    if len(filtered_edges) == 0:
        return
    fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
    for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
        edge.fact_embedding = fact_embedding

```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------

```python
import logging
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
    DedupCandidateIndexes,
    DedupResolutionState,
    _build_candidate_indexes,
    _cached_shingles,
    _has_high_entropy,
    _hash_shingle,
    _jaccard_similarity,
    _lsh_bands,
    _minhash_signature,
    _name_entropy,
    _normalize_name_for_fuzzy,
    _normalize_string_exact,
    _resolve_with_similarity,
    _shingles,
)
from graphiti_core.utils.maintenance.node_operations import (
    _collect_candidate_nodes,
    _resolve_with_llm,
    extract_attributes_from_node,
    extract_attributes_from_nodes,
    resolve_extracted_nodes,
)


def _make_clients():
    driver = MagicMock()
    embedder = MagicMock()
    cross_encoder = MagicMock()
    llm_client = MagicMock()
    llm_generate = AsyncMock()
    llm_client.generate_response = llm_generate

    clients = GraphitiClients.model_construct(  # bypass validation to allow test doubles
        driver=driver,
        embedder=embedder,
        cross_encoder=cross_encoder,
        llm_client=llm_client,
    )

    return clients, llm_generate


def _make_episode(group_id: str = 'group'):
    return EpisodicNode(
        name='episode',
        group_id=group_id,
        source=EpisodeType.message,
        source_description='test',
        content='content',
        valid_at=utc_now(),
    )


@pytest.mark.asyncio
async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
    clients, llm_generate = _make_clients()

    candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])

    async def fake_search(*_, **__):
        return SearchResults(nodes=[candidate])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        fake_search,
    )
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
        AsyncMock(return_value=[]),
    )

    resolved, uuid_map, _ = await resolve_extracted_nodes(
        clients,
        [extracted],
        episode=_make_episode(),
        previous_episodes=[],
    )

    assert resolved[0].uuid == candidate.uuid
    assert uuid_map[extracted.uuid] == candidate.uuid
    llm_generate.assert_not_awaited()


@pytest.mark.asyncio
async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
    clients, llm_generate = _make_clients()
    llm_generate.return_value = {
        'entity_resolutions': [
            {
                'id': 0,
                'duplicate_idx': -1,
                'name': 'Joe',
                'duplicates': [],
            }
        ]
    }

    extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])

    async def fake_search(*_, **__):
        return SearchResults(nodes=[])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        fake_search,
    )
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
        AsyncMock(return_value=[]),
    )

    resolved, uuid_map, _ = await resolve_extracted_nodes(
        clients,
        [extracted],
        episode=_make_episode(),
        previous_episodes=[],
    )

    assert resolved[0].uuid == extracted.uuid
    assert uuid_map[extracted.uuid] == extracted.uuid
    llm_generate.assert_awaited()


@pytest.mark.asyncio
async def test_resolve_nodes_fuzzy_match(monkeypatch):
    clients, llm_generate = _make_clients()

    candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])

    async def fake_search(*_, **__):
        return SearchResults(nodes=[candidate])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        fake_search,
    )
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
        AsyncMock(return_value=[]),
    )

    resolved, uuid_map, _ = await resolve_extracted_nodes(
        clients,
        [extracted],
        episode=_make_episode(),
        previous_episodes=[],
    )

    assert resolved[0].uuid == candidate.uuid
    assert uuid_map[extracted.uuid] == candidate.uuid
    llm_generate.assert_not_awaited()


@pytest.mark.asyncio
async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
    clients, _ = _make_clients()

    candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
    override_duplicate = EntityNode(
        uuid=candidate.uuid,
        name='Alice Alt',
        group_id='group',
        labels=['Entity'],
    )
    extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])

    search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        search_mock,
    )

    result = await _collect_candidate_nodes(
        clients,
        [extracted],
        existing_nodes_override=[override_duplicate],
    )

    assert len(result) == 1
    assert result[0].uuid == candidate.uuid
    search_mock.assert_awaited()


def test_build_candidate_indexes_populates_structures():
    candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])

    normalized_key = candidate.name.lower()
    assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
    assert indexes.nodes_by_uuid[candidate.uuid] is candidate
    assert candidate.uuid in indexes.shingles_by_candidate
    assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())


def test_normalize_helpers():
    assert _normalize_string_exact('  Alice   Smith ') == 'alice smith'
    assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'


def test_name_entropy_variants():
    assert _name_entropy('alice') > _name_entropy('aaaaa')
    assert _name_entropy('') == 0.0


def test_has_high_entropy_rules():
    assert _has_high_entropy('meaningful name') is True
    assert _has_high_entropy('aa') is False


def test_shingles_and_cache():
    raw = 'alice'
    shingle_set = _shingles(raw)
    assert shingle_set == {'ali', 'lic', 'ice'}
    assert _cached_shingles(raw) == shingle_set
    assert _cached_shingles(raw) is _cached_shingles(raw)


def test_hash_minhash_and_lsh():
    shingles = {'abc', 'bcd', 'cde'}
    signature = _minhash_signature(shingles)
    assert len(signature) == 32
    bands = _lsh_bands(signature)
    assert all(len(band) == 4 for band in bands)
    hashed = {_hash_shingle(s, 0) for s in shingles}
    assert len(hashed) == len(shingles)


def test_jaccard_similarity_edges():
    a = {'a', 'b'}
    b = {'a', 'c'}
    assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
    assert _jaccard_similarity(set(), set()) == 1.0
    assert _jaccard_similarity(a, set()) == 0.0


def test_resolve_with_similarity_exact_match_updates_state():
    candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])

    _resolve_with_similarity([extracted], indexes, state)

    assert state.resolved_nodes[0].uuid == candidate.uuid
    assert state.uuid_map[extracted.uuid] == candidate.uuid
    assert state.unresolved_indices == []
    assert state.duplicate_pairs == [(extracted, candidate)]


def test_resolve_with_similarity_low_entropy_defers_resolution():
    extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
    indexes = DedupCandidateIndexes(
        existing_nodes=[],
        nodes_by_uuid={},
        normalized_existing=defaultdict(list),
        shingles_by_candidate={},
        lsh_buckets=defaultdict(list),
    )
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])

    _resolve_with_similarity([extracted], indexes, state)

    assert state.resolved_nodes[0] is None
    assert state.unresolved_indices == [0]
    assert state.duplicate_pairs == []


def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
    candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
    candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate1, candidate2])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])

    _resolve_with_similarity([extracted], indexes, state)

    assert state.resolved_nodes[0] is None
    assert state.unresolved_indices == [0]
    assert state.duplicate_pairs == []


@pytest.mark.asyncio
async def test_resolve_with_llm_updates_unresolved(monkeypatch):
    extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
    candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    captured_context = {}

    def fake_prompt_nodes(context):
        captured_context.update(context)
        return ['prompt']

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        fake_prompt_nodes,
    )

    async def fake_generate_response(*_, **__):
        return {
            'entity_resolutions': [
                {
                    'id': 0,
                    'duplicate_idx': 0,
                    'name': 'Dizzy Gillespie',
                    'duplicates': [0],
                }
            ]
        }

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)

    await _resolve_with_llm(
        llm_client,
        [extracted],
        indexes,
        state,
        episode=_make_episode(),
        previous_episodes=[],
        entity_types=None,
    )

    assert state.resolved_nodes[0].uuid == candidate.uuid
    assert state.uuid_map[extracted.uuid] == candidate.uuid
    assert captured_context['existing_nodes'][0]['idx'] == 0
    assert isinstance(captured_context['existing_nodes'], list)
    assert state.duplicate_pairs == [(extracted, candidate)]


@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
    extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        lambda context: ['prompt'],
    )

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'entity_resolutions': [
                {
                    'id': 5,
                    'duplicate_idx': -1,
                    'name': 'Dexter',
                    'duplicates': [],
                }
            ]
        }
    )

    with caplog.at_level(logging.WARNING):
        await _resolve_with_llm(
            llm_client,
            [extracted],
            indexes,
            state,
            episode=_make_episode(),
            previous_episodes=[],
            entity_types=None,
        )

    assert state.resolved_nodes[0] is None
    assert 'Skipping invalid LLM dedupe id 5' in caplog.text


@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
    extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
    candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        lambda context: ['prompt'],
    )

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'entity_resolutions': [
                {
                    'id': 0,
                    'duplicate_idx': 0,
                    'name': 'Dizzy Gillespie',
                    'duplicates': [0],
                },
                {
                    'id': 0,
                    'duplicate_idx': -1,
                    'name': 'Dizzy',
                    'duplicates': [],
                },
            ]
        }
    )

    await _resolve_with_llm(
        llm_client,
        [extracted],
        indexes,
        state,
        episode=_make_episode(),
        previous_episodes=[],
        entity_types=None,
    )

    assert state.resolved_nodes[0].uuid == candidate.uuid
    assert state.uuid_map[extracted.uuid] == candidate.uuid
    assert state.duplicate_pairs == [(extracted, candidate)]


@pytest.mark.asyncio
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
    extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        lambda context: ['prompt'],
    )

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'entity_resolutions': [
                {
                    'id': 0,
                    'duplicate_idx': 10,
                    'name': 'Dexter',
                    'duplicates': [],
                }
            ]
        }
    )

    await _resolve_with_llm(
        llm_client,
        [extracted],
        indexes,
        state,
        episode=_make_episode(),
        previous_episodes=[],
        entity_types=None,
    )

    assert state.resolved_nodes[0] == extracted
    assert state.uuid_map[extracted.uuid] == extracted.uuid
    assert state.duplicate_pairs == []


@pytest.mark.asyncio
async def test_extract_attributes_without_callback_generates_summary():
    """Test that summary is generated when no callback is provided (default behavior)."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'Generated summary', 'attributes': {}}
    )

    node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
    episode = _make_episode()

    result = await extract_attributes_from_node(
        llm_client,
        node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=None,  # No callback provided
    )

    # Summary should be generated
    assert result.summary == 'Generated summary'
    # LLM should have been called for summary
    assert llm_client.generate_response.call_count == 1


@pytest.mark.asyncio
async def test_extract_attributes_with_callback_skip_summary():
    """Test that summary is NOT regenerated when callback returns False."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'This should not be used', 'attributes': {}}
    )

    node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
    episode = _make_episode()

    # Callback that always returns False (skip summary generation)
    async def skip_summary_filter(node: EntityNode) -> bool:
        return False

    result = await extract_attributes_from_node(
        llm_client,
        node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=skip_summary_filter,
    )

    # Summary should remain unchanged
    assert result.summary == 'Old summary'
    # LLM should NOT have been called for summary
    assert llm_client.generate_response.call_count == 0


@pytest.mark.asyncio
async def test_extract_attributes_with_callback_generate_summary():
    """Test that summary is regenerated when callback returns True."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'New generated summary', 'attributes': {}}
    )

    node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
    episode = _make_episode()

    # Callback that always returns True (generate summary)
    async def generate_summary_filter(node: EntityNode) -> bool:
        return True

    result = await extract_attributes_from_node(
        llm_client,
        node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=generate_summary_filter,
    )

    # Summary should be updated
    assert result.summary == 'New generated summary'
    # LLM should have been called for summary
    assert llm_client.generate_response.call_count == 1


@pytest.mark.asyncio
async def test_extract_attributes_with_selective_callback():
    """Test callback that selectively skips summaries based on node properties."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'Generated summary', 'attributes': {}}
    )

    user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
    topic_node = EntityNode(
        name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
    )

    episode = _make_episode()

    # Callback that skips User nodes but generates for others
    async def selective_filter(node: EntityNode) -> bool:
        return 'User' not in node.labels

    result_user = await extract_attributes_from_node(
        llm_client,
        user_node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=selective_filter,
    )

    result_topic = await extract_attributes_from_node(
        llm_client,
        topic_node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=selective_filter,
    )

    # User summary should remain unchanged
    assert result_user.summary == 'Old'
    # Topic summary should be generated
    assert result_topic.summary == 'Generated summary'
    # LLM should have been called only once (for topic)
    assert llm_client.generate_response.call_count == 1


@pytest.mark.asyncio
async def test_extract_attributes_from_nodes_with_callback():
    """Test that callback is properly passed through extract_attributes_from_nodes."""
    clients, _ = _make_clients()
    clients.llm_client.generate_response = AsyncMock(
        return_value={'summary': 'New summary', 'attributes': {}}
    )
    clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
    clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])

    node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
    node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')

    episode = _make_episode()

    call_tracker = []

    # Callback that tracks which nodes it's called with
    async def tracking_filter(node: EntityNode) -> bool:
        call_tracker.append(node.name)
        return 'User' not in node.labels

    results = await extract_attributes_from_nodes(
        clients,
        [node1, node2],
        episode=episode,
        previous_episodes=[],
        entity_types=None,
        should_summarize_node=tracking_filter,
    )

    # Callback should have been called for both nodes
    assert len(call_tracker) == 2
    assert 'Node1' in call_tracker
    assert 'Node2' in call_tracker

    # Node1 (User) should keep old summary, Node2 (Topic) should get new summary
    node1_result = next(n for n in results if n.name == 'Node1')
    node2_result = next(n for n in results if n.name == 'Node2')

    assert node1_result.summary == 'Old1'
    assert node2_result.summary == 'New summary'

```

--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import BaseModel

from graphiti_core.llm_client.config import LLMConfig, ModelSize
from graphiti_core.llm_client.errors import RateLimitError
from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
from graphiti_core.prompts.models import Message


# Test model for response testing
class ResponseModel(BaseModel):
    """Test model for response testing."""

    test_field: str
    optional_field: int = 0


@pytest.fixture
def mock_gemini_client():
    """Fixture to mock the Google Gemini client."""
    with patch('google.genai.Client') as mock_client:
        # Setup mock instance and its methods
        mock_instance = mock_client.return_value
        mock_instance.aio = MagicMock()
        mock_instance.aio.models = MagicMock()
        mock_instance.aio.models.generate_content = AsyncMock()
        yield mock_instance


@pytest.fixture
def gemini_client(mock_gemini_client):
    """Fixture to create a GeminiClient with a mocked client."""
    config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
    client = GeminiClient(config=config, cache=False)
    # Replace the client's client with our mock to ensure we're using the mock
    client.client = mock_gemini_client
    return client


class TestGeminiClientInitialization:
    """Tests for GeminiClient initialization."""

    @patch('google.genai.Client')
    def test_init_with_config(self, mock_client):
        """Test initialization with a config object."""
        config = LLMConfig(
            api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
        )
        client = GeminiClient(config=config, cache=False, max_tokens=1000)

        assert client.config == config
        assert client.model == 'test-model'
        assert client.temperature == 0.5
        assert client.max_tokens == 1000

    @patch('google.genai.Client')
    def test_init_with_default_model(self, mock_client):
        """Test initialization with default model when none is provided."""
        config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
        client = GeminiClient(config=config, cache=False)

        assert client.model == DEFAULT_MODEL

    @patch('google.genai.Client')
    def test_init_without_config(self, mock_client):
        """Test initialization without a config uses defaults."""
        client = GeminiClient(cache=False)

        assert client.config is not None
        # When no config.model is set, it will be None, not DEFAULT_MODEL
        assert client.model is None

    @patch('google.genai.Client')
    def test_init_with_thinking_config(self, mock_client):
        """Test initialization with thinking config."""
        with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
            thinking_config = mock_thinking_config.return_value
            client = GeminiClient(thinking_config=thinking_config)
            assert client.thinking_config == thinking_config


class TestGeminiClientGenerateResponse:
    """Tests for GeminiClient generate_response method."""

    @pytest.mark.asyncio
    async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
        """Test successful response generation with simple text."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response text'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method
        messages = [Message(role='user', content='Test message')]
        result = await gemini_client.generate_response(messages)

        # Assertions
        assert isinstance(result, dict)
        assert result['content'] == 'Test response text'
        mock_gemini_client.aio.models.generate_content.assert_called_once()

    @pytest.mark.asyncio
    async def test_generate_response_with_structured_output(
        self, gemini_client, mock_gemini_client
    ):
        """Test response generation with structured output."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method
        messages = [
            Message(role='system', content='System message'),
            Message(role='user', content='User message'),
        ]
        result = await gemini_client.generate_response(
            messages=messages, response_model=ResponseModel
        )

        # Assertions
        assert isinstance(result, dict)
        assert result['test_field'] == 'test_value'
        assert result['optional_field'] == 42
        mock_gemini_client.aio.models.generate_content.assert_called_once()

    @pytest.mark.asyncio
    async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
        """Test response generation with system message handling."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Response with system context'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method
        messages = [
            Message(role='system', content='System message'),
            Message(role='user', content='User message'),
        ]
        await gemini_client.generate_response(messages)

        # Verify system message is processed correctly
        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        assert 'System message' in config.system_instruction

    @pytest.mark.asyncio
    async def test_get_model_for_size(self, gemini_client):
        """Test model selection based on size."""
        # Test small model
        small_model = gemini_client._get_model_for_size(ModelSize.small)
        assert small_model == DEFAULT_SMALL_MODEL

        # Test medium/large model
        medium_model = gemini_client._get_model_for_size(ModelSize.medium)
        assert medium_model == gemini_client.model

    @pytest.mark.asyncio
    async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
        """Test handling of rate limit errors."""
        # Setup mock to raise rate limit error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'Rate limit exceeded'
        )

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(RateLimitError):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
        """Test handling of quota errors."""
        # Setup mock to raise quota error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'Quota exceeded for requests'
        )

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(RateLimitError):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
        """Test handling of resource exhausted errors."""
        # Setup mock to raise resource exhausted error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'resource_exhausted: Request limit exceeded'
        )

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(RateLimitError):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
        """Test handling of safety blocks."""
        # Setup mock response with safety block
        mock_candidate = MagicMock()
        mock_candidate.finish_reason = 'SAFETY'
        mock_candidate.safety_ratings = [
            MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
        ]

        mock_response = MagicMock()
        mock_response.candidates = [mock_candidate]
        mock_response.prompt_feedback = None
        mock_response.text = ''
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception, match='Content blocked by safety filters'):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
        """Test handling of prompt blocks."""
        # Setup mock response with prompt block
        mock_prompt_feedback = MagicMock()
        mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'

        mock_response = MagicMock()
        mock_response.candidates = []
        mock_response.prompt_feedback = mock_prompt_feedback
        mock_response.text = ''
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception, match='Content blocked by safety filters'):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
        """Test handling of structured output parsing errors."""
        # Setup mock response with invalid JSON that will exhaust retries
        mock_response = MagicMock()
        mock_response.text = 'Invalid JSON that cannot be parsed'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception - should exhaust retries
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception):  # noqa: B017
            await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have called generate_content MAX_RETRIES times (2 attempts total)
        assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

    @pytest.mark.asyncio
    async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
        """Test that safety blocks are not retried."""
        # Setup mock response with safety block
        mock_candidate = MagicMock()
        mock_candidate.finish_reason = 'SAFETY'
        mock_candidate.safety_ratings = [
            MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
        ]

        mock_response = MagicMock()
        mock_response.candidates = [mock_candidate]
        mock_response.prompt_feedback = None
        mock_response.text = ''
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check that it doesn't retry
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception, match='Content blocked by safety filters'):
            await gemini_client.generate_response(messages)

        # Should only be called once (no retries for safety blocks)
        assert mock_gemini_client.aio.models.generate_content.call_count == 1

    @pytest.mark.asyncio
    async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
        """Test retry behavior on validation error."""
        # First call returns invalid JSON, second call returns valid data
        mock_response1 = MagicMock()
        mock_response1.text = 'Invalid JSON that cannot be parsed'
        mock_response1.candidates = []
        mock_response1.prompt_feedback = None

        mock_response2 = MagicMock()
        mock_response2.text = '{"test_field": "correct_value"}'
        mock_response2.candidates = []
        mock_response2.prompt_feedback = None

        mock_gemini_client.aio.models.generate_content.side_effect = [
            mock_response1,
            mock_response2,
        ]

        # Call method
        messages = [Message(role='user', content='Test message')]
        result = await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have called generate_content twice due to retry
        assert mock_gemini_client.aio.models.generate_content.call_count == 2
        assert result['test_field'] == 'correct_value'

    @pytest.mark.asyncio
    async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
        """Test behavior when max retries are exceeded."""
        # Setup mock to always return invalid JSON
        mock_response = MagicMock()
        mock_response.text = 'Invalid JSON that cannot be parsed'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception):  # noqa: B017
            await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have called generate_content MAX_RETRIES times (2 attempts total)
        assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

    @pytest.mark.asyncio
    async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
        """Test handling of empty responses."""
        # Setup mock response with no text
        mock_response = MagicMock()
        mock_response.text = ''
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method with structured output and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception):  # noqa: B017
            await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have exhausted retries due to empty response (2 attempts total)
        assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

    @pytest.mark.asyncio
    async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
        """Test that explicit max_tokens parameter takes precedence over all other values."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method with custom max tokens (should take precedence)
        messages = [Message(role='user', content='Test message')]
        await gemini_client.generate_response(messages, max_tokens=500)

        # Verify explicit max_tokens parameter takes precedence
        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        # Explicit parameter should override everything else
        assert config.max_output_tokens == 500

    @pytest.mark.asyncio
    async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
        """Test max_tokens precedence when no explicit parameter is provided."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Test case 1: No explicit max_tokens, has instance max_tokens
        config = LLMConfig(
            api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
        )
        client = GeminiClient(
            config=config, cache=False, max_tokens=2000, client=mock_gemini_client
        )

        messages = [Message(role='user', content='Test message')]
        await client.generate_response(messages)

        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        # Instance max_tokens should be used
        assert config.max_output_tokens == 2000

        # Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
        config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
        client = GeminiClient(config=config, cache=False, client=mock_gemini_client)

        messages = [Message(role='user', content='Test message')]
        await client.generate_response(messages)

        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        # Model mapping should be used
        assert config.max_output_tokens == 65536

    @pytest.mark.asyncio
    async def test_model_size_selection(self, gemini_client, mock_gemini_client):
        """Test that the correct model is selected based on model size."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method with small model size
        messages = [Message(role='user', content='Test message')]
        await gemini_client.generate_response(messages, model_size=ModelSize.small)

        # Verify correct model is used
        call_args = mock_gemini_client.aio.models.generate_content.call_args
        assert call_args[1]['model'] == DEFAULT_SMALL_MODEL

    @pytest.mark.asyncio
    async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
        """Test that different Gemini models use their correct max tokens."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Test data: (model_name, expected_max_tokens)
        test_cases = [
            ('gemini-2.5-flash', 65536),
            ('gemini-2.5-pro', 65536),
            ('gemini-2.5-flash-lite', 64000),
            ('gemini-2.0-flash', 8192),
            ('gemini-1.5-pro', 8192),
            ('gemini-1.5-flash', 8192),
            ('unknown-model', 8192),  # Fallback case
        ]

        for model_name, expected_max_tokens in test_cases:
            # Create client with specific model, no explicit max_tokens to test mapping
            config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
            client = GeminiClient(config=config, cache=False, client=mock_gemini_client)

            # Call method without explicit max_tokens to test model mapping fallback
            messages = [Message(role='user', content='Test message')]
            await client.generate_response(messages)

            # Verify correct max tokens is used from model mapping
            call_args = mock_gemini_client.aio.models.generate_content.call_args
            config = call_args[1]['config']
            assert config.max_output_tokens == expected_max_tokens, (
                f'Model {model_name} should use {expected_max_tokens} tokens'
            )


if __name__ == '__main__':
    pytest.main(['-v', 'test_gemini_client.py'])

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Comprehensive integration test suite for Graphiti MCP Server.
Covers all MCP tools with consideration for LLM inference latency.
"""

import asyncio
import json
import os
import time
from dataclasses import dataclass
from typing import Any

import pytest
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client


@dataclass
class TestMetrics:
    """Track test performance metrics."""

    operation: str
    start_time: float
    end_time: float
    success: bool
    details: dict[str, Any]

    @property
    def duration(self) -> float:
        """Calculate operation duration in seconds."""
        return self.end_time - self.start_time


class GraphitiTestClient:
    """Enhanced test client for comprehensive Graphiti MCP testing."""

    def __init__(self, test_group_id: str | None = None):
        self.test_group_id = test_group_id or f'test_{int(time.time())}'
        self.session = None
        self.metrics: list[TestMetrics] = []
        self.default_timeout = 30  # seconds

    async def __aenter__(self):
        """Initialize MCP client session."""
        server_params = StdioServerParameters(
            command='uv',
            args=['run', '../main.py', '--transport', 'stdio'],
            env={
                'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
                'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
                'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
            },
        )

        self.client_context = stdio_client(server_params)
        read, write = await self.client_context.__aenter__()
        self.session = ClientSession(read, write)
        await self.session.initialize()

        # Wait for server to be fully ready
        await asyncio.sleep(2)

        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Clean up client session."""
        if self.session:
            await self.session.close()
        if hasattr(self, 'client_context'):
            await self.client_context.__aexit__(exc_type, exc_val, exc_tb)

    async def call_tool_with_metrics(
        self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
    ) -> tuple[Any, TestMetrics]:
        """Call a tool and capture performance metrics."""
        start_time = time.time()
        timeout = timeout or self.default_timeout

        try:
            result = await asyncio.wait_for(
                self.session.call_tool(tool_name, arguments), timeout=timeout
            )

            content = result.content[0].text if result.content else None
            success = True
            details = {'result': content, 'tool': tool_name}

        except asyncio.TimeoutError:
            content = None
            success = False
            details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}

        except Exception as e:
            content = None
            success = False
            details = {'error': str(e), 'tool': tool_name}

        end_time = time.time()
        metric = TestMetrics(
            operation=f'call_{tool_name}',
            start_time=start_time,
            end_time=end_time,
            success=success,
            details=details,
        )
        self.metrics.append(metric)

        return content, metric

    async def wait_for_episode_processing(
        self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
    ) -> bool:
        """
        Wait for episodes to be processed with intelligent polling.

        Args:
            expected_count: Number of episodes expected to be processed
            max_wait: Maximum seconds to wait
            poll_interval: Seconds between status checks

        Returns:
            True if episodes were processed successfully
        """
        start_time = time.time()

        while (time.time() - start_time) < max_wait:
            result, _ = await self.call_tool_with_metrics(
                'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
            )

            if result:
                try:
                    episodes = json.loads(result) if isinstance(result, str) else result
                    if len(episodes.get('episodes', [])) >= expected_count:
                        return True
                except (json.JSONDecodeError, AttributeError):
                    pass

            await asyncio.sleep(poll_interval)

        return False


class TestCoreOperations:
    """Test core Graphiti operations."""

    @pytest.mark.asyncio
    async def test_server_initialization(self):
        """Verify server initializes with all required tools."""
        async with GraphitiTestClient() as client:
            tools_result = await client.session.list_tools()
            tools = {tool.name for tool in tools_result.tools}

            required_tools = {
                'add_memory',
                'search_memory_nodes',
                'search_memory_facts',
                'get_episodes',
                'delete_episode',
                'delete_entity_edge',
                'get_entity_edge',
                'clear_graph',
                'get_status',
            }

            missing_tools = required_tools - tools
            assert not missing_tools, f'Missing required tools: {missing_tools}'

    @pytest.mark.asyncio
    async def test_add_text_memory(self):
        """Test adding text-based memories."""
        async with GraphitiTestClient() as client:
            # Add memory
            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Tech Conference Notes',
                    'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
                    'source': 'text',
                    'source_description': 'conference notes',
                    'group_id': client.test_group_id,
                },
            )

            assert metric.success, f'Failed to add memory: {metric.details}'
            assert 'queued' in str(result).lower()

            # Wait for processing
            processed = await client.wait_for_episode_processing(expected_count=1)
            assert processed, 'Episode was not processed within timeout'

    @pytest.mark.asyncio
    async def test_add_json_memory(self):
        """Test adding structured JSON memories."""
        async with GraphitiTestClient() as client:
            json_data = {
                'project': {
                    'name': 'GraphitiDB',
                    'version': '2.0.0',
                    'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
                },
                'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
            }

            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Project Data',
                    'episode_body': json.dumps(json_data),
                    'source': 'json',
                    'source_description': 'project database',
                    'group_id': client.test_group_id,
                },
            )

            assert metric.success
            assert 'queued' in str(result).lower()

    @pytest.mark.asyncio
    async def test_add_message_memory(self):
        """Test adding conversation/message memories."""
        async with GraphitiTestClient() as client:
            conversation = """
            user: What are the key features of Graphiti?
            assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
            user: How does it handle entity resolution?
            assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
            """

            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Feature Discussion',
                    'episode_body': conversation,
                    'source': 'message',
                    'source_description': 'support chat',
                    'group_id': client.test_group_id,
                },
            )

            assert metric.success
            assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'


class TestSearchOperations:
    """Test search and retrieval operations."""

    @pytest.mark.asyncio
    async def test_search_nodes_semantic(self):
        """Test semantic search for nodes."""
        async with GraphitiTestClient() as client:
            # First add some test data
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Product Launch',
                    'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
                    'source': 'text',
                    'source_description': 'product roadmap',
                    'group_id': client.test_group_id,
                },
            )

            # Wait for processing
            await client.wait_for_episode_processing()

            # Search for nodes
            result, metric = await client.call_tool_with_metrics(
                'search_memory_nodes',
                {'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
            )

            assert metric.success
            assert result is not None

    @pytest.mark.asyncio
    async def test_search_facts_with_filters(self):
        """Test fact search with various filters."""
        async with GraphitiTestClient() as client:
            # Add test data
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Company Facts',
                    'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
                    'source': 'text',
                    'source_description': 'company profile',
                    'group_id': client.test_group_id,
                },
            )

            await client.wait_for_episode_processing()

            # Search with date filter
            result, metric = await client.call_tool_with_metrics(
                'search_memory_facts',
                {
                    'query': 'company information',
                    'group_id': client.test_group_id,
                    'created_after': '2020-01-01T00:00:00Z',
                    'limit': 20,
                },
            )

            assert metric.success

    @pytest.mark.asyncio
    async def test_hybrid_search(self):
        """Test hybrid search combining semantic and keyword search."""
        async with GraphitiTestClient() as client:
            # Add diverse test data
            test_memories = [
                {
                    'name': 'Technical Doc',
                    'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
                    'source': 'text',
                },
                {
                    'name': 'Architecture',
                    'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
                    'source': 'text',
                },
            ]

            for memory in test_memories:
                memory['group_id'] = client.test_group_id
                memory['source_description'] = 'documentation'
                await client.call_tool_with_metrics('add_memory', memory)

            await client.wait_for_episode_processing(expected_count=2)

            # Test semantic + keyword search
            result, metric = await client.call_tool_with_metrics(
                'search_memory_nodes',
                {'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
            )

            assert metric.success


class TestEpisodeManagement:
    """Test episode lifecycle operations."""

    @pytest.mark.asyncio
    async def test_get_episodes_pagination(self):
        """Test retrieving episodes with pagination."""
        async with GraphitiTestClient() as client:
            # Add multiple episodes
            for i in range(5):
                await client.call_tool_with_metrics(
                    'add_memory',
                    {
                        'name': f'Episode {i}',
                        'episode_body': f'This is test episode number {i}',
                        'source': 'text',
                        'source_description': 'test',
                        'group_id': client.test_group_id,
                    },
                )

            await client.wait_for_episode_processing(expected_count=5)

            # Test pagination
            result, metric = await client.call_tool_with_metrics(
                'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
            )

            assert metric.success
            episodes = json.loads(result) if isinstance(result, str) else result
            assert len(episodes.get('episodes', [])) <= 3

    @pytest.mark.asyncio
    async def test_delete_episode(self):
        """Test deleting specific episodes."""
        async with GraphitiTestClient() as client:
            # Add an episode
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'To Delete',
                    'episode_body': 'This episode will be deleted',
                    'source': 'text',
                    'source_description': 'test',
                    'group_id': client.test_group_id,
                },
            )

            await client.wait_for_episode_processing()

            # Get episode UUID
            result, _ = await client.call_tool_with_metrics(
                'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
            )

            episodes = json.loads(result) if isinstance(result, str) else result
            episode_uuid = episodes['episodes'][0]['uuid']

            # Delete the episode
            result, metric = await client.call_tool_with_metrics(
                'delete_episode', {'episode_uuid': episode_uuid}
            )

            assert metric.success
            assert 'deleted' in str(result).lower()


class TestEntityAndEdgeOperations:
    """Test entity and edge management."""

    @pytest.mark.asyncio
    async def test_get_entity_edge(self):
        """Test retrieving entity edges."""
        async with GraphitiTestClient() as client:
            # Add data to create entities and edges
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Relationship Data',
                    'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
                    'source': 'text',
                    'source_description': 'org chart',
                    'group_id': client.test_group_id,
                },
            )

            await client.wait_for_episode_processing()

            # Search for nodes to get UUIDs
            result, _ = await client.call_tool_with_metrics(
                'search_memory_nodes',
                {'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
            )

            # Note: This test assumes edges are created between entities
            # Actual edge retrieval would require valid edge UUIDs

    @pytest.mark.asyncio
    async def test_delete_entity_edge(self):
        """Test deleting entity edges."""
        # Similar structure to get_entity_edge but with deletion
        pass  # Implement based on actual edge creation patterns


class TestErrorHandling:
    """Test error conditions and edge cases."""

    @pytest.mark.asyncio
    async def test_invalid_tool_arguments(self):
        """Test handling of invalid tool arguments."""
        async with GraphitiTestClient() as client:
            # Missing required arguments
            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {'name': 'Incomplete'},  # Missing required fields
            )

            assert not metric.success
            assert 'error' in str(metric.details).lower()

    @pytest.mark.asyncio
    async def test_timeout_handling(self):
        """Test timeout handling for long operations."""
        async with GraphitiTestClient() as client:
            # Simulate a very large episode that might time out
            large_text = 'Large document content. ' * 10000

            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Large Document',
                    'episode_body': large_text,
                    'source': 'text',
                    'source_description': 'large file',
                    'group_id': client.test_group_id,
                },
                timeout=5,  # Short timeout
            )

            # Check if timeout was handled gracefully
            if not metric.success:
                assert 'timeout' in str(metric.details).lower()

    @pytest.mark.asyncio
    async def test_concurrent_operations(self):
        """Test handling of concurrent operations."""
        async with GraphitiTestClient() as client:
            # Launch multiple operations concurrently
            tasks = []
            for i in range(5):
                task = client.call_tool_with_metrics(
                    'add_memory',
                    {
                        'name': f'Concurrent {i}',
                        'episode_body': f'Concurrent operation {i}',
                        'source': 'text',
                        'source_description': 'concurrent test',
                        'group_id': client.test_group_id,
                    },
                )
                tasks.append(task)

            results = await asyncio.gather(*tasks, return_exceptions=True)

            # Check that operations were queued successfully
            successful = sum(1 for r, m in results if m.success)
            assert successful >= 3  # At least 60% should succeed


class TestPerformance:
    """Test performance characteristics and optimization."""

    @pytest.mark.asyncio
    async def test_latency_metrics(self):
        """Measure and validate operation latencies."""
        async with GraphitiTestClient() as client:
            operations = [
                (
                    'add_memory',
                    {
                        'name': 'Perf Test',
                        'episode_body': 'Simple text',
                        'source': 'text',
                        'source_description': 'test',
                        'group_id': client.test_group_id,
                    },
                ),
                (
                    'search_memory_nodes',
                    {'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
                ),
                ('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
            ]

            for tool_name, args in operations:
                _, metric = await client.call_tool_with_metrics(tool_name, args)

                # Log performance metrics
                print(f'{tool_name}: {metric.duration:.2f}s')

                # Basic latency assertions
                if tool_name == 'get_episodes':
                    assert metric.duration < 2, f'{tool_name} too slow'
                elif tool_name == 'search_memory_nodes':
                    assert metric.duration < 10, f'{tool_name} too slow'

    @pytest.mark.asyncio
    async def test_batch_processing_efficiency(self):
        """Test efficiency of batch operations."""
        async with GraphitiTestClient() as client:
            batch_size = 10
            start_time = time.time()

            # Batch add memories
            for i in range(batch_size):
                await client.call_tool_with_metrics(
                    'add_memory',
                    {
                        'name': f'Batch {i}',
                        'episode_body': f'Batch content {i}',
                        'source': 'text',
                        'source_description': 'batch test',
                        'group_id': client.test_group_id,
                    },
                )

            # Wait for all to process
            processed = await client.wait_for_episode_processing(
                expected_count=batch_size,
                max_wait=120,  # Allow more time for batch
            )

            total_time = time.time() - start_time
            avg_time_per_item = total_time / batch_size

            assert processed, f'Failed to process {batch_size} items'
            assert avg_time_per_item < 15, (
                f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
            )

            # Generate performance report
            print('\nBatch Performance Report:')
            print(f'  Total items: {batch_size}')
            print(f'  Total time: {total_time:.2f}s')
            print(f'  Avg per item: {avg_time_per_item:.2f}s')


class TestDatabaseBackends:
    """Test different database backend configurations."""

    @pytest.mark.asyncio
    @pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
    async def test_database_operations(self, database):
        """Test operations with different database backends."""
        env_vars = {
            'DATABASE_PROVIDER': database,
            'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
        }

        if database == 'neo4j':
            env_vars.update(
                {
                    'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                    'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                    'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
                }
            )
        elif database == 'falkordb':
            env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')

        # This test would require setting up server with specific database
        # Implementation depends on database availability
        pass  # Placeholder for database-specific tests


def generate_test_report(client: GraphitiTestClient) -> str:
    """Generate a comprehensive test report from metrics."""
    if not client.metrics:
        return 'No metrics collected'

    report = []
    report.append('\n' + '=' * 60)
    report.append('GRAPHITI MCP TEST REPORT')
    report.append('=' * 60)

    # Summary statistics
    total_ops = len(client.metrics)
    successful_ops = sum(1 for m in client.metrics if m.success)
    avg_duration = sum(m.duration for m in client.metrics) / total_ops

    report.append(f'\nTotal Operations: {total_ops}')
    report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
    report.append(f'Average Duration: {avg_duration:.2f}s')

    # Operation breakdown
    report.append('\nOperation Breakdown:')
    operation_stats = {}
    for metric in client.metrics:
        if metric.operation not in operation_stats:
            operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
        stats = operation_stats[metric.operation]
        stats['count'] += 1
        stats['success'] += 1 if metric.success else 0
        stats['total_duration'] += metric.duration

    for op, stats in sorted(operation_stats.items()):
        avg_dur = stats['total_duration'] / stats['count']
        success_rate = stats['success'] / stats['count'] * 100
        report.append(
            f'  {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
        )

    # Slowest operations
    slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
    report.append('\nSlowest Operations:')
    for metric in slowest:
        report.append(f'  {metric.operation}: {metric.duration:.2f}s')

    report.append('=' * 60)
    return '\n'.join(report)


if __name__ == '__main__':
    # Run tests with pytest
    pytest.main([__file__, '-v', '--asyncio-mode=auto'])

```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/edge_operations.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from datetime import datetime
from time import time

from pydantic import BaseModel
from typing_extensions import LiteralString

from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import (
    CommunityEdge,
    EntityEdge,
    EpisodicEdge,
    create_entity_edge_embeddings,
)
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact

DEFAULT_EDGE_NAME = 'RELATES_TO'

logger = logging.getLogger(__name__)


def build_episodic_edges(
    entity_nodes: list[EntityNode],
    episode_uuid: str,
    created_at: datetime,
) -> list[EpisodicEdge]:
    episodic_edges: list[EpisodicEdge] = [
        EpisodicEdge(
            source_node_uuid=episode_uuid,
            target_node_uuid=node.uuid,
            created_at=created_at,
            group_id=node.group_id,
        )
        for node in entity_nodes
    ]

    logger.debug(f'Built episodic edges: {episodic_edges}')

    return episodic_edges


def build_community_edges(
    entity_nodes: list[EntityNode],
    community_node: CommunityNode,
    created_at: datetime,
) -> list[CommunityEdge]:
    edges: list[CommunityEdge] = [
        CommunityEdge(
            source_node_uuid=community_node.uuid,
            target_node_uuid=node.uuid,
            created_at=created_at,
            group_id=community_node.group_id,
        )
        for node in entity_nodes
    ]

    return edges


async def extract_edges(
    clients: GraphitiClients,
    episode: EpisodicNode,
    nodes: list[EntityNode],
    previous_episodes: list[EpisodicNode],
    edge_type_map: dict[tuple[str, str], list[str]],
    group_id: str = '',
    edge_types: dict[str, type[BaseModel]] | None = None,
) -> list[EntityEdge]:
    start = time()

    extract_edges_max_tokens = 16384
    llm_client = clients.llm_client

    edge_type_signature_map: dict[str, tuple[str, str]] = {
        edge_type: signature
        for signature, edge_types in edge_type_map.items()
        for edge_type in edge_types
    }

    edge_types_context = (
        [
            {
                'fact_type_name': type_name,
                'fact_type_signature': edge_type_signature_map.get(type_name, ('Entity', 'Entity')),
                'fact_type_description': type_model.__doc__,
            }
            for type_name, type_model in edge_types.items()
        ]
        if edge_types is not None
        else []
    )

    # Prepare context for LLM
    context = {
        'episode_content': episode.content,
        'nodes': [
            {'id': idx, 'name': node.name, 'entity_types': node.labels}
            for idx, node in enumerate(nodes)
        ],
        'previous_episodes': [ep.content for ep in previous_episodes],
        'reference_time': episode.valid_at,
        'edge_types': edge_types_context,
        'custom_prompt': '',
    }

    facts_missed = True
    reflexion_iterations = 0
    while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
        llm_response = await llm_client.generate_response(
            prompt_library.extract_edges.edge(context),
            response_model=ExtractedEdges,
            max_tokens=extract_edges_max_tokens,
            group_id=group_id,
            prompt_name='extract_edges.edge',
        )
        edges_data = ExtractedEdges(**llm_response).edges

        context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]

        reflexion_iterations += 1
        if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
            reflexion_response = await llm_client.generate_response(
                prompt_library.extract_edges.reflexion(context),
                response_model=MissingFacts,
                max_tokens=extract_edges_max_tokens,
                group_id=group_id,
                prompt_name='extract_edges.reflexion',
            )

            missing_facts = reflexion_response.get('missing_facts', [])

            custom_prompt = 'The following facts were missed in a previous extraction: '
            for fact in missing_facts:
                custom_prompt += f'\n{fact},'

            context['custom_prompt'] = custom_prompt

            facts_missed = len(missing_facts) != 0

    end = time()
    logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')

    if len(edges_data) == 0:
        return []

    # Convert the extracted data into EntityEdge objects
    edges = []
    for edge_data in edges_data:
        # Validate Edge Date information
        valid_at = edge_data.valid_at
        invalid_at = edge_data.invalid_at
        valid_at_datetime = None
        invalid_at_datetime = None

        # Filter out empty edges
        if not edge_data.fact.strip():
            continue

        source_node_idx = edge_data.source_entity_id
        target_node_idx = edge_data.target_entity_id

        if len(nodes) == 0:
            logger.warning('No entities provided for edge extraction')
            continue

        if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
            logger.warning(
                f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
                f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
                f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
            )
            continue
        source_node_uuid = nodes[source_node_idx].uuid
        target_node_uuid = nodes[target_node_idx].uuid

        if valid_at:
            try:
                valid_at_datetime = ensure_utc(
                    datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
                )
            except ValueError as e:
                logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')

        if invalid_at:
            try:
                invalid_at_datetime = ensure_utc(
                    datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
                )
            except ValueError as e:
                logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
        edge = EntityEdge(
            source_node_uuid=source_node_uuid,
            target_node_uuid=target_node_uuid,
            name=edge_data.relation_type,
            group_id=group_id,
            fact=edge_data.fact,
            episodes=[episode.uuid],
            created_at=utc_now(),
            valid_at=valid_at_datetime,
            invalid_at=invalid_at_datetime,
        )
        edges.append(edge)
        logger.debug(
            f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
        )

    logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')

    return edges


async def resolve_extracted_edges(
    clients: GraphitiClients,
    extracted_edges: list[EntityEdge],
    episode: EpisodicNode,
    entities: list[EntityNode],
    edge_types: dict[str, type[BaseModel]],
    edge_type_map: dict[tuple[str, str], list[str]],
) -> tuple[list[EntityEdge], list[EntityEdge]]:
    # Fast path: deduplicate exact matches within the extracted edges before parallel processing
    seen: dict[tuple[str, str, str], EntityEdge] = {}
    deduplicated_edges: list[EntityEdge] = []

    for edge in extracted_edges:
        key = (
            edge.source_node_uuid,
            edge.target_node_uuid,
            _normalize_string_exact(edge.fact),
        )
        if key not in seen:
            seen[key] = edge
            deduplicated_edges.append(edge)

    extracted_edges = deduplicated_edges

    driver = clients.driver
    llm_client = clients.llm_client
    embedder = clients.embedder
    await create_entity_edge_embeddings(embedder, extracted_edges)

    valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
        *[
            EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
            for edge in extracted_edges
        ]
    )

    related_edges_results: list[SearchResults] = await semaphore_gather(
        *[
            search(
                clients,
                extracted_edge.fact,
                group_ids=[extracted_edge.group_id],
                config=EDGE_HYBRID_SEARCH_RRF,
                search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
            )
            for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
        ]
    )

    related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]

    edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
        *[
            search(
                clients,
                extracted_edge.fact,
                group_ids=[extracted_edge.group_id],
                config=EDGE_HYBRID_SEARCH_RRF,
                search_filter=SearchFilters(),
            )
            for extracted_edge in extracted_edges
        ]
    )

    edge_invalidation_candidates: list[list[EntityEdge]] = [
        result.edges for result in edge_invalidation_candidate_results
    ]

    logger.debug(
        f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
    )

    # Build entity hash table
    uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}

    # Determine which edge types are relevant for each edge.
    # `edge_types_lst` stores the subset of custom edge definitions whose
    # node signature matches each extracted edge. Anything outside this subset
    # should only stay on the edge if it is a non-custom (LLM generated) label.
    edge_types_lst: list[dict[str, type[BaseModel]]] = []
    custom_type_names = set(edge_types or {})
    for extracted_edge in extracted_edges:
        source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
        target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
        source_node_labels = (
            source_node.labels + ['Entity'] if source_node is not None else ['Entity']
        )
        target_node_labels = (
            target_node.labels + ['Entity'] if target_node is not None else ['Entity']
        )
        label_tuples = [
            (source_label, target_label)
            for source_label in source_node_labels
            for target_label in target_node_labels
        ]

        extracted_edge_types = {}
        for label_tuple in label_tuples:
            type_names = edge_type_map.get(label_tuple, [])
            for type_name in type_names:
                type_model = edge_types.get(type_name)
                if type_model is None:
                    continue

                extracted_edge_types[type_name] = type_model

        edge_types_lst.append(extracted_edge_types)

    for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
        allowed_type_names = set(extracted_edge_types)
        is_custom_name = extracted_edge.name in custom_type_names
        if not allowed_type_names:
            # No custom types are valid for this node pairing. Keep LLM generated
            # labels, but flip disallowed custom names back to the default.
            if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
                extracted_edge.name = DEFAULT_EDGE_NAME
            continue
        if is_custom_name and extracted_edge.name not in allowed_type_names:
            # Custom name exists but it is not permitted for this source/target
            # signature, so fall back to the default edge label.
            extracted_edge.name = DEFAULT_EDGE_NAME

    # resolve edges with related edges in the graph and find invalidation candidates
    results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
        await semaphore_gather(
            *[
                resolve_extracted_edge(
                    llm_client,
                    extracted_edge,
                    related_edges,
                    existing_edges,
                    episode,
                    extracted_edge_types,
                    custom_type_names,
                )
                for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
                    extracted_edges,
                    related_edges_lists,
                    edge_invalidation_candidates,
                    edge_types_lst,
                    strict=True,
                )
            ]
        )
    )

    resolved_edges: list[EntityEdge] = []
    invalidated_edges: list[EntityEdge] = []
    for result in results:
        resolved_edge = result[0]
        invalidated_edge_chunk = result[1]

        resolved_edges.append(resolved_edge)
        invalidated_edges.extend(invalidated_edge_chunk)

    logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')

    await semaphore_gather(
        create_entity_edge_embeddings(embedder, resolved_edges),
        create_entity_edge_embeddings(embedder, invalidated_edges),
    )

    return resolved_edges, invalidated_edges


def resolve_edge_contradictions(
    resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
) -> list[EntityEdge]:
    if len(invalidation_candidates) == 0:
        return []

    # Determine which contradictory edges need to be expired
    invalidated_edges: list[EntityEdge] = []
    for edge in invalidation_candidates:
        # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
        edge_invalid_at_utc = ensure_utc(edge.invalid_at)
        resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
        edge_valid_at_utc = ensure_utc(edge.valid_at)
        resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)

        if (
            edge_invalid_at_utc is not None
            and resolved_edge_valid_at_utc is not None
            and edge_invalid_at_utc <= resolved_edge_valid_at_utc
        ) or (
            edge_valid_at_utc is not None
            and resolved_edge_invalid_at_utc is not None
            and resolved_edge_invalid_at_utc <= edge_valid_at_utc
        ):
            continue
        # New edge invalidates edge
        elif (
            edge_valid_at_utc is not None
            and resolved_edge_valid_at_utc is not None
            and edge_valid_at_utc < resolved_edge_valid_at_utc
        ):
            edge.invalid_at = resolved_edge.valid_at
            edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
            invalidated_edges.append(edge)

    return invalidated_edges


async def resolve_extracted_edge(
    llm_client: LLMClient,
    extracted_edge: EntityEdge,
    related_edges: list[EntityEdge],
    existing_edges: list[EntityEdge],
    episode: EpisodicNode,
    edge_type_candidates: dict[str, type[BaseModel]] | None = None,
    custom_edge_type_names: set[str] | None = None,
) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
    """Resolve an extracted edge against existing graph context.

    Parameters
    ----------
    llm_client : LLMClient
        Client used to invoke the LLM for deduplication and attribute extraction.
    extracted_edge : EntityEdge
        Newly extracted edge whose canonical representation is being resolved.
    related_edges : list[EntityEdge]
        Candidate edges with identical endpoints used for duplicate detection.
    existing_edges : list[EntityEdge]
        Broader set of edges evaluated for contradiction / invalidation.
    episode : EpisodicNode
        Episode providing content context when extracting edge attributes.
    edge_type_candidates : dict[str, type[BaseModel]] | None
        Custom edge types permitted for the current source/target signature.
    custom_edge_type_names : set[str] | None
        Full catalog of registered custom edge names. Used to distinguish
        between disallowed custom types (which fall back to the default label)
        and ad-hoc labels emitted by the LLM.

    Returns
    -------
    tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
        The resolved edge, any duplicates, and edges to invalidate.
    """
    if len(related_edges) == 0 and len(existing_edges) == 0:
        return extracted_edge, [], []

    # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
    normalized_fact = _normalize_string_exact(extracted_edge.fact)
    for edge in related_edges:
        if (
            edge.source_node_uuid == extracted_edge.source_node_uuid
            and edge.target_node_uuid == extracted_edge.target_node_uuid
            and _normalize_string_exact(edge.fact) == normalized_fact
        ):
            resolved = edge
            if episode is not None and episode.uuid not in resolved.episodes:
                resolved.episodes.append(episode.uuid)
            return resolved, [], []

    start = time()

    # Prepare context for LLM
    related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]

    invalidation_edge_candidates_context = [
        {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
    ]

    edge_types_context = (
        [
            {
                'fact_type_name': type_name,
                'fact_type_description': type_model.__doc__,
            }
            for type_name, type_model in edge_type_candidates.items()
        ]
        if edge_type_candidates is not None
        else []
    )

    context = {
        'existing_edges': related_edges_context,
        'new_edge': extracted_edge.fact,
        'edge_invalidation_candidates': invalidation_edge_candidates_context,
        'edge_types': edge_types_context,
    }

    if related_edges or existing_edges:
        logger.debug(
            'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
            len(related_edges),
            f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
            len(existing_edges),
            f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
        )

    llm_response = await llm_client.generate_response(
        prompt_library.dedupe_edges.resolve_edge(context),
        response_model=EdgeDuplicate,
        model_size=ModelSize.small,
        prompt_name='dedupe_edges.resolve_edge',
    )
    response_object = EdgeDuplicate(**llm_response)
    duplicate_facts = response_object.duplicate_facts

    # Validate duplicate_facts are in valid range for EXISTING FACTS
    invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
    if invalid_duplicates:
        logger.warning(
            'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
            invalid_duplicates,
            len(related_edges) - 1,
        )

    duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]

    resolved_edge = extracted_edge
    for duplicate_fact_id in duplicate_fact_ids:
        resolved_edge = related_edges[duplicate_fact_id]
        break

    if duplicate_fact_ids and episode is not None:
        resolved_edge.episodes.append(episode.uuid)

    contradicted_facts: list[int] = response_object.contradicted_facts

    # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
    invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
    if invalid_contradictions:
        logger.warning(
            'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
            invalid_contradictions,
            len(existing_edges) - 1,
        )

    invalidation_candidates: list[EntityEdge] = [
        existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
    ]

    fact_type: str = response_object.fact_type
    candidate_type_names = set(edge_type_candidates or {})
    custom_type_names = custom_edge_type_names or set()

    is_default_type = fact_type.upper() == 'DEFAULT'
    is_custom_type = fact_type in custom_type_names
    is_allowed_custom_type = fact_type in candidate_type_names

    if is_allowed_custom_type:
        # The LLM selected a custom type that is allowed for the node pair.
        # Adopt the custom type and, if needed, extract its structured attributes.
        resolved_edge.name = fact_type

        edge_attributes_context = {
            'episode_content': episode.content,
            'reference_time': episode.valid_at,
            'fact': resolved_edge.fact,
        }

        edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
        if edge_model is not None and len(edge_model.model_fields) != 0:
            edge_attributes_response = await llm_client.generate_response(
                prompt_library.extract_edges.extract_attributes(edge_attributes_context),
                response_model=edge_model,  # type: ignore
                model_size=ModelSize.small,
                prompt_name='extract_edges.extract_attributes',
            )

            resolved_edge.attributes = edge_attributes_response
    elif not is_default_type and is_custom_type:
        # The LLM picked a custom type that is not allowed for this signature.
        # Reset to the default label and drop any structured attributes.
        resolved_edge.name = DEFAULT_EDGE_NAME
        resolved_edge.attributes = {}
    elif not is_default_type:
        # Non-custom labels are allowed to pass through so long as the LLM does
        # not return the sentinel DEFAULT value.
        resolved_edge.name = fact_type
        resolved_edge.attributes = {}

    end = time()
    logger.debug(
        f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
    )

    now = utc_now()

    if resolved_edge.invalid_at and not resolved_edge.expired_at:
        resolved_edge.expired_at = now

    # Determine if the new_edge needs to be expired
    if resolved_edge.expired_at is None:
        invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
        for candidate in invalidation_candidates:
            candidate_valid_at_utc = ensure_utc(candidate.valid_at)
            resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
            if (
                candidate_valid_at_utc is not None
                and resolved_edge_valid_at_utc is not None
                and candidate_valid_at_utc > resolved_edge_valid_at_utc
            ):
                # Expire new edge since we have information about more recent events
                resolved_edge.invalid_at = candidate.valid_at
                resolved_edge.expired_at = now
                break

    # Determine which contradictory edges need to be expired
    invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
        resolved_edge, invalidation_candidates
    )
    duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]

    return resolved_edge, invalidated_edges, duplicate_edges


async def filter_existing_duplicate_of_edges(
    driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
) -> list[tuple[EntityNode, EntityNode]]:
    if not duplicates_node_tuples:
        return []

    duplicate_nodes_map = {
        (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
    }

    if driver.provider == GraphProvider.NEPTUNE:
        query: LiteralString = """
            UNWIND $duplicate_node_uuids AS duplicate_tuple
            MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
            RETURN DISTINCT
                n.uuid AS source_uuid,
                m.uuid AS target_uuid
        """

        duplicate_nodes = [
            {'source': source.uuid, 'target': target.uuid}
            for source, target in duplicates_node_tuples
        ]

        records, _, _ = await driver.execute_query(
            query,
            duplicate_node_uuids=duplicate_nodes,
            routing_='r',
        )
    else:
        if driver.provider == GraphProvider.KUZU:
            query = """
                UNWIND $duplicate_node_uuids AS duplicate
                MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
                RETURN DISTINCT
                    n.uuid AS source_uuid,
                    m.uuid AS target_uuid
            """
            duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
        else:
            query: LiteralString = """
                UNWIND $duplicate_node_uuids AS duplicate_tuple
                MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
                RETURN DISTINCT
                    n.uuid AS source_uuid,
                    m.uuid AS target_uuid
            """
            duplicate_node_uuids = list(duplicate_nodes_map.keys())

        records, _, _ = await driver.execute_query(
            query,
            duplicate_node_uuids=duplicate_node_uuids,
            routing_='r',
        )

    # Remove duplicates that already have the IS_DUPLICATE_OF edge
    for record in records:
        duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
        if duplicate_nodes_map.get(duplicate_tuple):
            duplicate_nodes_map.pop(duplicate_tuple)

    return list(duplicate_nodes_map.values())

```

--------------------------------------------------------------------------------
/graphiti_core/nodes.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field
from typing_extensions import LiteralString

from graphiti_core.driver.driver import (
    GraphDriver,
    GraphProvider,
)
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
    COMMUNITY_NODE_RETURN,
    COMMUNITY_NODE_RETURN_NEPTUNE,
    EPISODIC_NODE_RETURN,
    EPISODIC_NODE_RETURN_NEPTUNE,
    get_community_node_save_query,
    get_entity_node_return_query,
    get_entity_node_save_query,
    get_episode_node_save_query,
)
from graphiti_core.utils.datetime_utils import utc_now

logger = logging.getLogger(__name__)


class EpisodeType(Enum):
    """
    Enumeration of different types of episodes that can be processed.

    This enum defines the various sources or formats of episodes that the system
    can handle. It's used to categorize and potentially handle different types
    of input data differently.

    Attributes:
    -----------
    message : str
        Represents a standard message-type episode. The content for this type
        should be formatted as "actor: content". For example, "user: Hello, how are you?"
        or "assistant: I'm doing well, thank you for asking."
    json : str
        Represents an episode containing a JSON string object with structured data.
    text : str
        Represents a plain text episode.
    """

    message = 'message'
    json = 'json'
    text = 'text'

    @staticmethod
    def from_str(episode_type: str):
        if episode_type == 'message':
            return EpisodeType.message
        if episode_type == 'json':
            return EpisodeType.json
        if episode_type == 'text':
            return EpisodeType.text
        logger.error(f'Episode type: {episode_type} not implemented')
        raise NotImplementedError


class Node(BaseModel, ABC):
    uuid: str = Field(default_factory=lambda: str(uuid4()))
    name: str = Field(description='name of the node')
    group_id: str = Field(description='partition of the graph')
    labels: list[str] = Field(default_factory=list)
    created_at: datetime = Field(default_factory=lambda: utc_now())

    @abstractmethod
    async def save(self, driver: GraphDriver): ...

    async def delete(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.node_delete(self, driver)

        match driver.provider:
            case GraphProvider.NEO4J:
                records, _, _ = await driver.execute_query(
                    """
                    MATCH (n {uuid: $uuid})
                    WHERE n:Entity OR n:Episodic OR n:Community
                    OPTIONAL MATCH (n)-[r]-()
                    WITH collect(r.uuid) AS edge_uuids, n
                    DETACH DELETE n
                    RETURN edge_uuids
                    """,
                    uuid=self.uuid,
                )

            case GraphProvider.KUZU:
                for label in ['Episodic', 'Community']:
                    await driver.execute_query(
                        f"""
                        MATCH (n:{label} {{uuid: $uuid}})
                        DETACH DELETE n
                        """,
                        uuid=self.uuid,
                    )
                # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
                # Explicitly delete the "edge" nodes first, then the entity node.
                await driver.execute_query(
                    """
                    MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
                    DETACH DELETE e
                    """,
                    uuid=self.uuid,
                )
                await driver.execute_query(
                    """
                    MATCH (n:Entity {uuid: $uuid})
                    DETACH DELETE n
                    """,
                    uuid=self.uuid,
                )
            case _:  # FalkorDB, Neptune
                for label in ['Entity', 'Episodic', 'Community']:
                    await driver.execute_query(
                        f"""
                        MATCH (n:{label} {{uuid: $uuid}})
                        DETACH DELETE n
                        """,
                        uuid=self.uuid,
                    )

        logger.debug(f'Deleted Node: {self.uuid}')

    def __hash__(self):
        return hash(self.uuid)

    def __eq__(self, other):
        if isinstance(other, Node):
            return self.uuid == other.uuid
        return False

    @classmethod
    async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.node_delete_by_group_id(
                cls, driver, group_id, batch_size
            )

        match driver.provider:
            case GraphProvider.NEO4J:
                async with driver.session() as session:
                    await session.run(
                        """
                        MATCH (n:Entity|Episodic|Community {group_id: $group_id})
                        CALL (n) {
                            DETACH DELETE n
                        } IN TRANSACTIONS OF $batch_size ROWS
                        """,
                        group_id=group_id,
                        batch_size=batch_size,
                    )

            case GraphProvider.KUZU:
                for label in ['Episodic', 'Community']:
                    await driver.execute_query(
                        f"""
                        MATCH (n:{label} {{group_id: $group_id}})
                        DETACH DELETE n
                        """,
                        group_id=group_id,
                    )
                # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
                # Explicitly delete the "edge" nodes first, then the entity node.
                await driver.execute_query(
                    """
                    MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
                    DETACH DELETE e
                    """,
                    group_id=group_id,
                )
                await driver.execute_query(
                    """
                    MATCH (n:Entity {group_id: $group_id})
                    DETACH DELETE n
                    """,
                    group_id=group_id,
                )
            case _:  # FalkorDB, Neptune
                for label in ['Entity', 'Episodic', 'Community']:
                    await driver.execute_query(
                        f"""
                        MATCH (n:{label} {{group_id: $group_id}})
                        DETACH DELETE n
                        """,
                        group_id=group_id,
                    )

    @classmethod
    async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.node_delete_by_uuids(
                cls, driver, uuids, group_id=None, batch_size=batch_size
            )

        match driver.provider:
            case GraphProvider.FALKORDB:
                for label in ['Entity', 'Episodic', 'Community']:
                    await driver.execute_query(
                        f"""
                        MATCH (n:{label})
                        WHERE n.uuid IN $uuids
                        DETACH DELETE n
                        """,
                        uuids=uuids,
                    )
            case GraphProvider.KUZU:
                for label in ['Episodic', 'Community']:
                    await driver.execute_query(
                        f"""
                        MATCH (n:{label})
                        WHERE n.uuid IN $uuids
                        DETACH DELETE n
                        """,
                        uuids=uuids,
                    )
                # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
                # Explicitly delete the "edge" nodes first, then the entity node.
                await driver.execute_query(
                    """
                    MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
                    WHERE n.uuid IN $uuids
                    DETACH DELETE e
                    """,
                    uuids=uuids,
                )
                await driver.execute_query(
                    """
                    MATCH (n:Entity)
                    WHERE n.uuid IN $uuids
                    DETACH DELETE n
                    """,
                    uuids=uuids,
                )
            case _:  # Neo4J, Neptune
                async with driver.session() as session:
                    # Collect all edge UUIDs before deleting nodes
                    await session.run(
                        """
                        MATCH (n:Entity|Episodic|Community)
                        WHERE n.uuid IN $uuids
                        MATCH (n)-[r]-()
                        RETURN collect(r.uuid) AS edge_uuids
                        """,
                        uuids=uuids,
                    )

                    # Now delete the nodes in batches
                    await session.run(
                        """
                        MATCH (n:Entity|Episodic|Community)
                        WHERE n.uuid IN $uuids
                        CALL (n) {
                            DETACH DELETE n
                        } IN TRANSACTIONS OF $batch_size ROWS
                        """,
                        uuids=uuids,
                        batch_size=batch_size,
                    )

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...


class EpisodicNode(Node):
    source: EpisodeType = Field(description='source type')
    source_description: str = Field(description='description of the data source')
    content: str = Field(description='raw episode data')
    valid_at: datetime = Field(
        description='datetime of when the original document was created',
    )
    entity_edges: list[str] = Field(
        description='list of entity edges referenced in this episode',
        default_factory=list,
    )

    async def save(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.episodic_node_save(self, driver)

        episode_args = {
            'uuid': self.uuid,
            'name': self.name,
            'group_id': self.group_id,
            'source_description': self.source_description,
            'content': self.content,
            'entity_edges': self.entity_edges,
            'created_at': self.created_at,
            'valid_at': self.valid_at,
            'source': self.source.value,
        }

        result = await driver.execute_query(
            get_episode_node_save_query(driver.provider), **episode_args
        )

        logger.debug(f'Saved Node to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (e:Episodic {uuid: $uuid})
            RETURN
            """
            + (
                EPISODIC_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else EPISODIC_NODE_RETURN
            ),
            uuid=uuid,
            routing_='r',
        )

        episodes = [get_episodic_node_from_record(record) for record in records]

        if len(episodes) == 0:
            raise NodeNotFoundError(uuid)

        return episodes[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (e:Episodic)
            WHERE e.uuid IN $uuids
            RETURN DISTINCT
            """
            + (
                EPISODIC_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else EPISODIC_NODE_RETURN
            ),
            uuids=uuids,
            routing_='r',
        )

        episodes = [get_episodic_node_from_record(record) for record in records]

        return episodes

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (e:Episodic)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN DISTINCT
            """
            + (
                EPISODIC_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else EPISODIC_NODE_RETURN
            )
            + """
            ORDER BY uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        episodes = [get_episodic_node_from_record(record) for record in records]

        return episodes

    @classmethod
    async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
            RETURN DISTINCT
            """
            + (
                EPISODIC_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else EPISODIC_NODE_RETURN
            ),
            entity_node_uuid=entity_node_uuid,
            routing_='r',
        )

        episodes = [get_episodic_node_from_record(record) for record in records]

        return episodes


class EntityNode(Node):
    name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
    summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
    attributes: dict[str, Any] = Field(
        default={}, description='Additional attributes of the node. Dependent on node labels'
    )

    async def generate_name_embedding(self, embedder: EmbedderClient):
        start = time()
        text = self.name.replace('\n', ' ')
        self.name_embedding = await embedder.create(input_data=[text])
        end = time()
        logger.debug(f'embedded {text} in {end - start} ms')

        return self.name_embedding

    async def load_name_embedding(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.node_load_embeddings(self, driver)

        if driver.provider == GraphProvider.NEPTUNE:
            query: LiteralString = """
                MATCH (n:Entity {uuid: $uuid})
                RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
            """

        else:
            query: LiteralString = """
                MATCH (n:Entity {uuid: $uuid})
                RETURN n.name_embedding AS name_embedding
            """
        records, _, _ = await driver.execute_query(
            query,
            uuid=self.uuid,
            routing_='r',
        )

        if len(records) == 0:
            raise NodeNotFoundError(self.uuid)

        self.name_embedding = records[0]['name_embedding']

    async def save(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.node_save(self, driver)

        entity_data: dict[str, Any] = {
            'uuid': self.uuid,
            'name': self.name,
            'name_embedding': self.name_embedding,
            'group_id': self.group_id,
            'summary': self.summary,
            'created_at': self.created_at,
        }

        if driver.provider == GraphProvider.KUZU:
            entity_data['attributes'] = json.dumps(self.attributes)
            entity_data['labels'] = list(set(self.labels + ['Entity']))
            result = await driver.execute_query(
                get_entity_node_save_query(driver.provider, labels=''),
                **entity_data,
            )
        else:
            entity_data.update(self.attributes or {})
            labels = ':'.join(self.labels + ['Entity'])

            result = await driver.execute_query(
                get_entity_node_save_query(driver.provider, labels),
                entity_data=entity_data,
            )

        logger.debug(f'Saved Node to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Entity {uuid: $uuid})
            RETURN
            """
            + get_entity_node_return_query(driver.provider),
            uuid=uuid,
            routing_='r',
        )

        nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

        if len(nodes) == 0:
            raise NodeNotFoundError(uuid)

        return nodes[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Entity)
            WHERE n.uuid IN $uuids
            RETURN
            """
            + get_entity_node_return_query(driver.provider),
            uuids=uuids,
            routing_='r',
        )

        nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

        return nodes

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
        with_embeddings: bool = False,
    ):
        cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
        with_embeddings_query: LiteralString = (
            """,
            n.name_embedding AS name_embedding
            """
            if with_embeddings
            else ''
        )

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Entity)
            WHERE n.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + get_entity_node_return_query(driver.provider)
            + with_embeddings_query
            + """
            ORDER BY n.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

        return nodes


class CommunityNode(Node):
    name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
    summary: str = Field(description='region summary of member nodes', default_factory=str)

    async def save(self, driver: GraphDriver):
        if driver.provider == GraphProvider.NEPTUNE:
            await driver.save_to_aoss(  # pyright: ignore reportAttributeAccessIssue
                'communities',
                [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
            )
        result = await driver.execute_query(
            get_community_node_save_query(driver.provider),  # type: ignore
            uuid=self.uuid,
            name=self.name,
            group_id=self.group_id,
            summary=self.summary,
            name_embedding=self.name_embedding,
            created_at=self.created_at,
        )

        logger.debug(f'Saved Node to Graph: {self.uuid}')

        return result

    async def generate_name_embedding(self, embedder: EmbedderClient):
        start = time()
        text = self.name.replace('\n', ' ')
        self.name_embedding = await embedder.create(input_data=[text])
        end = time()
        logger.debug(f'embedded {text} in {end - start} ms')

        return self.name_embedding

    async def load_name_embedding(self, driver: GraphDriver):
        if driver.provider == GraphProvider.NEPTUNE:
            query: LiteralString = """
                MATCH (c:Community {uuid: $uuid})
                RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
            """
        else:
            query: LiteralString = """
            MATCH (c:Community {uuid: $uuid})
            RETURN c.name_embedding AS name_embedding
            """

        records, _, _ = await driver.execute_query(
            query,
            uuid=self.uuid,
            routing_='r',
        )

        if len(records) == 0:
            raise NodeNotFoundError(self.uuid)

        self.name_embedding = records[0]['name_embedding']

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (c:Community {uuid: $uuid})
            RETURN
            """
            + (
                COMMUNITY_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else COMMUNITY_NODE_RETURN
            ),
            uuid=uuid,
            routing_='r',
        )

        nodes = [get_community_node_from_record(record) for record in records]

        if len(nodes) == 0:
            raise NodeNotFoundError(uuid)

        return nodes[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (c:Community)
            WHERE c.uuid IN $uuids
            RETURN
            """
            + (
                COMMUNITY_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else COMMUNITY_NODE_RETURN
            ),
            uuids=uuids,
            routing_='r',
        )

        communities = [get_community_node_from_record(record) for record in records]

        return communities

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (c:Community)
            WHERE c.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + (
                COMMUNITY_NODE_RETURN_NEPTUNE
                if driver.provider == GraphProvider.NEPTUNE
                else COMMUNITY_NODE_RETURN
            )
            + """
            ORDER BY c.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        communities = [get_community_node_from_record(record) for record in records]

        return communities


# Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
    created_at = parse_db_date(record['created_at'])
    valid_at = parse_db_date(record['valid_at'])

    if created_at is None:
        raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
    if valid_at is None:
        raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')

    return EpisodicNode(
        content=record['content'],
        created_at=created_at,
        valid_at=valid_at,
        uuid=record['uuid'],
        group_id=record['group_id'],
        source=EpisodeType.from_str(record['source']),
        name=record['name'],
        source_description=record['source_description'],
        entity_edges=record['entity_edges'],
    )


def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
    if provider == GraphProvider.KUZU:
        attributes = json.loads(record['attributes']) if record['attributes'] else {}
    else:
        attributes = record['attributes']
        attributes.pop('uuid', None)
        attributes.pop('name', None)
        attributes.pop('group_id', None)
        attributes.pop('name_embedding', None)
        attributes.pop('summary', None)
        attributes.pop('created_at', None)
        attributes.pop('labels', None)

    labels = record.get('labels', [])
    group_id = record.get('group_id')
    if 'Entity_' + group_id.replace('-', '') in labels:
        labels.remove('Entity_' + group_id.replace('-', ''))

    entity_node = EntityNode(
        uuid=record['uuid'],
        name=record['name'],
        name_embedding=record.get('name_embedding'),
        group_id=group_id,
        labels=labels,
        created_at=parse_db_date(record['created_at']),  # type: ignore
        summary=record['summary'],
        attributes=attributes,
    )

    return entity_node


def get_community_node_from_record(record: Any) -> CommunityNode:
    return CommunityNode(
        uuid=record['uuid'],
        name=record['name'],
        group_id=record['group_id'],
        name_embedding=record['name_embedding'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
        summary=record['summary'],
    )


async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
    # filter out falsey values from nodes
    filtered_nodes = [node for node in nodes if node.name]

    if not filtered_nodes:
        return

    name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
    for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
        node.name_embedding = name_embedding

```

--------------------------------------------------------------------------------
/mcp_server/src/graphiti_mcp_server.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Graphiti MCP Server - Exposes Graphiti functionality through the Model Context Protocol (MCP)
"""

import argparse
import asyncio
import logging
import os
import sys
from pathlib import Path
from typing import Any, Optional

from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.maintenance.graph_data_operations import clear_data
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel
from starlette.responses import JSONResponse

from config.schema import GraphitiConfig, ServerConfig
from models.response_types import (
    EpisodeSearchResponse,
    ErrorResponse,
    FactSearchResponse,
    NodeResult,
    NodeSearchResponse,
    StatusResponse,
    SuccessResponse,
)
from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
from services.queue_service import QueueService
from utils.formatting import format_fact_result

# Load .env file from mcp_server directory
mcp_server_dir = Path(__file__).parent.parent
env_file = mcp_server_dir / '.env'
if env_file.exists():
    load_dotenv(env_file)
else:
    # Try current working directory as fallback
    load_dotenv()


# Semaphore limit for concurrent Graphiti operations.
#
# This controls how many episodes can be processed simultaneously. Each episode
# processing involves multiple LLM calls (entity extraction, deduplication, etc.),
# so the actual number of concurrent LLM requests will be higher.
#
# TUNING GUIDELINES:
#
# LLM Provider Rate Limits (requests per minute):
# - OpenAI Tier 1 (free):     3 RPM   -> SEMAPHORE_LIMIT=1-2
# - OpenAI Tier 2:            60 RPM   -> SEMAPHORE_LIMIT=5-8
# - OpenAI Tier 3:           500 RPM   -> SEMAPHORE_LIMIT=10-15
# - OpenAI Tier 4:         5,000 RPM   -> SEMAPHORE_LIMIT=20-50
# - Anthropic (default):     50 RPM   -> SEMAPHORE_LIMIT=5-8
# - Anthropic (high tier): 1,000 RPM   -> SEMAPHORE_LIMIT=15-30
# - Azure OpenAI (varies):  Consult your quota -> adjust accordingly
#
# SYMPTOMS:
# - Too high: 429 rate limit errors, increased costs from parallel processing
# - Too low: Slow throughput, underutilized API quota
#
# MONITORING:
# - Watch logs for rate limit errors (429)
# - Monitor episode processing times
# - Check LLM provider dashboard for actual request rates
#
# DEFAULT: 10 (suitable for OpenAI Tier 3, mid-tier Anthropic)
SEMAPHORE_LIMIT = int(os.getenv('SEMAPHORE_LIMIT', 10))


# Configure structured logging with timestamps
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
DATE_FORMAT = '%Y-%m-%d %H:%M:%S'

logging.basicConfig(
    level=logging.INFO,
    format=LOG_FORMAT,
    datefmt=DATE_FORMAT,
    stream=sys.stderr,
)

# Configure specific loggers
logging.getLogger('uvicorn').setLevel(logging.INFO)
logging.getLogger('uvicorn.access').setLevel(logging.WARNING)  # Reduce access log noise
logging.getLogger('mcp.server.streamable_http_manager').setLevel(
    logging.WARNING
)  # Reduce MCP noise


# Patch uvicorn's logging config to use our format
def configure_uvicorn_logging():
    """Configure uvicorn loggers to match our format after they're created."""
    for logger_name in ['uvicorn', 'uvicorn.error', 'uvicorn.access']:
        uvicorn_logger = logging.getLogger(logger_name)
        # Remove existing handlers and add our own with proper formatting
        uvicorn_logger.handlers.clear()
        handler = logging.StreamHandler(sys.stderr)
        handler.setFormatter(logging.Formatter(LOG_FORMAT, datefmt=DATE_FORMAT))
        uvicorn_logger.addHandler(handler)
        uvicorn_logger.propagate = False


logger = logging.getLogger(__name__)

# Create global config instance - will be properly initialized later
config: GraphitiConfig

# MCP server instructions
GRAPHITI_MCP_INSTRUCTIONS = """
Graphiti is a memory service for AI agents built on a knowledge graph. Graphiti performs well
with dynamic data such as user interactions, changing enterprise data, and external information.

Graphiti transforms information into a richly connected knowledge network, allowing you to 
capture relationships between concepts, entities, and information. The system organizes data as episodes 
(content snippets), nodes (entities), and facts (relationships between entities), creating a dynamic, 
queryable memory store that evolves with new information. Graphiti supports multiple data formats, including 
structured JSON data, enabling seamless integration with existing data pipelines and systems.

Facts contain temporal metadata, allowing you to track the time of creation and whether a fact is invalid 
(superseded by new information).

Key capabilities:
1. Add episodes (text, messages, or JSON) to the knowledge graph with the add_memory tool
2. Search for nodes (entities) in the graph using natural language queries with search_nodes
3. Find relevant facts (relationships between entities) with search_facts
4. Retrieve specific entity edges or episodes by UUID
5. Manage the knowledge graph with tools like delete_episode, delete_entity_edge, and clear_graph

The server connects to a database for persistent storage and uses language models for certain operations. 
Each piece of information is organized by group_id, allowing you to maintain separate knowledge domains.

When adding information, provide descriptive names and detailed content to improve search quality. 
When searching, use specific queries and consider filtering by group_id for more relevant results.

For optimal performance, ensure the database is properly configured and accessible, and valid 
API keys are provided for any language model operations.
"""

# MCP server instance
mcp = FastMCP(
    'Graphiti Agent Memory',
    instructions=GRAPHITI_MCP_INSTRUCTIONS,
)

# Global services
graphiti_service: Optional['GraphitiService'] = None
queue_service: QueueService | None = None

# Global client for backward compatibility
graphiti_client: Graphiti | None = None
semaphore: asyncio.Semaphore


class GraphitiService:
    """Graphiti service using the unified configuration system."""

    def __init__(self, config: GraphitiConfig, semaphore_limit: int = 10):
        self.config = config
        self.semaphore_limit = semaphore_limit
        self.semaphore = asyncio.Semaphore(semaphore_limit)
        self.client: Graphiti | None = None
        self.entity_types = None

    async def initialize(self) -> None:
        """Initialize the Graphiti client with factory-created components."""
        try:
            # Create clients using factories
            llm_client = None
            embedder_client = None

            # Create LLM client based on configured provider
            try:
                llm_client = LLMClientFactory.create(self.config.llm)
            except Exception as e:
                logger.warning(f'Failed to create LLM client: {e}')

            # Create embedder client based on configured provider
            try:
                embedder_client = EmbedderFactory.create(self.config.embedder)
            except Exception as e:
                logger.warning(f'Failed to create embedder client: {e}')

            # Get database configuration
            db_config = DatabaseDriverFactory.create_config(self.config.database)

            # Build entity types from configuration
            custom_types = None
            if self.config.graphiti.entity_types:
                custom_types = {}
                for entity_type in self.config.graphiti.entity_types:
                    # Create a dynamic Pydantic model for each entity type
                    # Note: Don't use 'name' as it's a protected Pydantic attribute
                    entity_model = type(
                        entity_type.name,
                        (BaseModel,),
                        {
                            '__doc__': entity_type.description,
                        },
                    )
                    custom_types[entity_type.name] = entity_model

            # Store entity types for later use
            self.entity_types = custom_types

            # Initialize Graphiti client with appropriate driver
            try:
                if self.config.database.provider.lower() == 'falkordb':
                    # For FalkorDB, create a FalkorDriver instance directly
                    from graphiti_core.driver.falkordb_driver import FalkorDriver

                    falkor_driver = FalkorDriver(
                        host=db_config['host'],
                        port=db_config['port'],
                        password=db_config['password'],
                        database=db_config['database'],
                    )

                    self.client = Graphiti(
                        graph_driver=falkor_driver,
                        llm_client=llm_client,
                        embedder=embedder_client,
                        max_coroutines=self.semaphore_limit,
                    )
                else:
                    # For Neo4j (default), use the original approach
                    self.client = Graphiti(
                        uri=db_config['uri'],
                        user=db_config['user'],
                        password=db_config['password'],
                        llm_client=llm_client,
                        embedder=embedder_client,
                        max_coroutines=self.semaphore_limit,
                    )
            except Exception as db_error:
                # Check for connection errors
                error_msg = str(db_error).lower()
                if 'connection refused' in error_msg or 'could not connect' in error_msg:
                    db_provider = self.config.database.provider
                    if db_provider.lower() == 'falkordb':
                        raise RuntimeError(
                            f'\n{"=" * 70}\n'
                            f'Database Connection Error: FalkorDB is not running\n'
                            f'{"=" * 70}\n\n'
                            f'FalkorDB at {db_config["host"]}:{db_config["port"]} is not accessible.\n\n'
                            f'To start FalkorDB:\n'
                            f'  - Using Docker Compose: cd mcp_server && docker compose up\n'
                            f'  - Or run FalkorDB manually: docker run -p 6379:6379 falkordb/falkordb\n\n'
                            f'{"=" * 70}\n'
                        ) from db_error
                    elif db_provider.lower() == 'neo4j':
                        raise RuntimeError(
                            f'\n{"=" * 70}\n'
                            f'Database Connection Error: Neo4j is not running\n'
                            f'{"=" * 70}\n\n'
                            f'Neo4j at {db_config.get("uri", "unknown")} is not accessible.\n\n'
                            f'To start Neo4j:\n'
                            f'  - Using Docker Compose: cd mcp_server && docker compose -f docker/docker-compose-neo4j.yml up\n'
                            f'  - Or install Neo4j Desktop from: https://neo4j.com/download/\n'
                            f'  - Or run Neo4j manually: docker run -p 7474:7474 -p 7687:7687 neo4j:latest\n\n'
                            f'{"=" * 70}\n'
                        ) from db_error
                    else:
                        raise RuntimeError(
                            f'\n{"=" * 70}\n'
                            f'Database Connection Error: {db_provider} is not running\n'
                            f'{"=" * 70}\n\n'
                            f'{db_provider} at {db_config.get("uri", "unknown")} is not accessible.\n\n'
                            f'Please ensure {db_provider} is running and accessible.\n\n'
                            f'{"=" * 70}\n'
                        ) from db_error
                # Re-raise other errors
                raise

            # Build indices
            await self.client.build_indices_and_constraints()

            logger.info('Successfully initialized Graphiti client')

            # Log configuration details
            if llm_client:
                logger.info(
                    f'Using LLM provider: {self.config.llm.provider} / {self.config.llm.model}'
                )
            else:
                logger.info('No LLM client configured - entity extraction will be limited')

            if embedder_client:
                logger.info(f'Using Embedder provider: {self.config.embedder.provider}')
            else:
                logger.info('No Embedder client configured - search will be limited')

            if self.entity_types:
                entity_type_names = list(self.entity_types.keys())
                logger.info(f'Using custom entity types: {", ".join(entity_type_names)}')
            else:
                logger.info('Using default entity types')

            logger.info(f'Using database: {self.config.database.provider}')
            logger.info(f'Using group_id: {self.config.graphiti.group_id}')

        except Exception as e:
            logger.error(f'Failed to initialize Graphiti client: {e}')
            raise

    async def get_client(self) -> Graphiti:
        """Get the Graphiti client, initializing if necessary."""
        if self.client is None:
            await self.initialize()
        if self.client is None:
            raise RuntimeError('Failed to initialize Graphiti client')
        return self.client


@mcp.tool()
async def add_memory(
    name: str,
    episode_body: str,
    group_id: str | None = None,
    source: str = 'text',
    source_description: str = '',
    uuid: str | None = None,
) -> SuccessResponse | ErrorResponse:
    """Add an episode to memory. This is the primary way to add information to the graph.

    This function returns immediately and processes the episode addition in the background.
    Episodes for the same group_id are processed sequentially to avoid race conditions.

    Args:
        name (str): Name of the episode
        episode_body (str): The content of the episode to persist to memory. When source='json', this must be a
                           properly escaped JSON string, not a raw Python dictionary. The JSON data will be
                           automatically processed to extract entities and relationships.
        group_id (str, optional): A unique ID for this graph. If not provided, uses the default group_id from CLI
                                 or a generated one.
        source (str, optional): Source type, must be one of:
                               - 'text': For plain text content (default)
                               - 'json': For structured data
                               - 'message': For conversation-style content
        source_description (str, optional): Description of the source
        uuid (str, optional): Optional UUID for the episode

    Examples:
        # Adding plain text content
        add_memory(
            name="Company News",
            episode_body="Acme Corp announced a new product line today.",
            source="text",
            source_description="news article",
            group_id="some_arbitrary_string"
        )

        # Adding structured JSON data
        # NOTE: episode_body should be a JSON string (standard JSON escaping)
        add_memory(
            name="Customer Profile",
            episode_body='{"company": {"name": "Acme Technologies"}, "products": [{"id": "P001", "name": "CloudSync"}, {"id": "P002", "name": "DataMiner"}]}',
            source="json",
            source_description="CRM data"
        )
    """
    global graphiti_service, queue_service

    if graphiti_service is None or queue_service is None:
        return ErrorResponse(error='Services not initialized')

    try:
        # Use the provided group_id or fall back to the default from config
        effective_group_id = group_id or config.graphiti.group_id

        # Try to parse the source as an EpisodeType enum, with fallback to text
        episode_type = EpisodeType.text  # Default
        if source:
            try:
                episode_type = EpisodeType[source.lower()]
            except (KeyError, AttributeError):
                # If the source doesn't match any enum value, use text as default
                logger.warning(f"Unknown source type '{source}', using 'text' as default")
                episode_type = EpisodeType.text

        # Submit to queue service for async processing
        await queue_service.add_episode(
            group_id=effective_group_id,
            name=name,
            content=episode_body,
            source_description=source_description,
            episode_type=episode_type,
            entity_types=graphiti_service.entity_types,
            uuid=uuid or None,  # Ensure None is passed if uuid is None
        )

        return SuccessResponse(
            message=f"Episode '{name}' queued for processing in group '{effective_group_id}'"
        )
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error queuing episode: {error_msg}')
        return ErrorResponse(error=f'Error queuing episode: {error_msg}')


@mcp.tool()
async def search_nodes(
    query: str,
    group_ids: list[str] | None = None,
    max_nodes: int = 10,
    entity_types: list[str] | None = None,
) -> NodeSearchResponse | ErrorResponse:
    """Search for nodes in the graph memory.

    Args:
        query: The search query
        group_ids: Optional list of group IDs to filter results
        max_nodes: Maximum number of nodes to return (default: 10)
        entity_types: Optional list of entity type names to filter by
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Use the provided group_ids or fall back to the default from config if none provided
        effective_group_ids = (
            group_ids
            if group_ids is not None
            else [config.graphiti.group_id]
            if config.graphiti.group_id
            else []
        )

        # Create search filters
        search_filters = SearchFilters(
            node_labels=entity_types,
        )

        # Use the search_ method with node search config
        from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF

        results = await client.search_(
            query=query,
            config=NODE_HYBRID_SEARCH_RRF,
            group_ids=effective_group_ids,
            search_filter=search_filters,
        )

        # Extract nodes from results
        nodes = results.nodes[:max_nodes] if results.nodes else []

        if not nodes:
            return NodeSearchResponse(message='No relevant nodes found', nodes=[])

        # Format the results
        node_results = []
        for node in nodes:
            # Get attributes and ensure no embeddings are included
            attrs = node.attributes if hasattr(node, 'attributes') else {}
            # Remove any embedding keys that might be in attributes
            attrs = {k: v for k, v in attrs.items() if 'embedding' not in k.lower()}

            node_results.append(
                NodeResult(
                    uuid=node.uuid,
                    name=node.name,
                    labels=node.labels if node.labels else [],
                    created_at=node.created_at.isoformat() if node.created_at else None,
                    summary=node.summary,
                    group_id=node.group_id,
                    attributes=attrs,
                )
            )

        return NodeSearchResponse(message='Nodes retrieved successfully', nodes=node_results)
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error searching nodes: {error_msg}')
        return ErrorResponse(error=f'Error searching nodes: {error_msg}')


@mcp.tool()
async def search_memory_facts(
    query: str,
    group_ids: list[str] | None = None,
    max_facts: int = 10,
    center_node_uuid: str | None = None,
) -> FactSearchResponse | ErrorResponse:
    """Search the graph memory for relevant facts.

    Args:
        query: The search query
        group_ids: Optional list of group IDs to filter results
        max_facts: Maximum number of facts to return (default: 10)
        center_node_uuid: Optional UUID of a node to center the search around
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        # Validate max_facts parameter
        if max_facts <= 0:
            return ErrorResponse(error='max_facts must be a positive integer')

        client = await graphiti_service.get_client()

        # Use the provided group_ids or fall back to the default from config if none provided
        effective_group_ids = (
            group_ids
            if group_ids is not None
            else [config.graphiti.group_id]
            if config.graphiti.group_id
            else []
        )

        relevant_edges = await client.search(
            group_ids=effective_group_ids,
            query=query,
            num_results=max_facts,
            center_node_uuid=center_node_uuid,
        )

        if not relevant_edges:
            return FactSearchResponse(message='No relevant facts found', facts=[])

        facts = [format_fact_result(edge) for edge in relevant_edges]
        return FactSearchResponse(message='Facts retrieved successfully', facts=facts)
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error searching facts: {error_msg}')
        return ErrorResponse(error=f'Error searching facts: {error_msg}')


@mcp.tool()
async def delete_entity_edge(uuid: str) -> SuccessResponse | ErrorResponse:
    """Delete an entity edge from the graph memory.

    Args:
        uuid: UUID of the entity edge to delete
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Get the entity edge by UUID
        entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)
        # Delete the edge using its delete method
        await entity_edge.delete(client.driver)
        return SuccessResponse(message=f'Entity edge with UUID {uuid} deleted successfully')
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error deleting entity edge: {error_msg}')
        return ErrorResponse(error=f'Error deleting entity edge: {error_msg}')


@mcp.tool()
async def delete_episode(uuid: str) -> SuccessResponse | ErrorResponse:
    """Delete an episode from the graph memory.

    Args:
        uuid: UUID of the episode to delete
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Get the episodic node by UUID
        episodic_node = await EpisodicNode.get_by_uuid(client.driver, uuid)
        # Delete the node using its delete method
        await episodic_node.delete(client.driver)
        return SuccessResponse(message=f'Episode with UUID {uuid} deleted successfully')
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error deleting episode: {error_msg}')
        return ErrorResponse(error=f'Error deleting episode: {error_msg}')


@mcp.tool()
async def get_entity_edge(uuid: str) -> dict[str, Any] | ErrorResponse:
    """Get an entity edge from the graph memory by its UUID.

    Args:
        uuid: UUID of the entity edge to retrieve
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Get the entity edge directly using the EntityEdge class method
        entity_edge = await EntityEdge.get_by_uuid(client.driver, uuid)

        # Use the format_fact_result function to serialize the edge
        # Return the Python dict directly - MCP will handle serialization
        return format_fact_result(entity_edge)
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error getting entity edge: {error_msg}')
        return ErrorResponse(error=f'Error getting entity edge: {error_msg}')


@mcp.tool()
async def get_episodes(
    group_ids: list[str] | None = None,
    max_episodes: int = 10,
) -> EpisodeSearchResponse | ErrorResponse:
    """Get episodes from the graph memory.

    Args:
        group_ids: Optional list of group IDs to filter results
        max_episodes: Maximum number of episodes to return (default: 10)
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Use the provided group_ids or fall back to the default from config if none provided
        effective_group_ids = (
            group_ids
            if group_ids is not None
            else [config.graphiti.group_id]
            if config.graphiti.group_id
            else []
        )

        # Get episodes from the driver directly
        from graphiti_core.nodes import EpisodicNode

        if effective_group_ids:
            episodes = await EpisodicNode.get_by_group_ids(
                client.driver, effective_group_ids, limit=max_episodes
            )
        else:
            # If no group IDs, we need to use a different approach
            # For now, return empty list when no group IDs specified
            episodes = []

        if not episodes:
            return EpisodeSearchResponse(message='No episodes found', episodes=[])

        # Format the results
        episode_results = []
        for episode in episodes:
            episode_dict = {
                'uuid': episode.uuid,
                'name': episode.name,
                'content': episode.content,
                'created_at': episode.created_at.isoformat() if episode.created_at else None,
                'source': episode.source.value
                if hasattr(episode.source, 'value')
                else str(episode.source),
                'source_description': episode.source_description,
                'group_id': episode.group_id,
            }
            episode_results.append(episode_dict)

        return EpisodeSearchResponse(
            message='Episodes retrieved successfully', episodes=episode_results
        )
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error getting episodes: {error_msg}')
        return ErrorResponse(error=f'Error getting episodes: {error_msg}')


@mcp.tool()
async def clear_graph(group_ids: list[str] | None = None) -> SuccessResponse | ErrorResponse:
    """Clear all data from the graph for specified group IDs.

    Args:
        group_ids: Optional list of group IDs to clear. If not provided, clears the default group.
    """
    global graphiti_service

    if graphiti_service is None:
        return ErrorResponse(error='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Use the provided group_ids or fall back to the default from config if none provided
        effective_group_ids = (
            group_ids or [config.graphiti.group_id] if config.graphiti.group_id else []
        )

        if not effective_group_ids:
            return ErrorResponse(error='No group IDs specified for clearing')

        # Clear data for the specified group IDs
        await clear_data(client.driver, group_ids=effective_group_ids)

        return SuccessResponse(
            message=f'Graph data cleared successfully for group IDs: {", ".join(effective_group_ids)}'
        )
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error clearing graph: {error_msg}')
        return ErrorResponse(error=f'Error clearing graph: {error_msg}')


@mcp.tool()
async def get_status() -> StatusResponse:
    """Get the status of the Graphiti MCP server and database connection."""
    global graphiti_service

    if graphiti_service is None:
        return StatusResponse(status='error', message='Graphiti service not initialized')

    try:
        client = await graphiti_service.get_client()

        # Test database connection with a simple query
        async with client.driver.session() as session:
            result = await session.run('MATCH (n) RETURN count(n) as count')
            # Consume the result to verify query execution
            if result:
                _ = [record async for record in result]

        # Use the provider from the service's config, not the global
        provider_name = graphiti_service.config.database.provider
        return StatusResponse(
            status='ok',
            message=f'Graphiti MCP server is running and connected to {provider_name} database',
        )
    except Exception as e:
        error_msg = str(e)
        logger.error(f'Error checking database connection: {error_msg}')
        return StatusResponse(
            status='error',
            message=f'Graphiti MCP server is running but database connection failed: {error_msg}',
        )


@mcp.custom_route('/health', methods=['GET'])
async def health_check(request) -> JSONResponse:
    """Health check endpoint for Docker and load balancers."""
    return JSONResponse({'status': 'healthy', 'service': 'graphiti-mcp'})


async def initialize_server() -> ServerConfig:
    """Parse CLI arguments and initialize the Graphiti server configuration."""
    global config, graphiti_service, queue_service, graphiti_client, semaphore

    parser = argparse.ArgumentParser(
        description='Run the Graphiti MCP server with YAML configuration support'
    )

    # Configuration file argument
    # Default to config/config.yaml relative to the mcp_server directory
    default_config = Path(__file__).parent.parent / 'config' / 'config.yaml'
    parser.add_argument(
        '--config',
        type=Path,
        default=default_config,
        help='Path to YAML configuration file (default: config/config.yaml)',
    )

    # Transport arguments
    parser.add_argument(
        '--transport',
        choices=['sse', 'stdio', 'http'],
        help='Transport to use: http (recommended, default), stdio (standard I/O), or sse (deprecated)',
    )
    parser.add_argument(
        '--host',
        help='Host to bind the MCP server to',
    )
    parser.add_argument(
        '--port',
        type=int,
        help='Port to bind the MCP server to',
    )

    # Provider selection arguments
    parser.add_argument(
        '--llm-provider',
        choices=['openai', 'azure_openai', 'anthropic', 'gemini', 'groq'],
        help='LLM provider to use',
    )
    parser.add_argument(
        '--embedder-provider',
        choices=['openai', 'azure_openai', 'gemini', 'voyage'],
        help='Embedder provider to use',
    )
    parser.add_argument(
        '--database-provider',
        choices=['neo4j', 'falkordb'],
        help='Database provider to use',
    )

    # LLM configuration arguments
    parser.add_argument('--model', help='Model name to use with the LLM client')
    parser.add_argument('--small-model', help='Small model name to use with the LLM client')
    parser.add_argument(
        '--temperature', type=float, help='Temperature setting for the LLM (0.0-2.0)'
    )

    # Embedder configuration arguments
    parser.add_argument('--embedder-model', help='Model name to use with the embedder')

    # Graphiti-specific arguments
    parser.add_argument(
        '--group-id',
        help='Namespace for the graph. If not provided, uses config file or generates random UUID.',
    )
    parser.add_argument(
        '--user-id',
        help='User ID for tracking operations',
    )
    parser.add_argument(
        '--destroy-graph',
        action='store_true',
        help='Destroy all Graphiti graphs on startup',
    )

    args = parser.parse_args()

    # Set config path in environment for the settings to pick up
    if args.config:
        os.environ['CONFIG_PATH'] = str(args.config)

    # Load configuration with environment variables and YAML
    config = GraphitiConfig()

    # Apply CLI overrides
    config.apply_cli_overrides(args)

    # Also apply legacy CLI args for backward compatibility
    if hasattr(args, 'destroy_graph'):
        config.destroy_graph = args.destroy_graph

    # Log configuration details
    logger.info('Using configuration:')
    logger.info(f'  - LLM: {config.llm.provider} / {config.llm.model}')
    logger.info(f'  - Embedder: {config.embedder.provider} / {config.embedder.model}')
    logger.info(f'  - Database: {config.database.provider}')
    logger.info(f'  - Group ID: {config.graphiti.group_id}')
    logger.info(f'  - Transport: {config.server.transport}')

    # Log graphiti-core version
    try:
        import graphiti_core

        graphiti_version = getattr(graphiti_core, '__version__', 'unknown')
        logger.info(f'  - Graphiti Core: {graphiti_version}')
    except Exception:
        # Check for Docker-stored version file
        version_file = Path('/app/.graphiti-core-version')
        if version_file.exists():
            graphiti_version = version_file.read_text().strip()
            logger.info(f'  - Graphiti Core: {graphiti_version}')
        else:
            logger.info('  - Graphiti Core: version unavailable')

    # Handle graph destruction if requested
    if hasattr(config, 'destroy_graph') and config.destroy_graph:
        logger.warning('Destroying all Graphiti graphs as requested...')
        temp_service = GraphitiService(config, SEMAPHORE_LIMIT)
        await temp_service.initialize()
        client = await temp_service.get_client()
        await clear_data(client.driver)
        logger.info('All graphs destroyed')

    # Initialize services
    graphiti_service = GraphitiService(config, SEMAPHORE_LIMIT)
    queue_service = QueueService()
    await graphiti_service.initialize()

    # Set global client for backward compatibility
    graphiti_client = await graphiti_service.get_client()
    semaphore = graphiti_service.semaphore

    # Initialize queue service with the client
    await queue_service.initialize(graphiti_client)

    # Set MCP server settings
    if config.server.host:
        mcp.settings.host = config.server.host
    if config.server.port:
        mcp.settings.port = config.server.port

    # Return MCP configuration for transport
    return config.server


async def run_mcp_server():
    """Run the MCP server in the current event loop."""
    # Initialize the server
    mcp_config = await initialize_server()

    # Run the server with configured transport
    logger.info(f'Starting MCP server with transport: {mcp_config.transport}')
    if mcp_config.transport == 'stdio':
        await mcp.run_stdio_async()
    elif mcp_config.transport == 'sse':
        logger.info(
            f'Running MCP server with SSE transport on {mcp.settings.host}:{mcp.settings.port}'
        )
        logger.info(f'Access the server at: http://{mcp.settings.host}:{mcp.settings.port}/sse')
        await mcp.run_sse_async()
    elif mcp_config.transport == 'http':
        # Use localhost for display if binding to 0.0.0.0
        display_host = 'localhost' if mcp.settings.host == '0.0.0.0' else mcp.settings.host
        logger.info(
            f'Running MCP server with streamable HTTP transport on {mcp.settings.host}:{mcp.settings.port}'
        )
        logger.info('=' * 60)
        logger.info('MCP Server Access Information:')
        logger.info(f'  Base URL: http://{display_host}:{mcp.settings.port}/')
        logger.info(f'  MCP Endpoint: http://{display_host}:{mcp.settings.port}/mcp/')
        logger.info('  Transport: HTTP (streamable)')

        # Show FalkorDB Browser UI access if enabled
        if os.environ.get('BROWSER', '1') == '1':
            logger.info(f'  FalkorDB Browser UI: http://{display_host}:3000/')

        logger.info('=' * 60)
        logger.info('For MCP clients, connect to the /mcp/ endpoint above')

        # Configure uvicorn logging to match our format
        configure_uvicorn_logging()

        await mcp.run_streamable_http_async()
    else:
        raise ValueError(
            f'Unsupported transport: {mcp_config.transport}. Use "sse", "stdio", or "http"'
        )


def main():
    """Main function to run the Graphiti MCP server."""
    try:
        # Run everything in a single event loop
        asyncio.run(run_mcp_server())
    except KeyboardInterrupt:
        logger.info('Server shutting down...')
    except Exception as e:
        logger.error(f'Error initializing Graphiti MCP server: {str(e)}')
        raise


if __name__ == '__main__':
    main()

```
Page 6/9FirstPrevNextLast