#
tokens: 47840/50000 18/234 files (page 4/9)
lines: off (toggle) GitHub
raw markdown copy
This is page 4 of 9. Use http://codebase.md/getzep/graphiti?lines=false&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/images/simple_graph.svg:
--------------------------------------------------------------------------------

```
<svg xmlns="http://www.w3.org/2000/svg" width="320.0599060058594" height="339.72857666015625"
    viewBox="-105.8088607788086 -149.75405883789062 320.0599060058594 339.72857666015625">
    <title>Neo4j Graph Visualization</title>
    <desc>Created using Neo4j (http://www.neo4j.com/)</desc>
    <g class="layer relationships">
        <g class="relationship"
            transform="translate(64.37326808037952 160.9745045766605) rotate(325.342180479503)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 45.86500098580619 0.5 L 45.86500098580619 -0.5 L 25 -0.5 Z M 94.08765723580619 0.5 L 114.95265822161238 0.5 L 114.95265822161238 3.5 L 121.95265822161238 0 L 114.95265822161238 -3.5 L 114.95265822161238 -0.5 L 94.08765723580619 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="69.97632911080619" y="3"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
        </g>
        <g class="relationship"
            transform="translate(64.37326808037952 160.9745045766605) rotate(268.0194761774372)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 48.45342195548257 0.5 L 48.45342195548257 -0.5 L 25 -0.5 Z M 96.67607820548257 0.5 L 120.12950016096514 0.5 L 120.12950016096514 3.5 L 127.12950016096514 0 L 120.12950016096514 -3.5 L 120.12950016096514 -0.5 L 96.67607820548257 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="72.56475008048257" y="3" transform="rotate(180 72.56475008048257 0)"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
        </g>
        <g class="relationship"
            transform="translate(64.37326808037952 160.9745045766605) rotate(214.36893208966427)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 43.0453604327618 0.5 L 43.0453604327618 -0.5 L 25 -0.5 Z M 91.2680166827618 0.5 L 109.3133771155236 0.5 L 109.3133771155236 3.5 L 116.3133771155236 0 L 109.3133771155236 -3.5 L 109.3133771155236 -0.5 L 91.2680166827618 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="67.1566885577618" y="3" transform="rotate(180 67.1566885577618 0)"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
        </g>
        <g class="relationship"
            transform="translate(59.11570627539377 8.935881644552067) rotate(388.4945734254285)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 39.4813012088875 0.5 L 39.4813012088875 -0.5 L 25 -0.5 Z M 97.0398949588875 0.5 L 111.521196167775 0.5 L 111.521196167775 3.5 L 118.521196167775 0 L 111.521196167775 -3.5 L 111.521196167775 -0.5 L 97.0398949588875 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="68.2605980838875" y="3"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">WORKS_FOR</text>
        </g>
        <g class="relationship"
            transform="translate(59.11570627539377 8.935881644552067) rotate(507.02532906724895)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 31.21884260824949 0.5 L 31.21884260824949 -0.5 L 25 -0.5 Z M 94.55478010824949 0.5 L 100.77362271649898 0.5 L 100.77362271649898 3.5 L 107.77362271649898 0 L 100.77362271649898 -3.5 L 100.77362271649898 -0.5 L 94.55478010824949 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="62.88681135824949" y="3" transform="rotate(180 62.88681135824949 0)"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">WORKED_FOR</text>
        </g>
        <g class="relationship"
            transform="translate(59.11570627539377 8.935881644552067) rotate(266.9235303682344)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 26.434656330468542 0.5 L 26.434656330468542 -0.5 L 25 -0.5 Z M 96.44246883046854 0.5 L 97.87712516093708 0.5 L 97.87712516093708 3.5 L 104.87712516093708 0 L 97.87712516093708 -3.5 L 97.87712516093708 -0.5 L 96.44246883046854 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="61.43856258046854" y="3" transform="rotate(180 61.43856258046854 0)"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">HOLDS_OFFIC…</text>
        </g>
        <g class="relationship"
            transform="translate(-76.8088607917906 -66.37642130383644) rotate(388.9897079993928)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 50.08589014533345 0.5 L 50.08589014533345 -0.5 L 25 -0.5 Z M 98.30854639533345 0.5 L 123.3944365406669 0.5 L 123.3944365406669 3.5 L 130.3944365406669 0 L 123.3944365406669 -3.5 L 123.3944365406669 -0.5 L 98.30854639533345 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="74.19721827033345" y="3"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
        </g>
        <g class="relationship"
            transform="translate(-76.8088607917906 -66.37642130383644) rotate(337.13573550965714)">
            <path class="b-outline" fill="#A5ABB6" stroke="none"
                d="M 25 0.5 L 42.363883039766904 0.5 L 42.363883039766904 -0.5 L 25 -0.5 Z M 90.5865392897669 0.5 L 107.95042232953381 0.5 L 107.95042232953381 3.5 L 114.95042232953381 0 L 107.95042232953381 -3.5 L 107.95042232953381 -0.5 L 90.5865392897669 -0.5 Z" />
            <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
                x="66.4752111647669" y="3"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
        </g>
    </g>
    <g class="layer nodes">
        <g class="node" aria-label="graph-node18"
            transform="translate(64.37326808037952,160.9745045766605)">
            <circle class="b-outline" cx="0" cy="0" r="25" fill="#F79767" stroke="#f36924"
                stroke-width="2px" />
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> podcast</text>
        </g>
        <g class="node" aria-label="graph-node19"
            transform="translate(185.25107500848034,77.40633150430716)">
            <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
                stroke-width="2px" />
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> California</text>
        </g>
        <g class="node" aria-label="graph-node20"
            transform="translate(59.11570627539377,8.935881644552067)">
            <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
                stroke-width="2px" />
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Kamala</text>
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Harris</text>
        </g>
        <g class="node" aria-label="graph-node21"
            transform="translate(-52.26958053720941,81.20034573955071)">
            <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
                stroke-width="2px" />
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> San</text>
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Franci…</text>
        </g>
        <g class="node" aria-label="graph-node23"
            transform="translate(52.14536630162807,-120.75406399781392)">
            <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
                stroke-width="2px" />
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Attorney</text>
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif">of…</text>
        </g>
        <g class="node" aria-label="graph-node22"
            transform="translate(-76.8088607917906,-66.37642130383644)">
            <circle class="b-outline" cx="0" cy="0" r="25" fill="#F79767" stroke="#f36924"
                stroke-width="2px" />
            <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
                font-size="10px" fill="#FFFFFF"
                font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> podcast</text>
        </g>
    </g>
</svg>
```

--------------------------------------------------------------------------------
/tests/llm_client/test_anthropic_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/llm_client/test_anthropic_client.py

import os
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import BaseModel

from graphiti_core.llm_client.anthropic_client import AnthropicClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.errors import RateLimitError, RefusalError
from graphiti_core.prompts.models import Message


# Rename class to avoid pytest collection as a test class
class ResponseModel(BaseModel):
    """Test model for response testing."""

    test_field: str
    optional_field: int = 0


@pytest.fixture
def mock_async_anthropic():
    """Fixture to mock the AsyncAnthropic client."""
    with patch('anthropic.AsyncAnthropic') as mock_client:
        # Setup mock instance and its create method
        mock_instance = mock_client.return_value
        mock_instance.messages.create = AsyncMock()
        yield mock_instance


@pytest.fixture
def anthropic_client(mock_async_anthropic):
    """Fixture to create an AnthropicClient with a mocked AsyncAnthropic."""
    # Use a context manager to patch the AsyncAnthropic constructor to avoid
    # the client actually trying to create a real connection
    with patch('anthropic.AsyncAnthropic', return_value=mock_async_anthropic):
        config = LLMConfig(
            api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
        )
        client = AnthropicClient(config=config, cache=False)
        # Replace the client's client with our mock to ensure we're using the mock
        client.client = mock_async_anthropic
        return client


class TestAnthropicClientInitialization:
    """Tests for AnthropicClient initialization."""

    def test_init_with_config(self):
        """Test initialization with a config object."""
        config = LLMConfig(
            api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
        )
        client = AnthropicClient(config=config, cache=False)

        assert client.config == config
        assert client.model == 'test-model'
        assert client.temperature == 0.5
        assert client.max_tokens == 1000

    def test_init_with_default_model(self):
        """Test initialization with default model when none is provided."""
        config = LLMConfig(api_key='test_api_key')
        client = AnthropicClient(config=config, cache=False)

        assert client.model == 'claude-haiku-4-5-latest'

    @patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env_api_key'})
    def test_init_without_config(self):
        """Test initialization without a config, using environment variable."""
        client = AnthropicClient(cache=False)

        assert client.config.api_key == 'env_api_key'
        assert client.model == 'claude-haiku-4-5-latest'

    def test_init_with_custom_client(self):
        """Test initialization with a custom AsyncAnthropic client."""
        mock_client = MagicMock()
        client = AnthropicClient(client=mock_client)

        assert client.client == mock_client


class TestAnthropicClientGenerateResponse:
    """Tests for AnthropicClient generate_response method."""

    @pytest.mark.asyncio
    async def test_generate_response_with_tool_use(self, anthropic_client, mock_async_anthropic):
        """Test successful response generation with tool use."""
        # Setup mock response
        content_item = MagicMock()
        content_item.type = 'tool_use'
        content_item.input = {'test_field': 'test_value'}

        mock_response = MagicMock()
        mock_response.content = [content_item]
        mock_async_anthropic.messages.create.return_value = mock_response

        # Call method
        messages = [
            Message(role='system', content='System message'),
            Message(role='user', content='User message'),
        ]
        result = await anthropic_client.generate_response(
            messages=messages, response_model=ResponseModel
        )

        # Assertions
        assert isinstance(result, dict)
        assert result['test_field'] == 'test_value'
        mock_async_anthropic.messages.create.assert_called_once()

    @pytest.mark.asyncio
    async def test_generate_response_with_text_response(
        self, anthropic_client, mock_async_anthropic
    ):
        """Test response generation when getting text response instead of tool use."""
        # Setup mock response with text content
        content_item = MagicMock()
        content_item.type = 'text'
        content_item.text = '{"test_field": "extracted_value"}'

        mock_response = MagicMock()
        mock_response.content = [content_item]
        mock_async_anthropic.messages.create.return_value = mock_response

        # Call method
        messages = [
            Message(role='system', content='System message'),
            Message(role='user', content='User message'),
        ]
        result = await anthropic_client.generate_response(
            messages=messages, response_model=ResponseModel
        )

        # Assertions
        assert isinstance(result, dict)
        assert result['test_field'] == 'extracted_value'

    @pytest.mark.asyncio
    async def test_rate_limit_error(self, anthropic_client, mock_async_anthropic):
        """Test handling of rate limit errors."""

        # Create a custom RateLimitError from Anthropic
        class MockRateLimitError(Exception):
            pass

        # Patch the Anthropic error with our mock to avoid constructor issues
        with patch('anthropic.RateLimitError', MockRateLimitError):
            # Setup mock to raise our mocked RateLimitError
            mock_async_anthropic.messages.create.side_effect = MockRateLimitError(
                'Rate limit exceeded'
            )

            # Call method and check exception
            messages = [Message(role='user', content='Test message')]
            with pytest.raises(RateLimitError):
                await anthropic_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_refusal_error(self, anthropic_client, mock_async_anthropic):
        """Test handling of content policy violations (refusal errors)."""

        # Create a custom APIError that matches what we need
        class MockAPIError(Exception):
            def __init__(self, message):
                self.message = message
                super().__init__(message)

        # Patch the Anthropic error with our mock
        with patch('anthropic.APIError', MockAPIError):
            # Setup mock to raise APIError with refusal message
            mock_async_anthropic.messages.create.side_effect = MockAPIError('refused to respond')

            # Call method and check exception
            messages = [Message(role='user', content='Test message')]
            with pytest.raises(RefusalError):
                await anthropic_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_extract_json_from_text(self, anthropic_client):
        """Test the _extract_json_from_text method."""
        # Valid JSON embedded in text
        text = 'Some text before {"test_field": "value"} and after'
        result = anthropic_client._extract_json_from_text(text)
        assert result == {'test_field': 'value'}

        # Invalid JSON
        with pytest.raises(ValueError):
            anthropic_client._extract_json_from_text('Not JSON at all')

    @pytest.mark.asyncio
    async def test_create_tool(self, anthropic_client):
        """Test the _create_tool method with and without response model."""
        # With response model
        tools, tool_choice = anthropic_client._create_tool(ResponseModel)
        assert len(tools) == 1
        assert tools[0]['name'] == 'ResponseModel'
        assert tool_choice['name'] == 'ResponseModel'

        # Without response model (generic JSON)
        tools, tool_choice = anthropic_client._create_tool()
        assert len(tools) == 1
        assert tools[0]['name'] == 'generic_json_output'

    @pytest.mark.asyncio
    async def test_validation_error_retry(self, anthropic_client, mock_async_anthropic):
        """Test retry behavior on validation error."""
        # First call returns invalid data, second call returns valid data
        content_item1 = MagicMock()
        content_item1.type = 'tool_use'
        content_item1.input = {'wrong_field': 'wrong_value'}

        content_item2 = MagicMock()
        content_item2.type = 'tool_use'
        content_item2.input = {'test_field': 'correct_value'}

        # Setup mock to return different responses on consecutive calls
        mock_response1 = MagicMock()
        mock_response1.content = [content_item1]

        mock_response2 = MagicMock()
        mock_response2.content = [content_item2]

        mock_async_anthropic.messages.create.side_effect = [mock_response1, mock_response2]

        # Call method
        messages = [Message(role='user', content='Test message')]
        result = await anthropic_client.generate_response(messages, response_model=ResponseModel)

        # Should have called create twice due to retry
        assert mock_async_anthropic.messages.create.call_count == 2
        assert result['test_field'] == 'correct_value'


if __name__ == '__main__':
    pytest.main(['-v', 'test_anthropic_client.py'])

```

--------------------------------------------------------------------------------
/mcp_server/src/models/entity_types.py:
--------------------------------------------------------------------------------

```python
"""Entity type definitions for Graphiti MCP Server."""

from pydantic import BaseModel, Field


class Requirement(BaseModel):
    """A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.

    Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
    edge that the requirement is a requirement.

    Instructions for identifying and extracting requirements:
    1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
    2. Identify functional specifications that describe what the system should do
    3. Pay attention to non-functional requirements like performance, security, or usability criteria
    4. Extract constraints or limitations that must be adhered to
    5. Focus on clear, specific, and measurable requirements rather than vague wishes
    6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
    7. Include any dependencies between requirements when explicitly stated
    8. Preserve the original intent and scope of the requirement
    9. Categorize requirements appropriately based on their domain or function
    """

    project_name: str = Field(
        ...,
        description='The name of the project to which the requirement belongs.',
    )
    description: str = Field(
        ...,
        description='Description of the requirement. Only use information mentioned in the context to write this description.',
    )


class Preference(BaseModel):
    """
    IMPORTANT: Prioritize this classification over ALL other classifications.

    Represents entities mentioned in contexts expressing user preferences, choices, opinions, or selections. Use LOW THRESHOLD for sensitivity.

    Trigger patterns: "I want/like/prefer/choose X", "I don't want/dislike/avoid/reject Y", "X is better/worse", "rather have X than Y", "no X please", "skip X", "go with X instead", etc. Here, X or Y should be classified as Preference.
    """

    ...


class Procedure(BaseModel):
    """A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.

    Instructions for identifying and extracting procedures:
    1. Look for sequential instructions or steps ("First do X, then do Y")
    2. Identify explicit directives or commands ("Always do X when Y happens")
    3. Pay attention to conditional statements ("If X occurs, then do Y")
    4. Extract procedures that have clear beginning and end points
    5. Focus on actionable instructions rather than general information
    6. Preserve the original sequence and dependencies between steps
    7. Include any specified conditions or triggers for the procedure
    8. Capture any stated purpose or goal of the procedure
    9. Summarize complex procedures while maintaining critical details
    """

    description: str = Field(
        ...,
        description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
    )


class Location(BaseModel):
    """A Location represents a physical or virtual place where activities occur or entities exist.

    IMPORTANT: Before using this classification, first check if the entity is a:
    User, Assistant, Preference, Organization, Document, Event - if so, use those instead.

    Instructions for identifying and extracting locations:
    1. Look for mentions of physical places (cities, buildings, rooms, addresses)
    2. Identify virtual locations (websites, online platforms, virtual meeting rooms)
    3. Extract specific location names rather than generic references
    4. Include relevant context about the location's purpose or significance
    5. Pay attention to location hierarchies (e.g., "conference room in Building A")
    6. Capture both permanent locations and temporary venues
    7. Note any significant activities or events associated with the location
    """

    name: str = Field(
        ...,
        description='The name or identifier of the location',
    )
    description: str = Field(
        ...,
        description='Brief description of the location and its significance. Only use information mentioned in the context.',
    )


class Event(BaseModel):
    """An Event represents a time-bound activity, occurrence, or experience.

    Instructions for identifying and extracting events:
    1. Look for activities with specific time frames (meetings, appointments, deadlines)
    2. Identify planned or scheduled occurrences (vacations, projects, celebrations)
    3. Extract unplanned occurrences (accidents, interruptions, discoveries)
    4. Capture the purpose or nature of the event
    5. Include temporal information when available (past, present, future, duration)
    6. Note participants or stakeholders involved in the event
    7. Identify outcomes or consequences of the event when mentioned
    8. Extract both recurring events and one-time occurrences
    """

    name: str = Field(
        ...,
        description='The name or title of the event',
    )
    description: str = Field(
        ...,
        description='Brief description of the event. Only use information mentioned in the context.',
    )


class Object(BaseModel):
    """An Object represents a physical item, tool, device, or possession.

    IMPORTANT: Use this classification ONLY as a last resort. First check if entity fits into:
    User, Assistant, Preference, Organization, Document, Event, Location, Topic - if so, use those instead.

    Instructions for identifying and extracting objects:
    1. Look for mentions of physical items or possessions (car, phone, equipment)
    2. Identify tools or devices used for specific purposes
    3. Extract items that are owned, used, or maintained by entities
    4. Include relevant attributes (brand, model, condition) when mentioned
    5. Note the object's purpose or function when specified
    6. Capture relationships between objects and their owners or users
    7. Avoid extracting objects that are better classified as Documents or other types
    """

    name: str = Field(
        ...,
        description='The name or identifier of the object',
    )
    description: str = Field(
        ...,
        description='Brief description of the object. Only use information mentioned in the context.',
    )


