#
tokens: 45869/50000 9/236 files (page 6/10)
lines: off (toggle) GitHub
raw markdown copy
This is page 6 of 10. Use http://codebase.md/getzep/graphiti?page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── dense_vs_normal_ingestion.py
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── content_chunking.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_entity_extraction.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       ├── search
│       │   └── search_utils_test.py
│       └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_edge_operations.py:
--------------------------------------------------------------------------------

```python
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
from pydantic import BaseModel

from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
    DEFAULT_EDGE_NAME,
    resolve_extracted_edge,
    resolve_extracted_edges,
)


@pytest.fixture
def mock_llm_client():
    client = MagicMock()
    client.generate_response = AsyncMock()
    return client


@pytest.fixture
def mock_extracted_edge():
    return EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='Test fact',
        episodes=['episode_1'],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )


@pytest.fixture
def mock_related_edges():
    return [
        EntityEdge(
            source_node_uuid='source_uuid_2',
            target_node_uuid='target_uuid_2',
            name='related_edge',
            group_id='group_1',
            fact='Related fact',
            episodes=['episode_2'],
            created_at=datetime.now(timezone.utc) - timedelta(days=1),
            valid_at=datetime.now(timezone.utc) - timedelta(days=1),
            invalid_at=None,
        )
    ]


@pytest.fixture
def mock_existing_edges():
    return [
        EntityEdge(
            source_node_uuid='source_uuid_3',
            target_node_uuid='target_uuid_3',
            name='existing_edge',
            group_id='group_1',
            fact='Existing fact',
            episodes=['episode_3'],
            created_at=datetime.now(timezone.utc) - timedelta(days=2),
            valid_at=datetime.now(timezone.utc) - timedelta(days=2),
            invalid_at=None,
        )
    ]


@pytest.fixture
def mock_current_episode():
    return EpisodicNode(
        uuid='episode_1',
        content='Current episode content',
        valid_at=datetime.now(timezone.utc),
        name='Current Episode',
        group_id='group_1',
        source='message',
        source_description='Test source description',
    )


@pytest.fixture
def mock_previous_episodes():
    return [
        EpisodicNode(
            uuid='episode_2',
            content='Previous episode content',
            valid_at=datetime.now(timezone.utc) - timedelta(days=1),
            name='Previous Episode',
            group_id='group_1',
            source='message',
            source_description='Test source description',
        )
    ]


# Run the tests
if __name__ == '__main__':
    pytest.main([__file__])


@pytest.mark.asyncio
async def test_resolve_extracted_edge_exact_fact_short_circuit(
    mock_llm_client,
    mock_existing_edges,
    mock_current_episode,
):
    extracted = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='Related fact',
        episodes=['episode_1'],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    related_edges = [
        EntityEdge(
            source_node_uuid='source_uuid',
            target_node_uuid='target_uuid',
            name='related_edge',
            group_id='group_1',
            fact=' related FACT  ',
            episodes=['episode_2'],
            created_at=datetime.now(timezone.utc) - timedelta(days=1),
            valid_at=None,
            invalid_at=None,
        )
    ]

    resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
        mock_llm_client,
        extracted,
        related_edges,
        mock_existing_edges,
        mock_current_episode,
        edge_type_candidates=None,
    )

    assert resolved_edge is related_edges[0]
    assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
    assert duplicate_edges == []
    assert invalidated == []
    mock_llm_client.generate_response.assert_not_called()


class OccurredAtEdge(BaseModel):
    """Edge model stub for OCCURRED_AT."""


@pytest.mark.asyncio
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
    from graphiti_core.utils.maintenance import edge_operations as edge_ops

    monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
    monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))

    async def immediate_gather(*aws, max_coroutines=None):
        return [await aw for aw in aws]

    monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
    monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'duplicate_facts': [],
            'contradicted_facts': [],
            'fact_type': 'DEFAULT',
        }
    )

    clients = SimpleNamespace(
        driver=MagicMock(),
        llm_client=llm_client,
        embedder=MagicMock(),
        cross_encoder=MagicMock(),
    )

    source_node = EntityNode(
        uuid='source_uuid',
        name='Document Node',
        group_id='group_1',
        labels=['Document'],
    )
    target_node = EntityNode(
        uuid='target_uuid',
        name='Topic Node',
        group_id='group_1',
        labels=['Topic'],
    )

    extracted_edge = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='OCCURRED_AT',
        group_id='group_1',
        fact='Document occurred at somewhere',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    edge_types = {'OCCURRED_AT': OccurredAtEdge}
    edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}

    resolved_edges, invalidated_edges = await resolve_extracted_edges(
        clients,
        [extracted_edge],
        episode,
        [source_node, target_node],
        edge_types,
        edge_type_map,
    )

    assert resolved_edges[0].name == DEFAULT_EDGE_NAME
    assert invalidated_edges == []


@pytest.mark.asyncio
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
    from graphiti_core.utils.maintenance import edge_operations as edge_ops

    monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
    monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))

    async def immediate_gather(*aws, max_coroutines=None):
        return [await aw for aw in aws]

    monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
    monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'duplicate_facts': [],
            'contradicted_facts': [],
            'fact_type': 'DEFAULT',
        }
    )

    clients = SimpleNamespace(
        driver=MagicMock(),
        llm_client=llm_client,
        embedder=MagicMock(),
        cross_encoder=MagicMock(),
    )

    source_node = EntityNode(
        uuid='source_uuid',
        name='User Node',
        group_id='group_1',
        labels=['User'],
    )
    target_node = EntityNode(
        uuid='target_uuid',
        name='Topic Node',
        group_id='group_1',
        labels=['Topic'],
    )

    extracted_edge = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='INTERACTED_WITH',
        group_id='group_1',
        fact='User interacted with topic',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    edge_types = {'OCCURRED_AT': OccurredAtEdge}
    edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}

    resolved_edges, invalidated_edges = await resolve_extracted_edges(
        clients,
        [extracted_edge],
        episode,
        [source_node, target_node],
        edge_types,
        edge_type_map,
    )

    assert resolved_edges[0].name == 'INTERACTED_WITH'
    assert invalidated_edges == []


@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
    mock_llm_client.generate_response.return_value = {
        'duplicate_facts': [],
        'contradicted_facts': [],
        'fact_type': 'OCCURRED_AT',
    }

    extracted_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='OCCURRED_AT',
        group_id='group_1',
        fact='Document occurred at somewhere',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    related_edge = EntityEdge(
        source_node_uuid='alt_source',
        target_node_uuid='alt_target',
        name='OTHER',
        group_id='group_1',
        fact='Different fact',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
        mock_llm_client,
        extracted_edge,
        [related_edge],
        [],
        episode,
        edge_type_candidates={},
        custom_edge_type_names={'OCCURRED_AT'},
    )

    assert resolved_edge.name == DEFAULT_EDGE_NAME
    assert duplicates == []
    assert invalidated == []


@pytest.mark.asyncio
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
    mock_llm_client.generate_response.return_value = {
        'duplicate_facts': [],
        'contradicted_facts': [],
        'fact_type': 'INTERACTED_WITH',
    }

    extracted_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='DEFAULT',
        group_id='group_1',
        fact='User interacted with topic',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    related_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='DEFAULT',
        group_id='group_1',
        fact='User mentioned a topic',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
        mock_llm_client,
        extracted_edge,
        [related_edge],
        [],
        episode,
        edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
        custom_edge_type_names={'OCCURRED_AT'},
    )

    assert resolved_edge.name == 'INTERACTED_WITH'
    assert resolved_edge.attributes == {}
    assert duplicates == []
    assert invalidated == []


@pytest.mark.asyncio
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
    """Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
    # Mock LLM to return duplicate_facts with integer indices
    mock_llm_client.generate_response.return_value = {
        'duplicate_facts': [0, 1],  # LLM identifies first two related edges as duplicates
        'contradicted_facts': [],
        'fact_type': 'DEFAULT',
    }

    extracted_edge = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User likes yoga',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    # Create multiple related edges - LLM should receive these with integer indices
    related_edge_0 = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User enjoys yoga',
        episodes=['episode_1'],
        created_at=datetime.now(timezone.utc) - timedelta(days=1),
        valid_at=None,
        invalid_at=None,
    )

    related_edge_1 = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User practices yoga',
        episodes=['episode_2'],
        created_at=datetime.now(timezone.utc) - timedelta(days=2),
        valid_at=None,
        invalid_at=None,
    )

    related_edge_2 = EntityEdge(
        source_node_uuid='source_uuid',
        target_node_uuid='target_uuid',
        name='test_edge',
        group_id='group_1',
        fact='User loves swimming',
        episodes=['episode_3'],
        created_at=datetime.now(timezone.utc) - timedelta(days=3),
        valid_at=None,
        invalid_at=None,
    )

    related_edges = [related_edge_0, related_edge_1, related_edge_2]

    resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
        mock_llm_client,
        extracted_edge,
        related_edges,
        [],
        episode,
        edge_type_candidates=None,
        custom_edge_type_names=set(),
    )

    # Verify LLM was called
    mock_llm_client.generate_response.assert_called_once()

    # Verify the system correctly identified duplicates using integer indices
    # The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
    assert len(duplicates) == 2
    assert related_edge_0 in duplicates
    assert related_edge_1 in duplicates
    assert invalidated == []

    # Verify that the resolved edge is one of the duplicates (the first one found)
    # Check UUID since the episode list gets modified
    assert resolved_edge.uuid == related_edge_0.uuid
    assert episode.uuid in resolved_edge.episodes


@pytest.mark.asyncio
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
    """Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
    from graphiti_core.utils.maintenance import edge_operations as edge_ops

    monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
    monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))

    # Track how many times resolve_extracted_edge is called
    resolve_call_count = 0

    async def mock_resolve_extracted_edge(
        llm_client,
        extracted_edge,
        related_edges,
        existing_edges,
        episode,
        edge_type_candidates=None,
        custom_edge_type_names=None,
    ):
        nonlocal resolve_call_count
        resolve_call_count += 1
        return extracted_edge, [], []

    # Mock semaphore_gather to execute awaitable immediately
    async def immediate_gather(*aws, max_coroutines=None):
        results = []
        for aw in aws:
            results.append(await aw)
        return results

    monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
    monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
    monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)

    llm_client = MagicMock()
    clients = SimpleNamespace(
        driver=MagicMock(),
        llm_client=llm_client,
        embedder=MagicMock(),
        cross_encoder=MagicMock(),
    )

    source_node = EntityNode(
        uuid='source_uuid',
        name='Assistant',
        group_id='group_1',
        labels=['Entity'],
    )
    target_node = EntityNode(
        uuid='target_uuid',
        name='User',
        group_id='group_1',
        labels=['Entity'],
    )

    # Create 3 identical edges
    edge1 = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='recommends',
        group_id='group_1',
        fact='assistant recommends yoga poses',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    edge2 = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='recommends',
        group_id='group_1',
        fact='  Assistant Recommends YOGA Poses  ',  # Different whitespace/case
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    edge3 = EntityEdge(
        source_node_uuid=source_node.uuid,
        target_node_uuid=target_node.uuid,
        name='recommends',
        group_id='group_1',
        fact='assistant recommends yoga poses',
        episodes=[],
        created_at=datetime.now(timezone.utc),
        valid_at=None,
        invalid_at=None,
    )

    episode = EpisodicNode(
        uuid='episode_uuid',
        name='Episode',
        group_id='group_1',
        source='message',
        source_description='desc',
        content='Episode content',
        valid_at=datetime.now(timezone.utc),
    )

    resolved_edges, invalidated_edges = await resolve_extracted_edges(
        clients,
        [edge1, edge2, edge3],
        episode,
        [source_node, target_node],
        {},
        {},
    )

    # Fast path should have deduplicated the 3 identical edges to 1
    # So resolve_extracted_edge should only be called once
    assert resolve_call_count == 1
    assert len(resolved_edges) == 1
    assert invalidated_edges == []

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_stress_load.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Stress and load testing for Graphiti MCP Server.
Tests system behavior under high load, resource constraints, and edge conditions.
"""

import asyncio
import gc
import random
import time
from dataclasses import dataclass

import psutil
import pytest
from test_fixtures import TestDataGenerator, graphiti_test_client


@dataclass
class LoadTestConfig:
    """Configuration for load testing scenarios."""

    num_clients: int = 10
    operations_per_client: int = 100
    ramp_up_time: float = 5.0  # seconds
    test_duration: float = 60.0  # seconds
    target_throughput: float | None = None  # ops/sec
    think_time: float = 0.1  # seconds between ops


@dataclass
class LoadTestResult:
    """Results from a load test run."""

    total_operations: int
    successful_operations: int
    failed_operations: int
    duration: float
    throughput: float
    average_latency: float
    p50_latency: float
    p95_latency: float
    p99_latency: float
    max_latency: float
    errors: dict[str, int]
    resource_usage: dict[str, float]


class LoadTester:
    """Orchestrate load testing scenarios."""

    def __init__(self, config: LoadTestConfig):
        self.config = config
        self.metrics: list[tuple[float, float, bool]] = []  # (start, duration, success)
        self.errors: dict[str, int] = {}
        self.start_time: float | None = None

    async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
        """Run workload for a single simulated client."""
        stats = {'success': 0, 'failure': 0}
        data_gen = TestDataGenerator()

        # Ramp-up delay
        ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
        await asyncio.sleep(ramp_delay)

        for op_num in range(self.config.operations_per_client):
            operation_start = time.time()

            try:
                # Randomly select operation type
                operation = random.choice(
                    [
                        'add_memory',
                        'search_memory_nodes',
                        'get_episodes',
                    ]
                )

                if operation == 'add_memory':
                    args = {
                        'name': f'Load Test {client_id}-{op_num}',
                        'episode_body': data_gen.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'load test',
                        'group_id': group_id,
                    }
                elif operation == 'search_memory_nodes':
                    args = {
                        'query': random.choice(['performance', 'architecture', 'test', 'data']),
                        'group_id': group_id,
                        'limit': 10,
                    }
                else:  # get_episodes
                    args = {
                        'group_id': group_id,
                        'last_n': 10,
                    }

                # Execute operation with timeout
                await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)

                duration = time.time() - operation_start
                self.metrics.append((operation_start, duration, True))
                stats['success'] += 1

            except asyncio.TimeoutError:
                duration = time.time() - operation_start
                self.metrics.append((operation_start, duration, False))
                self.errors['timeout'] = self.errors.get('timeout', 0) + 1
                stats['failure'] += 1

            except Exception as e:
                duration = time.time() - operation_start
                self.metrics.append((operation_start, duration, False))
                error_type = type(e).__name__
                self.errors[error_type] = self.errors.get(error_type, 0) + 1
                stats['failure'] += 1

            # Think time between operations
            await asyncio.sleep(self.config.think_time)

            # Stop if we've exceeded test duration
            if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
                break

        return stats

    def calculate_results(self) -> LoadTestResult:
        """Calculate load test results from metrics."""
        if not self.metrics:
            return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})

        successful = [m for m in self.metrics if m[2]]
        failed = [m for m in self.metrics if not m[2]]

        latencies = sorted([m[1] for m in self.metrics])
        duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])

        # Calculate percentiles
        def percentile(data: list[float], p: float) -> float:
            if not data:
                return 0.0
            idx = int(len(data) * p / 100)
            return data[min(idx, len(data) - 1)]

        # Get resource usage
        process = psutil.Process()
        resource_usage = {
            'cpu_percent': process.cpu_percent(),
            'memory_mb': process.memory_info().rss / 1024 / 1024,
            'num_threads': process.num_threads(),
        }

        return LoadTestResult(
            total_operations=len(self.metrics),
            successful_operations=len(successful),
            failed_operations=len(failed),
            duration=duration,
            throughput=len(self.metrics) / duration if duration > 0 else 0,
            average_latency=sum(latencies) / len(latencies) if latencies else 0,
            p50_latency=percentile(latencies, 50),
            p95_latency=percentile(latencies, 95),
            p99_latency=percentile(latencies, 99),
            max_latency=max(latencies) if latencies else 0,
            errors=self.errors,
            resource_usage=resource_usage,
        )


class TestLoadScenarios:
    """Various load testing scenarios."""

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_sustained_load(self):
        """Test system under sustained moderate load."""
        config = LoadTestConfig(
            num_clients=5,
            operations_per_client=20,
            ramp_up_time=2.0,
            test_duration=30.0,
            think_time=0.5,
        )

        async with graphiti_test_client() as (session, group_id):
            tester = LoadTester(config)
            tester.start_time = time.time()

            # Run client workloads
            client_tasks = []
            for client_id in range(config.num_clients):
                task = tester.run_client_workload(client_id, session, group_id)
                client_tasks.append(task)

            # Execute all clients
            await asyncio.gather(*client_tasks)

            # Calculate results
            results = tester.calculate_results()

            # Assertions
            assert results.successful_operations > results.failed_operations
            assert results.average_latency < 5.0, (
                f'Average latency too high: {results.average_latency:.2f}s'
            )
            assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'

            # Report results
            print('\nSustained Load Test Results:')
            print(f'  Total operations: {results.total_operations}')
            print(
                f'  Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
            )
            print(f'  Throughput: {results.throughput:.2f} ops/s')
            print(f'  Avg latency: {results.average_latency:.2f}s')
            print(f'  P95 latency: {results.p95_latency:.2f}s')

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_spike_load(self):
        """Test system response to sudden load spikes."""
        async with graphiti_test_client() as (session, group_id):
            # Normal load phase
            normal_tasks = []
            for i in range(3):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Normal Load {i}',
                        'episode_body': 'Normal operation',
                        'source': 'text',
                        'source_description': 'normal',
                        'group_id': group_id,
                    },
                )
                normal_tasks.append(task)
                await asyncio.sleep(0.5)

            await asyncio.gather(*normal_tasks)

            # Spike phase - sudden burst of requests
            spike_start = time.time()
            spike_tasks = []
            for i in range(50):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Spike Load {i}',
                        'episode_body': TestDataGenerator.generate_technical_document(),
                        'source': 'text',
                        'source_description': 'spike',
                        'group_id': group_id,
                    },
                )
                spike_tasks.append(task)

            # Execute spike
            spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
            spike_duration = time.time() - spike_start

            # Analyze spike handling
            spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
            spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)

            print('\nSpike Load Test Results:')
            print(f'  Spike size: {len(spike_tasks)} operations')
            print(f'  Duration: {spike_duration:.2f}s')
            print(f'  Success rate: {spike_success_rate * 100:.1f}%')
            print(f'  Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')

            # System should handle at least 80% of spike
            assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_memory_leak_detection(self):
        """Test for memory leaks during extended operation."""
        async with graphiti_test_client() as (session, group_id):
            process = psutil.Process()
            gc.collect()  # Force garbage collection
            initial_memory = process.memory_info().rss / 1024 / 1024  # MB

            # Perform many operations
            for batch in range(10):
                batch_tasks = []
                for i in range(10):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Memory Test {batch}-{i}',
                            'episode_body': TestDataGenerator.generate_technical_document(),
                            'source': 'text',
                            'source_description': 'memory test',
                            'group_id': group_id,
                        },
                    )
                    batch_tasks.append(task)

                await asyncio.gather(*batch_tasks)

                # Force garbage collection between batches
                gc.collect()
                await asyncio.sleep(1)

            # Check memory after operations
            gc.collect()
            final_memory = process.memory_info().rss / 1024 / 1024  # MB
            memory_growth = final_memory - initial_memory

            print('\nMemory Leak Test:')
            print(f'  Initial memory: {initial_memory:.1f} MB')
            print(f'  Final memory: {final_memory:.1f} MB')
            print(f'  Growth: {memory_growth:.1f} MB')

            # Allow for some memory growth but flag potential leaks
            # This is a soft check - actual threshold depends on system
            if memory_growth > 100:  # More than 100MB growth
                print(f'  ⚠️  Potential memory leak detected: {memory_growth:.1f} MB growth')

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_connection_pool_exhaustion(self):
        """Test behavior when connection pools are exhausted."""
        async with graphiti_test_client() as (session, group_id):
            # Create many concurrent long-running operations
            long_tasks = []
            for i in range(100):  # Many more than typical pool size
                task = session.call_tool(
                    'search_memory_nodes',
                    {
                        'query': f'complex query {i} '
                        + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
                        'group_id': group_id,
                        'limit': 100,
                    },
                )
                long_tasks.append(task)

            # Execute with timeout
            try:
                results = await asyncio.wait_for(
                    asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
                )

                # Count connection-related errors
                connection_errors = sum(
                    1
                    for r in results
                    if isinstance(r, Exception) and 'connection' in str(r).lower()
                )

                print('\nConnection Pool Test:')
                print(f'  Total requests: {len(long_tasks)}')
                print(f'  Connection errors: {connection_errors}')

            except asyncio.TimeoutError:
                print('  Test timed out - possible deadlock or exhaustion')

    @pytest.mark.asyncio
    @pytest.mark.slow
    async def test_gradual_degradation(self):
        """Test system degradation under increasing load."""
        async with graphiti_test_client() as (session, group_id):
            load_levels = [5, 10, 20, 40, 80]  # Increasing concurrent operations
            results_by_level = {}

            for level in load_levels:
                level_start = time.time()
                tasks = []

                for i in range(level):
                    task = session.call_tool(
                        'add_memory',
                        {
                            'name': f'Load Level {level} Op {i}',
                            'episode_body': f'Testing at load level {level}',
                            'source': 'text',
                            'source_description': 'degradation test',
                            'group_id': group_id,
                        },
                    )
                    tasks.append(task)

                # Execute level
                level_results = await asyncio.gather(*tasks, return_exceptions=True)
                level_duration = time.time() - level_start

                # Calculate metrics
                failures = sum(1 for r in level_results if isinstance(r, Exception))
                success_rate = (level - failures) / level * 100
                throughput = level / level_duration

                results_by_level[level] = {
                    'success_rate': success_rate,
                    'throughput': throughput,
                    'duration': level_duration,
                }

                print(f'\nLoad Level {level}:')
                print(f'  Success rate: {success_rate:.1f}%')
                print(f'  Throughput: {throughput:.2f} ops/s')
                print(f'  Duration: {level_duration:.2f}s')

                # Brief pause between levels
                await asyncio.sleep(2)

            # Verify graceful degradation
            # Success rate should not drop below 50% even at high load
            for level, metrics in results_by_level.items():
                assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'


class TestResourceLimits:
    """Test behavior at resource limits."""

    @pytest.mark.asyncio
    async def test_large_payload_handling(self):
        """Test handling of very large payloads."""
        async with graphiti_test_client() as (session, group_id):
            payload_sizes = [
                (1_000, '1KB'),
                (10_000, '10KB'),
                (100_000, '100KB'),
                (1_000_000, '1MB'),
            ]

            for size, label in payload_sizes:
                content = 'x' * size

                start_time = time.time()
                try:
                    await asyncio.wait_for(
                        session.call_tool(
                            'add_memory',
                            {
                                'name': f'Large Payload {label}',
                                'episode_body': content,
                                'source': 'text',
                                'source_description': 'payload test',
                                'group_id': group_id,
                            },
                        ),
                        timeout=30.0,
                    )
                    duration = time.time() - start_time
                    status = '✅ Success'

                except asyncio.TimeoutError:
                    duration = 30.0
                    status = '⏱️  Timeout'

                except Exception as e:
                    duration = time.time() - start_time
                    status = f'❌ Error: {type(e).__name__}'

                print(f'Payload {label}: {status} ({duration:.2f}s)')

    @pytest.mark.asyncio
    async def test_rate_limit_handling(self):
        """Test handling of rate limits."""
        async with graphiti_test_client() as (session, group_id):
            # Rapid fire requests to trigger rate limits
            rapid_tasks = []
            for i in range(100):
                task = session.call_tool(
                    'add_memory',
                    {
                        'name': f'Rate Limit Test {i}',
                        'episode_body': f'Testing rate limit {i}',
                        'source': 'text',
                        'source_description': 'rate test',
                        'group_id': group_id,
                    },
                )
                rapid_tasks.append(task)

            # Execute without delays
            results = await asyncio.gather(*rapid_tasks, return_exceptions=True)

            # Count rate limit errors
            rate_limit_errors = sum(
                1
                for r in results
                if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
            )

            print('\nRate Limit Test:')
            print(f'  Total requests: {len(rapid_tasks)}')
            print(f'  Rate limit errors: {rate_limit_errors}')
            print(
                f'  Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
            )


def generate_load_test_report(results: list[LoadTestResult]) -> str:
    """Generate comprehensive load test report."""
    report = []
    report.append('\n' + '=' * 60)
    report.append('LOAD TEST REPORT')
    report.append('=' * 60)

    for i, result in enumerate(results):
        report.append(f'\nTest Run {i + 1}:')
        report.append(f'  Total Operations: {result.total_operations}')
        report.append(
            f'  Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
        )
        report.append(f'  Throughput: {result.throughput:.2f} ops/s')
        report.append(
            f'  Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
        )

        if result.errors:
            report.append('  Errors:')
            for error_type, count in result.errors.items():
                report.append(f'    {error_type}: {count}')

        report.append('  Resource Usage:')
        for metric, value in result.resource_usage.items():
            report.append(f'    {metric}: {value:.2f}')

    report.append('=' * 60)
    return '\n'.join(report)


if __name__ == '__main__':
    pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])

```

