#
tokens: 48895/50000 13/234 files (page 5/9)
lines: off (toggle) GitHub
raw markdown copy
This is page 5 of 9. Use http://codebase.md/getzep/graphiti?lines=false&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/models/nodes/node_db_queries.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Any

from graphiti_core.driver.driver import GraphProvider


def get_episode_node_save_query(provider: GraphProvider) -> str:
    match provider:
        case GraphProvider.NEPTUNE:
            return """
                MERGE (n:Episodic {uuid: $uuid})
                SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
                entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
                RETURN n.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MERGE (n:Episodic {uuid: $uuid})
                SET
                    n.name = $name,
                    n.group_id = $group_id,
                    n.created_at = $created_at,
                    n.source = $source,
                    n.source_description = $source_description,
                    n.content = $content,
                    n.valid_at = $valid_at,
                    n.entity_edges = $entity_edges
                RETURN n.uuid AS uuid
            """
        case GraphProvider.FALKORDB:
            return """
                MERGE (n:Episodic {uuid: $uuid})
                SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
                entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
                RETURN n.uuid AS uuid
            """
        case _:  # Neo4j
            return """
                MERGE (n:Episodic {uuid: $uuid})
                SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
                entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
                RETURN n.uuid AS uuid
            """


def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
    match provider:
        case GraphProvider.NEPTUNE:
            return """
                UNWIND $episodes AS episode
                MERGE (n:Episodic {uuid: episode.uuid})
                SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
                    source: episode.source, content: episode.content,
                entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
                RETURN n.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MERGE (n:Episodic {uuid: $uuid})
                SET
                    n.name = $name,
                    n.group_id = $group_id,
                    n.created_at = $created_at,
                    n.source = $source,
                    n.source_description = $source_description,
                    n.content = $content,
                    n.valid_at = $valid_at,
                    n.entity_edges = $entity_edges
                RETURN n.uuid AS uuid
            """
        case GraphProvider.FALKORDB:
            return """
                UNWIND $episodes AS episode
                MERGE (n:Episodic {uuid: episode.uuid})
                SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, 
                entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
                RETURN n.uuid AS uuid
            """
        case _:  # Neo4j
            return """
                UNWIND $episodes AS episode
                MERGE (n:Episodic {uuid: episode.uuid})
                SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, 
                entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
                RETURN n.uuid AS uuid
            """


EPISODIC_NODE_RETURN = """
    e.uuid AS uuid,
    e.name AS name,
    e.group_id AS group_id,
    e.created_at AS created_at,
    e.source AS source,
    e.source_description AS source_description,
    e.content AS content,
    e.valid_at AS valid_at,
    e.entity_edges AS entity_edges
"""

EPISODIC_NODE_RETURN_NEPTUNE = """
    e.content AS content,
    e.created_at AS created_at,
    e.valid_at AS valid_at,
    e.uuid AS uuid,
    e.name AS name,
    e.group_id AS group_id,
    e.source_description AS source_description,
    e.source AS source,
    split(e.entity_edges, ",") AS entity_edges
"""


def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
    match provider:
        case GraphProvider.FALKORDB:
            return f"""
                MERGE (n:Entity {{uuid: $entity_data.uuid}})
                SET n:{labels}
                SET n = $entity_data
                SET n.name_embedding = vecf32($entity_data.name_embedding)
                RETURN n.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MERGE (n:Entity {uuid: $uuid})
                SET
                    n.name = $name,
                    n.group_id = $group_id,
                    n.labels = $labels,
                    n.created_at = $created_at,
                    n.name_embedding = $name_embedding,
                    n.summary = $summary,
                    n.attributes = $attributes
                WITH n
                RETURN n.uuid AS uuid
            """
        case GraphProvider.NEPTUNE:
            label_subquery = ''
            for label in labels.split(':'):
                label_subquery += f' SET n:{label}\n'
            return f"""
                MERGE (n:Entity {{uuid: $entity_data.uuid}})
                {label_subquery}
                SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
                SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
                RETURN n.uuid AS uuid
            """
        case _:
            save_embedding_query = (
                'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
                if not has_aoss
                else ''
            )
            return (
                f"""
                MERGE (n:Entity {{uuid: $entity_data.uuid}})
                SET n:{labels}
                SET n = $entity_data
                """
                + save_embedding_query
                + """
                RETURN n.uuid AS uuid
            """
            )


def get_entity_node_save_bulk_query(
    provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
) -> str | Any:
    match provider:
        case GraphProvider.FALKORDB:
            queries = []
            for node in nodes:
                for label in node['labels']:
                    queries.append(
                        (
                            f"""
                            UNWIND $nodes AS node
                            MERGE (n:Entity {{uuid: node.uuid}})
                            SET n:{label}
                            SET n = node
                            WITH n, node
                            SET n.name_embedding = vecf32(node.name_embedding)
                            RETURN n.uuid AS uuid
                            """,
                            {'nodes': [node]},
                        )
                    )
            return queries
        case GraphProvider.NEPTUNE:
            queries = []
            for node in nodes:
                labels = ''
                for label in node['labels']:
                    labels += f' SET n:{label}\n'
                queries.append(
                    f"""
                        UNWIND $nodes AS node
                        MERGE (n:Entity {{uuid: node.uuid}})
                        {labels}
                        SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
                        SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
                        RETURN n.uuid AS uuid
                    """
                )
            return queries
        case GraphProvider.KUZU:
            return """
                MERGE (n:Entity {uuid: $uuid})
                SET
                    n.name = $name,
                    n.group_id = $group_id,
                    n.labels = $labels,
                    n.created_at = $created_at,
                    n.name_embedding = $name_embedding,
                    n.summary = $summary,
                    n.attributes = $attributes
                RETURN n.uuid AS uuid
            """
        case _:  # Neo4j
            save_embedding_query = (
                'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
                if not has_aoss
                else ''
            )
            return (
                """
                    UNWIND $nodes AS node
                    MERGE (n:Entity {uuid: node.uuid})
                    SET n:$(node.labels)
                    SET n = node
                    """
                + save_embedding_query
                + """
                RETURN n.uuid AS uuid
            """
            )


def get_entity_node_return_query(provider: GraphProvider) -> str:
    # `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
    if provider == GraphProvider.KUZU:
        return """
            n.uuid AS uuid,
            n.name AS name,
            n.group_id AS group_id,
            n.labels AS labels,
            n.created_at AS created_at,
            n.summary AS summary,
            n.attributes AS attributes
        """

    return """
        n.uuid AS uuid,
        n.name AS name,
        n.group_id AS group_id,
        n.created_at AS created_at,
        n.summary AS summary,
        labels(n) AS labels,
        properties(n) AS attributes
    """


def get_community_node_save_query(provider: GraphProvider) -> str:
    match provider:
        case GraphProvider.FALKORDB:
            return """
                MERGE (n:Community {uuid: $uuid})
                SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
                RETURN n.uuid AS uuid
            """
        case GraphProvider.NEPTUNE:
            return """
                MERGE (n:Community {uuid: $uuid})
                SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
                SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
                RETURN n.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MERGE (n:Community {uuid: $uuid})
                SET
                    n.name = $name,
                    n.group_id = $group_id,
                    n.created_at = $created_at,
                    n.name_embedding = $name_embedding,
                    n.summary = $summary
                RETURN n.uuid AS uuid
            """
        case _:  # Neo4j
            return """
                MERGE (n:Community {uuid: $uuid})
                SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
                WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
                RETURN n.uuid AS uuid
            """


COMMUNITY_NODE_RETURN = """
    c.uuid AS uuid,
    c.name AS name,
    c.group_id AS group_id,
    c.created_at AS created_at,
    c.name_embedding AS name_embedding,
    c.summary AS summary
"""

COMMUNITY_NODE_RETURN_NEPTUNE = """
    n.uuid AS uuid,
    n.name AS name,
    [x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
    n.group_id AS group_id,
    n.summary AS summary,
    n.created_at AS created_at
"""

```

--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
from graphiti_core.llm_client import LLMConfig, RateLimitError


@pytest.fixture
def mock_gemini_client():
    """Fixture to mock the Google Gemini client."""
    with patch('google.genai.Client') as mock_client:
        # Setup mock instance and its methods
        mock_instance = mock_client.return_value
        mock_instance.aio = MagicMock()
        mock_instance.aio.models = MagicMock()
        mock_instance.aio.models.generate_content = AsyncMock()
        yield mock_instance


@pytest.fixture
def gemini_reranker_client(mock_gemini_client):
    """Fixture to create a GeminiRerankerClient with a mocked client."""
    config = LLMConfig(api_key='test_api_key', model='test-model')
    client = GeminiRerankerClient(config=config)
    # Replace the client's client with our mock to ensure we're using the mock
    client.client = mock_gemini_client
    return client


def create_mock_response(score_text: str) -> MagicMock:
    """Helper function to create a mock Gemini response."""
    mock_response = MagicMock()
    mock_response.text = score_text
    return mock_response


class TestGeminiRerankerClientInitialization:
    """Tests for GeminiRerankerClient initialization."""

    def test_init_with_config(self):
        """Test initialization with a config object."""
        config = LLMConfig(api_key='test_api_key', model='test-model')
        client = GeminiRerankerClient(config=config)

        assert client.config == config

    @patch('google.genai.Client')
    def test_init_without_config(self, mock_client):
        """Test initialization without a config uses defaults."""
        client = GeminiRerankerClient()

        assert client.config is not None

    def test_init_with_custom_client(self):
        """Test initialization with a custom client."""
        mock_client = MagicMock()
        client = GeminiRerankerClient(client=mock_client)

        assert client.client == mock_client


class TestGeminiRerankerClientRanking:
    """Tests for GeminiRerankerClient rank method."""

    @pytest.mark.asyncio
    async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
        """Test basic ranking functionality."""
        # Setup mock responses with different scores
        mock_responses = [
            create_mock_response('85'),  # High relevance
            create_mock_response('45'),  # Medium relevance
            create_mock_response('20'),  # Low relevance
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        # Test data
        query = 'What is the capital of France?'
        passages = [
            'Paris is the capital and most populous city of France.',
            'London is the capital city of England and the United Kingdom.',
            'Berlin is the capital and largest city of Germany.',
        ]

        # Call method
        result = await gemini_reranker_client.rank(query, passages)

        # Assertions
        assert len(result) == 3
        assert all(isinstance(item, tuple) for item in result)
        assert all(
            isinstance(passage, str) and isinstance(score, float) for passage, score in result
        )

        # Check scores are normalized to [0, 1] and sorted in descending order
        scores = [score for _, score in result]
        assert all(0.0 <= score <= 1.0 for score in scores)
        assert scores == sorted(scores, reverse=True)

        # Check that the highest scoring passage is first
        assert result[0][1] == 0.85  # 85/100
        assert result[1][1] == 0.45  # 45/100
        assert result[2][1] == 0.20  # 20/100

    @pytest.mark.asyncio
    async def test_rank_empty_passages(self, gemini_reranker_client):
        """Test ranking with empty passages list."""
        query = 'Test query'
        passages = []

        result = await gemini_reranker_client.rank(query, passages)

        assert result == []

    @pytest.mark.asyncio
    async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
        """Test ranking with a single passage."""
        # Setup mock response
        mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')

        query = 'Test query'
        passages = ['Single test passage']

        result = await gemini_reranker_client.rank(query, passages)

        assert len(result) == 1
        assert result[0][0] == 'Single test passage'
        assert result[0][1] == 1.0  # Single passage gets full score

    @pytest.mark.asyncio
    async def test_rank_score_extraction_with_regex(
        self, gemini_reranker_client, mock_gemini_client
    ):
        """Test score extraction from various response formats."""
        # Setup mock responses with different formats
        mock_responses = [
            create_mock_response('Score: 90'),  # Contains text before number
            create_mock_response('The relevance is 65 out of 100'),  # Contains text around number
            create_mock_response('8'),  # Just the number
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        result = await gemini_reranker_client.rank(query, passages)

        # Check that scores were extracted correctly and normalized
        scores = [score for _, score in result]
        assert 0.90 in scores  # 90/100
        assert 0.65 in scores  # 65/100
        assert 0.08 in scores  # 8/100

    @pytest.mark.asyncio
    async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of invalid or non-numeric scores."""
        # Setup mock responses with invalid scores
        mock_responses = [
            create_mock_response('Not a number'),  # Invalid response
            create_mock_response(''),  # Empty response
            create_mock_response('95'),  # Valid response
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        result = await gemini_reranker_client.rank(query, passages)

        # Check that invalid scores are handled gracefully (assigned 0.0)
        scores = [score for _, score in result]
        assert 0.95 in scores  # Valid score
        assert scores.count(0.0) == 2  # Two invalid scores assigned 0.0

    @pytest.mark.asyncio
    async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
        """Test that scores are properly clamped to [0, 1] range."""
        # Setup mock responses with extreme scores
        # Note: regex only matches 1-3 digits, so negative numbers won't match
        mock_responses = [
            create_mock_response('999'),  # Above 100 but within regex range
            create_mock_response('invalid'),  # Invalid response becomes 0.0
            create_mock_response('50'),  # Normal score
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        result = await gemini_reranker_client.rank(query, passages)

        # Check that scores are normalized and clamped
        scores = [score for _, score in result]
        assert all(0.0 <= score <= 1.0 for score in scores)
        # 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
        assert 1.0 in scores
        # Invalid response should be 0.0
        assert 0.0 in scores
        # Normal score should be normalized (50/100 = 0.5)
        assert 0.5 in scores

    @pytest.mark.asyncio
    async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of rate limit errors."""
        # Setup mock to raise rate limit error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'Rate limit exceeded'
        )

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of quota errors."""
        # Setup mock to raise quota error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of resource exhausted errors."""
        # Setup mock to raise resource exhausted error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of HTTP 429 errors."""
        # Setup mock to raise 429 error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'HTTP 429 Too Many Requests'
        )

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of generic errors."""
        # Setup mock to raise generic error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(Exception) as exc_info:
            await gemini_reranker_client.rank(query, passages)

        assert 'Generic error' in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
        """Test that multiple passages are scored concurrently."""
        # Setup mock responses
        mock_responses = [
            create_mock_response('80'),
            create_mock_response('60'),
            create_mock_response('40'),
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        await gemini_reranker_client.rank(query, passages)

        # Verify that generate_content was called for each passage
        assert mock_gemini_client.aio.models.generate_content.call_count == 3

        # Verify that all calls were made with correct parameters
        calls = mock_gemini_client.aio.models.generate_content.call_args_list
        for call in calls:
            args, kwargs = call
            assert kwargs['model'] == gemini_reranker_client.config.model
            assert kwargs['config'].temperature == 0.0
            assert kwargs['config'].max_output_tokens == 3

    @pytest.mark.asyncio
    async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of response parsing errors."""
        # Setup mock responses that will trigger ValueError during parsing
        mock_responses = [
            create_mock_response('not a number at all'),  # Will fail regex match
            create_mock_response('also invalid text'),  # Will fail regex match
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        # Use multiple passages to avoid the single passage special case
        passages = ['Passage 1', 'Passage 2']

        result = await gemini_reranker_client.rank(query, passages)

        # Should handle the error gracefully and assign 0.0 score to both
        assert len(result) == 2
        assert all(score == 0.0 for _, score in result)

    @pytest.mark.asyncio
    async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of empty response text."""
        # Setup mock response with empty text
        mock_response = MagicMock()
        mock_response.text = ''  # Empty string instead of None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        query = 'Test query'
        # Use multiple passages to avoid the single passage special case
        passages = ['Passage 1', 'Passage 2']

        result = await gemini_reranker_client.rank(query, passages)

        # Should handle empty text gracefully and assign 0.0 score to both
        assert len(result) == 2
        assert all(score == 0.0 for _, score in result)


if __name__ == '__main__':
    pytest.main(['-v', 'test_gemini_reranker_client.py'])

```

--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
import sys
from datetime import datetime

import numpy as np
import pytest

from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import get_edge_count, get_node_count, group_id

pytest_plugins = ('pytest_asyncio',)


def setup_logging():
    # Create a logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)  # Set the logging level to INFO

    # Create console handler and set level to INFO
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)

    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # Add formatter to console handler
    console_handler.setFormatter(formatter)

    # Add console handler to logger
    logger.addHandler(console_handler)

    return logger


