#
tokens: 48435/50000 13/236 files (page 5/9)
lines: off (toggle) GitHub
raw markdown copy
This is page 5 of 9. Use http://codebase.md/getzep/graphiti?page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── dense_vs_normal_ingestion.py
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── content_chunking.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_entity_extraction.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       ├── search
│       │   └── search_utils_test.py
│       └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
from graphiti_core.llm_client import LLMConfig, RateLimitError


@pytest.fixture
def mock_gemini_client():
    """Fixture to mock the Google Gemini client."""
    with patch('google.genai.Client') as mock_client:
        # Setup mock instance and its methods
        mock_instance = mock_client.return_value
        mock_instance.aio = MagicMock()
        mock_instance.aio.models = MagicMock()
        mock_instance.aio.models.generate_content = AsyncMock()
        yield mock_instance


@pytest.fixture
def gemini_reranker_client(mock_gemini_client):
    """Fixture to create a GeminiRerankerClient with a mocked client."""
    config = LLMConfig(api_key='test_api_key', model='test-model')
    client = GeminiRerankerClient(config=config)
    # Replace the client's client with our mock to ensure we're using the mock
    client.client = mock_gemini_client
    return client


def create_mock_response(score_text: str) -> MagicMock:
    """Helper function to create a mock Gemini response."""
    mock_response = MagicMock()
    mock_response.text = score_text
    return mock_response


class TestGeminiRerankerClientInitialization:
    """Tests for GeminiRerankerClient initialization."""

    def test_init_with_config(self):
        """Test initialization with a config object."""
        config = LLMConfig(api_key='test_api_key', model='test-model')
        client = GeminiRerankerClient(config=config)

        assert client.config == config

    @patch('google.genai.Client')
    def test_init_without_config(self, mock_client):
        """Test initialization without a config uses defaults."""
        client = GeminiRerankerClient()

        assert client.config is not None

    def test_init_with_custom_client(self):
        """Test initialization with a custom client."""
        mock_client = MagicMock()
        client = GeminiRerankerClient(client=mock_client)

        assert client.client == mock_client


class TestGeminiRerankerClientRanking:
    """Tests for GeminiRerankerClient rank method."""

    @pytest.mark.asyncio
    async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
        """Test basic ranking functionality."""
        # Setup mock responses with different scores
        mock_responses = [
            create_mock_response('85'),  # High relevance
            create_mock_response('45'),  # Medium relevance
            create_mock_response('20'),  # Low relevance
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        # Test data
        query = 'What is the capital of France?'
        passages = [
            'Paris is the capital and most populous city of France.',
            'London is the capital city of England and the United Kingdom.',
            'Berlin is the capital and largest city of Germany.',
        ]

        # Call method
        result = await gemini_reranker_client.rank(query, passages)

        # Assertions
        assert len(result) == 3
        assert all(isinstance(item, tuple) for item in result)
        assert all(
            isinstance(passage, str) and isinstance(score, float) for passage, score in result
        )

        # Check scores are normalized to [0, 1] and sorted in descending order
        scores = [score for _, score in result]
        assert all(0.0 <= score <= 1.0 for score in scores)
        assert scores == sorted(scores, reverse=True)

        # Check that the highest scoring passage is first
        assert result[0][1] == 0.85  # 85/100
        assert result[1][1] == 0.45  # 45/100
        assert result[2][1] == 0.20  # 20/100

    @pytest.mark.asyncio
    async def test_rank_empty_passages(self, gemini_reranker_client):
        """Test ranking with empty passages list."""
        query = 'Test query'
        passages = []

        result = await gemini_reranker_client.rank(query, passages)

        assert result == []

    @pytest.mark.asyncio
    async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
        """Test ranking with a single passage."""
        # Setup mock response
        mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')

        query = 'Test query'
        passages = ['Single test passage']

        result = await gemini_reranker_client.rank(query, passages)

        assert len(result) == 1
        assert result[0][0] == 'Single test passage'
        assert result[0][1] == 1.0  # Single passage gets full score

    @pytest.mark.asyncio
    async def test_rank_score_extraction_with_regex(
        self, gemini_reranker_client, mock_gemini_client
    ):
        """Test score extraction from various response formats."""
        # Setup mock responses with different formats
        mock_responses = [
            create_mock_response('Score: 90'),  # Contains text before number
            create_mock_response('The relevance is 65 out of 100'),  # Contains text around number
            create_mock_response('8'),  # Just the number
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        result = await gemini_reranker_client.rank(query, passages)

        # Check that scores were extracted correctly and normalized
        scores = [score for _, score in result]
        assert 0.90 in scores  # 90/100
        assert 0.65 in scores  # 65/100
        assert 0.08 in scores  # 8/100

    @pytest.mark.asyncio
    async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of invalid or non-numeric scores."""
        # Setup mock responses with invalid scores
        mock_responses = [
            create_mock_response('Not a number'),  # Invalid response
            create_mock_response(''),  # Empty response
            create_mock_response('95'),  # Valid response
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        result = await gemini_reranker_client.rank(query, passages)

        # Check that invalid scores are handled gracefully (assigned 0.0)
        scores = [score for _, score in result]
        assert 0.95 in scores  # Valid score
        assert scores.count(0.0) == 2  # Two invalid scores assigned 0.0

    @pytest.mark.asyncio
    async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
        """Test that scores are properly clamped to [0, 1] range."""
        # Setup mock responses with extreme scores
        # Note: regex only matches 1-3 digits, so negative numbers won't match
        mock_responses = [
            create_mock_response('999'),  # Above 100 but within regex range
            create_mock_response('invalid'),  # Invalid response becomes 0.0
            create_mock_response('50'),  # Normal score
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        result = await gemini_reranker_client.rank(query, passages)

        # Check that scores are normalized and clamped
        scores = [score for _, score in result]
        assert all(0.0 <= score <= 1.0 for score in scores)
        # 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
        assert 1.0 in scores
        # Invalid response should be 0.0
        assert 0.0 in scores
        # Normal score should be normalized (50/100 = 0.5)
        assert 0.5 in scores

    @pytest.mark.asyncio
    async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of rate limit errors."""
        # Setup mock to raise rate limit error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'Rate limit exceeded'
        )

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of quota errors."""
        # Setup mock to raise quota error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of resource exhausted errors."""
        # Setup mock to raise resource exhausted error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of HTTP 429 errors."""
        # Setup mock to raise 429 error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'HTTP 429 Too Many Requests'
        )

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(RateLimitError):
            await gemini_reranker_client.rank(query, passages)

    @pytest.mark.asyncio
    async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of generic errors."""
        # Setup mock to raise generic error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2']

        with pytest.raises(Exception) as exc_info:
            await gemini_reranker_client.rank(query, passages)

        assert 'Generic error' in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
        """Test that multiple passages are scored concurrently."""
        # Setup mock responses
        mock_responses = [
            create_mock_response('80'),
            create_mock_response('60'),
            create_mock_response('40'),
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        passages = ['Passage 1', 'Passage 2', 'Passage 3']

        await gemini_reranker_client.rank(query, passages)

        # Verify that generate_content was called for each passage
        assert mock_gemini_client.aio.models.generate_content.call_count == 3

        # Verify that all calls were made with correct parameters
        calls = mock_gemini_client.aio.models.generate_content.call_args_list
        for call in calls:
            args, kwargs = call
            assert kwargs['model'] == gemini_reranker_client.config.model
            assert kwargs['config'].temperature == 0.0
            assert kwargs['config'].max_output_tokens == 3

    @pytest.mark.asyncio
    async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of response parsing errors."""
        # Setup mock responses that will trigger ValueError during parsing
        mock_responses = [
            create_mock_response('not a number at all'),  # Will fail regex match
            create_mock_response('also invalid text'),  # Will fail regex match
        ]
        mock_gemini_client.aio.models.generate_content.side_effect = mock_responses

        query = 'Test query'
        # Use multiple passages to avoid the single passage special case
        passages = ['Passage 1', 'Passage 2']

        result = await gemini_reranker_client.rank(query, passages)

        # Should handle the error gracefully and assign 0.0 score to both
        assert len(result) == 2
        assert all(score == 0.0 for _, score in result)

    @pytest.mark.asyncio
    async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
        """Test handling of empty response text."""
        # Setup mock response with empty text
        mock_response = MagicMock()
        mock_response.text = ''  # Empty string instead of None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        query = 'Test query'
        # Use multiple passages to avoid the single passage special case
        passages = ['Passage 1', 'Passage 2']

        result = await gemini_reranker_client.rank(query, passages)

        # Should handle empty text gracefully and assign 0.0 score to both
        assert len(result) == 2
        assert all(score == 0.0 for _, score in result)


if __name__ == '__main__':
    pytest.main(['-v', 'test_gemini_reranker_client.py'])

```

--------------------------------------------------------------------------------
/examples/quickstart/dense_vs_normal_ingestion.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2025, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Dense vs Normal Episode Ingestion Example
-----------------------------------------
This example demonstrates how Graphiti handles different types of content:

1. Normal Content (prose, narrative, conversations):
   - Lower entity density (few entities per token)
   - Processed in a single LLM call
   - Examples: meeting transcripts, news articles, documentation

2. Dense Content (structured data with many entities):
   - High entity density (many entities per token)
   - Automatically chunked for reliable extraction
   - Examples: bulk data imports, cost reports, entity-dense JSON

The chunking behavior is controlled by environment variables:
- CHUNK_MIN_TOKENS: Minimum tokens before considering chunking (default: 1000)
- CHUNK_DENSITY_THRESHOLD: Entity density threshold (default: 0.15)
- CHUNK_TOKEN_SIZE: Target size per chunk (default: 3000)
- CHUNK_OVERLAP_TOKENS: Overlap between chunks (default: 200)
"""

import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO

from dotenv import load_dotenv

from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType

#################################################
# CONFIGURATION
#################################################

logging.basicConfig(
    level=INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)

load_dotenv()

neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')

if not neo4j_uri or not neo4j_user or not neo4j_password:
    raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')


#################################################
# EXAMPLE DATA
#################################################

# Normal content: A meeting transcript (low entity density)
# This is prose/narrative content with few entities per token.
# It will NOT trigger chunking - processed in a single LLM call.
NORMAL_EPISODE_CONTENT = """
Meeting Notes - Q4 Planning Session

Alice opened the meeting by reviewing our progress on the mobile app redesign.
She mentioned that the user research phase went well and highlighted key findings
from the customer interviews conducted last month.

Bob then presented the engineering timeline. He explained that the backend API
refactoring is about 60% complete and should be finished by end of November.
The team has resolved most of the performance issues identified in the load tests.

Carol raised concerns about the holiday freeze period affecting our deployment
schedule. She suggested we move the beta launch to early December to give the
QA team enough time for regression testing before the code freeze.

David agreed with Carol's assessment and proposed allocating two additional
engineers from the platform team to help with the testing effort. He also
mentioned that the documentation needs to be updated before the release.

Action items:
- Alice will finalize the design specs by Friday
- Bob will coordinate with the platform team on resource allocation
- Carol will update the project timeline in Jira
- David will schedule a follow-up meeting for next Tuesday

The meeting concluded at 3:30 PM with agreement to reconvene next week.
"""

# Dense content: AWS cost data (high entity density)
# This is structured data with many entities per token.
# It WILL trigger chunking - processed in multiple LLM calls.
DENSE_EPISODE_CONTENT = {
    'report_type': 'AWS Cost Breakdown',
    'months': [
        {
            'period': '2025-01',
            'services': [
                {'name': 'Amazon S3', 'cost': 2487.97},
                {'name': 'Amazon RDS', 'cost': 1071.74},
                {'name': 'Amazon ECS', 'cost': 853.74},
                {'name': 'Amazon OpenSearch', 'cost': 389.74},
                {'name': 'AWS Secrets Manager', 'cost': 265.77},
                {'name': 'CloudWatch', 'cost': 232.34},
                {'name': 'Amazon VPC', 'cost': 238.39},
                {'name': 'EC2 Other', 'cost': 226.82},
                {'name': 'Amazon EC2 Compute', 'cost': 78.27},
                {'name': 'Amazon DocumentDB', 'cost': 65.40},
                {'name': 'Amazon ECR', 'cost': 29.00},
                {'name': 'Amazon ELB', 'cost': 37.53},
            ],
        },
        {
            'period': '2025-02',
            'services': [
                {'name': 'Amazon S3', 'cost': 2721.04},
                {'name': 'Amazon RDS', 'cost': 1035.77},
                {'name': 'Amazon ECS', 'cost': 779.49},
                {'name': 'Amazon OpenSearch', 'cost': 357.90},
                {'name': 'AWS Secrets Manager', 'cost': 268.57},
                {'name': 'CloudWatch', 'cost': 224.57},
                {'name': 'Amazon VPC', 'cost': 215.15},
                {'name': 'EC2 Other', 'cost': 213.86},
                {'name': 'Amazon EC2 Compute', 'cost': 70.70},
                {'name': 'Amazon DocumentDB', 'cost': 59.07},
                {'name': 'Amazon ECR', 'cost': 33.92},
                {'name': 'Amazon ELB', 'cost': 33.89},
            ],
        },
        {
            'period': '2025-03',
            'services': [
                {'name': 'Amazon S3', 'cost': 2952.31},
                {'name': 'Amazon RDS', 'cost': 1198.79},
                {'name': 'Amazon ECS', 'cost': 869.78},
                {'name': 'Amazon OpenSearch', 'cost': 389.75},
                {'name': 'AWS Secrets Manager', 'cost': 271.33},
                {'name': 'CloudWatch', 'cost': 233.00},
                {'name': 'Amazon VPC', 'cost': 238.31},
                {'name': 'EC2 Other', 'cost': 227.78},
                {'name': 'Amazon EC2 Compute', 'cost': 78.21},
                {'name': 'Amazon DocumentDB', 'cost': 65.40},
                {'name': 'Amazon ECR', 'cost': 33.75},
                {'name': 'Amazon ELB', 'cost': 37.54},
            ],
        },
        {
            'period': '2025-04',
            'services': [
                {'name': 'Amazon S3', 'cost': 3189.62},
                {'name': 'Amazon RDS', 'cost': 1102.30},
                {'name': 'Amazon ECS', 'cost': 848.19},
                {'name': 'Amazon OpenSearch', 'cost': 379.14},
                {'name': 'AWS Secrets Manager', 'cost': 270.89},
                {'name': 'CloudWatch', 'cost': 230.64},
                {'name': 'Amazon VPC', 'cost': 230.54},
                {'name': 'EC2 Other', 'cost': 220.18},
                {'name': 'Amazon EC2 Compute', 'cost': 75.70},
                {'name': 'Amazon DocumentDB', 'cost': 63.29},
                {'name': 'Amazon ECR', 'cost': 35.21},
                {'name': 'Amazon ELB', 'cost': 36.30},
            ],
        },
        {
            'period': '2025-05',
            'services': [
                {'name': 'Amazon S3', 'cost': 3423.07},
                {'name': 'Amazon RDS', 'cost': 1014.50},
                {'name': 'Amazon ECS', 'cost': 874.75},
                {'name': 'Amazon OpenSearch', 'cost': 389.71},
                {'name': 'AWS Secrets Manager', 'cost': 274.91},
                {'name': 'CloudWatch', 'cost': 233.28},
                {'name': 'Amazon VPC', 'cost': 238.53},
                {'name': 'EC2 Other', 'cost': 227.27},
                {'name': 'Amazon EC2 Compute', 'cost': 78.27},
                {'name': 'Amazon DocumentDB', 'cost': 65.40},
                {'name': 'Amazon ECR', 'cost': 37.42},
                {'name': 'Amazon ELB', 'cost': 37.52},
            ],
        },
        {
            'period': '2025-06',
            'services': [
                {'name': 'Amazon S3', 'cost': 3658.14},
                {'name': 'Amazon RDS', 'cost': 963.60},
                {'name': 'Amazon ECS', 'cost': 942.45},
                {'name': 'Amazon OpenSearch', 'cost': 379.06},
                {'name': 'AWS Secrets Manager', 'cost': 282.41},
                {'name': 'CloudWatch', 'cost': 230.69},
                {'name': 'Amazon VPC', 'cost': 230.64},
                {'name': 'EC2 Other', 'cost': 220.11},
                {'name': 'Amazon EC2 Compute', 'cost': 75.74},
                {'name': 'Amazon DocumentDB', 'cost': 63.30},
                {'name': 'Amazon ECR', 'cost': 39.49},
                {'name': 'Amazon ELB', 'cost': 36.30},
            ],
        },
    ],
}

# Message content: A conversation (moderate entity density)
# This demonstrates message-type episodes with speaker patterns.
MESSAGE_EPISODE_CONTENT = """User: I need to set up a new development environment for the analytics project.
Assistant: I can help with that. Which cloud provider are you using - AWS, GCP, or Azure?
User: We're using AWS. The project requires PostgreSQL, Redis, and Elasticsearch.
Assistant: Great choice. For PostgreSQL, I recommend Amazon RDS with the db.r5.large instance type. For Redis, you can use Amazon ElastiCache. For Elasticsearch, Amazon OpenSearch Service would be ideal.
User: What about the compute layer? We need to run Python data pipelines.
Assistant: For Python data pipelines, I suggest Amazon ECS with Fargate for serverless containers, or AWS Lambda for smaller, event-driven tasks. If you need more control, EC2 instances with Auto Scaling would work well.
User: Let's go with ECS Fargate. Can you also recommend a CI/CD setup?
Assistant: For CI/CD with ECS Fargate, I recommend AWS CodePipeline integrated with CodeBuild for building container images, and CodeDeploy for blue-green deployments. You can store your container images in Amazon ECR.
"""


async def main():
    graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)

    try:
        #################################################
        # EXAMPLE 1: Normal Content (No Chunking)
        #################################################
        # This prose content has low entity density.
        # Graphiti will process it in a single LLM call.
        #################################################

        print('=' * 60)
        print('EXAMPLE 1: Normal Content (Meeting Transcript)')
        print('=' * 60)
        print(f'Content length: {len(NORMAL_EPISODE_CONTENT)} characters')
        print(f'Estimated tokens: ~{len(NORMAL_EPISODE_CONTENT) // 4}')
        print('Expected behavior: Single LLM call (no chunking)')
        print()

        await graphiti.add_episode(
            name='Q4 Planning Meeting',
            episode_body=NORMAL_EPISODE_CONTENT,
            source=EpisodeType.text,
            source_description='Meeting transcript',
            reference_time=datetime.now(timezone.utc),
        )
        print('Successfully added normal episode\n')

        #################################################
        # EXAMPLE 2: Dense Content (Chunking Triggered)
        #################################################
        # This structured data has high entity density.
        # Graphiti will automatically chunk it for
        # reliable extraction across multiple LLM calls.
        #################################################

        print('=' * 60)
        print('EXAMPLE 2: Dense Content (AWS Cost Report)')
        print('=' * 60)
        dense_json = json.dumps(DENSE_EPISODE_CONTENT)
        print(f'Content length: {len(dense_json)} characters')
        print(f'Estimated tokens: ~{len(dense_json) // 4}')
        print('Expected behavior: Multiple LLM calls (chunking enabled)')
        print()

        await graphiti.add_episode(
            name='AWS Cost Report 2025 H1',
            episode_body=dense_json,
            source=EpisodeType.json,
            source_description='AWS cost breakdown by service',
            reference_time=datetime.now(timezone.utc),
        )
        print('Successfully added dense episode\n')

        #################################################
        # EXAMPLE 3: Message Content
        #################################################
        # Conversation content with speaker patterns.
        # Chunking preserves message boundaries.
        #################################################

        print('=' * 60)
        print('EXAMPLE 3: Message Content (Conversation)')
        print('=' * 60)
        print(f'Content length: {len(MESSAGE_EPISODE_CONTENT)} characters')
        print(f'Estimated tokens: ~{len(MESSAGE_EPISODE_CONTENT) // 4}')
        print('Expected behavior: Depends on density threshold')
        print()

        await graphiti.add_episode(
            name='Dev Environment Setup Chat',
            episode_body=MESSAGE_EPISODE_CONTENT,
            source=EpisodeType.message,
            source_description='Support conversation',
            reference_time=datetime.now(timezone.utc),
        )
        print('Successfully added message episode\n')

        #################################################
        # SEARCH RESULTS
        #################################################

        print('=' * 60)
        print('SEARCH: Verifying extracted entities')
        print('=' * 60)

        # Search for entities from normal content
        print("\nSearching for: 'Q4 planning meeting participants'")
        results = await graphiti.search('Q4 planning meeting participants')
        print(f'Found {len(results)} results')
        for r in results[:3]:
            print(f'  - {r.fact}')

        # Search for entities from dense content
        print("\nSearching for: 'AWS S3 costs'")
        results = await graphiti.search('AWS S3 costs')
        print(f'Found {len(results)} results')
        for r in results[:3]:
            print(f'  - {r.fact}')

        # Search for entities from message content
        print("\nSearching for: 'ECS Fargate recommendations'")
        results = await graphiti.search('ECS Fargate recommendations')
        print(f'Found {len(results)} results')
        for r in results[:3]:
            print(f'  - {r.fact}')

    finally:
        await graphiti.close()
        print('\nConnection closed')


if __name__ == '__main__':
    asyncio.run(main())

```