--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import typing
from datetime import datetime

import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any

from graphiti_core.driver.driver import (
    GraphDriver,
    GraphDriverSession,
    GraphProvider,
)
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
    get_entity_edge_save_bulk_query,
    get_episodic_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
    get_entity_node_save_bulk_query,
    get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
    DedupResolutionState,
    _build_candidate_indexes,
    _normalize_string_exact,
    _resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
    extract_edges,
    resolve_extracted_edge,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
    EPISODE_WINDOW_LEN,
    retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
    extract_nodes,
    resolve_extracted_nodes,
)

logger = logging.getLogger(__name__)

CHUNK_SIZE = 10


def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
    """Collapse alias -> canonical chains while preserving direction.

    The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
    union-find with iterative path compression to ensure every source UUID resolves to its ultimate
    canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
    """

    parent: dict[str, str] = {}

    def find(uuid: str) -> str:
        """Directed union-find lookup using iterative path compression."""
        parent.setdefault(uuid, uuid)
        root = uuid
        while parent[root] != root:
            root = parent[root]

        while parent[uuid] != root:
            next_uuid = parent[uuid]
            parent[uuid] = root
            uuid = next_uuid

        return root

    for source_uuid, target_uuid in pairs:
        parent.setdefault(source_uuid, source_uuid)
        parent.setdefault(target_uuid, target_uuid)
        parent[find(source_uuid)] = find(target_uuid)

    return {uuid: find(uuid) for uuid in parent}


class RawEpisode(BaseModel):
    name: str
    uuid: str | None = Field(default=None)
    content: str
    source_description: str
    source: EpisodeType
    reference_time: datetime


async def retrieve_previous_episodes_bulk(
    driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
    previous_episodes_list = await semaphore_gather(
        *[
            retrieve_episodes(
                driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
            )
            for episode in episodes
        ]
    )
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
        (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
    ]

    return episode_tuples


async def add_nodes_and_edges_bulk(
    driver: GraphDriver,
    episodic_nodes: list[EpisodicNode],
    episodic_edges: list[EpisodicEdge],
    entity_nodes: list[EntityNode],
    entity_edges: list[EntityEdge],
    embedder: EmbedderClient,
):
    session = driver.session()
    try:
        await session.execute_write(
            add_nodes_and_edges_bulk_tx,
            episodic_nodes,
            episodic_edges,
            entity_nodes,
            entity_edges,
            embedder,
            driver=driver,
        )
    finally:
        await session.close()


async def add_nodes_and_edges_bulk_tx(
    tx: GraphDriverSession,
    episodic_nodes: list[EpisodicNode],
    episodic_edges: list[EpisodicEdge],
    entity_nodes: list[EntityNode],
    entity_edges: list[EntityEdge],
    embedder: EmbedderClient,
    driver: GraphDriver,
):
    episodes = [dict(episode) for episode in episodic_nodes]
    for episode in episodes:
        episode['source'] = str(episode['source'].value)
        episode.pop('labels', None)

    nodes = []

    for node in entity_nodes:
        if node.name_embedding is None:
            await node.generate_name_embedding(embedder)

        entity_data: dict[str, Any] = {
            'uuid': node.uuid,
            'name': node.name,
            'group_id': node.group_id,
            'summary': node.summary,
            'created_at': node.created_at,
            'name_embedding': node.name_embedding,
            'labels': list(set(node.labels + ['Entity'])),
        }

        if driver.provider == GraphProvider.KUZU:
            attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
            entity_data['attributes'] = json.dumps(attributes)
        else:
            entity_data.update(node.attributes or {})

        nodes.append(entity_data)

    edges = []
    for edge in entity_edges:
        if edge.fact_embedding is None:
            await edge.generate_embedding(embedder)
        edge_data: dict[str, Any] = {
            'uuid': edge.uuid,
            'source_node_uuid': edge.source_node_uuid,
            'target_node_uuid': edge.target_node_uuid,
            'name': edge.name,
            'fact': edge.fact,
            'group_id': edge.group_id,
            'episodes': edge.episodes,
            'created_at': edge.created_at,
            'expired_at': edge.expired_at,
            'valid_at': edge.valid_at,
            'invalid_at': edge.invalid_at,
            'fact_embedding': edge.fact_embedding,
        }

        if driver.provider == GraphProvider.KUZU:
            attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
            edge_data['attributes'] = json.dumps(attributes)
        else:
            edge_data.update(edge.attributes or {})

        edges.append(edge_data)

    if driver.graph_operations_interface:
        await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
        await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
        await driver.graph_operations_interface.episodic_edge_save_bulk(
            None, driver, tx, [edge.model_dump() for edge in episodic_edges]
        )
        await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)

    elif driver.provider == GraphProvider.KUZU:
        # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
        episode_query = get_episode_node_save_bulk_query(driver.provider)
        for episode in episodes:
            await tx.run(episode_query, **episode)
        entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
        for node in nodes:
            await tx.run(entity_node_query, **node)
        entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
        for edge in edges:
            await tx.run(entity_edge_query, **edge)
        episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
        for edge in episodic_edges:
            await tx.run(episodic_edge_query, **edge.model_dump())
    else:
        await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
        await tx.run(
            get_entity_node_save_bulk_query(driver.provider, nodes),
            nodes=nodes,
        )
        await tx.run(
            get_episodic_edge_save_bulk_query(driver.provider),
            episodic_edges=[edge.model_dump() for edge in episodic_edges],
        )
        await tx.run(
            get_entity_edge_save_bulk_query(driver.provider),
            entity_edges=edges,
        )


async def extract_nodes_and_edges_bulk(
    clients: GraphitiClients,
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
    edge_type_map: dict[tuple[str, str], list[str]],
    entity_types: dict[str, type[BaseModel]] | None = None,
    excluded_entity_types: list[str] | None = None,
    edge_types: dict[str, type[BaseModel]] | None = None,
    custom_extraction_instructions: str | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
    extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
        *[
            extract_nodes(
                clients,
                episode,
                previous_episodes,
                entity_types=entity_types,
                excluded_entity_types=excluded_entity_types,
                custom_extraction_instructions=custom_extraction_instructions,
            )
            for episode, previous_episodes in episode_tuples
        ]
    )

    extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
        *[
            extract_edges(
                clients,
                episode,
                extracted_nodes_bulk[i],
                previous_episodes,
                edge_type_map=edge_type_map,
                group_id=episode.group_id,
                edge_types=edge_types,
                custom_extraction_instructions=custom_extraction_instructions,
            )
            for i, (episode, previous_episodes) in enumerate(episode_tuples)
        ]
    )

    return extracted_nodes_bulk, extracted_edges_bulk