@pytest.mark.asyncio
async def test_episodic_edge(graph_driver, mock_embedder):
    now = datetime.now()

    # Create episodic node
    episode_node = EpisodicNode(
        name='test_episode',
        labels=[],
        created_at=now,
        valid_at=now,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Alice likes Bob',
        entity_edges=[],
        group_id=group_id,
    )
    node_count = await get_node_count(graph_driver, [episode_node.uuid])
    assert node_count == 0
    await episode_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [episode_node.uuid])
    assert node_count == 1

    # Create entity node
    alice_node = EntityNode(
        name='Alice',
        labels=[],
        created_at=now,
        summary='Alice summary',
        group_id=group_id,
    )
    await alice_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await alice_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 1

    # Create episodic to entity edge
    episodic_edge = EpisodicEdge(
        source_node_uuid=episode_node.uuid,
        target_node_uuid=alice_node.uuid,
        created_at=now,
        group_id=group_id,
    )
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 0
    await episodic_edge.save(graph_driver)
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 1

    # Get edge by uuid
    retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
    assert retrieved.uuid == episodic_edge.uuid
    assert retrieved.source_node_uuid == episode_node.uuid
    assert retrieved.target_node_uuid == alice_node.uuid
    assert retrieved.created_at == now
    assert retrieved.group_id == group_id

    # Get edge by uuids
    retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
    assert len(retrieved) == 1
    assert retrieved[0].uuid == episodic_edge.uuid
    assert retrieved[0].source_node_uuid == episode_node.uuid
    assert retrieved[0].target_node_uuid == alice_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by group ids
    retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == episodic_edge.uuid
    assert retrieved[0].source_node_uuid == episode_node.uuid
    assert retrieved[0].target_node_uuid == alice_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get episodic node by entity node uuid
    retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == episode_node.uuid
    assert retrieved[0].name == 'test_episode'
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Delete edge by uuid
    await episodic_edge.delete(graph_driver)
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 0

    # Delete edge by uuids
    await episodic_edge.save(graph_driver)
    await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 0

    # Cleanup nodes
    await episode_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [episode_node.uuid])
    assert node_count == 0
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0

    await graph_driver.close()


@pytest.mark.asyncio
async def test_entity_edge(graph_driver, mock_embedder):
    now = datetime.now()

    # Create entity node
    alice_node = EntityNode(
        name='Alice',
        labels=[],
        created_at=now,
        summary='Alice summary',
        group_id=group_id,
    )
    await alice_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await alice_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 1

    # Create entity node
    bob_node = EntityNode(
        name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
    )
    await bob_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [bob_node.uuid])
    assert node_count == 0
    await bob_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [bob_node.uuid])
    assert node_count == 1

    # Create entity to entity edge
    entity_edge = EntityEdge(
        source_node_uuid=alice_node.uuid,
        target_node_uuid=bob_node.uuid,
        created_at=now,
        name='likes',
        fact='Alice likes Bob',
        episodes=[],
        expired_at=now,
        valid_at=now,
        invalid_at=now,
        group_id=group_id,
    )
    edge_embedding = await entity_edge.generate_embedding(mock_embedder)
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0
    await entity_edge.save(graph_driver)
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 1

    # Get edge by uuid
    retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
    assert retrieved.uuid == entity_edge.uuid
    assert retrieved.source_node_uuid == alice_node.uuid
    assert retrieved.target_node_uuid == bob_node.uuid
    assert retrieved.created_at == now
    assert retrieved.group_id == group_id

    # Get edge by uuids
    retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
    assert len(retrieved) == 1
    assert retrieved[0].uuid == entity_edge.uuid
    assert retrieved[0].source_node_uuid == alice_node.uuid
    assert retrieved[0].target_node_uuid == bob_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by group ids
    retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == entity_edge.uuid
    assert retrieved[0].source_node_uuid == alice_node.uuid
    assert retrieved[0].target_node_uuid == bob_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by node uuid
    retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == entity_edge.uuid
    assert retrieved[0].source_node_uuid == alice_node.uuid
    assert retrieved[0].target_node_uuid == bob_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get fact embedding
    await entity_edge.load_fact_embedding(graph_driver)
    assert np.allclose(entity_edge.fact_embedding, edge_embedding)

    # Delete edge by uuid
    await entity_edge.delete(graph_driver)
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Delete edge by uuids
    await entity_edge.save(graph_driver)
    await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Deleting node should delete the edge
    await entity_edge.save(graph_driver)
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Deleting node by uuids should delete the edge
    await alice_node.save(graph_driver)
    await entity_edge.save(graph_driver)
    await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Deleting node by group id should delete the edge
    await alice_node.save(graph_driver)
    await entity_edge.save(graph_driver)
    await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Cleanup nodes
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await bob_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [bob_node.uuid])
    assert node_count == 0

    await graph_driver.close()


@pytest.mark.asyncio
async def test_community_edge(graph_driver, mock_embedder):
    now = datetime.now()

    # Create community node
    community_node_1 = CommunityNode(
        name='test_community_1',
        group_id=group_id,
        summary='Community A summary',
    )
    await community_node_1.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid])
    assert node_count == 0
    await community_node_1.save(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid])
    assert node_count == 1

    # Create community node
    community_node_2 = CommunityNode(
        name='test_community_2',
        group_id=group_id,
        summary='Community B summary',
    )
    await community_node_2.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [community_node_2.uuid])
    assert node_count == 0
    await community_node_2.save(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_2.uuid])
    assert node_count == 1

    # Create entity node
    alice_node = EntityNode(
        name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
    )
    await alice_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await alice_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 1

    # Create community to community edge
    community_edge = CommunityEdge(
        source_node_uuid=community_node_1.uuid,
        target_node_uuid=community_node_2.uuid,
        created_at=now,
        group_id=group_id,
    )
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 0
    await community_edge.save(graph_driver)
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 1

    # Get edge by uuid
    retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
    assert retrieved.uuid == community_edge.uuid
    assert retrieved.source_node_uuid == community_node_1.uuid
    assert retrieved.target_node_uuid == community_node_2.uuid
    assert retrieved.created_at == now
    assert retrieved.group_id == group_id

    # Get edge by uuids
    retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
    assert len(retrieved) == 1
    assert retrieved[0].uuid == community_edge.uuid
    assert retrieved[0].source_node_uuid == community_node_1.uuid
    assert retrieved[0].target_node_uuid == community_node_2.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by group ids
    retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == community_edge.uuid
    assert retrieved[0].source_node_uuid == community_node_1.uuid
    assert retrieved[0].target_node_uuid == community_node_2.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Delete edge by uuid
    await community_edge.delete(graph_driver)
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 0

    # Delete edge by uuids
    await community_edge.save(graph_driver)
    await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 0

    # Cleanup nodes
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await community_node_1.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid])
    assert node_count == 0
    await community_node_2.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_2.uuid])
    assert node_count == 0

    await graph_driver.close()

```

--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/embedder/test_gemini.py

from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from embedder_fixtures import create_embedding_values

from graphiti_core.embedder.gemini import (
    DEFAULT_EMBEDDING_MODEL,
    GeminiEmbedder,
    GeminiEmbedderConfig,
)


def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
    """Create a mock Gemini embedding with specified value multiplier and dimension."""
    mock_embedding = MagicMock()
    mock_embedding.values = create_embedding_values(multiplier, dimension)
    return mock_embedding


@pytest.fixture
def mock_gemini_response() -> MagicMock:
    """Create a mock Gemini embeddings response."""
    mock_result = MagicMock()
    mock_result.embeddings = [create_gemini_embedding()]
    return mock_result


@pytest.fixture
def mock_gemini_batch_response() -> MagicMock:
    """Create a mock Gemini batch embeddings response."""
    mock_result = MagicMock()
    mock_result.embeddings = [
        create_gemini_embedding(0.1),
        create_gemini_embedding(0.2),
        create_gemini_embedding(0.3),
    ]
    return mock_result


@pytest.fixture
def mock_gemini_client() -> Generator[Any, Any, None]:
    """Create a mocked Gemini client."""
    with patch('google.genai.Client') as mock_client:
        mock_instance = mock_client.return_value
        mock_instance.aio = MagicMock()
        mock_instance.aio.models = MagicMock()
        mock_instance.aio.models.embed_content = AsyncMock()
        yield mock_instance


@pytest.fixture
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
    """Create a GeminiEmbedder with a mocked client."""
    config = GeminiEmbedderConfig(api_key='test_api_key')
    client = GeminiEmbedder(config=config)
    client.client = mock_gemini_client
    return client


class TestGeminiEmbedderInitialization:
    """Tests for GeminiEmbedder initialization."""

    @patch('google.genai.Client')
    def test_init_with_config(self, mock_client):
        """Test initialization with a config object."""
        config = GeminiEmbedderConfig(
            api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
        )
        embedder = GeminiEmbedder(config=config)

        assert embedder.config == config
        assert embedder.config.embedding_model == 'custom-model'
        assert embedder.config.api_key == 'test_api_key'
        assert embedder.config.embedding_dim == 768

    @patch('google.genai.Client')
    def test_init_without_config(self, mock_client):
        """Test initialization without a config uses defaults."""
        embedder = GeminiEmbedder()

        assert embedder.config is not None
        assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL

    @patch('google.genai.Client')
    def test_init_with_partial_config(self, mock_client):
        """Test initialization with partial config."""
        config = GeminiEmbedderConfig(api_key='test_api_key')
        embedder = GeminiEmbedder(config=config)

        assert embedder.config.api_key == 'test_api_key'
        assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL


class TestGeminiEmbedderCreate:
    """Tests for GeminiEmbedder create method."""

    @pytest.mark.asyncio
    async def test_create_calls_api_correctly(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_response: MagicMock,
    ) -> None:
        """Test that create method correctly calls the API and processes the response."""
        # Setup
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response

        # Call method
        result = await gemini_embedder.create('Test input')

        # Verify API is called with correct parameters
        mock_gemini_client.aio.models.embed_content.assert_called_once()
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
        assert kwargs['contents'] == ['Test input']

        # Verify result is processed correctly
        assert result == mock_gemini_response.embeddings[0].values

    @pytest.mark.asyncio
    @patch('google.genai.Client')
    async def test_create_with_custom_model(
        self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
    ) -> None:
        """Test create method with custom embedding model."""
        # Setup embedder with custom model
        config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
        embedder = GeminiEmbedder(config=config)
        embedder.client = mock_gemini_client
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response

        # Call method
        await embedder.create('Test input')

        # Verify custom model is used
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == 'custom-model'

    @pytest.mark.asyncio
    @patch('google.genai.Client')
    async def test_create_with_custom_dimension(
        self, mock_client_class, mock_gemini_client: Any
    ) -> None:
        """Test create method with custom embedding dimension."""
        # Setup embedder with custom dimension
        config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
        embedder = GeminiEmbedder(config=config)
        embedder.client = mock_gemini_client

        # Setup mock response with custom dimension
        mock_response = MagicMock()
        mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        # Call method
        result = await embedder.create('Test input')

        # Verify custom dimension is used in config
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['config'].output_dimensionality == 768

        # Verify result has correct dimension
        assert len(result) == 768

    @pytest.mark.asyncio
    async def test_create_with_different_input_types(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_response: MagicMock,
    ) -> None:
        """Test create method with different input types."""
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response

        # Test with string
        await gemini_embedder.create('Test string')

        # Test with list of strings
        await gemini_embedder.create(['Test', 'List'])

        # Test with iterable of integers
        await gemini_embedder.create([1, 2, 3])

        # Verify all calls were made
        assert mock_gemini_client.aio.models.embed_content.call_count == 3

    @pytest.mark.asyncio
    async def test_create_no_embeddings_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create method handling of no embeddings response."""
        # Setup mock response with no embeddings
        mock_response = MagicMock()
        mock_response.embeddings = []
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        # Call method and expect exception
        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create('Test input')

        assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_create_no_values_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create method handling of embeddings with no values."""
        # Setup mock response with embedding but no values
        mock_embedding = MagicMock()
        mock_embedding.values = None
        mock_response = MagicMock()
        mock_response.embeddings = [mock_embedding]
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        # Call method and expect exception
        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create('Test input')

        assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)


class TestGeminiEmbedderCreateBatch:
    """Tests for GeminiEmbedder create_batch method."""

    @pytest.mark.asyncio
    async def test_create_batch_processes_multiple_inputs(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_batch_response: MagicMock,
    ) -> None:
        """Test that create_batch method correctly processes multiple inputs."""
        # Setup
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
        input_batch = ['Input 1', 'Input 2', 'Input 3']

        # Call method
        result = await gemini_embedder.create_batch(input_batch)

        # Verify API is called with correct parameters
        mock_gemini_client.aio.models.embed_content.assert_called_once()
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
        assert kwargs['contents'] == input_batch

        # Verify all results are processed correctly
        assert len(result) == 3
        assert result == [
            mock_gemini_batch_response.embeddings[0].values,
            mock_gemini_batch_response.embeddings[1].values,
            mock_gemini_batch_response.embeddings[2].values,
        ]

    @pytest.mark.asyncio
    async def test_create_batch_single_input(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_response: MagicMock,
    ) -> None:
        """Test create_batch method with single input."""
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
        input_batch = ['Single input']

        result = await gemini_embedder.create_batch(input_batch)

        assert len(result) == 1
        assert result[0] == mock_gemini_response.embeddings[0].values

    @pytest.mark.asyncio
    async def test_create_batch_empty_input(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method with empty input."""
        # Setup mock response with no embeddings
        mock_response = MagicMock()
        mock_response.embeddings = []
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        input_batch = []

        result = await gemini_embedder.create_batch(input_batch)
        assert result == []
        mock_gemini_client.aio.models.embed_content.assert_not_called()

    @pytest.mark.asyncio
    async def test_create_batch_no_embeddings_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method handling of no embeddings response."""
        # Setup mock response with no embeddings
        mock_response = MagicMock()
        mock_response.embeddings = []
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        input_batch = ['Input 1', 'Input 2']

        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create_batch(input_batch)

        assert 'No embeddings returned from Gemini API' in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_create_batch_empty_values_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method handling of embeddings with empty values."""
        # Setup mock response with embeddings but empty values
        mock_embedding1 = MagicMock()
        mock_embedding1.values = [0.1, 0.2, 0.3]  # Valid values
        mock_embedding2 = MagicMock()
        mock_embedding2.values = None  # Empty values

        # Mock response for the initial batch call
        mock_batch_response = MagicMock()
        mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]

        # Mock response for individual processing of 'Input 1'
        mock_individual_response_1 = MagicMock()
        mock_individual_response_1.embeddings = [mock_embedding1]

        # Mock response for individual processing of 'Input 2' (which has empty values)
        mock_individual_response_2 = MagicMock()
        mock_individual_response_2.embeddings = [mock_embedding2]

        # Set side_effect for embed_content to control return values for each call
        mock_gemini_client.aio.models.embed_content.side_effect = [
            mock_batch_response,  # First call for the batch
            mock_individual_response_1,  # Second call for individual item 1
            mock_individual_response_2,  # Third call for individual item 2
        ]

        input_batch = ['Input 1', 'Input 2']

        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create_batch(input_batch)

        assert 'Empty embedding values returned' in str(exc_info.value)

    @pytest.mark.asyncio
    @patch('google.genai.Client')
    async def test_create_batch_with_custom_model_and_dimension(
        self, mock_client_class, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method with custom model and dimension."""
        # Setup embedder with custom settings
        config = GeminiEmbedderConfig(
            api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
        )
        embedder = GeminiEmbedder(config=config)
        embedder.client = mock_gemini_client

        # Setup mock response
        mock_response = MagicMock()
        mock_response.embeddings = [
            create_gemini_embedding(0.1, 512),
            create_gemini_embedding(0.2, 512),
        ]
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        input_batch = ['Input 1', 'Input 2']
        result = await embedder.create_batch(input_batch)

        # Verify custom settings are used
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == 'custom-batch-model'
        assert kwargs['config'].output_dimensionality == 512

        # Verify results have correct dimension
        assert len(result) == 2
        assert all(len(embedding) == 512 for embedding in result)


if __name__ == '__main__':
    pytest.main(['-xvs', __file__])

```

--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from graphiti_core.driver.driver import GraphProvider

try:
    from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession

    HAS_FALKORDB = True
except ImportError:
    FalkorDriver = None
    HAS_FALKORDB = False