--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
import sys
from datetime import datetime

import numpy as np
import pytest

from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import get_edge_count, get_node_count, group_id

pytest_plugins = ('pytest_asyncio',)


def setup_logging():
    # Create a logger
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)  # Set the logging level to INFO

    # Create console handler and set level to INFO
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)

    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # Add formatter to console handler
    console_handler.setFormatter(formatter)

    # Add console handler to logger
    logger.addHandler(console_handler)

    return logger


@pytest.mark.asyncio
async def test_episodic_edge(graph_driver, mock_embedder):
    now = datetime.now()

    # Create episodic node
    episode_node = EpisodicNode(
        name='test_episode',
        labels=[],
        created_at=now,
        valid_at=now,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Alice likes Bob',
        entity_edges=[],
        group_id=group_id,
    )
    node_count = await get_node_count(graph_driver, [episode_node.uuid])
    assert node_count == 0
    await episode_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [episode_node.uuid])
    assert node_count == 1

    # Create entity node
    alice_node = EntityNode(
        name='Alice',
        labels=[],
        created_at=now,
        summary='Alice summary',
        group_id=group_id,
    )
    await alice_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await alice_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 1

    # Create episodic to entity edge
    episodic_edge = EpisodicEdge(
        source_node_uuid=episode_node.uuid,
        target_node_uuid=alice_node.uuid,
        created_at=now,
        group_id=group_id,
    )
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 0
    await episodic_edge.save(graph_driver)
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 1

    # Get edge by uuid
    retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
    assert retrieved.uuid == episodic_edge.uuid
    assert retrieved.source_node_uuid == episode_node.uuid
    assert retrieved.target_node_uuid == alice_node.uuid
    assert retrieved.created_at == now
    assert retrieved.group_id == group_id

    # Get edge by uuids
    retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
    assert len(retrieved) == 1
    assert retrieved[0].uuid == episodic_edge.uuid
    assert retrieved[0].source_node_uuid == episode_node.uuid
    assert retrieved[0].target_node_uuid == alice_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by group ids
    retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == episodic_edge.uuid
    assert retrieved[0].source_node_uuid == episode_node.uuid
    assert retrieved[0].target_node_uuid == alice_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get episodic node by entity node uuid
    retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == episode_node.uuid
    assert retrieved[0].name == 'test_episode'
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Delete edge by uuid
    await episodic_edge.delete(graph_driver)
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 0

    # Delete edge by uuids
    await episodic_edge.save(graph_driver)
    await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
    edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
    assert edge_count == 0

    # Cleanup nodes
    await episode_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [episode_node.uuid])
    assert node_count == 0
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0

    await graph_driver.close()


@pytest.mark.asyncio
async def test_entity_edge(graph_driver, mock_embedder):
    now = datetime.now()

    # Create entity node
    alice_node = EntityNode(
        name='Alice',
        labels=[],
        created_at=now,
        summary='Alice summary',
        group_id=group_id,
    )
    await alice_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await alice_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 1

    # Create entity node
    bob_node = EntityNode(
        name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
    )
    await bob_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [bob_node.uuid])
    assert node_count == 0
    await bob_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [bob_node.uuid])
    assert node_count == 1

    # Create entity to entity edge
    entity_edge = EntityEdge(
        source_node_uuid=alice_node.uuid,
        target_node_uuid=bob_node.uuid,
        created_at=now,
        name='likes',
        fact='Alice likes Bob',
        episodes=[],
        expired_at=now,
        valid_at=now,
        invalid_at=now,
        group_id=group_id,
    )
    edge_embedding = await entity_edge.generate_embedding(mock_embedder)
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0
    await entity_edge.save(graph_driver)
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 1

    # Get edge by uuid
    retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
    assert retrieved.uuid == entity_edge.uuid
    assert retrieved.source_node_uuid == alice_node.uuid
    assert retrieved.target_node_uuid == bob_node.uuid
    assert retrieved.created_at == now
    assert retrieved.group_id == group_id

    # Get edge by uuids
    retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
    assert len(retrieved) == 1
    assert retrieved[0].uuid == entity_edge.uuid
    assert retrieved[0].source_node_uuid == alice_node.uuid
    assert retrieved[0].target_node_uuid == bob_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by group ids
    retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == entity_edge.uuid
    assert retrieved[0].source_node_uuid == alice_node.uuid
    assert retrieved[0].target_node_uuid == bob_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by node uuid
    retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == entity_edge.uuid
    assert retrieved[0].source_node_uuid == alice_node.uuid
    assert retrieved[0].target_node_uuid == bob_node.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get fact embedding
    await entity_edge.load_fact_embedding(graph_driver)
    assert np.allclose(entity_edge.fact_embedding, edge_embedding)

    # Delete edge by uuid
    await entity_edge.delete(graph_driver)
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Delete edge by uuids
    await entity_edge.save(graph_driver)
    await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Deleting node should delete the edge
    await entity_edge.save(graph_driver)
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Deleting node by uuids should delete the edge
    await alice_node.save(graph_driver)
    await entity_edge.save(graph_driver)
    await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Deleting node by group id should delete the edge
    await alice_node.save(graph_driver)
    await entity_edge.save(graph_driver)
    await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
    assert edge_count == 0

    # Cleanup nodes
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await bob_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [bob_node.uuid])
    assert node_count == 0

    await graph_driver.close()


@pytest.mark.asyncio
async def test_community_edge(graph_driver, mock_embedder):
    now = datetime.now()

    # Create community node
    community_node_1 = CommunityNode(
        name='test_community_1',
        group_id=group_id,
        summary='Community A summary',
    )
    await community_node_1.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid])
    assert node_count == 0
    await community_node_1.save(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid])
    assert node_count == 1

    # Create community node
    community_node_2 = CommunityNode(
        name='test_community_2',
        group_id=group_id,
        summary='Community B summary',
    )
    await community_node_2.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [community_node_2.uuid])
    assert node_count == 0
    await community_node_2.save(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_2.uuid])
    assert node_count == 1

    # Create entity node
    alice_node = EntityNode(
        name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
    )
    await alice_node.generate_name_embedding(mock_embedder)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await alice_node.save(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 1

    # Create community to community edge
    community_edge = CommunityEdge(
        source_node_uuid=community_node_1.uuid,
        target_node_uuid=community_node_2.uuid,
        created_at=now,
        group_id=group_id,
    )
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 0
    await community_edge.save(graph_driver)
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 1

    # Get edge by uuid
    retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
    assert retrieved.uuid == community_edge.uuid
    assert retrieved.source_node_uuid == community_node_1.uuid
    assert retrieved.target_node_uuid == community_node_2.uuid
    assert retrieved.created_at == now
    assert retrieved.group_id == group_id

    # Get edge by uuids
    retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
    assert len(retrieved) == 1
    assert retrieved[0].uuid == community_edge.uuid
    assert retrieved[0].source_node_uuid == community_node_1.uuid
    assert retrieved[0].target_node_uuid == community_node_2.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Get edge by group ids
    retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
    assert len(retrieved) == 1
    assert retrieved[0].uuid == community_edge.uuid
    assert retrieved[0].source_node_uuid == community_node_1.uuid
    assert retrieved[0].target_node_uuid == community_node_2.uuid
    assert retrieved[0].created_at == now
    assert retrieved[0].group_id == group_id

    # Delete edge by uuid
    await community_edge.delete(graph_driver)
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 0

    # Delete edge by uuids
    await community_edge.save(graph_driver)
    await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
    edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
    assert edge_count == 0

    # Cleanup nodes
    await alice_node.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [alice_node.uuid])
    assert node_count == 0
    await community_node_1.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid])
    assert node_count == 0
    await community_node_2.delete(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_2.uuid])
    assert node_count == 0

    await graph_driver.close()

```

--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/embedder/test_gemini.py

from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from embedder_fixtures import create_embedding_values

from graphiti_core.embedder.gemini import (
    DEFAULT_EMBEDDING_MODEL,
    GeminiEmbedder,
    GeminiEmbedderConfig,
)


def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
    """Create a mock Gemini embedding with specified value multiplier and dimension."""
    mock_embedding = MagicMock()
    mock_embedding.values = create_embedding_values(multiplier, dimension)
    return mock_embedding


@pytest.fixture
def mock_gemini_response() -> MagicMock:
    """Create a mock Gemini embeddings response."""
    mock_result = MagicMock()
    mock_result.embeddings = [create_gemini_embedding()]
    return mock_result


@pytest.fixture
def mock_gemini_batch_response() -> MagicMock:
    """Create a mock Gemini batch embeddings response."""
    mock_result = MagicMock()
    mock_result.embeddings = [
        create_gemini_embedding(0.1),
        create_gemini_embedding(0.2),
        create_gemini_embedding(0.3),
    ]
    return mock_result


@pytest.fixture
def mock_gemini_client() -> Generator[Any, Any, None]:
    """Create a mocked Gemini client."""
    with patch('google.genai.Client') as mock_client:
        mock_instance = mock_client.return_value
        mock_instance.aio = MagicMock()
        mock_instance.aio.models = MagicMock()
        mock_instance.aio.models.embed_content = AsyncMock()
        yield mock_instance


@pytest.fixture
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
    """Create a GeminiEmbedder with a mocked client."""
    config = GeminiEmbedderConfig(api_key='test_api_key')
    client = GeminiEmbedder(config=config)
    client.client = mock_gemini_client
    return client


class TestGeminiEmbedderInitialization:
    """Tests for GeminiEmbedder initialization."""

    @patch('google.genai.Client')
    def test_init_with_config(self, mock_client):
        """Test initialization with a config object."""
        config = GeminiEmbedderConfig(
            api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
        )
        embedder = GeminiEmbedder(config=config)

        assert embedder.config == config
        assert embedder.config.embedding_model == 'custom-model'
        assert embedder.config.api_key == 'test_api_key'
        assert embedder.config.embedding_dim == 768

    @patch('google.genai.Client')
    def test_init_without_config(self, mock_client):
        """Test initialization without a config uses defaults."""
        embedder = GeminiEmbedder()

        assert embedder.config is not None
        assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL

    @patch('google.genai.Client')
    def test_init_with_partial_config(self, mock_client):
        """Test initialization with partial config."""
        config = GeminiEmbedderConfig(api_key='test_api_key')
        embedder = GeminiEmbedder(config=config)

        assert embedder.config.api_key == 'test_api_key'
        assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL


class TestGeminiEmbedderCreate:
    """Tests for GeminiEmbedder create method."""

    @pytest.mark.asyncio
    async def test_create_calls_api_correctly(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_response: MagicMock,
    ) -> None:
        """Test that create method correctly calls the API and processes the response."""
        # Setup
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response

        # Call method
        result = await gemini_embedder.create('Test input')

        # Verify API is called with correct parameters
        mock_gemini_client.aio.models.embed_content.assert_called_once()
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
        assert kwargs['contents'] == ['Test input']

        # Verify result is processed correctly
        assert result == mock_gemini_response.embeddings[0].values

    @pytest.mark.asyncio
    @patch('google.genai.Client')
    async def test_create_with_custom_model(
        self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
    ) -> None:
        """Test create method with custom embedding model."""
        # Setup embedder with custom model
        config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
        embedder = GeminiEmbedder(config=config)
        embedder.client = mock_gemini_client
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response

        # Call method
        await embedder.create('Test input')

        # Verify custom model is used
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == 'custom-model'

    @pytest.mark.asyncio
    @patch('google.genai.Client')
    async def test_create_with_custom_dimension(
        self, mock_client_class, mock_gemini_client: Any
    ) -> None:
        """Test create method with custom embedding dimension."""
        # Setup embedder with custom dimension
        config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
        embedder = GeminiEmbedder(config=config)
        embedder.client = mock_gemini_client

        # Setup mock response with custom dimension
        mock_response = MagicMock()
        mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        # Call method
        result = await embedder.create('Test input')

        # Verify custom dimension is used in config
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['config'].output_dimensionality == 768

        # Verify result has correct dimension
        assert len(result) == 768

    @pytest.mark.asyncio
    async def test_create_with_different_input_types(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_response: MagicMock,
    ) -> None:
        """Test create method with different input types."""
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response

        # Test with string
        await gemini_embedder.create('Test string')

        # Test with list of strings
        await gemini_embedder.create(['Test', 'List'])

        # Test with iterable of integers
        await gemini_embedder.create([1, 2, 3])

        # Verify all calls were made
        assert mock_gemini_client.aio.models.embed_content.call_count == 3

    @pytest.mark.asyncio
    async def test_create_no_embeddings_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create method handling of no embeddings response."""
        # Setup mock response with no embeddings
        mock_response = MagicMock()
        mock_response.embeddings = []
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        # Call method and expect exception
        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create('Test input')

        assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_create_no_values_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create method handling of embeddings with no values."""
        # Setup mock response with embedding but no values
        mock_embedding = MagicMock()
        mock_embedding.values = None
        mock_response = MagicMock()
        mock_response.embeddings = [mock_embedding]
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        # Call method and expect exception
        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create('Test input')

        assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)


class TestGeminiEmbedderCreateBatch:
    """Tests for GeminiEmbedder create_batch method."""

    @pytest.mark.asyncio
    async def test_create_batch_processes_multiple_inputs(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_batch_response: MagicMock,
    ) -> None:
        """Test that create_batch method correctly processes multiple inputs."""
        # Setup
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
        input_batch = ['Input 1', 'Input 2', 'Input 3']

        # Call method
        result = await gemini_embedder.create_batch(input_batch)

        # Verify API is called with correct parameters
        mock_gemini_client.aio.models.embed_content.assert_called_once()
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
        assert kwargs['contents'] == input_batch

        # Verify all results are processed correctly
        assert len(result) == 3
        assert result == [
            mock_gemini_batch_response.embeddings[0].values,
            mock_gemini_batch_response.embeddings[1].values,
            mock_gemini_batch_response.embeddings[2].values,
        ]

    @pytest.mark.asyncio
    async def test_create_batch_single_input(
        self,
        gemini_embedder: GeminiEmbedder,
        mock_gemini_client: Any,
        mock_gemini_response: MagicMock,
    ) -> None:
        """Test create_batch method with single input."""
        mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
        input_batch = ['Single input']

        result = await gemini_embedder.create_batch(input_batch)

        assert len(result) == 1
        assert result[0] == mock_gemini_response.embeddings[0].values

    @pytest.mark.asyncio
    async def test_create_batch_empty_input(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method with empty input."""
        # Setup mock response with no embeddings
        mock_response = MagicMock()
        mock_response.embeddings = []
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        input_batch = []

        result = await gemini_embedder.create_batch(input_batch)
        assert result == []
        mock_gemini_client.aio.models.embed_content.assert_not_called()

    @pytest.mark.asyncio
    async def test_create_batch_no_embeddings_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method handling of no embeddings response."""
        # Setup mock response with no embeddings
        mock_response = MagicMock()
        mock_response.embeddings = []
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        input_batch = ['Input 1', 'Input 2']

        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create_batch(input_batch)

        assert 'No embeddings returned from Gemini API' in str(exc_info.value)

    @pytest.mark.asyncio
    async def test_create_batch_empty_values_error(
        self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method handling of embeddings with empty values."""
        # Setup mock response with embeddings but empty values
        mock_embedding1 = MagicMock()
        mock_embedding1.values = [0.1, 0.2, 0.3]  # Valid values
        mock_embedding2 = MagicMock()
        mock_embedding2.values = None  # Empty values

        # Mock response for the initial batch call
        mock_batch_response = MagicMock()
        mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]

        # Mock response for individual processing of 'Input 1'
        mock_individual_response_1 = MagicMock()
        mock_individual_response_1.embeddings = [mock_embedding1]

        # Mock response for individual processing of 'Input 2' (which has empty values)
        mock_individual_response_2 = MagicMock()
        mock_individual_response_2.embeddings = [mock_embedding2]

        # Set side_effect for embed_content to control return values for each call
        mock_gemini_client.aio.models.embed_content.side_effect = [
            mock_batch_response,  # First call for the batch
            mock_individual_response_1,  # Second call for individual item 1
            mock_individual_response_2,  # Third call for individual item 2
        ]

        input_batch = ['Input 1', 'Input 2']

        with pytest.raises(ValueError) as exc_info:
            await gemini_embedder.create_batch(input_batch)

        assert 'Empty embedding values returned' in str(exc_info.value)

    @pytest.mark.asyncio
    @patch('google.genai.Client')
    async def test_create_batch_with_custom_model_and_dimension(
        self, mock_client_class, mock_gemini_client: Any
    ) -> None:
        """Test create_batch method with custom model and dimension."""
        # Setup embedder with custom settings
        config = GeminiEmbedderConfig(
            api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
        )
        embedder = GeminiEmbedder(config=config)
        embedder.client = mock_gemini_client

        # Setup mock response
        mock_response = MagicMock()
        mock_response.embeddings = [
            create_gemini_embedding(0.1, 512),
            create_gemini_embedding(0.2, 512),
        ]
        mock_gemini_client.aio.models.embed_content.return_value = mock_response

        input_batch = ['Input 1', 'Input 2']
        result = await embedder.create_batch(input_batch)

        # Verify custom settings are used
        _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
        assert kwargs['model'] == 'custom-batch-model'
        assert kwargs['config'].output_dimensionality == 512

        # Verify results have correct dimension
        assert len(result) == 2
        assert all(len(embedding) == 512 for embedding in result)


if __name__ == '__main__':
    pytest.main(['-xvs', __file__])

```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_entity_extraction.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.prompts.extract_nodes import ExtractedEntity
from graphiti_core.utils import content_chunking
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance import node_operations
from graphiti_core.utils.maintenance.node_operations import (
    _build_entity_types_context,
    _merge_extracted_entities,
    extract_nodes,
)