async def dedupe_nodes_bulk(
    clients: GraphitiClients,
    extracted_nodes: list[list[EntityNode]],
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
    entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
    """Resolve entity duplicates across an in-memory batch using a two-pass strategy.

    1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
       reconciled against the live graph just like the non-batch flow.
    2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
       duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
       can apply to edges and persistence.
    """

    first_pass_results = await semaphore_gather(
        *[
            resolve_extracted_nodes(
                clients,
                nodes,
                episode_tuples[i][0],
                episode_tuples[i][1],
                entity_types,
            )
            for i, nodes in enumerate(extracted_nodes)
        ]
    )

    episode_resolutions: list[tuple[str, list[EntityNode]]] = []
    per_episode_uuid_maps: list[dict[str, str]] = []
    duplicate_pairs: list[tuple[str, str]] = []

    for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
        first_pass_results, episode_tuples, strict=True
    ):
        episode_resolutions.append((episode.uuid, resolved_nodes))
        per_episode_uuid_maps.append(uuid_map)
        duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)

    canonical_nodes: dict[str, EntityNode] = {}
    for _, resolved_nodes in episode_resolutions:
        for node in resolved_nodes:
            # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
            # the MinHash index for the accumulated canonical pool each time. The LRU-backed
            # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
            # but if batches grow significantly we should switch to an incremental index or chunked
            # processing.
            if not canonical_nodes:
                canonical_nodes[node.uuid] = node
                continue

            existing_candidates = list(canonical_nodes.values())
            normalized = _normalize_string_exact(node.name)
            exact_match = next(
                (
                    candidate
                    for candidate in existing_candidates
                    if _normalize_string_exact(candidate.name) == normalized
                ),
                None,
            )
            if exact_match is not None:
                if exact_match.uuid != node.uuid:
                    duplicate_pairs.append((node.uuid, exact_match.uuid))
                continue

            indexes = _build_candidate_indexes(existing_candidates)
            state = DedupResolutionState(
                resolved_nodes=[None],
                uuid_map={},
                unresolved_indices=[],
            )
            _resolve_with_similarity([node], indexes, state)

            resolved = state.resolved_nodes[0]
            if resolved is None:
                canonical_nodes[node.uuid] = node
                continue

            canonical_uuid = resolved.uuid
            canonical_nodes.setdefault(canonical_uuid, resolved)
            if canonical_uuid != node.uuid:
                duplicate_pairs.append((node.uuid, canonical_uuid))

    union_pairs: list[tuple[str, str]] = []
    for uuid_map in per_episode_uuid_maps:
        union_pairs.extend(uuid_map.items())
    union_pairs.extend(duplicate_pairs)

    compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)

    nodes_by_episode: dict[str, list[EntityNode]] = {}
    for episode_uuid, resolved_nodes in episode_resolutions:
        deduped_nodes: list[EntityNode] = []
        seen: set[str] = set()
        for node in resolved_nodes:
            canonical_uuid = compressed_map.get(node.uuid, node.uuid)
            if canonical_uuid in seen:
                continue
            seen.add(canonical_uuid)
            canonical_node = canonical_nodes.get(canonical_uuid)
            if canonical_node is None:
                logger.error(
                    'Canonical node %s missing during batch dedupe; falling back to %s',
                    canonical_uuid,
                    node.uuid,
                )
                canonical_node = node
            deduped_nodes.append(canonical_node)

        nodes_by_episode[episode_uuid] = deduped_nodes

    return nodes_by_episode, compressed_map


async def dedupe_edges_bulk(
    clients: GraphitiClients,
    extracted_edges: list[list[EntityEdge]],
    episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
    _entities: list[EntityNode],
    edge_types: dict[str, type[BaseModel]],
    _edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
    embedder = clients.embedder
    min_score = 0.6

    # generate embeddings
    await semaphore_gather(
        *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
    )

    # Find similar results
    dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
    for i, edges_i in enumerate(extracted_edges):
        existing_edges: list[EntityEdge] = []
        for edges_j in extracted_edges:
            existing_edges += edges_j

        for edge in edges_i:
            candidates: list[EntityEdge] = []
            for existing_edge in existing_edges:
                # Skip self-comparison
                if edge.uuid == existing_edge.uuid:
                    continue
                # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
                # This approach will cast a wider net than BM25, which is ideal for this use case
                if (
                    edge.source_node_uuid != existing_edge.source_node_uuid
                    or edge.target_node_uuid != existing_edge.target_node_uuid
                ):
                    continue

                edge_words = set(edge.fact.lower().split())
                existing_edge_words = set(existing_edge.fact.lower().split())
                has_overlap = not edge_words.isdisjoint(existing_edge_words)
                if has_overlap:
                    candidates.append(existing_edge)
                    continue

                # Check for semantic similarity even if there is no overlap
                similarity = np.dot(
                    normalize_l2(edge.fact_embedding or []),
                    normalize_l2(existing_edge.fact_embedding or []),
                )
                if similarity >= min_score:
                    candidates.append(existing_edge)

            dedupe_tuples.append((episode_tuples[i][0], edge, candidates))

    bulk_edge_resolutions: list[
        tuple[EntityEdge, EntityEdge, list[EntityEdge]]
    ] = await semaphore_gather(
        *[
            resolve_extracted_edge(
                clients.llm_client,
                edge,
                candidates,
                candidates,
                episode,
                edge_types,
                set(edge_types),
            )
            for episode, edge, candidates in dedupe_tuples
        ]
    )

    # For now we won't track edge invalidation
    duplicate_pairs: list[tuple[str, str]] = []
    for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
        episode, edge, candidates = dedupe_tuples[i]
        for duplicate in duplicates:
            duplicate_pairs.append((edge.uuid, duplicate.uuid))

    # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
    compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)

    edge_uuid_map: dict[str, EntityEdge] = {
        edge.uuid: edge for edges in extracted_edges for edge in edges
    }

    edges_by_episode: dict[str, list[EntityEdge]] = {}
    for i, edges in enumerate(extracted_edges):
        episode = episode_tuples[i][0]

        edges_by_episode[episode.uuid] = [
            edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
        ]

    return edges_by_episode


class UnionFind:
    def __init__(self, elements):
        # start each element in its own set
        self.parent = {e: e for e in elements}

    def find(self, x):
        # path‐compression
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self, a, b):
        ra, rb = self.find(a), self.find(b)
        if ra == rb:
            return
        # attach the lexicographically larger root under the smaller
        if ra < rb:
            self.parent[rb] = ra
        else:
            self.parent[ra] = rb


def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
    """
    all_ids: iterable of all entity IDs (strings)
    duplicate_pairs: iterable of (id1, id2) pairs
    returns: dict mapping each id -> lexicographically smallest id in its duplicate set
    """
    all_uuids = set()
    for pair in duplicate_pairs:
        all_uuids.add(pair[0])
        all_uuids.add(pair[1])

    uf = UnionFind(all_uuids)
    for a, b in duplicate_pairs:
        uf.union(a, b)
    # ensure full path‐compression before mapping
    return {uuid: uf.find(uuid) for uuid in all_uuids}


E = typing.TypeVar('E', bound=Edge)


def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
    for edge in edges:
        source_uuid = edge.source_node_uuid
        target_uuid = edge.target_node_uuid
        edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
        edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)

    return edges

```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------

```python
import logging
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock

import pytest

from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
    DedupCandidateIndexes,
    DedupResolutionState,
    _build_candidate_indexes,
    _cached_shingles,
    _has_high_entropy,
    _hash_shingle,
    _jaccard_similarity,
    _lsh_bands,
    _minhash_signature,
    _name_entropy,
    _normalize_name_for_fuzzy,
    _normalize_string_exact,
    _resolve_with_similarity,
    _shingles,
)
from graphiti_core.utils.maintenance.node_operations import (
    _collect_candidate_nodes,
    _resolve_with_llm,
    extract_attributes_from_node,
    extract_attributes_from_nodes,
    resolve_extracted_nodes,
)


def _make_clients():
    driver = MagicMock()
    embedder = MagicMock()
    cross_encoder = MagicMock()
    llm_client = MagicMock()
    llm_generate = AsyncMock()
    llm_client.generate_response = llm_generate

    clients = GraphitiClients.model_construct(  # bypass validation to allow test doubles
        driver=driver,
        embedder=embedder,
        cross_encoder=cross_encoder,
        llm_client=llm_client,
    )

    return clients, llm_generate


def _make_episode(group_id: str = 'group'):
    return EpisodicNode(
        name='episode',
        group_id=group_id,
        source=EpisodeType.message,
        source_description='test',
        content='content',
        valid_at=utc_now(),
    )


@pytest.mark.asyncio
async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
    clients, llm_generate = _make_clients()

    candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])

    async def fake_search(*_, **__):
        return SearchResults(nodes=[candidate])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        fake_search,
    )
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
        AsyncMock(return_value=[]),
    )

    resolved, uuid_map, _ = await resolve_extracted_nodes(
        clients,
        [extracted],
        episode=_make_episode(),
        previous_episodes=[],
    )

    assert resolved[0].uuid == candidate.uuid
    assert uuid_map[extracted.uuid] == candidate.uuid
    llm_generate.assert_not_awaited()


@pytest.mark.asyncio
async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
    clients, llm_generate = _make_clients()
    llm_generate.return_value = {
        'entity_resolutions': [
            {
                'id': 0,
                'duplicate_idx': -1,
                'name': 'Joe',
                'duplicates': [],
            }
        ]
    }

    extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])

    async def fake_search(*_, **__):
        return SearchResults(nodes=[])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        fake_search,
    )
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
        AsyncMock(return_value=[]),
    )

    resolved, uuid_map, _ = await resolve_extracted_nodes(
        clients,
        [extracted],
        episode=_make_episode(),
        previous_episodes=[],
    )

    assert resolved[0].uuid == extracted.uuid
    assert uuid_map[extracted.uuid] == extracted.uuid
    llm_generate.assert_awaited()


@pytest.mark.asyncio
async def test_resolve_nodes_fuzzy_match(monkeypatch):
    clients, llm_generate = _make_clients()

    candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])

    async def fake_search(*_, **__):
        return SearchResults(nodes=[candidate])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        fake_search,
    )
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
        AsyncMock(return_value=[]),
    )

    resolved, uuid_map, _ = await resolve_extracted_nodes(
        clients,
        [extracted],
        episode=_make_episode(),
        previous_episodes=[],
    )

    assert resolved[0].uuid == candidate.uuid
    assert uuid_map[extracted.uuid] == candidate.uuid
    llm_generate.assert_not_awaited()


@pytest.mark.asyncio
async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
    clients, _ = _make_clients()

    candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
    override_duplicate = EntityNode(
        uuid=candidate.uuid,
        name='Alice Alt',
        group_id='group',
        labels=['Entity'],
    )
    extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])

    search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.search',
        search_mock,
    )

    result = await _collect_candidate_nodes(
        clients,
        [extracted],
        existing_nodes_override=[override_duplicate],
    )

    assert len(result) == 1
    assert result[0].uuid == candidate.uuid
    search_mock.assert_awaited()


def test_build_candidate_indexes_populates_structures():
    candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])

    normalized_key = candidate.name.lower()
    assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
    assert indexes.nodes_by_uuid[candidate.uuid] is candidate
    assert candidate.uuid in indexes.shingles_by_candidate
    assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())


def test_normalize_helpers():
    assert _normalize_string_exact('  Alice   Smith ') == 'alice smith'
    assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'


def test_name_entropy_variants():
    assert _name_entropy('alice') > _name_entropy('aaaaa')
    assert _name_entropy('') == 0.0


def test_has_high_entropy_rules():
    assert _has_high_entropy('meaningful name') is True
    assert _has_high_entropy('aa') is False


def test_shingles_and_cache():
    raw = 'alice'
    shingle_set = _shingles(raw)
    assert shingle_set == {'ali', 'lic', 'ice'}
    assert _cached_shingles(raw) == shingle_set
    assert _cached_shingles(raw) is _cached_shingles(raw)


def test_hash_minhash_and_lsh():
    shingles = {'abc', 'bcd', 'cde'}
    signature = _minhash_signature(shingles)
    assert len(signature) == 32
    bands = _lsh_bands(signature)
    assert all(len(band) == 4 for band in bands)
    hashed = {_hash_shingle(s, 0) for s in shingles}
    assert len(hashed) == len(shingles)


def test_jaccard_similarity_edges():
    a = {'a', 'b'}
    b = {'a', 'c'}
    assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
    assert _jaccard_similarity(set(), set()) == 1.0
    assert _jaccard_similarity(a, set()) == 0.0


def test_resolve_with_similarity_exact_match_updates_state():
    candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])

    _resolve_with_similarity([extracted], indexes, state)

    assert state.resolved_nodes[0].uuid == candidate.uuid
    assert state.uuid_map[extracted.uuid] == candidate.uuid
    assert state.unresolved_indices == []
    assert state.duplicate_pairs == [(extracted, candidate)]


def test_resolve_with_similarity_low_entropy_defers_resolution():
    extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
    indexes = DedupCandidateIndexes(
        existing_nodes=[],
        nodes_by_uuid={},
        normalized_existing=defaultdict(list),
        shingles_by_candidate={},
        lsh_buckets=defaultdict(list),
    )
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])

    _resolve_with_similarity([extracted], indexes, state)

    assert state.resolved_nodes[0] is None
    assert state.unresolved_indices == [0]
    assert state.duplicate_pairs == []


def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
    candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
    candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
    extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate1, candidate2])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])

    _resolve_with_similarity([extracted], indexes, state)

    assert state.resolved_nodes[0] is None
    assert state.unresolved_indices == [0]
    assert state.duplicate_pairs == []


@pytest.mark.asyncio
async def test_resolve_with_llm_updates_unresolved(monkeypatch):
    extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
    candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    captured_context = {}

    def fake_prompt_nodes(context):
        captured_context.update(context)
        return ['prompt']

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        fake_prompt_nodes,
    )

    async def fake_generate_response(*_, **__):
        return {
            'entity_resolutions': [
                {
                    'id': 0,
                    'duplicate_idx': 0,
                    'name': 'Dizzy Gillespie',
                    'duplicates': [0],
                }
            ]
        }

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)

    await _resolve_with_llm(
        llm_client,
        [extracted],
        indexes,
        state,
        episode=_make_episode(),
        previous_episodes=[],
        entity_types=None,
    )

    assert state.resolved_nodes[0].uuid == candidate.uuid
    assert state.uuid_map[extracted.uuid] == candidate.uuid
    assert captured_context['existing_nodes'][0]['idx'] == 0
    assert isinstance(captured_context['existing_nodes'], list)
    assert state.duplicate_pairs == [(extracted, candidate)]


@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
    extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        lambda context: ['prompt'],
    )

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'entity_resolutions': [
                {
                    'id': 5,
                    'duplicate_idx': -1,
                    'name': 'Dexter',
                    'duplicates': [],
                }
            ]
        }
    )

    with caplog.at_level(logging.WARNING):
        await _resolve_with_llm(
            llm_client,
            [extracted],
            indexes,
            state,
            episode=_make_episode(),
            previous_episodes=[],
            entity_types=None,
        )

    assert state.resolved_nodes[0] is None
    assert 'Skipping invalid LLM dedupe id 5' in caplog.text


@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
    extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
    candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([candidate])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        lambda context: ['prompt'],
    )

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'entity_resolutions': [
                {
                    'id': 0,
                    'duplicate_idx': 0,
                    'name': 'Dizzy Gillespie',
                    'duplicates': [0],
                },
                {
                    'id': 0,
                    'duplicate_idx': -1,
                    'name': 'Dizzy',
                    'duplicates': [],
                },
            ]
        }
    )

    await _resolve_with_llm(
        llm_client,
        [extracted],
        indexes,
        state,
        episode=_make_episode(),
        previous_episodes=[],
        entity_types=None,
    )

    assert state.resolved_nodes[0].uuid == candidate.uuid
    assert state.uuid_map[extracted.uuid] == candidate.uuid
    assert state.duplicate_pairs == [(extracted, candidate)]


@pytest.mark.asyncio
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
    extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])

    indexes = _build_candidate_indexes([])
    state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])

    monkeypatch.setattr(
        'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
        lambda context: ['prompt'],
    )

    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={
            'entity_resolutions': [
                {
                    'id': 0,
                    'duplicate_idx': 10,
                    'name': 'Dexter',
                    'duplicates': [],
                }
            ]
        }
    )

    await _resolve_with_llm(
        llm_client,
        [extracted],
        indexes,
        state,
        episode=_make_episode(),
        previous_episodes=[],
        entity_types=None,
    )

    assert state.resolved_nodes[0] == extracted
    assert state.uuid_map[extracted.uuid] == extracted.uuid
    assert state.duplicate_pairs == []


@pytest.mark.asyncio
async def test_extract_attributes_without_callback_generates_summary():
    """Test that summary is generated when no callback is provided (default behavior)."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'Generated summary', 'attributes': {}}
    )

    node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
    episode = _make_episode()

    result = await extract_attributes_from_node(
        llm_client,
        node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=None,  # No callback provided
    )

    # Summary should be generated
    assert result.summary == 'Generated summary'
    # LLM should have been called for summary
    assert llm_client.generate_response.call_count == 1


@pytest.mark.asyncio
async def test_extract_attributes_with_callback_skip_summary():
    """Test that summary is NOT regenerated when callback returns False."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'This should not be used', 'attributes': {}}
    )

    node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
    episode = _make_episode()

    # Callback that always returns False (skip summary generation)
    async def skip_summary_filter(node: EntityNode) -> bool:
        return False

    result = await extract_attributes_from_node(
        llm_client,
        node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=skip_summary_filter,
    )

    # Summary should remain unchanged
    assert result.summary == 'Old summary'
    # LLM should NOT have been called for summary
    assert llm_client.generate_response.call_count == 0