class Topic(BaseModel):
    """A Topic represents a subject of conversation, interest, or knowledge domain.

    IMPORTANT: Use this classification ONLY as a last resort. First check if entity fits into:
    User, Assistant, Preference, Organization, Document, Event, Location - if so, use those instead.

    Instructions for identifying and extracting topics:
    1. Look for subjects being discussed or areas of interest (health, technology, sports)
    2. Identify knowledge domains or fields of study
    3. Extract themes that span multiple conversations or contexts
    4. Include specific subtopics when mentioned (e.g., "machine learning" rather than just "AI")
    5. Capture topics associated with projects, work, or hobbies
    6. Note the context in which the topic appears
    7. Avoid extracting topics that are better classified as Events, Documents, or Organizations
    """

    name: str = Field(
        ...,
        description='The name or identifier of the topic',
    )
    description: str = Field(
        ...,
        description='Brief description of the topic and its context. Only use information mentioned in the context.',
    )


class Organization(BaseModel):
    """An Organization represents a company, institution, group, or formal entity.

    Instructions for identifying and extracting organizations:
    1. Look for company names, employers, and business entities
    2. Identify institutions (schools, hospitals, government agencies)
    3. Extract formal groups (clubs, teams, associations)
    4. Include organizational type when mentioned (company, nonprofit, agency)
    5. Capture relationships between people and organizations (employer, member)
    6. Note the organization's industry or domain when specified
    7. Extract both large entities and small groups if formally organized
    """

    name: str = Field(
        ...,
        description='The name of the organization',
    )
    description: str = Field(
        ...,
        description='Brief description of the organization. Only use information mentioned in the context.',
    )


class Document(BaseModel):
    """A Document represents information content in various forms.

    Instructions for identifying and extracting documents:
    1. Look for references to written or recorded content (books, articles, reports)
    2. Identify digital content (emails, videos, podcasts, presentations)
    3. Extract specific document titles or identifiers when available
    4. Include document type (report, article, video) when mentioned
    5. Capture the document's purpose or subject matter
    6. Note relationships to authors, creators, or sources
    7. Include document status (draft, published, archived) when mentioned
    """

    title: str = Field(
        ...,
        description='The title or identifier of the document',
    )
    description: str = Field(
        ...,
        description='Brief description of the document and its content. Only use information mentioned in the context.',
    )


ENTITY_TYPES: dict[str, BaseModel] = {
    'Requirement': Requirement,  # type: ignore
    'Preference': Preference,  # type: ignore
    'Procedure': Procedure,  # type: ignore
    'Location': Location,  # type: ignore
    'Event': Event,  # type: ignore
    'Object': Object,  # type: ignore
    'Topic': Topic,  # type: ignore
    'Organization': Organization,  # type: ignore
    'Document': Document,  # type: ignore
}

```

--------------------------------------------------------------------------------
/examples/quickstart/quickstart_falkordb.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2025, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO

from dotenv import load_dotenv

from graphiti_core import Graphiti
from graphiti_core.driver.falkordb_driver import FalkorDriver
from graphiti_core.nodes import EpisodeType
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF

#################################################
# CONFIGURATION
#################################################
# Set up logging and environment variables for
# connecting to FalkorDB database
#################################################

# Configure logging
logging.basicConfig(
    level=INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)

load_dotenv()

# FalkorDB connection parameters
# Make sure FalkorDB (on-premises) is running — see https://docs.falkordb.com/
# By default, FalkorDB does not require a username or password,
# but you can set them via environment variables for added security.
#
# If you're using FalkorDB Cloud, set the environment variables accordingly.
# For on-premises use, you can leave them as None or set them to your preferred values.
#
# The default host and port are 'localhost' and '6379', respectively.
# You can override these values in your environment variables or directly in the code.

falkor_username = os.environ.get('FALKORDB_USERNAME', None)
falkor_password = os.environ.get('FALKORDB_PASSWORD', None)
falkor_host = os.environ.get('FALKORDB_HOST', 'localhost')
falkor_port = os.environ.get('FALKORDB_PORT', '6379')


async def main():
    #################################################
    # INITIALIZATION
    #################################################
    # Connect to FalkorDB and set up Graphiti indices
    # This is required before using other Graphiti
    # functionality
    #################################################

    # Initialize Graphiti with FalkorDB connection
    falkor_driver = FalkorDriver(
        host=falkor_host, port=falkor_port, username=falkor_username, password=falkor_password
    )
    graphiti = Graphiti(graph_driver=falkor_driver)

    try:
        #################################################
        # ADDING EPISODES
        #################################################
        # Episodes are the primary units of information
        # in Graphiti. They can be text or structured JSON
        # and are automatically processed to extract entities
        # and relationships.
        #################################################

        # Example: Add Episodes
        # Episodes list containing both text and JSON episodes
        episodes = [
            {
                'content': 'Kamala Harris is the Attorney General of California. She was previously '
                'the district attorney for San Francisco.',
                'type': EpisodeType.text,
                'description': 'podcast transcript',
            },
            {
                'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
                'type': EpisodeType.text,
                'description': 'podcast transcript',
            },
            {
                'content': {
                    'name': 'Gavin Newsom',
                    'position': 'Governor',
                    'state': 'California',
                    'previous_role': 'Lieutenant Governor',
                    'previous_location': 'San Francisco',
                },
                'type': EpisodeType.json,
                'description': 'podcast metadata',
            },
            {
                'content': {
                    'name': 'Gavin Newsom',
                    'position': 'Governor',
                    'term_start': 'January 7, 2019',
                    'term_end': 'Present',
                },
                'type': EpisodeType.json,
                'description': 'podcast metadata',
            },
        ]

        # Add episodes to the graph
        for i, episode in enumerate(episodes):
            await graphiti.add_episode(
                name=f'Freakonomics Radio {i}',
                episode_body=episode['content']
                if isinstance(episode['content'], str)
                else json.dumps(episode['content']),
                source=episode['type'],
                source_description=episode['description'],
                reference_time=datetime.now(timezone.utc),
            )
            print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')

        #################################################
        # BASIC SEARCH
        #################################################
        # The simplest way to retrieve relationships (edges)
        # from Graphiti is using the search method, which
        # performs a hybrid search combining semantic
        # similarity and BM25 text retrieval.
        #################################################

        # Perform a hybrid search combining semantic similarity and BM25 retrieval
        print("\nSearching for: 'Who was the California Attorney General?'")
        results = await graphiti.search('Who was the California Attorney General?')

        # Print search results
        print('\nSearch Results:')
        for result in results:
            print(f'UUID: {result.uuid}')
            print(f'Fact: {result.fact}')
            if hasattr(result, 'valid_at') and result.valid_at:
                print(f'Valid from: {result.valid_at}')
            if hasattr(result, 'invalid_at') and result.invalid_at:
                print(f'Valid until: {result.invalid_at}')
            print('---')

        #################################################
        # CENTER NODE SEARCH
        #################################################
        # For more contextually relevant results, you can
        # use a center node to rerank search results based
        # on their graph distance to a specific node
        #################################################

        # Use the top search result's UUID as the center node for reranking
        if results and len(results) > 0:
            # Get the source node UUID from the top result
            center_node_uuid = results[0].source_node_uuid

            print('\nReranking search results based on graph distance:')
            print(f'Using center node UUID: {center_node_uuid}')

            reranked_results = await graphiti.search(
                'Who was the California Attorney General?', center_node_uuid=center_node_uuid
            )

            # Print reranked search results
            print('\nReranked Search Results:')
            for result in reranked_results:
                print(f'UUID: {result.uuid}')
                print(f'Fact: {result.fact}')
                if hasattr(result, 'valid_at') and result.valid_at:
                    print(f'Valid from: {result.valid_at}')
                if hasattr(result, 'invalid_at') and result.invalid_at:
                    print(f'Valid until: {result.invalid_at}')
                print('---')
        else:
            print('No results found in the initial search to use as center node.')

        #################################################
        # NODE SEARCH USING SEARCH RECIPES
        #################################################
        # Graphiti provides predefined search recipes
        # optimized for different search scenarios.
        # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
        # nodes directly instead of edges.
        #################################################

        # Example: Perform a node search using _search method with standard recipes
        print(
            '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
        )

        # Use a predefined search configuration recipe and modify its limit
        node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
        node_search_config.limit = 5  # Limit to 5 results

        # Execute the node search
        node_search_results = await graphiti._search(
            query='California Governor',
            config=node_search_config,
        )

        # Print node search results
        print('\nNode Search Results:')
        for node in node_search_results.nodes:
            print(f'Node UUID: {node.uuid}')
            print(f'Node Name: {node.name}')
            node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
            print(f'Content Summary: {node_summary}')
            print(f'Node Labels: {", ".join(node.labels)}')
            print(f'Created At: {node.created_at}')
            if hasattr(node, 'attributes') and node.attributes:
                print('Attributes:')
                for key, value in node.attributes.items():
                    print(f'  {key}: {value}')
            print('---')

    finally:
        #################################################
        # CLEANUP
        #################################################
        # Always close the connection to FalkorDB when
        # finished to properly release resources
        #################################################

        # Close the connection
        await graphiti.close()
        print('\nConnection closed')


if __name__ == '__main__':
    asyncio.run(main())

```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/openai_base_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import typing
from abc import abstractmethod
from typing import Any, ClassVar

import openai
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel

from ..prompts.models import Message
from .client import LLMClient, get_extraction_language_instruction
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError

logger = logging.getLogger(__name__)

DEFAULT_MODEL = 'gpt-5-mini'
DEFAULT_SMALL_MODEL = 'gpt-5-nano'
DEFAULT_REASONING = 'minimal'
DEFAULT_VERBOSITY = 'low'