def _make_clients():
    """Create mock GraphitiClients for testing."""
    driver = MagicMock()
    embedder = MagicMock()
    cross_encoder = MagicMock()
    llm_client = MagicMock()
    llm_generate = AsyncMock()
    llm_client.generate_response = llm_generate

    clients = GraphitiClients.model_construct(  # bypass validation to allow test doubles
        driver=driver,
        embedder=embedder,
        cross_encoder=cross_encoder,
        llm_client=llm_client,
    )

    return clients, llm_generate


def _make_episode(
    content: str = 'Test content',
    source: EpisodeType = EpisodeType.text,
    group_id: str = 'group',
) -> EpisodicNode:
    """Create a test episode node."""
    return EpisodicNode(
        name='test_episode',
        group_id=group_id,
        source=source,
        source_description='test',
        content=content,
        valid_at=utc_now(),
    )


class TestExtractNodesSmallInput:
    @pytest.mark.asyncio
    async def test_small_input_single_llm_call(self, monkeypatch):
        """Small inputs should use a single LLM call without chunking."""
        clients, llm_generate = _make_clients()

        # Mock LLM response
        llm_generate.return_value = {
            'extracted_entities': [
                {'name': 'Alice', 'entity_type_id': 0},
                {'name': 'Bob', 'entity_type_id': 0},
            ]
        }

        # Small content (below threshold)
        episode = _make_episode(content='Alice talked to Bob.')

        nodes = await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        # Verify results
        assert len(nodes) == 2
        assert {n.name for n in nodes} == {'Alice', 'Bob'}

        # LLM should be called exactly once
        llm_generate.assert_awaited_once()

    @pytest.mark.asyncio
    async def test_extracts_entity_types(self, monkeypatch):
        """Entity type classification should work correctly."""
        clients, llm_generate = _make_clients()

        from pydantic import BaseModel

        class Person(BaseModel):
            """A human person."""

            pass

        llm_generate.return_value = {
            'extracted_entities': [
                {'name': 'Alice', 'entity_type_id': 1},  # Person
                {'name': 'Acme Corp', 'entity_type_id': 0},  # Default Entity
            ]
        }

        episode = _make_episode(content='Alice works at Acme Corp.')

        nodes = await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
            entity_types={'Person': Person},
        )

        # Alice should have Person label
        alice = next(n for n in nodes if n.name == 'Alice')
        assert 'Person' in alice.labels

        # Acme should have Entity label
        acme = next(n for n in nodes if n.name == 'Acme Corp')
        assert 'Entity' in acme.labels

    @pytest.mark.asyncio
    async def test_excludes_entity_types(self, monkeypatch):
        """Excluded entity types should not appear in results."""
        clients, llm_generate = _make_clients()

        from pydantic import BaseModel

        class User(BaseModel):
            """A user of the system."""

            pass

        llm_generate.return_value = {
            'extracted_entities': [
                {'name': 'Alice', 'entity_type_id': 1},  # User (excluded)
                {'name': 'Project X', 'entity_type_id': 0},  # Entity
            ]
        }

        episode = _make_episode(content='Alice created Project X.')

        nodes = await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
            entity_types={'User': User},
            excluded_entity_types=['User'],
        )

        # Alice should be excluded
        assert len(nodes) == 1
        assert nodes[0].name == 'Project X'

    @pytest.mark.asyncio
    async def test_filters_empty_names(self, monkeypatch):
        """Entities with empty names should be filtered out."""
        clients, llm_generate = _make_clients()

        llm_generate.return_value = {
            'extracted_entities': [
                {'name': 'Alice', 'entity_type_id': 0},
                {'name': '', 'entity_type_id': 0},
                {'name': '   ', 'entity_type_id': 0},
            ]
        }

        episode = _make_episode(content='Alice is here.')

        nodes = await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        assert len(nodes) == 1
        assert nodes[0].name == 'Alice'


class TestExtractNodesChunking:
    @pytest.mark.asyncio
    async def test_large_input_triggers_chunking(self, monkeypatch):
        """Large inputs should be chunked and processed in parallel."""
        clients, llm_generate = _make_clients()

        # Track number of LLM calls
        call_count = 0

        async def mock_generate(*args, **kwargs):
            nonlocal call_count
            call_count += 1
            return {
                'extracted_entities': [
                    {'name': f'Entity{call_count}', 'entity_type_id': 0},
                ]
            }

        llm_generate.side_effect = mock_generate

        # Patch should_chunk where it's imported in node_operations
        monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
        monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size

        # Large content that exceeds threshold
        large_content = 'word ' * 1000
        episode = _make_episode(content=large_content)

        await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        # Multiple LLM calls should have been made
        assert call_count > 1

    @pytest.mark.asyncio
    async def test_json_content_uses_json_chunking(self, monkeypatch):
        """JSON episodes should use JSON-aware chunking."""
        clients, llm_generate = _make_clients()

        llm_generate.return_value = {
            'extracted_entities': [
                {'name': 'Service1', 'entity_type_id': 0},
            ]
        }

        # Patch should_chunk where it's imported in node_operations
        monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
        monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size

        # JSON content
        json_data = [{'service': f'Service{i}'} for i in range(50)]
        episode = _make_episode(
            content=json.dumps(json_data),
            source=EpisodeType.json,
        )

        await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        # Verify JSON chunking was used (LLM called multiple times)
        assert llm_generate.await_count > 1

    @pytest.mark.asyncio
    async def test_message_content_uses_message_chunking(self, monkeypatch):
        """Message episodes should use message-aware chunking."""
        clients, llm_generate = _make_clients()

        llm_generate.return_value = {
            'extracted_entities': [
                {'name': 'Speaker', 'entity_type_id': 0},
            ]
        }

        # Patch should_chunk where it's imported in node_operations
        monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
        monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size

        # Conversation content
        messages = [f'Speaker{i}: Hello from speaker {i}!' for i in range(50)]
        episode = _make_episode(
            content='\n'.join(messages),
            source=EpisodeType.message,
        )

        await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        assert llm_generate.await_count > 1

    @pytest.mark.asyncio
    async def test_deduplicates_across_chunks(self, monkeypatch):
        """Entities appearing in multiple chunks should be deduplicated."""
        clients, llm_generate = _make_clients()

        # Simulate same entity appearing in multiple chunks
        call_count = 0

        async def mock_generate(*args, **kwargs):
            nonlocal call_count
            call_count += 1
            # Return 'Alice' in every chunk
            return {
                'extracted_entities': [
                    {'name': 'Alice', 'entity_type_id': 0},
                    {'name': f'Entity{call_count}', 'entity_type_id': 0},
                ]
            }

        llm_generate.side_effect = mock_generate

        # Patch should_chunk where it's imported in node_operations
        monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
        monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size

        large_content = 'word ' * 1000
        episode = _make_episode(content=large_content)

        nodes = await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        # Alice should appear only once despite being in every chunk
        alice_count = sum(1 for n in nodes if n.name == 'Alice')
        assert alice_count == 1

    @pytest.mark.asyncio
    async def test_deduplication_case_insensitive(self, monkeypatch):
        """Deduplication should be case-insensitive."""
        clients, llm_generate = _make_clients()

        call_count = 0

        async def mock_generate(*args, **kwargs):
            nonlocal call_count
            call_count += 1
            if call_count == 1:
                return {'extracted_entities': [{'name': 'alice', 'entity_type_id': 0}]}
            return {'extracted_entities': [{'name': 'Alice', 'entity_type_id': 0}]}

        llm_generate.side_effect = mock_generate

        # Patch should_chunk where it's imported in node_operations
        monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
        monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size

        large_content = 'word ' * 1000
        episode = _make_episode(content=large_content)

        nodes = await extract_nodes(
            clients,
            episode,
            previous_episodes=[],
        )

        # Should have only one Alice (case-insensitive dedup)
        alice_variants = [n for n in nodes if n.name.lower() == 'alice']
        assert len(alice_variants) == 1


class TestExtractNodesPromptSelection:
    @pytest.mark.asyncio
    async def test_uses_text_prompt_for_text_episodes(self, monkeypatch):
        """Text episodes should use extract_text prompt."""
        clients, llm_generate = _make_clients()
        llm_generate.return_value = {'extracted_entities': []}

        episode = _make_episode(source=EpisodeType.text)

        await extract_nodes(clients, episode, previous_episodes=[])

        # Check prompt_name parameter
        call_kwargs = llm_generate.call_args[1]
        assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_text'

    @pytest.mark.asyncio
    async def test_uses_json_prompt_for_json_episodes(self, monkeypatch):
        """JSON episodes should use extract_json prompt."""
        clients, llm_generate = _make_clients()
        llm_generate.return_value = {'extracted_entities': []}

        episode = _make_episode(content='{}', source=EpisodeType.json)

        await extract_nodes(clients, episode, previous_episodes=[])

        call_kwargs = llm_generate.call_args[1]
        assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_json'

    @pytest.mark.asyncio
    async def test_uses_message_prompt_for_message_episodes(self, monkeypatch):
        """Message episodes should use extract_message prompt."""
        clients, llm_generate = _make_clients()
        llm_generate.return_value = {'extracted_entities': []}

        episode = _make_episode(source=EpisodeType.message)

        await extract_nodes(clients, episode, previous_episodes=[])

        call_kwargs = llm_generate.call_args[1]
        assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_message'


class TestBuildEntityTypesContext:
    def test_default_entity_type_always_included(self):
        """Default Entity type should always be at index 0."""
        context = _build_entity_types_context(None)

        assert len(context) == 1
        assert context[0]['entity_type_id'] == 0
        assert context[0]['entity_type_name'] == 'Entity'

    def test_custom_types_added_after_default(self):
        """Custom entity types should be added with sequential IDs."""
        from pydantic import BaseModel

        class Person(BaseModel):
            """A human person."""

            pass

        class Organization(BaseModel):
            """A business or organization."""

            pass

        context = _build_entity_types_context(
            {
                'Person': Person,
                'Organization': Organization,
            }
        )

        assert len(context) == 3
        assert context[0]['entity_type_name'] == 'Entity'
        assert context[1]['entity_type_name'] == 'Person'
        assert context[1]['entity_type_id'] == 1
        assert context[2]['entity_type_name'] == 'Organization'
        assert context[2]['entity_type_id'] == 2


class TestMergeExtractedEntities:
    def test_merge_deduplicates_by_name(self):
        """Entities with same name should be deduplicated."""
        chunk_results = [
            [
                ExtractedEntity(name='Alice', entity_type_id=0),
                ExtractedEntity(name='Bob', entity_type_id=0),
            ],
            [
                ExtractedEntity(name='Alice', entity_type_id=0),  # Duplicate
                ExtractedEntity(name='Charlie', entity_type_id=0),
            ],
        ]

        merged = _merge_extracted_entities(chunk_results)

        assert len(merged) == 3
        names = {e.name for e in merged}
        assert names == {'Alice', 'Bob', 'Charlie'}

    def test_merge_prefers_first_occurrence(self):
        """When duplicates exist, first occurrence should be preferred."""
        chunk_results = [
            [ExtractedEntity(name='Alice', entity_type_id=1)],  # First: type 1
            [ExtractedEntity(name='Alice', entity_type_id=2)],  # Later: type 2
        ]

        merged = _merge_extracted_entities(chunk_results)

        assert len(merged) == 1
        assert merged[0].entity_type_id == 1  # First occurrence wins