@pytest.mark.asyncio
async def test_extract_attributes_with_callback_generate_summary():
    """Test that summary is regenerated when callback returns True."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'New generated summary', 'attributes': {}}
    )

    node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
    episode = _make_episode()

    # Callback that always returns True (generate summary)
    async def generate_summary_filter(node: EntityNode) -> bool:
        return True

    result = await extract_attributes_from_node(
        llm_client,
        node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=generate_summary_filter,
    )

    # Summary should be updated
    assert result.summary == 'New generated summary'
    # LLM should have been called for summary
    assert llm_client.generate_response.call_count == 1


@pytest.mark.asyncio
async def test_extract_attributes_with_selective_callback():
    """Test callback that selectively skips summaries based on node properties."""
    llm_client = MagicMock()
    llm_client.generate_response = AsyncMock(
        return_value={'summary': 'Generated summary', 'attributes': {}}
    )

    user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
    topic_node = EntityNode(
        name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
    )

    episode = _make_episode()

    # Callback that skips User nodes but generates for others
    async def selective_filter(node: EntityNode) -> bool:
        return 'User' not in node.labels

    result_user = await extract_attributes_from_node(
        llm_client,
        user_node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=selective_filter,
    )

    result_topic = await extract_attributes_from_node(
        llm_client,
        topic_node,
        episode=episode,
        previous_episodes=[],
        entity_type=None,
        should_summarize_node=selective_filter,
    )

    # User summary should remain unchanged
    assert result_user.summary == 'Old'
    # Topic summary should be generated
    assert result_topic.summary == 'Generated summary'
    # LLM should have been called only once (for topic)
    assert llm_client.generate_response.call_count == 1


@pytest.mark.asyncio
async def test_extract_attributes_from_nodes_with_callback():
    """Test that callback is properly passed through extract_attributes_from_nodes."""
    clients, _ = _make_clients()
    clients.llm_client.generate_response = AsyncMock(
        return_value={'summary': 'New summary', 'attributes': {}}
    )
    clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
    clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])

    node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
    node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')

    episode = _make_episode()

    call_tracker = []

    # Callback that tracks which nodes it's called with
    async def tracking_filter(node: EntityNode) -> bool:
        call_tracker.append(node.name)
        return 'User' not in node.labels

    results = await extract_attributes_from_nodes(
        clients,
        [node1, node2],
        episode=episode,
        previous_episodes=[],
        entity_types=None,
        should_summarize_node=tracking_filter,
    )

    # Callback should have been called for both nodes
    assert len(call_tracker) == 2
    assert 'Node1' in call_tracker
    assert 'Node2' in call_tracker

    # Node1 (User) should keep old summary, Node2 (Topic) should get new summary
    node1_result = next(n for n in results if n.name == 'Node1')
    node2_result = next(n for n in results if n.name == 'Node2')

    assert node1_result.summary == 'Old1'
    assert node2_result.summary == 'New summary'

```

--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import BaseModel

from graphiti_core.llm_client.config import LLMConfig, ModelSize
from graphiti_core.llm_client.errors import RateLimitError
from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
from graphiti_core.prompts.models import Message


# Test model for response testing
class ResponseModel(BaseModel):
    """Test model for response testing."""

    test_field: str
    optional_field: int = 0


@pytest.fixture
def mock_gemini_client():
    """Fixture to mock the Google Gemini client."""
    with patch('google.genai.Client') as mock_client:
        # Setup mock instance and its methods
        mock_instance = mock_client.return_value
        mock_instance.aio = MagicMock()
        mock_instance.aio.models = MagicMock()
        mock_instance.aio.models.generate_content = AsyncMock()
        yield mock_instance


@pytest.fixture
def gemini_client(mock_gemini_client):
    """Fixture to create a GeminiClient with a mocked client."""
    config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
    client = GeminiClient(config=config, cache=False)
    # Replace the client's client with our mock to ensure we're using the mock
    client.client = mock_gemini_client
    return client


class TestGeminiClientInitialization:
    """Tests for GeminiClient initialization."""

    @patch('google.genai.Client')
    def test_init_with_config(self, mock_client):
        """Test initialization with a config object."""
        config = LLMConfig(
            api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
        )
        client = GeminiClient(config=config, cache=False, max_tokens=1000)

        assert client.config == config
        assert client.model == 'test-model'
        assert client.temperature == 0.5
        assert client.max_tokens == 1000

    @patch('google.genai.Client')
    def test_init_with_default_model(self, mock_client):
        """Test initialization with default model when none is provided."""
        config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
        client = GeminiClient(config=config, cache=False)

        assert client.model == DEFAULT_MODEL

    @patch('google.genai.Client')
    def test_init_without_config(self, mock_client):
        """Test initialization without a config uses defaults."""
        client = GeminiClient(cache=False)

        assert client.config is not None
        # When no config.model is set, it will be None, not DEFAULT_MODEL
        assert client.model is None

    @patch('google.genai.Client')
    def test_init_with_thinking_config(self, mock_client):
        """Test initialization with thinking config."""
        with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
            thinking_config = mock_thinking_config.return_value
            client = GeminiClient(thinking_config=thinking_config)
            assert client.thinking_config == thinking_config


class TestGeminiClientGenerateResponse:
    """Tests for GeminiClient generate_response method."""

    @pytest.mark.asyncio
    async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
        """Test successful response generation with simple text."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response text'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method
        messages = [Message(role='user', content='Test message')]
        result = await gemini_client.generate_response(messages)

        # Assertions
        assert isinstance(result, dict)
        assert result['content'] == 'Test response text'
        mock_gemini_client.aio.models.generate_content.assert_called_once()

    @pytest.mark.asyncio
    async def test_generate_response_with_structured_output(
        self, gemini_client, mock_gemini_client
    ):
        """Test response generation with structured output."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method
        messages = [
            Message(role='system', content='System message'),
            Message(role='user', content='User message'),
        ]
        result = await gemini_client.generate_response(
            messages=messages, response_model=ResponseModel
        )

        # Assertions
        assert isinstance(result, dict)
        assert result['test_field'] == 'test_value'
        assert result['optional_field'] == 42
        mock_gemini_client.aio.models.generate_content.assert_called_once()

    @pytest.mark.asyncio
    async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
        """Test response generation with system message handling."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Response with system context'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method
        messages = [
            Message(role='system', content='System message'),
            Message(role='user', content='User message'),
        ]
        await gemini_client.generate_response(messages)

        # Verify system message is processed correctly
        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        assert 'System message' in config.system_instruction

    @pytest.mark.asyncio
    async def test_get_model_for_size(self, gemini_client):
        """Test model selection based on size."""
        # Test small model
        small_model = gemini_client._get_model_for_size(ModelSize.small)
        assert small_model == DEFAULT_SMALL_MODEL

        # Test medium/large model
        medium_model = gemini_client._get_model_for_size(ModelSize.medium)
        assert medium_model == gemini_client.model

    @pytest.mark.asyncio
    async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
        """Test handling of rate limit errors."""
        # Setup mock to raise rate limit error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'Rate limit exceeded'
        )

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(RateLimitError):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
        """Test handling of quota errors."""
        # Setup mock to raise quota error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'Quota exceeded for requests'
        )

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(RateLimitError):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
        """Test handling of resource exhausted errors."""
        # Setup mock to raise resource exhausted error
        mock_gemini_client.aio.models.generate_content.side_effect = Exception(
            'resource_exhausted: Request limit exceeded'
        )

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(RateLimitError):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
        """Test handling of safety blocks."""
        # Setup mock response with safety block
        mock_candidate = MagicMock()
        mock_candidate.finish_reason = 'SAFETY'
        mock_candidate.safety_ratings = [
            MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
        ]

        mock_response = MagicMock()
        mock_response.candidates = [mock_candidate]
        mock_response.prompt_feedback = None
        mock_response.text = ''
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception, match='Content blocked by safety filters'):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
        """Test handling of prompt blocks."""
        # Setup mock response with prompt block
        mock_prompt_feedback = MagicMock()
        mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'

        mock_response = MagicMock()
        mock_response.candidates = []
        mock_response.prompt_feedback = mock_prompt_feedback
        mock_response.text = ''
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception, match='Content blocked by safety filters'):
            await gemini_client.generate_response(messages)

    @pytest.mark.asyncio
    async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
        """Test handling of structured output parsing errors."""
        # Setup mock response with invalid JSON that will exhaust retries
        mock_response = MagicMock()
        mock_response.text = 'Invalid JSON that cannot be parsed'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception - should exhaust retries
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception):  # noqa: B017
            await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have called generate_content MAX_RETRIES times (2 attempts total)
        assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

    @pytest.mark.asyncio
    async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
        """Test that safety blocks are not retried."""
        # Setup mock response with safety block
        mock_candidate = MagicMock()
        mock_candidate.finish_reason = 'SAFETY'
        mock_candidate.safety_ratings = [
            MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
        ]

        mock_response = MagicMock()
        mock_response.candidates = [mock_candidate]
        mock_response.prompt_feedback = None
        mock_response.text = ''
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check that it doesn't retry
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception, match='Content blocked by safety filters'):
            await gemini_client.generate_response(messages)

        # Should only be called once (no retries for safety blocks)
        assert mock_gemini_client.aio.models.generate_content.call_count == 1

    @pytest.mark.asyncio
    async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
        """Test retry behavior on validation error."""
        # First call returns invalid JSON, second call returns valid data
        mock_response1 = MagicMock()
        mock_response1.text = 'Invalid JSON that cannot be parsed'
        mock_response1.candidates = []
        mock_response1.prompt_feedback = None

        mock_response2 = MagicMock()
        mock_response2.text = '{"test_field": "correct_value"}'
        mock_response2.candidates = []
        mock_response2.prompt_feedback = None

        mock_gemini_client.aio.models.generate_content.side_effect = [
            mock_response1,
            mock_response2,
        ]

        # Call method
        messages = [Message(role='user', content='Test message')]
        result = await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have called generate_content twice due to retry
        assert mock_gemini_client.aio.models.generate_content.call_count == 2
        assert result['test_field'] == 'correct_value'

    @pytest.mark.asyncio
    async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
        """Test behavior when max retries are exceeded."""
        # Setup mock to always return invalid JSON
        mock_response = MagicMock()
        mock_response.text = 'Invalid JSON that cannot be parsed'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception):  # noqa: B017
            await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have called generate_content MAX_RETRIES times (2 attempts total)
        assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

    @pytest.mark.asyncio
    async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
        """Test handling of empty responses."""
        # Setup mock response with no text
        mock_response = MagicMock()
        mock_response.text = ''
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method with structured output and check exception
        messages = [Message(role='user', content='Test message')]
        with pytest.raises(Exception):  # noqa: B017
            await gemini_client.generate_response(messages, response_model=ResponseModel)

        # Should have exhausted retries due to empty response (2 attempts total)
        assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES

    @pytest.mark.asyncio
    async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
        """Test that explicit max_tokens parameter takes precedence over all other values."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method with custom max tokens (should take precedence)
        messages = [Message(role='user', content='Test message')]
        await gemini_client.generate_response(messages, max_tokens=500)

        # Verify explicit max_tokens parameter takes precedence
        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        # Explicit parameter should override everything else
        assert config.max_output_tokens == 500

    @pytest.mark.asyncio
    async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
        """Test max_tokens precedence when no explicit parameter is provided."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Test case 1: No explicit max_tokens, has instance max_tokens
        config = LLMConfig(
            api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
        )
        client = GeminiClient(
            config=config, cache=False, max_tokens=2000, client=mock_gemini_client
        )

        messages = [Message(role='user', content='Test message')]
        await client.generate_response(messages)

        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        # Instance max_tokens should be used
        assert config.max_output_tokens == 2000

        # Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
        config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
        client = GeminiClient(config=config, cache=False, client=mock_gemini_client)

        messages = [Message(role='user', content='Test message')]
        await client.generate_response(messages)

        call_args = mock_gemini_client.aio.models.generate_content.call_args
        config = call_args[1]['config']
        # Model mapping should be used
        assert config.max_output_tokens == 65536

    @pytest.mark.asyncio
    async def test_model_size_selection(self, gemini_client, mock_gemini_client):
        """Test that the correct model is selected based on model size."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Call method with small model size
        messages = [Message(role='user', content='Test message')]
        await gemini_client.generate_response(messages, model_size=ModelSize.small)

        # Verify correct model is used
        call_args = mock_gemini_client.aio.models.generate_content.call_args
        assert call_args[1]['model'] == DEFAULT_SMALL_MODEL

    @pytest.mark.asyncio
    async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
        """Test that different Gemini models use their correct max tokens."""
        # Setup mock response
        mock_response = MagicMock()
        mock_response.text = 'Test response'
        mock_response.candidates = []
        mock_response.prompt_feedback = None
        mock_gemini_client.aio.models.generate_content.return_value = mock_response

        # Test data: (model_name, expected_max_tokens)
        test_cases = [
            ('gemini-2.5-flash', 65536),
            ('gemini-2.5-pro', 65536),
            ('gemini-2.5-flash-lite', 64000),
            ('gemini-2.0-flash', 8192),
            ('gemini-1.5-pro', 8192),
            ('gemini-1.5-flash', 8192),
            ('unknown-model', 8192),  # Fallback case
        ]

        for model_name, expected_max_tokens in test_cases:
            # Create client with specific model, no explicit max_tokens to test mapping
            config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
            client = GeminiClient(config=config, cache=False, client=mock_gemini_client)

            # Call method without explicit max_tokens to test model mapping fallback
            messages = [Message(role='user', content='Test message')]
            await client.generate_response(messages)

            # Verify correct max tokens is used from model mapping
            call_args = mock_gemini_client.aio.models.generate_content.call_args
            config = call_args[1]['config']
            assert config.max_output_tokens == expected_max_tokens, (
                f'Model {model_name} should use {expected_max_tokens} tokens'
            )


if __name__ == '__main__':
    pytest.main(['-v', 'test_gemini_client.py'])

```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from collections.abc import Awaitable, Callable
from time import time
from typing import Any

from pydantic import BaseModel

from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import (
    EntityNode,
    EpisodeType,
    EpisodicNode,
    create_entity_node_embeddings,
)
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
    EntitySummary,
    ExtractedEntities,
    ExtractedEntity,
    MissedEntities,
)
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.content_chunking import (
    chunk_json_content,
    chunk_message_content,
    chunk_text_content,
    should_chunk,
)
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
    DedupCandidateIndexes,
    DedupResolutionState,
    _build_candidate_indexes,
    _resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
    filter_existing_duplicate_of_edges,
)
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence

logger = logging.getLogger(__name__)

NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]


async def extract_nodes_reflexion(
    llm_client: LLMClient,
    episode: EpisodicNode,
    previous_episodes: list[EpisodicNode],
    node_names: list[str],
    group_id: str | None = None,
) -> list[str]:
    # Prepare context for LLM
    context = {
        'episode_content': episode.content,
        'previous_episodes': [ep.content for ep in previous_episodes],
        'extracted_entities': node_names,
    }

    llm_response = await llm_client.generate_response(
        prompt_library.extract_nodes.reflexion(context),
        MissedEntities,
        group_id=group_id,
        prompt_name='extract_nodes.reflexion',
    )
    missed_entities = llm_response.get('missed_entities', [])

    return missed_entities


async def extract_nodes(
    clients: GraphitiClients,
    episode: EpisodicNode,
    previous_episodes: list[EpisodicNode],
    entity_types: dict[str, type[BaseModel]] | None = None,
    excluded_entity_types: list[str] | None = None,
    custom_extraction_instructions: str | None = None,
) -> list[EntityNode]:
    """Extract entity nodes from an episode with adaptive chunking.

    For high-density content (many entities per token), the content is chunked
    and processed in parallel to avoid LLM timeouts and truncation issues.
    """
    start = time()
    llm_client = clients.llm_client

    # Build entity types context
    entity_types_context = _build_entity_types_context(entity_types)

    # Build base context
    context = {
        'episode_content': episode.content,
        'episode_timestamp': episode.valid_at.isoformat(),
        'previous_episodes': [ep.content for ep in previous_episodes],
        'custom_extraction_instructions': custom_extraction_instructions or '',
        'entity_types': entity_types_context,
        'source_description': episode.source_description,
    }

    # Check if chunking is needed (based on entity density)
    if should_chunk(episode.content, episode.source):
        extracted_entities = await _extract_nodes_chunked(llm_client, episode, context)
    else:
        extracted_entities = await _extract_nodes_single(llm_client, episode, context)

    # Filter empty names
    filtered_entities = [e for e in extracted_entities if e.name.strip()]

    end = time()
    logger.debug(f'Extracted {len(filtered_entities)} entities in {(end - start) * 1000:.0f} ms')

    # Convert to EntityNode objects
    extracted_nodes = _create_entity_nodes(
        filtered_entities, entity_types_context, excluded_entity_types, episode
    )

    logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
    return extracted_nodes


def _build_entity_types_context(
    entity_types: dict[str, type[BaseModel]] | None,
) -> list[dict]:
    """Build entity types context with ID mappings."""
    entity_types_context = [
        {
            'entity_type_id': 0,
            'entity_type_name': 'Entity',
            'entity_type_description': (
                'Default entity classification. Use this entity type '
                'if the entity is not one of the other listed types.'
            ),
        }
    ]

    if entity_types is not None:
        entity_types_context += [
            {
                'entity_type_id': i + 1,
                'entity_type_name': type_name,
                'entity_type_description': type_model.__doc__,
            }
            for i, (type_name, type_model) in enumerate(entity_types.items())
        ]

    return entity_types_context


async def _extract_nodes_single(
    llm_client: LLMClient,
    episode: EpisodicNode,
    context: dict,
) -> list[ExtractedEntity]:
    """Extract entities using a single LLM call."""
    llm_response = await _call_extraction_llm(llm_client, episode, context)
    response_object = ExtractedEntities(**llm_response)
    return response_object.extracted_entities


async def _extract_nodes_chunked(
    llm_client: LLMClient,
    episode: EpisodicNode,
    context: dict,
) -> list[ExtractedEntity]:
    """Extract entities from large content using chunking."""
    # Chunk the content based on episode type
    if episode.source == EpisodeType.json:
        chunks = chunk_json_content(episode.content)
    elif episode.source == EpisodeType.message:
        chunks = chunk_message_content(episode.content)
    else:
        chunks = chunk_text_content(episode.content)

    logger.debug(f'Chunked content into {len(chunks)} chunks for entity extraction')

    # Extract entities from each chunk in parallel
    chunk_results = await semaphore_gather(
        *[_extract_from_chunk(llm_client, chunk, context, episode) for chunk in chunks]
    )

    # Merge and deduplicate entities across chunks
    merged_entities = _merge_extracted_entities(chunk_results)
    logger.debug(
        f'Merged {sum(len(r) for r in chunk_results)} entities into {len(merged_entities)} unique'
    )

    return merged_entities


async def _extract_from_chunk(
    llm_client: LLMClient,
    chunk: str,
    base_context: dict,
    episode: EpisodicNode,
) -> list[ExtractedEntity]:
    """Extract entities from a single chunk."""
    chunk_context = {**base_context, 'episode_content': chunk}
    llm_response = await _call_extraction_llm(llm_client, episode, chunk_context)
    return ExtractedEntities(**llm_response).extracted_entities


async def _call_extraction_llm(
    llm_client: LLMClient,
    episode: EpisodicNode,
    context: dict,
) -> dict:
    """Call the appropriate extraction prompt based on episode type."""
    if episode.source == EpisodeType.message:
        prompt = prompt_library.extract_nodes.extract_message(context)
        prompt_name = 'extract_nodes.extract_message'
    elif episode.source == EpisodeType.text:
        prompt = prompt_library.extract_nodes.extract_text(context)
        prompt_name = 'extract_nodes.extract_text'
    elif episode.source == EpisodeType.json:
        prompt = prompt_library.extract_nodes.extract_json(context)
        prompt_name = 'extract_nodes.extract_json'
    else:
        # Fallback to text extraction
        prompt = prompt_library.extract_nodes.extract_text(context)
        prompt_name = 'extract_nodes.extract_text'

    return await llm_client.generate_response(
        prompt,
        response_model=ExtractedEntities,
        group_id=episode.group_id,
        prompt_name=prompt_name,
    )


def _merge_extracted_entities(
    chunk_results: list[list[ExtractedEntity]],
) -> list[ExtractedEntity]:
    """Merge entities from multiple chunks, deduplicating by normalized name.

    When duplicates occur, prefer the first occurrence (maintains ordering).
    """
    seen_names: set[str] = set()
    merged: list[ExtractedEntity] = []

    for entities in chunk_results:
        for entity in entities:
            normalized = entity.name.strip().lower()
            if normalized and normalized not in seen_names:
                seen_names.add(normalized)
                merged.append(entity)

    return merged


def _create_entity_nodes(
    extracted_entities: list[ExtractedEntity],
    entity_types_context: list[dict],
    excluded_entity_types: list[str] | None,
    episode: EpisodicNode,
) -> list[EntityNode]:
    """Convert ExtractedEntity objects to EntityNode objects."""
    extracted_nodes = []

    for extracted_entity in extracted_entities:
        type_id = extracted_entity.entity_type_id
        if 0 <= type_id < len(entity_types_context):
            entity_type_name = entity_types_context[type_id].get('entity_type_name')
        else:
            entity_type_name = 'Entity'

        # Check if this entity type should be excluded
        if excluded_entity_types and entity_type_name in excluded_entity_types:
            logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
            continue

        labels: list[str] = list({'Entity', str(entity_type_name)})

        new_node = EntityNode(
            name=extracted_entity.name,
            group_id=episode.group_id,
            labels=labels,
            summary='',
            created_at=utc_now(),
        )
        extracted_nodes.append(new_node)
        logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')

    return extracted_nodes


async def _collect_candidate_nodes(
    clients: GraphitiClients,
    extracted_nodes: list[EntityNode],
    existing_nodes_override: list[EntityNode] | None,
) -> list[EntityNode]:
    """Search per extracted name and return unique candidates with overrides honored in order."""
    search_results: list[SearchResults] = await semaphore_gather(
        *[
            search(
                clients=clients,
                query=node.name,
                group_ids=[node.group_id],
                search_filter=SearchFilters(),
                config=NODE_HYBRID_SEARCH_RRF,
            )
            for node in extracted_nodes
        ]
    )

    candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]

    if existing_nodes_override is not None:
        candidate_nodes.extend(existing_nodes_override)

    seen_candidate_uuids: set[str] = set()
    ordered_candidates: list[EntityNode] = []
    for candidate in candidate_nodes:
        if candidate.uuid in seen_candidate_uuids:
            continue
        seen_candidate_uuids.add(candidate.uuid)
        ordered_candidates.append(candidate)

    return ordered_candidates


async def _resolve_with_llm(
    llm_client: LLMClient,
    extracted_nodes: list[EntityNode],
    indexes: DedupCandidateIndexes,
    state: DedupResolutionState,
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
    entity_types: dict[str, type[BaseModel]] | None,
) -> None:
    """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.

    The guardrails below defensively ignore malformed or duplicate LLM responses so the
    ingestion workflow remains deterministic even when the model misbehaves.
    """
    if not state.unresolved_indices:
        return

    entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}

    llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]

    extracted_nodes_context = [
        {
            'id': i,
            'name': node.name,
            'entity_type': node.labels,
            'entity_type_description': entity_types_dict.get(
                next((item for item in node.labels if item != 'Entity'), '')
            ).__doc__
            or 'Default Entity Type',
        }
        for i, node in enumerate(llm_extracted_nodes)
    ]

    sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
    logger.debug(
        'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
        len(llm_extracted_nodes),
        len(llm_extracted_nodes) - 1,
        sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
    )
    if llm_extracted_nodes:
        sample_size = min(3, len(extracted_nodes_context))
        logger.debug(
            'First %d entities: %s',
            sample_size,
            [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
        )
        if len(extracted_nodes_context) > 3:
            logger.debug(
                'Last %d entities: %s',
                sample_size,
                [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
            )

    existing_nodes_context = [
        {
            **{
                'idx': i,
                'name': candidate.name,
                'entity_types': candidate.labels,
            },
            **candidate.attributes,
        }
        for i, candidate in enumerate(indexes.existing_nodes)
    ]

    context = {
        'extracted_nodes': extracted_nodes_context,
        'existing_nodes': existing_nodes_context,
        'episode_content': episode.content if episode is not None else '',
        'previous_episodes': (
            [ep.content for ep in previous_episodes] if previous_episodes is not None else []
        ),
    }

    llm_response = await llm_client.generate_response(
        prompt_library.dedupe_nodes.nodes(context),
        response_model=NodeResolutions,
        prompt_name='dedupe_nodes.nodes',
    )

    node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions

    valid_relative_range = range(len(state.unresolved_indices))
    processed_relative_ids: set[int] = set()

    received_ids = {r.id for r in node_resolutions}
    expected_ids = set(valid_relative_range)
    missing_ids = expected_ids - received_ids
    extra_ids = received_ids - expected_ids

    logger.debug(
        'Received %d resolutions for %d entities',
        len(node_resolutions),
        len(state.unresolved_indices),
    )

    if missing_ids:
        logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))

    if extra_ids:
        logger.warning(
            'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
            len(state.unresolved_indices) - 1,
            sorted(extra_ids),
            sorted(received_ids),
        )

    for resolution in node_resolutions:
        relative_id: int = resolution.id
        duplicate_idx: int = resolution.duplicate_idx

        if relative_id not in valid_relative_range:
            logger.warning(
                'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
                relative_id,
                len(state.unresolved_indices) - 1,
                len(node_resolutions),
            )
            continue

        if relative_id in processed_relative_ids:
            logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
            continue
        processed_relative_ids.add(relative_id)

        original_index = state.unresolved_indices[relative_id]
        extracted_node = extracted_nodes[original_index]

        resolved_node: EntityNode
        if duplicate_idx == -1:
            resolved_node = extracted_node
        elif 0 <= duplicate_idx < len(indexes.existing_nodes):
            resolved_node = indexes.existing_nodes[duplicate_idx]
        else:
            logger.warning(
                'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
                duplicate_idx,
                extracted_node.uuid,
            )
            resolved_node = extracted_node

        state.resolved_nodes[original_index] = resolved_node
        state.uuid_map[extracted_node.uuid] = resolved_node.uuid
        if resolved_node.uuid != extracted_node.uuid:
            state.duplicate_pairs.append((extracted_node, resolved_node))


async def resolve_extracted_nodes(
    clients: GraphitiClients,
    extracted_nodes: list[EntityNode],
    episode: EpisodicNode | None = None,
    previous_episodes: list[EpisodicNode] | None = None,
    entity_types: dict[str, type[BaseModel]] | None = None,
    existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
    """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
    llm_client = clients.llm_client
    driver = clients.driver
    existing_nodes = await _collect_candidate_nodes(
        clients,
        extracted_nodes,
        existing_nodes_override,
    )

    indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)

    state = DedupResolutionState(
        resolved_nodes=[None] * len(extracted_nodes),
        uuid_map={},
        unresolved_indices=[],
    )

    _resolve_with_similarity(extracted_nodes, indexes, state)

    await _resolve_with_llm(
        llm_client,
        extracted_nodes,
        indexes,
        state,
        episode,
        previous_episodes,
        entity_types,
    )

    for idx, node in enumerate(extracted_nodes):
        if state.resolved_nodes[idx] is None:
            state.resolved_nodes[idx] = node
            state.uuid_map[node.uuid] = node.uuid

    logger.debug(
        'Resolved nodes: %s',
        [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
    )

    new_node_duplicates: list[
        tuple[EntityNode, EntityNode]
    ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)

    return (
        [node for node in state.resolved_nodes if node is not None],
        state.uuid_map,
        new_node_duplicates,
    )


async def extract_attributes_from_nodes(
    clients: GraphitiClients,
    nodes: list[EntityNode],
    episode: EpisodicNode | None = None,
    previous_episodes: list[EpisodicNode] | None = None,
    entity_types: dict[str, type[BaseModel]] | None = None,
    should_summarize_node: NodeSummaryFilter | None = None,
) -> list[EntityNode]:
    llm_client = clients.llm_client
    embedder = clients.embedder
    updated_nodes: list[EntityNode] = await semaphore_gather(
        *[
            extract_attributes_from_node(
                llm_client,
                node,
                episode,
                previous_episodes,
                (
                    entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
                    if entity_types is not None
                    else None
                ),
                should_summarize_node,
            )
            for node in nodes
        ]
    )

    await create_entity_node_embeddings(embedder, updated_nodes)

    return updated_nodes


async def extract_attributes_from_node(
    llm_client: LLMClient,
    node: EntityNode,
    episode: EpisodicNode | None = None,
    previous_episodes: list[EpisodicNode] | None = None,
    entity_type: type[BaseModel] | None = None,
    should_summarize_node: NodeSummaryFilter | None = None,
) -> EntityNode:
    # Extract attributes if entity type is defined and has attributes
    llm_response = await _extract_entity_attributes(
        llm_client, node, episode, previous_episodes, entity_type
    )

    # Extract summary if needed
    await _extract_entity_summary(
        llm_client, node, episode, previous_episodes, should_summarize_node
    )

    node.attributes.update(llm_response)

    return node


async def _extract_entity_attributes(
    llm_client: LLMClient,
    node: EntityNode,
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
    entity_type: type[BaseModel] | None,
) -> dict[str, Any]:
    if entity_type is None or len(entity_type.model_fields) == 0:
        return {}

    attributes_context = _build_episode_context(
        # should not include summary
        node_data={
            'name': node.name,
            'entity_types': node.labels,
            'attributes': node.attributes,
        },
        episode=episode,
        previous_episodes=previous_episodes,
    )

    llm_response = await llm_client.generate_response(
        prompt_library.extract_nodes.extract_attributes(attributes_context),
        response_model=entity_type,
        model_size=ModelSize.small,
        group_id=node.group_id,
        prompt_name='extract_nodes.extract_attributes',
    )

    # validate response
    entity_type(**llm_response)

    return llm_response


async def _extract_entity_summary(
    llm_client: LLMClient,
    node: EntityNode,
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
    should_summarize_node: NodeSummaryFilter | None,
) -> None:
    if should_summarize_node is not None and not await should_summarize_node(node):
        return

    summary_context = _build_episode_context(
        node_data={
            'name': node.name,
            'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
            'entity_types': node.labels,
            'attributes': node.attributes,
        },
        episode=episode,
        previous_episodes=previous_episodes,
    )

    summary_response = await llm_client.generate_response(
        prompt_library.extract_nodes.extract_summary(summary_context),
        response_model=EntitySummary,
        model_size=ModelSize.small,
        group_id=node.group_id,
        prompt_name='extract_nodes.extract_summary',
    )

    node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)