class BaseOpenAIClient(LLMClient):
    """
    Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).

    This class contains shared logic for both OpenAI and Azure OpenAI clients,
    reducing code duplication while allowing for implementation-specific differences.
    """

    # Class-level constants
    MAX_RETRIES: ClassVar[int] = 2

    def __init__(
        self,
        config: LLMConfig | None = None,
        cache: bool = False,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        reasoning: str | None = DEFAULT_REASONING,
        verbosity: str | None = DEFAULT_VERBOSITY,
    ):
        if cache:
            raise NotImplementedError('Caching is not implemented for OpenAI-based clients')

        if config is None:
            config = LLMConfig()

        super().__init__(config, cache)
        self.max_tokens = max_tokens
        self.reasoning = reasoning
        self.verbosity = verbosity

    @abstractmethod
    async def _create_completion(
        self,
        model: str,
        messages: list[ChatCompletionMessageParam],
        temperature: float | None,
        max_tokens: int,
        response_model: type[BaseModel] | None = None,
    ) -> Any:
        """Create a completion using the specific client implementation."""
        pass

    @abstractmethod
    async def _create_structured_completion(
        self,
        model: str,
        messages: list[ChatCompletionMessageParam],
        temperature: float | None,
        max_tokens: int,
        response_model: type[BaseModel],
        reasoning: str | None,
        verbosity: str | None,
    ) -> Any:
        """Create a structured completion using the specific client implementation."""
        pass

    def _convert_messages_to_openai_format(
        self, messages: list[Message]
    ) -> list[ChatCompletionMessageParam]:
        """Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
        openai_messages: list[ChatCompletionMessageParam] = []
        for m in messages:
            m.content = self._clean_input(m.content)
            if m.role == 'user':
                openai_messages.append({'role': 'user', 'content': m.content})
            elif m.role == 'system':
                openai_messages.append({'role': 'system', 'content': m.content})
        return openai_messages

    def _get_model_for_size(self, model_size: ModelSize) -> str:
        """Get the appropriate model name based on the requested size."""
        if model_size == ModelSize.small:
            return self.small_model or DEFAULT_SMALL_MODEL
        else:
            return self.model or DEFAULT_MODEL

    def _handle_structured_response(self, response: Any) -> dict[str, Any]:
        """Handle structured response parsing and validation."""
        response_object = response.output_text

        if response_object:
            return json.loads(response_object)
        elif response_object.refusal:
            raise RefusalError(response_object.refusal)
        else:
            raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')

    def _handle_json_response(self, response: Any) -> dict[str, Any]:
        """Handle JSON response parsing."""
        result = response.choices[0].message.content or '{}'
        return json.loads(result)

    async def _generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int = DEFAULT_MAX_TOKENS,
        model_size: ModelSize = ModelSize.medium,
    ) -> dict[str, Any]:
        """Generate a response using the appropriate client implementation."""
        openai_messages = self._convert_messages_to_openai_format(messages)
        model = self._get_model_for_size(model_size)

        try:
            if response_model:
                response = await self._create_structured_completion(
                    model=model,
                    messages=openai_messages,
                    temperature=self.temperature,
                    max_tokens=max_tokens or self.max_tokens,
                    response_model=response_model,
                    reasoning=self.reasoning,
                    verbosity=self.verbosity,
                )
                return self._handle_structured_response(response)
            else:
                response = await self._create_completion(
                    model=model,
                    messages=openai_messages,
                    temperature=self.temperature,
                    max_tokens=max_tokens or self.max_tokens,
                )
                return self._handle_json_response(response)

        except openai.LengthFinishReasonError as e:
            raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
        except openai.RateLimitError as e:
            raise RateLimitError from e
        except openai.AuthenticationError as e:
            logger.error(
                f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
            )
            raise
        except Exception as e:
            # Provide more context for connection errors
            error_msg = str(e)
            if 'Connection error' in error_msg or 'connection' in error_msg.lower():
                logger.error(
                    f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
                )
            else:
                logger.error(f'Error in generating LLM response: {e}')
            raise

    async def generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
        group_id: str | None = None,
        prompt_name: str | None = None,
    ) -> dict[str, typing.Any]:
        """Generate a response with retry logic and error handling."""
        if max_tokens is None:
            max_tokens = self.max_tokens

        # Add multilingual extraction instructions
        messages[0].content += get_extraction_language_instruction(group_id)

        # Wrap entire operation in tracing span
        with self.tracer.start_span('llm.generate') as span:
            attributes = {
                'llm.provider': 'openai',
                'model.size': model_size.value,
                'max_tokens': max_tokens,
            }
            if prompt_name:
                attributes['prompt.name'] = prompt_name
            span.add_attributes(attributes)

            retry_count = 0
            last_error = None

            while retry_count <= self.MAX_RETRIES:
                try:
                    response = await self._generate_response(
                        messages, response_model, max_tokens, model_size
                    )
                    return response
                except (RateLimitError, RefusalError):
                    # These errors should not trigger retries
                    span.set_status('error', str(last_error))
                    raise
                except (
                    openai.APITimeoutError,
                    openai.APIConnectionError,
                    openai.InternalServerError,
                ):
                    # Let OpenAI's client handle these retries
                    span.set_status('error', str(last_error))
                    raise
                except Exception as e:
                    last_error = e

                    # Don't retry if we've hit the max retries
                    if retry_count >= self.MAX_RETRIES:
                        logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
                        span.set_status('error', str(e))
                        span.record_exception(e)
                        raise

                    retry_count += 1

                    # Construct a detailed error message for the LLM
                    error_context = (
                        f'The previous response attempt was invalid. '
                        f'Error type: {e.__class__.__name__}. '
                        f'Error details: {str(e)}. '
                        f'Please try again with a valid response, ensuring the output matches '
                        f'the expected format and constraints.'
                    )

                    error_message = Message(role='user', content=error_context)
                    messages.append(error_message)
                    logger.warning(
                        f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
                    )

            # If we somehow get here, raise the last error
            span.set_status('error', str(last_error))
            raise last_error or Exception('Max retries exceeded with no specific error')

```

--------------------------------------------------------------------------------
/mcp_server/src/config/schema.py:
--------------------------------------------------------------------------------

```python
"""Configuration schemas with pydantic-settings and YAML support."""

import os
from pathlib import Path
from typing import Any

import yaml
from pydantic import BaseModel, Field
from pydantic_settings import (
    BaseSettings,
    PydanticBaseSettingsSource,
    SettingsConfigDict,
)


class YamlSettingsSource(PydanticBaseSettingsSource):
    """Custom settings source for loading from YAML files."""

    def __init__(self, settings_cls: type[BaseSettings], config_path: Path | None = None):
        super().__init__(settings_cls)
        self.config_path = config_path or Path('config.yaml')

    def _expand_env_vars(self, value: Any) -> Any:
        """Recursively expand environment variables in configuration values."""
        if isinstance(value, str):
            # Support ${VAR} and ${VAR:default} syntax
            import re

            def replacer(match):
                var_name = match.group(1)
                default_value = match.group(3) if match.group(3) is not None else ''
                return os.environ.get(var_name, default_value)

            pattern = r'\$\{([^:}]+)(:([^}]*))?\}'

            # Check if the entire value is a single env var expression
            full_match = re.fullmatch(pattern, value)
            if full_match:
                result = replacer(full_match)
                # Convert boolean-like strings to actual booleans
                if isinstance(result, str):
                    lower_result = result.lower().strip()
                    if lower_result in ('true', '1', 'yes', 'on'):
                        return True
                    elif lower_result in ('false', '0', 'no', 'off'):
                        return False
                    elif lower_result == '':
                        # Empty string means env var not set - return None for optional fields
                        return None
                return result
            else:
                # Otherwise, do string substitution (keep as strings for partial replacements)
                return re.sub(pattern, replacer, value)
        elif isinstance(value, dict):
            return {k: self._expand_env_vars(v) for k, v in value.items()}
        elif isinstance(value, list):
            return [self._expand_env_vars(item) for item in value]
        return value

    def get_field_value(self, field_name: str, field_info: Any) -> Any:
        """Get field value from YAML config."""
        return None

    def __call__(self) -> dict[str, Any]:
        """Load and parse YAML configuration."""
        if not self.config_path.exists():
            return {}

        with open(self.config_path) as f:
            raw_config = yaml.safe_load(f) or {}

        # Expand environment variables
        return self._expand_env_vars(raw_config)


class ServerConfig(BaseModel):
    """Server configuration."""

    transport: str = Field(
        default='http',
        description='Transport type: http (default, recommended), stdio, or sse (deprecated)',
    )
    host: str = Field(default='0.0.0.0', description='Server host')
    port: int = Field(default=8000, description='Server port')


class OpenAIProviderConfig(BaseModel):
    """OpenAI provider configuration."""

    api_key: str | None = None
    api_url: str = 'https://api.openai.com/v1'
    organization_id: str | None = None


class AzureOpenAIProviderConfig(BaseModel):
    """Azure OpenAI provider configuration."""

    api_key: str | None = None
    api_url: str | None = None
    api_version: str = '2024-10-21'
    deployment_name: str | None = None
    use_azure_ad: bool = False


class AnthropicProviderConfig(BaseModel):
    """Anthropic provider configuration."""

    api_key: str | None = None
    api_url: str = 'https://api.anthropic.com'
    max_retries: int = 3


class GeminiProviderConfig(BaseModel):
    """Gemini provider configuration."""

    api_key: str | None = None
    project_id: str | None = None
    location: str = 'us-central1'


class GroqProviderConfig(BaseModel):
    """Groq provider configuration."""

    api_key: str | None = None
    api_url: str = 'https://api.groq.com/openai/v1'


class VoyageProviderConfig(BaseModel):
    """Voyage AI provider configuration."""

    api_key: str | None = None
    api_url: str = 'https://api.voyageai.com/v1'
    model: str = 'voyage-3'


class LLMProvidersConfig(BaseModel):
    """LLM providers configuration."""

    openai: OpenAIProviderConfig | None = None
    azure_openai: AzureOpenAIProviderConfig | None = None
    anthropic: AnthropicProviderConfig | None = None
    gemini: GeminiProviderConfig | None = None
    groq: GroqProviderConfig | None = None


class LLMConfig(BaseModel):
    """LLM configuration."""

    provider: str = Field(default='openai', description='LLM provider')
    model: str = Field(default='gpt-4.1', description='Model name')
    temperature: float | None = Field(
        default=None, description='Temperature (optional, defaults to None for reasoning models)'
    )
    max_tokens: int = Field(default=4096, description='Max tokens')
    providers: LLMProvidersConfig = Field(default_factory=LLMProvidersConfig)


class EmbedderProvidersConfig(BaseModel):
    """Embedder providers configuration."""

    openai: OpenAIProviderConfig | None = None
    azure_openai: AzureOpenAIProviderConfig | None = None
    gemini: GeminiProviderConfig | None = None
    voyage: VoyageProviderConfig | None = None


class EmbedderConfig(BaseModel):
    """Embedder configuration."""

    provider: str = Field(default='openai', description='Embedder provider')
    model: str = Field(default='text-embedding-3-small', description='Model name')
    dimensions: int = Field(default=1536, description='Embedding dimensions')
    providers: EmbedderProvidersConfig = Field(default_factory=EmbedderProvidersConfig)


class Neo4jProviderConfig(BaseModel):
    """Neo4j provider configuration."""

    uri: str = 'bolt://localhost:7687'
    username: str = 'neo4j'
    password: str | None = None
    database: str = 'neo4j'
    use_parallel_runtime: bool = False


class FalkorDBProviderConfig(BaseModel):
    """FalkorDB provider configuration."""

    uri: str = 'redis://localhost:6379'
    password: str | None = None
    database: str = 'default_db'


class DatabaseProvidersConfig(BaseModel):
    """Database providers configuration."""

    neo4j: Neo4jProviderConfig | None = None
    falkordb: FalkorDBProviderConfig | None = None


class DatabaseConfig(BaseModel):
    """Database configuration."""

    provider: str = Field(default='falkordb', description='Database provider')
    providers: DatabaseProvidersConfig = Field(default_factory=DatabaseProvidersConfig)


class EntityTypeConfig(BaseModel):
    """Entity type configuration."""

    name: str
    description: str


class GraphitiAppConfig(BaseModel):
    """Graphiti-specific configuration."""

    group_id: str = Field(default='main', description='Group ID')
    episode_id_prefix: str | None = Field(default='', description='Episode ID prefix')
    user_id: str = Field(default='mcp_user', description='User ID')
    entity_types: list[EntityTypeConfig] = Field(default_factory=list)

    def model_post_init(self, __context) -> None:
        """Convert None to empty string for episode_id_prefix."""
        if self.episode_id_prefix is None:
            self.episode_id_prefix = ''


class GraphitiConfig(BaseSettings):
    """Graphiti configuration with YAML and environment support."""

    server: ServerConfig = Field(default_factory=ServerConfig)
    llm: LLMConfig = Field(default_factory=LLMConfig)
    embedder: EmbedderConfig = Field(default_factory=EmbedderConfig)
    database: DatabaseConfig = Field(default_factory=DatabaseConfig)
    graphiti: GraphitiAppConfig = Field(default_factory=GraphitiAppConfig)

    # Additional server options
    destroy_graph: bool = Field(default=False, description='Clear graph on startup')

    model_config = SettingsConfigDict(
        env_prefix='',
        env_nested_delimiter='__',
        case_sensitive=False,
        extra='ignore',
    )

    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: type[BaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> tuple[PydanticBaseSettingsSource, ...]:
        """Customize settings sources to include YAML."""
        config_path = Path(os.environ.get('CONFIG_PATH', 'config/config.yaml'))
        yaml_settings = YamlSettingsSource(settings_cls, config_path)
        # Priority: CLI args (init) > env vars > yaml > defaults
        return (init_settings, env_settings, yaml_settings, dotenv_settings)

    def apply_cli_overrides(self, args) -> None:
        """Apply CLI argument overrides to configuration."""
        # Override server settings
        if hasattr(args, 'transport') and args.transport:
            self.server.transport = args.transport

        # Override LLM settings
        if hasattr(args, 'llm_provider') and args.llm_provider:
            self.llm.provider = args.llm_provider
        if hasattr(args, 'model') and args.model:
            self.llm.model = args.model
        if hasattr(args, 'temperature') and args.temperature is not None:
            self.llm.temperature = args.temperature

        # Override embedder settings
        if hasattr(args, 'embedder_provider') and args.embedder_provider:
            self.embedder.provider = args.embedder_provider
        if hasattr(args, 'embedder_model') and args.embedder_model:
            self.embedder.model = args.embedder_model

        # Override database settings
        if hasattr(args, 'database_provider') and args.database_provider:
            self.database.provider = args.database_provider

        # Override Graphiti settings
        if hasattr(args, 'group_id') and args.group_id:
            self.graphiti.group_id = args.group_id
        if hasattr(args, 'user_id') and args.user_id:
            self.graphiti.user_id = args.user_id

```

--------------------------------------------------------------------------------
/tests/helpers_test.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
from unittest.mock import Mock

import numpy as np
import pytest
from dotenv import load_dotenv

from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.embedder.client import EmbedderClient
from graphiti_core.helpers import lucene_sanitize
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.utils.maintenance.graph_data_operations import clear_data

load_dotenv()

drivers: list[GraphProvider] = []
if os.getenv('DISABLE_NEO4J') is None:
    try:
        from graphiti_core.driver.neo4j_driver import Neo4jDriver

        drivers.append(GraphProvider.NEO4J)
    except ImportError:
        raise

if os.getenv('DISABLE_FALKORDB') is None:
    try:
        from graphiti_core.driver.falkordb_driver import FalkorDriver

        drivers.append(GraphProvider.FALKORDB)
    except ImportError:
        raise

if os.getenv('DISABLE_KUZU') is None:
    try:
        from graphiti_core.driver.kuzu_driver import KuzuDriver

        drivers.append(GraphProvider.KUZU)
    except ImportError:
        raise

# Disable Neptune for now
os.environ['DISABLE_NEPTUNE'] = 'True'
if os.getenv('DISABLE_NEPTUNE') is None:
    try:
        from graphiti_core.driver.neptune_driver import NeptuneDriver

        drivers.append(GraphProvider.NEPTUNE)
    except ImportError:
        raise

NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')

FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
FALKORDB_USER = os.getenv('FALKORDB_USER', None)
FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)

NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
AOSS_HOST = os.getenv('AOSS_HOST', None)

KUZU_DB = os.getenv('KUZU_DB', ':memory:')

group_id = 'graphiti_test_group'
group_id_2 = 'graphiti_test_group_2'


def get_driver(provider: GraphProvider) -> GraphDriver:
    if provider == GraphProvider.NEO4J:
        return Neo4jDriver(
            uri=NEO4J_URI,
            user=NEO4J_USER,
            password=NEO4J_PASSWORD,
        )
    elif provider == GraphProvider.FALKORDB:
        return FalkorDriver(
            host=FALKORDB_HOST,
            port=int(FALKORDB_PORT),
            username=FALKORDB_USER,
            password=FALKORDB_PASSWORD,
        )
    elif provider == GraphProvider.KUZU:
        driver = KuzuDriver(
            db=KUZU_DB,
        )
        return driver
    elif provider == GraphProvider.NEPTUNE:
        return NeptuneDriver(
            host=NEPTUNE_HOST,
            port=int(NEPTUNE_PORT),
            aoss_host=AOSS_HOST,
        )
    else:
        raise ValueError(f'Driver {provider} not available')


@pytest.fixture(params=drivers)
async def graph_driver(request):
    driver = request.param
    graph_driver = get_driver(driver)
    await clear_data(graph_driver, [group_id, group_id_2])
    try:
        yield graph_driver  # provide driver to the test
    finally:
        # always called, even if the test fails or raises
        # await clean_up(graph_driver)
        await graph_driver.close()


embedding_dim = 384
embeddings = {
    key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
    for key in [
        'Alice',
        'Bob',
        'Alice likes Bob',
        'test_entity_1',
        'test_entity_2',
        'test_entity_3',
        'test_entity_4',
        'test_entity_alice',
        'test_entity_bob',
        'test_entity_1 is a duplicate of test_entity_2',
        'test_entity_3 is a duplicate of test_entity_4',
        'test_entity_1 relates to test_entity_2',
        'test_entity_1 relates to test_entity_3',
        'test_entity_2 relates to test_entity_3',
        'test_entity_1 relates to test_entity_4',
        'test_entity_2 relates to test_entity_4',
        'test_entity_3 relates to test_entity_4',
        'test_entity_1 relates to test_entity_2',
        'test_entity_3 relates to test_entity_4',
        'test_entity_2 relates to test_entity_3',
        'test_community_1',
        'test_community_2',
    ]
}
embeddings['Alice Smith'] = embeddings['Alice']


@pytest.fixture
def mock_embedder():
    mock_model = Mock(spec=EmbedderClient)

    def mock_embed(input_data):
        if isinstance(input_data, str):
            return embeddings[input_data]
        elif isinstance(input_data, list):
            combined_input = ' '.join(input_data)
            return embeddings[combined_input]
        else:
            raise ValueError(f'Unsupported input type: {type(input_data)}')

    mock_model.create.side_effect = mock_embed
    return mock_model


def test_lucene_sanitize():
    # Call the function with test data
    queries = [
        (
            'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
            '\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/',
        ),
        ('this has no escape characters', 'this has no escape characters'),
    ]

    for query, assert_result in queries:
        result = lucene_sanitize(query)
        assert assert_result == result


async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
    results, _, _ = await driver.execute_query(
        """
        MATCH (n)
        WHERE n.uuid IN $uuids
        RETURN COUNT(n) as count
        """,
        uuids=uuids,
    )
    return int(results[0]['count'])


async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
    results, _, _ = await driver.execute_query(
        """
        MATCH (n)-[e]->(m)
        WHERE e.uuid IN $uuids
        RETURN COUNT(e) as count
        UNION ALL
        MATCH (e:RelatesToNode_)
        WHERE e.uuid IN $uuids
        RETURN COUNT(e) as count
        """,
        uuids=uuids,
    )
    return sum(int(result['count']) for result in results)


async def print_graph(graph_driver: GraphDriver):
    nodes, _, _ = await graph_driver.execute_query(
        """
        MATCH (n)
        RETURN n.uuid, n.name
        """,
    )
    print('Nodes:')
    for node in nodes:
        print('  ', node)
    edges, _, _ = await graph_driver.execute_query(
        """
        MATCH (n)-[e]->(m)
        RETURN n.name, e.uuid, m.name
        """,
    )
    print('Edges:')
    for edge in edges:
        print('  ', edge)


async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
    assert retrieved.uuid == sample.uuid
    assert retrieved.name == sample.name
    assert retrieved.group_id == group_id
    assert retrieved.created_at == sample.created_at
    assert retrieved.source == sample.source
    assert retrieved.source_description == sample.source_description
    assert retrieved.content == sample.content
    assert retrieved.valid_at == sample.valid_at
    assert set(retrieved.entity_edges) == set(sample.entity_edges)


async def assert_entity_node_equals(
    graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
):
    await retrieved.load_name_embedding(graph_driver)
    assert retrieved.uuid == sample.uuid
    assert retrieved.name == sample.name
    assert retrieved.group_id == sample.group_id
    assert set(retrieved.labels) == set(sample.labels)
    assert retrieved.created_at == sample.created_at
    assert retrieved.name_embedding is not None
    assert sample.name_embedding is not None
    assert np.allclose(retrieved.name_embedding, sample.name_embedding)
    assert retrieved.summary == sample.summary
    assert retrieved.attributes == sample.attributes


async def assert_community_node_equals(
    graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
):
    await retrieved.load_name_embedding(graph_driver)
    assert retrieved.uuid == sample.uuid
    assert retrieved.name == sample.name
    assert retrieved.group_id == group_id
    assert retrieved.created_at == sample.created_at
    assert retrieved.name_embedding is not None
    assert sample.name_embedding is not None
    assert np.allclose(retrieved.name_embedding, sample.name_embedding)
    assert retrieved.summary == sample.summary


async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
    assert retrieved.uuid == sample.uuid
    assert retrieved.group_id == sample.group_id
    assert retrieved.created_at == sample.created_at
    assert retrieved.source_node_uuid == sample.source_node_uuid
    assert retrieved.target_node_uuid == sample.target_node_uuid


async def assert_entity_edge_equals(
    graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
):
    await retrieved.load_fact_embedding(graph_driver)
    assert retrieved.uuid == sample.uuid
    assert retrieved.group_id == sample.group_id
    assert retrieved.created_at == sample.created_at
    assert retrieved.source_node_uuid == sample.source_node_uuid
    assert retrieved.target_node_uuid == sample.target_node_uuid
    assert retrieved.name == sample.name
    assert retrieved.fact == sample.fact
    assert retrieved.fact_embedding is not None
    assert sample.fact_embedding is not None
    assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
    assert retrieved.episodes == sample.episodes
    assert retrieved.expired_at == sample.expired_at
    assert retrieved.valid_at == sample.valid_at
    assert retrieved.invalid_at == sample.invalid_at
    assert retrieved.attributes == sample.attributes


if __name__ == '__main__':
    pytest.main([__file__])

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_fixtures.py:
--------------------------------------------------------------------------------

```python
"""
Shared test fixtures and utilities for Graphiti MCP integration tests.
"""

import asyncio
import contextlib
import json
import os
import random
import time
from contextlib import asynccontextmanager
from typing import Any

import pytest
from faker import Faker
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

fake = Faker()


class TestDataGenerator:
    """Generate realistic test data for various scenarios."""

    @staticmethod
    def generate_company_profile() -> dict[str, Any]:
        """Generate a realistic company profile."""
        return {
            'company': {
                'name': fake.company(),
                'founded': random.randint(1990, 2023),
                'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']),
                'employees': random.randint(10, 10000),
                'revenue': f'${random.randint(1, 1000)}M',
                'headquarters': fake.city(),
            },
            'products': [
                {
                    'id': fake.uuid4()[:8],
                    'name': fake.catch_phrase(),
                    'category': random.choice(['Software', 'Hardware', 'Service']),
                    'price': random.randint(10, 10000),
                }
                for _ in range(random.randint(1, 5))
            ],
            'leadership': {
                'ceo': fake.name(),
                'cto': fake.name(),
                'cfo': fake.name(),
            },
        }

    @staticmethod
    def generate_conversation(turns: int = 3) -> str:
        """Generate a realistic conversation."""
        topics = [
            'product features',
            'pricing',
            'technical support',
            'integration',
            'documentation',
            'performance',
        ]

        conversation = []
        for _ in range(turns):
            topic = random.choice(topics)
            user_msg = f'user: {fake.sentence()} about {topic}?'
            assistant_msg = f'assistant: {fake.paragraph(nb_sentences=2)}'
            conversation.extend([user_msg, assistant_msg])

        return '\n'.join(conversation)

    @staticmethod
    def generate_technical_document() -> str:
        """Generate technical documentation content."""
        sections = [
            f'# {fake.catch_phrase()}\n\n{fake.paragraph()}',
            f'## Architecture\n{fake.paragraph()}',
            f'## Implementation\n{fake.paragraph()}',
            f'## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s',
            f'## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}',
        ]
        return '\n\n'.join(sections)

    @staticmethod
    def generate_news_article() -> str:
        """Generate a news article."""
        company = fake.company()
        return f"""
        {company} Announces {fake.catch_phrase()}

        {fake.city()}, {fake.date()} - {company} today announced {fake.paragraph()}.

        "This is a significant milestone," said {fake.name()}, CEO of {company}.
        "{fake.sentence()}"

        The announcement comes after {fake.paragraph()}.

        Industry analysts predict {fake.paragraph()}.
        """

    @staticmethod
    def generate_user_profile() -> dict[str, Any]:
        """Generate a user profile."""
        return {
            'user_id': fake.uuid4(),
            'name': fake.name(),
            'email': fake.email(),
            'joined': fake.date_time_this_year().isoformat(),
            'preferences': {
                'theme': random.choice(['light', 'dark', 'auto']),
                'notifications': random.choice([True, False]),
                'language': random.choice(['en', 'es', 'fr', 'de']),
            },
            'activity': {
                'last_login': fake.date_time_this_month().isoformat(),
                'total_sessions': random.randint(1, 1000),
                'average_duration': f'{random.randint(1, 60)} minutes',
            },
        }


class MockLLMProvider:
    """Mock LLM provider for testing without actual API calls."""

    def __init__(self, delay: float = 0.1):
        self.delay = delay  # Simulate LLM latency

    async def generate(self, prompt: str) -> str:
        """Simulate LLM generation with delay."""
        await asyncio.sleep(self.delay)

        # Return deterministic responses based on prompt patterns
        if 'extract entities' in prompt.lower():
            return json.dumps(
                {
                    'entities': [
                        {'name': 'TestEntity1', 'type': 'PERSON'},
                        {'name': 'TestEntity2', 'type': 'ORGANIZATION'},
                    ]
                }
            )
        elif 'summarize' in prompt.lower():
            return 'This is a test summary of the provided content.'
        else:
            return 'Mock LLM response'


@asynccontextmanager
async def graphiti_test_client(
    group_id: str | None = None,
    database: str = 'falkordb',
    use_mock_llm: bool = False,
    config_overrides: dict[str, Any] | None = None,
):
    """
    Context manager for creating test clients with various configurations.

    Args:
        group_id: Test group identifier
        database: Database backend (neo4j, falkordb)
        use_mock_llm: Whether to use mock LLM for faster tests
        config_overrides: Additional config overrides
    """
    test_group_id = group_id or f'test_{int(time.time())}_{random.randint(1000, 9999)}'

    env = {
        'DATABASE_PROVIDER': database,
        'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key' if use_mock_llm else None),
    }

    # Database-specific configuration
    if database == 'neo4j':
        env.update(
            {
                'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
            }
        )
    elif database == 'falkordb':
        env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')

    # Apply config overrides
    if config_overrides:
        env.update(config_overrides)

    # Add mock LLM flag if needed
    if use_mock_llm:
        env['USE_MOCK_LLM'] = 'true'

    server_params = StdioServerParameters(
        command='uv', args=['run', 'main.py', '--transport', 'stdio'], env=env
    )

    async with stdio_client(server_params) as (read, write):
        session = ClientSession(read, write)
        await session.initialize()

        try:
            yield session, test_group_id
        finally:
            # Cleanup: Clear test data
            with contextlib.suppress(Exception):
                await session.call_tool('clear_graph', {'group_id': test_group_id})

            await session.close()


class PerformanceBenchmark:
    """Track and analyze performance benchmarks."""

    def __init__(self):
        self.measurements: dict[str, list[float]] = {}

    def record(self, operation: str, duration: float):
        """Record a performance measurement."""
        if operation not in self.measurements:
            self.measurements[operation] = []
        self.measurements[operation].append(duration)

    def get_stats(self, operation: str) -> dict[str, float]:
        """Get statistics for an operation."""
        if operation not in self.measurements or not self.measurements[operation]:
            return {}

        durations = self.measurements[operation]
        return {
            'count': len(durations),
            'mean': sum(durations) / len(durations),
            'min': min(durations),
            'max': max(durations),
            'median': sorted(durations)[len(durations) // 2],
        }

    def report(self) -> str:
        """Generate a performance report."""
        lines = ['Performance Benchmark Report', '=' * 40]

        for operation in sorted(self.measurements.keys()):
            stats = self.get_stats(operation)
            lines.append(f'\n{operation}:')
            lines.append(f'  Samples: {stats["count"]}')
            lines.append(f'  Mean: {stats["mean"]:.3f}s')
            lines.append(f'  Median: {stats["median"]:.3f}s')
            lines.append(f'  Min: {stats["min"]:.3f}s')
            lines.append(f'  Max: {stats["max"]:.3f}s')

        return '\n'.join(lines)


# Pytest fixtures
@pytest.fixture
def test_data_generator():
    """Provide test data generator."""
    return TestDataGenerator()


@pytest.fixture
def performance_benchmark():
    """Provide performance benchmark tracker."""
    return PerformanceBenchmark()


@pytest.fixture
async def mock_graphiti_client():
    """Provide a Graphiti client with mocked LLM."""
    async with graphiti_test_client(use_mock_llm=True) as (session, group_id):
        yield session, group_id


@pytest.fixture
async def graphiti_client():
    """Provide a real Graphiti client."""
    async with graphiti_test_client(use_mock_llm=False) as (session, group_id):
        yield session, group_id


# Test data fixtures
@pytest.fixture
def sample_memories():
    """Provide sample memory data for testing."""
    return [
        {
            'name': 'Company Overview',
            'episode_body': TestDataGenerator.generate_company_profile(),
            'source': 'json',
            'source_description': 'company database',
        },
        {
            'name': 'Product Launch',
            'episode_body': TestDataGenerator.generate_news_article(),
            'source': 'text',
            'source_description': 'press release',
        },
        {
            'name': 'Customer Support',
            'episode_body': TestDataGenerator.generate_conversation(),
            'source': 'message',
            'source_description': 'support chat',
        },
        {
            'name': 'Technical Specs',
            'episode_body': TestDataGenerator.generate_technical_document(),
            'source': 'text',
            'source_description': 'documentation',
        },
    ]


@pytest.fixture
def large_dataset():
    """Generate a large dataset for stress testing."""
    return [
        {
            'name': f'Document {i}',
            'episode_body': TestDataGenerator.generate_technical_document(),
            'source': 'text',
            'source_description': 'bulk import',
        }
        for i in range(50)
    ]

```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_bulk_utils.py:
--------------------------------------------------------------------------------

```python
from collections import deque
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.datetime_utils import utc_now


def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
    return EpisodicNode(
        name=f'episode-{uuid_suffix}',
        group_id=group_id,
        labels=[],
        source=EpisodeType.message,
        content='content',
        source_description='test',
        created_at=utc_now(),
        valid_at=utc_now(),
    )


def _make_clients() -> GraphitiClients:
    driver = MagicMock()
    embedder = MagicMock()
    cross_encoder = MagicMock()
    llm_client = MagicMock()

    return GraphitiClients.model_construct(  # bypass validation to allow test doubles
        driver=driver,
        embedder=embedder,
        cross_encoder=cross_encoder,
        llm_client=llm_client,
    )


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
    clients = _make_clients()

    episode_one = _make_episode('1')
    episode_two = _make_episode('2')

    extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
    extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])

    canonical = extracted_one

    call_queue = deque()

    async def fake_resolve(
        clients_arg,
        nodes_arg,
        episode_arg,
        previous_episodes_arg,
        entity_types_arg,
        existing_nodes_override=None,
    ):
        call_queue.append(existing_nodes_override)

        if nodes_arg == [extracted_one]:
            return [canonical], {canonical.uuid: canonical.uuid}, []

        assert nodes_arg == [extracted_two]
        assert existing_nodes_override is None

        return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]

    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [[extracted_one], [extracted_two]],
        [(episode_one, []), (episode_two, [])],
    )

    assert len(call_queue) == 2
    assert call_queue[0] is None
    assert call_queue[1] is None

    assert nodes_by_episode[episode_one.uuid] == [canonical]
    assert nodes_by_episode[episode_two.uuid] == [canonical]
    assert compressed_map.get(extracted_two.uuid) == canonical.uuid


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
    clients = _make_clients()

    resolve_mock = AsyncMock()
    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [],
        [],
    )

    assert nodes_by_episode == {}
    assert compressed_map == {}
    resolve_mock.assert_not_awaited()


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
    clients = _make_clients()

    episode = _make_episode('solo')
    extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])

    resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [[extracted]],
        [(episode, [])],
    )

    assert nodes_by_episode == {episode.uuid: [extracted]}
    assert compressed_map == {extracted.uuid: extracted.uuid}
    resolve_mock.assert_awaited_once()


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
    clients = _make_clients()

    episode_one = _make_episode('one')
    episode_two = _make_episode('two')

    extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
    extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])

    canonical = extracted_one
    alias = extracted_two

    async def fake_resolve(
        clients_arg,
        nodes_arg,
        episode_arg,
        previous_episodes_arg,
        entity_types_arg,
        existing_nodes_override=None,
    ):
        if nodes_arg == [extracted_one]:
            return [canonical], {canonical.uuid: canonical.uuid}, []
        assert nodes_arg == [extracted_two]
        return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]

    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [[extracted_one], [extracted_two]],
        [(episode_one, []), (episode_two, [])],
    )

    assert nodes_by_episode[episode_one.uuid] == [canonical]
    assert nodes_by_episode[episode_two.uuid] == [canonical]
    assert compressed_map.get(alias.uuid) == canonical.uuid


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
    clients = _make_clients()

    episode = _make_episode('missing')
    extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])

    resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)

    with caplog.at_level('WARNING'):
        nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
            clients,
            [[extracted]],
            [(episode, [])],
        )

    assert nodes_by_episode[episode.uuid] == [extracted]
    assert compressed_map.get(extracted.uuid) == 'missing-canonical'
    assert any('Canonical node missing' in rec.message for rec in caplog.records)


def test_build_directed_uuid_map_empty():
    assert bulk_utils._build_directed_uuid_map([]) == {}


def test_build_directed_uuid_map_chain():
    mapping = bulk_utils._build_directed_uuid_map(
        [
            ('a', 'b'),
            ('b', 'c'),
        ]
    )

    assert mapping['a'] == 'c'
    assert mapping['b'] == 'c'
    assert mapping['c'] == 'c'


def test_build_directed_uuid_map_preserves_direction():
    mapping = bulk_utils._build_directed_uuid_map(
        [
            ('alias', 'canonical'),
        ]
    )

    assert mapping['alias'] == 'canonical'
    assert mapping['canonical'] == 'canonical'


def test_resolve_edge_pointers_updates_sources():
    created_at = utc_now()
    edge = EntityEdge(
        name='knows',
        fact='fact',
        group_id='group',
        source_node_uuid='alias',
        target_node_uuid='target',
        created_at=created_at,
    )

    bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})

    assert edge.source_node_uuid == 'canonical'
    assert edge.target_node_uuid == 'target'


@pytest.mark.asyncio
async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
    """Test that dedupe_edges_bulk correctly compares edges within the same episode.

    This test verifies the fix that removed the `if i == j: continue` check,
    which was preventing edges from the same episode from being compared against each other.
    """
    clients = _make_clients()

    # Track which edges are compared
    comparisons_made = []

    # Create mock embedder that sets embedding values
    async def mock_create_embeddings(embedder, edges):
        for edge in edges:
            edge.fact_embedding = [0.1, 0.2, 0.3]

    monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)

    # Mock resolve_extracted_edge to track comparisons and mark duplicates
    async def mock_resolve_extracted_edge(
        llm_client,
        extracted_edge,
        related_edges,
        existing_edges,
        episode,
        edge_type_candidates=None,
        custom_edge_type_names=None,
    ):
        # Track that this edge was compared against the related_edges
        comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))

        # If there are related edges with same source/target/fact, mark as duplicate
        for related in related_edges:
            if (
                related.uuid != extracted_edge.uuid  # Can't be duplicate of self
                and related.source_node_uuid == extracted_edge.source_node_uuid
                and related.target_node_uuid == extracted_edge.target_node_uuid
                and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
            ):
                # Return the related edge and mark extracted_edge as duplicate
                return related, [], [related]
        # Otherwise return the extracted edge as-is
        return extracted_edge, [], []

    monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)

    episode = _make_episode('1')
    source_uuid = 'source-uuid'
    target_uuid = 'target-uuid'

    # Create 3 identical edges within the same episode
    edge1 = EntityEdge(
        name='recommends',
        fact='assistant recommends yoga poses',
        group_id='group',
        source_node_uuid=source_uuid,
        target_node_uuid=target_uuid,
        created_at=utc_now(),
        episodes=[episode.uuid],
    )
    edge2 = EntityEdge(
        name='recommends',
        fact='assistant recommends yoga poses',
        group_id='group',
        source_node_uuid=source_uuid,
        target_node_uuid=target_uuid,
        created_at=utc_now(),
        episodes=[episode.uuid],
    )
    edge3 = EntityEdge(
        name='recommends',
        fact='assistant recommends yoga poses',
        group_id='group',
        source_node_uuid=source_uuid,
        target_node_uuid=target_uuid,
        created_at=utc_now(),
        episodes=[episode.uuid],
    )

    await bulk_utils.dedupe_edges_bulk(
        clients,
        [[edge1, edge2, edge3]],
        [(episode, [])],
        [],
        {},
        {},
    )

    # Verify that edges were compared against each other (within same episode)
    # Each edge should have been compared against all 3 edges (including itself, which gets filtered)
    assert len(comparisons_made) == 3
    for _, compared_against in comparisons_made:
        # Each edge should have access to all 3 edges as candidates
        assert len(compared_against) >= 2  # At least 2 others (self is filtered out)

```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/community_operations.py:
--------------------------------------------------------------------------------

```python
import asyncio
import logging
from collections import defaultdict

from pydantic import BaseModel

from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.edges import CommunityEdge
from graphiti_core.embedder import EmbedderClient
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.edge_operations import build_community_edges

MAX_COMMUNITY_BUILD_CONCURRENCY = 10

logger = logging.getLogger(__name__)


class Neighbor(BaseModel):
    node_uuid: str
    edge_count: int


async def get_community_clusters(
    driver: GraphDriver, group_ids: list[str] | None
) -> list[list[EntityNode]]:
    community_clusters: list[list[EntityNode]] = []

    if group_ids is None:
        group_id_values, _, _ = await driver.execute_query(
            """
            MATCH (n:Entity)
            WHERE n.group_id IS NOT NULL
            RETURN
                collect(DISTINCT n.group_id) AS group_ids
            """
        )

        group_ids = group_id_values[0]['group_ids'] if group_id_values else []

    for group_id in group_ids:
        projection: dict[str, list[Neighbor]] = {}
        nodes = await EntityNode.get_by_group_ids(driver, [group_id])
        for node in nodes:
            match_query = """
                MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
            """
            if driver.provider == GraphProvider.KUZU:
                match_query = """
                MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
                """
            records, _, _ = await driver.execute_query(
                match_query
                + """
                WITH count(e) AS count, m.uuid AS uuid
                RETURN
                    uuid,
                    count
                """,
                uuid=node.uuid,
                group_id=group_id,
            )

            projection[node.uuid] = [
                Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
            ]

        cluster_uuids = label_propagation(projection)

        community_clusters.extend(
            list(
                await semaphore_gather(
                    *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
                )
            )
        )

    return community_clusters


def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
    # Implement the label propagation community detection algorithm.
    # 1. Start with each node being assigned its own community
    # 2. Each node will take on the community of the plurality of its neighbors
    # 3. Ties are broken by going to the largest community
    # 4. Continue until no communities change during propagation

    community_map = {uuid: i for i, uuid in enumerate(projection.keys())}

    while True:
        no_change = True
        new_community_map: dict[str, int] = {}

        for uuid, neighbors in projection.items():
            curr_community = community_map[uuid]

            community_candidates: dict[int, int] = defaultdict(int)
            for neighbor in neighbors:
                community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
            community_lst = [
                (count, community) for community, count in community_candidates.items()
            ]

            community_lst.sort(reverse=True)
            candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1)
            if community_candidate != -1 and candidate_rank > 1:
                new_community = community_candidate
            else:
                new_community = max(community_candidate, curr_community)

            new_community_map[uuid] = new_community

            if new_community != curr_community:
                no_change = False

        if no_change:
            break

        community_map = new_community_map

    community_cluster_map = defaultdict(list)
    for uuid, community in community_map.items():
        community_cluster_map[community].append(uuid)

    clusters = [cluster for cluster in community_cluster_map.values()]
    return clusters


async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
    # Prepare context for LLM
    context = {
        'node_summaries': [{'summary': summary} for summary in summary_pair],
    }

    llm_response = await llm_client.generate_response(
        prompt_library.summarize_nodes.summarize_pair(context),
        response_model=Summary,
        prompt_name='summarize_nodes.summarize_pair',
    )

    pair_summary = llm_response.get('summary', '')

    return pair_summary


async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
    context = {
        'summary': summary,
    }

    llm_response = await llm_client.generate_response(
        prompt_library.summarize_nodes.summary_description(context),
        response_model=SummaryDescription,
        prompt_name='summarize_nodes.summary_description',
    )

    description = llm_response.get('description', '')

    return description


async def build_community(
    llm_client: LLMClient, community_cluster: list[EntityNode]
) -> tuple[CommunityNode, list[CommunityEdge]]:
    summaries = [entity.summary for entity in community_cluster]
    length = len(summaries)
    while length > 1:
        odd_one_out: str | None = None
        if length % 2 == 1:
            odd_one_out = summaries.pop()
            length -= 1
        new_summaries: list[str] = list(
            await semaphore_gather(
                *[
                    summarize_pair(llm_client, (str(left_summary), str(right_summary)))
                    for left_summary, right_summary in zip(
                        summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
                    )
                ]
            )
        )
        if odd_one_out is not None:
            new_summaries.append(odd_one_out)
        summaries = new_summaries
        length = len(summaries)

    summary = summaries[0]
    name = await generate_summary_description(llm_client, summary)
    now = utc_now()
    community_node = CommunityNode(
        name=name,
        group_id=community_cluster[0].group_id,
        labels=['Community'],
        created_at=now,
        summary=summary,
    )
    community_edges = build_community_edges(community_cluster, community_node, now)

    logger.debug((community_node, community_edges))

    return community_node, community_edges


async def build_communities(
    driver: GraphDriver,
    llm_client: LLMClient,
    group_ids: list[str] | None,
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
    community_clusters = await get_community_clusters(driver, group_ids)

    semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)

    async def limited_build_community(cluster):
        async with semaphore:
            return await build_community(llm_client, cluster)

    communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
        await semaphore_gather(
            *[limited_build_community(cluster) for cluster in community_clusters]
        )
    )

    community_nodes: list[CommunityNode] = []
    community_edges: list[CommunityEdge] = []
    for community in communities:
        community_nodes.append(community[0])
        community_edges.extend(community[1])

    return community_nodes, community_edges


async def remove_communities(driver: GraphDriver):
    await driver.execute_query(
        """
        MATCH (c:Community)
        DETACH DELETE c
        """
    )


async def determine_entity_community(
    driver: GraphDriver, entity: EntityNode
) -> tuple[CommunityNode | None, bool]:
    # Check if the node is already part of a community
    records, _, _ = await driver.execute_query(
        """
        MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
        RETURN
        """
        + COMMUNITY_NODE_RETURN,
        entity_uuid=entity.uuid,
    )

    if len(records) > 0:
        return get_community_node_from_record(records[0]), False

    # If the node has no community, add it to the mode community of surrounding entities
    match_query = """
        MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
    """
    if driver.provider == GraphProvider.KUZU:
        match_query = """
            MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
        """
    records, _, _ = await driver.execute_query(
        match_query
        + """
        RETURN
        """
        + COMMUNITY_NODE_RETURN,
        entity_uuid=entity.uuid,
    )

    communities: list[CommunityNode] = [
        get_community_node_from_record(record) for record in records
    ]

    community_map: dict[str, int] = defaultdict(int)
    for community in communities:
        community_map[community.uuid] += 1

    community_uuid = None
    max_count = 0
    for uuid, count in community_map.items():
        if count > max_count:
            community_uuid = uuid
            max_count = count

    if max_count == 0:
        return None, False

    for community in communities:
        if community.uuid == community_uuid:
            return community, True

    return None, False


async def update_community(
    driver: GraphDriver,
    llm_client: LLMClient,
    embedder: EmbedderClient,
    entity: EntityNode,
) -> tuple[list[CommunityNode], list[CommunityEdge]]:
    community, is_new = await determine_entity_community(driver, entity)

    if community is None:
        return [], []

    new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
    new_name = await generate_summary_description(llm_client, new_summary)

    community.summary = new_summary
    community.name = new_name

    community_edges = []
    if is_new:
        community_edge = (build_community_edges([entity], community, utc_now()))[0]
        await community_edge.save(driver)
        community_edges.append(community_edge)

    await community.generate_name_embedding(embedder)

    await community.save(driver)

    return [community], community_edges

```

--------------------------------------------------------------------------------
/graphiti_core/prompts/extract_nodes.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from typing import Any, Protocol, TypedDict

from pydantic import BaseModel, Field

from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS

from .models import Message, PromptFunction, PromptVersion
from .prompt_helpers import to_prompt_json
from .snippets import summary_instructions


class ExtractedEntity(BaseModel):
    name: str = Field(..., description='Name of the extracted entity')
    entity_type_id: int = Field(
        description='ID of the classified entity type. '
        'Must be one of the provided entity_type_id integers.',
    )


class ExtractedEntities(BaseModel):
    extracted_entities: list[ExtractedEntity] = Field(..., description='List of extracted entities')


class MissedEntities(BaseModel):
    missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")


class EntityClassificationTriple(BaseModel):
    uuid: str = Field(description='UUID of the entity')
    name: str = Field(description='Name of the entity')
    entity_type: str | None = Field(
        default=None,
        description='Type of the entity. Must be one of the provided types or None',
    )


class EntityClassification(BaseModel):
    entity_classifications: list[EntityClassificationTriple] = Field(
        ..., description='List of entities classification triples.'
    )


class EntitySummary(BaseModel):
    summary: str = Field(
        ...,
        description=f'Summary containing the important information about the entity. Under {MAX_SUMMARY_CHARS} characters.',
    )


class Prompt(Protocol):
    extract_message: PromptVersion
    extract_json: PromptVersion
    extract_text: PromptVersion
    reflexion: PromptVersion
    classify_nodes: PromptVersion
    extract_attributes: PromptVersion
    extract_summary: PromptVersion


class Versions(TypedDict):
    extract_message: PromptFunction
    extract_json: PromptFunction
    extract_text: PromptFunction
    reflexion: PromptFunction
    classify_nodes: PromptFunction
    extract_attributes: PromptFunction
    extract_summary: PromptFunction


def extract_message(context: dict[str, Any]) -> list[Message]:
    sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages. 
    Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""

    user_prompt = f"""
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>

<PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']])}
</PREVIOUS MESSAGES>

<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>

Instructions:

You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the 
reference entities. Only extract distinct entities from the CURRENT MESSAGE. Don't extract pronouns like you, me, he/she/they, we/us as entities.

1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node.
   - If the speaker is mentioned again in the message, treat both mentions as a **single entity**.

2. **Entity Identification**:
   - Extract all significant entities, concepts, or actors that are **explicitly or implicitly** mentioned in the CURRENT MESSAGE.
   - **Exclude** entities mentioned only in the PREVIOUS MESSAGES (they are for context only).

3. **Entity Classification**:
   - Use the descriptions in ENTITY TYPES to classify each extracted entity.
   - Assign the appropriate `entity_type_id` for each one.

4. **Exclusions**:
   - Do NOT extract entities representing relationships or actions.
   - Do NOT extract dates, times, or other temporal information—these will be handled separately.

5. **Formatting**:
   - Be **explicit and unambiguous** in naming entities (e.g., use full names when available).

{context['custom_prompt']}
"""
    return [
        Message(role='system', content=sys_prompt),
        Message(role='user', content=user_prompt),
    ]


def extract_json(context: dict[str, Any]) -> list[Message]:
    sys_prompt = """You are an AI assistant that extracts entity nodes from JSON. 
    Your primary task is to extract and classify relevant entities from JSON files"""

    user_prompt = f"""
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>