```

--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import unittest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from graphiti_core.driver.driver import GraphProvider

try:
    from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession

    HAS_FALKORDB = True
except ImportError:
    FalkorDriver = None
    HAS_FALKORDB = False


class TestFalkorDriver:
    """Comprehensive test suite for FalkorDB driver."""

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def setup_method(self):
        """Set up test fixtures."""
        self.mock_client = MagicMock()
        with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
            self.driver = FalkorDriver()
        self.driver.client = self.mock_client

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_init_with_connection_params(self):
        """Test initialization with connection parameters."""
        with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
            driver = FalkorDriver(
                host='test-host', port='1234', username='test-user', password='test-pass'
            )
            assert driver.provider == GraphProvider.FALKORDB
            mock_falkor_db.assert_called_once_with(
                host='test-host', port='1234', username='test-user', password='test-pass'
            )

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_init_with_falkor_db_instance(self):
        """Test initialization with a FalkorDB instance."""
        with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
            mock_falkor_db = MagicMock()
            driver = FalkorDriver(falkor_db=mock_falkor_db)
            assert driver.provider == GraphProvider.FALKORDB
            assert driver.client is mock_falkor_db
            mock_falkor_db_class.assert_not_called()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_provider(self):
        """Test driver provider identification."""
        assert self.driver.provider == GraphProvider.FALKORDB

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_get_graph_with_name(self):
        """Test _get_graph with specific graph name."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        result = self.driver._get_graph('test_graph')

        self.mock_client.select_graph.assert_called_once_with('test_graph')
        assert result is mock_graph

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_get_graph_with_none_defaults_to_default_database(self):
        """Test _get_graph with None defaults to default_db."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        result = self.driver._get_graph(None)

        self.mock_client.select_graph.assert_called_once_with('default_db')
        assert result is mock_graph

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_success(self):
        """Test successful query execution."""
        mock_graph = MagicMock()
        mock_result = MagicMock()
        mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
        mock_result.result_set = [['row1col1', 'row1col2']]
        mock_graph.query = AsyncMock(return_value=mock_result)
        self.mock_client.select_graph.return_value = mock_graph

        result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')

        mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})

        result_set, header, summary = result
        assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
        assert header == ['column1', 'column2']
        assert summary is None

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_handles_index_already_exists_error(self):
        """Test handling of 'already indexed' error."""
        mock_graph = MagicMock()
        mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
        self.mock_client.select_graph.return_value = mock_graph

        with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
            result = await self.driver.execute_query('CREATE INDEX ...')

            mock_logger.info.assert_called_once()
            assert result is None

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_propagates_other_exceptions(self):
        """Test that other exceptions are properly propagated."""
        mock_graph = MagicMock()
        mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
        self.mock_client.select_graph.return_value = mock_graph

        with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
            with pytest.raises(Exception, match='Other error'):
                await self.driver.execute_query('INVALID QUERY')

            mock_logger.error.assert_called_once()

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_query_converts_datetime_parameters(self):
        """Test that datetime objects in kwargs are converted to ISO strings."""
        mock_graph = MagicMock()
        mock_result = MagicMock()
        mock_result.header = []
        mock_result.result_set = []
        mock_graph.query = AsyncMock(return_value=mock_result)
        self.mock_client.select_graph.return_value = mock_graph

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        await self.driver.execute_query(
            'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
        )

        call_args = mock_graph.query.call_args[0]
        assert call_args[1]['created_at'] == test_datetime.isoformat()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_session_creation(self):
        """Test session creation with specific database."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        session = self.driver.session()

        assert isinstance(session, FalkorDriverSession)
        assert session.graph is mock_graph

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_session_creation_with_none_uses_default_database(self):
        """Test session creation with None uses default database."""
        mock_graph = MagicMock()
        self.mock_client.select_graph.return_value = mock_graph

        session = self.driver.session()

        assert isinstance(session, FalkorDriverSession)

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_close_calls_connection_close(self):
        """Test driver close method calls connection close."""
        mock_connection = MagicMock()
        mock_connection.close = AsyncMock()
        self.mock_client.connection = mock_connection

        # Ensure hasattr checks work correctly
        del self.mock_client.aclose  # Remove aclose if it exists

        with patch('builtins.hasattr') as mock_hasattr:
            # hasattr(self.client, 'aclose') returns False
            # hasattr(self.client.connection, 'aclose') returns False
            # hasattr(self.client.connection, 'close') returns True
            mock_hasattr.side_effect = lambda obj, attr: (
                attr == 'close' and obj is mock_connection
            )

            await self.driver.close()

        mock_connection.close.assert_called_once()

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_delete_all_indexes(self):
        """Test delete_all_indexes method."""
        with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
            # Return None to simulate no indexes found
            mock_execute.return_value = None

            await self.driver.delete_all_indexes()

            mock_execute.assert_called_once_with('CALL db.indexes()')


class TestFalkorDriverSession:
    """Test FalkorDB driver session functionality."""

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def setup_method(self):
        """Set up test fixtures."""
        self.mock_graph = MagicMock()
        self.session = FalkorDriverSession(self.mock_graph)

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_session_async_context_manager(self):
        """Test session can be used as async context manager."""
        async with self.session as s:
            assert s is self.session

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_close_method(self):
        """Test session close method doesn't raise exceptions."""
        await self.session.close()  # Should not raise

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_execute_write_passes_session_and_args(self):
        """Test execute_write method passes session and arguments correctly."""

        async def test_func(session, *args, **kwargs):
            assert session is self.session
            assert args == ('arg1', 'arg2')
            assert kwargs == {'key': 'value'}
            return 'result'

        result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
        assert result == 'result'

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_run_single_query_with_parameters(self):
        """Test running a single query with parameters."""
        self.mock_graph.query = AsyncMock()

        await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')

        self.mock_graph.query.assert_called_once_with(
            'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
        )

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_run_multiple_queries_as_list(self):
        """Test running multiple queries passed as list."""
        self.mock_graph.query = AsyncMock()

        queries = [
            ('MATCH (n) RETURN n', {'param1': 'value1'}),
            ('CREATE (n:Node)', {'param2': 'value2'}),
        ]

        await self.session.run(queries)

        assert self.mock_graph.query.call_count == 2
        calls = self.mock_graph.query.call_args_list
        assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
        assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_run_converts_datetime_objects_to_iso_strings(self):
        """Test that datetime objects are converted to ISO strings."""
        self.mock_graph.query = AsyncMock()
        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        await self.session.run(
            'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
        )

        self.mock_graph.query.assert_called_once()
        call_args = self.mock_graph.query.call_args[0]
        assert call_args[1]['created_at'] == test_datetime.isoformat()


class TestDatetimeConversion:
    """Test datetime conversion utility function."""

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_datetime_dict(self):
        """Test datetime conversion in nested dictionary."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
        input_dict = {
            'string_val': 'test',
            'datetime_val': test_datetime,
            'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
        }

        result = convert_datetimes_to_strings(input_dict)

        assert result['string_val'] == 'test'
        assert result['datetime_val'] == test_datetime.isoformat()
        assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
        assert result['nested_dict']['nested_string'] == 'nested_test'

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_datetime_list_and_tuple(self):
        """Test datetime conversion in lists and tuples."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)

        # Test list
        input_list = ['test', test_datetime, ['nested', test_datetime]]
        result_list = convert_datetimes_to_strings(input_list)
        assert result_list[0] == 'test'
        assert result_list[1] == test_datetime.isoformat()
        assert result_list[2][1] == test_datetime.isoformat()

        # Test tuple
        input_tuple = ('test', test_datetime)
        result_tuple = convert_datetimes_to_strings(input_tuple)
        assert isinstance(result_tuple, tuple)
        assert result_tuple[0] == 'test'
        assert result_tuple[1] == test_datetime.isoformat()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_single_datetime(self):
        """Test datetime conversion for single datetime object."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
        result = convert_datetimes_to_strings(test_datetime)
        assert result == test_datetime.isoformat()

    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    def test_convert_other_types_unchanged(self):
        """Test that non-datetime types are returned unchanged."""
        from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings

        assert convert_datetimes_to_strings('string') == 'string'
        assert convert_datetimes_to_strings(123) == 123
        assert convert_datetimes_to_strings(None) is None
        assert convert_datetimes_to_strings(True) is True


# Simple integration test
class TestFalkorDriverIntegration:
    """Simple integration test for FalkorDB driver."""

    @pytest.mark.asyncio
    @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
    async def test_basic_integration_with_real_falkordb(self):
        """Basic integration test with real FalkorDB instance."""
        pytest.importorskip('falkordb')

        falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
        falkor_port = os.getenv('FALKORDB_PORT', '6379')

        try:
            driver = FalkorDriver(host=falkor_host, port=falkor_port)

            # Test basic query execution
            result = await driver.execute_query('RETURN 1 as test')
            assert result is not None

            result_set, header, summary = result
            assert header == ['test']
            assert result_set == [{'test': 1}]

            await driver.close()

        except Exception as e:
            pytest.skip(f'FalkorDB not available for integration test: {e}')

```

--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------

```python
"""Factory classes for creating LLM, Embedder, and Database clients."""

from openai import AsyncAzureOpenAI

from config.schema import (
    DatabaseConfig,
    EmbedderConfig,
    LLMConfig,
)

# Try to import FalkorDriver if available
try:
    from graphiti_core.driver.falkordb_driver import FalkorDriver  # noqa: F401

    HAS_FALKOR = True
except ImportError:
    HAS_FALKOR = False

# Kuzu support removed - FalkorDB is now the default
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig

# Try to import additional providers if available
try:
    from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient

    HAS_AZURE_EMBEDDER = True
except ImportError:
    HAS_AZURE_EMBEDDER = False

try:
    from graphiti_core.embedder.gemini import GeminiEmbedder

    HAS_GEMINI_EMBEDDER = True
except ImportError:
    HAS_GEMINI_EMBEDDER = False

try:
    from graphiti_core.embedder.voyage import VoyageAIEmbedder

    HAS_VOYAGE_EMBEDDER = True
except ImportError:
    HAS_VOYAGE_EMBEDDER = False

try:
    from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient

    HAS_AZURE_LLM = True
except ImportError:
    HAS_AZURE_LLM = False

try:
    from graphiti_core.llm_client.anthropic_client import AnthropicClient

    HAS_ANTHROPIC = True
except ImportError:
    HAS_ANTHROPIC = False

try:
    from graphiti_core.llm_client.gemini_client import GeminiClient

    HAS_GEMINI = True
except ImportError:
    HAS_GEMINI = False

try:
    from graphiti_core.llm_client.groq_client import GroqClient

    HAS_GROQ = True
except ImportError:
    HAS_GROQ = False
from utils.utils import create_azure_credential_token_provider


def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
    """Validate API key is present.

    Args:
        provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
        api_key: The API key to validate
        logger: Logger instance for output

    Returns:
        The validated API key

    Raises:
        ValueError: If API key is None or empty
    """
    if not api_key:
        raise ValueError(
            f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
        )

    logger.info(f'Creating {provider_name} client')

    return api_key


class LLMClientFactory:
    """Factory for creating LLM clients based on configuration."""

    @staticmethod
    def create(config: LLMConfig) -> LLMClient:
        """Create an LLM client based on the configured provider."""
        import logging

        logger = logging.getLogger(__name__)

        provider = config.provider.lower()

        match provider:
            case 'openai':
                if not config.providers.openai:
                    raise ValueError('OpenAI provider configuration not found')

                api_key = config.providers.openai.api_key
                _validate_api_key('OpenAI', api_key, logger)

                from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig

                # Determine appropriate small model based on main model type
                is_reasoning_model = (
                    config.model.startswith('gpt-5')
                    or config.model.startswith('o1')
                    or config.model.startswith('o3')
                )
                small_model = (
                    'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
                )  # Use reasoning model for small tasks if main model is reasoning

                llm_config = CoreLLMConfig(
                    api_key=api_key,
                    model=config.model,
                    small_model=small_model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )

                # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
                if is_reasoning_model:
                    return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
                else:
                    # For non-reasoning models, explicitly pass None to disable these parameters
                    return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)

            case 'azure_openai':
                if not HAS_AZURE_LLM:
                    raise ValueError(
                        'Azure OpenAI LLM client not available in current graphiti-core version'
                    )
                if not config.providers.azure_openai:
                    raise ValueError('Azure OpenAI provider configuration not found')
                azure_config = config.providers.azure_openai

                if not azure_config.api_url:
                    raise ValueError('Azure OpenAI API URL is required')

                # Handle Azure AD authentication if enabled
                api_key: str | None = None
                azure_ad_token_provider = None
                if azure_config.use_azure_ad:
                    logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
                    azure_ad_token_provider = create_azure_credential_token_provider()
                else:
                    api_key = azure_config.api_key
                    _validate_api_key('Azure OpenAI', api_key, logger)

                # Create the Azure OpenAI client first
                azure_client = AsyncAzureOpenAI(
                    api_key=api_key,
                    azure_endpoint=azure_config.api_url,
                    api_version=azure_config.api_version,
                    azure_deployment=azure_config.deployment_name,
                    azure_ad_token_provider=azure_ad_token_provider,
                )

                # Then create the LLMConfig
                from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig

                llm_config = CoreLLMConfig(
                    api_key=api_key,
                    base_url=azure_config.api_url,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )

                return AzureOpenAILLMClient(
                    azure_client=azure_client,
                    config=llm_config,
                    max_tokens=config.max_tokens,
                )

            case 'anthropic':
                if not HAS_ANTHROPIC:
                    raise ValueError(
                        'Anthropic client not available in current graphiti-core version'
                    )
                if not config.providers.anthropic:
                    raise ValueError('Anthropic provider configuration not found')

                api_key = config.providers.anthropic.api_key
                _validate_api_key('Anthropic', api_key, logger)

                llm_config = GraphitiLLMConfig(
                    api_key=api_key,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )
                return AnthropicClient(config=llm_config)

            case 'gemini':
                if not HAS_GEMINI:
                    raise ValueError('Gemini client not available in current graphiti-core version')
                if not config.providers.gemini:
                    raise ValueError('Gemini provider configuration not found')

                api_key = config.providers.gemini.api_key
                _validate_api_key('Gemini', api_key, logger)

                llm_config = GraphitiLLMConfig(
                    api_key=api_key,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )
                return GeminiClient(config=llm_config)

            case 'groq':
                if not HAS_GROQ:
                    raise ValueError('Groq client not available in current graphiti-core version')
                if not config.providers.groq:
                    raise ValueError('Groq provider configuration not found')

                api_key = config.providers.groq.api_key
                _validate_api_key('Groq', api_key, logger)

                llm_config = GraphitiLLMConfig(
                    api_key=api_key,
                    base_url=config.providers.groq.api_url,
                    model=config.model,
                    temperature=config.temperature,
                    max_tokens=config.max_tokens,
                )
                return GroqClient(config=llm_config)

            case _:
                raise ValueError(f'Unsupported LLM provider: {provider}')


class EmbedderFactory:
    """Factory for creating Embedder clients based on configuration."""

    @staticmethod
    def create(config: EmbedderConfig) -> EmbedderClient:
        """Create an Embedder client based on the configured provider."""
        import logging

        logger = logging.getLogger(__name__)

        provider = config.provider.lower()

        match provider:
            case 'openai':
                if not config.providers.openai:
                    raise ValueError('OpenAI provider configuration not found')

                api_key = config.providers.openai.api_key
                _validate_api_key('OpenAI Embedder', api_key, logger)

                from graphiti_core.embedder.openai import OpenAIEmbedderConfig

                embedder_config = OpenAIEmbedderConfig(
                    api_key=api_key,
                    embedding_model=config.model,
                )
                return OpenAIEmbedder(config=embedder_config)

            case 'azure_openai':
                if not HAS_AZURE_EMBEDDER:
                    raise ValueError(
                        'Azure OpenAI embedder not available in current graphiti-core version'
                    )
                if not config.providers.azure_openai:
                    raise ValueError('Azure OpenAI provider configuration not found')
                azure_config = config.providers.azure_openai

                if not azure_config.api_url:
                    raise ValueError('Azure OpenAI API URL is required')

                # Handle Azure AD authentication if enabled
                api_key: str | None = None
                azure_ad_token_provider = None
                if azure_config.use_azure_ad:
                    logger.info(
                        'Creating Azure OpenAI Embedder client with Azure AD authentication'
                    )
                    azure_ad_token_provider = create_azure_credential_token_provider()
                else:
                    api_key = azure_config.api_key
                    _validate_api_key('Azure OpenAI Embedder', api_key, logger)

                # Create the Azure OpenAI client first
                azure_client = AsyncAzureOpenAI(
                    api_key=api_key,
                    azure_endpoint=azure_config.api_url,
                    api_version=azure_config.api_version,
                    azure_deployment=azure_config.deployment_name,
                    azure_ad_token_provider=azure_ad_token_provider,
                )

                return AzureOpenAIEmbedderClient(
                    azure_client=azure_client,
                    model=config.model or 'text-embedding-3-small',
                )

            case 'gemini':
                if not HAS_GEMINI_EMBEDDER:
                    raise ValueError(
                        'Gemini embedder not available in current graphiti-core version'
                    )
                if not config.providers.gemini:
                    raise ValueError('Gemini provider configuration not found')

                api_key = config.providers.gemini.api_key
                _validate_api_key('Gemini Embedder', api_key, logger)

                from graphiti_core.embedder.gemini import GeminiEmbedderConfig

                gemini_config = GeminiEmbedderConfig(
                    api_key=api_key,
                    embedding_model=config.model or 'models/text-embedding-004',
                    embedding_dim=config.dimensions or 768,
                )
                return GeminiEmbedder(config=gemini_config)

            case 'voyage':
                if not HAS_VOYAGE_EMBEDDER:
                    raise ValueError(
                        'Voyage embedder not available in current graphiti-core version'
                    )
                if not config.providers.voyage:
                    raise ValueError('Voyage provider configuration not found')

                api_key = config.providers.voyage.api_key
                _validate_api_key('Voyage Embedder', api_key, logger)

                from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig

                voyage_config = VoyageAIEmbedderConfig(
                    api_key=api_key,
                    embedding_model=config.model or 'voyage-3',
                    embedding_dim=config.dimensions or 1024,
                )
                return VoyageAIEmbedder(config=voyage_config)

            case _:
                raise ValueError(f'Unsupported Embedder provider: {provider}')


class DatabaseDriverFactory:
    """Factory for creating Database drivers based on configuration.

    Note: This returns configuration dictionaries that can be passed to Graphiti(),
    not driver instances directly, as the drivers require complex initialization.
    """

    @staticmethod
    def create_config(config: DatabaseConfig) -> dict:
        """Create database configuration dictionary based on the configured provider."""
        provider = config.provider.lower()

        match provider:
            case 'neo4j':
                # Use Neo4j config if provided, otherwise use defaults
                if config.providers.neo4j:
                    neo4j_config = config.providers.neo4j
                else:
                    # Create default Neo4j configuration
                    from config.schema import Neo4jProviderConfig

                    neo4j_config = Neo4jProviderConfig()

                # Check for environment variable overrides (for CI/CD compatibility)
                import os

                uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
                username = os.environ.get('NEO4J_USER', neo4j_config.username)
                password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)

                return {
                    'uri': uri,
                    'user': username,
                    'password': password,
                    # Note: database and use_parallel_runtime would need to be passed
                    # to the driver after initialization if supported
                }

            case 'falkordb':
                if not HAS_FALKOR:
                    raise ValueError(
                        'FalkorDB driver not available in current graphiti-core version'
                    )

                # Use FalkorDB config if provided, otherwise use defaults
                if config.providers.falkordb:
                    falkor_config = config.providers.falkordb
                else:
                    # Create default FalkorDB configuration
                    from config.schema import FalkorDBProviderConfig

                    falkor_config = FalkorDBProviderConfig()

                # Check for environment variable overrides (for CI/CD compatibility)
                import os
                from urllib.parse import urlparse

                uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
                password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)

                # Parse the URI to extract host and port
                parsed = urlparse(uri)
                host = parsed.hostname or 'localhost'
                port = parsed.port or 6379

                return {
                    'driver': 'falkordb',
                    'host': host,
                    'port': port,
                    'password': password,
                    'database': falkor_config.database,
                }

            case _:
                raise ValueError(f'Unsupported Database provider: {provider}')

```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/anthropic_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import os
import typing
from json import JSONDecodeError
from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel, ValidationError

from ..prompts.models import Message
from .client import LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError

if TYPE_CHECKING:
    import anthropic
    from anthropic import AsyncAnthropic
    from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
else:
    try:
        import anthropic
        from anthropic import AsyncAnthropic
        from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
    except ImportError:
        raise ImportError(
            'anthropic is required for AnthropicClient. '
            'Install it with: pip install graphiti-core[anthropic]'
        ) from None