class TestFalkorDriver:
    """Comprehensive test suite for FalkorDB driver."""

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def setup_method(self):
        """Set up test fixtures."""
        self.mock_client = MagicMock()
        with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
            self.driver = FalkorDriver()
        self.driver.client = self.mock_client

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_init_with_connection_params(self):
        """Test initialization with connection parameters."""
        with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
            driver = FalkorDriver(
                host='test-host', port='1234', username='test-user', password='test-pass'
            )
            assert driver.provider == GraphProvider.FALKORDB
            mock_falkor_db.assert_called_once_with(
                host='test-host', port='1234', username='test-user', password='test-pass'
            )

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_init_with_falkor_db_instance(self):
        """Test initialization with a FalkorDB instance."""
        with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
            mock_falkor_db = MagicMock()
            driver = FalkorDriver(falkor_db=mock_falkor_db)
            assert driver.provider == GraphProvider.FALKORDB
            assert driver.client is mock_falkor_db
            mock_falkor_db_class.assert_not_called()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_provider(self):
        """Test driver provider identification."""
        assert self.driver.provider == GraphProvider.FALKORDB

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_get_graph_with_name(self):
        """Test _get_graph with specific graph name."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        result = self.driver._get_graph('test_graph')

        self.mock_client.select_graph.assert_called_once_with('test_graph')
        assert result is mock_graph

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_get_graph_with_none_defaults_to_default_database(self):
        """Test _get_graph with None defaults to default_db."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        result = self.driver._get_graph(None)

        self.mock_client.select_graph.assert_called_once_with('default_db')
        assert result is mock_graph

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_success(self):
        """Test successful query execution."""
        mock_graph = MagicMock()
        mock_result = MagicMock()
        mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
        mock_result.result_set = [['row1col1', 'row1col2']]
        mock_graph.query = AsyncMock(return_value=mock_result)
        self.mock_client.select_graph.return_value = mock_graph

        result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')

        mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})

        result_set, header, summary = result
        assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
        assert header == ['column1', 'column2']
        assert summary is None

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_handles_index_already_exists_error(self):
        """Test handling of 'already indexed' error."""
        mock_graph = MagicMock()
        mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
        self.mock_client.select_graph.return_value = mock_graph

        with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
            result = await self.driver.execute_query('CREATE INDEX ...')

            mock_logger.info.assert_called_once()
            assert result is None

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_propagates_other_exceptions(self):
        """Test that other exceptions are properly propagated."""
        mock_graph = MagicMock()
        mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
        self.mock_client.select_graph.return_value = mock_graph

        with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
            with pytest.raises(Exception, match='Other error'):
                await self.driver.execute_query('INVALID QUERY')

            mock_logger.error.assert_called_once()

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_converts_datetime_parameters(self):
        """Test that datetime objects in kwargs are converted to ISO strings."""
        mock_graph = MagicMock()
        mock_result = MagicMock()
        mock_result.header = []
        mock_result.result_set = []
        mock_graph.query = AsyncMock(return_value=mock_result)
        self.mock_client.select_graph.return_value = mock_graph

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        await self.driver.execute_query(
            'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
        )

        call_args = mock_graph.query.call_args[0]
        assert call_args[1]['created_at'] == test_datetime.isoformat()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_session_creation(self):
        """Test session creation with specific database."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        session = self.driver.session()

        assert isinstance(session, FalkorDriverSession)
        assert session.graph is mock_graph

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_session_creation_with_none_uses_default_database(self):
        """Test session creation with None uses default database."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        session = self.driver.session()

        assert isinstance(session, FalkorDriverSession)

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_close_calls_connection_close(self):
        """Test driver close method calls connection close."""
        mock_connection = MagicMock()
        mock_connection.close = AsyncMock()
        self.mock_client.connection = mock_connection

        # Ensure hasattr checks work correctly
        del self.mock_client.aclose  # Remove aclose if it exists

        with patch('builtins.hasattr') as mock_hasattr:
            # hasattr(self.client, 'aclose') returns False
            # hasattr(self.client.connection, 'aclose') returns False
            # hasattr(self.client.connection, 'close') returns True
            mock_hasattr.side_effect = lambda obj, attr: (
                attr == 'close' and obj is mock_connection
            )

            await self.driver.close()

        mock_connection.close.assert_called_once()

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_delete_all_indexes(self):
        """Test delete_all_indexes method."""
        with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
            # Return None to simulate no indexes found
            mock_execute.return_value = None

            await self.driver.delete_all_indexes()

            mock_execute.assert_called_once_with('CALL db.indexes()')


class TestFalkorDriverSession:
    """Test FalkorDB driver session functionality."""

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def setup_method(self):
        """Set up test fixtures."""
        self.mock_graph = MagicMock()
        self.session = FalkorDriverSession(self.mock_graph)

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_session_async_context_manager(self):
        """Test session can be used as async context manager."""
        async with self.session as s:
            assert s is self.session

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_close_method(self):
        """Test session close method doesn't raise exceptions."""
        await self.session.close()  # Should not raise

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_write_passes_session_and_args(self):
        """Test execute_write method passes session and arguments correctly."""

        async def test_func(session, *args, **kwargs):
            assert session is self.session
            assert args == ('arg1', 'arg2')
            assert kwargs == {'key': 'value'}
            return 'result'

        result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
        assert result == 'result'

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_run_single_query_with_parameters(self):
        """Test running a single query with parameters."""
        self.mock_graph.query = AsyncMock()

        await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')

        self.mock_graph.query.assert_called_once_with(
            'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
        )

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_run_multiple_queries_as_list(self):
        """Test running multiple queries passed as list."""
        self.mock_graph.query = AsyncMock()

        queries = [
            ('MATCH (n) RETURN n', {'param1': 'value1'}),
            ('CREATE (n:Node)', {'param2': 'value2'}),
        ]

        await self.session.run(queries)

        assert self.mock_graph.query.call_count == 2
        calls = self.mock_graph.query.call_args_list
        assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
        assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_run_converts_datetime_objects_to_iso_strings(self):
        """Test that datetime objects are converted to ISO strings."""
        self.mock_graph.query = AsyncMock()
        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        await self.session.run(
            'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
        )

        self.mock_graph.query.assert_called_once()
        call_args = self.mock_graph.query.call_args[0]
        assert call_args[1]['created_at'] == test_datetime.isoformat()


class TestDatetimeConversion:
    """Test datetime conversion utility function."""

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_datetime_dict(self):
        """Test datetime conversion in nested dictionary."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
        input_dict = {
            'string_val': 'test',
            'datetime_val': test_datetime,
            'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
        }

        result = convert_datetimes_to_strings(input_dict)

        assert result['string_val'] == 'test'
        assert result['datetime_val'] == test_datetime.isoformat()
        assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
        assert result['nested_dict']['nested_string'] == 'nested_test'

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_datetime_list_and_tuple(self):
        """Test datetime conversion in lists and tuples."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        # Test list
        input_list = ['test', test_datetime, ['nested', test_datetime]]
        result_list = convert_datetimes_to_strings(input_list)
        assert result_list[0] == 'test'
        assert result_list[1] == test_datetime.isoformat()
        assert result_list[2][1] == test_datetime.isoformat()

        # Test tuple
        input_tuple = ('test', test_datetime)
        result_tuple = convert_datetimes_to_strings(input_tuple)
        assert isinstance(result_tuple, tuple)
        assert result_tuple[0] == 'test'
        assert result_tuple[1] == test_datetime.isoformat()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_single_datetime(self):
        """Test datetime conversion for single datetime object."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
        result = convert_datetimes_to_strings(test_datetime)
        assert result == test_datetime.isoformat()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_other_types_unchanged(self):
        """Test that non-datetime types are returned unchanged."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        assert convert_datetimes_to_strings('string') == 'string'
        assert convert_datetimes_to_strings(123) == 123
        assert convert_datetimes_to_strings(None) is None
        assert convert_datetimes_to_strings(True) is True


# Simple integration test
class TestFalkorDriverIntegration:
    """Simple integration test for FalkorDB driver."""

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_basic_integration_with_real_falkordb(self):
        """Basic integration test with real FalkorDB instance."""
        pytest.importorskip('falkordb')

        falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
        falkor_port = os.getenv('FALKORDB_PORT', '6379')

        try:
            driver = FalkorDriver(host=falkor_host, port=falkor_port)

            # Test basic query execution
            result = await driver.execute_query('RETURN 1 as test')
            assert result is not None

            result_set, header, summary = result
            assert header == ['test']
            assert result_set == [{'test': 1}]

            await driver.close()

        except Exception as e:
            pytest.skip(f'FalkorDB not available for integration test: {e}')

```

--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------

```python
"""Factory classes for creating LLM, Embedder, and Database clients."""

from openai import AsyncAzureOpenAI

from config.schema import (
    DatabaseConfig,
    EmbedderConfig,
    LLMConfig,
)

# Try to import FalkorDriver if available
try:
    from graphiti_core.driver.falkordb_driver import FalkorDriver  # noqa: F401

    HAS_FALKOR = True
except ImportError:
    HAS_FALKOR = False

# Kuzu support removed - FalkorDB is now the default
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig

# Try to import additional providers if available
try:
    from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient

    HAS_AZURE_EMBEDDER = True
except ImportError:
    HAS_AZURE_EMBEDDER = False

try:
    from graphiti_core.embedder.gemini import GeminiEmbedder

    HAS_GEMINI_EMBEDDER = True
except ImportError:
    HAS_GEMINI_EMBEDDER = False

try:
    from graphiti_core.embedder.voyage import VoyageAIEmbedder

    HAS_VOYAGE_EMBEDDER = True
except ImportError:
    HAS_VOYAGE_EMBEDDER = False

try:
    from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient

    HAS_AZURE_LLM = True
except ImportError:
    HAS_AZURE_LLM = False

try:
    from graphiti_core.llm_client.anthropic_client import AnthropicClient

    HAS_ANTHROPIC = True
except ImportError:
    HAS_ANTHROPIC = False

try:
    from graphiti_core.llm_client.gemini_client import GeminiClient

    HAS_GEMINI = True
except ImportError:
    HAS_GEMINI = False

try:
    from graphiti_core.llm_client.groq_client import GroqClient

    HAS_GROQ = True
except ImportError:
    HAS_GROQ = False
from utils.utils import create_azure_credential_token_provider


def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
    """Validate API key is present.

    Args:
        provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
        api_key: The API key to validate
        logger: Logger instance for output

    Returns:
        The validated API key

    Raises:
        ValueError: If API key is None or empty
    """
    if not api_key:
        raise ValueError(
            f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
        )

    logger.info(f'Creating {provider_name} client')

    return api_key


class LLMClientFactory:
    """Factory for creating LLM clients based on configuration."""

    @staticmethod
    def create(config: LLMConfig) -> LLMClient:
        """Create an LLM client based on the configured provider."""
        import logging

        logger = logging.getLogger(__name__)

        provider = config.provider.lower()

        match provider:
            case 'openai':
                if not config.providers.openai:
                    raise ValueError('OpenAI provider configuration not found')

                api_key = config.providers.openai.api_key
                _validate_api_key('OpenAI', api_key, logger)

                from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig

                # Determine appropriate small model based on main model type
                is_reasoning_model = (
                    config.model.startswith('gpt-5')
                    or config.model.startswith('o1')
                    or config.model.startswith('o3')
                )
                small_model = (
                    'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
                )  # Use reasoning model for small tasks if main model is reasoning

                llm_config = CoreLLMConfig(
                    api_key=api_key,
                    model=config.model,
                    small_model=small_model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )

                # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
                if is_reasoning_model:
                    return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
                else:
                    # For non-reasoning models, explicitly pass None to disable these parameters
                    return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)

            case 'azure_openai':
                if not HAS_AZURE_LLM:
                    raise ValueError(
                        'Azure OpenAI LLM client not available in current graphiti-core version'
                    )
                if not config.providers.azure_openai:
                    raise ValueError('Azure OpenAI provider configuration not found')
                azure_config = config.providers.azure_openai

                if not azure_config.api_url:
                    raise ValueError('Azure OpenAI API URL is required')

                # Handle Azure AD authentication if enabled
                api_key: str | None = None
                azure_ad_token_provider = None
                if azure_config.use_azure_ad:
                    logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
                    azure_ad_token_provider = create_azure_credential_token_provider()
                else:
                    api_key = azure_config.api_key
                    _validate_api_key('Azure OpenAI', api_key, logger)

                # Create the Azure OpenAI client first
                azure_client = AsyncAzureOpenAI(
                    api_key=api_key,
                    azure_endpoint=azure_config.api_url,
                    api_version=azure_config.api_version,
                    azure_deployment=azure_config.deployment_name,
                    azure_ad_token_provider=azure_ad_token_provider,
                )

                # Then create the LLMConfig
                from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig

                llm_config = CoreLLMConfig(
                    api_key=api_key,
                    base_url=azure_config.api_url,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )

                return AzureOpenAILLMClient(
                    azure_client=azure_client,
                    config=llm_config,
                    max_tokens=config.max_tokens,
                )

            case 'anthropic':
                if not HAS_ANTHROPIC:
                    raise ValueError(
                        'Anthropic client not available in current graphiti-core version'
                    )
                if not config.providers.anthropic:
                    raise ValueError('Anthropic provider configuration not found')

                api_key = config.providers.anthropic.api_key
                _validate_api_key('Anthropic', api_key, logger)

                llm_config = GraphitiLLMConfig(
                    api_key=api_key,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )
                return AnthropicClient(config=llm_config)

            case 'gemini':
                if not HAS_GEMINI:
                    raise ValueError('Gemini client not available in current graphiti-core version')
                if not config.providers.gemini:
                    raise ValueError('Gemini provider configuration not found')

                api_key = config.providers.gemini.api_key
                _validate_api_key('Gemini', api_key, logger)

                llm_config = GraphitiLLMConfig(
                    api_key=api_key,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )
                return GeminiClient(config=llm_config)

            case 'groq':
                if not HAS_GROQ:
                    raise ValueError('Groq client not available in current graphiti-core version')
                if not config.providers.groq:
                    raise ValueError('Groq provider configuration not found')

                api_key = config.providers.groq.api_key
                _validate_api_key('Groq', api_key, logger)

                llm_config = GraphitiLLMConfig(
                    api_key=api_key,
                    base_url=config.providers.groq.api_url,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )
                return GroqClient(config=llm_config)

            case _:
                raise ValueError(f'Unsupported LLM provider: {provider}')


class EmbedderFactory:
    """Factory for creating Embedder clients based on configuration."""

    @staticmethod
    def create(config: EmbedderConfig) -> EmbedderClient:
        """Create an Embedder client based on the configured provider."""
        import logging

        logger = logging.getLogger(__name__)

        provider = config.provider.lower()

        match provider:
            case 'openai':
                if not config.providers.openai:
                    raise ValueError('OpenAI provider configuration not found')

                api_key = config.providers.openai.api_key
                _validate_api_key('OpenAI Embedder', api_key, logger)

                from graphiti_core.embedder.openai import OpenAIEmbedderConfig

                embedder_config = OpenAIEmbedderConfig(
                    api_key=api_key,
                    embedding_model=config.model,
                )
                return OpenAIEmbedder(config=embedder_config)

            case 'azure_openai':
                if not HAS_AZURE_EMBEDDER:
                    raise ValueError(
                        'Azure OpenAI embedder not available in current graphiti-core version'
                    )
                if not config.providers.azure_openai:
                    raise ValueError('Azure OpenAI provider configuration not found')
                azure_config = config.providers.azure_openai

                if not azure_config.api_url:
                    raise ValueError('Azure OpenAI API URL is required')

                # Handle Azure AD authentication if enabled
                api_key: str | None = None
                azure_ad_token_provider = None
                if azure_config.use_azure_ad:
                    logger.info(
                        'Creating Azure OpenAI Embedder client with Azure AD authentication'
                    )
                    azure_ad_token_provider = create_azure_credential_token_provider()
                else:
                    api_key = azure_config.api_key
                    _validate_api_key('Azure OpenAI Embedder', api_key, logger)

                # Create the Azure OpenAI client first
                azure_client = AsyncAzureOpenAI(
                    api_key=api_key,
                    azure_endpoint=azure_config.api_url,
                    api_version=azure_config.api_version,
                    azure_deployment=azure_config.deployment_name,
                    azure_ad_token_provider=azure_ad_token_provider,
                )

                return AzureOpenAIEmbedderClient(
                    azure_client=azure_client,
                    model=config.model or 'text-embedding-3-small',
                )

            case 'gemini':
                if not HAS_GEMINI_EMBEDDER:
                    raise ValueError(
                        'Gemini embedder not available in current graphiti-core version'
                    )
                if not config.providers.gemini:
                    raise ValueError('Gemini provider configuration not found')

                api_key = config.providers.gemini.api_key
                _validate_api_key('Gemini Embedder', api_key, logger)

                from graphiti_core.embedder.gemini import GeminiEmbedderConfig

                gemini_config = GeminiEmbedderConfig(
                    api_key=api_key,
                    embedding_model=config.model or 'models/text-embedding-004',
                    embedding_dim=config.dimensions or 768,
                )
                return GeminiEmbedder(config=gemini_config)

            case 'voyage':
                if not HAS_VOYAGE_EMBEDDER:
                    raise ValueError(
                        'Voyage embedder not available in current graphiti-core version'
                    )
                if not config.providers.voyage:
                    raise ValueError('Voyage provider configuration not found')

                api_key = config.providers.voyage.api_key
                _validate_api_key('Voyage Embedder', api_key, logger)

                from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig

                voyage_config = VoyageAIEmbedderConfig(
                    api_key=api_key,
                    embedding_model=config.model or 'voyage-3',
                    embedding_dim=config.dimensions or 1024,
                )
                return VoyageAIEmbedder(config=voyage_config)

            case _:
                raise ValueError(f'Unsupported Embedder provider: {provider}')