def _build_episode_context(
    node_data: dict[str, Any],
    episode: EpisodicNode | None,
    previous_episodes: list[EpisodicNode] | None,
) -> dict[str, Any]:
    return {
        'node': node_data,
        'episode_content': episode.content if episode is not None else '',
        'previous_episodes': (
            [ep.content for ep in previous_episodes] if previous_episodes is not None else []
        ),
    }

```

--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------

```python
#!/usr/bin/env python3
"""
Comprehensive integration test suite for Graphiti MCP Server.
Covers all MCP tools with consideration for LLM inference latency.
"""

import asyncio
import json
import os
import time
from dataclasses import dataclass
from typing import Any

import pytest
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client


@dataclass
class TestMetrics:
    """Track test performance metrics."""

    operation: str
    start_time: float
    end_time: float
    success: bool
    details: dict[str, Any]

    @property
    def duration(self) -> float:
        """Calculate operation duration in seconds."""
        return self.end_time - self.start_time


class GraphitiTestClient:
    """Enhanced test client for comprehensive Graphiti MCP testing."""

    def __init__(self, test_group_id: str | None = None):
        self.test_group_id = test_group_id or f'test_{int(time.time())}'
        self.session = None
        self.metrics: list[TestMetrics] = []
        self.default_timeout = 30  # seconds

    async def __aenter__(self):
        """Initialize MCP client session."""
        server_params = StdioServerParameters(
            command='uv',
            args=['run', '../main.py', '--transport', 'stdio'],
            env={
                'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
                'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
                'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
            },
        )

        self.client_context = stdio_client(server_params)
        read, write = await self.client_context.__aenter__()
        self.session = ClientSession(read, write)
        await self.session.initialize()

        # Wait for server to be fully ready
        await asyncio.sleep(2)

        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Clean up client session."""
        if self.session:
            await self.session.close()
        if hasattr(self, 'client_context'):
            await self.client_context.__aexit__(exc_type, exc_val, exc_tb)

    async def call_tool_with_metrics(
        self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
    ) -> tuple[Any, TestMetrics]:
        """Call a tool and capture performance metrics."""
        start_time = time.time()
        timeout = timeout or self.default_timeout

        try:
            result = await asyncio.wait_for(
                self.session.call_tool(tool_name, arguments), timeout=timeout
            )

            content = result.content[0].text if result.content else None
            success = True
            details = {'result': content, 'tool': tool_name}

        except asyncio.TimeoutError:
            content = None
            success = False
            details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}

        except Exception as e:
            content = None
            success = False
            details = {'error': str(e), 'tool': tool_name}

        end_time = time.time()
        metric = TestMetrics(
            operation=f'call_{tool_name}',
            start_time=start_time,
            end_time=end_time,
            success=success,
            details=details,
        )
        self.metrics.append(metric)

        return content, metric

    async def wait_for_episode_processing(
        self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
    ) -> bool:
        """
        Wait for episodes to be processed with intelligent polling.

        Args:
            expected_count: Number of episodes expected to be processed
            max_wait: Maximum seconds to wait
            poll_interval: Seconds between status checks

        Returns:
            True if episodes were processed successfully
        """
        start_time = time.time()

        while (time.time() - start_time) < max_wait:
            result, _ = await self.call_tool_with_metrics(
                'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
            )

            if result:
                try:
                    episodes = json.loads(result) if isinstance(result, str) else result
                    if len(episodes.get('episodes', [])) >= expected_count:
                        return True
                except (json.JSONDecodeError, AttributeError):
                    pass

            await asyncio.sleep(poll_interval)

        return False


class TestCoreOperations:
    """Test core Graphiti operations."""

    @pytest.mark.asyncio
    async def test_server_initialization(self):
        """Verify server initializes with all required tools."""
        async with GraphitiTestClient() as client:
            tools_result = await client.session.list_tools()
            tools = {tool.name for tool in tools_result.tools}

            required_tools = {
                'add_memory',
                'search_memory_nodes',
                'search_memory_facts',
                'get_episodes',
                'delete_episode',
                'delete_entity_edge',
                'get_entity_edge',
                'clear_graph',
                'get_status',
            }

            missing_tools = required_tools - tools
            assert not missing_tools, f'Missing required tools: {missing_tools}'

    @pytest.mark.asyncio
    async def test_add_text_memory(self):
        """Test adding text-based memories."""
        async with GraphitiTestClient() as client:
            # Add memory
            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Tech Conference Notes',
                    'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
                    'source': 'text',
                    'source_description': 'conference notes',
                    'group_id': client.test_group_id,
                },
            )

            assert metric.success, f'Failed to add memory: {metric.details}'
            assert 'queued' in str(result).lower()

            # Wait for processing
            processed = await client.wait_for_episode_processing(expected_count=1)
            assert processed, 'Episode was not processed within timeout'

    @pytest.mark.asyncio
    async def test_add_json_memory(self):
        """Test adding structured JSON memories."""
        async with GraphitiTestClient() as client:
            json_data = {
                'project': {
                    'name': 'GraphitiDB',
                    'version': '2.0.0',
                    'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
                },
                'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
            }

            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Project Data',
                    'episode_body': json.dumps(json_data),
                    'source': 'json',
                    'source_description': 'project database',
                    'group_id': client.test_group_id,
                },
            )

            assert metric.success
            assert 'queued' in str(result).lower()

    @pytest.mark.asyncio
    async def test_add_message_memory(self):
        """Test adding conversation/message memories."""
        async with GraphitiTestClient() as client:
            conversation = """
            user: What are the key features of Graphiti?
            assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
            user: How does it handle entity resolution?
            assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
            """

            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Feature Discussion',
                    'episode_body': conversation,
                    'source': 'message',
                    'source_description': 'support chat',
                    'group_id': client.test_group_id,
                },
            )

            assert metric.success
            assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'


class TestSearchOperations:
    """Test search and retrieval operations."""

    @pytest.mark.asyncio
    async def test_search_nodes_semantic(self):
        """Test semantic search for nodes."""
        async with GraphitiTestClient() as client:
            # First add some test data
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Product Launch',
                    'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
                    'source': 'text',
                    'source_description': 'product roadmap',
                    'group_id': client.test_group_id,
                },
            )

            # Wait for processing
            await client.wait_for_episode_processing()

            # Search for nodes
            result, metric = await client.call_tool_with_metrics(
                'search_memory_nodes',
                {'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
            )

            assert metric.success
            assert result is not None

    @pytest.mark.asyncio
    async def test_search_facts_with_filters(self):
        """Test fact search with various filters."""
        async with GraphitiTestClient() as client:
            # Add test data
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Company Facts',
                    'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
                    'source': 'text',
                    'source_description': 'company profile',
                    'group_id': client.test_group_id,
                },
            )

            await client.wait_for_episode_processing()

            # Search with date filter
            result, metric = await client.call_tool_with_metrics(
                'search_memory_facts',
                {
                    'query': 'company information',
                    'group_id': client.test_group_id,
                    'created_after': '2020-01-01T00:00:00Z',
                    'limit': 20,
                },
            )

            assert metric.success

    @pytest.mark.asyncio
    async def test_hybrid_search(self):
        """Test hybrid search combining semantic and keyword search."""
        async with GraphitiTestClient() as client:
            # Add diverse test data
            test_memories = [
                {
                    'name': 'Technical Doc',
                    'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
                    'source': 'text',
                },
                {
                    'name': 'Architecture',
                    'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
                    'source': 'text',
                },
            ]

            for memory in test_memories:
                memory['group_id'] = client.test_group_id
                memory['source_description'] = 'documentation'
                await client.call_tool_with_metrics('add_memory', memory)

            await client.wait_for_episode_processing(expected_count=2)

            # Test semantic + keyword search
            result, metric = await client.call_tool_with_metrics(
                'search_memory_nodes',
                {'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
            )

            assert metric.success


class TestEpisodeManagement:
    """Test episode lifecycle operations."""

    @pytest.mark.asyncio
    async def test_get_episodes_pagination(self):
        """Test retrieving episodes with pagination."""
        async with GraphitiTestClient() as client:
            # Add multiple episodes
            for i in range(5):
                await client.call_tool_with_metrics(
                    'add_memory',
                    {
                        'name': f'Episode {i}',
                        'episode_body': f'This is test episode number {i}',
                        'source': 'text',
                        'source_description': 'test',
                        'group_id': client.test_group_id,
                    },
                )

            await client.wait_for_episode_processing(expected_count=5)

            # Test pagination
            result, metric = await client.call_tool_with_metrics(
                'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
            )

            assert metric.success
            episodes = json.loads(result) if isinstance(result, str) else result
            assert len(episodes.get('episodes', [])) <= 3

    @pytest.mark.asyncio
    async def test_delete_episode(self):
        """Test deleting specific episodes."""
        async with GraphitiTestClient() as client:
            # Add an episode
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'To Delete',
                    'episode_body': 'This episode will be deleted',
                    'source': 'text',
                    'source_description': 'test',
                    'group_id': client.test_group_id,
                },
            )

            await client.wait_for_episode_processing()

            # Get episode UUID
            result, _ = await client.call_tool_with_metrics(
                'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
            )

            episodes = json.loads(result) if isinstance(result, str) else result
            episode_uuid = episodes['episodes'][0]['uuid']

            # Delete the episode
            result, metric = await client.call_tool_with_metrics(
                'delete_episode', {'episode_uuid': episode_uuid}
            )

            assert metric.success
            assert 'deleted' in str(result).lower()


class TestEntityAndEdgeOperations:
    """Test entity and edge management."""

    @pytest.mark.asyncio
    async def test_get_entity_edge(self):
        """Test retrieving entity edges."""
        async with GraphitiTestClient() as client:
            # Add data to create entities and edges
            await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Relationship Data',
                    'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
                    'source': 'text',
                    'source_description': 'org chart',
                    'group_id': client.test_group_id,
                },
            )

            await client.wait_for_episode_processing()

            # Search for nodes to get UUIDs
            result, _ = await client.call_tool_with_metrics(
                'search_memory_nodes',
                {'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
            )

            # Note: This test assumes edges are created between entities
            # Actual edge retrieval would require valid edge UUIDs

    @pytest.mark.asyncio
    async def test_delete_entity_edge(self):
        """Test deleting entity edges."""
        # Similar structure to get_entity_edge but with deletion
        pass  # Implement based on actual edge creation patterns


class TestErrorHandling:
    """Test error conditions and edge cases."""

    @pytest.mark.asyncio
    async def test_invalid_tool_arguments(self):
        """Test handling of invalid tool arguments."""
        async with GraphitiTestClient() as client:
            # Missing required arguments
            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {'name': 'Incomplete'},  # Missing required fields
            )

            assert not metric.success
            assert 'error' in str(metric.details).lower()

    @pytest.mark.asyncio
    async def test_timeout_handling(self):
        """Test timeout handling for long operations."""
        async with GraphitiTestClient() as client:
            # Simulate a very large episode that might time out
            large_text = 'Large document content. ' * 10000

            result, metric = await client.call_tool_with_metrics(
                'add_memory',
                {
                    'name': 'Large Document',
                    'episode_body': large_text,
                    'source': 'text',
                    'source_description': 'large file',
                    'group_id': client.test_group_id,
                },
                timeout=5,  # Short timeout
            )

            # Check if timeout was handled gracefully
            if not metric.success:
                assert 'timeout' in str(metric.details).lower()

    @pytest.mark.asyncio
    async def test_concurrent_operations(self):
        """Test handling of concurrent operations."""
        async with GraphitiTestClient() as client:
            # Launch multiple operations concurrently
            tasks = []
            for i in range(5):
                task = client.call_tool_with_metrics(
                    'add_memory',
                    {
                        'name': f'Concurrent {i}',
                        'episode_body': f'Concurrent operation {i}',
                        'source': 'text',
                        'source_description': 'concurrent test',
                        'group_id': client.test_group_id,
                    },
                )
                tasks.append(task)

            results = await asyncio.gather(*tasks, return_exceptions=True)

            # Check that operations were queued successfully
            successful = sum(1 for r, m in results if m.success)
            assert successful >= 3  # At least 60% should succeed


class TestPerformance:
    """Test performance characteristics and optimization."""

    @pytest.mark.asyncio
    async def test_latency_metrics(self):
        """Measure and validate operation latencies."""
        async with GraphitiTestClient() as client:
            operations = [
                (
                    'add_memory',
                    {
                        'name': 'Perf Test',
                        'episode_body': 'Simple text',
                        'source': 'text',
                        'source_description': 'test',
                        'group_id': client.test_group_id,
                    },
                ),
                (
                    'search_memory_nodes',
                    {'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
                ),
                ('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
            ]

            for tool_name, args in operations:
                _, metric = await client.call_tool_with_metrics(tool_name, args)

                # Log performance metrics
                print(f'{tool_name}: {metric.duration:.2f}s')

                # Basic latency assertions
                if tool_name == 'get_episodes':
                    assert metric.duration < 2, f'{tool_name} too slow'
                elif tool_name == 'search_memory_nodes':
                    assert metric.duration < 10, f'{tool_name} too slow'

    @pytest.mark.asyncio
    async def test_batch_processing_efficiency(self):
        """Test efficiency of batch operations."""
        async with GraphitiTestClient() as client:
            batch_size = 10
            start_time = time.time()

            # Batch add memories
            for i in range(batch_size):
                await client.call_tool_with_metrics(
                    'add_memory',
                    {
                        'name': f'Batch {i}',
                        'episode_body': f'Batch content {i}',
                        'source': 'text',
                        'source_description': 'batch test',
                        'group_id': client.test_group_id,
                    },
                )

            # Wait for all to process
            processed = await client.wait_for_episode_processing(
                expected_count=batch_size,
                max_wait=120,  # Allow more time for batch
            )

            total_time = time.time() - start_time
            avg_time_per_item = total_time / batch_size

            assert processed, f'Failed to process {batch_size} items'
            assert avg_time_per_item < 15, (
                f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
            )

            # Generate performance report
            print('\nBatch Performance Report:')
            print(f'  Total items: {batch_size}')
            print(f'  Total time: {total_time:.2f}s')
            print(f'  Avg per item: {avg_time_per_item:.2f}s')


class TestDatabaseBackends:
    """Test different database backend configurations."""

    @pytest.mark.asyncio
    @pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
    async def test_database_operations(self, database):
        """Test operations with different database backends."""
        env_vars = {
            'DATABASE_PROVIDER': database,
            'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
        }

        if database == 'neo4j':
            env_vars.update(
                {
                    'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
                    'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
                    'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
                }
            )
        elif database == 'falkordb':
            env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')

        # This test would require setting up server with specific database
        # Implementation depends on database availability
        pass  # Placeholder for database-specific tests


def generate_test_report(client: GraphitiTestClient) -> str:
    """Generate a comprehensive test report from metrics."""
    if not client.metrics:
        return 'No metrics collected'

    report = []
    report.append('\n' + '=' * 60)
    report.append('GRAPHITI MCP TEST REPORT')
    report.append('=' * 60)

    # Summary statistics
    total_ops = len(client.metrics)
    successful_ops = sum(1 for m in client.metrics if m.success)
    avg_duration = sum(m.duration for m in client.metrics) / total_ops

    report.append(f'\nTotal Operations: {total_ops}')
    report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
    report.append(f'Average Duration: {avg_duration:.2f}s')

    # Operation breakdown
    report.append('\nOperation Breakdown:')
    operation_stats = {}
    for metric in client.metrics:
        if metric.operation not in operation_stats:
            operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
        stats = operation_stats[metric.operation]
        stats['count'] += 1
        stats['success'] += 1 if metric.success else 0
        stats['total_duration'] += metric.duration

    for op, stats in sorted(operation_stats.items()):
        avg_dur = stats['total_duration'] / stats['count']
        success_rate = stats['success'] / stats['count'] * 100
        report.append(
            f'  {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
        )

    # Slowest operations
    slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
    report.append('\nSlowest Operations:')
    for metric in slowest:
        report.append(f'  {metric.operation}: {metric.duration:.2f}s')

    report.append('=' * 60)
    return '\n'.join(report)


if __name__ == '__main__':
    # Run tests with pytest
    pytest.main([__file__, '-v', '--asyncio-mode=auto'])

```

--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field
from typing_extensions import LiteralString

from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
    COMMUNITY_EDGE_RETURN,
    EPISODIC_EDGE_RETURN,
    EPISODIC_EDGE_SAVE,
    HAS_EPISODE_EDGE_RETURN,
    HAS_EPISODE_EDGE_SAVE,
    NEXT_EPISODE_EDGE_RETURN,
    NEXT_EPISODE_EDGE_SAVE,
    get_community_edge_save_query,
    get_entity_edge_return_query,
    get_entity_edge_save_query,
)
from graphiti_core.nodes import Node

logger = logging.getLogger(__name__)


class Edge(BaseModel, ABC):
    uuid: str = Field(default_factory=lambda: str(uuid4()))
    group_id: str = Field(description='partition of the graph')
    source_node_uuid: str
    target_node_uuid: str
    created_at: datetime

    @abstractmethod
    async def save(self, driver: GraphDriver): ...

    async def delete(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.edge_delete(self, driver)

        if driver.provider == GraphProvider.KUZU:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
                DELETE e
                """,
                uuid=self.uuid,
            )
            await driver.execute_query(
                """
                MATCH (e:RelatesToNode_ {uuid: $uuid})
                DETACH DELETE e
                """,
                uuid=self.uuid,
            )
        else:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
                DELETE e
                """,
                uuid=self.uuid,
            )

        logger.debug(f'Deleted Edge: {self.uuid}')

    @classmethod
    async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)

        if driver.provider == GraphProvider.KUZU:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
                WHERE e.uuid IN $uuids
                DELETE e
                """,
                uuids=uuids,
            )
            await driver.execute_query(
                """
                MATCH (e:RelatesToNode_)
                WHERE e.uuid IN $uuids
                DETACH DELETE e
                """,
                uuids=uuids,
            )
        else:
            await driver.execute_query(
                """
                MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
                WHERE e.uuid IN $uuids
                DELETE e
                """,
                uuids=uuids,
            )

        logger.debug(f'Deleted Edges: {uuids}')

    def __hash__(self):
        return hash(self.uuid)

    def __eq__(self, other):
        if isinstance(other, Node):
            return self.uuid == other.uuid
        return False

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...