logger = logging.getLogger(__name__)

AnthropicModel = Literal[
    'claude-sonnet-4-5-latest',
    'claude-sonnet-4-5-20250929',
    'claude-haiku-4-5-latest',
    'claude-3-7-sonnet-latest',
    'claude-3-7-sonnet-20250219',
    'claude-3-5-haiku-latest',
    'claude-3-5-haiku-20241022',
    'claude-3-5-sonnet-latest',
    'claude-3-5-sonnet-20241022',
    'claude-3-5-sonnet-20240620',
    'claude-3-opus-latest',
    'claude-3-opus-20240229',
    'claude-3-sonnet-20240229',
    'claude-3-haiku-20240307',
    'claude-2.1',
    'claude-2.0',
]

DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'

# Maximum output tokens for different Anthropic models
# Based on official Anthropic documentation (as of 2025)
# Note: These represent standard limits without beta headers.
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
ANTHROPIC_MODEL_MAX_TOKENS = {
    # Claude 4.5 models - 64K tokens
    'claude-sonnet-4-5-latest': 65536,
    'claude-sonnet-4-5-20250929': 65536,
    'claude-haiku-4-5-latest': 65536,
    # Claude 3.7 models - standard 64K tokens
    'claude-3-7-sonnet-latest': 65536,
    'claude-3-7-sonnet-20250219': 65536,
    # Claude 3.5 models
    'claude-3-5-haiku-latest': 8192,
    'claude-3-5-haiku-20241022': 8192,
    'claude-3-5-sonnet-latest': 8192,
    'claude-3-5-sonnet-20241022': 8192,
    'claude-3-5-sonnet-20240620': 8192,
    # Claude 3 models - 4K tokens
    'claude-3-opus-latest': 4096,
    'claude-3-opus-20240229': 4096,
    'claude-3-sonnet-20240229': 4096,
    'claude-3-haiku-20240307': 4096,
    # Claude 2 models - 4K tokens
    'claude-2.1': 4096,
    'claude-2.0': 4096,
}

# Default max tokens for models not in the mapping
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192


class AnthropicClient(LLMClient):
    """
    A client for the Anthropic LLM.

    Args:
        config: A configuration object for the LLM.
        cache: Whether to cache the LLM responses.
        client: An optional client instance to use.
        max_tokens: The maximum number of tokens to generate.

    Methods:
        generate_response: Generate a response from the LLM.

    Notes:
        - If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
            variable, and all default values will be used for the LLMConfig.

    """

    model: AnthropicModel

    def __init__(
        self,
        config: LLMConfig | None = None,
        cache: bool = False,
        client: AsyncAnthropic | None = None,
        max_tokens: int = DEFAULT_MAX_TOKENS,
    ) -> None:
        if config is None:
            config = LLMConfig()
            config.api_key = os.getenv('ANTHROPIC_API_KEY')
            config.max_tokens = max_tokens

        if config.model is None:
            config.model = DEFAULT_MODEL

        super().__init__(config, cache)
        # Explicitly set the instance model to the config model to prevent type checking errors
        self.model = typing.cast(AnthropicModel, config.model)

        if not client:
            self.client = AsyncAnthropic(
                api_key=config.api_key,
                max_retries=1,
            )
        else:
            self.client = client

    def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
        """Extract JSON from text content.

        A helper method to extract JSON from text content, used when tool use fails or
        no response_model is provided.

        Args:
            text: The text to extract JSON from

        Returns:
            Extracted JSON as a dictionary

        Raises:
            ValueError: If JSON cannot be extracted or parsed
        """
        try:
            json_start = text.find('{')
            json_end = text.rfind('}') + 1
            if json_start >= 0 and json_end > json_start:
                json_str = text[json_start:json_end]
                return json.loads(json_str)
            else:
                raise ValueError(f'Could not extract JSON from model response: {text}')
        except (JSONDecodeError, ValueError) as e:
            raise ValueError(f'Could not extract JSON from model response: {text}') from e

    def _create_tool(
        self, response_model: type[BaseModel] | None = None
    ) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
        """
        Create a tool definition based on the response_model if provided, or a generic JSON tool if not.

        Args:
            response_model: Optional Pydantic model to use for structured output.

        Returns:
            A list containing a single tool definition for use with the Anthropic API.
        """
        if response_model is not None:
            # Use the response_model to define the tool
            model_schema = response_model.model_json_schema()
            tool_name = response_model.__name__
            description = model_schema.get('description', f'Extract {tool_name} information')
        else:
            # Create a generic JSON output tool
            tool_name = 'generic_json_output'
            description = 'Output data in JSON format'
            model_schema = {
                'type': 'object',
                'additionalProperties': True,
                'description': 'Any JSON object containing the requested information',
            }

        tool = {
            'name': tool_name,
            'description': description,
            'input_schema': model_schema,
        }
        tool_list = [tool]
        tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
        tool_choice = {'type': 'tool', 'name': tool_name}
        tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
        return tool_list_cast, tool_choice_cast

    def _get_max_tokens_for_model(self, model: str) -> int:
        """Get the maximum output tokens for a specific Anthropic model.

        Args:
            model: The model name to look up

        Returns:
            int: The maximum output tokens for the model
        """
        return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)

    def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
        """
        Resolve the maximum output tokens to use based on precedence rules.

        Precedence order (highest to lowest):
        1. Explicit max_tokens parameter passed to generate_response()
        2. Instance max_tokens set during client initialization
        3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
        4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback

        Args:
            requested_max_tokens: The max_tokens parameter passed to generate_response()
            model: The model name to look up model-specific limits

        Returns:
            int: The resolved maximum tokens to use
        """
        # 1. Use explicit parameter if provided
        if requested_max_tokens is not None:
            return requested_max_tokens

        # 2. Use instance max_tokens if set during initialization
        if self.max_tokens is not None:
            return self.max_tokens

        # 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
        return self._get_max_tokens_for_model(model)

    async def _generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the Anthropic LLM using tool-based approach for all requests.

        Args:
            messages: List of message objects to send to the LLM.
            response_model: Optional Pydantic model to use for structured output.
            max_tokens: Maximum number of tokens to generate.

        Returns:
            Dictionary containing the structured response from the LLM.

        Raises:
            RateLimitError: If the rate limit is exceeded.
            RefusalError: If the LLM refuses to respond.
            Exception: If an error occurs during the generation process.
        """
        system_message = messages[0]
        user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
        user_messages_cast = typing.cast(list[MessageParam], user_messages)

        # Resolve max_tokens dynamically based on the model's capabilities
        # This allows different models to use their full output capacity
        max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)

        try:
            # Create the appropriate tool based on whether response_model is provided
            tools, tool_choice = self._create_tool(response_model)
            result = await self.client.messages.create(
                system=system_message.content,
                max_tokens=max_creation_tokens,
                temperature=self.temperature,
                messages=user_messages_cast,
                model=self.model,
                tools=tools,
                tool_choice=tool_choice,
            )

            # Extract the tool output from the response
            for content_item in result.content:
                if content_item.type == 'tool_use':
                    if isinstance(content_item.input, dict):
                        tool_args: dict[str, typing.Any] = content_item.input
                    else:
                        tool_args = json.loads(str(content_item.input))
                    return tool_args

            # If we didn't get a proper tool_use response, try to extract from text
            for content_item in result.content:
                if content_item.type == 'text':
                    return self._extract_json_from_text(content_item.text)
                else:
                    raise ValueError(
                        f'Could not extract structured data from model response: {result.content}'
                    )

            # If we get here, we couldn't parse a structured response
            raise ValueError(
                f'Could not extract structured data from model response: {result.content}'
            )

        except anthropic.RateLimitError as e:
            raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
        except anthropic.APIError as e:
            # Special case for content policy violations. We convert these to RefusalError
            # to bypass the retry mechanism, as retrying policy-violating content will always fail.
            # This avoids wasting API calls and provides more specific error messaging to the user.
            if 'refused to respond' in str(e).lower():
                raise RefusalError(str(e)) from e
            raise e
        except Exception as e:
            raise e

    async def generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
        group_id: str | None = None,
        prompt_name: str | None = None,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the LLM.

        Args:
            messages: List of message objects to send to the LLM.
            response_model: Optional Pydantic model to use for structured output.
            max_tokens: Maximum number of tokens to generate.

        Returns:
            Dictionary containing the structured response from the LLM.

        Raises:
            RateLimitError: If the rate limit is exceeded.
            RefusalError: If the LLM refuses to respond.
            Exception: If an error occurs during the generation process.
        """
        if max_tokens is None:
            max_tokens = self.max_tokens

        # Wrap entire operation in tracing span
        with self.tracer.start_span('llm.generate') as span:
            attributes = {
                'llm.provider': 'anthropic',
                'model.size': model_size.value,
                'max_tokens': max_tokens,
            }
            if prompt_name:
                attributes['prompt.name'] = prompt_name
            span.add_attributes(attributes)

            retry_count = 0
            max_retries = 2
            last_error: Exception | None = None

            while retry_count <= max_retries:
                try:
                    response = await self._generate_response(
                        messages, response_model, max_tokens, model_size
                    )

                    # If we have a response_model, attempt to validate the response
                    if response_model is not None:
                        # Validate the response against the response_model
                        model_instance = response_model(**response)
                        return model_instance.model_dump()

                    # If no validation needed, return the response
                    return response

                except (RateLimitError, RefusalError):
                    # These errors should not trigger retries
                    span.set_status('error', str(last_error))
                    raise
                except Exception as e:
                    last_error = e

                    if retry_count >= max_retries:
                        if isinstance(e, ValidationError):
                            logger.error(
                                f'Validation error after {retry_count}/{max_retries} attempts: {e}'
                            )
                        else:
                            logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
                        span.set_status('error', str(e))
                        span.record_exception(e)
                        raise e

                    if isinstance(e, ValidationError):
                        response_model_cast = typing.cast(type[BaseModel], response_model)
                        error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
                    else:
                        error_context = (
                            f'The previous response attempt was invalid. '
                            f'Error type: {e.__class__.__name__}. '
                            f'Error details: {str(e)}. '
                            f'Please try again with a valid response.'
                        )

                    # Common retry logic
                    retry_count += 1
                    messages.append(Message(role='user', content=error_context))
                    logger.warning(
                        f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
                    )

            # If we somehow get here, raise the last error
            span.set_status('error', str(last_error))
            raise last_error or Exception('Max retries exceeded with no specific error')

```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_bulk_utils.py:
--------------------------------------------------------------------------------

```python
from collections import deque
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.bulk_utils import extract_nodes_and_edges_bulk
from graphiti_core.utils.datetime_utils import utc_now


def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
    return EpisodicNode(
        name=f'episode-{uuid_suffix}',
        group_id=group_id,
        labels=[],
        source=EpisodeType.message,
        content='content',
        source_description='test',
        created_at=utc_now(),
        valid_at=utc_now(),
    )


def _make_clients() -> GraphitiClients:
    driver = MagicMock()
    embedder = MagicMock()
    cross_encoder = MagicMock()
    llm_client = MagicMock()

    return GraphitiClients.model_construct(  # bypass validation to allow test doubles
        driver=driver,
        embedder=embedder,
        cross_encoder=cross_encoder,
        llm_client=llm_client,
    )


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
    clients = _make_clients()

    episode_one = _make_episode('1')
    episode_two = _make_episode('2')

    extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
    extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])

    canonical = extracted_one

    call_queue = deque()

    async def fake_resolve(
        clients_arg,
        nodes_arg,
        episode_arg,
        previous_episodes_arg,
        entity_types_arg,
        existing_nodes_override=None,
    ):
        call_queue.append(existing_nodes_override)

        if nodes_arg == [extracted_one]:
            return [canonical], {canonical.uuid: canonical.uuid}, []

        assert nodes_arg == [extracted_two]
        assert existing_nodes_override is None

        return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]

    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [[extracted_one], [extracted_two]],
        [(episode_one, []), (episode_two, [])],
    )

    assert len(call_queue) == 2
    assert call_queue[0] is None
    assert call_queue[1] is None

    assert nodes_by_episode[episode_one.uuid] == [canonical]
    assert nodes_by_episode[episode_two.uuid] == [canonical]
    assert compressed_map.get(extracted_two.uuid) == canonical.uuid


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
    clients = _make_clients()

    resolve_mock = AsyncMock()
    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [],
        [],
    )

    assert nodes_by_episode == {}
    assert compressed_map == {}
    resolve_mock.assert_not_awaited()


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
    clients = _make_clients()

    episode = _make_episode('solo')
    extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])

    resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [[extracted]],
        [(episode, [])],
    )

    assert nodes_by_episode == {episode.uuid: [extracted]}
    assert compressed_map == {extracted.uuid: extracted.uuid}
    resolve_mock.assert_awaited_once()


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
    clients = _make_clients()

    episode_one = _make_episode('one')
    episode_two = _make_episode('two')

    extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
    extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])

    canonical = extracted_one
    alias = extracted_two

    async def fake_resolve(
        clients_arg,
        nodes_arg,
        episode_arg,
        previous_episodes_arg,
        entity_types_arg,
        existing_nodes_override=None,
    ):
        if nodes_arg == [extracted_one]:
            return [canonical], {canonical.uuid: canonical.uuid}, []
        assert nodes_arg == [extracted_two]
        return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]

    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)

    nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
        clients,
        [[extracted_one], [extracted_two]],
        [(episode_one, []), (episode_two, [])],
    )

    assert nodes_by_episode[episode_one.uuid] == [canonical]
    assert nodes_by_episode[episode_two.uuid] == [canonical]
    assert compressed_map.get(alias.uuid) == canonical.uuid


@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
    clients = _make_clients()

    episode = _make_episode('missing')
    extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])

    resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
    monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)

    with caplog.at_level('WARNING'):
        nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
            clients,
            [[extracted]],
            [(episode, [])],
        )

    assert nodes_by_episode[episode.uuid] == [extracted]
    assert compressed_map.get(extracted.uuid) == 'missing-canonical'
    assert any('Canonical node missing' in rec.message for rec in caplog.records)


def test_build_directed_uuid_map_empty():
    assert bulk_utils._build_directed_uuid_map([]) == {}


def test_build_directed_uuid_map_chain():
    mapping = bulk_utils._build_directed_uuid_map(
        [
            ('a', 'b'),
            ('b', 'c'),
        ]
    )

    assert mapping['a'] == 'c'
    assert mapping['b'] == 'c'
    assert mapping['c'] == 'c'


def test_build_directed_uuid_map_preserves_direction():
    mapping = bulk_utils._build_directed_uuid_map(
        [
            ('alias', 'canonical'),
        ]
    )

    assert mapping['alias'] == 'canonical'
    assert mapping['canonical'] == 'canonical'


def test_resolve_edge_pointers_updates_sources():
    created_at = utc_now()
    edge = EntityEdge(
        name='knows',
        fact='fact',
        group_id='group',
        source_node_uuid='alias',
        target_node_uuid='target',
        created_at=created_at,
    )

    bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})

    assert edge.source_node_uuid == 'canonical'
    assert edge.target_node_uuid == 'target'


@pytest.mark.asyncio
async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
    """Test that dedupe_edges_bulk correctly compares edges within the same episode.

    This test verifies the fix that removed the `if i == j: continue` check,
    which was preventing edges from the same episode from being compared against each other.
    """
    clients = _make_clients()

    # Track which edges are compared
    comparisons_made = []

    # Create mock embedder that sets embedding values
    async def mock_create_embeddings(embedder, edges):
        for edge in edges:
            edge.fact_embedding = [0.1, 0.2, 0.3]

    monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)

    # Mock resolve_extracted_edge to track comparisons and mark duplicates
    async def mock_resolve_extracted_edge(
        llm_client,
        extracted_edge,
        related_edges,
        existing_edges,
        episode,
        edge_type_candidates=None,
        custom_edge_type_names=None,
    ):
        # Track that this edge was compared against the related_edges
        comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))

        # If there are related edges with same source/target/fact, mark as duplicate
        for related in related_edges:
            if (
                related.uuid != extracted_edge.uuid  # Can't be duplicate of self
                and related.source_node_uuid == extracted_edge.source_node_uuid
                and related.target_node_uuid == extracted_edge.target_node_uuid
                and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
            ):
                # Return the related edge and mark extracted_edge as duplicate
                return related, [], [related]
        # Otherwise return the extracted edge as-is
        return extracted_edge, [], []

    monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)

    episode = _make_episode('1')
    source_uuid = 'source-uuid'
    target_uuid = 'target-uuid'

    # Create 3 identical edges within the same episode
    edge1 = EntityEdge(
        name='recommends',
        fact='assistant recommends yoga poses',
        group_id='group',
        source_node_uuid=source_uuid,
        target_node_uuid=target_uuid,
        created_at=utc_now(),
        episodes=[episode.uuid],
    )
    edge2 = EntityEdge(
        name='recommends',
        fact='assistant recommends yoga poses',
        group_id='group',
        source_node_uuid=source_uuid,
        target_node_uuid=target_uuid,
        created_at=utc_now(),
        episodes=[episode.uuid],
    )
    edge3 = EntityEdge(
        name='recommends',
        fact='assistant recommends yoga poses',
        group_id='group',
        source_node_uuid=source_uuid,
        target_node_uuid=target_uuid,
        created_at=utc_now(),
        episodes=[episode.uuid],
    )

    await bulk_utils.dedupe_edges_bulk(
        clients,
        [[edge1, edge2, edge3]],
        [(episode, [])],
        [],
        {},
        {},
    )

    # Verify that edges were compared against each other (within same episode)
    # Each edge should have been compared against all 3 edges (including itself, which gets filtered)
    assert len(comparisons_made) == 3
    for _, compared_against in comparisons_made:
        # Each edge should have access to all 3 edges as candidates
        assert len(compared_against) >= 2  # At least 2 others (self is filtered out)