class DatabaseDriverFactory:
    """Factory for creating Database drivers based on configuration.

    Note: This returns configuration dictionaries that can be passed to Graphiti(),
    not driver instances directly, as the drivers require complex initialization.
    """

    @staticmethod
    def create_config(config: DatabaseConfig) -> dict:
        """Create database configuration dictionary based on the configured provider."""
        provider = config.provider.lower()

        match provider:
            case 'neo4j':
                # Use Neo4j config if provided, otherwise use defaults
                if config.providers.neo4j:
                    neo4j_config = config.providers.neo4j
                else:
                    # Create default Neo4j configuration
                    from config.schema import Neo4jProviderConfig

                    neo4j_config = Neo4jProviderConfig()

                # Check for environment variable overrides (for CI/CD compatibility)
                import os

                uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
                username = os.environ.get('NEO4J_USER', neo4j_config.username)
                password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)

                return {
                    'uri': uri,
                    'user': username,
                    'password': password,
                    # Note: database and use_parallel_runtime would need to be passed
                    # to the driver after initialization if supported
                }

            case 'falkordb':
                if not HAS_FALKOR:
                    raise ValueError(
                        'FalkorDB driver not available in current graphiti-core version'
                    )

                # Use FalkorDB config if provided, otherwise use defaults
                if config.providers.falkordb:
                    falkor_config = config.providers.falkordb
                else:
                    # Create default FalkorDB configuration
                    from config.schema import FalkorDBProviderConfig

                    falkor_config = FalkorDBProviderConfig()

                # Check for environment variable overrides (for CI/CD compatibility)
                import os
                from urllib.parse import urlparse

                uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
                password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)

                # Parse the URI to extract host and port
                parsed = urlparse(uri)
                host = parsed.hostname or 'localhost'
                port = parsed.port or 6379

                return {
                    'driver': 'falkordb',
                    'host': host,
                    'port': port,
                    'password': password,
                    'database': falkor_config.database,
                }

            case _:
                raise ValueError(f'Unsupported Database provider: {provider}')

```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/anthropic_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import os
import typing
from json import JSONDecodeError
from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel, ValidationError

from ..prompts.models import Message
from .client import LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError

if TYPE_CHECKING:
    import anthropic
    from anthropic import AsyncAnthropic
    from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
else:
    try:
        import anthropic
        from anthropic import AsyncAnthropic
        from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
    except ImportError:
        raise ImportError(
            'anthropic is required for AnthropicClient. '
            'Install it with: pip install graphiti-core[anthropic]'
        ) from None


logger = logging.getLogger(__name__)

AnthropicModel = Literal[
    'claude-sonnet-4-5-latest',
    'claude-sonnet-4-5-20250929',
    'claude-haiku-4-5-latest',
    'claude-3-7-sonnet-latest',
    'claude-3-7-sonnet-20250219',
    'claude-3-5-haiku-latest',
    'claude-3-5-haiku-20241022',
    'claude-3-5-sonnet-latest',
    'claude-3-5-sonnet-20241022',
    'claude-3-5-sonnet-20240620',
    'claude-3-opus-latest',
    'claude-3-opus-20240229',
    'claude-3-sonnet-20240229',
    'claude-3-haiku-20240307',
    'claude-2.1',
    'claude-2.0',
]

DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'

# Maximum output tokens for different Anthropic models
# Based on official Anthropic documentation (as of 2025)
# Note: These represent standard limits without beta headers.
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
ANTHROPIC_MODEL_MAX_TOKENS = {
    # Claude 4.5 models - 64K tokens
    'claude-sonnet-4-5-latest': 65536,
    'claude-sonnet-4-5-20250929': 65536,
    'claude-haiku-4-5-latest': 65536,
    # Claude 3.7 models - standard 64K tokens
    'claude-3-7-sonnet-latest': 65536,
    'claude-3-7-sonnet-20250219': 65536,
    # Claude 3.5 models
    'claude-3-5-haiku-latest': 8192,
    'claude-3-5-haiku-20241022': 8192,
    'claude-3-5-sonnet-latest': 8192,
    'claude-3-5-sonnet-20241022': 8192,
    'claude-3-5-sonnet-20240620': 8192,
    # Claude 3 models - 4K tokens
    'claude-3-opus-latest': 4096,
    'claude-3-opus-20240229': 4096,
    'claude-3-sonnet-20240229': 4096,
    'claude-3-haiku-20240307': 4096,
    # Claude 2 models - 4K tokens
    'claude-2.1': 4096,
    'claude-2.0': 4096,
}

# Default max tokens for models not in the mapping
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192


class AnthropicClient(LLMClient):
    """
    A client for the Anthropic LLM.

    Args:
        config: A configuration object for the LLM.
        cache: Whether to cache the LLM responses.
        client: An optional client instance to use.
        max_tokens: The maximum number of tokens to generate.

    Methods:
        generate_response: Generate a response from the LLM.

    Notes:
        - If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
            variable, and all default values will be used for the LLMConfig.

    """

    model: AnthropicModel

    def __init__(
        self,
        config: LLMConfig | None = None,
        cache: bool = False,
        client: AsyncAnthropic | None = None,
        max_tokens: int = DEFAULT_MAX_TOKENS,
    ) -> None:
        if config is None:
            config = LLMConfig()
            config.api_key = os.getenv('ANTHROPIC_API_KEY')
            config.max_tokens = max_tokens

        if config.model is None:
            config.model = DEFAULT_MODEL

        super().__init__(config, cache)
        # Explicitly set the instance model to the config model to prevent type checking errors
        self.model = typing.cast(AnthropicModel, config.model)

        if not client:
            self.client = AsyncAnthropic(
                api_key=config.api_key,
                max_retries=1,
            )
        else:
            self.client = client

    def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
        """Extract JSON from text content.

        A helper method to extract JSON from text content, used when tool use fails or
        no response_model is provided.

        Args:
            text: The text to extract JSON from

        Returns:
            Extracted JSON as a dictionary

        Raises:
            ValueError: If JSON cannot be extracted or parsed
        """
        try:
            json_start = text.find('{')
            json_end = text.rfind('}') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = text[json_start:json_end]
                return json.loads(json_str)
            else:
                raise ValueError(f'Could not extract JSON from model response: {text}')
        except (JSONDecodeError, ValueError) as e:
            raise ValueError(f'Could not extract JSON from model response: {text}') from e

    def _create_tool(
        self, response_model: type[BaseModel] | None = None
    ) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
        """
        Create a tool definition based on the response_model if provided, or a generic JSON tool if not.

        Args:
            response_model: Optional Pydantic model to use for structured output.

        Returns:
            A list containing a single tool definition for use with the Anthropic API.
        """
        if response_model is not None:
            # Use the response_model to define the tool
            model_schema = response_model.model_json_schema()
            tool_name = response_model.__name__
            description = model_schema.get('description', f'Extract {tool_name} information')
        else:
            # Create a generic JSON output tool
            tool_name = 'generic_json_output'
            description = 'Output data in JSON format'
            model_schema = {
                'type': 'object',
                'additionalProperties': True,
                'description': 'Any JSON object containing the requested information',
            }

        tool = {
            'name': tool_name,
            'description': description,
            'input_schema': model_schema,
        }
        tool_list = [tool]
        tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
        tool_choice = {'type': 'tool', 'name': tool_name}
        tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
        return tool_list_cast, tool_choice_cast

    def _get_max_tokens_for_model(self, model: str) -> int:
        """Get the maximum output tokens for a specific Anthropic model.

        Args:
            model: The model name to look up

        Returns:
            int: The maximum output tokens for the model
        """
        return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)

    def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
        """
        Resolve the maximum output tokens to use based on precedence rules.

        Precedence order (highest to lowest):
        1. Explicit max_tokens parameter passed to generate_response()
        2. Instance max_tokens set during client initialization
        3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
        4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback

        Args:
            requested_max_tokens: The max_tokens parameter passed to generate_response()
            model: The model name to look up model-specific limits

        Returns:
            int: The resolved maximum tokens to use
        """
        # 1. Use explicit parameter if provided
        if requested_max_tokens is not None:
            return requested_max_tokens

        # 2. Use instance max_tokens if set during initialization
        if self.max_tokens is not None:
            return self.max_tokens

        # 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
        return self._get_max_tokens_for_model(model)

    async def _generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the Anthropic LLM using tool-based approach for all requests.

        Args:
            messages: List of message objects to send to the LLM.
            response_model: Optional Pydantic model to use for structured output.
            max_tokens: Maximum number of tokens to generate.

        Returns:
            Dictionary containing the structured response from the LLM.

        Raises:
            RateLimitError: If the rate limit is exceeded.
            RefusalError: If the LLM refuses to respond.
            Exception: If an error occurs during the generation process.
        """
        system_message = messages[0]
        user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
        user_messages_cast = typing.cast(list[MessageParam], user_messages)

        # Resolve max_tokens dynamically based on the model's capabilities
        # This allows different models to use their full output capacity
        max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)

        try:
            # Create the appropriate tool based on whether response_model is provided
            tools, tool_choice = self._create_tool(response_model)
            result = await self.client.messages.create(
                system=system_message.content,
                max_tokens=max_creation_tokens,
                temperature=self.temperature,
                messages=user_messages_cast,
                model=self.model,
                tools=tools,
                tool_choice=tool_choice,
            )

            # Extract the tool output from the response
            for content_item in result.content:
                if content_item.type == 'tool_use':
                    if isinstance(content_item.input, dict):
                        tool_args: dict[str, typing.Any] = content_item.input
                    else:
                        tool_args = json.loads(str(content_item.input))
                    return tool_args

            # If we didn't get a proper tool_use response, try to extract from text
            for content_item in result.content:
                if content_item.type == 'text':
                    return self._extract_json_from_text(content_item.text)
                else:
                    raise ValueError(
                        f'Could not extract structured data from model response: {result.content}'
                    )

            # If we get here, we couldn't parse a structured response
            raise ValueError(
                f'Could not extract structured data from model response: {result.content}'
            )

        except anthropic.RateLimitError as e:
            raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
        except anthropic.APIError as e:
            # Special case for content policy violations. We convert these to RefusalError
            # to bypass the retry mechanism, as retrying policy-violating content will always fail.
            # This avoids wasting API calls and provides more specific error messaging to the user.
            if 'refused to respond' in str(e).lower():
                raise RefusalError(str(e)) from e
            raise e
        except Exception as e:
            raise e

    async def generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
        group_id: str | None = None,
        prompt_name: str | None = None,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the LLM.

        Args:
            messages: List of message objects to send to the LLM.
            response_model: Optional Pydantic model to use for structured output.
            max_tokens: Maximum number of tokens to generate.

        Returns:
            Dictionary containing the structured response from the LLM.

        Raises:
            RateLimitError: If the rate limit is exceeded.
            RefusalError: If the LLM refuses to respond.
            Exception: If an error occurs during the generation process.
        """
        if max_tokens is None:
            max_tokens = self.max_tokens

        # Wrap entire operation in tracing span
        with self.tracer.start_span('llm.generate') as span:
            attributes = {
                'llm.provider': 'anthropic',
                'model.size': model_size.value,
                'max_tokens': max_tokens,
            }
            if prompt_name:
                attributes['prompt.name'] = prompt_name
            span.add_attributes(attributes)

            retry_count = 0
            max_retries = 2
            last_error: Exception | None = None

            while retry_count <= max_retries:
                try:
                    response = await self._generate_response(
                        messages, response_model, max_tokens, model_size
                    )

                    # If we have a response_model, attempt to validate the response
                    if response_model is not None:
                        # Validate the response against the response_model
                        model_instance = response_model(**response)
                        return model_instance.model_dump()

                    # If no validation needed, return the response
                    return response

                except (RateLimitError, RefusalError):
                    # These errors should not trigger retries
                    span.set_status('error', str(last_error))
                    raise
                except Exception as e:
                    last_error = e

                    if retry_count >= max_retries:
                        if isinstance(e, ValidationError):
                            logger.error(
                                f'Validation error after {retry_count}/{max_retries} attempts: {e}'
                            )
                        else:
                            logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
                        span.set_status('error', str(e))
                        span.record_exception(e)
                        raise e

                    if isinstance(e, ValidationError):
                        response_model_cast = typing.cast(type[BaseModel], response_model)
                        error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
                    else:
                        error_context = (
                            f'The previous response attempt was invalid. '
                            f'Error type: {e.__class__.__name__}. '
                            f'Error details: {str(e)}. '
                            f'Please try again with a valid response.'
                        )

                    # Common retry logic
                    retry_count += 1
                    messages.append(Message(role='user', content=error_context))
                    logger.warning(
                        f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
                    )

            # If we somehow get here, raise the last error
            span.set_status('error', str(last_error))
            raise last_error or Exception('Max retries exceeded with no specific error')

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_async_operations.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Asynchronous operation tests for Graphiti MCP Server.
Tests concurrent operations, queue management, and async patterns.
"""

import asyncio
import contextlib
import json
import time

import pytest
from test_fixtures import (
    TestDataGenerator,
    graphiti_test_client,
)


class TestAsyncQueueManagement:
    """Test asynchronous queue operations and episode processing."""

    @pytest.mark.asyncio
    async def test_sequential_queue_processing(self):
        """Verify episodes are processed sequentially within a group."""
        async with graphiti_test_client() as (session, group_id):
            # Add multiple episodes quickly
            episodes = []
            for i in range(5):
                result = await session.call_tool(
                    'add_memory',
                    {
                        'name': f'Sequential Test {i}',
                        'episode_body': f'Episode {i} with timestamp {time.time()}',
                        'source': 'text',
                        'source_description': 'sequential test',
                        'group_id': group_id,
                        'reference_id': f'seq_{i}',  # Add reference for tracking
                    },
                )
                episodes.append(result)

            # Wait for processing
            await asyncio.sleep(10)  # Allow time for sequential processing

            # Retrieve episodes and verify order
            result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})

            processed_episodes = json.loads(result.content[0].text)['episodes']

            # Verify all episodes were processed
            assert len(processed_episodes) >= 5, (
                f'Expected at least 5 episodes, got {len(processed_episodes)}'
            )

            # Verify sequential processing (timestamps should be ordered)
            timestamps = [ep.get('created_at') for ep in processed_episodes]
            assert timestamps == sorted(timestamps), 'Episodes not processed in order'

    @pytest.mark.asyncio
    async def test_concurrent_group_processing(self):
        """Test that different groups can process concurrently."""
        async with graphiti_test_client() as (session, _):
            groups = [f'group_{i}_{time.time()}' for i in range(3)]
            tasks = []

            # Create tasks for different groups
            for group_id in groups:
                for j in range(2):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Group {group_id} Episode {j}',
                            'episode_body': f'Content for {group_id}',
                            'source': 'text',
                            'source_description': 'concurrent test',
                            'group_id': group_id,
                        },
                    )
                    tasks.append(task)

            # Execute all tasks concurrently
            start_time = time.time()
            results = await asyncio.gather(*tasks, return_exceptions=True)
            execution_time = time.time() - start_time

            # Verify all succeeded
            failures = [r for r in results if isinstance(r, Exception)]
            assert not failures, f'Concurrent operations failed: {failures}'

            # Check that execution was actually concurrent (should be faster than sequential)
            # Sequential would take at least 6 * processing_time
            assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'

    @pytest.mark.asyncio
    async def test_queue_overflow_handling(self):
        """Test behavior when queue reaches capacity."""
        async with graphiti_test_client() as (session, group_id):
            # Attempt to add many episodes rapidly
            tasks = []
            for i in range(100):  # Large number to potentially overflow
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Overflow Test {i}',
                        'episode_body': f'Episode {i}',
                        'source': 'text',
                        'source_description': 'overflow test',
                        'group_id': group_id,
                    },
                )
                tasks.append(task)

            # Execute with gathering to catch any failures
            results = await asyncio.gather(*tasks, return_exceptions=True)

            # Count successful queuing
            successful = sum(1 for r in results if not isinstance(r, Exception))

            # Should handle overflow gracefully
            assert successful > 0, 'No episodes were queued successfully'

            # Log overflow behavior
            if successful < 100:
                print(f'Queue overflow: {successful}/100 episodes queued')


class TestConcurrentOperations:
    """Test concurrent tool calls and operations."""

    @pytest.mark.asyncio
    async def test_concurrent_search_operations(self):
        """Test multiple concurrent search operations."""
        async with graphiti_test_client() as (session, group_id):
            # First, add some test data
            data_gen = TestDataGenerator()

            add_tasks = []
            for _ in range(5):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': 'Search Test Data',
                        'episode_body': data_gen.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'search test',
                        'group_id': group_id,
                    },
                )
                add_tasks.append(task)

            await asyncio.gather(*add_tasks)
            await asyncio.sleep(15)  # Wait for processing

            # Now perform concurrent searches
            search_queries = [
                'architecture',
                'performance',
                'implementation',
                'dependencies',
                'latency',
            ]

            search_tasks = []
            for query in search_queries:
                task = session.call_tool(
                    'search_memory_nodes',
                    {
                        'query': query,
                        'group_id': group_id,
                        'limit': 10,
                    },
                )
                search_tasks.append(task)

            start_time = time.time()
            results = await asyncio.gather(*search_tasks, return_exceptions=True)
            search_time = time.time() - start_time

            # Verify all searches completed
            failures = [r for r in results if isinstance(r, Exception)]
            assert not failures, f'Search operations failed: {failures}'

            # Verify concurrent execution efficiency
            assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'

    @pytest.mark.asyncio
    async def test_mixed_operation_concurrency(self):
        """Test different types of operations running concurrently."""
        async with graphiti_test_client() as (session, group_id):
            operations = []

            # Add memory operation
            operations.append(
                session.call_tool(
                    'add_memory',
                    {
                        'name': 'Mixed Op Test',
                        'episode_body': 'Testing mixed operations',
                        'source': 'text',
                        'source_description': 'test',
                        'group_id': group_id,
                    },
                )
            )

            # Search operation
            operations.append(
                session.call_tool(
                    'search_memory_nodes',
                    {
                        'query': 'test',
                        'group_id': group_id,
                        'limit': 5,
                    },
                )
            )

            # Get episodes operation
            operations.append(
                session.call_tool(
                    'get_episodes',
                    {
                        'group_id': group_id,
                        'last_n': 10,
                    },
                )
            )

            # Get status operation
            operations.append(session.call_tool('get_status', {}))

            # Execute all concurrently
            results = await asyncio.gather(*operations, return_exceptions=True)

            # Check results
            for i, result in enumerate(results):
                assert not isinstance(result, Exception), f'Operation {i} failed: {result}'


class TestAsyncErrorHandling:
    """Test async error handling and recovery."""

    @pytest.mark.asyncio
    async def test_timeout_recovery(self):
        """Test recovery from operation timeouts."""
        async with graphiti_test_client() as (session, group_id):
            # Create a very large episode that might time out
            large_content = 'x' * 1000000  # 1MB of data

            with contextlib.suppress(asyncio.TimeoutError):
                await asyncio.wait_for(
                    session.call_tool(
                        'add_memory',
                        {
                            'name': 'Timeout Test',
                            'episode_body': large_content,
                            'source': 'text',
                            'source_description': 'timeout test',
                            'group_id': group_id,
                        },
                    ),
                    timeout=2.0,  # Short timeout - expected to timeout
                )

            # Verify server is still responsive after timeout
            status_result = await session.call_tool('get_status', {})
            assert status_result is not None, 'Server unresponsive after timeout'

    @pytest.mark.asyncio
    async def test_cancellation_handling(self):
        """Test proper handling of cancelled operations."""
        async with graphiti_test_client() as (session, group_id):
            # Start a long-running operation
            task = asyncio.create_task(
                session.call_tool(
                    'add_memory',
                    {
                        'name': 'Cancellation Test',
                        'episode_body': TestDataGenerator.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'cancel test',
                        'group_id': group_id,
                    },
                )
            )

            # Cancel after a short delay
            await asyncio.sleep(0.1)
            task.cancel()

            # Verify cancellation was handled
            with pytest.raises(asyncio.CancelledError):
                await task

            # Server should still be operational
            result = await session.call_tool('get_status', {})
            assert result is not None

    @pytest.mark.asyncio
    async def test_exception_propagation(self):
        """Test that exceptions are properly propagated in async context."""
        async with graphiti_test_client() as (session, group_id):
            # Call with invalid arguments
            with pytest.raises(ValueError):
                await session.call_tool(
                    'add_memory',
                    {
                        # Missing required fields
                        'group_id': group_id,
                    },
                )

            # Server should remain operational
            status = await session.call_tool('get_status', {})
            assert status is not None


class TestAsyncPerformance:
    """Performance tests for async operations."""

    @pytest.mark.asyncio
    async def test_async_throughput(self, performance_benchmark):
        """Measure throughput of async operations."""
        async with graphiti_test_client() as (session, group_id):
            num_operations = 50
            start_time = time.time()

            # Create many concurrent operations
            tasks = []
            for i in range(num_operations):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Throughput Test {i}',
                        'episode_body': f'Content {i}',
                        'source': 'text',
                        'source_description': 'throughput test',
                        'group_id': group_id,
                    },
                )
                tasks.append(task)

            # Execute all
            results = await asyncio.gather(*tasks, return_exceptions=True)
            total_time = time.time() - start_time

            # Calculate metrics
            successful = sum(1 for r in results if not isinstance(r, Exception))
            throughput = successful / total_time

            performance_benchmark.record('async_throughput', throughput)

            # Log results
            print('\nAsync Throughput Test:')
            print(f'  Operations: {num_operations}')
            print(f'  Successful: {successful}')
            print(f'  Total time: {total_time:.2f}s')
            print(f'  Throughput: {throughput:.2f} ops/s')

            # Assert minimum throughput
            assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'

    @pytest.mark.asyncio
    async def test_latency_under_load(self, performance_benchmark):
        """Test operation latency under concurrent load."""
        async with graphiti_test_client() as (session, group_id):
            # Create background load
            background_tasks = []
            for i in range(10):
                task = asyncio.create_task(
                    session.call_tool(
                        'add_memory',
                        {
                            'name': f'Background {i}',
                            'episode_body': TestDataGenerator.generate_technical_document(),
                            'source': 'text',
                            'source_description': 'background',
                            'group_id': f'background_{group_id}',
                        },
                    )
                )
                background_tasks.append(task)

            # Measure latency of operations under load
            latencies = []
            for _ in range(5):
                start = time.time()
                await session.call_tool('get_status', {})
                latency = time.time() - start
                latencies.append(latency)
                performance_benchmark.record('latency_under_load', latency)

            # Clean up background tasks
            for task in background_tasks:
                task.cancel()

            # Analyze latencies
            avg_latency = sum(latencies) / len(latencies)
            max_latency = max(latencies)

            print('\nLatency Under Load:')
            print(f'  Average: {avg_latency:.3f}s')
            print(f'  Max: {max_latency:.3f}s')

            # Assert acceptable latency
            assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
            assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'


class TestAsyncStreamHandling:
    """Test handling of streaming responses and data."""

    @pytest.mark.asyncio
    async def test_large_response_streaming(self):
        """Test handling of large streamed responses."""
        async with graphiti_test_client() as (session, group_id):
            # Add many episodes
            for i in range(20):
                await session.call_tool(
                    'add_memory',
                    {
                        'name': f'Stream Test {i}',
                        'episode_body': f'Episode content {i}',
                        'source': 'text',
                        'source_description': 'stream test',
                        'group_id': group_id,
                    },
                )

            # Wait for processing
            await asyncio.sleep(30)

            # Request large result set
            result = await session.call_tool(
                'get_episodes',
                {
                    'group_id': group_id,
                    'last_n': 100,  # Request all
                },
            )

            # Verify response handling
            episodes = json.loads(result.content[0].text)['episodes']
            assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'

    @pytest.mark.asyncio
    async def test_incremental_processing(self):
        """Test incremental processing of results."""
        async with graphiti_test_client() as (session, group_id):
            # Add episodes incrementally
            for batch in range(3):
                batch_tasks = []
                for i in range(5):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Batch {batch} Item {i}',
                            'episode_body': f'Content for batch {batch}',
                            'source': 'text',
                            'source_description': 'incremental test',
                            'group_id': group_id,
                        },
                    )
                    batch_tasks.append(task)

                # Process batch
                await asyncio.gather(*batch_tasks)

                # Wait for this batch to process
                await asyncio.sleep(10)

                # Verify incremental results
                result = await session.call_tool(
                    'get_episodes',
                    {
                        'group_id': group_id,
                        'last_n': 100,
                    },
                )

                episodes = json.loads(result.content[0].text)['episodes']
                expected_min = (batch + 1) * 5
                assert len(episodes) >= expected_min, (
                    f'Batch {batch}: Expected at least {expected_min} episodes'
                )


if __name__ == '__main__':
    pytest.main([__file__, '-v', '--asyncio-mode=auto'])

```