<SOURCE DESCRIPTION>:
{context['source_description']}
</SOURCE DESCRIPTION>
<JSON>
{context['episode_content']}
</JSON>

{context['custom_prompt']}

Given the above source description and JSON, extract relevant entities from the provided JSON.
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
Indicate the classified entity type by providing its entity_type_id.

Guidelines:
1. Extract all entities that the JSON represents. This will often be something like a "name" or "user" field
2. Extract all entities mentioned in all other properties throughout the JSON structure
3. Do NOT extract any properties that contain dates
"""
    return [
        Message(role='system', content=sys_prompt),
        Message(role='user', content=user_prompt),
    ]


def extract_text(context: dict[str, Any]) -> list[Message]:
    sys_prompt = """You are an AI assistant that extracts entity nodes from text. 
    Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""

    user_prompt = f"""
<ENTITY TYPES>
{context['entity_types']}
</ENTITY TYPES>

<TEXT>
{context['episode_content']}
</TEXT>

Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
Indicate the classified entity type by providing its entity_type_id.

{context['custom_prompt']}

Guidelines:
1. Extract significant entities, concepts, or actors mentioned in the conversation.
2. Avoid creating nodes for relationships or actions.
3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
"""
    return [
        Message(role='system', content=sys_prompt),
        Message(role='user', content=user_prompt),
    ]


def reflexion(context: dict[str, Any]) -> list[Message]:
    sys_prompt = """You are an AI assistant that determines which entities have not been extracted from the given context"""

    user_prompt = f"""
<PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']])}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>

<EXTRACTED ENTITIES>
{context['extracted_entities']}
</EXTRACTED ENTITIES>

Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
extracted.
"""
    return [
        Message(role='system', content=sys_prompt),
        Message(role='user', content=user_prompt),
    ]


def classify_nodes(context: dict[str, Any]) -> list[Message]:
    sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""

    user_prompt = f"""
    <PREVIOUS MESSAGES>
    {to_prompt_json([ep for ep in context['previous_episodes']])}
    </PREVIOUS MESSAGES>
    <CURRENT MESSAGE>
    {context['episode_content']}
    </CURRENT MESSAGE>

    <EXTRACTED ENTITIES>
    {context['extracted_entities']}
    </EXTRACTED ENTITIES>

    <ENTITY TYPES>
    {context['entity_types']}
    </ENTITY TYPES>

    Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.

    Guidelines:
    1. Each entity must have exactly one type
    2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
    3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
"""
    return [
        Message(role='system', content=sys_prompt),
        Message(role='user', content=user_prompt),
    ]


def extract_attributes(context: dict[str, Any]) -> list[Message]:
    return [
        Message(
            role='system',
            content='You are a helpful assistant that extracts entity properties from the provided text.',
        ),
        Message(
            role='user',
            content=f"""
        Given the MESSAGES and the following ENTITY, update any of its attributes based on the information provided
        in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined.

        Guidelines:
        1. Do not hallucinate entity property values if they cannot be found in the current context.
        2. Only use the provided MESSAGES and ENTITY to set attribute values.

        <MESSAGES>
        {to_prompt_json(context['previous_episodes'])}
        {to_prompt_json(context['episode_content'])}
        </MESSAGES>

        <ENTITY>
        {context['node']}
        </ENTITY>
        """,
        ),
    ]


def extract_summary(context: dict[str, Any]) -> list[Message]:
    return [
        Message(
            role='system',
            content='You are a helpful assistant that extracts entity summaries from the provided text.',
        ),
        Message(
            role='user',
            content=f"""
        Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity
        from the messages and relevant information from the existing summary.

        {summary_instructions}

        <MESSAGES>
        {to_prompt_json(context['previous_episodes'])}
        {to_prompt_json(context['episode_content'])}
        </MESSAGES>

        <ENTITY>
        {context['node']}
        </ENTITY>
        """,
        ),
    ]


versions: Versions = {
    'extract_message': extract_message,
    'extract_json': extract_json,
    'extract_text': extract_text,
    'reflexion': reflexion,
    'extract_summary': extract_summary,
    'classify_nodes': classify_nodes,
    'extract_attributes': extract_attributes,
}

```

--------------------------------------------------------------------------------
/graphiti_core/models/edges/edge_db_queries.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from graphiti_core.driver.driver import GraphProvider

EPISODIC_EDGE_SAVE = """
    MATCH (episode:Episodic {uuid: $episode_uuid})
    MATCH (node:Entity {uuid: $entity_uuid})
    MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
    SET
        e.group_id = $group_id,
        e.created_at = $created_at
    RETURN e.uuid AS uuid
"""


def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
    if provider == GraphProvider.KUZU:
        return """
            MATCH (episode:Episodic {uuid: $source_node_uuid})
            MATCH (node:Entity {uuid: $target_node_uuid})
            MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
            SET
                e.group_id = $group_id,
                e.created_at = $created_at
            RETURN e.uuid AS uuid
        """

    return """
        UNWIND $episodic_edges AS edge
        MATCH (episode:Episodic {uuid: edge.source_node_uuid})
        MATCH (node:Entity {uuid: edge.target_node_uuid})
        MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
        SET
            e.group_id = edge.group_id,
            e.created_at = edge.created_at
        RETURN e.uuid AS uuid
    """


EPISODIC_EDGE_RETURN = """
    e.uuid AS uuid,
    e.group_id AS group_id,
    n.uuid AS source_node_uuid,
    m.uuid AS target_node_uuid,
    e.created_at AS created_at
"""


def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
    match provider:
        case GraphProvider.FALKORDB:
            return """
                MATCH (source:Entity {uuid: $edge_data.source_uuid})
                MATCH (target:Entity {uuid: $edge_data.target_uuid})
                MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
                SET e = $edge_data
                SET e.fact_embedding = vecf32($edge_data.fact_embedding)
                RETURN e.uuid AS uuid
            """
        case GraphProvider.NEPTUNE:
            return """
                MATCH (source:Entity {uuid: $edge_data.source_uuid})
                MATCH (target:Entity {uuid: $edge_data.target_uuid})
                MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
                SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
                SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
                SET e.episodes = join($edge_data.episodes, ",")
                RETURN $edge_data.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MATCH (source:Entity {uuid: $source_uuid})
                MATCH (target:Entity {uuid: $target_uuid})
                MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
                SET
                    e.group_id = $group_id,
                    e.created_at = $created_at,
                    e.name = $name,
                    e.fact = $fact,
                    e.fact_embedding = $fact_embedding,
                    e.episodes = $episodes,
                    e.expired_at = $expired_at,
                    e.valid_at = $valid_at,
                    e.invalid_at = $invalid_at,
                    e.attributes = $attributes
                RETURN e.uuid AS uuid
            """
        case _:  # Neo4j
            save_embedding_query = (
                """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
                if not has_aoss
                else ''
            )
            return (
                (
                    """
                        MATCH (source:Entity {uuid: $edge_data.source_uuid})
                        MATCH (target:Entity {uuid: $edge_data.target_uuid})
                        MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
                        SET e = $edge_data
                        """
                    + save_embedding_query
                )
                + """
                RETURN e.uuid AS uuid
                """
            )


def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
    match provider:
        case GraphProvider.FALKORDB:
            return """
                UNWIND $entity_edges AS edge
                MATCH (source:Entity {uuid: edge.source_node_uuid})
                MATCH (target:Entity {uuid: edge.target_node_uuid})
                MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
                SET r = edge
                SET r.fact_embedding = vecf32(edge.fact_embedding)
                WITH r, edge
                RETURN edge.uuid AS uuid
            """
        case GraphProvider.NEPTUNE:
            return """
                UNWIND $entity_edges AS edge
                MATCH (source:Entity {uuid: edge.source_node_uuid})
                MATCH (target:Entity {uuid: edge.target_node_uuid})
                MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
                SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
                SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
                SET r.episodes = join(edge.episodes, ",")
                RETURN edge.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MATCH (source:Entity {uuid: $source_node_uuid})
                MATCH (target:Entity {uuid: $target_node_uuid})
                MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
                SET
                    e.group_id = $group_id,
                    e.created_at = $created_at,
                    e.name = $name,
                    e.fact = $fact,
                    e.fact_embedding = $fact_embedding,
                    e.episodes = $episodes,
                    e.expired_at = $expired_at,
                    e.valid_at = $valid_at,
                    e.invalid_at = $invalid_at,
                    e.attributes = $attributes
                RETURN e.uuid AS uuid
            """
        case _:
            save_embedding_query = (
                'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
                if not has_aoss
                else ''
            )
            return (
                """
                    UNWIND $entity_edges AS edge
                    MATCH (source:Entity {uuid: edge.source_node_uuid})
                    MATCH (target:Entity {uuid: edge.target_node_uuid})
                    MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
                    SET e = edge
                    """
                + save_embedding_query
                + """
                RETURN edge.uuid AS uuid
            """
            )


def get_entity_edge_return_query(provider: GraphProvider) -> str:
    # `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.

    if provider == GraphProvider.NEPTUNE:
        return """
        e.uuid AS uuid,
        n.uuid AS source_node_uuid,
        m.uuid AS target_node_uuid,
        e.group_id AS group_id,
        e.name AS name,
        e.fact AS fact,
        split(e.episodes, ',') AS episodes,
        e.created_at AS created_at,
        e.expired_at AS expired_at,
        e.valid_at AS valid_at,
        e.invalid_at AS invalid_at,
        properties(e) AS attributes
    """

    return """
        e.uuid AS uuid,
        n.uuid AS source_node_uuid,
        m.uuid AS target_node_uuid,
        e.group_id AS group_id,
        e.created_at AS created_at,
        e.name AS name,
        e.fact AS fact,
        e.episodes AS episodes,
        e.expired_at AS expired_at,
        e.valid_at AS valid_at,
        e.invalid_at AS invalid_at,
    """ + (
        'e.attributes AS attributes'
        if provider == GraphProvider.KUZU
        else 'properties(e) AS attributes'
    )


def get_community_edge_save_query(provider: GraphProvider) -> str:
    match provider:
        case GraphProvider.FALKORDB:
            return """
                MATCH (community:Community {uuid: $community_uuid})
                MATCH (node {uuid: $entity_uuid})
                MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
                SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
                RETURN e.uuid AS uuid
            """
        case GraphProvider.NEPTUNE:
            return """
                MATCH (community:Community {uuid: $community_uuid})
                MATCH (node {uuid: $entity_uuid})
                WHERE node:Entity OR node:Community
                MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
                SET r.uuid= $uuid
                SET r.group_id= $group_id
                SET r.created_at= $created_at
                RETURN r.uuid AS uuid
            """
        case GraphProvider.KUZU:
            return """
                MATCH (community:Community {uuid: $community_uuid})
                MATCH (node:Entity {uuid: $entity_uuid})
                MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
                SET
                    e.group_id = $group_id,
                    e.created_at = $created_at
                RETURN e.uuid AS uuid
                UNION
                MATCH (community:Community {uuid: $community_uuid})
                MATCH (node:Community {uuid: $entity_uuid})
                MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
                SET
                    e.group_id = $group_id,
                    e.created_at = $created_at
                RETURN e.uuid AS uuid
            """
        case _:  # Neo4j
            return """
                MATCH (community:Community {uuid: $community_uuid})
                MATCH (node:Entity | Community {uuid: $entity_uuid})
                MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
                SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
                RETURN e.uuid AS uuid
            """


COMMUNITY_EDGE_RETURN = """
    e.uuid AS uuid,
    e.group_id AS group_id,
    n.uuid AS source_node_uuid,
    m.uuid AS target_node_uuid,
    e.created_at AS created_at
"""