@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_passes_custom_instructions_to_extract_nodes(
    monkeypatch,
):
    """Test that custom_extraction_instructions is passed to extract_nodes."""
    clients = _make_clients()
    episode = _make_episode('1')

    # Track calls to extract_nodes
    extract_nodes_calls = []

    async def mock_extract_nodes(
        clients,
        episode,
        previous_episodes,
        entity_types=None,
        excluded_entity_types=None,
        custom_extraction_instructions=None,
    ):
        extract_nodes_calls.append(
            {
                'entity_types': entity_types,
                'excluded_entity_types': excluded_entity_types,
                'custom_extraction_instructions': custom_extraction_instructions,
            }
        )
        return []

    async def mock_extract_edges(
        clients,
        episode,
        nodes,
        previous_episodes,
        edge_type_map,
        group_id='',
        edge_types=None,
        custom_extraction_instructions=None,
    ):
        return []

    monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
    monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)

    custom_instructions = 'Focus on extracting person entities and their relationships.'

    await extract_nodes_and_edges_bulk(
        clients,
        [(episode, [])],
        edge_type_map={},
        custom_extraction_instructions=custom_instructions,
    )

    assert len(extract_nodes_calls) == 1
    assert extract_nodes_calls[0]['custom_extraction_instructions'] == custom_instructions


@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_passes_custom_instructions_to_extract_edges(
    monkeypatch,
):
    """Test that custom_extraction_instructions is passed to extract_edges."""
    clients = _make_clients()
    episode = _make_episode('1')

    # Track calls to extract_edges
    extract_edges_calls = []
    extracted_node = EntityNode(name='Test', group_id='group', labels=['Entity'])

    async def mock_extract_nodes(
        clients,
        episode,
        previous_episodes,
        entity_types=None,
        excluded_entity_types=None,
        custom_extraction_instructions=None,
    ):
        return [extracted_node]

    async def mock_extract_edges(
        clients,
        episode,
        nodes,
        previous_episodes,
        edge_type_map,
        group_id='',
        edge_types=None,
        custom_extraction_instructions=None,
    ):
        extract_edges_calls.append(
            {
                'nodes': nodes,
                'edge_type_map': edge_type_map,
                'edge_types': edge_types,
                'custom_extraction_instructions': custom_extraction_instructions,
            }
        )
        return []

    monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
    monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)

    custom_instructions = 'Extract only professional relationships between people.'

    await extract_nodes_and_edges_bulk(
        clients,
        [(episode, [])],
        edge_type_map={('Entity', 'Entity'): ['knows']},
        custom_extraction_instructions=custom_instructions,
    )

    assert len(extract_edges_calls) == 1
    assert extract_edges_calls[0]['custom_extraction_instructions'] == custom_instructions
    assert extract_edges_calls[0]['nodes'] == [extracted_node]


@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_custom_instructions_none_by_default(monkeypatch):
    """Test that custom_extraction_instructions defaults to None when not provided."""
    clients = _make_clients()
    episode = _make_episode('1')

    extract_nodes_calls = []
    extract_edges_calls = []

    async def mock_extract_nodes(
        clients,
        episode,
        previous_episodes,
        entity_types=None,
        excluded_entity_types=None,
        custom_extraction_instructions=None,
    ):
        extract_nodes_calls.append(
            {'custom_extraction_instructions': custom_extraction_instructions}
        )
        return []

    async def mock_extract_edges(
        clients,
        episode,
        nodes,
        previous_episodes,
        edge_type_map,
        group_id='',
        edge_types=None,
        custom_extraction_instructions=None,
    ):
        extract_edges_calls.append(
            {'custom_extraction_instructions': custom_extraction_instructions}
        )
        return []

    monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
    monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)

    # Call without custom_extraction_instructions
    await extract_nodes_and_edges_bulk(
        clients,
        [(episode, [])],
        edge_type_map={},
    )

    assert len(extract_nodes_calls) == 1
    assert extract_nodes_calls[0]['custom_extraction_instructions'] is None
    assert len(extract_edges_calls) == 1
    assert extract_edges_calls[0]['custom_extraction_instructions'] is None


@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_custom_instructions_multiple_episodes(monkeypatch):
    """Test that custom_extraction_instructions is passed for all episodes in bulk."""
    clients = _make_clients()
    episode1 = _make_episode('1')
    episode2 = _make_episode('2')
    episode3 = _make_episode('3')

    extract_nodes_calls = []
    extract_edges_calls = []

    async def mock_extract_nodes(
        clients,
        episode,
        previous_episodes,
        entity_types=None,
        excluded_entity_types=None,
        custom_extraction_instructions=None,
    ):
        extract_nodes_calls.append(
            {
                'episode_name': episode.name,
                'custom_extraction_instructions': custom_extraction_instructions,
            }
        )
        return []

    async def mock_extract_edges(
        clients,
        episode,
        nodes,
        previous_episodes,
        edge_type_map,
        group_id='',
        edge_types=None,
        custom_extraction_instructions=None,
    ):
        extract_edges_calls.append(
            {
                'episode_name': episode.name,
                'custom_extraction_instructions': custom_extraction_instructions,
            }
        )
        return []

    monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
    monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)

    custom_instructions = 'Extract entities related to financial transactions.'

    await extract_nodes_and_edges_bulk(
        clients,
        [(episode1, []), (episode2, []), (episode3, [])],
        edge_type_map={},
        custom_extraction_instructions=custom_instructions,
    )

    # All 3 episodes should have received the custom instructions
    assert len(extract_nodes_calls) == 3
    assert len(extract_edges_calls) == 3

    for call in extract_nodes_calls:
        assert call['custom_extraction_instructions'] == custom_instructions

    for call in extract_edges_calls:
        assert call['custom_extraction_instructions'] == custom_instructions

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_async_operations.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Asynchronous operation tests for Graphiti MCP Server.
Tests concurrent operations, queue management, and async patterns.
"""

import asyncio
import contextlib
import json
import time

import pytest
from test_fixtures import (
    TestDataGenerator,
    graphiti_test_client,
)


class TestAsyncQueueManagement:
    """Test asynchronous queue operations and episode processing."""

    @pytest.mark.asyncio
    async def test_sequential_queue_processing(self):
        """Verify episodes are processed sequentially within a group."""
        async with graphiti_test_client() as (session, group_id):
            # Add multiple episodes quickly
            episodes = []
            for i in range(5):
                result = await session.call_tool(
                    'add_memory',
                    {
                        'name': f'Sequential Test {i}',
                        'episode_body': f'Episode {i} with timestamp {time.time()}',
                        'source': 'text',
                        'source_description': 'sequential test',
                        'group_id': group_id,
                        'reference_id': f'seq_{i}',  # Add reference for tracking
                    },
                )
                episodes.append(result)

            # Wait for processing
            await asyncio.sleep(10)  # Allow time for sequential processing

            # Retrieve episodes and verify order
            result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})

            processed_episodes = json.loads(result.content[0].text)['episodes']

            # Verify all episodes were processed
            assert len(processed_episodes) >= 5, (
                f'Expected at least 5 episodes, got {len(processed_episodes)}'
            )

            # Verify sequential processing (timestamps should be ordered)
            timestamps = [ep.get('created_at') for ep in processed_episodes]
            assert timestamps == sorted(timestamps), 'Episodes not processed in order'

    @pytest.mark.asyncio
    async def test_concurrent_group_processing(self):
        """Test that different groups can process concurrently."""
        async with graphiti_test_client() as (session, _):
            groups = [f'group_{i}_{time.time()}' for i in range(3)]
            tasks = []

            # Create tasks for different groups
            for group_id in groups:
                for j in range(2):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Group {group_id} Episode {j}',
                            'episode_body': f'Content for {group_id}',
                            'source': 'text',
                            'source_description': 'concurrent test',
                            'group_id': group_id,
                        },
                    )
                    tasks.append(task)

            # Execute all tasks concurrently
            start_time = time.time()
            results = await asyncio.gather(*tasks, return_exceptions=True)
            execution_time = time.time() - start_time

            # Verify all succeeded
            failures = [r for r in results if isinstance(r, Exception)]
            assert not failures, f'Concurrent operations failed: {failures}'

            # Check that execution was actually concurrent (should be faster than sequential)
            # Sequential would take at least 6 * processing_time
            assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'

    @pytest.mark.asyncio
    async def test_queue_overflow_handling(self):
        """Test behavior when queue reaches capacity."""
        async with graphiti_test_client() as (session, group_id):
            # Attempt to add many episodes rapidly
            tasks = []
            for i in range(100):  # Large number to potentially overflow
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Overflow Test {i}',
                        'episode_body': f'Episode {i}',
                        'source': 'text',
                        'source_description': 'overflow test',
                        'group_id': group_id,
                    },
                )
                tasks.append(task)

            # Execute with gathering to catch any failures
            results = await asyncio.gather(*tasks, return_exceptions=True)

            # Count successful queuing
            successful = sum(1 for r in results if not isinstance(r, Exception))

            # Should handle overflow gracefully
            assert successful > 0, 'No episodes were queued successfully'

            # Log overflow behavior
            if successful < 100:
                print(f'Queue overflow: {successful}/100 episodes queued')


class TestConcurrentOperations:
    """Test concurrent tool calls and operations."""

    @pytest.mark.asyncio
    async def test_concurrent_search_operations(self):
        """Test multiple concurrent search operations."""
        async with graphiti_test_client() as (session, group_id):
            # First, add some test data
            data_gen = TestDataGenerator()

            add_tasks = []
            for _ in range(5):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': 'Search Test Data',
                        'episode_body': data_gen.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'search test',
                        'group_id': group_id,
                    },
                )
                add_tasks.append(task)

            await asyncio.gather(*add_tasks)
            await asyncio.sleep(15)  # Wait for processing

            # Now perform concurrent searches
            search_queries = [
                'architecture',
                'performance',
                'implementation',
                'dependencies',
                'latency',
            ]

            search_tasks = []
            for query in search_queries:
                task = session.call_tool(
                    'search_memory_nodes',
                    {
                        'query': query,
                        'group_id': group_id,
                        'limit': 10,
                    },
                )
                search_tasks.append(task)

            start_time = time.time()
            results = await asyncio.gather(*search_tasks, return_exceptions=True)
            search_time = time.time() - start_time

            # Verify all searches completed
            failures = [r for r in results if isinstance(r, Exception)]
            assert not failures, f'Search operations failed: {failures}'

            # Verify concurrent execution efficiency
            assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'

    @pytest.mark.asyncio
    async def test_mixed_operation_concurrency(self):
        """Test different types of operations running concurrently."""
        async with graphiti_test_client() as (session, group_id):
            operations = []

            # Add memory operation
            operations.append(
                session.call_tool(
                    'add_memory',
                    {
                        'name': 'Mixed Op Test',
                        'episode_body': 'Testing mixed operations',
                        'source': 'text',
                        'source_description': 'test',
                        'group_id': group_id,
                    },
                )
            )

            # Search operation
            operations.append(
                session.call_tool(
                    'search_memory_nodes',
                    {
                        'query': 'test',
                        'group_id': group_id,
                        'limit': 5,
                    },
                )
            )

            # Get episodes operation
            operations.append(
                session.call_tool(
                    'get_episodes',
                    {
                        'group_id': group_id,
                        'last_n': 10,
                    },
                )
            )

            # Get status operation
            operations.append(session.call_tool('get_status', {}))

            # Execute all concurrently
            results = await asyncio.gather(*operations, return_exceptions=True)

            # Check results
            for i, result in enumerate(results):
                assert not isinstance(result, Exception), f'Operation {i} failed: {result}'


class TestAsyncErrorHandling:
    """Test async error handling and recovery."""

    @pytest.mark.asyncio
    async def test_timeout_recovery(self):
        """Test recovery from operation timeouts."""
        async with graphiti_test_client() as (session, group_id):
            # Create a very large episode that might time out
            large_content = 'x' * 1000000  # 1MB of data

            with contextlib.suppress(asyncio.TimeoutError):
                await asyncio.wait_for(
                    session.call_tool(
                        'add_memory',
                        {
                            'name': 'Timeout Test',
                            'episode_body': large_content,
                            'source': 'text',
                            'source_description': 'timeout test',
                            'group_id': group_id,
                        },
                    ),
                    timeout=2.0,  # Short timeout - expected to timeout
                )

            # Verify server is still responsive after timeout
            status_result = await session.call_tool('get_status', {})
            assert status_result is not None, 'Server unresponsive after timeout'

    @pytest.mark.asyncio
    async def test_cancellation_handling(self):
        """Test proper handling of cancelled operations."""
        async with graphiti_test_client() as (session, group_id):
            # Start a long-running operation
            task = asyncio.create_task(
                session.call_tool(
                    'add_memory',
                    {
                        'name': 'Cancellation Test',
                        'episode_body': TestDataGenerator.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'cancel test',
                        'group_id': group_id,
                    },
                )
            )

            # Cancel after a short delay
            await asyncio.sleep(0.1)
            task.cancel()

            # Verify cancellation was handled
            with pytest.raises(asyncio.CancelledError):
                await task

            # Server should still be operational
            result = await session.call_tool('get_status', {})
            assert result is not None

    @pytest.mark.asyncio
    async def test_exception_propagation(self):
        """Test that exceptions are properly propagated in async context."""
        async with graphiti_test_client() as (session, group_id):
            # Call with invalid arguments
            with pytest.raises(ValueError):
                await session.call_tool(
                    'add_memory',
                    {
                        # Missing required fields
                        'group_id': group_id,
                    },
                )

            # Server should remain operational
            status = await session.call_tool('get_status', {})
            assert status is not None


class TestAsyncPerformance:
    """Performance tests for async operations."""

    @pytest.mark.asyncio
    async def test_async_throughput(self, performance_benchmark):
        """Measure throughput of async operations."""
        async with graphiti_test_client() as (session, group_id):
            num_operations = 50
            start_time = time.time()

            # Create many concurrent operations
            tasks = []
            for i in range(num_operations):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Throughput Test {i}',
                        'episode_body': f'Content {i}',
                        'source': 'text',
                        'source_description': 'throughput test',
                        'group_id': group_id,
                    },
                )
                tasks.append(task)

            # Execute all
            results = await asyncio.gather(*tasks, return_exceptions=True)
            total_time = time.time() - start_time

            # Calculate metrics
            successful = sum(1 for r in results if not isinstance(r, Exception))
            throughput = successful / total_time

            performance_benchmark.record('async_throughput', throughput)

            # Log results
            print('\nAsync Throughput Test:')
            print(f'  Operations: {num_operations}')
            print(f'  Successful: {successful}')
            print(f'  Total time: {total_time:.2f}s')
            print(f'  Throughput: {throughput:.2f} ops/s')

            # Assert minimum throughput
            assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'

    @pytest.mark.asyncio
    async def test_latency_under_load(self, performance_benchmark):
        """Test operation latency under concurrent load."""
        async with graphiti_test_client() as (session, group_id):
            # Create background load
            background_tasks = []
            for i in range(10):
                task = asyncio.create_task(
                    session.call_tool(
                        'add_memory',
                        {
                            'name': f'Background {i}',
                            'episode_body': TestDataGenerator.generate_technical_document(),
                            'source': 'text',
                            'source_description': 'background',
                            'group_id': f'background_{group_id}',
                        },
                    )
                )
                background_tasks.append(task)

            # Measure latency of operations under load
            latencies = []
            for _ in range(5):
                start = time.time()
                await session.call_tool('get_status', {})
                latency = time.time() - start
                latencies.append(latency)
                performance_benchmark.record('latency_under_load', latency)

            # Clean up background tasks
            for task in background_tasks:
                task.cancel()

            # Analyze latencies
            avg_latency = sum(latencies) / len(latencies)
            max_latency = max(latencies)

            print('\nLatency Under Load:')
            print(f'  Average: {avg_latency:.3f}s')
            print(f'  Max: {max_latency:.3f}s')

            # Assert acceptable latency
            assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
            assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'


class TestAsyncStreamHandling:
    """Test handling of streaming responses and data."""

    @pytest.mark.asyncio
    async def test_large_response_streaming(self):
        """Test handling of large streamed responses."""
        async with graphiti_test_client() as (session, group_id):
            # Add many episodes
            for i in range(20):
                await session.call_tool(
                    'add_memory',
                    {
                        'name': f'Stream Test {i}',
                        'episode_body': f'Episode content {i}',
                        'source': 'text',
                        'source_description': 'stream test',
                        'group_id': group_id,
                    },
                )

            # Wait for processing
            await asyncio.sleep(30)

            # Request large result set
            result = await session.call_tool(
                'get_episodes',
                {
                    'group_id': group_id,
                    'last_n': 100,  # Request all
                },
            )

            # Verify response handling
            episodes = json.loads(result.content[0].text)['episodes']
            assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'

    @pytest.mark.asyncio
    async def test_incremental_processing(self):
        """Test incremental processing of results."""
        async with graphiti_test_client() as (session, group_id):
            # Add episodes incrementally
            for batch in range(3):
                batch_tasks = []
                for i in range(5):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Batch {batch} Item {i}',
                            'episode_body': f'Content for batch {batch}',
                            'source': 'text',
                            'source_description': 'incremental test',
                            'group_id': group_id,
                        },
                    )
                    batch_tasks.append(task)

                # Process batch
                await asyncio.gather(*batch_tasks)

                # Wait for this batch to process
                await asyncio.sleep(10)

                # Verify incremental results
                result = await session.call_tool(
                    'get_episodes',
                    {
                        'group_id': group_id,
                        'last_n': 100,
                    },
                )

                episodes = json.loads(result.content[0].text)['episodes']
                expected_min = (batch + 1) * 5
                assert len(episodes) >= expected_min, (
                    f'Batch {batch}: Expected at least {expected_min} episodes'
                )


if __name__ == '__main__':
    pytest.main([__file__, '-v', '--asyncio-mode=auto'])

```

--------------------------------------------------------------------------------
/graphiti_core/search/search.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from collections import defaultdict
from time import time

from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder.client import EMBEDDING_DIM
from graphiti_core.errors import SearchRerankerError
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
    DEFAULT_SEARCH_LIMIT,
    CommunityReranker,
    CommunitySearchConfig,
    CommunitySearchMethod,
    EdgeReranker,
    EdgeSearchConfig,
    EdgeSearchMethod,
    EpisodeReranker,
    EpisodeSearchConfig,
    NodeReranker,
    NodeSearchConfig,
    NodeSearchMethod,
    SearchConfig,
    SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
    community_fulltext_search,
    community_similarity_search,
    edge_bfs_search,
    edge_fulltext_search,
    edge_similarity_search,
    episode_fulltext_search,
    episode_mentions_reranker,
    get_embeddings_for_communities,
    get_embeddings_for_edges,
    get_embeddings_for_nodes,
    maximal_marginal_relevance,
    node_bfs_search,
    node_distance_reranker,
    node_fulltext_search,
    node_similarity_search,
    rrf,
)

logger = logging.getLogger(__name__)


async def search(
    clients: GraphitiClients,
    query: str,
    group_ids: list[str] | None,
    config: SearchConfig,
    search_filter: SearchFilters,
    center_node_uuid: str | None = None,
    bfs_origin_node_uuids: list[str] | None = None,
    query_vector: list[float] | None = None,
    driver: GraphDriver | None = None,
) -> SearchResults:
    start = time()

    driver = driver or clients.driver
    embedder = clients.embedder
    cross_encoder = clients.cross_encoder

    if query.strip() == '':
        return SearchResults()

    if (
        config.edge_config
        and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
        or config.edge_config
        and EdgeReranker.mmr == config.edge_config.reranker
        or config.node_config
        and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
        or config.node_config
        and NodeReranker.mmr == config.node_config.reranker
        or (
            config.community_config
            and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
        )
        or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
    ):
        search_vector = (
            query_vector
            if query_vector is not None
            else await embedder.create(input_data=[query.replace('\n', ' ')])
        )
    else:
        search_vector = [0.0] * EMBEDDING_DIM

    # if group_ids is empty, set it to None
    group_ids = group_ids if group_ids and group_ids != [''] else None
    (
        (edges, edge_reranker_scores),
        (nodes, node_reranker_scores),
        (episodes, episode_reranker_scores),
        (communities, community_reranker_scores),
    ) = await semaphore_gather(
        edge_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.edge_config,
            search_filter,
            center_node_uuid,
            bfs_origin_node_uuids,
            config.limit,
            config.reranker_min_score,
        ),
        node_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.node_config,
            search_filter,
            center_node_uuid,
            bfs_origin_node_uuids,
            config.limit,
            config.reranker_min_score,
        ),
        episode_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.episode_config,
            search_filter,
            config.limit,
            config.reranker_min_score,
        ),
        community_search(
            driver,
            cross_encoder,
            query,
            search_vector,
            group_ids,
            config.community_config,
            config.limit,
            config.reranker_min_score,
        ),
    )

    results = SearchResults(
        edges=edges,
        edge_reranker_scores=edge_reranker_scores,
        nodes=nodes,
        node_reranker_scores=node_reranker_scores,
        episodes=episodes,
        episode_reranker_scores=episode_reranker_scores,
        communities=communities,
        community_reranker_scores=community_reranker_scores,
    )

    latency = (time() - start) * 1000

    logger.debug(f'search returned context for query {query} in {latency} ms')

    return results


async def edge_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    query_vector: list[float],
    group_ids: list[str] | None,
    config: EdgeSearchConfig | None,
    search_filter: SearchFilters,
    center_node_uuid: str | None = None,
    bfs_origin_node_uuids: list[str] | None = None,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[EntityEdge], list[float]]:
    if config is None:
        return [], []

    # Build search tasks based on configured search methods
    search_tasks = []
    if EdgeSearchMethod.bm25 in config.search_methods:
        search_tasks.append(
            edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
        )
    if EdgeSearchMethod.cosine_similarity in config.search_methods:
        search_tasks.append(
            edge_similarity_search(
                driver,
                query_vector,
                None,
                None,
                search_filter,
                group_ids,
                2 * limit,
                config.sim_min_score,
            )
        )
    if EdgeSearchMethod.bfs in config.search_methods:
        search_tasks.append(
            edge_bfs_search(
                driver,
                bfs_origin_node_uuids,
                config.bfs_max_depth,
                search_filter,
                group_ids,
                2 * limit,
            )
        )

    # Execute only the configured search methods
    search_results: list[list[EntityEdge]] = []
    if search_tasks:
        search_results = list(await semaphore_gather(*search_tasks))

    if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
        source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
        search_results.append(
            await edge_bfs_search(
                driver,
                source_node_uuids,
                config.bfs_max_depth,
                search_filter,
                group_ids,
                2 * limit,
            )
        )

    edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}

    reranked_uuids: list[str] = []
    edge_scores: list[float] = []
    if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
        search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

        reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
    elif config.reranker == EdgeReranker.mmr:
        search_result_uuids_and_vectors = await get_embeddings_for_edges(
            driver, list(edge_uuid_map.values())
        )
        reranked_uuids, edge_scores = maximal_marginal_relevance(
            query_vector,
            search_result_uuids_and_vectors,
            config.mmr_lambda,
            reranker_min_score,
        )
    elif config.reranker == EdgeReranker.cross_encoder:
        fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
        reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
        reranked_uuids = [
            fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
        ]
        edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
    elif config.reranker == EdgeReranker.node_distance:
        if center_node_uuid is None:
            raise SearchRerankerError('No center node provided for Node Distance reranker')

        # use rrf as a preliminary sort
        sorted_result_uuids, node_scores = rrf(
            [[edge.uuid for edge in result] for result in search_results],
            min_score=reranker_min_score,
        )
        sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]

        # node distance reranking
        source_to_edge_uuid_map = defaultdict(list)
        for edge in sorted_results:
            source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)

        source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]

        reranked_node_uuids, edge_scores = await node_distance_reranker(
            driver, source_uuids, center_node_uuid, min_score=reranker_min_score
        )

        for node_uuid in reranked_node_uuids:
            reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])

    reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]

    if config.reranker == EdgeReranker.episode_mentions:
        reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))

    return reranked_edges[:limit], edge_scores[:limit]


async def node_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    query_vector: list[float],
    group_ids: list[str] | None,
    config: NodeSearchConfig | None,
    search_filter: SearchFilters,
    center_node_uuid: str | None = None,
    bfs_origin_node_uuids: list[str] | None = None,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[EntityNode], list[float]]:
    if config is None:
        return [], []

    # Build search tasks based on configured search methods
    search_tasks = []
    if NodeSearchMethod.bm25 in config.search_methods:
        search_tasks.append(
            node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
        )
    if NodeSearchMethod.cosine_similarity in config.search_methods:
        search_tasks.append(
            node_similarity_search(
                driver,
                query_vector,
                search_filter,
                group_ids,
                2 * limit,
                config.sim_min_score,
            )
        )
    if NodeSearchMethod.bfs in config.search_methods:
        search_tasks.append(
            node_bfs_search(
                driver,
                bfs_origin_node_uuids,
                search_filter,
                config.bfs_max_depth,
                group_ids,
                2 * limit,
            )
        )

    # Execute only the configured search methods
    search_results: list[list[EntityNode]] = []
    if search_tasks:
        search_results = list(await semaphore_gather(*search_tasks))

    if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
        origin_node_uuids = [node.uuid for result in search_results for node in result]
        search_results.append(
            await node_bfs_search(
                driver,
                origin_node_uuids,
                search_filter,
                config.bfs_max_depth,
                group_ids,
                2 * limit,
            )
        )

    search_result_uuids = [[node.uuid for node in result] for result in search_results]
    node_uuid_map = {node.uuid: node for result in search_results for node in result}

    reranked_uuids: list[str] = []
    node_scores: list[float] = []
    if config.reranker == NodeReranker.rrf:
        reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
    elif config.reranker == NodeReranker.mmr:
        search_result_uuids_and_vectors = await get_embeddings_for_nodes(
            driver, list(node_uuid_map.values())
        )

        reranked_uuids, node_scores = maximal_marginal_relevance(
            query_vector,
            search_result_uuids_and_vectors,
            config.mmr_lambda,
            reranker_min_score,
        )
    elif config.reranker == NodeReranker.cross_encoder:
        name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}

        reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
        reranked_uuids = [
            name_to_uuid_map[name]
            for name, score in reranked_node_names
            if score >= reranker_min_score
        ]
        node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
    elif config.reranker == NodeReranker.episode_mentions:
        reranked_uuids, node_scores = await episode_mentions_reranker(
            driver, search_result_uuids, min_score=reranker_min_score
        )
    elif config.reranker == NodeReranker.node_distance:
        if center_node_uuid is None:
            raise SearchRerankerError('No center node provided for Node Distance reranker')
        reranked_uuids, node_scores = await node_distance_reranker(
            driver,
            rrf(search_result_uuids, min_score=reranker_min_score)[0],
            center_node_uuid,
            min_score=reranker_min_score,
        )

    reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]

    return reranked_nodes[:limit], node_scores[:limit]


async def episode_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    _query_vector: list[float],
    group_ids: list[str] | None,
    config: EpisodeSearchConfig | None,
    search_filter: SearchFilters,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[EpisodicNode], list[float]]:
    if config is None:
        return [], []
    search_results: list[list[EpisodicNode]] = list(
        await semaphore_gather(
            *[
                episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
            ]
        )
    )

    search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
    episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}

    reranked_uuids: list[str] = []
    episode_scores: list[float] = []
    if config.reranker == EpisodeReranker.rrf:
        reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)

    elif config.reranker == EpisodeReranker.cross_encoder:
        # use rrf as a preliminary reranker
        rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
        rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]

        content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}

        reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
        reranked_uuids = [
            content_to_uuid_map[content]
            for content, score in reranked_contents
            if score >= reranker_min_score
        ]
        episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]

    reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]

    return reranked_episodes[:limit], episode_scores[:limit]


async def community_search(
    driver: GraphDriver,
    cross_encoder: CrossEncoderClient,
    query: str,
    query_vector: list[float],
    group_ids: list[str] | None,
    config: CommunitySearchConfig | None,
    limit=DEFAULT_SEARCH_LIMIT,
    reranker_min_score: float = 0,
) -> tuple[list[CommunityNode], list[float]]:
    if config is None:
        return [], []

    search_results: list[list[CommunityNode]] = list(
        await semaphore_gather(
            *[
                community_fulltext_search(driver, query, group_ids, 2 * limit),
                community_similarity_search(
                    driver, query_vector, group_ids, 2 * limit, config.sim_min_score
                ),
            ]
        )
    )

    search_result_uuids = [[community.uuid for community in result] for result in search_results]
    community_uuid_map = {
        community.uuid: community for result in search_results for community in result
    }

    reranked_uuids: list[str] = []
    community_scores: list[float] = []
    if config.reranker == CommunityReranker.rrf:
        reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
    elif config.reranker == CommunityReranker.mmr:
        search_result_uuids_and_vectors = await get_embeddings_for_communities(
            driver, list(community_uuid_map.values())
        )

        reranked_uuids, community_scores = maximal_marginal_relevance(
            query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
        )
    elif config.reranker == CommunityReranker.cross_encoder:
        name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
        reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
        reranked_uuids = [
            name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
        ]
        community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]

    reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]

    return reranked_communities[:limit], community_scores[:limit]

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_integration.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
Tests all major MCP tools and handles episode processing latency.
"""