--------------------------------------------------------------------------------
/graphiti_core/search/search.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from collections import defaultdict
from time import time

from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder.client import EMBEDDING_DIM
from graphiti_core.errors import SearchRerankerError
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
    DEFAULT_SEARCH_LIMIT,
    CommunityReranker,
    CommunitySearchConfig,
    CommunitySearchMethod,
    EdgeReranker,
    EdgeSearchConfig,
    EdgeSearchMethod,
    EpisodeReranker,
    EpisodeSearchConfig,
    NodeReranker,
    NodeSearchConfig,
    NodeSearchMethod,
    SearchConfig,
    SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
    community_fulltext_search,
    community_similarity_search,
    edge_bfs_search,
    edge_fulltext_search,
    edge_similarity_search,
    episode_fulltext_search,
    episode_mentions_reranker,
    get_embeddings_for_communities,
    get_embeddings_for_edges,
    get_embeddings_for_nodes,
    maximal_marginal_relevance,
    node_bfs_search,
    node_distance_reranker,
    node_fulltext_search,
    node_similarity_search,
    rrf,
)

logger = logging.getLogger(__name__)


async def search(
    clients: GraphitiClients,
    query: str,
    group_ids: list[str] | None,
    config: SearchConfig,
    search_filter: SearchFilters,
    center_node_uuid: str | None = None,
    bfs_origin_node_uuids: list[str] | None = None,
    query_vector: list[float] | None = None,
    driver: GraphDriver | None = None,
) -> SearchResults:
    start = time()

    driver = driver or clients.driver
    embedder = clients.embedder
    cross_encoder = clients.cross_encoder

    if query.strip() == '':
        return SearchResults()

    if (
        config.edge_config
        and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
        or config.edge_config
        and EdgeReranker.mmr == config.edge_config.reranker
        or config.node_config
        and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
        or config.node_config
        and NodeReranker.mmr == config.node_config.reranker
        or (
            config.community_config
            and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
        )
        or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
    ):
        search_vector = (
            query_vector
            if query_vector is not None
            else await embedder.create(input_data=[query.replace('\n', ' ')])
        )
    else:
        search_vector = [0.0] * EMBEDDING_DIM

    # if group_ids is empty, set it to None
    group_ids = group_ids if group_ids and group_ids != [''] else None
    (
        (edges, edge_reranker_scores),
        (nodes, node_reranker_scores),
        (episodes, episode_reranker_scores),
        (communities, community_reranker_scores),
    ) = await semaphore_gather(
        edge_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.edge_config,
            search_filter,
            center_node_uuid,
            bfs_origin_node_uuids,
            config.limit,
            config.reranker_min_score,
        ),
        node_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.node_config,
            search_filter,
            center_node_uuid,
            bfs_origin_node_uuids,
            config.limit,
            config.reranker_min_score,
        ),
        episode_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.episode_config,
            search_filter,
            config.limit,
            config.reranker_min_score,
        ),
        community_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.community_config,
            config.limit,
            config.reranker_min_score,
        ),
    )

    results = SearchResults(
        edges=edges,
        edge_reranker_scores=edge_reranker_scores,
        nodes=nodes,
        node_reranker_scores=node_reranker_scores,
        episodes=episodes,
        episode_reranker_scores=episode_reranker_scores,
        communities=communities,
        community_reranker_scores=community_reranker_scores,
    )

    latency = (time() - start) * 1000

    logger.debug(f'search returned context for query {query} in {latency} ms')

    return results


async def edge_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    query_vector: list[float],
    group_ids: list[str] | None,
    config: EdgeSearchConfig | None,
    search_filter: SearchFilters,
    center_node_uuid: str | None = None,
    bfs_origin_node_uuids: list[str] | None = None,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[EntityEdge], list[float]]:
    if config is None:
        return [], []

    # Build search tasks based on configured search methods
    search_tasks = []
    if EdgeSearchMethod.bm25 in config.search_methods:
        search_tasks.append(
            edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
        )
    if EdgeSearchMethod.cosine_similarity in config.search_methods:
        search_tasks.append(
            edge_similarity_search(
                driver,
                query_vector,
                None,
                None,
                search_filter,
                group_ids,
                2 * limit,
                config.sim_min_score,
            )
        )
    if EdgeSearchMethod.bfs in config.search_methods:
        search_tasks.append(
            edge_bfs_search(
                driver,
                bfs_origin_node_uuids,
                config.bfs_max_depth,
                search_filter,
                group_ids,
                2 * limit,
            )
        )

    # Execute only the configured search methods
    search_results: list[list[EntityEdge]] = []
    if search_tasks:
        search_results = list(await semaphore_gather(*search_tasks))

    if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
        source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
        search_results.append(
            await edge_bfs_search(
                driver,
                source_node_uuids,
                config.bfs_max_depth,
                search_filter,
                group_ids,
                2 * limit,
            )
        )

    edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}

    reranked_uuids: list[str] = []
    edge_scores: list[float] = []
    if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
        search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

        reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
    elif config.reranker == EdgeReranker.mmr:
        search_result_uuids_and_vectors = await get_embeddings_for_edges(
            driver, list(edge_uuid_map.values())
        )
        reranked_uuids, edge_scores = maximal_marginal_relevance(
            query_vector,
            search_result_uuids_and_vectors,
            config.mmr_lambda,
            reranker_min_score,
        )
    elif config.reranker == EdgeReranker.cross_encoder:
        fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
        reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
        reranked_uuids = [
            fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
        ]
        edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
    elif config.reranker == EdgeReranker.node_distance:
        if center_node_uuid is None:
            raise SearchRerankerError('No center node provided for Node Distance reranker')

        # use rrf as a preliminary sort
        sorted_result_uuids, node_scores = rrf(
            [[edge.uuid for edge in result] for result in search_results],
            min_score=reranker_min_score,
        )
        sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]

        # node distance reranking
        source_to_edge_uuid_map = defaultdict(list)
        for edge in sorted_results:
            source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)

        source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]

        reranked_node_uuids, edge_scores = await node_distance_reranker(
            driver, source_uuids, center_node_uuid, min_score=reranker_min_score
        )

        for node_uuid in reranked_node_uuids:
            reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])

    reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]

    if config.reranker == EdgeReranker.episode_mentions:
        reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))

    return reranked_edges[:limit], edge_scores[:limit]


async def node_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    query_vector: list[float],
    group_ids: list[str] | None,
    config: NodeSearchConfig | None,
    search_filter: SearchFilters,
    center_node_uuid: str | None = None,
    bfs_origin_node_uuids: list[str] | None = None,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[EntityNode], list[float]]:
    if config is None:
        return [], []

    # Build search tasks based on configured search methods
    search_tasks = []
    if NodeSearchMethod.bm25 in config.search_methods:
        search_tasks.append(
            node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
        )
    if NodeSearchMethod.cosine_similarity in config.search_methods:
        search_tasks.append(
            node_similarity_search(
                driver,
                query_vector,
                search_filter,
                group_ids,
                2 * limit,
                config.sim_min_score,
            )
        )
    if NodeSearchMethod.bfs in config.search_methods:
        search_tasks.append(
            node_bfs_search(
                driver,
                bfs_origin_node_uuids,
                search_filter,
                config.bfs_max_depth,
                group_ids,
                2 * limit,
            )
        )

    # Execute only the configured search methods
    search_results: list[list[EntityNode]] = []
    if search_tasks:
        search_results = list(await semaphore_gather(*search_tasks))

    if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
        origin_node_uuids = [node.uuid for result in search_results for node in result]
        search_results.append(
            await node_bfs_search(
                driver,
                origin_node_uuids,
                search_filter,
                config.bfs_max_depth,
                group_ids,
                2 * limit,
            )
        )

    search_result_uuids = [[node.uuid for node in result] for result in search_results]
    node_uuid_map = {node.uuid: node for result in search_results for node in result}

    reranked_uuids: list[str] = []
    node_scores: list[float] = []
    if config.reranker == NodeReranker.rrf:
        reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
    elif config.reranker == NodeReranker.mmr:
        search_result_uuids_and_vectors = await get_embeddings_for_nodes(
            driver, list(node_uuid_map.values())
        )

        reranked_uuids, node_scores = maximal_marginal_relevance(
            query_vector,
            search_result_uuids_and_vectors,
            config.mmr_lambda,
            reranker_min_score,
        )
    elif config.reranker == NodeReranker.cross_encoder:
        name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}

        reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
        reranked_uuids = [
            name_to_uuid_map[name]
            for name, score in reranked_node_names
            if score >= reranker_min_score
        ]
        node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
    elif config.reranker == NodeReranker.episode_mentions:
        reranked_uuids, node_scores = await episode_mentions_reranker(
            driver, search_result_uuids, min_score=reranker_min_score
        )
    elif config.reranker == NodeReranker.node_distance:
        if center_node_uuid is None:
            raise SearchRerankerError('No center node provided for Node Distance reranker')
        reranked_uuids, node_scores = await node_distance_reranker(
            driver,
            rrf(search_result_uuids, min_score=reranker_min_score)[0],
            center_node_uuid,
            min_score=reranker_min_score,
        )

    reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]

    return reranked_nodes[:limit], node_scores[:limit]