class EpisodicEdge(Edge):
    async def save(self, driver: GraphDriver):
        result = await driver.execute_query(
            EPISODIC_EDGE_SAVE,
            episode_uuid=self.source_node_uuid,
            entity_uuid=self.target_node_uuid,
            uuid=self.uuid,
            group_id=self.group_id,
            created_at=self.created_at,
        )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
            RETURN
            """
            + EPISODIC_EDGE_RETURN,
            uuid=uuid,
            routing_='r',
        )

        edges = [get_episodic_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuid)
        return edges[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
            WHERE e.uuid IN $uuids
            RETURN
            """
            + EPISODIC_EDGE_RETURN,
            uuids=uuids,
            routing_='r',
        )

        edges = [get_episodic_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuids[0])
        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + EPISODIC_EDGE_RETURN
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_episodic_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise GroupsEdgesNotFoundError(group_ids)
        return edges


class EntityEdge(Edge):
    name: str = Field(description='name of the edge, relation name')
    fact: str = Field(description='fact representing the edge and nodes that it connects')
    fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
    episodes: list[str] = Field(
        default=[],
        description='list of episode ids that reference these entity edges',
    )
    expired_at: datetime | None = Field(
        default=None, description='datetime of when the node was invalidated'
    )
    valid_at: datetime | None = Field(
        default=None, description='datetime of when the fact became true'
    )
    invalid_at: datetime | None = Field(
        default=None, description='datetime of when the fact stopped being true'
    )
    attributes: dict[str, Any] = Field(
        default={}, description='Additional attributes of the edge. Dependent on edge name'
    )

    async def generate_embedding(self, embedder: EmbedderClient):
        start = time()

        text = self.fact.replace('\n', ' ')
        self.fact_embedding = await embedder.create(input_data=[text])

        end = time()
        logger.debug(f'embedded {text} in {end - start} ms')

        return self.fact_embedding

    async def load_fact_embedding(self, driver: GraphDriver):
        if driver.graph_operations_interface:
            return await driver.graph_operations_interface.edge_load_embeddings(self, driver)

        query = """
            MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
            RETURN e.fact_embedding AS fact_embedding
        """

        if driver.provider == GraphProvider.NEPTUNE:
            query = """
                MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
                RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
            """

        if driver.provider == GraphProvider.KUZU:
            query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
                RETURN e.fact_embedding AS fact_embedding
            """

        records, _, _ = await driver.execute_query(
            query,
            uuid=self.uuid,
            routing_='r',
        )

        if len(records) == 0:
            raise EdgeNotFoundError(self.uuid)

        self.fact_embedding = records[0]['fact_embedding']

    async def save(self, driver: GraphDriver):
        edge_data: dict[str, Any] = {
            'source_uuid': self.source_node_uuid,
            'target_uuid': self.target_node_uuid,
            'uuid': self.uuid,
            'name': self.name,
            'group_id': self.group_id,
            'fact': self.fact,
            'fact_embedding': self.fact_embedding,
            'episodes': self.episodes,
            'created_at': self.created_at,
            'expired_at': self.expired_at,
            'valid_at': self.valid_at,
            'invalid_at': self.invalid_at,
        }

        if driver.provider == GraphProvider.KUZU:
            edge_data['attributes'] = json.dumps(self.attributes)
            result = await driver.execute_query(
                get_entity_edge_save_query(driver.provider),
                **edge_data,
            )
        else:
            edge_data.update(self.attributes or {})
            result = await driver.execute_query(
                get_entity_edge_save_query(driver.provider),
                edge_data=edge_data,
            )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            uuid=uuid,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuid)
        return edges[0]

    @classmethod
    async def get_between_nodes(
        cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
    ):
        match_query = """
            MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity {uuid: $source_node_uuid})
                      -[:RELATES_TO]->(e:RelatesToNode_)
                      -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            source_node_uuid=source_node_uuid,
            target_node_uuid=target_node_uuid,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        return edges

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        if len(uuids) == 0:
            return []

        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            WHERE e.uuid IN $uuids
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            uuids=uuids,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
        with_embeddings: bool = False,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
        with_embeddings_query: LiteralString = (
            """,
                e.fact_embedding AS fact_embedding
                """
            if with_embeddings
            else ''
        )

        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider)
            + with_embeddings_query
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        if len(edges) == 0:
            raise GroupsEdgesNotFoundError(group_ids)
        return edges

    @classmethod
    async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
        match_query = """
            MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
            """

        records, _, _ = await driver.execute_query(
            match_query
            + """
            RETURN
            """
            + get_entity_edge_return_query(driver.provider),
            node_uuid=node_uuid,
            routing_='r',
        )

        edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

        return edges