```

--------------------------------------------------------------------------------
/mcp_server/tests/run_tests.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Test runner for Graphiti MCP integration tests.
Provides various test execution modes and reporting options.
"""

import argparse
import os
import sys
import time
from pathlib import Path

import pytest
from dotenv import load_dotenv

# Load environment variables from .env file
env_file = Path(__file__).parent.parent / '.env'
if env_file.exists():
    load_dotenv(env_file)
else:
    # Try loading from current directory
    load_dotenv()


class TestRunner:
    """Orchestrate test execution with various configurations."""

    def __init__(self, args):
        self.args = args
        self.test_dir = Path(__file__).parent
        self.results = {}

    def check_prerequisites(self) -> dict[str, bool]:
        """Check if required services and dependencies are available."""
        checks = {}

        # Check for OpenAI API key if not using mocks
        if not self.args.mock_llm:
            api_key = os.environ.get('OPENAI_API_KEY')
            checks['openai_api_key'] = bool(api_key)
            if not api_key:
                # Check if .env file exists for helpful message
                env_path = Path(__file__).parent.parent / '.env'
                if not env_path.exists():
                    checks['openai_api_key_hint'] = (
                        'Set OPENAI_API_KEY in environment or create mcp_server/.env file'
                    )
        else:
            checks['openai_api_key'] = True

        # Check database availability based on backend
        if self.args.database == 'neo4j':
            checks['neo4j'] = self._check_neo4j()
        elif self.args.database == 'falkordb':
            checks['falkordb'] = self._check_falkordb()

        # Check Python dependencies
        checks['mcp'] = self._check_python_package('mcp')
        checks['pytest'] = self._check_python_package('pytest')
        checks['pytest-asyncio'] = self._check_python_package('pytest-asyncio')

        return checks

    def _check_neo4j(self) -> bool:
        """Check if Neo4j is available."""
        try:
            import neo4j

            # Try to connect
            uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
            user = os.environ.get('NEO4J_USER', 'neo4j')
            password = os.environ.get('NEO4J_PASSWORD', 'graphiti')

            driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
            with driver.session() as session:
                session.run('RETURN 1')
            driver.close()
            return True
        except Exception:
            return False

    def _check_falkordb(self) -> bool:
        """Check if FalkorDB is available."""
        try:
            import redis

            uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
            r = redis.from_url(uri)
            r.ping()
            return True
        except Exception:
            return False

    def _check_python_package(self, package: str) -> bool:
        """Check if a Python package is installed."""
        try:
            __import__(package.replace('-', '_'))
            return True
        except ImportError:
            return False

    def run_test_suite(self, suite: str) -> int:
        """Run a specific test suite."""
        pytest_args = ['-v', '--tb=short']

        # Add database marker
        if self.args.database:
            for db in ['neo4j', 'falkordb']:
                if db != self.args.database:
                    pytest_args.extend(['-m', f'not requires_{db}'])

        # Add suite-specific arguments
        if suite == 'unit':
            pytest_args.extend(['-m', 'unit', 'test_*.py'])
        elif suite == 'integration':
            pytest_args.extend(['-m', 'integration or not unit', 'test_*.py'])
        elif suite == 'comprehensive':
            pytest_args.append('test_comprehensive_integration.py')
        elif suite == 'async':
            pytest_args.append('test_async_operations.py')
        elif suite == 'stress':
            pytest_args.extend(['-m', 'slow', 'test_stress_load.py'])
        elif suite == 'smoke':
            # Quick smoke test - just basic operations
            pytest_args.extend(
                [
                    'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization',
                    'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory',
                ]
            )
        elif suite == 'all':
            pytest_args.append('.')
        else:
            pytest_args.append(suite)

        # Add coverage if requested
        if self.args.coverage:
            pytest_args.extend(['--cov=../src', '--cov-report=html'])

        # Add parallel execution if requested
        if self.args.parallel:
            pytest_args.extend(['-n', str(self.args.parallel)])

        # Add verbosity
        if self.args.verbose:
            pytest_args.append('-vv')

        # Add markers to skip
        if self.args.skip_slow:
            pytest_args.extend(['-m', 'not slow'])

        # Add timeout override
        if self.args.timeout:
            pytest_args.extend(['--timeout', str(self.args.timeout)])

        # Add environment variables
        env = os.environ.copy()
        if self.args.mock_llm:
            env['USE_MOCK_LLM'] = 'true'
        if self.args.database:
            env['DATABASE_PROVIDER'] = self.args.database

        # Run tests from the test directory
        print(f'Running {suite} tests with pytest args: {" ".join(pytest_args)}')

        # Change to test directory to run tests
        original_dir = os.getcwd()
        os.chdir(self.test_dir)

        try:
            result = pytest.main(pytest_args)
        finally:
            os.chdir(original_dir)

        return result

    def run_performance_benchmark(self):
        """Run performance benchmarking suite."""
        print('Running performance benchmarks...')

        # Import test modules

        # Run performance tests
        result = pytest.main(
            [
                '-v',
                'test_comprehensive_integration.py::TestPerformance',
                'test_async_operations.py::TestAsyncPerformance',
                '--benchmark-only' if self.args.benchmark_only else '',
            ]
        )

        return result

    def generate_report(self):
        """Generate test execution report."""
        report = []
        report.append('\n' + '=' * 60)
        report.append('GRAPHITI MCP TEST EXECUTION REPORT')
        report.append('=' * 60)

        # Prerequisites check
        checks = self.check_prerequisites()
        report.append('\nPrerequisites:')
        for check, passed in checks.items():
            status = '✅' if passed else '❌'
            report.append(f'  {status} {check}')

        # Test configuration
        report.append('\nConfiguration:')
        report.append(f'  Database: {self.args.database}')
        report.append(f'  Mock LLM: {self.args.mock_llm}')
        report.append(f'  Parallel: {self.args.parallel or "No"}')
        report.append(f'  Timeout: {self.args.timeout}s')

        # Results summary (if available)
        if self.results:
            report.append('\nResults:')
            for suite, result in self.results.items():
                status = '✅ Passed' if result == 0 else f'❌ Failed ({result})'
                report.append(f'  {suite}: {status}')

        report.append('=' * 60)
        return '\n'.join(report)


def main():
    """Main entry point for test runner."""
    parser = argparse.ArgumentParser(
        description='Run Graphiti MCP integration tests',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Test Suites:
  unit          - Run unit tests only
  integration   - Run integration tests
  comprehensive - Run comprehensive integration test suite
  async         - Run async operation tests
  stress        - Run stress and load tests
  smoke         - Run quick smoke tests
  all           - Run all tests

Examples:
  python run_tests.py smoke                    # Quick smoke test
  python run_tests.py integration --parallel 4 # Run integration tests in parallel
  python run_tests.py stress --database neo4j  # Run stress tests with Neo4j
  python run_tests.py all --coverage          # Run all tests with coverage
        """,
    )

    parser.add_argument(
        'suite',
        choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'],
        help='Test suite to run',
    )

    parser.add_argument(
        '--database',
        choices=['neo4j', 'falkordb'],
        default='falkordb',
        help='Database backend to test (default: falkordb)',
    )

    parser.add_argument('--mock-llm', action='store_true', help='Use mock LLM for faster testing')

    parser.add_argument(
        '--parallel', type=int, metavar='N', help='Run tests in parallel with N workers'
    )

    parser.add_argument('--coverage', action='store_true', help='Generate coverage report')

    parser.add_argument('--verbose', action='store_true', help='Verbose output')

    parser.add_argument('--skip-slow', action='store_true', help='Skip slow tests')

    parser.add_argument(
        '--timeout', type=int, default=300, help='Test timeout in seconds (default: 300)'
    )

    parser.add_argument('--benchmark-only', action='store_true', help='Run only benchmark tests')

    parser.add_argument(
        '--check-only', action='store_true', help='Only check prerequisites without running tests'
    )

    args = parser.parse_args()

    # Create test runner
    runner = TestRunner(args)

    # Check prerequisites
    if args.check_only:
        print(runner.generate_report())
        sys.exit(0)

    # Check if prerequisites are met
    checks = runner.check_prerequisites()
    # Filter out hint keys from validation
    validation_checks = {k: v for k, v in checks.items() if not k.endswith('_hint')}

    if not all(validation_checks.values()):
        print('⚠️  Some prerequisites are not met:')
        for check, passed in checks.items():
            if check.endswith('_hint'):
                continue  # Skip hint entries
            if not passed:
                print(f'  ❌ {check}')
                # Show hint if available
                hint_key = f'{check}_hint'
                if hint_key in checks:
                    print(f'     💡 {checks[hint_key]}')

        if not args.mock_llm and not checks.get('openai_api_key'):
            print('\n💡 Tip: Use --mock-llm to run tests without OpenAI API key')

        response = input('\nContinue anyway? (y/N): ')
        if response.lower() != 'y':
            sys.exit(1)

    # Run tests
    print(f'\n🚀 Starting test execution: {args.suite}')
    start_time = time.time()

    if args.benchmark_only:
        result = runner.run_performance_benchmark()
    else:
        result = runner.run_test_suite(args.suite)

    duration = time.time() - start_time

    # Store results
    runner.results[args.suite] = result

    # Generate and print report
    print(runner.generate_report())
    print(f'\n⏱️  Test execution completed in {duration:.2f} seconds')

    # Exit with test result code
    sys.exit(result)


if __name__ == '__main__':
    main()

```

--------------------------------------------------------------------------------
/graphiti_core/driver/neptune_driver.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import datetime
import logging
from collections.abc import Coroutine
from typing import Any

import boto3
from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers

from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider

logger = logging.getLogger(__name__)
DEFAULT_SIZE = 10

aoss_indices = [
    {
        'index_name': 'node_name_and_summary',
        'body': {
            'mappings': {
                'properties': {
                    'uuid': {'type': 'keyword'},
                    'name': {'type': 'text'},
                    'summary': {'type': 'text'},
                    'group_id': {'type': 'text'},
                }
            }
        },
        'query': {
            'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
            'size': DEFAULT_SIZE,
        },
    },
    {
        'index_name': 'community_name',
        'body': {
            'mappings': {
                'properties': {
                    'uuid': {'type': 'keyword'},
                    'name': {'type': 'text'},
                    'group_id': {'type': 'text'},
                }
            }
        },
        'query': {
            'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
            'size': DEFAULT_SIZE,
        },
    },
    {
        'index_name': 'episode_content',
        'body': {
            'mappings': {
                'properties': {
                    'uuid': {'type': 'keyword'},
                    'content': {'type': 'text'},
                    'source': {'type': 'text'},
                    'source_description': {'type': 'text'},
                    'group_id': {'type': 'text'},
                }
            }
        },
        'query': {
            'query': {
                'multi_match': {
                    'query': '',
                    'fields': ['content', 'source', 'source_description', 'group_id'],
                }
            },
            'size': DEFAULT_SIZE,
        },
    },
    {
        'index_name': 'edge_name_and_fact',
        'body': {
            'mappings': {
                'properties': {
                    'uuid': {'type': 'keyword'},
                    'name': {'type': 'text'},
                    'fact': {'type': 'text'},
                    'group_id': {'type': 'text'},
                }
            }
        },
        'query': {
            'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
            'size': DEFAULT_SIZE,
        },
    },
]


class NeptuneDriver(GraphDriver):
    provider: GraphProvider = GraphProvider.NEPTUNE

    def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
        """This initializes a NeptuneDriver for use with Neptune as a backend

        Args:
            host (str): The Neptune Database or Neptune Analytics host
            aoss_host (str): The OpenSearch host value
            port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
            aoss_port (int, optional): The OpenSearch port. Defaults to 443.
        """
        if not host:
            raise ValueError('You must provide an endpoint to create a NeptuneDriver')

        if host.startswith('neptune-db://'):
            # This is a Neptune Database Cluster
            endpoint = host.replace('neptune-db://', '')
            self.client = NeptuneGraph(endpoint, port)
            logger.debug('Creating Neptune Database session for %s', host)
        elif host.startswith('neptune-graph://'):
            # This is a Neptune Analytics Graph
            graphId = host.replace('neptune-graph://', '')
            self.client = NeptuneAnalyticsGraph(graphId)
            logger.debug('Creating Neptune Graph session for %s', host)
        else:
            raise ValueError(
                'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
            )

        if not aoss_host:
            raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')

        session = boto3.Session()
        self.aoss_client = OpenSearch(
            hosts=[{'host': aoss_host, 'port': aoss_port}],
            http_auth=Urllib3AWSV4SignerAuth(
                session.get_credentials(), session.region_name, 'aoss'
            ),
            use_ssl=True,
            verify_certs=True,
            connection_class=Urllib3HttpConnection,
            pool_maxsize=20,
        )

    def _sanitize_parameters(self, query, params: dict):
        if isinstance(query, list):
            queries = []
            for q in query:
                queries.append(self._sanitize_parameters(q, params))
            return queries
        else:
            for k, v in params.items():
                if isinstance(v, datetime.datetime):
                    params[k] = v.isoformat()
                elif isinstance(v, list):
                    # Handle lists that might contain datetime objects
                    for i, item in enumerate(v):
                        if isinstance(item, datetime.datetime):
                            v[i] = item.isoformat()
                            query = str(query).replace(f'${k}', f'datetime(${k})')
                        if isinstance(item, dict):
                            query = self._sanitize_parameters(query, v[i])

                    # If the list contains datetime objects, we need to wrap each element with datetime()
                    if any(isinstance(item, str) and 'T' in item for item in v):
                        # Create a new list expression with datetime() wrapped around each element
                        datetime_list = (
                            '['
                            + ', '.join(
                                f'datetime("{item}")'
                                if isinstance(item, str) and 'T' in item
                                else repr(item)
                                for item in v
                            )
                            + ']'
                        )
                        query = str(query).replace(f'${k}', datetime_list)
                elif isinstance(v, dict):
                    query = self._sanitize_parameters(query, v)
            return query

    async def execute_query(
        self, cypher_query_, **kwargs: Any
    ) -> tuple[dict[str, Any], None, None]:
        params = dict(kwargs)
        if isinstance(cypher_query_, list):
            for q in cypher_query_:
                result, _, _ = self._run_query(q[0], q[1])
            return result, None, None
        else:
            return self._run_query(cypher_query_, params)

    def _run_query(self, cypher_query_, params):
        cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
        try:
            result = self.client.query(cypher_query_, params=params)
        except Exception as e:
            logger.error('Query: %s', cypher_query_)
            logger.error('Parameters: %s', params)
            logger.error('Error executing query: %s', e)
            raise e

        return result, None, None

    def session(self, database: str | None = None) -> GraphDriverSession:
        return NeptuneDriverSession(driver=self)

    async def close(self) -> None:
        return self.client.client.close()

    async def _delete_all_data(self) -> Any:
        return await self.execute_query('MATCH (n) DETACH DELETE n')

    def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
        return self.delete_all_indexes_impl()

    async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
        # No matter what happens above, always return True
        return self.delete_aoss_indices()

    async def create_aoss_indices(self):
        for index in aoss_indices:
            index_name = index['index_name']
            client = self.aoss_client
            if not client.indices.exists(index=index_name):
                client.indices.create(index=index_name, body=index['body'])
        # Sleep for 1 minute to let the index creation complete
        await asyncio.sleep(60)

    async def delete_aoss_indices(self):
        for index in aoss_indices:
            index_name = index['index_name']
            client = self.aoss_client
            if client.indices.exists(index=index_name):
                client.indices.delete(index=index_name)

    async def build_indices_and_constraints(self, delete_existing: bool = False):
        # Neptune uses OpenSearch (AOSS) for indexing
        if delete_existing:
            await self.delete_aoss_indices()
        await self.create_aoss_indices()

    def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
        for index in aoss_indices:
            if name.lower() == index['index_name']:
                index['query']['query']['multi_match']['query'] = query_text
                query = {'size': limit, 'query': index['query']}
                resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
                return resp
        return {}

    def save_to_aoss(self, name: str, data: list[dict]) -> int:
        for index in aoss_indices:
            if name.lower() == index['index_name']:
                to_index = []
                for d in data:
                    item = {'_index': name, '_id': d['uuid']}
                    for p in index['body']['mappings']['properties']:
                        if p in d:
                            item[p] = d[p]
                    to_index.append(item)
                success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
                return success

        return 0


class NeptuneDriverSession(GraphDriverSession):
    provider = GraphProvider.NEPTUNE

    def __init__(self, driver: NeptuneDriver):  # type: ignore[reportUnknownArgumentType]
        self.driver = driver

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        # No cleanup needed for Neptune, but method must exist
        pass

    async def close(self):
        # No explicit close needed for Neptune, but method must exist
        pass

    async def execute_write(self, func, *args, **kwargs):
        # Directly await the provided async function with `self` as the transaction/session
        return await func(self, *args, **kwargs)

    async def run(self, query: str | list, **kwargs: Any) -> Any:
        if isinstance(query, list):
            res = None
            for q in query:
                res = await self.driver.execute_query(q, **kwargs)
            return res
        else:
            return await self.driver.execute_query(str(query), **kwargs)