async def episode_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    _query_vector: list[float],
    group_ids: list[str] | None,
    config: EpisodeSearchConfig | None,
    search_filter: SearchFilters,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[EpisodicNode], list[float]]:
    if config is None:
        return [], []
    search_results: list[list[EpisodicNode]] = list(
        await semaphore_gather(
            *[
                episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
            ]
        )
    )

    search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
    episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}

    reranked_uuids: list[str] = []
    episode_scores: list[float] = []
    if config.reranker == EpisodeReranker.rrf:
        reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)

    elif config.reranker == EpisodeReranker.cross_encoder:
        # use rrf as a preliminary reranker
        rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
        rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]

        content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}

        reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
        reranked_uuids = [
            content_to_uuid_map[content]
            for content, score in reranked_contents
            if score >= reranker_min_score
        ]
        episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]

    reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]

    return reranked_episodes[:limit], episode_scores[:limit]


async def community_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    query_vector: list[float],
    group_ids: list[str] | None,
    config: CommunitySearchConfig | None,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[CommunityNode], list[float]]:
    if config is None:
        return [], []

    search_results: list[list[CommunityNode]] = list(
        await semaphore_gather(
            *[
                community_fulltext_search(driver, query, group_ids, 2 * limit),
                community_similarity_search(
                    driver, query_vector, group_ids, 2 * limit, config.sim_min_score
                ),
            ]
        )
    )

    search_result_uuids = [[community.uuid for community in result] for result in search_results]
    community_uuid_map = {
        community.uuid: community for result in search_results for community in result
    }

    reranked_uuids: list[str] = []
    community_scores: list[float] = []
    if config.reranker == CommunityReranker.rrf:
        reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
    elif config.reranker == CommunityReranker.mmr:
        search_result_uuids_and_vectors = await get_embeddings_for_communities(
            driver, list(community_uuid_map.values())
        )

        reranked_uuids, community_scores = maximal_marginal_relevance(
            query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
        )
    elif config.reranker == CommunityReranker.cross_encoder:
        name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
        reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
        reranked_uuids = [
            name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
        ]
        community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]

    reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

    return reranked_communities[:limit], community_scores[:limit]

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_integration.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
Tests all major MCP tools and handles episode processing latency.
"""

import asyncio
import json
import os
import time
from typing import Any

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client


class GraphitiMCPIntegrationTest:
    """Integration test client for Graphiti MCP Server using official MCP SDK."""

    def __init__(self):
        self.test_group_id = f'test_group_{int(time.time())}'
        self.session = None

    async def __aenter__(self):
        """Start the MCP client session."""
        # Configure server parameters to run our refactored server
        server_params = StdioServerParameters(
            command='uv',
            args=['run', 'main.py', '--transport', 'stdio'],
            env={
                'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
                'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
            },
        )

        print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')

        # Use the async context manager properly
        self.client_context = stdio_client(server_params)
        read, write = await self.client_context.__aenter__()
        self.session = ClientSession(read, write)
        await self.session.initialize()

        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Close the MCP client session."""
        if self.session:
            await self.session.close()
        if hasattr(self, 'client_context'):
            await self.client_context.__aexit__(exc_type, exc_val, exc_tb)

    async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
        """Call an MCP tool and return the result."""
        try:
            result = await self.session.call_tool(tool_name, arguments)
            return result.content[0].text if result.content else {'error': 'No content returned'}
        except Exception as e:
            return {'error': str(e)}

    async def test_server_initialization(self) -> bool:
        """Test that the server initializes properly."""
        print('🔍 Testing server initialization...')

        try:
            # List available tools to verify server is responding
            tools_result = await self.session.list_tools()
            tools = [tool.name for tool in tools_result.tools]

            expected_tools = [
                'add_memory',
                'search_memory_nodes',
                'search_memory_facts',
                'get_episodes',
                'delete_episode',
                'delete_entity_edge',
                'get_entity_edge',
                'clear_graph',
            ]

            available_tools = len([tool for tool in expected_tools if tool in tools])
            print(
                f'   ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
            )
            print(f'   Available tools: {", ".join(sorted(tools))}')

            return available_tools >= len(expected_tools) * 0.8  # 80% of expected tools

        except Exception as e:
            print(f'   ❌ Server initialization failed: {e}')
            return False

    async def test_add_memory_operations(self) -> dict[str, bool]:
        """Test adding various types of memory episodes."""
        print('📝 Testing add_memory operations...')

        results = {}

        # Test 1: Add text episode
        print('   Testing text episode...')
        try:
            result = await self.call_tool(
                'add_memory',
                {
                    'name': 'Test Company News',
                    'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
                    'source': 'text',
                    'source_description': 'news article',
                    'group_id': self.test_group_id,
                },
            )

            if isinstance(result, str) and 'queued' in result.lower():
                print(f'   ✅ Text episode: {result}')
                results['text'] = True
            else:
                print(f'   ❌ Text episode failed: {result}')
                results['text'] = False
        except Exception as e:
            print(f'   ❌ Text episode error: {e}')
            results['text'] = False

        # Test 2: Add JSON episode
        print('   Testing JSON episode...')
        try:
            json_data = {
                'company': {'name': 'TechCorp', 'founded': 2010},
                'products': [
                    {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
                    {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
                ],
                'employees': 150,
            }

            result = await self.call_tool(
                'add_memory',
                {
                    'name': 'Company Profile',
                    'episode_body': json.dumps(json_data),
                    'source': 'json',
                    'source_description': 'CRM data',
                    'group_id': self.test_group_id,
                },
            )

            if isinstance(result, str) and 'queued' in result.lower():
                print(f'   ✅ JSON episode: {result}')
                results['json'] = True
            else:
                print(f'   ❌ JSON episode failed: {result}')
                results['json'] = False
        except Exception as e:
            print(f'   ❌ JSON episode error: {e}')
            results['json'] = False

        # Test 3: Add message episode
        print('   Testing message episode...')
        try:
            result = await self.call_tool(
                'add_memory',
                {
                    'name': 'Customer Support Chat',
                    'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
                    'source': 'message',
                    'source_description': 'support chat log',
                    'group_id': self.test_group_id,
                },
            )

            if isinstance(result, str) and 'queued' in result.lower():
                print(f'   ✅ Message episode: {result}')
                results['message'] = True
            else:
                print(f'   ❌ Message episode failed: {result}')
                results['message'] = False
        except Exception as e:
            print(f'   ❌ Message episode error: {e}')
            results['message'] = False

        return results

    async def wait_for_processing(self, max_wait: int = 45) -> bool:
        """Wait for episode processing to complete."""
        print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')

        for i in range(max_wait):
            await asyncio.sleep(1)

            try:
                # Check if we have any episodes
                result = await self.call_tool(
                    'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
                )

                # Parse the JSON result if it's a string
                if isinstance(result, str):
                    try:
                        parsed_result = json.loads(result)
                        if isinstance(parsed_result, list) and len(parsed_result) > 0:
                            print(
                                f'   ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
                            )
                            return True
                    except json.JSONDecodeError:
                        if 'episodes' in result.lower():
                            print(f'   ✅ Episodes detected after {i + 1} seconds')
                            return True

            except Exception as e:
                if i == 0:  # Only log first error to avoid spam
                    print(f'   ⚠️  Waiting for processing... ({e})')
                continue

        print(f'   ⚠️  Still waiting after {max_wait} seconds...')
        return False

    async def test_search_operations(self) -> dict[str, bool]:
        """Test search functionality."""
        print('🔍 Testing search operations...')

        results = {}

        # Test search_memory_nodes
        print('   Testing search_memory_nodes...')
        try:
            result = await self.call_tool(
                'search_memory_nodes',
                {
                    'query': 'Acme Corp product launch AI',
                    'group_ids': [self.test_group_id],
                    'max_nodes': 5,
                },
            )

            success = False
            if isinstance(result, str):
                try:
                    parsed = json.loads(result)
                    nodes = parsed.get('nodes', [])
                    success = isinstance(nodes, list)
                    print(f'   ✅ Node search returned {len(nodes)} nodes')
                except json.JSONDecodeError:
                    success = 'nodes' in result.lower() and 'successfully' in result.lower()
                    if success:
                        print('   ✅ Node search completed successfully')

            results['nodes'] = success
            if not success:
                print(f'   ❌ Node search failed: {result}')

        except Exception as e:
            print(f'   ❌ Node search error: {e}')
            results['nodes'] = False

        # Test search_memory_facts
        print('   Testing search_memory_facts...')
        try:
            result = await self.call_tool(
                'search_memory_facts',
                {
                    'query': 'company products software TechCorp',
                    'group_ids': [self.test_group_id],
                    'max_facts': 5,
                },
            )

            success = False
            if isinstance(result, str):
                try:
                    parsed = json.loads(result)
                    facts = parsed.get('facts', [])
                    success = isinstance(facts, list)
                    print(f'   ✅ Fact search returned {len(facts)} facts')
                except json.JSONDecodeError:
                    success = 'facts' in result.lower() and 'successfully' in result.lower()
                    if success:
                        print('   ✅ Fact search completed successfully')

            results['facts'] = success
            if not success:
                print(f'   ❌ Fact search failed: {result}')

        except Exception as e:
            print(f'   ❌ Fact search error: {e}')
            results['facts'] = False

        return results

    async def test_episode_retrieval(self) -> bool:
        """Test episode retrieval."""
        print('📚 Testing episode retrieval...')

        try:
            result = await self.call_tool(
                'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
            )

            if isinstance(result, str):
                try:
                    parsed = json.loads(result)
                    if isinstance(parsed, list):
                        print(f'   ✅ Retrieved {len(parsed)} episodes')

                        # Show episode details
                        for i, episode in enumerate(parsed[:3]):
                            name = episode.get('name', 'Unknown')
                            source = episode.get('source', 'unknown')
                            print(f'     Episode {i + 1}: {name} (source: {source})')

                        return len(parsed) > 0
                except json.JSONDecodeError:
                    # Check if response indicates success
                    if 'episode' in result.lower():
                        print('   ✅ Episode retrieval completed')
                        return True

            print(f'   ❌ Unexpected result format: {result}')
            return False

        except Exception as e:
            print(f'   ❌ Episode retrieval failed: {e}')
            return False

    async def test_error_handling(self) -> dict[str, bool]:
        """Test error handling and edge cases."""
        print('🧪 Testing error handling...')

        results = {}

        # Test with nonexistent group
        print('   Testing nonexistent group handling...')
        try:
            result = await self.call_tool(
                'search_memory_nodes',
                {
                    'query': 'nonexistent data',
                    'group_ids': ['nonexistent_group_12345'],
                    'max_nodes': 5,
                },
            )

            # Should handle gracefully, not crash
            success = (
                'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
            )
            if success:
                print('   ✅ Nonexistent group handled gracefully')
            else:
                print(f'   ❌ Nonexistent group caused issues: {result}')

            results['nonexistent_group'] = success

        except Exception as e:
            print(f'   ❌ Nonexistent group test failed: {e}')
            results['nonexistent_group'] = False

        # Test empty query
        print('   Testing empty query handling...')
        try:
            result = await self.call_tool(
                'search_memory_nodes',
                {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
            )

            # Should handle gracefully
            success = (
                'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
            )
            if success:
                print('   ✅ Empty query handled gracefully')
            else:
                print(f'   ❌ Empty query caused issues: {result}')

            results['empty_query'] = success

        except Exception as e:
            print(f'   ❌ Empty query test failed: {e}')
            results['empty_query'] = False

        return results

    async def run_comprehensive_test(self) -> dict[str, Any]:
        """Run the complete integration test suite."""
        print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
        print(f'   Test group ID: {self.test_group_id}')
        print('=' * 70)

        results = {
            'server_init': False,
            'add_memory': {},
            'processing_wait': False,
            'search': {},
            'episodes': False,
            'error_handling': {},
            'overall_success': False,
        }

        # Test 1: Server Initialization
        results['server_init'] = await self.test_server_initialization()
        if not results['server_init']:
            print('❌ Server initialization failed, aborting remaining tests')
            return results

        print()

        # Test 2: Add Memory Operations
        results['add_memory'] = await self.test_add_memory_operations()
        print()

        # Test 3: Wait for Processing
        results['processing_wait'] = await self.wait_for_processing()
        print()

        # Test 4: Search Operations
        results['search'] = await self.test_search_operations()
        print()

        # Test 5: Episode Retrieval
        results['episodes'] = await self.test_episode_retrieval()
        print()

        # Test 6: Error Handling
        results['error_handling'] = await self.test_error_handling()
        print()

        # Calculate overall success
        memory_success = any(results['add_memory'].values())
        search_success = any(results['search'].values()) if results['search'] else False
        error_success = (
            any(results['error_handling'].values()) if results['error_handling'] else True
        )

        results['overall_success'] = (
            results['server_init']
            and memory_success
            and (results['episodes'] or results['processing_wait'])
            and error_success
        )

        # Print comprehensive summary
        print('=' * 70)
        print('📊 COMPREHENSIVE TEST SUMMARY')
        print('-' * 35)
        print(f'Server Initialization:    {"✅ PASS" if results["server_init"] else "❌ FAIL"}')

        memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
        print(
            f'Memory Operations:        {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
        )

        print(f'Processing Pipeline:      {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')

        search_stats = (
            f'({sum(results["search"].values())}/{len(results["search"])} types)'
            if results['search']
            else '(0/0 types)'
        )
        print(
            f'Search Operations:        {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
        )

        print(f'Episode Retrieval:        {"✅ PASS" if results["episodes"] else "❌ FAIL"}')

        error_stats = (
            f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
            if results['error_handling']
            else '(0/0 cases)'
        )
        print(
            f'Error Handling:           {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
        )

        print('-' * 35)
        print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')

        if results['overall_success']:
            print('\n🎉 The refactored Graphiti MCP server is working correctly!')
            print('   All core functionality has been successfully tested.')
        else:
            print('\n⚠️  Some issues were detected. Review the test results above.')
            print('   The refactoring may need additional attention.')

        return results


async def main():
    """Run the integration test."""
    try:
        async with GraphitiMCPIntegrationTest() as test:
            results = await test.run_comprehensive_test()

            # Exit with appropriate code
            exit_code = 0 if results['overall_success'] else 1
            exit(exit_code)
    except Exception as e:
        print(f'❌ Test setup failed: {e}')
        exit(1)


if __name__ == '__main__':
    asyncio.run(main())

```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/gemini_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import re
import typing
from typing import TYPE_CHECKING, ClassVar

from pydantic import BaseModel

from ..prompts.models import Message
from .client import LLMClient, get_extraction_language_instruction
from .config import LLMConfig, ModelSize
from .errors import RateLimitError

if TYPE_CHECKING:
    from google import genai
    from google.genai import types
else:
    try:
        from google import genai
        from google.genai import types
    except ImportError:
        # If gemini client is not installed, raise an ImportError
        raise ImportError(
            'google-genai is required for GeminiClient. '
            'Install it with: pip install graphiti-core[google-genai]'
        ) from None


logger = logging.getLogger(__name__)

DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'

# Maximum output tokens for different Gemini models
GEMINI_MODEL_MAX_TOKENS = {
    # Gemini 2.5 models
    'gemini-2.5-pro': 65536,
    'gemini-2.5-flash': 65536,
    'gemini-2.5-flash-lite': 64000,
    # Gemini 2.0 models
    'gemini-2.0-flash': 8192,
    'gemini-2.0-flash-lite': 8192,
    # Gemini 1.5 models
    'gemini-1.5-pro': 8192,
    'gemini-1.5-flash': 8192,
    'gemini-1.5-flash-8b': 8192,
}

# Default max tokens for models not in the mapping
DEFAULT_GEMINI_MAX_TOKENS = 8192


class GeminiClient(LLMClient):
    """
    GeminiClient is a client class for interacting with Google's Gemini language models.

    This class extends the LLMClient and provides methods to initialize the client
    and generate responses from the Gemini language model.

    Attributes:
        model (str): The model name to use for generating responses.
        temperature (float): The temperature to use for generating responses.
        max_tokens (int): The maximum number of tokens to generate in a response.
        thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
    Methods:
        __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
            Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.

        _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
            Generates a response from the language model based on the provided messages.
    """

    # Class-level constants
    MAX_RETRIES: ClassVar[int] = 2

    def __init__(
        self,
        config: LLMConfig | None = None,
        cache: bool = False,
        max_tokens: int | None = None,
        thinking_config: types.ThinkingConfig | None = None,
        client: 'genai.Client | None' = None,
    ):
        """
        Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.

        Args:
            config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
            cache (bool): Whether to use caching for responses. Defaults to False.
            thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
                Only use with models that support thinking (gemini-2.5+). Defaults to None.
            client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
        """
        if config is None:
            config = LLMConfig()

        super().__init__(config, cache)

        self.model = config.model

        if client is None:
            self.client = genai.Client(api_key=config.api_key)
        else:
            self.client = client

        self.max_tokens = max_tokens
        self.thinking_config = thinking_config

    def _check_safety_blocks(self, response) -> None:
        """Check if response was blocked for safety reasons and raise appropriate exceptions."""
        # Check if the response was blocked for safety reasons
        if not (hasattr(response, 'candidates') and response.candidates):
            return

        candidate = response.candidates[0]
        if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
            return

        # Content was blocked for safety reasons - collect safety details
        safety_info = []
        safety_ratings = getattr(candidate, 'safety_ratings', None)

        if safety_ratings:
            for rating in safety_ratings:
                if getattr(rating, 'blocked', False):
                    category = getattr(rating, 'category', 'Unknown')
                    probability = getattr(rating, 'probability', 'Unknown')
                    safety_info.append(f'{category}: {probability}')

        safety_details = (
            ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
        )
        raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')

    def _check_prompt_blocks(self, response) -> None:
        """Check if prompt was blocked and raise appropriate exceptions."""
        prompt_feedback = getattr(response, 'prompt_feedback', None)
        if not prompt_feedback:
            return

        block_reason = getattr(prompt_feedback, 'block_reason', None)
        if block_reason:
            raise Exception(f'Prompt blocked by Gemini: {block_reason}')

    def _get_model_for_size(self, model_size: ModelSize) -> str:
        """Get the appropriate model name based on the requested size."""
        if model_size == ModelSize.small:
            return self.small_model or DEFAULT_SMALL_MODEL
        else:
            return self.model or DEFAULT_MODEL

    def _get_max_tokens_for_model(self, model: str) -> int:
        """Get the maximum output tokens for a specific Gemini model."""
        return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)

    def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
        """
        Resolve the maximum output tokens to use based on precedence rules.

        Precedence order (highest to lowest):
        1. Explicit max_tokens parameter passed to generate_response()
        2. Instance max_tokens set during client initialization
        3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
        4. DEFAULT_MAX_TOKENS as final fallback

        Args:
            requested_max_tokens: The max_tokens parameter passed to generate_response()
            model: The model name to look up model-specific limits

        Returns:
            int: The resolved maximum tokens to use
        """
        # 1. Use explicit parameter if provided
        if requested_max_tokens is not None:
            return requested_max_tokens

        # 2. Use instance max_tokens if set during initialization
        if self.max_tokens is not None:
            return self.max_tokens

        # 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
        return self._get_max_tokens_for_model(model)

    def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
        """
        Attempt to salvage a JSON object if the raw output is truncated.

        This is accomplished by looking for the last closing bracket for an array or object.
        If found, it will try to load the JSON object from the raw output.
        If the JSON object is not valid, it will return None.

        Args:
            raw_output (str): The raw output from the LLM.

        Returns:
            dict[str, typing.Any]: The salvaged JSON object.
            None: If no salvage is possible.
        """
        if not raw_output:
            return None
        # Try to salvage a JSON array
        array_match = re.search(r'\]\s*$', raw_output)
        if array_match:
            try:
                return json.loads(raw_output[: array_match.end()])
            except Exception:
                pass
        # Try to salvage a JSON object
        obj_match = re.search(r'\}\s*$', raw_output)
        if obj_match:
            try:
                return json.loads(raw_output[: obj_match.end()])
            except Exception:
                pass
        return None

    async def _generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the Gemini language model.

        Args:
            messages (list[Message]): A list of messages to send to the language model.
            response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
            max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
            model_size (ModelSize): The size of the model to use (small or medium).

        Returns:
            dict[str, typing.Any]: The response from the language model.

        Raises:
            RateLimitError: If the API rate limit is exceeded.
            Exception: If there is an error generating the response or content is blocked.
        """
        try:
            gemini_messages: typing.Any = []
            # If a response model is provided, add schema for structured output
            system_prompt = ''
            if response_model is not None:
                # Get the schema from the Pydantic model
                pydantic_schema = response_model.model_json_schema()

                # Create instruction to output in the desired JSON format
                system_prompt += (
                    f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
                    'Do not include any explanatory text before or after the JSON.\n\n'
                )

            # Add messages content
            # First check for a system message
            if messages and messages[0].role == 'system':
                system_prompt = f'{messages[0].content}\n\n {system_prompt}'
                messages = messages[1:]

            # Add the rest of the messages
            for m in messages:
                m.content = self._clean_input(m.content)
                gemini_messages.append(
                    types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
                )

            # Get the appropriate model for the requested size
            model = self._get_model_for_size(model_size)

            # Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
            resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)

            # Create generation config
            generation_config = types.GenerateContentConfig(
                temperature=self.temperature,
                max_output_tokens=resolved_max_tokens,
                response_mime_type='application/json' if response_model else None,
                response_schema=response_model if response_model else None,
                system_instruction=system_prompt,
                thinking_config=self.thinking_config,
            )

            # Generate content using the simple string approach
            response = await self.client.aio.models.generate_content(
                model=model,
                contents=gemini_messages,
                config=generation_config,
            )

            # Always capture the raw output for debugging
            raw_output = getattr(response, 'text', None)

            # Check for safety and prompt blocks
            self._check_safety_blocks(response)
            self._check_prompt_blocks(response)

            # If this was a structured output request, parse the response into the Pydantic model
            if response_model is not None:
                try:
                    if not raw_output:
                        raise ValueError('No response text')

                    validated_model = response_model.model_validate(json.loads(raw_output))

                    # Return as a dictionary for API consistency
                    return validated_model.model_dump()
                except Exception as e:
                    if raw_output:
                        logger.error(
                            '🦀 LLM generation failed parsing as JSON, will try to salvage.'
                        )
                        logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
                        # Try to salvage
                        salvaged = self.salvage_json(raw_output)
                        if salvaged is not None:
                            logger.warning('Salvaged partial JSON from truncated/malformed output.')
                            return salvaged
                    raise Exception(f'Failed to parse structured response: {e}') from e

            # Otherwise, return the response text as a dictionary
            return {'content': raw_output}

        except Exception as e:
            # Check if it's a rate limit error based on Gemini API error codes
            error_message = str(e).lower()
            if (
                'rate limit' in error_message
                or 'quota' in error_message
                or 'resource_exhausted' in error_message
                or '429' in str(e)
            ):
                raise RateLimitError from e

            logger.error(f'Error in generating LLM response: {e}')
            raise Exception from e

    async def generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
        group_id: str | None = None,
        prompt_name: str | None = None,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the Gemini language model with retry logic and error handling.
        This method overrides the parent class method to provide a direct implementation with advanced retry logic.

        Args:
            messages (list[Message]): A list of messages to send to the language model.
            response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
            max_tokens (int | None): The maximum number of tokens to generate in the response.
            model_size (ModelSize): The size of the model to use (small or medium).
            group_id (str | None): Optional partition identifier for the graph.
            prompt_name (str | None): Optional name of the prompt for tracing.

        Returns:
            dict[str, typing.Any]: The response from the language model.
        """
        # Add multilingual extraction instructions
        messages[0].content += get_extraction_language_instruction(group_id)

        # Wrap entire operation in tracing span
        with self.tracer.start_span('llm.generate') as span:
            attributes = {
                'llm.provider': 'gemini',
                'model.size': model_size.value,
                'max_tokens': max_tokens or self.max_tokens,
            }
            if prompt_name:
                attributes['prompt.name'] = prompt_name
            span.add_attributes(attributes)

            retry_count = 0
            last_error = None
            last_output = None

            while retry_count < self.MAX_RETRIES:
                try:
                    response = await self._generate_response(
                        messages=messages,
                        response_model=response_model,
                        max_tokens=max_tokens,
                        model_size=model_size,
                    )
                    last_output = (
                        response.get('content')
                        if isinstance(response, dict) and 'content' in response
                        else None
                    )
                    return response
                except RateLimitError as e:
                    # Rate limit errors should not trigger retries (fail fast)
                    span.set_status('error', str(e))
                    raise e
                except Exception as e:
                    last_error = e

                    # Check if this is a safety block - these typically shouldn't be retried
                    error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
                    if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
                        logger.warning(f'Content blocked by safety filters: {e}')
                        span.set_status('error', str(e))
                        raise Exception(f'Content blocked by safety filters: {e}') from e

                    retry_count += 1

                    # Construct a detailed error message for the LLM
                    error_context = (
                        f'The previous response attempt was invalid. '
                        f'Error type: {e.__class__.__name__}. '
                        f'Error details: {str(e)}. '
                        f'Please try again with a valid response, ensuring the output matches '
                        f'the expected format and constraints.'
                    )

                    error_message = Message(role='user', content=error_context)
                    messages.append(error_message)
                    logger.warning(
                        f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
                    )

            # If we exit the loop without returning, all retries are exhausted
            logger.error('🦀 LLM generation failed and retries are exhausted.')
            logger.error(self._get_failed_generation_log(messages, last_output))
            logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
            span.set_status('error', str(last_error))
            span.record_exception(last_error) if last_error else None
            raise last_error or Exception('Max retries exceeded')

```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_edge_operations.py:
--------------------------------------------------------------------------------

```python
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
from pydantic import BaseModel

from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
    DEFAULT_EDGE_NAME,
    resolve_extracted_edge,
    resolve_extracted_edges,
)


@pytest.fixture
def mock_llm_client():
    client = MagicMock()
    client.generate_response = AsyncMock()
    return client


@pytest.fixture
def mock_extracted_edge():
    return EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='Test fact',
        episodes=['episode_1'],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )


@pytest.fixture
def mock_related_edges():
    return [
        EntityEdge(
            source_node_uuid='source_uuid_2',
            target_node_uuid='target_uuid_2',
            name='related_edge',
            group_id='group_1',
            fact='Related fact',
            episodes=['episode_2'],
            created_at=datetime.now(timezone.utc) - timedelta(days=1),
            valid_at=datetime.now(timezone.utc) - timedelta(days=1),
            invalid_at=None,
        )
    ]


@pytest.fixture
def mock_existing_edges():
    return [
        EntityEdge(
            source_node_uuid='source_uuid_3',
            target_node_uuid='target_uuid_3',
            name='existing_edge',
            group_id='group_1',
            fact='Existing fact',
            episodes=['episode_3'],
            created_at=datetime.now(timezone.utc) - timedelta(days=2),
            valid_at=datetime.now(timezone.utc) - timedelta(days=2),
            invalid_at=None,
        )
    ]


@pytest.fixture
def mock_current_episode():
    return EpisodicNode(
        uuid='episode_1',
        content='Current episode content',
        valid_at=datetime.now(timezone.utc),
        name='Current Episode',
        group_id='group_1',
        source='message',
        source_description='Test source description',
    )


@pytest.fixture
def mock_previous_episodes():
    return [
        EpisodicNode(
            uuid='episode_2',
            content='Previous episode content',
            valid_at=datetime.now(timezone.utc) - timedelta(days=1),
            name='Previous Episode',
            group_id='group_1',
            source='message',
            source_description='Test source description',
        )
    ]


# Run the tests
if __name__ == '__main__':
    pytest.main([__file__])


@pytest.mark.asyncio
async def test_resolve_extracted_edge_exact_fact_short_circuit(
    mock_llm_client,
    mock_existing_edges,
    mock_current_episode,
):
    extracted = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='Related fact',
        episodes=['episode_1'],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    related_edges = [
        EntityEdge(
            source_node_uuid='source_uuid',
            target_node_uuid='target_uuid',
            name='related_edge',
            group_id='group_1',
            fact=' related FACT  ',
            episodes=['episode_2'],
            created_at=datetime.now(timezone.utc) - timedelta(days=1),
            valid_at=None,
            invalid_at=None,
        )
    ]

    resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
        mock_llm_client,
        extracted,
        related_edges,
        mock_existing_edges,
        mock_current_episode,
        edge_type_candidates=None,
    )

    assert resolved_edge is related_edges[0]
    assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
    assert duplicate_edges == []
    assert invalidated == []
    mock_llm_client.generate_response.assert_not_called()


class OccurredAtEdge(BaseModel):
    """Edge model stub for OCCURRED_AT."""


@pytest.mark.asyncio
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
    from graphiti_core.utils.maintenance import edge_operations as edge_ops

    monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
    monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))

    async def immediate_gather(*aws, max_coroutines=None):
        return [await aw for aw in aws]

    monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
    monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'duplicate_facts': [],
            'contradicted_facts': [],
            'fact_type': 'DEFAULT',
        }
    )

    clients = SimpleNamespace(
        driver=MagicMock(),
        llm_client=llm_client,
        embedder=MagicMock(),
        cross_encoder=MagicMock(),
    )

    source_node = EntityNode(
        uuid='source_uuid',
        name='Document Node',
        group_id='group_1',
        labels=['Document'],
    )
    target_node = EntityNode(
        uuid='target_uuid',
        name='Topic Node',
        group_id='group_1',
        labels=['Topic'],
    )

    extracted_edge = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='OCCURRED_AT',
        group_id='group_1',
        fact='Document occurred at somewhere',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    edge_types = {'OCCURRED_AT': OccurredAtEdge}
    edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}

    resolved_edges, invalidated_edges = await resolve_extracted_edges(
        clients,
        [extracted_edge],
        episode,
        [source_node, target_node],
        edge_types,
        edge_type_map,
    )

    assert resolved_edges[0].name == DEFAULT_EDGE_NAME
    assert invalidated_edges == []


@pytest.mark.asyncio
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
    from graphiti_core.utils.maintenance import edge_operations as edge_ops

    monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
    monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))

    async def immediate_gather(*aws, max_coroutines=None):
        return [await aw for aw in aws]

    monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
    monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'duplicate_facts': [],
            'contradicted_facts': [],
            'fact_type': 'DEFAULT',
        }
    )

    clients = SimpleNamespace(
        driver=MagicMock(),
        llm_client=llm_client,
        embedder=MagicMock(),
        cross_encoder=MagicMock(),
    )

    source_node = EntityNode(
        uuid='source_uuid',
        name='User Node',
        group_id='group_1',
        labels=['User'],
    )
    target_node = EntityNode(
        uuid='target_uuid',
        name='Topic Node',
        group_id='group_1',
        labels=['Topic'],
    )

    extracted_edge = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='INTERACTED_WITH',
        group_id='group_1',
        fact='User interacted with topic',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    edge_types = {'OCCURRED_AT': OccurredAtEdge}
    edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}

    resolved_edges, invalidated_edges = await resolve_extracted_edges(
        clients,
        [extracted_edge],
        episode,
        [source_node, target_node],
        edge_types,
        edge_type_map,
    )

    assert resolved_edges[0].name == 'INTERACTED_WITH'
    assert invalidated_edges == []


@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
    mock_llm_client.generate_response.return_value = {
        'duplicate_facts': [],
        'contradicted_facts': [],
        'fact_type': 'OCCURRED_AT',
    }

    extracted_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='OCCURRED_AT',
        group_id='group_1',
        fact='Document occurred at somewhere',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    related_edge = EntityEdge(
        source_node_uuid='alt_source',
        target_node_uuid='alt_target',
        name='OTHER',
        group_id='group_1',
        fact='Different fact',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
        mock_llm_client,
        extracted_edge,
        [related_edge],
        [],
        episode,
        edge_type_candidates={},
        custom_edge_type_names={'OCCURRED_AT'},
    )

    assert resolved_edge.name == DEFAULT_EDGE_NAME
    assert duplicates == []
    assert invalidated == []


@pytest.mark.asyncio
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
    mock_llm_client.generate_response.return_value = {
        'duplicate_facts': [],
        'contradicted_facts': [],
        'fact_type': 'INTERACTED_WITH',
    }

    extracted_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='DEFAULT',
        group_id='group_1',
        fact='User interacted with topic',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    related_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='DEFAULT',
        group_id='group_1',
        fact='User mentioned a topic',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
        mock_llm_client,
        extracted_edge,
        [related_edge],
        [],
        episode,
        edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
        custom_edge_type_names={'OCCURRED_AT'},
    )

    assert resolved_edge.name == 'INTERACTED_WITH'
    assert resolved_edge.attributes == {}
    assert duplicates == []
    assert invalidated == []


@pytest.mark.asyncio
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
    """Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
    # Mock LLM to return duplicate_facts with integer indices
    mock_llm_client.generate_response.return_value = {
        'duplicate_facts': [0, 1],  # LLM identifies first two related edges as duplicates
        'contradicted_facts': [],
        'fact_type': 'DEFAULT',
    }

    extracted_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User likes yoga',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    # Create multiple related edges - LLM should receive these with integer indices
    related_edge_0 = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User enjoys yoga',
        episodes=['episode_1'],
        created_at=datetime.now(timezone.utc) - timedelta(days=1),
        valid_at=None,
        invalid_at=None,
    )

    related_edge_1 = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User practices yoga',
        episodes=['episode_2'],
        created_at=datetime.now(timezone.utc) - timedelta(days=2),
        valid_at=None,
        invalid_at=None,
    )

    related_edge_2 = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User loves swimming',
        episodes=['episode_3'],
        created_at=datetime.now(timezone.utc) - timedelta(days=3),
        valid_at=None,
        invalid_at=None,
    )

    related_edges = [related_edge_0, related_edge_1, related_edge_2]

    resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
        mock_llm_client,
        extracted_edge,
        related_edges,
        [],
        episode,
        edge_type_candidates=None,
        custom_edge_type_names=set(),
    )

    # Verify LLM was called
    mock_llm_client.generate_response.assert_called_once()

    # Verify the system correctly identified duplicates using integer indices
    # The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
    assert len(duplicates) == 2
    assert related_edge_0 in duplicates
    assert related_edge_1 in duplicates
    assert invalidated == []

    # Verify that the resolved edge is one of the duplicates (the first one found)
    # Check UUID since the episode list gets modified
    assert resolved_edge.uuid == related_edge_0.uuid
    assert episode.uuid in resolved_edge.episodes


@pytest.mark.asyncio
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
    """Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
    from graphiti_core.utils.maintenance import edge_operations as edge_ops

    monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
    monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))

    # Track how many times resolve_extracted_edge is called
    resolve_call_count = 0

    async def mock_resolve_extracted_edge(
        llm_client,
        extracted_edge,
        related_edges,
        existing_edges,
        episode,
        edge_type_candidates=None,
        custom_edge_type_names=None,
    ):
        nonlocal resolve_call_count
        resolve_call_count += 1
        return extracted_edge, [], []

    # Mock semaphore_gather to execute awaitable immediately
    async def immediate_gather(*aws, max_coroutines=None):
        results = []
        for aw in aws:
            results.append(await aw)
        return results

    monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
    monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
    monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)

    llm_client = MagicMock()
    clients = SimpleNamespace(
        driver=MagicMock(),
        llm_client=llm_client,
        embedder=MagicMock(),
        cross_encoder=MagicMock(),
    )

    source_node = EntityNode(
        uuid='source_uuid',
        name='Assistant',
        group_id='group_1',
        labels=['Entity'],
    )
    target_node = EntityNode(
        uuid='target_uuid',
        name='User',
        group_id='group_1',
        labels=['Entity'],
    )

    # Create 3 identical edges
    edge1 = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='recommends',
        group_id='group_1',
        fact='assistant recommends yoga poses',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    edge2 = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='recommends',
        group_id='group_1',
        fact='  Assistant Recommends YOGA Poses  ',  # Different whitespace/case
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    edge3 = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='recommends',
        group_id='group_1',
        fact='assistant recommends yoga poses',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    resolved_edges, invalidated_edges = await resolve_extracted_edges(
        clients,
        [edge1, edge2, edge3],
        episode,
        [source_node, target_node],
        {},
        {},
    )

    # Fast path should have deduplicated the 3 identical edges to 1
    # So resolve_extracted_edge should only be called once
    assert resolve_call_count == 1
    assert len(resolved_edges) == 1
    assert invalidated_edges == []

```

--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import typing
from datetime import datetime

import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any

from graphiti_core.driver.driver import (
    GraphDriver,
    GraphDriverSession,
    GraphProvider,
)
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
    get_entity_edge_save_bulk_query,
    get_episodic_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
    get_entity_node_save_bulk_query,
    get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
    DedupResolutionState,
    _build_candidate_indexes,
    _normalize_string_exact,
    _resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
    extract_edges,
    resolve_extracted_edge,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
    EPISODE_WINDOW_LEN,
    retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
    extract_nodes,
    resolve_extracted_nodes,
)

logger = logging.getLogger(__name__)

CHUNK_SIZE = 10


def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
    """Collapse alias -> canonical chains while preserving direction.

    The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
    union-find with iterative path compression to ensure every source UUID resolves to its ultimate
    canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
    """

    parent: dict[str, str] = {}

    def find(uuid: str) -> str:
        """Directed union-find lookup using iterative path compression."""
        parent.setdefault(uuid, uuid)
        root = uuid
        while parent[root] != root:
            root = parent[root]

        while parent[uuid] != root:
            next_uuid = parent[uuid]
            parent[uuid] = root
            uuid = next_uuid

        return root

    for source_uuid, target_uuid in pairs:
        parent.setdefault(source_uuid, source_uuid)
        parent.setdefault(target_uuid, target_uuid)
        parent[find(source_uuid)] = find(target_uuid)

    return {uuid: find(uuid) for uuid in parent}


class RawEpisode(BaseModel):
    name: str
    uuid: str | None = Field(default=None)
    content: str
    source_description: str
    source: EpisodeType
    reference_time: datetime


async def retrieve_previous_episodes_bulk(
    driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
    previous_episodes_list = await semaphore_gather(
        *[
            retrieve_episodes(
                driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
            )
            for episode in episodes
        ]
    )
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
        (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
    ]

    return episode_tuples


async def add_nodes_and_edges_bulk(
    driver: GraphDriver,
    episodic_nodes: list[EpisodicNode],
    episodic_edges: list[EpisodicEdge],
    entity_nodes: list[EntityNode],
    entity_edges: list[EntityEdge],
    embedder: EmbedderClient,
):
    session = driver.session()
    try:
        await session.execute_write(
            add_nodes_and_edges_bulk_tx,
            episodic_nodes,
            episodic_edges,
            entity_nodes,
            entity_edges,
            embedder,
            driver=driver,
        )
    finally:
        await session.close()


async def add_nodes_and_edges_bulk_tx(
    tx: GraphDriverSession,
    episodic_nodes: list[EpisodicNode],
    episodic_edges: list[EpisodicEdge],
    entity_nodes: list[EntityNode],
    entity_edges: list[EntityEdge],
    embedder: EmbedderClient,
    driver: GraphDriver,
):
    episodes = [dict(episode) for episode in episodic_nodes]
    for episode in episodes:
        episode['source'] = str(episode['source'].value)
        episode.pop('labels', None)

    nodes = []

    for node in entity_nodes:
        if node.name_embedding is None:
            await node.generate_name_embedding(embedder)

        entity_data: dict[str, Any] = {
            'uuid': node.uuid,
            'name': node.name,
            'group_id': node.group_id,
            'summary': node.summary,
            'created_at': node.created_at,
            'name_embedding': node.name_embedding,
            'labels': list(set(node.labels + ['Entity'])),
        }

        if driver.provider == GraphProvider.KUZU:
            attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
            entity_data['attributes'] = json.dumps(attributes)
        else:
            entity_data.update(node.attributes or {})

        nodes.append(entity_data)

    edges = []
    for edge in entity_edges:
        if edge.fact_embedding is None:
            await edge.generate_embedding(embedder)
        edge_data: dict[str, Any] = {
            'uuid': edge.uuid,
            'source_node_uuid': edge.source_node_uuid,
            'target_node_uuid': edge.target_node_uuid,
            'name': edge.name,
            'fact': edge.fact,
            'group_id': edge.group_id,
            'episodes': edge.episodes,
            'created_at': edge.created_at,
            'expired_at': edge.expired_at,
            'valid_at': edge.valid_at,
            'invalid_at': edge.invalid_at,
            'fact_embedding': edge.fact_embedding,
        }

        if driver.provider == GraphProvider.KUZU:
            attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
            edge_data['attributes'] = json.dumps(attributes)
        else:
            edge_data.update(edge.attributes or {})

        edges.append(edge_data)

    if driver.graph_operations_interface:
        await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
        await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
        await driver.graph_operations_interface.episodic_edge_save_bulk(
            None, driver, tx, [edge.model_dump() for edge in episodic_edges]
        )
        await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)

    elif driver.provider == GraphProvider.KUZU:
        # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
        episode_query = get_episode_node_save_bulk_query(driver.provider)
        for episode in episodes:
            await tx.run(episode_query, **episode)
        entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
        for node in nodes:
            await tx.run(entity_node_query, **node)
        entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
        for edge in edges:
            await tx.run(entity_edge_query, **edge)
        episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
        for edge in episodic_edges:
            await tx.run(episodic_edge_query, **edge.model_dump())
    else:
        await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
        await tx.run(
            get_entity_node_save_bulk_query(driver.provider, nodes),
            nodes=nodes,
        )
        await tx.run(
            get_episodic_edge_save_bulk_query(driver.provider),
            episodic_edges=[edge.model_dump() for edge in episodic_edges],
        )
        await tx.run(
            get_entity_edge_save_bulk_query(driver.provider),
            entity_edges=edges,
        )


async def extract_nodes_and_edges_bulk(
    clients: GraphitiClients,
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
    edge_type_map: dict[tuple[str, str], list[str]],
    entity_types: dict[str, type[BaseModel]] | None = None,
    excluded_entity_types: list[str] | None = None,
    edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
    extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
        *[
            extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
            for episode, previous_episodes in episode_tuples
        ]
    )

    extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
        *[
            extract_edges(
                clients,
                episode,
                extracted_nodes_bulk[i],
                previous_episodes,
                edge_type_map=edge_type_map,
                group_id=episode.group_id,
                edge_types=edge_types,
            )
            for i, (episode, previous_episodes) in enumerate(episode_tuples)
        ]
    )

    return extracted_nodes_bulk, extracted_edges_bulk


async def dedupe_nodes_bulk(
    clients: GraphitiClients,
    extracted_nodes: list[list[EntityNode]],
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
    entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
    """Resolve entity duplicates across an in-memory batch using a two-pass strategy.

    1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
       reconciled against the live graph just like the non-batch flow.
    2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
       duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
       can apply to edges and persistence.
    """

    first_pass_results = await semaphore_gather(
        *[
            resolve_extracted_nodes(
                clients,
                nodes,
                episode_tuples[i][0],
                episode_tuples[i][1],
                entity_types,
            )
            for i, nodes in enumerate(extracted_nodes)
        ]
    )

    episode_resolutions: list[tuple[str, list[EntityNode]]] = []
    per_episode_uuid_maps: list[dict[str, str]] = []
    duplicate_pairs: list[tuple[str, str]] = []

    for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
        first_pass_results, episode_tuples, strict=True
    ):
        episode_resolutions.append((episode.uuid, resolved_nodes))
        per_episode_uuid_maps.append(uuid_map)
        duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)

    canonical_nodes: dict[str, EntityNode] = {}
    for _, resolved_nodes in episode_resolutions:
        for node in resolved_nodes:
            # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
            # the MinHash index for the accumulated canonical pool each time. The LRU-backed
            # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
            # but if batches grow significantly we should switch to an incremental index or chunked
            # processing.
            if not canonical_nodes:
                canonical_nodes[node.uuid] = node
                continue

            existing_candidates = list(canonical_nodes.values())
            normalized = _normalize_string_exact(node.name)
            exact_match = next(
                (
                    candidate
                    for candidate in existing_candidates
                    if _normalize_string_exact(candidate.name) == normalized
                ),
                None,
            )
            if exact_match is not None:
                if exact_match.uuid != node.uuid:
                    duplicate_pairs.append((node.uuid, exact_match.uuid))
                continue

            indexes = _build_candidate_indexes(existing_candidates)
            state = DedupResolutionState(
                resolved_nodes=[None],
                uuid_map={},
                unresolved_indices=[],
            )
            _resolve_with_similarity([node], indexes, state)

            resolved = state.resolved_nodes[0]
            if resolved is None:
                canonical_nodes[node.uuid] = node
                continue

            canonical_uuid = resolved.uuid
            canonical_nodes.setdefault(canonical_uuid, resolved)
            if canonical_uuid != node.uuid:
                duplicate_pairs.append((node.uuid, canonical_uuid))

    union_pairs: list[tuple[str, str]] = []
    for uuid_map in per_episode_uuid_maps:
        union_pairs.extend(uuid_map.items())
    union_pairs.extend(duplicate_pairs)

    compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)

    nodes_by_episode: dict[str, list[EntityNode]] = {}
    for episode_uuid, resolved_nodes in episode_resolutions:
        deduped_nodes: list[EntityNode] = []
        seen: set[str] = set()
        for node in resolved_nodes:
            canonical_uuid = compressed_map.get(node.uuid, node.uuid)
            if canonical_uuid in seen:
                continue
            seen.add(canonical_uuid)
            canonical_node = canonical_nodes.get(canonical_uuid)
            if canonical_node is None:
                logger.error(
                    'Canonical node %s missing during batch dedupe; falling back to %s',
                    canonical_uuid,
                    node.uuid,
                )
                canonical_node = node
            deduped_nodes.append(canonical_node)

        nodes_by_episode[episode_uuid] = deduped_nodes

    return nodes_by_episode, compressed_map