class CommunityEdge(Edge):
    async def save(self, driver: GraphDriver):
        result = await driver.execute_query(
            get_community_edge_save_query(driver.provider),
            community_uuid=self.source_node_uuid,
            entity_uuid=self.target_node_uuid,
            uuid=self.uuid,
            group_id=self.group_id,
            created_at=self.created_at,
        )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
            RETURN
            """
            + COMMUNITY_EDGE_RETURN,
            uuid=uuid,
            routing_='r',
        )

        edges = [get_community_edge_from_record(record) for record in records]

        return edges[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Community)-[e:HAS_MEMBER]->(m)
            WHERE e.uuid IN $uuids
            RETURN
            """
            + COMMUNITY_EDGE_RETURN,
            uuids=uuids,
            routing_='r',
        )

        edges = [get_community_edge_from_record(record) for record in records]

        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Community)-[e:HAS_MEMBER]->(m)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + COMMUNITY_EDGE_RETURN
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_community_edge_from_record(record) for record in records]

        return edges


class HasEpisodeEdge(Edge):
    async def save(self, driver: GraphDriver):
        result = await driver.execute_query(
            HAS_EPISODE_EDGE_SAVE,
            saga_uuid=self.source_node_uuid,
            episode_uuid=self.target_node_uuid,
            uuid=self.uuid,
            group_id=self.group_id,
            created_at=self.created_at,
        )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    async def delete(self, driver: GraphDriver):
        await driver.execute_query(
            """
            MATCH (n:Saga)-[e:HAS_EPISODE {uuid: $uuid}]->(m:Episodic)
            DELETE e
            """,
            uuid=self.uuid,
        )

        logger.debug(f'Deleted Edge: {self.uuid}')

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Saga)-[e:HAS_EPISODE {uuid: $uuid}]->(m:Episodic)
            RETURN
            """
            + HAS_EPISODE_EDGE_RETURN,
            uuid=uuid,
            routing_='r',
        )

        edges = [get_has_episode_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuid)
        return edges[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Saga)-[e:HAS_EPISODE]->(m:Episodic)
            WHERE e.uuid IN $uuids
            RETURN
            """
            + HAS_EPISODE_EDGE_RETURN,
            uuids=uuids,
            routing_='r',
        )

        edges = [get_has_episode_edge_from_record(record) for record in records]

        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Saga)-[e:HAS_EPISODE]->(m:Episodic)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + HAS_EPISODE_EDGE_RETURN
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_has_episode_edge_from_record(record) for record in records]

        return edges


class NextEpisodeEdge(Edge):
    async def save(self, driver: GraphDriver):
        result = await driver.execute_query(
            NEXT_EPISODE_EDGE_SAVE,
            source_episode_uuid=self.source_node_uuid,
            target_episode_uuid=self.target_node_uuid,
            uuid=self.uuid,
            group_id=self.group_id,
            created_at=self.created_at,
        )

        logger.debug(f'Saved edge to Graph: {self.uuid}')

        return result

    async def delete(self, driver: GraphDriver):
        await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:NEXT_EPISODE {uuid: $uuid}]->(m:Episodic)
            DELETE e
            """,
            uuid=self.uuid,
        )

        logger.debug(f'Deleted Edge: {self.uuid}')

    @classmethod
    async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:NEXT_EPISODE {uuid: $uuid}]->(m:Episodic)
            RETURN
            """
            + NEXT_EPISODE_EDGE_RETURN,
            uuid=uuid,
            routing_='r',
        )

        edges = [get_next_episode_edge_from_record(record) for record in records]

        if len(edges) == 0:
            raise EdgeNotFoundError(uuid)
        return edges[0]

    @classmethod
    async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:NEXT_EPISODE]->(m:Episodic)
            WHERE e.uuid IN $uuids
            RETURN
            """
            + NEXT_EPISODE_EDGE_RETURN,
            uuids=uuids,
            routing_='r',
        )

        edges = [get_next_episode_edge_from_record(record) for record in records]

        return edges

    @classmethod
    async def get_by_group_ids(
        cls,
        driver: GraphDriver,
        group_ids: list[str],
        limit: int | None = None,
        uuid_cursor: str | None = None,
    ):
        cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
        limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''

        records, _, _ = await driver.execute_query(
            """
            MATCH (n:Episodic)-[e:NEXT_EPISODE]->(m:Episodic)
            WHERE e.group_id IN $group_ids
            """
            + cursor_query
            + """
            RETURN
            """
            + NEXT_EPISODE_EDGE_RETURN
            + """
            ORDER BY e.uuid DESC
            """
            + limit_query,
            group_ids=group_ids,
            uuid=uuid_cursor,
            limit=limit,
            routing_='r',
        )

        edges = [get_next_episode_edge_from_record(record) for record in records]

        return edges


# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
    return EpisodicEdge(
        uuid=record['uuid'],
        group_id=record['group_id'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
    )


def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
    episodes = record['episodes']
    if provider == GraphProvider.KUZU:
        attributes = json.loads(record['attributes']) if record['attributes'] else {}
    else:
        attributes = record['attributes']
        attributes.pop('uuid', None)
        attributes.pop('source_node_uuid', None)
        attributes.pop('target_node_uuid', None)
        attributes.pop('fact', None)
        attributes.pop('fact_embedding', None)
        attributes.pop('name', None)
        attributes.pop('group_id', None)
        attributes.pop('episodes', None)
        attributes.pop('created_at', None)
        attributes.pop('expired_at', None)
        attributes.pop('valid_at', None)
        attributes.pop('invalid_at', None)

    edge = EntityEdge(
        uuid=record['uuid'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        fact=record['fact'],
        fact_embedding=record.get('fact_embedding'),
        name=record['name'],
        group_id=record['group_id'],
        episodes=episodes,
        created_at=parse_db_date(record['created_at']),  # type: ignore
        expired_at=parse_db_date(record['expired_at']),
        valid_at=parse_db_date(record['valid_at']),
        invalid_at=parse_db_date(record['invalid_at']),
        attributes=attributes,
    )

    return edge


def get_community_edge_from_record(record: Any):
    return CommunityEdge(
        uuid=record['uuid'],
        group_id=record['group_id'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
    )


def get_has_episode_edge_from_record(record: Any) -> HasEpisodeEdge:
    return HasEpisodeEdge(
        uuid=record['uuid'],
        group_id=record['group_id'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
    )


def get_next_episode_edge_from_record(record: Any) -> NextEpisodeEdge:
    return NextEpisodeEdge(
        uuid=record['uuid'],
        group_id=record['group_id'],
        source_node_uuid=record['source_node_uuid'],
        target_node_uuid=record['target_node_uuid'],
        created_at=parse_db_date(record['created_at']),  # type: ignore
    )


async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
    # filter out falsey values from edges
    filtered_edges = [edge for edge in edges if edge.fact]

    if len(filtered_edges) == 0:
        return
    fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
    for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
        edge.fact_embedding = fact_embedding

```

--------------------------------------------------------------------------------
/graphiti_core/utils/content_chunking.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import json
import logging
import random
import re
from itertools import combinations
from math import comb
from typing import TypeVar

from graphiti_core.helpers import (
    CHUNK_DENSITY_THRESHOLD,
    CHUNK_MIN_TOKENS,
    CHUNK_OVERLAP_TOKENS,
    CHUNK_TOKEN_SIZE,
)
from graphiti_core.nodes import EpisodeType

logger = logging.getLogger(__name__)

# Approximate characters per token (conservative estimate)
CHARS_PER_TOKEN = 4


def estimate_tokens(text: str) -> int:
    """Estimate token count using character-based heuristic.

    Uses ~4 characters per token as a conservative estimate.
    This is faster than actual tokenization and works across all LLM providers.

    Args:
        text: The text to estimate tokens for

    Returns:
        Estimated token count
    """
    return len(text) // CHARS_PER_TOKEN


def _tokens_to_chars(tokens: int) -> int:
    """Convert token count to approximate character count."""
    return tokens * CHARS_PER_TOKEN


def should_chunk(content: str, episode_type: EpisodeType) -> bool:
    """Determine whether content should be chunked based on size and entity density.

    Only chunks content that is both:
    1. Large enough to potentially cause LLM issues (>= CHUNK_MIN_TOKENS)
    2. High entity density (many entities per token)

    Short content processes fine regardless of density. This targets the specific
    failure case of large entity-dense inputs while preserving context for
    prose/narrative content and avoiding unnecessary chunking of small inputs.

    Args:
        content: The content to evaluate
        episode_type: Type of episode (json, message, text)

    Returns:
        True if content is large and has high entity density
    """
    tokens = estimate_tokens(content)

    # Short content always processes fine - no need to chunk
    if tokens < CHUNK_MIN_TOKENS:
        return False

    return _estimate_high_density(content, episode_type, tokens)


def _estimate_high_density(content: str, episode_type: EpisodeType, tokens: int) -> bool:
    """Estimate whether content has high entity density.

    High-density content (many entities per token) benefits from chunking.
    Low-density content (prose, narratives) loses context when chunked.

    Args:
        content: The content to analyze
        episode_type: Type of episode
        tokens: Pre-computed token count

    Returns:
        True if content appears to have high entity density
    """
    if episode_type == EpisodeType.json:
        return _json_likely_dense(content, tokens)
    else:
        return _text_likely_dense(content, tokens)


def _json_likely_dense(content: str, tokens: int) -> bool:
    """Estimate entity density for JSON content.

    JSON is considered dense if it has many array elements or object keys,
    as each typically represents a distinct entity or data point.

    Heuristics:
    - Array: Count elements, estimate entities per 1000 tokens
    - Object: Count top-level keys

    Args:
        content: JSON string content
        tokens: Token count

    Returns:
        True if JSON appears to have high entity density
    """
    try:
        data = json.loads(content)
    except json.JSONDecodeError:
        # Invalid JSON, fall back to text heuristics
        return _text_likely_dense(content, tokens)

    if isinstance(data, list):
        # For arrays, each element likely contains entities
        element_count = len(data)
        # Estimate density: elements per 1000 tokens
        density = (element_count / tokens) * 1000 if tokens > 0 else 0
        return density > CHUNK_DENSITY_THRESHOLD * 1000  # Scale threshold
    elif isinstance(data, dict):
        # For objects, count keys recursively (shallow)
        key_count = _count_json_keys(data, max_depth=2)
        density = (key_count / tokens) * 1000 if tokens > 0 else 0
        return density > CHUNK_DENSITY_THRESHOLD * 1000
    else:
        # Scalar value, no need to chunk
        return False


def _count_json_keys(data: dict, max_depth: int = 2, current_depth: int = 0) -> int:
    """Count keys in a JSON object up to a certain depth.

    Args:
        data: Dictionary to count keys in
        max_depth: Maximum depth to traverse
        current_depth: Current recursion depth

    Returns:
        Count of keys
    """
    if current_depth >= max_depth:
        return 0

    count = len(data)
    for value in data.values():
        if isinstance(value, dict):
            count += _count_json_keys(value, max_depth, current_depth + 1)
        elif isinstance(value, list):
            for item in value:
                if isinstance(item, dict):
                    count += _count_json_keys(item, max_depth, current_depth + 1)
    return count


def _text_likely_dense(content: str, tokens: int) -> bool:
    """Estimate entity density for text content.

    Uses capitalized words as a proxy for named entities (people, places,
    organizations, products). High ratio of capitalized words suggests
    high entity density.

    Args:
        content: Text content
        tokens: Token count

    Returns:
        True if text appears to have high entity density
    """
    if tokens == 0:
        return False

    # Split into words
    words = content.split()
    if not words:
        return False

    # Count capitalized words (excluding sentence starters)
    # A word is "capitalized" if it starts with uppercase and isn't all caps
    capitalized_count = 0
    for i, word in enumerate(words):
        # Skip if it's likely a sentence starter (after . ! ? or first word)
        if i == 0:
            continue
        if i > 0 and words[i - 1].rstrip()[-1:] in '.!?':
            continue

        # Check if capitalized (first char upper, not all caps)
        cleaned = word.strip('.,!?;:\'"()[]{}')
        if cleaned and cleaned[0].isupper() and not cleaned.isupper():
            capitalized_count += 1

    # Calculate density: capitalized words per 1000 tokens
    density = (capitalized_count / tokens) * 1000 if tokens > 0 else 0

    # Text density threshold is typically lower than JSON
    # A well-written article might have 5-10% named entities
    return density > CHUNK_DENSITY_THRESHOLD * 500  # Half the JSON threshold


def chunk_json_content(
    content: str,
    chunk_size_tokens: int | None = None,
    overlap_tokens: int | None = None,
) -> list[str]:
    """Split JSON content into chunks while preserving structure.

    For arrays: splits at element boundaries, keeping complete objects.
    For objects: splits at top-level key boundaries.

    Args:
        content: JSON string to chunk
        chunk_size_tokens: Target size per chunk in tokens (default from env)
        overlap_tokens: Overlap between chunks in tokens (default from env)

    Returns:
        List of JSON string chunks
    """
    chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
    overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS

    chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
    overlap_chars = _tokens_to_chars(overlap_tokens)

    try:
        data = json.loads(content)
    except json.JSONDecodeError:
        logger.warning('Failed to parse JSON, falling back to text chunking')
        return chunk_text_content(content, chunk_size_tokens, overlap_tokens)

    if isinstance(data, list):
        return _chunk_json_array(data, chunk_size_chars, overlap_chars)
    elif isinstance(data, dict):
        return _chunk_json_object(data, chunk_size_chars, overlap_chars)
    else:
        # Scalar value, return as-is
        return [content]


def _chunk_json_array(
    data: list,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Chunk a JSON array by splitting at element boundaries."""
    if not data:
        return ['[]']

    chunks: list[str] = []
    current_elements: list = []
    current_size = 2  # Account for '[]'

    for element in data:
        element_json = json.dumps(element)
        element_size = len(element_json) + 2  # Account for comma and space

        # Check if adding this element would exceed chunk size
        if current_elements and current_size + element_size > chunk_size_chars:
            # Save current chunk
            chunks.append(json.dumps(current_elements))

            # Start new chunk with overlap (include last few elements)
            overlap_elements = _get_overlap_elements(current_elements, overlap_chars)
            current_elements = overlap_elements
            current_size = len(json.dumps(current_elements)) if current_elements else 2

        current_elements.append(element)
        current_size += element_size

    # Don't forget the last chunk
    if current_elements:
        chunks.append(json.dumps(current_elements))

    return chunks if chunks else ['[]']