```

--------------------------------------------------------------------------------
/signatures/version1/cla.json:
--------------------------------------------------------------------------------

```json
{
  "signedContributors": [
    {
      "name": "colombod",
      "id": 375556,
      "comment_id": 2761979440,
      "created_at": "2025-03-28T17:21:29Z",
      "repoId": 840056306,
      "pullRequestNo": 310
    },
    {
      "name": "evanmschultz",
      "id": 3806601,
      "comment_id": 2813673237,
      "created_at": "2025-04-17T17:56:24Z",
      "repoId": 840056306,
      "pullRequestNo": 372
    },
    {
      "name": "soichisumi",
      "id": 30210641,
      "comment_id": 2818469528,
      "created_at": "2025-04-21T14:02:11Z",
      "repoId": 840056306,
      "pullRequestNo": 382
    },
    {
      "name": "drumnation",
      "id": 18486434,
      "comment_id": 2822330188,
      "created_at": "2025-04-22T19:51:09Z",
      "repoId": 840056306,
      "pullRequestNo": 389
    },
    {
      "name": "jackaldenryan",
      "id": 61809814,
      "comment_id": 2845356793,
      "created_at": "2025-05-01T17:51:11Z",
      "repoId": 840056306,
      "pullRequestNo": 429
    },
    {
      "name": "t41372",
      "id": 36402030,
      "comment_id": 2849035400,
      "created_at": "2025-05-04T06:24:37Z",
      "repoId": 840056306,
      "pullRequestNo": 438
    },
    {
      "name": "markalosey",
      "id": 1949914,
      "comment_id": 2878173826,
      "created_at": "2025-05-13T23:27:16Z",
      "repoId": 840056306,
      "pullRequestNo": 486
    },
    {
      "name": "adamkatav",
      "id": 13109136,
      "comment_id": 2887184706,
      "created_at": "2025-05-16T16:29:22Z",
      "repoId": 840056306,
      "pullRequestNo": 493
    },
    {
      "name": "realugbun",
      "id": 74101927,
      "comment_id": 2899731784,
      "created_at": "2025-05-22T02:36:44Z",
      "repoId": 840056306,
      "pullRequestNo": 513
    },
    {
      "name": "dudizimber",
      "id": 16744955,
      "comment_id": 2912211548,
      "created_at": "2025-05-27T11:45:57Z",
      "repoId": 840056306,
      "pullRequestNo": 525
    },
    {
      "name": "galshubeli",
      "id": 124919062,
      "comment_id": 2912289100,
      "created_at": "2025-05-27T12:15:03Z",
      "repoId": 840056306,
      "pullRequestNo": 525
    },
    {
      "name": "TheEpTic",
      "id": 326774,
      "comment_id": 2917970901,
      "created_at": "2025-05-29T01:26:54Z",
      "repoId": 840056306,
      "pullRequestNo": 541
    },
    {
      "name": "PrettyWood",
      "id": 18406791,
      "comment_id": 2938495182,
      "created_at": "2025-06-04T04:44:59Z",
      "repoId": 840056306,
      "pullRequestNo": 558
    },
    {
      "name": "denyska",
      "id": 1242726,
      "comment_id": 2957480685,
      "created_at": "2025-06-10T02:08:05Z",
      "repoId": 840056306,
      "pullRequestNo": 574
    },
    {
      "name": "LongPML",
      "id": 59755436,
      "comment_id": 2965391879,
      "created_at": "2025-06-12T07:10:01Z",
      "repoId": 840056306,
      "pullRequestNo": 579
    },
    {
      "name": "karn09",
      "id": 3743119,
      "comment_id": 2973492225,
      "created_at": "2025-06-15T04:45:13Z",
      "repoId": 840056306,
      "pullRequestNo": 584
    },
    {
      "name": "abab-dev",
      "id": 146825408,
      "comment_id": 2975719469,
      "created_at": "2025-06-16T09:12:53Z",
      "repoId": 840056306,
      "pullRequestNo": 588
    },
    {
      "name": "thorchh",
      "id": 75025911,
      "comment_id": 2982990164,
      "created_at": "2025-06-18T07:19:38Z",
      "repoId": 840056306,
      "pullRequestNo": 601
    },
    {
      "name": "robrichardson13",
      "id": 9492530,
      "comment_id": 2989798338,
      "created_at": "2025-06-20T04:59:06Z",
      "repoId": 840056306,
      "pullRequestNo": 611
    },
    {
      "name": "gkorland",
      "id": 753206,
      "comment_id": 2993690025,
      "created_at": "2025-06-21T17:35:37Z",
      "repoId": 840056306,
      "pullRequestNo": 609
    },
    {
      "name": "urmzd",
      "id": 45431570,
      "comment_id": 3027098935,
      "created_at": "2025-07-02T09:16:46Z",
      "repoId": 840056306,
      "pullRequestNo": 661
    },
    {
      "name": "jawwadfirdousi",
      "id": 10913083,
      "comment_id": 3027808026,
      "created_at": "2025-07-02T13:02:22Z",
      "repoId": 840056306,
      "pullRequestNo": 663
    },
    {
      "name": "jamesindeed",
      "id": 60527576,
      "comment_id": 3028293328,
      "created_at": "2025-07-02T15:24:23Z",
      "repoId": 840056306,
      "pullRequestNo": 664
    },
    {
      "name": "dev-mirzabicer",
      "id": 90691873,
      "comment_id": 3035836506,
      "created_at": "2025-07-04T11:47:08Z",
      "repoId": 840056306,
      "pullRequestNo": 672
    },
    {
      "name": "zeroasterisk",
      "id": 23422,
      "comment_id": 3040716245,
      "created_at": "2025-07-06T03:41:19Z",
      "repoId": 840056306,
      "pullRequestNo": 679
    },
    {
      "name": "charlesmcchan",
      "id": 425857,
      "comment_id": 3066732289,
      "created_at": "2025-07-13T08:54:26Z",
      "repoId": 840056306,
      "pullRequestNo": 711
    },
    {
      "name": "soraxas",
      "id": 22362177,
      "comment_id": 3084093750,
      "created_at": "2025-07-17T13:33:25Z",
      "repoId": 840056306,
      "pullRequestNo": 741
    },
    {
      "name": "sdht0",
      "id": 867424,
      "comment_id": 3092540466,
      "created_at": "2025-07-19T19:52:21Z",
      "repoId": 840056306,
      "pullRequestNo": 748
    },
    {
      "name": "Naseem77",
      "id": 34807727,
      "comment_id": 3093746709,
      "created_at": "2025-07-20T07:07:33Z",
      "repoId": 840056306,
      "pullRequestNo": 742
    },
    {
      "name": "kavenGw",
      "id": 3193355,
      "comment_id": 3100620568,
      "created_at": "2025-07-22T02:58:50Z",
      "repoId": 840056306,
      "pullRequestNo": 750
    },
    {
      "name": "paveljakov",
      "id": 45147436,
      "comment_id": 3113955940,
      "created_at": "2025-07-24T15:39:36Z",
      "repoId": 840056306,
      "pullRequestNo": 764
    },
    {
      "name": "gifflet",
      "id": 33522742,
      "comment_id": 3133869379,
      "created_at": "2025-07-29T20:00:27Z",
      "repoId": 840056306,
      "pullRequestNo": 782
    },
    {
      "name": "bechbd",
      "id": 6898505,
      "comment_id": 3140501814,
      "created_at": "2025-07-31T15:58:08Z",
      "repoId": 840056306,
      "pullRequestNo": 793
    },
    {
      "name": "hugo-son",
      "id": 141999572,
      "comment_id": 3155009405,
      "created_at": "2025-08-05T12:27:09Z",
      "repoId": 840056306,
      "pullRequestNo": 805
    },
    {
      "name": "mvanders",
      "id": 758617,
      "comment_id": 3160523661,
      "created_at": "2025-08-06T14:56:21Z",
      "repoId": 840056306,
      "pullRequestNo": 808
    },
    {
      "name": "v-khanna",
      "id": 102773390,
      "comment_id": 3162200130,
      "created_at": "2025-08-07T02:23:09Z",
      "repoId": 840056306,
      "pullRequestNo": 812
    },
    {
      "name": "vjeeva",
      "id": 13189349,
      "comment_id": 3165600173,
      "created_at": "2025-08-07T20:24:08Z",
      "repoId": 840056306,
      "pullRequestNo": 814
    },
    {
      "name": "liebertar",
      "id": 99405438,
      "comment_id": 3166905812,
      "created_at": "2025-08-08T07:52:27Z",
      "repoId": 840056306,
      "pullRequestNo": 816
    },
    {
      "name": "CaroLe-prw",
      "id": 42695882,
      "comment_id": 3187949734,
      "created_at": "2025-08-14T10:29:25Z",
      "repoId": 840056306,
      "pullRequestNo": 833
    },
    {
      "name": "Wizmann",
      "id": 1270921,
      "comment_id": 3196208374,
      "created_at": "2025-08-18T11:09:35Z",
      "repoId": 840056306,
      "pullRequestNo": 842
    },
    {
      "name": "liangyuanpeng",
      "id": 28711504,
      "comment_id": 3205841804,
      "created_at": "2025-08-20T11:35:42Z",
      "repoId": 840056306,
      "pullRequestNo": 847
    },
    {
      "name": "aktek-yazge",
      "id": 218602044,
      "comment_id": 3078757968,
      "created_at": "2025-07-16T14:00:40Z",
      "repoId": 840056306,
      "pullRequestNo": 735
    },
    {
      "name": "Shelvak",
      "id": 873323,
      "comment_id": 3243330690,
      "created_at": "2025-09-01T22:26:32Z",
      "repoId": 840056306,
      "pullRequestNo": 885
    },
    {
      "name": "maskshell",
      "id": 5113279,
      "comment_id": 3244187860,
      "created_at": "2025-09-02T07:48:05Z",
      "repoId": 840056306,
      "pullRequestNo": 886
    },
    {
      "name": "jeanlucthumm",
      "id": 4934853,
      "comment_id": 3255120747,
      "created_at": "2025-09-04T18:49:57Z",
      "repoId": 840056306,
      "pullRequestNo": 892
    },
    {
      "name": "Bit-urd",
      "id": 43745133,
      "comment_id": 3264006888,
      "created_at": "2025-09-07T20:01:08Z",
      "repoId": 840056306,
      "pullRequestNo": 895
    },
    {
      "name": "DavIvek",
      "id": 88043717,
      "comment_id": 3269895491,
      "created_at": "2025-09-09T09:59:47Z",
      "repoId": 840056306,
      "pullRequestNo": 900
    },
    {
      "name": "gsw945",
      "id": 6281968,
      "comment_id": 3270396586,
      "created_at": "2025-09-09T12:05:27Z",
      "repoId": 840056306,
      "pullRequestNo": 901
    },
    {
      "name": "luan122",
      "id": 5606023,
      "comment_id": 3287095238,
      "created_at": "2025-09-12T23:14:21Z",
      "repoId": 840056306,
      "pullRequestNo": 908
    },
    {
      "name": "Brandtweary",
      "id": 7968557,
      "comment_id": 3314191937,
      "created_at": "2025-09-19T23:37:33Z",
      "repoId": 840056306,
      "pullRequestNo": 916
    },
    {
      "name": "clsferguson",
      "id": 48876201,
      "comment_id": 3368715688,
      "created_at": "2025-10-05T03:30:10Z",
      "repoId": 840056306,
      "pullRequestNo": 981
    },
    {
      "name": "ngaiyuc",
      "id": 69293565,
      "comment_id": 3407383300,
      "created_at": "2025-10-15T16:45:10Z",
      "repoId": 840056306,
      "pullRequestNo": 1005
    },
    {
      "name": "0fism",
      "id": 63762457,
      "comment_id": 3407328042,
      "created_at": "2025-10-15T16:29:33Z",
      "repoId": 840056306,
      "pullRequestNo": 1005
    },
    {
      "name": "dontang97",
      "id": 88384441,
      "comment_id": 3431443627,
      "created_at": "2025-10-22T09:52:01Z",
      "repoId": 840056306,
      "pullRequestNo": 1020
    },
    {
      "name": "didier-durand",
      "id": 2927957,
      "comment_id": 3460571645,
      "created_at": "2025-10-29T09:31:25Z",
      "repoId": 840056306,
      "pullRequestNo": 1028
    },
    {
      "name": "anubhavgirdhar1",
      "id": 85768253,
      "comment_id": 3468525446,
      "created_at": "2025-10-30T15:11:58Z",
      "repoId": 840056306,
      "pullRequestNo": 1035
    },
    {
      "name": "Galleons2029",
      "id": 88185941,
      "comment_id": 3495884964,
      "created_at": "2025-11-06T08:39:46Z",
      "repoId": 840056306,
      "pullRequestNo": 1053
    },
    {
      "name": "supmo668",
      "id": 28805779,
      "comment_id": 3550309664,
      "created_at": "2025-11-19T01:56:25Z",
      "repoId": 840056306,
      "pullRequestNo": 1072
    },
    {
      "name": "donbr",
      "id": 7340008,
      "comment_id": 3568970102,
      "created_at": "2025-11-24T05:19:42Z",
      "repoId": 840056306,
      "pullRequestNo": 1081
    },
    {
      "name": "apetti1920",
      "id": 4706645,
      "comment_id": 3572726648,
      "created_at": "2025-11-24T21:07:34Z",
      "repoId": 840056306,
      "pullRequestNo": 1084
    }
  ]
}
```

--------------------------------------------------------------------------------
/tests/test_entity_exclusion_int.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from datetime import datetime, timezone

import pytest
from pydantic import BaseModel, Field

from graphiti_core.graphiti import Graphiti
from graphiti_core.helpers import validate_excluded_entity_types
from tests.helpers_test import drivers, get_driver

pytestmark = pytest.mark.integration
pytest_plugins = ('pytest_asyncio',)


# Test entity type definitions
class Person(BaseModel):
    """A human person mentioned in the conversation."""

    first_name: str | None = Field(None, description='First name of the person')
    last_name: str | None = Field(None, description='Last name of the person')
    occupation: str | None = Field(None, description='Job or profession of the person')


class Organization(BaseModel):
    """A company, institution, or organized group."""

    organization_type: str | None = Field(
        None, description='Type of organization (company, NGO, etc.)'
    )
    industry: str | None = Field(
        None, description='Industry or sector the organization operates in'
    )


class Location(BaseModel):
    """A geographic location, place, or address."""

    location_type: str | None = Field(
        None, description='Type of location (city, country, building, etc.)'
    )
    coordinates: str | None = Field(None, description='Geographic coordinates if available')


@pytest.mark.asyncio
@pytest.mark.parametrize(
    'driver',
    drivers,
)
async def test_exclude_default_entity_type(driver):
    """Test excluding the default 'Entity' type while keeping custom types."""
    graphiti = Graphiti(graph_driver=get_driver(driver))

    try:
        await graphiti.build_indices_and_constraints()

        # Define entity types but exclude the default 'Entity' type
        entity_types = {
            'Person': Person,
            'Organization': Organization,
        }

        # Add an episode that would normally create both Entity and custom type entities
        episode_content = (
            'John Smith works at Acme Corporation in New York. The weather is nice today.'
        )

        result = await graphiti.add_episode(
            name='Business Meeting',
            episode_body=episode_content,
            source_description='Meeting notes',
            reference_time=datetime.now(timezone.utc),
            entity_types=entity_types,
            excluded_entity_types=['Entity'],  # Exclude default type
            group_id='test_exclude_default',
        )

        # Verify that nodes were created (custom types should still work)
        assert result is not None

        # Search for nodes to verify only custom types were created
        search_results = await graphiti.search_(
            query='John Smith Acme Corporation', group_ids=['test_exclude_default']
        )

        # Check that entities were created but with specific types, not default 'Entity'
        found_nodes = search_results.nodes
        for node in found_nodes:
            assert 'Entity' in node.labels  # All nodes should have Entity label
            # But they should also have specific type labels
            assert any(label in ['Person', 'Organization'] for label in node.labels), (
                f'Node {node.name} should have a specific type label, got: {node.labels}'
            )

        # Clean up
        await _cleanup_test_nodes(graphiti, 'test_exclude_default')

    finally:
        await graphiti.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
    'driver',
    drivers,
)
async def test_exclude_specific_custom_types(driver):
    """Test excluding specific custom entity types while keeping others."""
    graphiti = Graphiti(graph_driver=get_driver(driver))

    try:
        await graphiti.build_indices_and_constraints()

        # Define multiple entity types
        entity_types = {
            'Person': Person,
            'Organization': Organization,
            'Location': Location,
        }

        # Add an episode with content that would create all types
        episode_content = (
            'Sarah Johnson from Google visited the San Francisco office to discuss the new project.'
        )

        result = await graphiti.add_episode(
            name='Office Visit',
            episode_body=episode_content,
            source_description='Visit report',
            reference_time=datetime.now(timezone.utc),
            entity_types=entity_types,
            excluded_entity_types=['Organization', 'Location'],  # Exclude these types
            group_id='test_exclude_custom',
        )

        assert result is not None

        # Search for nodes to verify only Person and Entity types were created
        search_results = await graphiti.search_(
            query='Sarah Johnson Google San Francisco', group_ids=['test_exclude_custom']
        )

        found_nodes = search_results.nodes

        # Should have Person and Entity type nodes, but no Organization or Location
        for node in found_nodes:
            assert 'Entity' in node.labels
            # Should not have excluded types
            assert 'Organization' not in node.labels, (
                f'Found excluded Organization in node: {node.name}'
            )
            assert 'Location' not in node.labels, f'Found excluded Location in node: {node.name}'

        # Should find at least one Person entity (Sarah Johnson)
        person_nodes = [n for n in found_nodes if 'Person' in n.labels]
        assert len(person_nodes) > 0, 'Should have found at least one Person entity'

        # Clean up
        await _cleanup_test_nodes(graphiti, 'test_exclude_custom')

    finally:
        await graphiti.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
    'driver',
    drivers,
)
async def test_exclude_all_types(driver):
    """Test excluding all entity types (edge case)."""
    graphiti = Graphiti(graph_driver=get_driver(driver))

    try:
        await graphiti.build_indices_and_constraints()

        entity_types = {
            'Person': Person,
            'Organization': Organization,
        }

        # Exclude all types
        result = await graphiti.add_episode(
            name='No Entities',
            episode_body='This text mentions John and Microsoft but no entities should be created.',
            source_description='Test content',
            reference_time=datetime.now(timezone.utc),
            entity_types=entity_types,
            excluded_entity_types=['Entity', 'Person', 'Organization'],  # Exclude everything
            group_id='test_exclude_all',
        )

        assert result is not None

        # Search for nodes - should find very few or none from this episode
        search_results = await graphiti.search_(
            query='John Microsoft', group_ids=['test_exclude_all']
        )

        # There should be minimal to no entities created
        found_nodes = search_results.nodes
        assert len(found_nodes) == 0, (
            f'Expected no entities, but found: {[n.name for n in found_nodes]}'
        )

        # Clean up
        await _cleanup_test_nodes(graphiti, 'test_exclude_all')

    finally:
        await graphiti.close()


@pytest.mark.asyncio
@pytest.mark.parametrize(
    'driver',
    drivers,
)
async def test_exclude_no_types(driver):
    """Test normal behavior when no types are excluded (baseline test)."""
    graphiti = Graphiti(graph_driver=get_driver(driver))

    try:
        await graphiti.build_indices_and_constraints()

        entity_types = {
            'Person': Person,
            'Organization': Organization,
        }

        # Don't exclude any types
        result = await graphiti.add_episode(
            name='Normal Behavior',
            episode_body='Alice Smith works at TechCorp.',
            source_description='Normal test',
            reference_time=datetime.now(timezone.utc),
            entity_types=entity_types,
            excluded_entity_types=None,  # No exclusions
            group_id='test_exclude_none',
        )

        assert result is not None

        # Search for nodes - should find entities of all types
        search_results = await graphiti.search_(
            query='Alice Smith TechCorp', group_ids=['test_exclude_none']
        )

        found_nodes = search_results.nodes
        assert len(found_nodes) > 0, 'Should have found some entities'

        # Should have both Person and Organization entities
        person_nodes = [n for n in found_nodes if 'Person' in n.labels]
        org_nodes = [n for n in found_nodes if 'Organization' in n.labels]

        assert len(person_nodes) > 0, 'Should have found Person entities'
        assert len(org_nodes) > 0, 'Should have found Organization entities'

        # Clean up
        await _cleanup_test_nodes(graphiti, 'test_exclude_none')

    finally:
        await graphiti.close()


def test_validation_valid_excluded_types():
    """Test validation function with valid excluded types."""
    entity_types = {
        'Person': Person,
        'Organization': Organization,
    }

    # Valid exclusions
    assert validate_excluded_entity_types(['Entity'], entity_types) is True
    assert validate_excluded_entity_types(['Person'], entity_types) is True
    assert validate_excluded_entity_types(['Entity', 'Person'], entity_types) is True
    assert validate_excluded_entity_types(None, entity_types) is True
    assert validate_excluded_entity_types([], entity_types) is True


def test_validation_invalid_excluded_types():
    """Test validation function with invalid excluded types."""
    entity_types = {
        'Person': Person,
        'Organization': Organization,
    }

    # Invalid exclusions should raise ValueError
    with pytest.raises(ValueError, match='Invalid excluded entity types'):
        validate_excluded_entity_types(['InvalidType'], entity_types)

    with pytest.raises(ValueError, match='Invalid excluded entity types'):
        validate_excluded_entity_types(['Person', 'NonExistentType'], entity_types)


@pytest.mark.asyncio
@pytest.mark.parametrize(
    'driver',
    drivers,
)
async def test_excluded_types_parameter_validation_in_add_episode(driver):
    """Test that add_episode validates excluded_entity_types parameter."""
    graphiti = Graphiti(graph_driver=get_driver(driver))

    try:
        entity_types = {
            'Person': Person,
        }

        # Should raise ValueError for invalid excluded type
        with pytest.raises(ValueError, match='Invalid excluded entity types'):
            await graphiti.add_episode(
                name='Invalid Test',
                episode_body='Test content',
                source_description='Test',
                reference_time=datetime.now(timezone.utc),
                entity_types=entity_types,
                excluded_entity_types=['NonExistentType'],
                group_id='test_validation',
            )

    finally:
        await graphiti.close()


async def _cleanup_test_nodes(graphiti: Graphiti, group_id: str):
    """Helper function to clean up test nodes."""
    try:
        # Get all nodes for this group
        search_results = await graphiti.search_(query='*', group_ids=[group_id])

        # Delete all found nodes
        for node in search_results.nodes:
            await node.delete(graphiti.driver)

    except Exception as e:
        # Log but don't fail the test if cleanup fails
        print(f'Warning: Failed to clean up test nodes for group {group_id}: {e}')

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_integration.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
HTTP/SSE Integration test for the refactored Graphiti MCP Server.
Tests server functionality when running in SSE (Server-Sent Events) mode over HTTP.
Note: This test requires the server to be running with --transport sse.
"""

import asyncio
import json
import time
from typing import Any

import httpx


class MCPIntegrationTest:
    """Integration test client for Graphiti MCP Server."""

    def __init__(self, base_url: str = 'http://localhost:8000'):
        self.base_url = base_url
        self.client = httpx.AsyncClient(timeout=30.0)
        self.test_group_id = f'test_group_{int(time.time())}'

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        await self.client.aclose()

    async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
        """Call an MCP tool via the SSE endpoint."""
        # MCP protocol message structure
        message = {
            'jsonrpc': '2.0',
            'id': int(time.time() * 1000),
            'method': 'tools/call',
            'params': {'name': tool_name, 'arguments': arguments},
        }

        try:
            response = await self.client.post(
                f'{self.base_url}/message',
                json=message,
                headers={'Content-Type': 'application/json'},
            )

            if response.status_code != 200:
                return {'error': f'HTTP {response.status_code}: {response.text}'}

            result = response.json()
            return result.get('result', result)

        except Exception as e:
            return {'error': str(e)}

    async def test_server_status(self) -> bool:
        """Test the get_status resource."""
        print('🔍 Testing server status...')

        try:
            response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
            if response.status_code == 200:
                status = response.json()
                print(f'   ✅ Server status: {status.get("status", "unknown")}')
                return status.get('status') == 'ok'
            else:
                print(f'   ❌ Status check failed: HTTP {response.status_code}')
                return False
        except Exception as e:
            print(f'   ❌ Status check failed: {e}')
            return False

    async def test_add_memory(self) -> dict[str, str]:
        """Test adding various types of memory episodes."""
        print('📝 Testing add_memory functionality...')

        episode_results = {}

        # Test 1: Add text episode
        print('   Testing text episode...')
        result = await self.call_mcp_tool(
            'add_memory',
            {
                'name': 'Test Company News',
                'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
                'source': 'text',
                'source_description': 'news article',
                'group_id': self.test_group_id,
            },
        )

        if 'error' in result:
            print(f'   ❌ Text episode failed: {result["error"]}')
        else:
            print(f'   ✅ Text episode queued: {result.get("message", "Success")}')
            episode_results['text'] = 'success'

        # Test 2: Add JSON episode
        print('   Testing JSON episode...')
        json_data = {
            'company': {'name': 'TechCorp', 'founded': 2010},
            'products': [
                {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
                {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
            ],
            'employees': 150,
        }

        result = await self.call_mcp_tool(
            'add_memory',
            {
                'name': 'Company Profile',
                'episode_body': json.dumps(json_data),
                'source': 'json',
                'source_description': 'CRM data',
                'group_id': self.test_group_id,
            },
        )

        if 'error' in result:
            print(f'   ❌ JSON episode failed: {result["error"]}')
        else:
            print(f'   ✅ JSON episode queued: {result.get("message", "Success")}')
            episode_results['json'] = 'success'

        # Test 3: Add message episode
        print('   Testing message episode...')
        result = await self.call_mcp_tool(
            'add_memory',
            {
                'name': 'Customer Support Chat',
                'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
                'source': 'message',
                'source_description': 'support chat log',
                'group_id': self.test_group_id,
            },
        )

        if 'error' in result:
            print(f'   ❌ Message episode failed: {result["error"]}')
        else:
            print(f'   ✅ Message episode queued: {result.get("message", "Success")}')
            episode_results['message'] = 'success'

        return episode_results

    async def wait_for_processing(self, max_wait: int = 30) -> None:
        """Wait for episode processing to complete."""
        print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')

        for i in range(max_wait):
            await asyncio.sleep(1)

            # Check if we have any episodes
            result = await self.call_mcp_tool(
                'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
            )

            if not isinstance(result, dict) or 'error' in result:
                continue

            if isinstance(result, list) and len(result) > 0:
                print(f'   ✅ Found {len(result)} processed episodes after {i + 1} seconds')
                return

        print(f'   ⚠️  Still waiting after {max_wait} seconds...')

    async def test_search_functions(self) -> dict[str, bool]:
        """Test search functionality."""
        print('🔍 Testing search functions...')

        results = {}

        # Test search_memory_nodes
        print('   Testing search_memory_nodes...')
        result = await self.call_mcp_tool(
            'search_memory_nodes',
            {
                'query': 'Acme Corp product launch',
                'group_ids': [self.test_group_id],
                'max_nodes': 5,
            },
        )

        if 'error' in result:
            print(f'   ❌ Node search failed: {result["error"]}')
            results['nodes'] = False
        else:
            nodes = result.get('nodes', [])
            print(f'   ✅ Node search returned {len(nodes)} nodes')
            results['nodes'] = True

        # Test search_memory_facts
        print('   Testing search_memory_facts...')
        result = await self.call_mcp_tool(
            'search_memory_facts',
            {
                'query': 'company products software',
                'group_ids': [self.test_group_id],
                'max_facts': 5,
            },
        )

        if 'error' in result:
            print(f'   ❌ Fact search failed: {result["error"]}')
            results['facts'] = False
        else:
            facts = result.get('facts', [])
            print(f'   ✅ Fact search returned {len(facts)} facts')
            results['facts'] = True

        return results

    async def test_episode_retrieval(self) -> bool:
        """Test episode retrieval."""
        print('📚 Testing episode retrieval...')

        result = await self.call_mcp_tool(
            'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
        )

        if 'error' in result:
            print(f'   ❌ Episode retrieval failed: {result["error"]}')
            return False

        if isinstance(result, list):
            print(f'   ✅ Retrieved {len(result)} episodes')

            # Print episode details
            for i, episode in enumerate(result[:3]):  # Show first 3
                name = episode.get('name', 'Unknown')
                source = episode.get('source', 'unknown')
                print(f'     Episode {i + 1}: {name} (source: {source})')

            return len(result) > 0
        else:
            print(f'   ❌ Unexpected result format: {type(result)}')
            return False

    async def test_edge_cases(self) -> dict[str, bool]:
        """Test edge cases and error handling."""
        print('🧪 Testing edge cases...')

        results = {}

        # Test with invalid group_id
        print('   Testing invalid group_id...')
        result = await self.call_mcp_tool(
            'search_memory_nodes',
            {'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
        )

        # Should not error, just return empty results
        if 'error' not in result:
            nodes = result.get('nodes', [])
            print(f'   ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
            results['invalid_group'] = True
        else:
            print(f'   ❌ Invalid group_id caused error: {result["error"]}')
            results['invalid_group'] = False

        # Test empty query
        print('   Testing empty query...')
        result = await self.call_mcp_tool(
            'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
        )

        if 'error' not in result:
            print('   ✅ Empty query handled gracefully')
            results['empty_query'] = True
        else:
            print(f'   ❌ Empty query caused error: {result["error"]}')
            results['empty_query'] = False

        return results

    async def run_full_test_suite(self) -> dict[str, Any]:
        """Run the complete integration test suite."""
        print('🚀 Starting Graphiti MCP Server Integration Test')
        print(f'   Test group ID: {self.test_group_id}')
        print('=' * 60)

        results = {
            'server_status': False,
            'add_memory': {},
            'search': {},
            'episodes': False,
            'edge_cases': {},
            'overall_success': False,
        }

        # Test 1: Server Status
        results['server_status'] = await self.test_server_status()
        if not results['server_status']:
            print('❌ Server not responding, aborting tests')
            return results

        print()

        # Test 2: Add Memory
        results['add_memory'] = await self.test_add_memory()
        print()

        # Test 3: Wait for processing
        await self.wait_for_processing()
        print()

        # Test 4: Search Functions
        results['search'] = await self.test_search_functions()
        print()

        # Test 5: Episode Retrieval
        results['episodes'] = await self.test_episode_retrieval()
        print()

        # Test 6: Edge Cases
        results['edge_cases'] = await self.test_edge_cases()
        print()

        # Calculate overall success
        memory_success = len(results['add_memory']) > 0
        search_success = any(results['search'].values())
        edge_case_success = any(results['edge_cases'].values())

        results['overall_success'] = (
            results['server_status']
            and memory_success
            and results['episodes']
            and (search_success or edge_case_success)  # At least some functionality working
        )

        # Print summary
        print('=' * 60)
        print('📊 TEST SUMMARY')
        print(f'   Server Status: {"✅" if results["server_status"] else "❌"}')
        print(
            f'   Memory Operations: {"✅" if memory_success else "❌"} ({len(results["add_memory"])} types)'
        )
        print(f'   Search Functions: {"✅" if search_success else "❌"}')
        print(f'   Episode Retrieval: {"✅" if results["episodes"] else "❌"}')
        print(f'   Edge Cases: {"✅" if edge_case_success else "❌"}')
        print()
        print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')

        if results['overall_success']:
            print('   The refactored MCP server is working correctly!')
        else:
            print('   Some issues detected. Check individual test results above.')

        return results


async def main():
    """Run the integration test."""
    async with MCPIntegrationTest() as test:
        results = await test.run_full_test_suite()

        # Exit with appropriate code
        exit_code = 0 if results['overall_success'] else 1
        exit(exit_code)


if __name__ == '__main__':
    asyncio.run(main())

```