async def dedupe_edges_bulk(
    clients: GraphitiClients,
    extracted_edges: list[list[EntityEdge]],
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
    _entities: list[EntityNode],
    edge_types: dict[str, type[BaseModel]],
    _edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
    embedder = clients.embedder
    min_score = 0.6

    # generate embeddings
    await semaphore_gather(
        *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
    )

    # Find similar results
    dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
    for i, edges_i in enumerate(extracted_edges):
        existing_edges: list[EntityEdge] = []
        for edges_j in extracted_edges:
            existing_edges += edges_j

        for edge in edges_i:
            candidates: list[EntityEdge] = []
            for existing_edge in existing_edges:
                # Skip self-comparison
                if edge.uuid == existing_edge.uuid:
                    continue
                # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
                # This approach will cast a wider net than BM25, which is ideal for this use case
                if (
                    edge.source_node_uuid != existing_edge.source_node_uuid
                    or edge.target_node_uuid != existing_edge.target_node_uuid
                ):
                    continue

                edge_words = set(edge.fact.lower().split())
                existing_edge_words = set(existing_edge.fact.lower().split())
                has_overlap = not edge_words.isdisjoint(existing_edge_words)
                if has_overlap:
                    candidates.append(existing_edge)
                    continue

                # Check for semantic similarity even if there is no overlap
                similarity = np.dot(
                    normalize_l2(edge.fact_embedding or []),
                    normalize_l2(existing_edge.fact_embedding or []),
                )
                if similarity >= min_score:
                    candidates.append(existing_edge)

            dedupe_tuples.append((episode_tuples[i][0], edge, candidates))

    bulk_edge_resolutions: list[
        tuple[EntityEdge, EntityEdge, list[EntityEdge]]
    ] = await semaphore_gather(
        *[
            resolve_extracted_edge(
                clients.llm_client,
                edge,
                candidates,
                candidates,
                episode,
                edge_types,
                set(edge_types),
            )
            for episode, edge, candidates in dedupe_tuples
        ]
    )

    # For now we won't track edge invalidation
    duplicate_pairs: list[tuple[str, str]] = []
    for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
        episode, edge, candidates = dedupe_tuples[i]
        for duplicate in duplicates:
            duplicate_pairs.append((edge.uuid, duplicate.uuid))

    # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
    compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)

    edge_uuid_map: dict[str, EntityEdge] = {
        edge.uuid: edge for edges in extracted_edges for edge in edges
    }

    edges_by_episode: dict[str, list[EntityEdge]] = {}
    for i, edges in enumerate(extracted_edges):
        episode = episode_tuples[i][0]

        edges_by_episode[episode.uuid] = [
            edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
        ]

    return edges_by_episode


class UnionFind:
    def __init__(self, elements):
        # start each element in its own set
        self.parent = {e: e for e in elements}

    def find(self, x):
        # path‐compression
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        # attach the lexicographically larger root under the smaller
        if ra < rb:
            self.parent[rb] = ra
        else:
            self.parent[ra] = rb


def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
    """
    all_ids: iterable of all entity IDs (strings)
    duplicate_pairs: iterable of (id1, id2) pairs
    returns: dict mapping each id -> lexicographically smallest id in its duplicate set
    """
    all_uuids = set()
    for pair in duplicate_pairs:
        all_uuids.add(pair[0])
        all_uuids.add(pair[1])

    uf = UnionFind(all_uuids)
    for a, b in duplicate_pairs:
        uf.union(a, b)
    # ensure full path‐compression before mapping
    return {uuid: uf.find(uuid) for uuid in all_uuids}


E = typing.TypeVar('E', bound=Edge)


def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
    for edge in edges:
        source_uuid = edge.source_node_uuid
        target_uuid = edge.target_node_uuid
        edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
        edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)

    return edges

```
Page 5/9FirstPrevNextLast