def _get_overlap_elements(elements: list, overlap_chars: int) -> list:
    """Get elements from the end of a list that fit within overlap_chars."""
    if not elements:
        return []

    overlap_elements: list = []
    current_size = 2  # Account for '[]'

    for element in reversed(elements):
        element_json = json.dumps(element)
        element_size = len(element_json) + 2

        if current_size + element_size > overlap_chars:
            break

        overlap_elements.insert(0, element)
        current_size += element_size

    return overlap_elements


def _chunk_json_object(
    data: dict,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Chunk a JSON object by splitting at top-level key boundaries."""
    if not data:
        return ['{}']

    chunks: list[str] = []
    current_keys: list[str] = []
    current_dict: dict = {}
    current_size = 2  # Account for '{}'

    for key, value in data.items():
        entry_json = json.dumps({key: value})
        entry_size = len(entry_json)

        # Check if adding this entry would exceed chunk size
        if current_dict and current_size + entry_size > chunk_size_chars:
            # Save current chunk
            chunks.append(json.dumps(current_dict))

            # Start new chunk with overlap (include last few keys)
            overlap_dict = _get_overlap_dict(current_dict, current_keys, overlap_chars)
            current_dict = overlap_dict
            current_keys = list(overlap_dict.keys())
            current_size = len(json.dumps(current_dict)) if current_dict else 2

        current_dict[key] = value
        current_keys.append(key)
        current_size += entry_size

    # Don't forget the last chunk
    if current_dict:
        chunks.append(json.dumps(current_dict))

    return chunks if chunks else ['{}']


def _get_overlap_dict(data: dict, keys: list[str], overlap_chars: int) -> dict:
    """Get key-value pairs from the end of a dict that fit within overlap_chars."""
    if not data or not keys:
        return {}

    overlap_dict: dict = {}
    current_size = 2  # Account for '{}'

    for key in reversed(keys):
        if key not in data:
            continue
        entry_json = json.dumps({key: data[key]})
        entry_size = len(entry_json)

        if current_size + entry_size > overlap_chars:
            break

        overlap_dict[key] = data[key]
        current_size += entry_size

    # Reverse to maintain original order
    return dict(reversed(list(overlap_dict.items())))


def chunk_text_content(
    content: str,
    chunk_size_tokens: int | None = None,
    overlap_tokens: int | None = None,
) -> list[str]:
    """Split text content at natural boundaries (paragraphs, sentences).

    Includes overlap to capture entities at chunk boundaries.

    Args:
        content: Text to chunk
        chunk_size_tokens: Target size per chunk in tokens (default from env)
        overlap_tokens: Overlap between chunks in tokens (default from env)

    Returns:
        List of text chunks
    """
    chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
    overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS

    chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
    overlap_chars = _tokens_to_chars(overlap_tokens)

    if len(content) <= chunk_size_chars:
        return [content]

    # Split into paragraphs first
    paragraphs = re.split(r'\n\s*\n', content)

    chunks: list[str] = []
    current_chunk: list[str] = []
    current_size = 0

    for paragraph in paragraphs:
        paragraph = paragraph.strip()
        if not paragraph:
            continue

        para_size = len(paragraph)

        # If a single paragraph is too large, split it by sentences
        if para_size > chunk_size_chars:
            # First, save current chunk if any
            if current_chunk:
                chunks.append('\n\n'.join(current_chunk))
                current_chunk = []
                current_size = 0

            # Split large paragraph by sentences
            sentence_chunks = _chunk_by_sentences(paragraph, chunk_size_chars, overlap_chars)
            chunks.extend(sentence_chunks)
            continue

        # Check if adding this paragraph would exceed chunk size
        if current_chunk and current_size + para_size + 2 > chunk_size_chars:
            # Save current chunk
            chunks.append('\n\n'.join(current_chunk))

            # Start new chunk with overlap
            overlap_text = _get_overlap_text('\n\n'.join(current_chunk), overlap_chars)
            if overlap_text:
                current_chunk = [overlap_text]
                current_size = len(overlap_text)
            else:
                current_chunk = []
                current_size = 0

        current_chunk.append(paragraph)
        current_size += para_size + 2  # Account for '\n\n'

    # Don't forget the last chunk
    if current_chunk:
        chunks.append('\n\n'.join(current_chunk))

    return chunks if chunks else [content]


def _chunk_by_sentences(
    text: str,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Split text by sentence boundaries."""
    # Split on sentence-ending punctuation followed by whitespace
    sentence_pattern = r'(?<=[.!?])\s+'
    sentences = re.split(sentence_pattern, text)

    chunks: list[str] = []
    current_chunk: list[str] = []
    current_size = 0

    for sentence in sentences:
        sentence = sentence.strip()
        if not sentence:
            continue

        sent_size = len(sentence)

        # If a single sentence is too large, split it by fixed size
        if sent_size > chunk_size_chars:
            if current_chunk:
                chunks.append(' '.join(current_chunk))
                current_chunk = []
                current_size = 0

            # Split by fixed size as last resort
            fixed_chunks = _chunk_by_size(sentence, chunk_size_chars, overlap_chars)
            chunks.extend(fixed_chunks)
            continue

        # Check if adding this sentence would exceed chunk size
        if current_chunk and current_size + sent_size + 1 > chunk_size_chars:
            chunks.append(' '.join(current_chunk))

            # Start new chunk with overlap
            overlap_text = _get_overlap_text(' '.join(current_chunk), overlap_chars)
            if overlap_text:
                current_chunk = [overlap_text]
                current_size = len(overlap_text)
            else:
                current_chunk = []
                current_size = 0

        current_chunk.append(sentence)
        current_size += sent_size + 1

    if current_chunk:
        chunks.append(' '.join(current_chunk))

    return chunks


def _chunk_by_size(
    text: str,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Split text by fixed character size (last resort)."""
    chunks: list[str] = []
    start = 0

    while start < len(text):
        end = min(start + chunk_size_chars, len(text))

        # Try to break at word boundary
        if end < len(text):
            space_idx = text.rfind(' ', start, end)
            if space_idx > start:
                end = space_idx

        chunks.append(text[start:end].strip())

        # Move start forward, ensuring progress even if overlap >= chunk_size
        # Always advance by at least (chunk_size - overlap) or 1 char minimum
        min_progress = max(1, chunk_size_chars - overlap_chars)
        start = max(start + min_progress, end - overlap_chars)

    return chunks


def _get_overlap_text(text: str, overlap_chars: int) -> str:
    """Get the last overlap_chars characters of text, breaking at word boundary."""
    if len(text) <= overlap_chars:
        return text

    overlap_start = len(text) - overlap_chars
    # Find the next word boundary after overlap_start
    space_idx = text.find(' ', overlap_start)
    if space_idx != -1:
        return text[space_idx + 1 :]
    return text[overlap_start:]


def chunk_message_content(
    content: str,
    chunk_size_tokens: int | None = None,
    overlap_tokens: int | None = None,
) -> list[str]:
    """Split conversation content preserving message boundaries.

    Never splits mid-message. Messages are identified by patterns like:
    - "Speaker: message"
    - JSON message arrays
    - Newline-separated messages

    Args:
        content: Conversation content to chunk
        chunk_size_tokens: Target size per chunk in tokens (default from env)
        overlap_tokens: Overlap between chunks in tokens (default from env)

    Returns:
        List of conversation chunks
    """
    chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
    overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS

    chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
    overlap_chars = _tokens_to_chars(overlap_tokens)

    if len(content) <= chunk_size_chars:
        return [content]

    # Try to detect message format
    # Check if it's JSON (array of message objects)
    try:
        data = json.loads(content)
        if isinstance(data, list):
            return _chunk_message_array(data, chunk_size_chars, overlap_chars)
    except json.JSONDecodeError:
        pass

    # Try speaker pattern (e.g., "Alice: Hello")
    speaker_pattern = r'^([A-Za-z_][A-Za-z0-9_\s]*):(.+?)(?=^[A-Za-z_][A-Za-z0-9_\s]*:|$)'
    if re.search(speaker_pattern, content, re.MULTILINE | re.DOTALL):
        return _chunk_speaker_messages(content, chunk_size_chars, overlap_chars)

    # Fallback to line-based chunking
    return _chunk_by_lines(content, chunk_size_chars, overlap_chars)


def _chunk_message_array(
    messages: list,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Chunk a JSON array of message objects."""
    # Delegate to JSON array chunking
    chunks = _chunk_json_array(messages, chunk_size_chars, overlap_chars)
    return chunks


def _chunk_speaker_messages(
    content: str,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Chunk messages in 'Speaker: message' format."""
    # Split on speaker patterns
    pattern = r'(?=^[A-Za-z_][A-Za-z0-9_\s]*:)'
    messages = re.split(pattern, content, flags=re.MULTILINE)
    messages = [m.strip() for m in messages if m.strip()]

    if not messages:
        return [content]

    chunks: list[str] = []
    current_messages: list[str] = []
    current_size = 0

    for message in messages:
        msg_size = len(message)

        # If a single message is too large, include it as its own chunk
        if msg_size > chunk_size_chars:
            if current_messages:
                chunks.append('\n'.join(current_messages))
                current_messages = []
                current_size = 0
            chunks.append(message)
            continue

        if current_messages and current_size + msg_size + 1 > chunk_size_chars:
            chunks.append('\n'.join(current_messages))

            # Get overlap (last message(s) that fit)
            overlap_messages = _get_overlap_messages(current_messages, overlap_chars)
            current_messages = overlap_messages
            current_size = sum(len(m) for m in current_messages) + len(current_messages) - 1

        current_messages.append(message)
        current_size += msg_size + 1

    if current_messages:
        chunks.append('\n'.join(current_messages))

    return chunks if chunks else [content]


def _get_overlap_messages(messages: list[str], overlap_chars: int) -> list[str]:
    """Get messages from the end that fit within overlap_chars."""
    if not messages:
        return []

    overlap: list[str] = []
    current_size = 0

    for msg in reversed(messages):
        msg_size = len(msg) + 1
        if current_size + msg_size > overlap_chars:
            break
        overlap.insert(0, msg)
        current_size += msg_size

    return overlap


def _chunk_by_lines(
    content: str,
    chunk_size_chars: int,
    overlap_chars: int,
) -> list[str]:
    """Chunk content by line boundaries."""
    lines = content.split('\n')

    chunks: list[str] = []
    current_lines: list[str] = []
    current_size = 0

    for line in lines:
        line_size = len(line) + 1

        if current_lines and current_size + line_size > chunk_size_chars:
            chunks.append('\n'.join(current_lines))

            # Get overlap lines
            overlap_text = '\n'.join(current_lines)
            overlap = _get_overlap_text(overlap_text, overlap_chars)
            if overlap:
                current_lines = overlap.split('\n')
                current_size = len(overlap)
            else:
                current_lines = []
                current_size = 0

        current_lines.append(line)
        current_size += line_size

    if current_lines:
        chunks.append('\n'.join(current_lines))

    return chunks if chunks else [content]


T = TypeVar('T')

MAX_COMBINATIONS_TO_EVALUATE = 1000


def _random_combination(n: int, k: int) -> tuple[int, ...]:
    """Generate a random combination of k items from range(n)."""
    return tuple(sorted(random.sample(range(n), k)))


def generate_covering_chunks(items: list[T], k: int) -> list[tuple[list[T], list[int]]]:
    """Generate chunks of items that cover all pairs using a greedy approach.

    Based on the Handshake Flights Problem / Covering Design problem.
    Each chunk of K items covers C(K,2) = K(K-1)/2 pairs. We greedily select
    chunks to maximize coverage of uncovered pairs, minimizing the total number
    of chunks needed to ensure every pair of items appears in at least one chunk.

    For large inputs where C(n,k) > MAX_COMBINATIONS_TO_EVALUATE, random sampling
    is used instead of exhaustive search to maintain performance.

    Lower bound (Schönheim): F >= ceil(N/K * ceil((N-1)/(K-1)))

    Args:
        items: List of items to partition into covering chunks
        k: Maximum number of items per chunk

    Returns:
        List of tuples (chunk_items, global_indices) where global_indices maps
        each position in chunk_items to its index in the original items list.
    """
    n = len(items)
    if n <= k:
        return [(items, list(range(n)))]

    # Track uncovered pairs using frozensets of indices
    uncovered_pairs: set[frozenset[int]] = {
        frozenset([i, j]) for i in range(n) for j in range(i + 1, n)
    }

    chunks: list[tuple[list[T], list[int]]] = []

    # Determine if we need to sample or can enumerate all combinations
    total_combinations = comb(n, k)
    use_sampling = total_combinations > MAX_COMBINATIONS_TO_EVALUATE

    while uncovered_pairs:
        # Greedy selection: find the chunk that covers the most uncovered pairs
        best_chunk_indices: tuple[int, ...] | None = None
        best_covered_count = 0

        if use_sampling:
            # Sample random combinations when there are too many to enumerate
            seen_combinations: set[tuple[int, ...]] = set()
            # Limit total attempts (including duplicates) to prevent infinite loops
            max_total_attempts = MAX_COMBINATIONS_TO_EVALUATE * 3
            total_attempts = 0
            samples_evaluated = 0
            while samples_evaluated < MAX_COMBINATIONS_TO_EVALUATE:
                total_attempts += 1
                if total_attempts > max_total_attempts:
                    # Too many total attempts, break to avoid infinite loop
                    break
                chunk_indices = _random_combination(n, k)
                if chunk_indices in seen_combinations:
                    continue
                seen_combinations.add(chunk_indices)
                samples_evaluated += 1

                # Count how many uncovered pairs this chunk covers
                covered_count = sum(
                    1
                    for i, idx_i in enumerate(chunk_indices)
                    for idx_j in chunk_indices[i + 1 :]
                    if frozenset([idx_i, idx_j]) in uncovered_pairs
                )

                if covered_count > best_covered_count:
                    best_covered_count = covered_count
                    best_chunk_indices = chunk_indices
        else:
            # Enumerate all combinations when feasible
            for chunk_indices in combinations(range(n), k):
                # Count how many uncovered pairs this chunk covers
                covered_count = sum(
                    1
                    for i, idx_i in enumerate(chunk_indices)
                    for idx_j in chunk_indices[i + 1 :]
                    if frozenset([idx_i, idx_j]) in uncovered_pairs
                )

                if covered_count > best_covered_count:
                    best_covered_count = covered_count
                    best_chunk_indices = chunk_indices

        if best_chunk_indices is None or best_covered_count == 0:
            # Greedy search couldn't find a chunk covering uncovered pairs.
            # This can happen with random sampling. Fall back to creating
            # small chunks that directly cover remaining pairs.
            break

        # Mark pairs in this chunk as covered
        for i, idx_i in enumerate(best_chunk_indices):
            for idx_j in best_chunk_indices[i + 1 :]:
                uncovered_pairs.discard(frozenset([idx_i, idx_j]))

        chunk_items = [items[idx] for idx in best_chunk_indices]
        chunks.append((chunk_items, list(best_chunk_indices)))

    # Handle any remaining uncovered pairs that the greedy algorithm missed.
    # This can happen when random sampling fails to find covering chunks.
    # Create minimal chunks (size 2) to guarantee all pairs are covered.
    for pair in uncovered_pairs:
        pair_indices = sorted(pair)
        chunk_items = [items[idx] for idx in pair_indices]
        chunks.append((chunk_items, pair_indices))

    return chunks

```
Page 6/10FirstPrevNextLast