import asyncio
import json
import os
import time
from typing import Any

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client


class GraphitiMCPIntegrationTest:
    """Integration test client for Graphiti MCP Server using official MCP SDK."""

    def __init__(self):
        self.test_group_id = f'test_group_{int(time.time())}'
        self.session = None

    async def __aenter__(self):
        """Start the MCP client session."""
        # Configure server parameters to run our refactored server
        server_params = StdioServerParameters(
            command='uv',
            args=['run', 'main.py', '--transport', 'stdio'],
            env={
                'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
                'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
            },
        )

        print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')

        # Use the async context manager properly
        self.client_context = stdio_client(server_params)
        read, write = await self.client_context.__aenter__()
        self.session = ClientSession(read, write)
        await self.session.initialize()

        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Close the MCP client session."""
        if self.session:
            await self.session.close()
        if hasattr(self, 'client_context'):
            await self.client_context.__aexit__(exc_type, exc_val, exc_tb)

    async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
        """Call an MCP tool and return the result."""
        try:
            result = await self.session.call_tool(tool_name, arguments)
            return result.content[0].text if result.content else {'error': 'No content returned'}
        except Exception as e:
            return {'error': str(e)}

    async def test_server_initialization(self) -> bool:
        """Test that the server initializes properly."""
        print('🔍 Testing server initialization...')

        try:
            # List available tools to verify server is responding
            tools_result = await self.session.list_tools()
            tools = [tool.name for tool in tools_result.tools]

            expected_tools = [
                'add_memory',
                'search_memory_nodes',
                'search_memory_facts',
                'get_episodes',
                'delete_episode',
                'delete_entity_edge',
                'get_entity_edge',
                'clear_graph',
            ]

            available_tools = len([tool for tool in expected_tools if tool in tools])
            print(
                f'   ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
            )
            print(f'   Available tools: {", ".join(sorted(tools))}')

            return available_tools >= len(expected_tools) * 0.8  # 80% of expected tools

        except Exception as e:
            print(f'   ❌ Server initialization failed: {e}')
            return False

    async def test_add_memory_operations(self) -> dict[str, bool]:
        """Test adding various types of memory episodes."""
        print('📝 Testing add_memory operations...')

        results = {}

        # Test 1: Add text episode
        print('   Testing text episode...')
        try:
            result = await self.call_tool(
                'add_memory',
                {
                    'name': 'Test Company News',
                    'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
                    'source': 'text',
                    'source_description': 'news article',
                    'group_id': self.test_group_id,
                },
            )

            if isinstance(result, str) and 'queued' in result.lower():
                print(f'   ✅ Text episode: {result}')
                results['text'] = True
            else:
                print(f'   ❌ Text episode failed: {result}')
                results['text'] = False
        except Exception as e:
            print(f'   ❌ Text episode error: {e}')
            results['text'] = False

        # Test 2: Add JSON episode
        print('   Testing JSON episode...')
        try:
            json_data = {
                'company': {'name': 'TechCorp', 'founded': 2010},
                'products': [
                    {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
                    {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
                ],
                'employees': 150,
            }

            result = await self.call_tool(
                'add_memory',
                {
                    'name': 'Company Profile',
                    'episode_body': json.dumps(json_data),
                    'source': 'json',
                    'source_description': 'CRM data',
                    'group_id': self.test_group_id,
                },
            )

            if isinstance(result, str) and 'queued' in result.lower():
                print(f'   ✅ JSON episode: {result}')
                results['json'] = True
            else:
                print(f'   ❌ JSON episode failed: {result}')
                results['json'] = False
        except Exception as e:
            print(f'   ❌ JSON episode error: {e}')
            results['json'] = False

        # Test 3: Add message episode
        print('   Testing message episode...')
        try:
            result = await self.call_tool(
                'add_memory',
                {
                    'name': 'Customer Support Chat',
                    'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
                    'source': 'message',
                    'source_description': 'support chat log',
                    'group_id': self.test_group_id,
                },
            )

            if isinstance(result, str) and 'queued' in result.lower():
                print(f'   ✅ Message episode: {result}')
                results['message'] = True
            else:
                print(f'   ❌ Message episode failed: {result}')
                results['message'] = False
        except Exception as e:
            print(f'   ❌ Message episode error: {e}')
            results['message'] = False

        return results

    async def wait_for_processing(self, max_wait: int = 45) -> bool:
        """Wait for episode processing to complete."""
        print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')

        for i in range(max_wait):
            await asyncio.sleep(1)

            try:
                # Check if we have any episodes
                result = await self.call_tool(
                    'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
                )

                # Parse the JSON result if it's a string
                if isinstance(result, str):
                    try:
                        parsed_result = json.loads(result)
                        if isinstance(parsed_result, list) and len(parsed_result) > 0:
                            print(
                                f'   ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
                            )
                            return True
                    except json.JSONDecodeError:
                        if 'episodes' in result.lower():
                            print(f'   ✅ Episodes detected after {i + 1} seconds')
                            return True

            except Exception as e:
                if i == 0:  # Only log first error to avoid spam
                    print(f'   ⚠️  Waiting for processing... ({e})')
                continue

        print(f'   ⚠️  Still waiting after {max_wait} seconds...')
        return False

    async def test_search_operations(self) -> dict[str, bool]:
        """Test search functionality."""
        print('🔍 Testing search operations...')

        results = {}

        # Test search_memory_nodes
        print('   Testing search_memory_nodes...')
        try:
            result = await self.call_tool(
                'search_memory_nodes',
                {
                    'query': 'Acme Corp product launch AI',
                    'group_ids': [self.test_group_id],
                    'max_nodes': 5,
                },
            )

            success = False
            if isinstance(result, str):
                try:
                    parsed = json.loads(result)
                    nodes = parsed.get('nodes', [])
                    success = isinstance(nodes, list)
                    print(f'   ✅ Node search returned {len(nodes)} nodes')
                except json.JSONDecodeError:
                    success = 'nodes' in result.lower() and 'successfully' in result.lower()
                    if success:
                        print('   ✅ Node search completed successfully')

            results['nodes'] = success
            if not success:
                print(f'   ❌ Node search failed: {result}')

        except Exception as e:
            print(f'   ❌ Node search error: {e}')
            results['nodes'] = False

        # Test search_memory_facts
        print('   Testing search_memory_facts...')
        try:
            result = await self.call_tool(
                'search_memory_facts',
                {
                    'query': 'company products software TechCorp',
                    'group_ids': [self.test_group_id],
                    'max_facts': 5,
                },
            )

            success = False
            if isinstance(result, str):
                try:
                    parsed = json.loads(result)
                    facts = parsed.get('facts', [])
                    success = isinstance(facts, list)
                    print(f'   ✅ Fact search returned {len(facts)} facts')
                except json.JSONDecodeError:
                    success = 'facts' in result.lower() and 'successfully' in result.lower()
                    if success:
                        print('   ✅ Fact search completed successfully')

            results['facts'] = success
            if not success:
                print(f'   ❌ Fact search failed: {result}')

        except Exception as e:
            print(f'   ❌ Fact search error: {e}')
            results['facts'] = False

        return results

    async def test_episode_retrieval(self) -> bool:
        """Test episode retrieval."""
        print('📚 Testing episode retrieval...')

        try:
            result = await self.call_tool(
                'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
            )

            if isinstance(result, str):
                try:
                    parsed = json.loads(result)
                    if isinstance(parsed, list):
                        print(f'   ✅ Retrieved {len(parsed)} episodes')

                        # Show episode details
                        for i, episode in enumerate(parsed[:3]):
                            name = episode.get('name', 'Unknown')
                            source = episode.get('source', 'unknown')
                            print(f'     Episode {i + 1}: {name} (source: {source})')

                        return len(parsed) > 0
                except json.JSONDecodeError:
                    # Check if response indicates success
                    if 'episode' in result.lower():
                        print('   ✅ Episode retrieval completed')
                        return True

            print(f'   ❌ Unexpected result format: {result}')
            return False

        except Exception as e:
            print(f'   ❌ Episode retrieval failed: {e}')
            return False

    async def test_error_handling(self) -> dict[str, bool]:
        """Test error handling and edge cases."""
        print('🧪 Testing error handling...')

        results = {}

        # Test with nonexistent group
        print('   Testing nonexistent group handling...')
        try:
            result = await self.call_tool(
                'search_memory_nodes',
                {
                    'query': 'nonexistent data',
                    'group_ids': ['nonexistent_group_12345'],
                    'max_nodes': 5,
                },
            )

            # Should handle gracefully, not crash
            success = (
                'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
            )
            if success:
                print('   ✅ Nonexistent group handled gracefully')
            else:
                print(f'   ❌ Nonexistent group caused issues: {result}')

            results['nonexistent_group'] = success

        except Exception as e:
            print(f'   ❌ Nonexistent group test failed: {e}')
            results['nonexistent_group'] = False

        # Test empty query
        print('   Testing empty query handling...')
        try:
            result = await self.call_tool(
                'search_memory_nodes',
                {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
            )

            # Should handle gracefully
            success = (
                'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
            )
            if success:
                print('   ✅ Empty query handled gracefully')
            else:
                print(f'   ❌ Empty query caused issues: {result}')

            results['empty_query'] = success

        except Exception as e:
            print(f'   ❌ Empty query test failed: {e}')
            results['empty_query'] = False

        return results

    async def run_comprehensive_test(self) -> dict[str, Any]:
        """Run the complete integration test suite."""
        print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
        print(f'   Test group ID: {self.test_group_id}')
        print('=' * 70)

        results = {
            'server_init': False,
            'add_memory': {},
            'processing_wait': False,
            'search': {},
            'episodes': False,
            'error_handling': {},
            'overall_success': False,
        }

        # Test 1: Server Initialization
        results['server_init'] = await self.test_server_initialization()
        if not results['server_init']:
            print('❌ Server initialization failed, aborting remaining tests')
            return results

        print()

        # Test 2: Add Memory Operations
        results['add_memory'] = await self.test_add_memory_operations()
        print()

        # Test 3: Wait for Processing
        results['processing_wait'] = await self.wait_for_processing()
        print()

        # Test 4: Search Operations
        results['search'] = await self.test_search_operations()
        print()

        # Test 5: Episode Retrieval
        results['episodes'] = await self.test_episode_retrieval()
        print()

        # Test 6: Error Handling
        results['error_handling'] = await self.test_error_handling()
        print()

        # Calculate overall success
        memory_success = any(results['add_memory'].values())
        search_success = any(results['search'].values()) if results['search'] else False
        error_success = (
            any(results['error_handling'].values()) if results['error_handling'] else True
        )

        results['overall_success'] = (
            results['server_init']
            and memory_success
            and (results['episodes'] or results['processing_wait'])
            and error_success
        )

        # Print comprehensive summary
        print('=' * 70)
        print('📊 COMPREHENSIVE TEST SUMMARY')
        print('-' * 35)
        print(f'Server Initialization:    {"✅ PASS" if results["server_init"] else "❌ FAIL"}')

        memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
        print(
            f'Memory Operations:        {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
        )

        print(f'Processing Pipeline:      {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')

        search_stats = (
            f'({sum(results["search"].values())}/{len(results["search"])} types)'
            if results['search']
            else '(0/0 types)'
        )
        print(
            f'Search Operations:        {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
        )

        print(f'Episode Retrieval:        {"✅ PASS" if results["episodes"] else "❌ FAIL"}')

        error_stats = (
            f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
            if results['error_handling']
            else '(0/0 cases)'
        )
        print(
            f'Error Handling:           {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
        )

        print('-' * 35)
        print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')

        if results['overall_success']:
            print('\n🎉 The refactored Graphiti MCP server is working correctly!')
            print('   All core functionality has been successfully tested.')
        else:
            print('\n⚠️  Some issues were detected. Review the test results above.')
            print('   The refactoring may need additional attention.')

        return results


async def main():
    """Run the integration test."""
    try:
        async with GraphitiMCPIntegrationTest() as test:
            results = await test.run_comprehensive_test()

            # Exit with appropriate code
            exit_code = 0 if results['overall_success'] else 1
            exit(exit_code)
    except Exception as e:
        print(f'❌ Test setup failed: {e}')
        exit(1)


if __name__ == '__main__':
    asyncio.run(main())

```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/gemini_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import re
import typing
from typing import TYPE_CHECKING, ClassVar

from pydantic import BaseModel

from ..prompts.models import Message
from .client import LLMClient, get_extraction_language_instruction
from .config import LLMConfig, ModelSize
from .errors import RateLimitError

if TYPE_CHECKING:
    from google import genai
    from google.genai import types
else:
    try:
        from google import genai
        from google.genai import types
    except ImportError:
        # If gemini client is not installed, raise an ImportError
        raise ImportError(
            'google-genai is required for GeminiClient. '
            'Install it with: pip install graphiti-core[google-genai]'
        ) from None


logger = logging.getLogger(__name__)

DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'

# Maximum output tokens for different Gemini models
GEMINI_MODEL_MAX_TOKENS = {
    # Gemini 2.5 models
    'gemini-2.5-pro': 65536,
    'gemini-2.5-flash': 65536,
    'gemini-2.5-flash-lite': 64000,
    # Gemini 2.0 models
    'gemini-2.0-flash': 8192,
    'gemini-2.0-flash-lite': 8192,
    # Gemini 1.5 models
    'gemini-1.5-pro': 8192,
    'gemini-1.5-flash': 8192,
    'gemini-1.5-flash-8b': 8192,
}

# Default max tokens for models not in the mapping
DEFAULT_GEMINI_MAX_TOKENS = 8192


class GeminiClient(LLMClient):
    """
    GeminiClient is a client class for interacting with Google's Gemini language models.

    This class extends the LLMClient and provides methods to initialize the client
    and generate responses from the Gemini language model.

    Attributes:
        model (str): The model name to use for generating responses.
        temperature (float): The temperature to use for generating responses.
        max_tokens (int): The maximum number of tokens to generate in a response.
        thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
    Methods:
        __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
            Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.

        _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
            Generates a response from the language model based on the provided messages.
    """

    # Class-level constants
    MAX_RETRIES: ClassVar[int] = 2

    def __init__(
        self,
        config: LLMConfig | None = None,
        cache: bool = False,
        max_tokens: int | None = None,
        thinking_config: types.ThinkingConfig | None = None,
        client: 'genai.Client | None' = None,
    ):
        """
        Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.

        Args:
            config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
            cache (bool): Whether to use caching for responses. Defaults to False.
            thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
                Only use with models that support thinking (gemini-2.5+). Defaults to None.
            client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
        """
        if config is None:
            config = LLMConfig()

        super().__init__(config, cache)

        self.model = config.model

        if client is None:
            self.client = genai.Client(api_key=config.api_key)
        else:
            self.client = client

        self.max_tokens = max_tokens
        self.thinking_config = thinking_config

    def _check_safety_blocks(self, response) -> None:
        """Check if response was blocked for safety reasons and raise appropriate exceptions."""
        # Check if the response was blocked for safety reasons
        if not (hasattr(response, 'candidates') and response.candidates):
            return

        candidate = response.candidates[0]
        if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
            return

        # Content was blocked for safety reasons - collect safety details
        safety_info = []
        safety_ratings = getattr(candidate, 'safety_ratings', None)

        if safety_ratings:
            for rating in safety_ratings:
                if getattr(rating, 'blocked', False):
                    category = getattr(rating, 'category', 'Unknown')
                    probability = getattr(rating, 'probability', 'Unknown')
                    safety_info.append(f'{category}: {probability}')

        safety_details = (
            ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
        )
        raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')

    def _check_prompt_blocks(self, response) -> None:
        """Check if prompt was blocked and raise appropriate exceptions."""
        prompt_feedback = getattr(response, 'prompt_feedback', None)
        if not prompt_feedback:
            return

        block_reason = getattr(prompt_feedback, 'block_reason', None)
        if block_reason:
            raise Exception(f'Prompt blocked by Gemini: {block_reason}')

    def _get_model_for_size(self, model_size: ModelSize) -> str:
        """Get the appropriate model name based on the requested size."""
        if model_size == ModelSize.small:
            return self.small_model or DEFAULT_SMALL_MODEL
        else:
            return self.model or DEFAULT_MODEL

    def _get_max_tokens_for_model(self, model: str) -> int:
        """Get the maximum output tokens for a specific Gemini model."""
        return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)

    def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
        """
        Resolve the maximum output tokens to use based on precedence rules.

        Precedence order (highest to lowest):
        1. Explicit max_tokens parameter passed to generate_response()
        2. Instance max_tokens set during client initialization
        3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
        4. DEFAULT_MAX_TOKENS as final fallback

        Args:
            requested_max_tokens: The max_tokens parameter passed to generate_response()
            model: The model name to look up model-specific limits

        Returns:
            int: The resolved maximum tokens to use
        """
        # 1. Use explicit parameter if provided
        if requested_max_tokens is not None:
            return requested_max_tokens

        # 2. Use instance max_tokens if set during initialization
        if self.max_tokens is not None:
            return self.max_tokens

        # 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
        return self._get_max_tokens_for_model(model)

    def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
        """
        Attempt to salvage a JSON object if the raw output is truncated.

        This is accomplished by looking for the last closing bracket for an array or object.
        If found, it will try to load the JSON object from the raw output.
        If the JSON object is not valid, it will return None.

        Args:
            raw_output (str): The raw output from the LLM.

        Returns:
            dict[str, typing.Any]: The salvaged JSON object.
            None: If no salvage is possible.
        """
        if not raw_output:
            return None
        # Try to salvage a JSON array
        array_match = re.search(r'\]\s*$', raw_output)
        if array_match:
            try:
                return json.loads(raw_output[: array_match.end()])
            except Exception:
                pass
        # Try to salvage a JSON object
        obj_match = re.search(r'\}\s*$', raw_output)
        if obj_match:
            try:
                return json.loads(raw_output[: obj_match.end()])
            except Exception:
                pass
        return None

    async def _generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the Gemini language model.

        Args:
            messages (list[Message]): A list of messages to send to the language model.
            response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
            max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
            model_size (ModelSize): The size of the model to use (small or medium).

        Returns:
            dict[str, typing.Any]: The response from the language model.

        Raises:
            RateLimitError: If the API rate limit is exceeded.
            Exception: If there is an error generating the response or content is blocked.
        """
        try:
            gemini_messages: typing.Any = []
            # If a response model is provided, add schema for structured output
            system_prompt = ''
            if response_model is not None:
                # Get the schema from the Pydantic model
                pydantic_schema = response_model.model_json_schema()

                # Create instruction to output in the desired JSON format
                system_prompt += (
                    f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
                    'Do not include any explanatory text before or after the JSON.\n\n'
                )

            # Add messages content
            # First check for a system message
            if messages and messages[0].role == 'system':
                system_prompt = f'{messages[0].content}\n\n {system_prompt}'
                messages = messages[1:]

            # Add the rest of the messages
            for m in messages:
                m.content = self._clean_input(m.content)
                gemini_messages.append(
                    types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
                )

            # Get the appropriate model for the requested size
            model = self._get_model_for_size(model_size)

            # Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
            resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)

            # Create generation config
            generation_config = types.GenerateContentConfig(
                temperature=self.temperature,
                max_output_tokens=resolved_max_tokens,
                response_mime_type='application/json' if response_model else None,
                response_schema=response_model if response_model else None,
                system_instruction=system_prompt,
                thinking_config=self.thinking_config,
            )

            # Generate content using the simple string approach
            response = await self.client.aio.models.generate_content(
                model=model,
                contents=gemini_messages,
                config=generation_config,
            )

            # Always capture the raw output for debugging
            raw_output = getattr(response, 'text', None)

            # Check for safety and prompt blocks
            self._check_safety_blocks(response)
            self._check_prompt_blocks(response)

            # If this was a structured output request, parse the response into the Pydantic model
            if response_model is not None:
                try:
                    if not raw_output:
                        raise ValueError('No response text')

                    validated_model = response_model.model_validate(json.loads(raw_output))

                    # Return as a dictionary for API consistency
                    return validated_model.model_dump()
                except Exception as e:
                    if raw_output:
                        logger.error(
                            '🦀 LLM generation failed parsing as JSON, will try to salvage.'
                        )
                        logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
                        # Try to salvage
                        salvaged = self.salvage_json(raw_output)
                        if salvaged is not None:
                            logger.warning('Salvaged partial JSON from truncated/malformed output.')
                            return salvaged
                    raise Exception(f'Failed to parse structured response: {e}') from e

            # Otherwise, return the response text as a dictionary
            return {'content': raw_output}

        except Exception as e:
            # Check if it's a rate limit error based on Gemini API error codes
            error_message = str(e).lower()
            if (
                'rate limit' in error_message
                or 'quota' in error_message
                or 'resource_exhausted' in error_message
                or '429' in str(e)
            ):
                raise RateLimitError from e

            logger.error(f'Error in generating LLM response: {e}')
            raise Exception from e

    async def generate_response(
        self,
        messages: list[Message],
        response_model: type[BaseModel] | None = None,
        max_tokens: int | None = None,
        model_size: ModelSize = ModelSize.medium,
        group_id: str | None = None,
        prompt_name: str | None = None,
    ) -> dict[str, typing.Any]:
        """
        Generate a response from the Gemini language model with retry logic and error handling.
        This method overrides the parent class method to provide a direct implementation with advanced retry logic.

        Args:
            messages (list[Message]): A list of messages to send to the language model.
            response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
            max_tokens (int | None): The maximum number of tokens to generate in the response.
            model_size (ModelSize): The size of the model to use (small or medium).
            group_id (str | None): Optional partition identifier for the graph.
            prompt_name (str | None): Optional name of the prompt for tracing.

        Returns:
            dict[str, typing.Any]: The response from the language model.
        """
        # Add multilingual extraction instructions
        messages[0].content += get_extraction_language_instruction(group_id)

        # Wrap entire operation in tracing span
        with self.tracer.start_span('llm.generate') as span:
            attributes = {
                'llm.provider': 'gemini',
                'model.size': model_size.value,
                'max_tokens': max_tokens or self.max_tokens,
            }
            if prompt_name:
                attributes['prompt.name'] = prompt_name
            span.add_attributes(attributes)

            retry_count = 0
            last_error = None
            last_output = None

            while retry_count < self.MAX_RETRIES:
                try:
                    response = await self._generate_response(
                        messages=messages,
                        response_model=response_model,
                        max_tokens=max_tokens,
                        model_size=model_size,
                    )
                    last_output = (
                        response.get('content')
                        if isinstance(response, dict) and 'content' in response
                        else None
                    )
                    return response
                except RateLimitError as e:
                    # Rate limit errors should not trigger retries (fail fast)
                    span.set_status('error', str(e))
                    raise e
                except Exception as e:
                    last_error = e

                    # Check if this is a safety block - these typically shouldn't be retried
                    error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
                    if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
                        logger.warning(f'Content blocked by safety filters: {e}')
                        span.set_status('error', str(e))
                        raise Exception(f'Content blocked by safety filters: {e}') from e

                    retry_count += 1

                    # Construct a detailed error message for the LLM
                    error_context = (
                        f'The previous response attempt was invalid. '
                        f'Error type: {e.__class__.__name__}. '
                        f'Error details: {str(e)}. '
                        f'Please try again with a valid response, ensuring the output matches '
                        f'the expected format and constraints.'
                    )

                    error_message = Message(role='user', content=error_context)
                    messages.append(error_message)
                    logger.warning(
                        f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
                    )

            # If we exit the loop without returning, all retries are exhausted
            logger.error('🦀 LLM generation failed and retries are exhausted.')
            logger.error(self._get_failed_generation_log(messages, last_output))
            logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
            span.set_status('error', str(last_error))
            span.record_exception(last_error) if last_error else None
            raise last_error or Exception('Max retries exceeded')

```
Page 5/9FirstPrevNextLast