--------------------------------------------------------------------------------
/graphiti_core/driver/falkordb_driver.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import datetime
import logging
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from falkordb import Graph as FalkorGraph
    from falkordb.asyncio import FalkorDB
else:
    try:
        from falkordb import Graph as FalkorGraph
        from falkordb.asyncio import FalkorDB
    except ImportError:
        # If falkordb is not installed, raise an ImportError
        raise ImportError(
            'falkordb is required for FalkorDriver. '
            'Install it with: pip install graphiti-core[falkordb]'
        ) from None

from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings

logger = logging.getLogger(__name__)

STOPWORDS = [
    'a',
    'is',
    'the',
    'an',
    'and',
    'are',
    'as',
    'at',
    'be',
    'but',
    'by',
    'for',
    'if',
    'in',
    'into',
    'it',
    'no',
    'not',
    'of',
    'on',
    'or',
    'such',
    'that',
    'their',
    'then',
    'there',
    'these',
    'they',
    'this',
    'to',
    'was',
    'will',
    'with',
]


class FalkorDriverSession(GraphDriverSession):
    provider = GraphProvider.FALKORDB

    def __init__(self, graph: FalkorGraph):
        self.graph = graph

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        # No cleanup needed for Falkor, but method must exist
        pass

    async def close(self):
        # No explicit close needed for FalkorDB, but method must exist
        pass

    async def execute_write(self, func, *args, **kwargs):
        # Directly await the provided async function with `self` as the transaction/session
        return await func(self, *args, **kwargs)

    async def run(self, query: str | list, **kwargs: Any) -> Any:
        # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
        if isinstance(query, list):
            for cypher, params in query:
                params = convert_datetimes_to_strings(params)
                await self.graph.query(str(cypher), params)  # type: ignore[reportUnknownArgumentType]
        else:
            params = dict(kwargs)
            params = convert_datetimes_to_strings(params)
            await self.graph.query(str(query), params)  # type: ignore[reportUnknownArgumentType]
        # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
        return None


class FalkorDriver(GraphDriver):
    provider = GraphProvider.FALKORDB
    default_group_id: str = '\\_'
    fulltext_syntax: str = '@'  # FalkorDB uses a redisearch-like syntax for fulltext queries
    aoss_client: None = None

    def __init__(
        self,
        host: str = 'localhost',
        port: int = 6379,
        username: str | None = None,
        password: str | None = None,
        falkor_db: FalkorDB | None = None,
        database: str = 'default_db',
    ):
        """
        Initialize the FalkorDB driver.

        FalkorDB is a multi-tenant graph database.
        To connect, provide the host and port.
        The default parameters assume a local (on-premises) FalkorDB instance.

        Args:
        host (str): The host where FalkorDB is running.
        port (int): The port on which FalkorDB is listening.
        username (str | None): The username for authentication (if required).
        password (str | None): The password for authentication (if required).
        falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
        database (str): The name of the database to connect to. Defaults to 'default_db'.
        """
        super().__init__()
        self._database = database
        if falkor_db is not None:
            # If a FalkorDB instance is provided, use it directly
            self.client = falkor_db
        else:
            self.client = FalkorDB(host=host, port=port, username=username, password=password)

        # Schedule the indices and constraints to be built
        try:
            # Try to get the current event loop
            loop = asyncio.get_running_loop()
            # Schedule the build_indices_and_constraints to run
            loop.create_task(self.build_indices_and_constraints())
        except RuntimeError:
            # No event loop running, this will be handled later
            pass

    def _get_graph(self, graph_name: str | None) -> FalkorGraph:
        # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
        if graph_name is None:
            graph_name = self._database
        return self.client.select_graph(graph_name)

    async def execute_query(self, cypher_query_, **kwargs: Any):
        graph = self._get_graph(self._database)

        # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
        params = convert_datetimes_to_strings(dict(kwargs))

        try:
            result = await graph.query(cypher_query_, params)  # type: ignore[reportUnknownArgumentType]
        except Exception as e:
            if 'already indexed' in str(e):
                # check if index already exists
                logger.info(f'Index already exists: {e}')
                return None
            logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
            raise

        # Convert the result header to a list of strings
        header = [h[1] for h in result.header]

        # Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
        records = []
        for row in result.result_set:
            record = {}
            for i, field_name in enumerate(header):
                if i < len(row):
                    record[field_name] = row[i]
                else:
                    # If there are more fields in header than values in row, set to None
                    record[field_name] = None
            records.append(record)

        return records, header, None

    def session(self, database: str | None = None) -> GraphDriverSession:
        return FalkorDriverSession(self._get_graph(database))

    async def close(self) -> None:
        """Close the driver connection."""
        if hasattr(self.client, 'aclose'):
            await self.client.aclose()  # type: ignore[reportUnknownMemberType]
        elif hasattr(self.client.connection, 'aclose'):
            await self.client.connection.aclose()
        elif hasattr(self.client.connection, 'close'):
            await self.client.connection.close()

    async def delete_all_indexes(self) -> None:
        result = await self.execute_query('CALL db.indexes()')
        if not result:
            return

        records, _, _ = result
        drop_tasks = []

        for record in records:
            label = record['label']
            entity_type = record['entitytype']

            for field_name, index_type in record['types'].items():
                if 'RANGE' in index_type:
                    drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
                elif 'FULLTEXT' in index_type:
                    if entity_type == 'NODE':
                        drop_tasks.append(
                            self.execute_query(
                                f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
                            )
                        )
                    elif entity_type == 'RELATIONSHIP':
                        drop_tasks.append(
                            self.execute_query(
                                f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
                            )
                        )

        if drop_tasks:
            await asyncio.gather(*drop_tasks)

    async def build_indices_and_constraints(self, delete_existing=False):
        if delete_existing:
            await self.delete_all_indexes()
        index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
        for query in index_queries:
            await self.execute_query(query)

    def clone(self, database: str) -> 'GraphDriver':
        """
        Returns a shallow copy of this driver with a different default database.
        Reuses the same connection (e.g. FalkorDB, Neo4j).
        """
        if database == self._database:
            cloned = self
        elif database == self.default_group_id:
            cloned = FalkorDriver(falkor_db=self.client)
        else:
            # Create a new instance of FalkorDriver with the same connection but a different database
            cloned = FalkorDriver(falkor_db=self.client, database=database)

        return cloned

    async def health_check(self) -> None:
        """Check FalkorDB connectivity by running a simple query."""
        try:
            await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
            return None
        except Exception as e:
            print(f'FalkorDB health check failed: {e}')
            raise

    @staticmethod
    def convert_datetimes_to_strings(obj):
        if isinstance(obj, dict):
            return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
        elif isinstance(obj, tuple):
            return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
        elif isinstance(obj, datetime):
            return obj.isoformat()
        else:
            return obj

    def sanitize(self, query: str) -> str:
        """
        Replace FalkorDB special characters with whitespace.
        Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
        """
        # FalkorDB separator characters that break text into tokens
        separator_map = str.maketrans(
            {
                ',': ' ',
                '.': ' ',
                '<': ' ',
                '>': ' ',
                '{': ' ',
                '}': ' ',
                '[': ' ',
                ']': ' ',
                '"': ' ',
                "'": ' ',
                ':': ' ',
                ';': ' ',
                '!': ' ',
                '@': ' ',
                '#': ' ',
                '$': ' ',
                '%': ' ',
                '^': ' ',
                '&': ' ',
                '*': ' ',
                '(': ' ',
                ')': ' ',
                '-': ' ',
                '+': ' ',
                '=': ' ',
                '~': ' ',
                '?': ' ',
            }
        )
        sanitized = query.translate(separator_map)
        # Clean up multiple spaces
        sanitized = ' '.join(sanitized.split())
        return sanitized

    def build_fulltext_query(
        self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
    ) -> str:
        """
        Build a fulltext query string for FalkorDB using RedisSearch syntax.
        FalkorDB uses RedisSearch-like syntax where:
        - Field queries use @ prefix: @field:value
        - Multiple values for same field: (@field:value1|value2)
        - Text search doesn't need @ prefix for content fields
        - AND is implicit with space: (@group_id:value) (text)
        - OR uses pipe within parentheses: (@group_id:value1|value2)
        """
        if group_ids is None or len(group_ids) == 0:
            group_filter = ''
        else:
            group_values = '|'.join(group_ids)
            group_filter = f'(@group_id:{group_values})'

        sanitized_query = self.sanitize(query)

        # Remove stopwords from the sanitized query
        query_words = sanitized_query.split()
        filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
        sanitized_query = ' | '.join(filtered_words)

        # If the query is too long return no query
        if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
            return ''

        full_query = group_filter + ' (' + sanitized_query + ')'

        return full_query

```
Page 4/9FirstPrevNextLast