This is page 6 of 9. Use http://codebase.md/getzep/graphiti?page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── dense_vs_normal_ingestion.py
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── content_chunking.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_entity_extraction.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ ├── search
│ │ └── search_utils_test.py
│ └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_edge_operations.py:
--------------------------------------------------------------------------------
```python
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
DEFAULT_EDGE_NAME,
resolve_extracted_edge,
resolve_extracted_edges,
)
@pytest.fixture
def mock_llm_client():
client = MagicMock()
client.generate_response = AsyncMock()
return client
@pytest.fixture
def mock_extracted_edge():
return EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='Test fact',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
@pytest.fixture
def mock_related_edges():
return [
EntityEdge(
source_node_uuid='source_uuid_2',
target_node_uuid='target_uuid_2',
name='related_edge',
group_id='group_1',
fact='Related fact',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
invalid_at=None,
)
]
@pytest.fixture
def mock_existing_edges():
return [
EntityEdge(
source_node_uuid='source_uuid_3',
target_node_uuid='target_uuid_3',
name='existing_edge',
group_id='group_1',
fact='Existing fact',
episodes=['episode_3'],
created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None,
)
]
@pytest.fixture
def mock_current_episode():
return EpisodicNode(
uuid='episode_1',
content='Current episode content',
valid_at=datetime.now(timezone.utc),
name='Current Episode',
group_id='group_1',
source='message',
source_description='Test source description',
)
@pytest.fixture
def mock_previous_episodes():
return [
EpisodicNode(
uuid='episode_2',
content='Previous episode content',
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
name='Previous Episode',
group_id='group_1',
source='message',
source_description='Test source description',
)
]
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])
@pytest.mark.asyncio
async def test_resolve_extracted_edge_exact_fact_short_circuit(
mock_llm_client,
mock_existing_edges,
mock_current_episode,
):
extracted = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='Related fact',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
related_edges = [
EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='related_edge',
group_id='group_1',
fact=' related FACT ',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=None,
invalid_at=None,
)
]
resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted,
related_edges,
mock_existing_edges,
mock_current_episode,
edge_type_candidates=None,
)
assert resolved_edge is related_edges[0]
assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
assert duplicate_edges == []
assert invalidated == []
mock_llm_client.generate_response.assert_not_called()
class OccurredAtEdge(BaseModel):
"""Edge model stub for OCCURRED_AT."""
@pytest.mark.asyncio
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
)
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
)
source_node = EntityNode(
uuid='source_uuid',
name='Document Node',
group_id='group_1',
labels=['Document'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == DEFAULT_EDGE_NAME
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
)
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
)
source_node = EntityNode(
uuid='source_uuid',
name='User Node',
group_id='group_1',
labels=['User'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='INTERACTED_WITH',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == 'INTERACTED_WITH'
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'OCCURRED_AT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='alt_source',
target_node_uuid='alt_target',
name='OTHER',
group_id='group_1',
fact='Different fact',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={},
custom_edge_type_names={'OCCURRED_AT'},
)
assert resolved_edge.name == DEFAULT_EDGE_NAME
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'INTERACTED_WITH',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User mentioned a topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
custom_edge_type_names={'OCCURRED_AT'},
)
assert resolved_edge.name == 'INTERACTED_WITH'
assert resolved_edge.attributes == {}
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
"""Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
# Mock LLM to return duplicate_facts with integer indices
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [0, 1], # LLM identifies first two related edges as duplicates
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User likes yoga',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
# Create multiple related edges - LLM should receive these with integer indices
related_edge_0 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User enjoys yoga',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=None,
invalid_at=None,
)
related_edge_1 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User practices yoga',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=None,
invalid_at=None,
)
related_edge_2 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User loves swimming',
episodes=['episode_3'],
created_at=datetime.now(timezone.utc) - timedelta(days=3),
valid_at=None,
invalid_at=None,
)
related_edges = [related_edge_0, related_edge_1, related_edge_2]
resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
related_edges,
[],
episode,
edge_type_candidates=None,
custom_edge_type_names=set(),
)
# Verify LLM was called
mock_llm_client.generate_response.assert_called_once()
# Verify the system correctly identified duplicates using integer indices
# The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
assert len(duplicates) == 2
assert related_edge_0 in duplicates
assert related_edge_1 in duplicates
assert invalidated == []
# Verify that the resolved edge is one of the duplicates (the first one found)
# Check UUID since the episode list gets modified
assert resolved_edge.uuid == related_edge_0.uuid
assert episode.uuid in resolved_edge.episodes
@pytest.mark.asyncio
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
"""Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
# Track how many times resolve_extracted_edge is called
resolve_call_count = 0
async def mock_resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
edge_type_candidates=None,
custom_edge_type_names=None,
):
nonlocal resolve_call_count
resolve_call_count += 1
return extracted_edge, [], []
# Mock semaphore_gather to execute awaitable immediately
async def immediate_gather(*aws, max_coroutines=None):
results = []
for aw in aws:
results.append(await aw)
return results
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
llm_client = MagicMock()
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
)
source_node = EntityNode(
uuid='source_uuid',
name='Assistant',
group_id='group_1',
labels=['Entity'],
)
target_node = EntityNode(
uuid='target_uuid',
name='User',
group_id='group_1',
labels=['Entity'],
)
# Create 3 identical edges
edge1 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact='assistant recommends yoga poses',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
edge2 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact=' Assistant Recommends YOGA Poses ', # Different whitespace/case
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
edge3 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact='assistant recommends yoga poses',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[edge1, edge2, edge3],
episode,
[source_node, target_node],
{},
{},
)
# Fast path should have deduplicated the 3 identical edges to 1
# So resolve_extracted_edge should only be called once
assert resolve_call_count == 1
assert len(resolved_edges) == 1
assert invalidated_edges == []
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_stress_load.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Stress and load testing for Graphiti MCP Server.
Tests system behavior under high load, resource constraints, and edge conditions.
"""
import asyncio
import gc
import random
import time
from dataclasses import dataclass
import psutil
import pytest
from test_fixtures import TestDataGenerator, graphiti_test_client
@dataclass
class LoadTestConfig:
"""Configuration for load testing scenarios."""
num_clients: int = 10
operations_per_client: int = 100
ramp_up_time: float = 5.0 # seconds
test_duration: float = 60.0 # seconds
target_throughput: float | None = None # ops/sec
think_time: float = 0.1 # seconds between ops
@dataclass
class LoadTestResult:
"""Results from a load test run."""
total_operations: int
successful_operations: int
failed_operations: int
duration: float
throughput: float
average_latency: float
p50_latency: float
p95_latency: float
p99_latency: float
max_latency: float
errors: dict[str, int]
resource_usage: dict[str, float]
class LoadTester:
"""Orchestrate load testing scenarios."""
def __init__(self, config: LoadTestConfig):
self.config = config
self.metrics: list[tuple[float, float, bool]] = [] # (start, duration, success)
self.errors: dict[str, int] = {}
self.start_time: float | None = None
async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
"""Run workload for a single simulated client."""
stats = {'success': 0, 'failure': 0}
data_gen = TestDataGenerator()
# Ramp-up delay
ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
await asyncio.sleep(ramp_delay)
for op_num in range(self.config.operations_per_client):
operation_start = time.time()
try:
# Randomly select operation type
operation = random.choice(
[
'add_memory',
'search_memory_nodes',
'get_episodes',
]
)
if operation == 'add_memory':
args = {
'name': f'Load Test {client_id}-{op_num}',
'episode_body': data_gen.generate_technical_document(),
'source': 'text',
'source_description': 'load test',
'group_id': group_id,
}
elif operation == 'search_memory_nodes':
args = {
'query': random.choice(['performance', 'architecture', 'test', 'data']),
'group_id': group_id,
'limit': 10,
}
else: # get_episodes
args = {
'group_id': group_id,
'last_n': 10,
}
# Execute operation with timeout
await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)
duration = time.time() - operation_start
self.metrics.append((operation_start, duration, True))
stats['success'] += 1
except asyncio.TimeoutError:
duration = time.time() - operation_start
self.metrics.append((operation_start, duration, False))
self.errors['timeout'] = self.errors.get('timeout', 0) + 1
stats['failure'] += 1
except Exception as e:
duration = time.time() - operation_start
self.metrics.append((operation_start, duration, False))
error_type = type(e).__name__
self.errors[error_type] = self.errors.get(error_type, 0) + 1
stats['failure'] += 1
# Think time between operations
await asyncio.sleep(self.config.think_time)
# Stop if we've exceeded test duration
if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
break
return stats
def calculate_results(self) -> LoadTestResult:
"""Calculate load test results from metrics."""
if not self.metrics:
return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})
successful = [m for m in self.metrics if m[2]]
failed = [m for m in self.metrics if not m[2]]
latencies = sorted([m[1] for m in self.metrics])
duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
# Calculate percentiles
def percentile(data: list[float], p: float) -> float:
if not data:
return 0.0
idx = int(len(data) * p / 100)
return data[min(idx, len(data) - 1)]
# Get resource usage
process = psutil.Process()
resource_usage = {
'cpu_percent': process.cpu_percent(),
'memory_mb': process.memory_info().rss / 1024 / 1024,
'num_threads': process.num_threads(),
}
return LoadTestResult(
total_operations=len(self.metrics),
successful_operations=len(successful),
failed_operations=len(failed),
duration=duration,
throughput=len(self.metrics) / duration if duration > 0 else 0,
average_latency=sum(latencies) / len(latencies) if latencies else 0,
p50_latency=percentile(latencies, 50),
p95_latency=percentile(latencies, 95),
p99_latency=percentile(latencies, 99),
max_latency=max(latencies) if latencies else 0,
errors=self.errors,
resource_usage=resource_usage,
)
class TestLoadScenarios:
"""Various load testing scenarios."""
@pytest.mark.asyncio
@pytest.mark.slow
async def test_sustained_load(self):
"""Test system under sustained moderate load."""
config = LoadTestConfig(
num_clients=5,
operations_per_client=20,
ramp_up_time=2.0,
test_duration=30.0,
think_time=0.5,
)
async with graphiti_test_client() as (session, group_id):
tester = LoadTester(config)
tester.start_time = time.time()
# Run client workloads
client_tasks = []
for client_id in range(config.num_clients):
task = tester.run_client_workload(client_id, session, group_id)
client_tasks.append(task)
# Execute all clients
await asyncio.gather(*client_tasks)
# Calculate results
results = tester.calculate_results()
# Assertions
assert results.successful_operations > results.failed_operations
assert results.average_latency < 5.0, (
f'Average latency too high: {results.average_latency:.2f}s'
)
assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'
# Report results
print('\nSustained Load Test Results:')
print(f' Total operations: {results.total_operations}')
print(
f' Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
)
print(f' Throughput: {results.throughput:.2f} ops/s')
print(f' Avg latency: {results.average_latency:.2f}s')
print(f' P95 latency: {results.p95_latency:.2f}s')
@pytest.mark.asyncio
@pytest.mark.slow
async def test_spike_load(self):
"""Test system response to sudden load spikes."""
async with graphiti_test_client() as (session, group_id):
# Normal load phase
normal_tasks = []
for i in range(3):
task = session.call_tool(
'add_memory',
{
'name': f'Normal Load {i}',
'episode_body': 'Normal operation',
'source': 'text',
'source_description': 'normal',
'group_id': group_id,
},
)
normal_tasks.append(task)
await asyncio.sleep(0.5)
await asyncio.gather(*normal_tasks)
# Spike phase - sudden burst of requests
spike_start = time.time()
spike_tasks = []
for i in range(50):
task = session.call_tool(
'add_memory',
{
'name': f'Spike Load {i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'spike',
'group_id': group_id,
},
)
spike_tasks.append(task)
# Execute spike
spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
spike_duration = time.time() - spike_start
# Analyze spike handling
spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
print('\nSpike Load Test Results:')
print(f' Spike size: {len(spike_tasks)} operations')
print(f' Duration: {spike_duration:.2f}s')
print(f' Success rate: {spike_success_rate * 100:.1f}%')
print(f' Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')
# System should handle at least 80% of spike
assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'
@pytest.mark.asyncio
@pytest.mark.slow
async def test_memory_leak_detection(self):
"""Test for memory leaks during extended operation."""
async with graphiti_test_client() as (session, group_id):
process = psutil.Process()
gc.collect() # Force garbage collection
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
# Perform many operations
for batch in range(10):
batch_tasks = []
for i in range(10):
task = session.call_tool(
'add_memory',
{
'name': f'Memory Test {batch}-{i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'memory test',
'group_id': group_id,
},
)
batch_tasks.append(task)
await asyncio.gather(*batch_tasks)
# Force garbage collection between batches
gc.collect()
await asyncio.sleep(1)
# Check memory after operations
gc.collect()
final_memory = process.memory_info().rss / 1024 / 1024 # MB
memory_growth = final_memory - initial_memory
print('\nMemory Leak Test:')
print(f' Initial memory: {initial_memory:.1f} MB')
print(f' Final memory: {final_memory:.1f} MB')
print(f' Growth: {memory_growth:.1f} MB')
# Allow for some memory growth but flag potential leaks
# This is a soft check - actual threshold depends on system
if memory_growth > 100: # More than 100MB growth
print(f' ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth')
@pytest.mark.asyncio
@pytest.mark.slow
async def test_connection_pool_exhaustion(self):
"""Test behavior when connection pools are exhausted."""
async with graphiti_test_client() as (session, group_id):
# Create many concurrent long-running operations
long_tasks = []
for i in range(100): # Many more than typical pool size
task = session.call_tool(
'search_memory_nodes',
{
'query': f'complex query {i} '
+ ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
'group_id': group_id,
'limit': 100,
},
)
long_tasks.append(task)
# Execute with timeout
try:
results = await asyncio.wait_for(
asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
)
# Count connection-related errors
connection_errors = sum(
1
for r in results
if isinstance(r, Exception) and 'connection' in str(r).lower()
)
print('\nConnection Pool Test:')
print(f' Total requests: {len(long_tasks)}')
print(f' Connection errors: {connection_errors}')
except asyncio.TimeoutError:
print(' Test timed out - possible deadlock or exhaustion')
@pytest.mark.asyncio
@pytest.mark.slow
async def test_gradual_degradation(self):
"""Test system degradation under increasing load."""
async with graphiti_test_client() as (session, group_id):
load_levels = [5, 10, 20, 40, 80] # Increasing concurrent operations
results_by_level = {}
for level in load_levels:
level_start = time.time()
tasks = []
for i in range(level):
task = session.call_tool(
'add_memory',
{
'name': f'Load Level {level} Op {i}',
'episode_body': f'Testing at load level {level}',
'source': 'text',
'source_description': 'degradation test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute level
level_results = await asyncio.gather(*tasks, return_exceptions=True)
level_duration = time.time() - level_start
# Calculate metrics
failures = sum(1 for r in level_results if isinstance(r, Exception))
success_rate = (level - failures) / level * 100
throughput = level / level_duration
results_by_level[level] = {
'success_rate': success_rate,
'throughput': throughput,
'duration': level_duration,
}
print(f'\nLoad Level {level}:')
print(f' Success rate: {success_rate:.1f}%')
print(f' Throughput: {throughput:.2f} ops/s')
print(f' Duration: {level_duration:.2f}s')
# Brief pause between levels
await asyncio.sleep(2)
# Verify graceful degradation
# Success rate should not drop below 50% even at high load
for level, metrics in results_by_level.items():
assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'
class TestResourceLimits:
"""Test behavior at resource limits."""
@pytest.mark.asyncio
async def test_large_payload_handling(self):
"""Test handling of very large payloads."""
async with graphiti_test_client() as (session, group_id):
payload_sizes = [
(1_000, '1KB'),
(10_000, '10KB'),
(100_000, '100KB'),
(1_000_000, '1MB'),
]
for size, label in payload_sizes:
content = 'x' * size
start_time = time.time()
try:
await asyncio.wait_for(
session.call_tool(
'add_memory',
{
'name': f'Large Payload {label}',
'episode_body': content,
'source': 'text',
'source_description': 'payload test',
'group_id': group_id,
},
),
timeout=30.0,
)
duration = time.time() - start_time
status = '✅ Success'
except asyncio.TimeoutError:
duration = 30.0
status = '⏱️ Timeout'
except Exception as e:
duration = time.time() - start_time
status = f'❌ Error: {type(e).__name__}'
print(f'Payload {label}: {status} ({duration:.2f}s)')
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
"""Test handling of rate limits."""
async with graphiti_test_client() as (session, group_id):
# Rapid fire requests to trigger rate limits
rapid_tasks = []
for i in range(100):
task = session.call_tool(
'add_memory',
{
'name': f'Rate Limit Test {i}',
'episode_body': f'Testing rate limit {i}',
'source': 'text',
'source_description': 'rate test',
'group_id': group_id,
},
)
rapid_tasks.append(task)
# Execute without delays
results = await asyncio.gather(*rapid_tasks, return_exceptions=True)
# Count rate limit errors
rate_limit_errors = sum(
1
for r in results
if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
)
print('\nRate Limit Test:')
print(f' Total requests: {len(rapid_tasks)}')
print(f' Rate limit errors: {rate_limit_errors}')
print(
f' Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
)
def generate_load_test_report(results: list[LoadTestResult]) -> str:
"""Generate comprehensive load test report."""
report = []
report.append('\n' + '=' * 60)
report.append('LOAD TEST REPORT')
report.append('=' * 60)
for i, result in enumerate(results):
report.append(f'\nTest Run {i + 1}:')
report.append(f' Total Operations: {result.total_operations}')
report.append(
f' Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
)
report.append(f' Throughput: {result.throughput:.2f} ops/s')
report.append(
f' Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
)
if result.errors:
report.append(' Errors:')
for error_type, count in result.errors.items():
report.append(f' {error_type}: {count}')
report.append(' Resource Usage:')
for metric, value in result.resource_usage.items():
report.append(f' {metric}: {value:.2f}')
report.append('=' * 60)
return '\n'.join(report)
if __name__ == '__main__':
pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])
```
--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from datetime import datetime
import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any
from graphiti_core.driver.driver import (
GraphDriver,
GraphDriverSession,
GraphProvider,
)
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
get_entity_edge_save_bulk_query,
get_episodic_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupResolutionState,
_build_candidate_indexes,
_normalize_string_exact,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edge,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
logger = logging.getLogger(__name__)
CHUNK_SIZE = 10
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
"""Collapse alias -> canonical chains while preserving direction.
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
"""
parent: dict[str, str] = {}
def find(uuid: str) -> str:
"""Directed union-find lookup using iterative path compression."""
parent.setdefault(uuid, uuid)
root = uuid
while parent[root] != root:
root = parent[root]
while parent[uuid] != root:
next_uuid = parent[uuid]
parent[uuid] = root
uuid = next_uuid
return root
for source_uuid, target_uuid in pairs:
parent.setdefault(source_uuid, source_uuid)
parent.setdefault(target_uuid, target_uuid)
parent[find(source_uuid)] = find(target_uuid)
return {uuid: find(uuid) for uuid in parent}
class RawEpisode(BaseModel):
name: str
uuid: str | None = Field(default=None)
content: str
source_description: str
source: EpisodeType
reference_time: datetime
async def retrieve_previous_episodes_bulk(
driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await semaphore_gather(
*[
retrieve_episodes(
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
)
for episode in episodes
]
)
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
]
return episode_tuples
async def add_nodes_and_edges_bulk(
driver: GraphDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
):
session = driver.session()
try:
await session.execute_write(
add_nodes_and_edges_bulk_tx,
episodic_nodes,
episodic_edges,
entity_nodes,
entity_edges,
embedder,
driver=driver,
)
finally:
await session.close()
async def add_nodes_and_edges_bulk_tx(
tx: GraphDriverSession,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
driver: GraphDriver,
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
episode.pop('labels', None)
nodes = []
for node in entity_nodes:
if node.name_embedding is None:
await node.generate_name_embedding(embedder)
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
'group_id': node.group_id,
'summary': node.summary,
'created_at': node.created_at,
'name_embedding': node.name_embedding,
'labels': list(set(node.labels + ['Entity'])),
}
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
entity_data['attributes'] = json.dumps(attributes)
else:
entity_data.update(node.attributes or {})
nodes.append(entity_data)
edges = []
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_embedding(embedder)
edge_data: dict[str, Any] = {
'uuid': edge.uuid,
'source_node_uuid': edge.source_node_uuid,
'target_node_uuid': edge.target_node_uuid,
'name': edge.name,
'fact': edge.fact,
'group_id': edge.group_id,
'episodes': edge.episodes,
'created_at': edge.created_at,
'expired_at': edge.expired_at,
'valid_at': edge.valid_at,
'invalid_at': edge.invalid_at,
'fact_embedding': edge.fact_embedding,
}
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
else:
edge_data.update(edge.attributes or {})
edges.append(edge_data)
if driver.graph_operations_interface:
await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
await driver.graph_operations_interface.episodic_edge_save_bulk(
None, driver, tx, [edge.model_dump() for edge in episodic_edges]
)
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
elif driver.provider == GraphProvider.KUZU:
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
episode_query = get_episode_node_save_bulk_query(driver.provider)
for episode in episodes:
await tx.run(episode_query, **episode)
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
for node in nodes:
await tx.run(entity_node_query, **node)
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
for edge in edges:
await tx.run(entity_edge_query, **edge)
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
for edge in episodic_edges:
await tx.run(episodic_edge_query, **edge.model_dump())
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
)
async def extract_nodes_and_edges_bulk(
clients: GraphitiClients,
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
edge_type_map: dict[tuple[str, str], list[str]],
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
custom_extraction_instructions: str | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
*[
extract_nodes(
clients,
episode,
previous_episodes,
entity_types=entity_types,
excluded_entity_types=excluded_entity_types,
custom_extraction_instructions=custom_extraction_instructions,
)
for episode, previous_episodes in episode_tuples
]
)
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
*[
extract_edges(
clients,
episode,
extracted_nodes_bulk[i],
previous_episodes,
edge_type_map=edge_type_map,
group_id=episode.group_id,
edge_types=edge_types,
custom_extraction_instructions=custom_extraction_instructions,
)
for i, (episode, previous_episodes) in enumerate(episode_tuples)
]
)
return extracted_nodes_bulk, extracted_edges_bulk
async def dedupe_nodes_bulk(
clients: GraphitiClients,
extracted_nodes: list[list[EntityNode]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
reconciled against the live graph just like the non-batch flow.
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
can apply to edges and persistence.
"""
first_pass_results = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
nodes,
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
)
for i, nodes in enumerate(extracted_nodes)
]
)
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
duplicate_pairs: list[tuple[str, str]] = []
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
first_pass_results, episode_tuples, strict=True
):
episode_resolutions.append((episode.uuid, resolved_nodes))
per_episode_uuid_maps.append(uuid_map)
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
canonical_nodes: dict[str, EntityNode] = {}
for _, resolved_nodes in episode_resolutions:
for node in resolved_nodes:
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
# but if batches grow significantly we should switch to an incremental index or chunked
# processing.
if not canonical_nodes:
canonical_nodes[node.uuid] = node
continue
existing_candidates = list(canonical_nodes.values())
normalized = _normalize_string_exact(node.name)
exact_match = next(
(
candidate
for candidate in existing_candidates
if _normalize_string_exact(candidate.name) == normalized
),
None,
)
if exact_match is not None:
if exact_match.uuid != node.uuid:
duplicate_pairs.append((node.uuid, exact_match.uuid))
continue
indexes = _build_candidate_indexes(existing_candidates)
state = DedupResolutionState(
resolved_nodes=[None],
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity([node], indexes, state)
resolved = state.resolved_nodes[0]
if resolved is None:
canonical_nodes[node.uuid] = node
continue
canonical_uuid = resolved.uuid
canonical_nodes.setdefault(canonical_uuid, resolved)
if canonical_uuid != node.uuid:
duplicate_pairs.append((node.uuid, canonical_uuid))
union_pairs: list[tuple[str, str]] = []
for uuid_map in per_episode_uuid_maps:
union_pairs.extend(uuid_map.items())
union_pairs.extend(duplicate_pairs)
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
nodes_by_episode: dict[str, list[EntityNode]] = {}
for episode_uuid, resolved_nodes in episode_resolutions:
deduped_nodes: list[EntityNode] = []
seen: set[str] = set()
for node in resolved_nodes:
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
if canonical_uuid in seen:
continue
seen.add(canonical_uuid)
canonical_node = canonical_nodes.get(canonical_uuid)
if canonical_node is None:
logger.error(
'Canonical node %s missing during batch dedupe; falling back to %s',
canonical_uuid,
node.uuid,
)
canonical_node = node
deduped_nodes.append(canonical_node)
nodes_by_episode[episode_uuid] = deduped_nodes
return nodes_by_episode, compressed_map
async def dedupe_edges_bulk(
clients: GraphitiClients,
extracted_edges: list[list[EntityEdge]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
_entities: list[EntityNode],
edge_types: dict[str, type[BaseModel]],
_edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
embedder = clients.embedder
min_score = 0.6
# generate embeddings
await semaphore_gather(
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
)
# Find similar results
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
for i, edges_i in enumerate(extracted_edges):
existing_edges: list[EntityEdge] = []
for edges_j in extracted_edges:
existing_edges += edges_j
for edge in edges_i:
candidates: list[EntityEdge] = []
for existing_edge in existing_edges:
# Skip self-comparison
if edge.uuid == existing_edge.uuid:
continue
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
if (
edge.source_node_uuid != existing_edge.source_node_uuid
or edge.target_node_uuid != existing_edge.target_node_uuid
):
continue
edge_words = set(edge.fact.lower().split())
existing_edge_words = set(existing_edge.fact.lower().split())
has_overlap = not edge_words.isdisjoint(existing_edge_words)
if has_overlap:
candidates.append(existing_edge)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(edge.fact_embedding or []),
normalize_l2(existing_edge.fact_embedding or []),
)
if similarity >= min_score:
candidates.append(existing_edge)
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
bulk_edge_resolutions: list[
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
] = await semaphore_gather(
*[
resolve_extracted_edge(
clients.llm_client,
edge,
candidates,
candidates,
episode,
edge_types,
set(edge_types),
)
for episode, edge, candidates in dedupe_tuples
]
)
# For now we won't track edge invalidation
duplicate_pairs: list[tuple[str, str]] = []
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
episode, edge, candidates = dedupe_tuples[i]
for duplicate in duplicates:
duplicate_pairs.append((edge.uuid, duplicate.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
edge_uuid_map: dict[str, EntityEdge] = {
edge.uuid: edge for edges in extracted_edges for edge in edges
}
edges_by_episode: dict[str, list[EntityEdge]] = {}
for i, edges in enumerate(extracted_edges):
episode = episode_tuples[i][0]
edges_by_episode[episode.uuid] = [
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
]
return edges_by_episode
class UnionFind:
def __init__(self, elements):
# start each element in its own set
self.parent = {e: e for e in elements}
def find(self, x):
# path‐compression
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, a, b):
ra, rb = self.find(a), self.find(b)
if ra == rb:
return
# attach the lexicographically larger root under the smaller
if ra < rb:
self.parent[rb] = ra
else:
self.parent[ra] = rb
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
"""
all_ids: iterable of all entity IDs (strings)
duplicate_pairs: iterable of (id1, id2) pairs
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
"""
all_uuids = set()
for pair in duplicate_pairs:
all_uuids.add(pair[0])
all_uuids.add(pair[1])
uf = UnionFind(all_uuids)
for a, b in duplicate_pairs:
uf.union(a, b)
# ensure full path‐compression before mapping
return {uuid: uf.find(uuid) for uuid in all_uuids}
E = typing.TypeVar('E', bound=Edge)
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
return edges
```
--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from time import time
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphDriver, GraphProvider
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.edges.edge_db_queries import (
COMMUNITY_EDGE_RETURN,
EPISODIC_EDGE_RETURN,
EPISODIC_EDGE_SAVE,
get_community_edge_save_query,
get_entity_edge_return_query,
get_entity_edge_save_query,
)
from graphiti_core.nodes import Node
logger = logging.getLogger(__name__)
class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
group_id: str = Field(description='partition of the graph')
source_node_uuid: str
target_node_uuid: str
created_at: datetime
@abstractmethod
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_delete(self, driver)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_ {uuid: $uuid})
DETACH DELETE e
""",
uuid=self.uuid,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
DELETE e
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Edge: {self.uuid}')
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
if driver.provider == GraphProvider.KUZU:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
await driver.execute_query(
"""
MATCH (e:RelatesToNode_)
WHERE e.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
)
else:
await driver.execute_query(
"""
MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
DELETE e
""",
uuids=uuids,
)
logger.debug(f'Deleted Edges: {uuids}')
def __hash__(self):
return hash(self.uuid)
def __eq__(self, other):
if isinstance(other, Node):
return self.uuid == other.uuid
return False
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
class EpisodicEdge(Edge):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
EPISODIC_EDGE_SAVE,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
RETURN
"""
+ EPISODIC_EDGE_RETURN,
uuid=uuid,
routing_='r',
)
edges = [get_episodic_edge_from_record(record) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.uuid IN $uuids
RETURN
"""
+ EPISODIC_EDGE_RETURN,
uuids=uuids,
routing_='r',
)
edges = [get_episodic_edge_from_record(record) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuids[0])
return edges
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ EPISODIC_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
edges = [get_episodic_edge_from_record(record) for record in records]
if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids)
return edges
class EntityEdge(Edge):
name: str = Field(description='name of the edge, relation name')
fact: str = Field(description='fact representing the edge and nodes that it connects')
fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
episodes: list[str] = Field(
default=[],
description='list of episode ids that reference these entity edges',
)
expired_at: datetime | None = Field(
default=None, description='datetime of when the node was invalidated'
)
valid_at: datetime | None = Field(
default=None, description='datetime of when the fact became true'
)
invalid_at: datetime | None = Field(
default=None, description='datetime of when the fact stopped being true'
)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the edge. Dependent on edge name'
)
async def generate_embedding(self, embedder: EmbedderClient):
start = time()
text = self.fact.replace('\n', ' ')
self.fact_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
return self.fact_embedding
async def load_fact_embedding(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
if driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
"""
if driver.provider == GraphProvider.KUZU:
query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
RETURN e.fact_embedding AS fact_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise EdgeNotFoundError(self.uuid)
self.fact_embedding = records[0]['fact_embedding']
async def save(self, driver: GraphDriver):
edge_data: dict[str, Any] = {
'source_uuid': self.source_node_uuid,
'target_uuid': self.target_node_uuid,
'uuid': self.uuid,
'name': self.name,
'group_id': self.group_id,
'fact': self.fact,
'fact_embedding': self.fact_embedding,
'episodes': self.episodes,
'created_at': self.created_at,
'expired_at': self.expired_at,
'valid_at': self.valid_at,
'invalid_at': self.invalid_at,
}
if driver.provider == GraphProvider.KUZU:
edge_data['attributes'] = json.dumps(self.attributes)
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
**edge_data,
)
else:
edge_data.update(self.attributes or {})
result = await driver.execute_query(
get_entity_edge_save_query(driver.provider),
edge_data=edge_data,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
match_query = """
MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise EdgeNotFoundError(uuid)
return edges[0]
@classmethod
async def get_between_nodes(
cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
):
match_query = """
MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $source_node_uuid})
-[:RELATES_TO]->(e:RelatesToNode_)
-[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
source_node_uuid=source_node_uuid,
target_node_uuid=target_node_uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
if len(uuids) == 0:
return []
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.uuid IN $uuids
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
with_embeddings: bool = False,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
e.fact_embedding AS fact_embedding
"""
if with_embeddings
else ''
)
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
if len(edges) == 0:
raise GroupsEdgesNotFoundError(group_ids)
return edges
@classmethod
async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
records, _, _ = await driver.execute_query(
match_query
+ """
RETURN
"""
+ get_entity_edge_return_query(driver.provider),
node_uuid=node_uuid,
routing_='r',
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
class CommunityEdge(Edge):
async def save(self, driver: GraphDriver):
result = await driver.execute_query(
get_community_edge_save_query(driver.provider),
community_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
group_id=self.group_id,
created_at=self.created_at,
)
logger.debug(f'Saved edge to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
RETURN
"""
+ COMMUNITY_EDGE_RETURN,
uuid=uuid,
routing_='r',
)
edges = [get_community_edge_from_record(record) for record in records]
return edges[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
WHERE e.uuid IN $uuids
RETURN
"""
+ COMMUNITY_EDGE_RETURN,
uuids=uuids,
routing_='r',
)
edges = [get_community_edge_from_record(record) for record in records]
return edges
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (n:Community)-[e:HAS_MEMBER]->(m)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ COMMUNITY_EDGE_RETURN
+ """
ORDER BY e.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
edges = [get_community_edge_from_record(record) for record in records]
return edges
# Edge helpers
def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
return EpisodicEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']), # type: ignore
)
def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
episodes = record['episodes']
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('source_node_uuid', None)
attributes.pop('target_node_uuid', None)
attributes.pop('fact', None)
attributes.pop('fact_embedding', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('episodes', None)
attributes.pop('created_at', None)
attributes.pop('expired_at', None)
attributes.pop('valid_at', None)
attributes.pop('invalid_at', None)
edge = EntityEdge(
uuid=record['uuid'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
fact=record['fact'],
fact_embedding=record.get('fact_embedding'),
name=record['name'],
group_id=record['group_id'],
episodes=episodes,
created_at=parse_db_date(record['created_at']), # type: ignore
expired_at=parse_db_date(record['expired_at']),
valid_at=parse_db_date(record['valid_at']),
invalid_at=parse_db_date(record['invalid_at']),
attributes=attributes,
)
return edge
def get_community_edge_from_record(record: Any):
return CommunityEdge(
uuid=record['uuid'],
group_id=record['group_id'],
source_node_uuid=record['source_node_uuid'],
target_node_uuid=record['target_node_uuid'],
created_at=parse_db_date(record['created_at']), # type: ignore
)
async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
# filter out falsey values from edges
filtered_edges = [edge for edge in edges if edge.fact]
if len(filtered_edges) == 0:
return
fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
edge.fact_embedding = fact_embedding
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------
```python
import logging
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupCandidateIndexes,
DedupResolutionState,
_build_candidate_indexes,
_cached_shingles,
_has_high_entropy,
_hash_shingle,
_jaccard_similarity,
_lsh_bands,
_minhash_signature,
_name_entropy,
_normalize_name_for_fuzzy,
_normalize_string_exact,
_resolve_with_similarity,
_shingles,
)
from graphiti_core.utils.maintenance.node_operations import (
_collect_candidate_nodes,
_resolve_with_llm,
extract_attributes_from_node,
extract_attributes_from_nodes,
resolve_extracted_nodes,
)
def _make_clients():
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
llm_generate = AsyncMock()
llm_client.generate_response = llm_generate
clients = GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
)
return clients, llm_generate
def _make_episode(group_id: str = 'group'):
return EpisodicNode(
name='episode',
group_id=group_id,
source=EpisodeType.message,
source_description='test',
content='content',
valid_at=utc_now(),
)
@pytest.mark.asyncio
async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
clients, llm_generate = _make_clients()
candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
async def fake_search(*_, **__):
return SearchResults(nodes=[candidate])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
fake_search,
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
AsyncMock(return_value=[]),
)
resolved, uuid_map, _ = await resolve_extracted_nodes(
clients,
[extracted],
episode=_make_episode(),
previous_episodes=[],
)
assert resolved[0].uuid == candidate.uuid
assert uuid_map[extracted.uuid] == candidate.uuid
llm_generate.assert_not_awaited()
@pytest.mark.asyncio
async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
clients, llm_generate = _make_clients()
llm_generate.return_value = {
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': -1,
'name': 'Joe',
'duplicates': [],
}
]
}
extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])
async def fake_search(*_, **__):
return SearchResults(nodes=[])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
fake_search,
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
AsyncMock(return_value=[]),
)
resolved, uuid_map, _ = await resolve_extracted_nodes(
clients,
[extracted],
episode=_make_episode(),
previous_episodes=[],
)
assert resolved[0].uuid == extracted.uuid
assert uuid_map[extracted.uuid] == extracted.uuid
llm_generate.assert_awaited()
@pytest.mark.asyncio
async def test_resolve_nodes_fuzzy_match(monkeypatch):
clients, llm_generate = _make_clients()
candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
async def fake_search(*_, **__):
return SearchResults(nodes=[candidate])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
fake_search,
)
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
AsyncMock(return_value=[]),
)
resolved, uuid_map, _ = await resolve_extracted_nodes(
clients,
[extracted],
episode=_make_episode(),
previous_episodes=[],
)
assert resolved[0].uuid == candidate.uuid
assert uuid_map[extracted.uuid] == candidate.uuid
llm_generate.assert_not_awaited()
@pytest.mark.asyncio
async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
clients, _ = _make_clients()
candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
override_duplicate = EntityNode(
uuid=candidate.uuid,
name='Alice Alt',
group_id='group',
labels=['Entity'],
)
extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])
search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.search',
search_mock,
)
result = await _collect_candidate_nodes(
clients,
[extracted],
existing_nodes_override=[override_duplicate],
)
assert len(result) == 1
assert result[0].uuid == candidate.uuid
search_mock.assert_awaited()
def test_build_candidate_indexes_populates_structures():
candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
normalized_key = candidate.name.lower()
assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
assert indexes.nodes_by_uuid[candidate.uuid] is candidate
assert candidate.uuid in indexes.shingles_by_candidate
assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())
def test_normalize_helpers():
assert _normalize_string_exact(' Alice Smith ') == 'alice smith'
assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'
def test_name_entropy_variants():
assert _name_entropy('alice') > _name_entropy('aaaaa')
assert _name_entropy('') == 0.0
def test_has_high_entropy_rules():
assert _has_high_entropy('meaningful name') is True
assert _has_high_entropy('aa') is False
def test_shingles_and_cache():
raw = 'alice'
shingle_set = _shingles(raw)
assert shingle_set == {'ali', 'lic', 'ice'}
assert _cached_shingles(raw) == shingle_set
assert _cached_shingles(raw) is _cached_shingles(raw)
def test_hash_minhash_and_lsh():
shingles = {'abc', 'bcd', 'cde'}
signature = _minhash_signature(shingles)
assert len(signature) == 32
bands = _lsh_bands(signature)
assert all(len(band) == 4 for band in bands)
hashed = {_hash_shingle(s, 0) for s in shingles}
assert len(hashed) == len(shingles)
def test_jaccard_similarity_edges():
a = {'a', 'b'}
b = {'a', 'c'}
assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
assert _jaccard_similarity(set(), set()) == 1.0
assert _jaccard_similarity(a, set()) == 0.0
def test_resolve_with_similarity_exact_match_updates_state():
candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
_resolve_with_similarity([extracted], indexes, state)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.unresolved_indices == []
assert state.duplicate_pairs == [(extracted, candidate)]
def test_resolve_with_similarity_low_entropy_defers_resolution():
extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
indexes = DedupCandidateIndexes(
existing_nodes=[],
nodes_by_uuid={},
normalized_existing=defaultdict(list),
shingles_by_candidate={},
lsh_buckets=defaultdict(list),
)
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
_resolve_with_similarity([extracted], indexes, state)
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate1, candidate2])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
_resolve_with_similarity([extracted], indexes, state)
assert state.resolved_nodes[0] is None
assert state.unresolved_indices == [0]
assert state.duplicate_pairs == []
@pytest.mark.asyncio
async def test_resolve_with_llm_updates_unresolved(monkeypatch):
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
captured_context = {}
def fake_prompt_nodes(context):
captured_context.update(context)
return ['prompt']
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
fake_prompt_nodes,
)
async def fake_generate_response(*_, **__):
return {
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 0,
'name': 'Dizzy Gillespie',
'duplicates': [0],
}
]
}
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert captured_context['existing_nodes'][0]['idx'] == 0
assert isinstance(captured_context['existing_nodes'], list)
assert state.duplicate_pairs == [(extracted, candidate)]
@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 5,
'duplicate_idx': -1,
'name': 'Dexter',
'duplicates': [],
}
]
}
)
with caplog.at_level(logging.WARNING):
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0] is None
assert 'Skipping invalid LLM dedupe id 5' in caplog.text
@pytest.mark.asyncio
async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([candidate])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 0,
'name': 'Dizzy Gillespie',
'duplicates': [0],
},
{
'id': 0,
'duplicate_idx': -1,
'name': 'Dizzy',
'duplicates': [],
},
]
}
)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0].uuid == candidate.uuid
assert state.uuid_map[extracted.uuid] == candidate.uuid
assert state.duplicate_pairs == [(extracted, candidate)]
@pytest.mark.asyncio
async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
indexes = _build_candidate_indexes([])
state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
monkeypatch.setattr(
'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
lambda context: ['prompt'],
)
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'entity_resolutions': [
{
'id': 0,
'duplicate_idx': 10,
'name': 'Dexter',
'duplicates': [],
}
]
}
)
await _resolve_with_llm(
llm_client,
[extracted],
indexes,
state,
episode=_make_episode(),
previous_episodes=[],
entity_types=None,
)
assert state.resolved_nodes[0] == extracted
assert state.uuid_map[extracted.uuid] == extracted.uuid
assert state.duplicate_pairs == []
@pytest.mark.asyncio
async def test_extract_attributes_without_callback_generates_summary():
"""Test that summary is generated when no callback is provided (default behavior)."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=None, # No callback provided
)
# Summary should be generated
assert result.summary == 'Generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_skip_summary():
"""Test that summary is NOT regenerated when callback returns False."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'This should not be used', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns False (skip summary generation)
async def skip_summary_filter(node: EntityNode) -> bool:
return False
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=skip_summary_filter,
)
# Summary should remain unchanged
assert result.summary == 'Old summary'
# LLM should NOT have been called for summary
assert llm_client.generate_response.call_count == 0
@pytest.mark.asyncio
async def test_extract_attributes_with_callback_generate_summary():
"""Test that summary is regenerated when callback returns True."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'New generated summary', 'attributes': {}}
)
node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
episode = _make_episode()
# Callback that always returns True (generate summary)
async def generate_summary_filter(node: EntityNode) -> bool:
return True
result = await extract_attributes_from_node(
llm_client,
node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=generate_summary_filter,
)
# Summary should be updated
assert result.summary == 'New generated summary'
# LLM should have been called for summary
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_with_selective_callback():
"""Test callback that selectively skips summaries based on node properties."""
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={'summary': 'Generated summary', 'attributes': {}}
)
user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
topic_node = EntityNode(
name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
)
episode = _make_episode()
# Callback that skips User nodes but generates for others
async def selective_filter(node: EntityNode) -> bool:
return 'User' not in node.labels
result_user = await extract_attributes_from_node(
llm_client,
user_node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=selective_filter,
)
result_topic = await extract_attributes_from_node(
llm_client,
topic_node,
episode=episode,
previous_episodes=[],
entity_type=None,
should_summarize_node=selective_filter,
)
# User summary should remain unchanged
assert result_user.summary == 'Old'
# Topic summary should be generated
assert result_topic.summary == 'Generated summary'
# LLM should have been called only once (for topic)
assert llm_client.generate_response.call_count == 1
@pytest.mark.asyncio
async def test_extract_attributes_from_nodes_with_callback():
"""Test that callback is properly passed through extract_attributes_from_nodes."""
clients, _ = _make_clients()
clients.llm_client.generate_response = AsyncMock(
return_value={'summary': 'New summary', 'attributes': {}}
)
clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
episode = _make_episode()
call_tracker = []
# Callback that tracks which nodes it's called with
async def tracking_filter(node: EntityNode) -> bool:
call_tracker.append(node.name)
return 'User' not in node.labels
results = await extract_attributes_from_nodes(
clients,
[node1, node2],
episode=episode,
previous_episodes=[],
entity_types=None,
should_summarize_node=tracking_filter,
)
# Callback should have been called for both nodes
assert len(call_tracker) == 2
assert 'Node1' in call_tracker
assert 'Node2' in call_tracker
# Node1 (User) should keep old summary, Node2 (Topic) should get new summary
node1_result = next(n for n in results if n.name == 'Node1')
node2_result = next(n for n in results if n.name == 'Node2')
assert node1_result.summary == 'Old1'
assert node2_result.summary == 'New summary'
```
--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import BaseModel
from graphiti_core.llm_client.config import LLMConfig, ModelSize
from graphiti_core.llm_client.errors import RateLimitError
from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
from graphiti_core.prompts.models import Message
# Test model for response testing
class ResponseModel(BaseModel):
"""Test model for response testing."""
test_field: str
optional_field: int = 0
@pytest.fixture
def mock_gemini_client():
"""Fixture to mock the Google Gemini client."""
with patch('google.genai.Client') as mock_client:
# Setup mock instance and its methods
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.generate_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_client(mock_gemini_client):
"""Fixture to create a GeminiClient with a mocked client."""
config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
client = GeminiClient(config=config, cache=False)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_gemini_client
return client
class TestGeminiClientInitialization:
"""Tests for GeminiClient initialization."""
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = GeminiClient(config=config, cache=False, max_tokens=1000)
assert client.config == config
assert client.model == 'test-model'
assert client.temperature == 0.5
assert client.max_tokens == 1000
@patch('google.genai.Client')
def test_init_with_default_model(self, mock_client):
"""Test initialization with default model when none is provided."""
config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
client = GeminiClient(config=config, cache=False)
assert client.model == DEFAULT_MODEL
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
client = GeminiClient(cache=False)
assert client.config is not None
# When no config.model is set, it will be None, not DEFAULT_MODEL
assert client.model is None
@patch('google.genai.Client')
def test_init_with_thinking_config(self, mock_client):
"""Test initialization with thinking config."""
with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
thinking_config = mock_thinking_config.return_value
client = GeminiClient(thinking_config=thinking_config)
assert client.thinking_config == thinking_config
class TestGeminiClientGenerateResponse:
"""Tests for GeminiClient generate_response method."""
@pytest.mark.asyncio
async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
"""Test successful response generation with simple text."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response text'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [Message(role='user', content='Test message')]
result = await gemini_client.generate_response(messages)
# Assertions
assert isinstance(result, dict)
assert result['content'] == 'Test response text'
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_structured_output(
self, gemini_client, mock_gemini_client
):
"""Test response generation with structured output."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
result = await gemini_client.generate_response(
messages=messages, response_model=ResponseModel
)
# Assertions
assert isinstance(result, dict)
assert result['test_field'] == 'test_value'
assert result['optional_field'] == 42
mock_gemini_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
"""Test response generation with system message handling."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Response with system context'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method
messages = [
Message(role='system', content='System message'),
Message(role='user', content='User message'),
]
await gemini_client.generate_response(messages)
# Verify system message is processed correctly
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert 'System message' in config.system_instruction
@pytest.mark.asyncio
async def test_get_model_for_size(self, gemini_client):
"""Test model selection based on size."""
# Test small model
small_model = gemini_client._get_model_for_size(ModelSize.small)
assert small_model == DEFAULT_SMALL_MODEL
# Test medium/large model
medium_model = gemini_client._get_model_for_size(ModelSize.medium)
assert medium_model == gemini_client.model
@pytest.mark.asyncio
async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of rate limit errors."""
# Setup mock to raise rate limit error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Rate limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of quota errors."""
# Setup mock to raise quota error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Quota exceeded for requests'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
"""Test handling of resource exhausted errors."""
# Setup mock to raise resource exhausted error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'resource_exhausted: Request limit exceeded'
)
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(RateLimitError):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
"""Test handling of safety blocks."""
# Setup mock response with safety block
mock_candidate = MagicMock()
mock_candidate.finish_reason = 'SAFETY'
mock_candidate.safety_ratings = [
MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
]
mock_response = MagicMock()
mock_response.candidates = [mock_candidate]
mock_response.prompt_feedback = None
mock_response.text = ''
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
"""Test handling of prompt blocks."""
# Setup mock response with prompt block
mock_prompt_feedback = MagicMock()
mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
mock_response = MagicMock()
mock_response.candidates = []
mock_response.prompt_feedback = mock_prompt_feedback
mock_response.text = ''
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
@pytest.mark.asyncio
async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
"""Test handling of structured output parsing errors."""
# Setup mock response with invalid JSON that will exhaust retries
mock_response = MagicMock()
mock_response.text = 'Invalid JSON that cannot be parsed'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception - should exhaust retries
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
"""Test that safety blocks are not retried."""
# Setup mock response with safety block
mock_candidate = MagicMock()
mock_candidate.finish_reason = 'SAFETY'
mock_candidate.safety_ratings = [
MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
]
mock_response = MagicMock()
mock_response.candidates = [mock_candidate]
mock_response.prompt_feedback = None
mock_response.text = ''
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check that it doesn't retry
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception, match='Content blocked by safety filters'):
await gemini_client.generate_response(messages)
# Should only be called once (no retries for safety blocks)
assert mock_gemini_client.aio.models.generate_content.call_count == 1
@pytest.mark.asyncio
async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
"""Test retry behavior on validation error."""
# First call returns invalid JSON, second call returns valid data
mock_response1 = MagicMock()
mock_response1.text = 'Invalid JSON that cannot be parsed'
mock_response1.candidates = []
mock_response1.prompt_feedback = None
mock_response2 = MagicMock()
mock_response2.text = '{"test_field": "correct_value"}'
mock_response2.candidates = []
mock_response2.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.side_effect = [
mock_response1,
mock_response2,
]
# Call method
messages = [Message(role='user', content='Test message')]
result = await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content twice due to retry
assert mock_gemini_client.aio.models.generate_content.call_count == 2
assert result['test_field'] == 'correct_value'
@pytest.mark.asyncio
async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
"""Test behavior when max retries are exceeded."""
# Setup mock to always return invalid JSON
mock_response = MagicMock()
mock_response.text = 'Invalid JSON that cannot be parsed'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have called generate_content MAX_RETRIES times (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
"""Test handling of empty responses."""
# Setup mock response with no text
mock_response = MagicMock()
mock_response.text = ''
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with structured output and check exception
messages = [Message(role='user', content='Test message')]
with pytest.raises(Exception): # noqa: B017
await gemini_client.generate_response(messages, response_model=ResponseModel)
# Should have exhausted retries due to empty response (2 attempts total)
assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
@pytest.mark.asyncio
async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
"""Test that explicit max_tokens parameter takes precedence over all other values."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with custom max tokens (should take precedence)
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, max_tokens=500)
# Verify explicit max_tokens parameter takes precedence
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Explicit parameter should override everything else
assert config.max_output_tokens == 500
@pytest.mark.asyncio
async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
"""Test max_tokens precedence when no explicit parameter is provided."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Test case 1: No explicit max_tokens, has instance max_tokens
config = LLMConfig(
api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
)
client = GeminiClient(
config=config, cache=False, max_tokens=2000, client=mock_gemini_client
)
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Instance max_tokens should be used
assert config.max_output_tokens == 2000
# Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
# Model mapping should be used
assert config.max_output_tokens == 65536
@pytest.mark.asyncio
async def test_model_size_selection(self, gemini_client, mock_gemini_client):
"""Test that the correct model is selected based on model size."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Call method with small model size
messages = [Message(role='user', content='Test message')]
await gemini_client.generate_response(messages, model_size=ModelSize.small)
# Verify correct model is used
call_args = mock_gemini_client.aio.models.generate_content.call_args
assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
@pytest.mark.asyncio
async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
"""Test that different Gemini models use their correct max tokens."""
# Setup mock response
mock_response = MagicMock()
mock_response.text = 'Test response'
mock_response.candidates = []
mock_response.prompt_feedback = None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
# Test data: (model_name, expected_max_tokens)
test_cases = [
('gemini-2.5-flash', 65536),
('gemini-2.5-pro', 65536),
('gemini-2.5-flash-lite', 64000),
('gemini-2.0-flash', 8192),
('gemini-1.5-pro', 8192),
('gemini-1.5-flash', 8192),
('unknown-model', 8192), # Fallback case
]
for model_name, expected_max_tokens in test_cases:
# Create client with specific model, no explicit max_tokens to test mapping
config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
# Call method without explicit max_tokens to test model mapping fallback
messages = [Message(role='user', content='Test message')]
await client.generate_response(messages)
# Verify correct max tokens is used from model mapping
call_args = mock_gemini_client.aio.models.generate_content.call_args
config = call_args[1]['config']
assert config.max_output_tokens == expected_max_tokens, (
f'Model {model_name} should use {expected_max_tokens} tokens'
)
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_client.py'])
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections.abc import Awaitable, Callable
from time import time
from typing import Any
from pydantic import BaseModel
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMClient
from graphiti_core.llm_client.config import ModelSize
from graphiti_core.nodes import (
EntityNode,
EpisodeType,
EpisodicNode,
create_entity_node_embeddings,
)
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
from graphiti_core.prompts.extract_nodes import (
EntitySummary,
ExtractedEntities,
ExtractedEntity,
MissedEntities,
)
from graphiti_core.search.search import search
from graphiti_core.search.search_config import SearchResults
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.utils.content_chunking import (
chunk_json_content,
chunk_message_content,
chunk_text_content,
should_chunk,
)
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupCandidateIndexes,
DedupResolutionState,
_build_candidate_indexes,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
filter_existing_duplicate_of_edges,
)
from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
logger = logging.getLogger(__name__)
NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
async def extract_nodes_reflexion(
llm_client: LLMClient,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
node_names: list[str],
group_id: str | None = None,
) -> list[str]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'extracted_entities': node_names,
}
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.reflexion(context),
MissedEntities,
group_id=group_id,
prompt_name='extract_nodes.reflexion',
)
missed_entities = llm_response.get('missed_entities', [])
return missed_entities
async def extract_nodes(
clients: GraphitiClients,
episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
custom_extraction_instructions: str | None = None,
) -> list[EntityNode]:
"""Extract entity nodes from an episode with adaptive chunking.
For high-density content (many entities per token), the content is chunked
and processed in parallel to avoid LLM timeouts and truncation issues.
"""
start = time()
llm_client = clients.llm_client
# Build entity types context
entity_types_context = _build_entity_types_context(entity_types)
# Build base context
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [ep.content for ep in previous_episodes],
'custom_extraction_instructions': custom_extraction_instructions or '',
'entity_types': entity_types_context,
'source_description': episode.source_description,
}
# Check if chunking is needed (based on entity density)
if should_chunk(episode.content, episode.source):
extracted_entities = await _extract_nodes_chunked(llm_client, episode, context)
else:
extracted_entities = await _extract_nodes_single(llm_client, episode, context)
# Filter empty names
filtered_entities = [e for e in extracted_entities if e.name.strip()]
end = time()
logger.debug(f'Extracted {len(filtered_entities)} entities in {(end - start) * 1000:.0f} ms')
# Convert to EntityNode objects
extracted_nodes = _create_entity_nodes(
filtered_entities, entity_types_context, excluded_entity_types, episode
)
logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
return extracted_nodes
def _build_entity_types_context(
entity_types: dict[str, type[BaseModel]] | None,
) -> list[dict]:
"""Build entity types context with ID mappings."""
entity_types_context = [
{
'entity_type_id': 0,
'entity_type_name': 'Entity',
'entity_type_description': (
'Default entity classification. Use this entity type '
'if the entity is not one of the other listed types.'
),
}
]
if entity_types is not None:
entity_types_context += [
{
'entity_type_id': i + 1,
'entity_type_name': type_name,
'entity_type_description': type_model.__doc__,
}
for i, (type_name, type_model) in enumerate(entity_types.items())
]
return entity_types_context
async def _extract_nodes_single(
llm_client: LLMClient,
episode: EpisodicNode,
context: dict,
) -> list[ExtractedEntity]:
"""Extract entities using a single LLM call."""
llm_response = await _call_extraction_llm(llm_client, episode, context)
response_object = ExtractedEntities(**llm_response)
return response_object.extracted_entities
async def _extract_nodes_chunked(
llm_client: LLMClient,
episode: EpisodicNode,
context: dict,
) -> list[ExtractedEntity]:
"""Extract entities from large content using chunking."""
# Chunk the content based on episode type
if episode.source == EpisodeType.json:
chunks = chunk_json_content(episode.content)
elif episode.source == EpisodeType.message:
chunks = chunk_message_content(episode.content)
else:
chunks = chunk_text_content(episode.content)
logger.debug(f'Chunked content into {len(chunks)} chunks for entity extraction')
# Extract entities from each chunk in parallel
chunk_results = await semaphore_gather(
*[_extract_from_chunk(llm_client, chunk, context, episode) for chunk in chunks]
)
# Merge and deduplicate entities across chunks
merged_entities = _merge_extracted_entities(chunk_results)
logger.debug(
f'Merged {sum(len(r) for r in chunk_results)} entities into {len(merged_entities)} unique'
)
return merged_entities
async def _extract_from_chunk(
llm_client: LLMClient,
chunk: str,
base_context: dict,
episode: EpisodicNode,
) -> list[ExtractedEntity]:
"""Extract entities from a single chunk."""
chunk_context = {**base_context, 'episode_content': chunk}
llm_response = await _call_extraction_llm(llm_client, episode, chunk_context)
return ExtractedEntities(**llm_response).extracted_entities
async def _call_extraction_llm(
llm_client: LLMClient,
episode: EpisodicNode,
context: dict,
) -> dict:
"""Call the appropriate extraction prompt based on episode type."""
if episode.source == EpisodeType.message:
prompt = prompt_library.extract_nodes.extract_message(context)
prompt_name = 'extract_nodes.extract_message'
elif episode.source == EpisodeType.text:
prompt = prompt_library.extract_nodes.extract_text(context)
prompt_name = 'extract_nodes.extract_text'
elif episode.source == EpisodeType.json:
prompt = prompt_library.extract_nodes.extract_json(context)
prompt_name = 'extract_nodes.extract_json'
else:
# Fallback to text extraction
prompt = prompt_library.extract_nodes.extract_text(context)
prompt_name = 'extract_nodes.extract_text'
return await llm_client.generate_response(
prompt,
response_model=ExtractedEntities,
group_id=episode.group_id,
prompt_name=prompt_name,
)
def _merge_extracted_entities(
chunk_results: list[list[ExtractedEntity]],
) -> list[ExtractedEntity]:
"""Merge entities from multiple chunks, deduplicating by normalized name.
When duplicates occur, prefer the first occurrence (maintains ordering).
"""
seen_names: set[str] = set()
merged: list[ExtractedEntity] = []
for entities in chunk_results:
for entity in entities:
normalized = entity.name.strip().lower()
if normalized and normalized not in seen_names:
seen_names.add(normalized)
merged.append(entity)
return merged
def _create_entity_nodes(
extracted_entities: list[ExtractedEntity],
entity_types_context: list[dict],
excluded_entity_types: list[str] | None,
episode: EpisodicNode,
) -> list[EntityNode]:
"""Convert ExtractedEntity objects to EntityNode objects."""
extracted_nodes = []
for extracted_entity in extracted_entities:
type_id = extracted_entity.entity_type_id
if 0 <= type_id < len(entity_types_context):
entity_type_name = entity_types_context[type_id].get('entity_type_name')
else:
entity_type_name = 'Entity'
# Check if this entity type should be excluded
if excluded_entity_types and entity_type_name in excluded_entity_types:
logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
continue
labels: list[str] = list({'Entity', str(entity_type_name)})
new_node = EntityNode(
name=extracted_entity.name,
group_id=episode.group_id,
labels=labels,
summary='',
created_at=utc_now(),
)
extracted_nodes.append(new_node)
logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
return extracted_nodes
async def _collect_candidate_nodes(
clients: GraphitiClients,
extracted_nodes: list[EntityNode],
existing_nodes_override: list[EntityNode] | None,
) -> list[EntityNode]:
"""Search per extracted name and return unique candidates with overrides honored in order."""
search_results: list[SearchResults] = await semaphore_gather(
*[
search(
clients=clients,
query=node.name,
group_ids=[node.group_id],
search_filter=SearchFilters(),
config=NODE_HYBRID_SEARCH_RRF,
)
for node in extracted_nodes
]
)
candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
if existing_nodes_override is not None:
candidate_nodes.extend(existing_nodes_override)
seen_candidate_uuids: set[str] = set()
ordered_candidates: list[EntityNode] = []
for candidate in candidate_nodes:
if candidate.uuid in seen_candidate_uuids:
continue
seen_candidate_uuids.add(candidate.uuid)
ordered_candidates.append(candidate)
return ordered_candidates
async def _resolve_with_llm(
llm_client: LLMClient,
extracted_nodes: list[EntityNode],
indexes: DedupCandidateIndexes,
state: DedupResolutionState,
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
entity_types: dict[str, type[BaseModel]] | None,
) -> None:
"""Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
The guardrails below defensively ignore malformed or duplicate LLM responses so the
ingestion workflow remains deterministic even when the model misbehaves.
"""
if not state.unresolved_indices:
return
entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
extracted_nodes_context = [
{
'id': i,
'name': node.name,
'entity_type': node.labels,
'entity_type_description': entity_types_dict.get(
next((item for item in node.labels if item != 'Entity'), '')
).__doc__
or 'Default Entity Type',
}
for i, node in enumerate(llm_extracted_nodes)
]
sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
logger.debug(
'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
len(llm_extracted_nodes),
len(llm_extracted_nodes) - 1,
sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
)
if llm_extracted_nodes:
sample_size = min(3, len(extracted_nodes_context))
logger.debug(
'First %d entities: %s',
sample_size,
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
)
if len(extracted_nodes_context) > 3:
logger.debug(
'Last %d entities: %s',
sample_size,
[(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
)
existing_nodes_context = [
{
**{
'idx': i,
'name': candidate.name,
'entity_types': candidate.labels,
},
**candidate.attributes,
}
for i, candidate in enumerate(indexes.existing_nodes)
]
context = {
'extracted_nodes': extracted_nodes_context,
'existing_nodes': existing_nodes_context,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': (
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
),
}
llm_response = await llm_client.generate_response(
prompt_library.dedupe_nodes.nodes(context),
response_model=NodeResolutions,
prompt_name='dedupe_nodes.nodes',
)
node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
valid_relative_range = range(len(state.unresolved_indices))
processed_relative_ids: set[int] = set()
received_ids = {r.id for r in node_resolutions}
expected_ids = set(valid_relative_range)
missing_ids = expected_ids - received_ids
extra_ids = received_ids - expected_ids
logger.debug(
'Received %d resolutions for %d entities',
len(node_resolutions),
len(state.unresolved_indices),
)
if missing_ids:
logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
if extra_ids:
logger.warning(
'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
len(state.unresolved_indices) - 1,
sorted(extra_ids),
sorted(received_ids),
)
for resolution in node_resolutions:
relative_id: int = resolution.id
duplicate_idx: int = resolution.duplicate_idx
if relative_id not in valid_relative_range:
logger.warning(
'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
relative_id,
len(state.unresolved_indices) - 1,
len(node_resolutions),
)
continue
if relative_id in processed_relative_ids:
logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
continue
processed_relative_ids.add(relative_id)
original_index = state.unresolved_indices[relative_id]
extracted_node = extracted_nodes[original_index]
resolved_node: EntityNode
if duplicate_idx == -1:
resolved_node = extracted_node
elif 0 <= duplicate_idx < len(indexes.existing_nodes):
resolved_node = indexes.existing_nodes[duplicate_idx]
else:
logger.warning(
'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
duplicate_idx,
extracted_node.uuid,
)
resolved_node = extracted_node
state.resolved_nodes[original_index] = resolved_node
state.uuid_map[extracted_node.uuid] = resolved_node.uuid
if resolved_node.uuid != extracted_node.uuid:
state.duplicate_pairs.append((extracted_node, resolved_node))
async def resolve_extracted_nodes(
clients: GraphitiClients,
extracted_nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
existing_nodes_override: list[EntityNode] | None = None,
) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
"""Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
llm_client = clients.llm_client
driver = clients.driver
existing_nodes = await _collect_candidate_nodes(
clients,
extracted_nodes,
existing_nodes_override,
)
indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
state = DedupResolutionState(
resolved_nodes=[None] * len(extracted_nodes),
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity(extracted_nodes, indexes, state)
await _resolve_with_llm(
llm_client,
extracted_nodes,
indexes,
state,
episode,
previous_episodes,
entity_types,
)
for idx, node in enumerate(extracted_nodes):
if state.resolved_nodes[idx] is None:
state.resolved_nodes[idx] = node
state.uuid_map[node.uuid] = node.uuid
logger.debug(
'Resolved nodes: %s',
[(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
)
new_node_duplicates: list[
tuple[EntityNode, EntityNode]
] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
return (
[node for node in state.resolved_nodes if node is not None],
state.uuid_map,
new_node_duplicates,
)
async def extract_attributes_from_nodes(
clients: GraphitiClients,
nodes: list[EntityNode],
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_types: dict[str, type[BaseModel]] | None = None,
should_summarize_node: NodeSummaryFilter | None = None,
) -> list[EntityNode]:
llm_client = clients.llm_client
embedder = clients.embedder
updated_nodes: list[EntityNode] = await semaphore_gather(
*[
extract_attributes_from_node(
llm_client,
node,
episode,
previous_episodes,
(
entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
if entity_types is not None
else None
),
should_summarize_node,
)
for node in nodes
]
)
await create_entity_node_embeddings(embedder, updated_nodes)
return updated_nodes
async def extract_attributes_from_node(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None = None,
previous_episodes: list[EpisodicNode] | None = None,
entity_type: type[BaseModel] | None = None,
should_summarize_node: NodeSummaryFilter | None = None,
) -> EntityNode:
# Extract attributes if entity type is defined and has attributes
llm_response = await _extract_entity_attributes(
llm_client, node, episode, previous_episodes, entity_type
)
# Extract summary if needed
await _extract_entity_summary(
llm_client, node, episode, previous_episodes, should_summarize_node
)
node.attributes.update(llm_response)
return node
async def _extract_entity_attributes(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
entity_type: type[BaseModel] | None,
) -> dict[str, Any]:
if entity_type is None or len(entity_type.model_fields) == 0:
return {}
attributes_context = _build_episode_context(
# should not include summary
node_data={
'name': node.name,
'entity_types': node.labels,
'attributes': node.attributes,
},
episode=episode,
previous_episodes=previous_episodes,
)
llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_attributes(attributes_context),
response_model=entity_type,
model_size=ModelSize.small,
group_id=node.group_id,
prompt_name='extract_nodes.extract_attributes',
)
# validate response
entity_type(**llm_response)
return llm_response
async def _extract_entity_summary(
llm_client: LLMClient,
node: EntityNode,
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
should_summarize_node: NodeSummaryFilter | None,
) -> None:
if should_summarize_node is not None and not await should_summarize_node(node):
return
summary_context = _build_episode_context(
node_data={
'name': node.name,
'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
'entity_types': node.labels,
'attributes': node.attributes,
},
episode=episode,
previous_episodes=previous_episodes,
)
summary_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_summary(summary_context),
response_model=EntitySummary,
model_size=ModelSize.small,
group_id=node.group_id,
prompt_name='extract_nodes.extract_summary',
)
node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
def _build_episode_context(
node_data: dict[str, Any],
episode: EpisodicNode | None,
previous_episodes: list[EpisodicNode] | None,
) -> dict[str, Any]:
return {
'node': node_data,
'episode_content': episode.content if episode is not None else '',
'previous_episodes': (
[ep.content for ep in previous_episodes] if previous_episodes is not None else []
),
}
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Comprehensive integration test suite for Graphiti MCP Server.
Covers all MCP tools with consideration for LLM inference latency.
"""
import asyncio
import json
import os
import time
from dataclasses import dataclass
from typing import Any
import pytest
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@dataclass
class TestMetrics:
"""Track test performance metrics."""
operation: str
start_time: float
end_time: float
success: bool
details: dict[str, Any]
@property
def duration(self) -> float:
"""Calculate operation duration in seconds."""
return self.end_time - self.start_time
class GraphitiTestClient:
"""Enhanced test client for comprehensive Graphiti MCP testing."""
def __init__(self, test_group_id: str | None = None):
self.test_group_id = test_group_id or f'test_{int(time.time())}'
self.session = None
self.metrics: list[TestMetrics] = []
self.default_timeout = 30 # seconds
async def __aenter__(self):
"""Initialize MCP client session."""
server_params = StdioServerParameters(
command='uv',
args=['run', '../main.py', '--transport', 'stdio'],
env={
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
},
)
self.client_context = stdio_client(server_params)
read, write = await self.client_context.__aenter__()
self.session = ClientSession(read, write)
await self.session.initialize()
# Wait for server to be fully ready
await asyncio.sleep(2)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Clean up client session."""
if self.session:
await self.session.close()
if hasattr(self, 'client_context'):
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
async def call_tool_with_metrics(
self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
) -> tuple[Any, TestMetrics]:
"""Call a tool and capture performance metrics."""
start_time = time.time()
timeout = timeout or self.default_timeout
try:
result = await asyncio.wait_for(
self.session.call_tool(tool_name, arguments), timeout=timeout
)
content = result.content[0].text if result.content else None
success = True
details = {'result': content, 'tool': tool_name}
except asyncio.TimeoutError:
content = None
success = False
details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
except Exception as e:
content = None
success = False
details = {'error': str(e), 'tool': tool_name}
end_time = time.time()
metric = TestMetrics(
operation=f'call_{tool_name}',
start_time=start_time,
end_time=end_time,
success=success,
details=details,
)
self.metrics.append(metric)
return content, metric
async def wait_for_episode_processing(
self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
) -> bool:
"""
Wait for episodes to be processed with intelligent polling.
Args:
expected_count: Number of episodes expected to be processed
max_wait: Maximum seconds to wait
poll_interval: Seconds between status checks
Returns:
True if episodes were processed successfully
"""
start_time = time.time()
while (time.time() - start_time) < max_wait:
result, _ = await self.call_tool_with_metrics(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
)
if result:
try:
episodes = json.loads(result) if isinstance(result, str) else result
if len(episodes.get('episodes', [])) >= expected_count:
return True
except (json.JSONDecodeError, AttributeError):
pass
await asyncio.sleep(poll_interval)
return False
class TestCoreOperations:
"""Test core Graphiti operations."""
@pytest.mark.asyncio
async def test_server_initialization(self):
"""Verify server initializes with all required tools."""
async with GraphitiTestClient() as client:
tools_result = await client.session.list_tools()
tools = {tool.name for tool in tools_result.tools}
required_tools = {
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'delete_entity_edge',
'get_entity_edge',
'clear_graph',
'get_status',
}
missing_tools = required_tools - tools
assert not missing_tools, f'Missing required tools: {missing_tools}'
@pytest.mark.asyncio
async def test_add_text_memory(self):
"""Test adding text-based memories."""
async with GraphitiTestClient() as client:
# Add memory
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Tech Conference Notes',
'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
'source': 'text',
'source_description': 'conference notes',
'group_id': client.test_group_id,
},
)
assert metric.success, f'Failed to add memory: {metric.details}'
assert 'queued' in str(result).lower()
# Wait for processing
processed = await client.wait_for_episode_processing(expected_count=1)
assert processed, 'Episode was not processed within timeout'
@pytest.mark.asyncio
async def test_add_json_memory(self):
"""Test adding structured JSON memories."""
async with GraphitiTestClient() as client:
json_data = {
'project': {
'name': 'GraphitiDB',
'version': '2.0.0',
'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
},
'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
}
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Project Data',
'episode_body': json.dumps(json_data),
'source': 'json',
'source_description': 'project database',
'group_id': client.test_group_id,
},
)
assert metric.success
assert 'queued' in str(result).lower()
@pytest.mark.asyncio
async def test_add_message_memory(self):
"""Test adding conversation/message memories."""
async with GraphitiTestClient() as client:
conversation = """
user: What are the key features of Graphiti?
assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
user: How does it handle entity resolution?
assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
"""
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Feature Discussion',
'episode_body': conversation,
'source': 'message',
'source_description': 'support chat',
'group_id': client.test_group_id,
},
)
assert metric.success
assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
class TestSearchOperations:
"""Test search and retrieval operations."""
@pytest.mark.asyncio
async def test_search_nodes_semantic(self):
"""Test semantic search for nodes."""
async with GraphitiTestClient() as client:
# First add some test data
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Product Launch',
'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
'source': 'text',
'source_description': 'product roadmap',
'group_id': client.test_group_id,
},
)
# Wait for processing
await client.wait_for_episode_processing()
# Search for nodes
result, metric = await client.call_tool_with_metrics(
'search_memory_nodes',
{'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
)
assert metric.success
assert result is not None
@pytest.mark.asyncio
async def test_search_facts_with_filters(self):
"""Test fact search with various filters."""
async with GraphitiTestClient() as client:
# Add test data
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Company Facts',
'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
'source': 'text',
'source_description': 'company profile',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing()
# Search with date filter
result, metric = await client.call_tool_with_metrics(
'search_memory_facts',
{
'query': 'company information',
'group_id': client.test_group_id,
'created_after': '2020-01-01T00:00:00Z',
'limit': 20,
},
)
assert metric.success
@pytest.mark.asyncio
async def test_hybrid_search(self):
"""Test hybrid search combining semantic and keyword search."""
async with GraphitiTestClient() as client:
# Add diverse test data
test_memories = [
{
'name': 'Technical Doc',
'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
'source': 'text',
},
{
'name': 'Architecture',
'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
'source': 'text',
},
]
for memory in test_memories:
memory['group_id'] = client.test_group_id
memory['source_description'] = 'documentation'
await client.call_tool_with_metrics('add_memory', memory)
await client.wait_for_episode_processing(expected_count=2)
# Test semantic + keyword search
result, metric = await client.call_tool_with_metrics(
'search_memory_nodes',
{'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
)
assert metric.success
class TestEpisodeManagement:
"""Test episode lifecycle operations."""
@pytest.mark.asyncio
async def test_get_episodes_pagination(self):
"""Test retrieving episodes with pagination."""
async with GraphitiTestClient() as client:
# Add multiple episodes
for i in range(5):
await client.call_tool_with_metrics(
'add_memory',
{
'name': f'Episode {i}',
'episode_body': f'This is test episode number {i}',
'source': 'text',
'source_description': 'test',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing(expected_count=5)
# Test pagination
result, metric = await client.call_tool_with_metrics(
'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
)
assert metric.success
episodes = json.loads(result) if isinstance(result, str) else result
assert len(episodes.get('episodes', [])) <= 3
@pytest.mark.asyncio
async def test_delete_episode(self):
"""Test deleting specific episodes."""
async with GraphitiTestClient() as client:
# Add an episode
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'To Delete',
'episode_body': 'This episode will be deleted',
'source': 'text',
'source_description': 'test',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing()
# Get episode UUID
result, _ = await client.call_tool_with_metrics(
'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
)
episodes = json.loads(result) if isinstance(result, str) else result
episode_uuid = episodes['episodes'][0]['uuid']
# Delete the episode
result, metric = await client.call_tool_with_metrics(
'delete_episode', {'episode_uuid': episode_uuid}
)
assert metric.success
assert 'deleted' in str(result).lower()
class TestEntityAndEdgeOperations:
"""Test entity and edge management."""
@pytest.mark.asyncio
async def test_get_entity_edge(self):
"""Test retrieving entity edges."""
async with GraphitiTestClient() as client:
# Add data to create entities and edges
await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Relationship Data',
'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
'source': 'text',
'source_description': 'org chart',
'group_id': client.test_group_id,
},
)
await client.wait_for_episode_processing()
# Search for nodes to get UUIDs
result, _ = await client.call_tool_with_metrics(
'search_memory_nodes',
{'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
)
# Note: This test assumes edges are created between entities
# Actual edge retrieval would require valid edge UUIDs
@pytest.mark.asyncio
async def test_delete_entity_edge(self):
"""Test deleting entity edges."""
# Similar structure to get_entity_edge but with deletion
pass # Implement based on actual edge creation patterns
class TestErrorHandling:
"""Test error conditions and edge cases."""
@pytest.mark.asyncio
async def test_invalid_tool_arguments(self):
"""Test handling of invalid tool arguments."""
async with GraphitiTestClient() as client:
# Missing required arguments
result, metric = await client.call_tool_with_metrics(
'add_memory',
{'name': 'Incomplete'}, # Missing required fields
)
assert not metric.success
assert 'error' in str(metric.details).lower()
@pytest.mark.asyncio
async def test_timeout_handling(self):
"""Test timeout handling for long operations."""
async with GraphitiTestClient() as client:
# Simulate a very large episode that might time out
large_text = 'Large document content. ' * 10000
result, metric = await client.call_tool_with_metrics(
'add_memory',
{
'name': 'Large Document',
'episode_body': large_text,
'source': 'text',
'source_description': 'large file',
'group_id': client.test_group_id,
},
timeout=5, # Short timeout
)
# Check if timeout was handled gracefully
if not metric.success:
assert 'timeout' in str(metric.details).lower()
@pytest.mark.asyncio
async def test_concurrent_operations(self):
"""Test handling of concurrent operations."""
async with GraphitiTestClient() as client:
# Launch multiple operations concurrently
tasks = []
for i in range(5):
task = client.call_tool_with_metrics(
'add_memory',
{
'name': f'Concurrent {i}',
'episode_body': f'Concurrent operation {i}',
'source': 'text',
'source_description': 'concurrent test',
'group_id': client.test_group_id,
},
)
tasks.append(task)
results = await asyncio.gather(*tasks, return_exceptions=True)
# Check that operations were queued successfully
successful = sum(1 for r, m in results if m.success)
assert successful >= 3 # At least 60% should succeed
class TestPerformance:
"""Test performance characteristics and optimization."""
@pytest.mark.asyncio
async def test_latency_metrics(self):
"""Measure and validate operation latencies."""
async with GraphitiTestClient() as client:
operations = [
(
'add_memory',
{
'name': 'Perf Test',
'episode_body': 'Simple text',
'source': 'text',
'source_description': 'test',
'group_id': client.test_group_id,
},
),
(
'search_memory_nodes',
{'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
),
('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
]
for tool_name, args in operations:
_, metric = await client.call_tool_with_metrics(tool_name, args)
# Log performance metrics
print(f'{tool_name}: {metric.duration:.2f}s')
# Basic latency assertions
if tool_name == 'get_episodes':
assert metric.duration < 2, f'{tool_name} too slow'
elif tool_name == 'search_memory_nodes':
assert metric.duration < 10, f'{tool_name} too slow'
@pytest.mark.asyncio
async def test_batch_processing_efficiency(self):
"""Test efficiency of batch operations."""
async with GraphitiTestClient() as client:
batch_size = 10
start_time = time.time()
# Batch add memories
for i in range(batch_size):
await client.call_tool_with_metrics(
'add_memory',
{
'name': f'Batch {i}',
'episode_body': f'Batch content {i}',
'source': 'text',
'source_description': 'batch test',
'group_id': client.test_group_id,
},
)
# Wait for all to process
processed = await client.wait_for_episode_processing(
expected_count=batch_size,
max_wait=120, # Allow more time for batch
)
total_time = time.time() - start_time
avg_time_per_item = total_time / batch_size
assert processed, f'Failed to process {batch_size} items'
assert avg_time_per_item < 15, (
f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
)
# Generate performance report
print('\nBatch Performance Report:')
print(f' Total items: {batch_size}')
print(f' Total time: {total_time:.2f}s')
print(f' Avg per item: {avg_time_per_item:.2f}s')
class TestDatabaseBackends:
"""Test different database backend configurations."""
@pytest.mark.asyncio
@pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
async def test_database_operations(self, database):
"""Test operations with different database backends."""
env_vars = {
'DATABASE_PROVIDER': database,
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
}
if database == 'neo4j':
env_vars.update(
{
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
}
)
elif database == 'falkordb':
env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
# This test would require setting up server with specific database
# Implementation depends on database availability
pass # Placeholder for database-specific tests
def generate_test_report(client: GraphitiTestClient) -> str:
"""Generate a comprehensive test report from metrics."""
if not client.metrics:
return 'No metrics collected'
report = []
report.append('\n' + '=' * 60)
report.append('GRAPHITI MCP TEST REPORT')
report.append('=' * 60)
# Summary statistics
total_ops = len(client.metrics)
successful_ops = sum(1 for m in client.metrics if m.success)
avg_duration = sum(m.duration for m in client.metrics) / total_ops
report.append(f'\nTotal Operations: {total_ops}')
report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
report.append(f'Average Duration: {avg_duration:.2f}s')
# Operation breakdown
report.append('\nOperation Breakdown:')
operation_stats = {}
for metric in client.metrics:
if metric.operation not in operation_stats:
operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
stats = operation_stats[metric.operation]
stats['count'] += 1
stats['success'] += 1 if metric.success else 0
stats['total_duration'] += metric.duration
for op, stats in sorted(operation_stats.items()):
avg_dur = stats['total_duration'] / stats['count']
success_rate = stats['success'] / stats['count'] * 100
report.append(
f' {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
)
# Slowest operations
slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
report.append('\nSlowest Operations:')
for metric in slowest:
report.append(f' {metric.operation}: {metric.duration:.2f}s')
report.append('=' * 60)
return '\n'.join(report)
if __name__ == '__main__':
# Run tests with pytest
pytest.main([__file__, '-v', '--asyncio-mode=auto'])
```
--------------------------------------------------------------------------------
/graphiti_core/utils/content_chunking.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import random
import re
from itertools import combinations
from math import comb
from typing import TypeVar
from graphiti_core.helpers import (
CHUNK_DENSITY_THRESHOLD,
CHUNK_MIN_TOKENS,
CHUNK_OVERLAP_TOKENS,
CHUNK_TOKEN_SIZE,
)
from graphiti_core.nodes import EpisodeType
logger = logging.getLogger(__name__)
# Approximate characters per token (conservative estimate)
CHARS_PER_TOKEN = 4
def estimate_tokens(text: str) -> int:
"""Estimate token count using character-based heuristic.
Uses ~4 characters per token as a conservative estimate.
This is faster than actual tokenization and works across all LLM providers.
Args:
text: The text to estimate tokens for
Returns:
Estimated token count
"""
return len(text) // CHARS_PER_TOKEN
def _tokens_to_chars(tokens: int) -> int:
"""Convert token count to approximate character count."""
return tokens * CHARS_PER_TOKEN
def should_chunk(content: str, episode_type: EpisodeType) -> bool:
"""Determine whether content should be chunked based on size and entity density.
Only chunks content that is both:
1. Large enough to potentially cause LLM issues (>= CHUNK_MIN_TOKENS)
2. High entity density (many entities per token)
Short content processes fine regardless of density. This targets the specific
failure case of large entity-dense inputs while preserving context for
prose/narrative content and avoiding unnecessary chunking of small inputs.
Args:
content: The content to evaluate
episode_type: Type of episode (json, message, text)
Returns:
True if content is large and has high entity density
"""
tokens = estimate_tokens(content)
# Short content always processes fine - no need to chunk
if tokens < CHUNK_MIN_TOKENS:
return False
return _estimate_high_density(content, episode_type, tokens)
def _estimate_high_density(content: str, episode_type: EpisodeType, tokens: int) -> bool:
"""Estimate whether content has high entity density.
High-density content (many entities per token) benefits from chunking.
Low-density content (prose, narratives) loses context when chunked.
Args:
content: The content to analyze
episode_type: Type of episode
tokens: Pre-computed token count
Returns:
True if content appears to have high entity density
"""
if episode_type == EpisodeType.json:
return _json_likely_dense(content, tokens)
else:
return _text_likely_dense(content, tokens)
def _json_likely_dense(content: str, tokens: int) -> bool:
"""Estimate entity density for JSON content.
JSON is considered dense if it has many array elements or object keys,
as each typically represents a distinct entity or data point.
Heuristics:
- Array: Count elements, estimate entities per 1000 tokens
- Object: Count top-level keys
Args:
content: JSON string content
tokens: Token count
Returns:
True if JSON appears to have high entity density
"""
try:
data = json.loads(content)
except json.JSONDecodeError:
# Invalid JSON, fall back to text heuristics
return _text_likely_dense(content, tokens)
if isinstance(data, list):
# For arrays, each element likely contains entities
element_count = len(data)
# Estimate density: elements per 1000 tokens
density = (element_count / tokens) * 1000 if tokens > 0 else 0
return density > CHUNK_DENSITY_THRESHOLD * 1000 # Scale threshold
elif isinstance(data, dict):
# For objects, count keys recursively (shallow)
key_count = _count_json_keys(data, max_depth=2)
density = (key_count / tokens) * 1000 if tokens > 0 else 0
return density > CHUNK_DENSITY_THRESHOLD * 1000
else:
# Scalar value, no need to chunk
return False
def _count_json_keys(data: dict, max_depth: int = 2, current_depth: int = 0) -> int:
"""Count keys in a JSON object up to a certain depth.
Args:
data: Dictionary to count keys in
max_depth: Maximum depth to traverse
current_depth: Current recursion depth
Returns:
Count of keys
"""
if current_depth >= max_depth:
return 0
count = len(data)
for value in data.values():
if isinstance(value, dict):
count += _count_json_keys(value, max_depth, current_depth + 1)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
count += _count_json_keys(item, max_depth, current_depth + 1)
return count
def _text_likely_dense(content: str, tokens: int) -> bool:
"""Estimate entity density for text content.
Uses capitalized words as a proxy for named entities (people, places,
organizations, products). High ratio of capitalized words suggests
high entity density.
Args:
content: Text content
tokens: Token count
Returns:
True if text appears to have high entity density
"""
if tokens == 0:
return False
# Split into words
words = content.split()
if not words:
return False
# Count capitalized words (excluding sentence starters)
# A word is "capitalized" if it starts with uppercase and isn't all caps
capitalized_count = 0
for i, word in enumerate(words):
# Skip if it's likely a sentence starter (after . ! ? or first word)
if i == 0:
continue
if i > 0 and words[i - 1].rstrip()[-1:] in '.!?':
continue
# Check if capitalized (first char upper, not all caps)
cleaned = word.strip('.,!?;:\'"()[]{}')
if cleaned and cleaned[0].isupper() and not cleaned.isupper():
capitalized_count += 1
# Calculate density: capitalized words per 1000 tokens
density = (capitalized_count / tokens) * 1000 if tokens > 0 else 0
# Text density threshold is typically lower than JSON
# A well-written article might have 5-10% named entities
return density > CHUNK_DENSITY_THRESHOLD * 500 # Half the JSON threshold
def chunk_json_content(
content: str,
chunk_size_tokens: int | None = None,
overlap_tokens: int | None = None,
) -> list[str]:
"""Split JSON content into chunks while preserving structure.
For arrays: splits at element boundaries, keeping complete objects.
For objects: splits at top-level key boundaries.
Args:
content: JSON string to chunk
chunk_size_tokens: Target size per chunk in tokens (default from env)
overlap_tokens: Overlap between chunks in tokens (default from env)
Returns:
List of JSON string chunks
"""
chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
overlap_chars = _tokens_to_chars(overlap_tokens)
try:
data = json.loads(content)
except json.JSONDecodeError:
logger.warning('Failed to parse JSON, falling back to text chunking')
return chunk_text_content(content, chunk_size_tokens, overlap_tokens)
if isinstance(data, list):
return _chunk_json_array(data, chunk_size_chars, overlap_chars)
elif isinstance(data, dict):
return _chunk_json_object(data, chunk_size_chars, overlap_chars)
else:
# Scalar value, return as-is
return [content]
def _chunk_json_array(
data: list,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Chunk a JSON array by splitting at element boundaries."""
if not data:
return ['[]']
chunks: list[str] = []
current_elements: list = []
current_size = 2 # Account for '[]'
for element in data:
element_json = json.dumps(element)
element_size = len(element_json) + 2 # Account for comma and space
# Check if adding this element would exceed chunk size
if current_elements and current_size + element_size > chunk_size_chars:
# Save current chunk
chunks.append(json.dumps(current_elements))
# Start new chunk with overlap (include last few elements)
overlap_elements = _get_overlap_elements(current_elements, overlap_chars)
current_elements = overlap_elements
current_size = len(json.dumps(current_elements)) if current_elements else 2
current_elements.append(element)
current_size += element_size
# Don't forget the last chunk
if current_elements:
chunks.append(json.dumps(current_elements))
return chunks if chunks else ['[]']
def _get_overlap_elements(elements: list, overlap_chars: int) -> list:
"""Get elements from the end of a list that fit within overlap_chars."""
if not elements:
return []
overlap_elements: list = []
current_size = 2 # Account for '[]'
for element in reversed(elements):
element_json = json.dumps(element)
element_size = len(element_json) + 2
if current_size + element_size > overlap_chars:
break
overlap_elements.insert(0, element)
current_size += element_size
return overlap_elements
def _chunk_json_object(
data: dict,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Chunk a JSON object by splitting at top-level key boundaries."""
if not data:
return ['{}']
chunks: list[str] = []
current_keys: list[str] = []
current_dict: dict = {}
current_size = 2 # Account for '{}'
for key, value in data.items():
entry_json = json.dumps({key: value})
entry_size = len(entry_json)
# Check if adding this entry would exceed chunk size
if current_dict and current_size + entry_size > chunk_size_chars:
# Save current chunk
chunks.append(json.dumps(current_dict))
# Start new chunk with overlap (include last few keys)
overlap_dict = _get_overlap_dict(current_dict, current_keys, overlap_chars)
current_dict = overlap_dict
current_keys = list(overlap_dict.keys())
current_size = len(json.dumps(current_dict)) if current_dict else 2
current_dict[key] = value
current_keys.append(key)
current_size += entry_size
# Don't forget the last chunk
if current_dict:
chunks.append(json.dumps(current_dict))
return chunks if chunks else ['{}']
def _get_overlap_dict(data: dict, keys: list[str], overlap_chars: int) -> dict:
"""Get key-value pairs from the end of a dict that fit within overlap_chars."""
if not data or not keys:
return {}
overlap_dict: dict = {}
current_size = 2 # Account for '{}'
for key in reversed(keys):
if key not in data:
continue
entry_json = json.dumps({key: data[key]})
entry_size = len(entry_json)
if current_size + entry_size > overlap_chars:
break
overlap_dict[key] = data[key]
current_size += entry_size
# Reverse to maintain original order
return dict(reversed(list(overlap_dict.items())))
def chunk_text_content(
content: str,
chunk_size_tokens: int | None = None,
overlap_tokens: int | None = None,
) -> list[str]:
"""Split text content at natural boundaries (paragraphs, sentences).
Includes overlap to capture entities at chunk boundaries.
Args:
content: Text to chunk
chunk_size_tokens: Target size per chunk in tokens (default from env)
overlap_tokens: Overlap between chunks in tokens (default from env)
Returns:
List of text chunks
"""
chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
overlap_chars = _tokens_to_chars(overlap_tokens)
if len(content) <= chunk_size_chars:
return [content]
# Split into paragraphs first
paragraphs = re.split(r'\n\s*\n', content)
chunks: list[str] = []
current_chunk: list[str] = []
current_size = 0
for paragraph in paragraphs:
paragraph = paragraph.strip()
if not paragraph:
continue
para_size = len(paragraph)
# If a single paragraph is too large, split it by sentences
if para_size > chunk_size_chars:
# First, save current chunk if any
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
current_chunk = []
current_size = 0
# Split large paragraph by sentences
sentence_chunks = _chunk_by_sentences(paragraph, chunk_size_chars, overlap_chars)
chunks.extend(sentence_chunks)
continue
# Check if adding this paragraph would exceed chunk size
if current_chunk and current_size + para_size + 2 > chunk_size_chars:
# Save current chunk
chunks.append('\n\n'.join(current_chunk))
# Start new chunk with overlap
overlap_text = _get_overlap_text('\n\n'.join(current_chunk), overlap_chars)
if overlap_text:
current_chunk = [overlap_text]
current_size = len(overlap_text)
else:
current_chunk = []
current_size = 0
current_chunk.append(paragraph)
current_size += para_size + 2 # Account for '\n\n'
# Don't forget the last chunk
if current_chunk:
chunks.append('\n\n'.join(current_chunk))
return chunks if chunks else [content]
def _chunk_by_sentences(
text: str,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Split text by sentence boundaries."""
# Split on sentence-ending punctuation followed by whitespace
sentence_pattern = r'(?<=[.!?])\s+'
sentences = re.split(sentence_pattern, text)
chunks: list[str] = []
current_chunk: list[str] = []
current_size = 0
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
sent_size = len(sentence)
# If a single sentence is too large, split it by fixed size
if sent_size > chunk_size_chars:
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_size = 0
# Split by fixed size as last resort
fixed_chunks = _chunk_by_size(sentence, chunk_size_chars, overlap_chars)
chunks.extend(fixed_chunks)
continue
# Check if adding this sentence would exceed chunk size
if current_chunk and current_size + sent_size + 1 > chunk_size_chars:
chunks.append(' '.join(current_chunk))
# Start new chunk with overlap
overlap_text = _get_overlap_text(' '.join(current_chunk), overlap_chars)
if overlap_text:
current_chunk = [overlap_text]
current_size = len(overlap_text)
else:
current_chunk = []
current_size = 0
current_chunk.append(sentence)
current_size += sent_size + 1
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
def _chunk_by_size(
text: str,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Split text by fixed character size (last resort)."""
chunks: list[str] = []
start = 0
while start < len(text):
end = min(start + chunk_size_chars, len(text))
# Try to break at word boundary
if end < len(text):
space_idx = text.rfind(' ', start, end)
if space_idx > start:
end = space_idx
chunks.append(text[start:end].strip())
# Move start forward, ensuring progress even if overlap >= chunk_size
# Always advance by at least (chunk_size - overlap) or 1 char minimum
min_progress = max(1, chunk_size_chars - overlap_chars)
start = max(start + min_progress, end - overlap_chars)
return chunks
def _get_overlap_text(text: str, overlap_chars: int) -> str:
"""Get the last overlap_chars characters of text, breaking at word boundary."""
if len(text) <= overlap_chars:
return text
overlap_start = len(text) - overlap_chars
# Find the next word boundary after overlap_start
space_idx = text.find(' ', overlap_start)
if space_idx != -1:
return text[space_idx + 1 :]
return text[overlap_start:]
def chunk_message_content(
content: str,
chunk_size_tokens: int | None = None,
overlap_tokens: int | None = None,
) -> list[str]:
"""Split conversation content preserving message boundaries.
Never splits mid-message. Messages are identified by patterns like:
- "Speaker: message"
- JSON message arrays
- Newline-separated messages
Args:
content: Conversation content to chunk
chunk_size_tokens: Target size per chunk in tokens (default from env)
overlap_tokens: Overlap between chunks in tokens (default from env)
Returns:
List of conversation chunks
"""
chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
overlap_chars = _tokens_to_chars(overlap_tokens)
if len(content) <= chunk_size_chars:
return [content]
# Try to detect message format
# Check if it's JSON (array of message objects)
try:
data = json.loads(content)
if isinstance(data, list):
return _chunk_message_array(data, chunk_size_chars, overlap_chars)
except json.JSONDecodeError:
pass
# Try speaker pattern (e.g., "Alice: Hello")
speaker_pattern = r'^([A-Za-z_][A-Za-z0-9_\s]*):(.+?)(?=^[A-Za-z_][A-Za-z0-9_\s]*:|$)'
if re.search(speaker_pattern, content, re.MULTILINE | re.DOTALL):
return _chunk_speaker_messages(content, chunk_size_chars, overlap_chars)
# Fallback to line-based chunking
return _chunk_by_lines(content, chunk_size_chars, overlap_chars)
def _chunk_message_array(
messages: list,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Chunk a JSON array of message objects."""
# Delegate to JSON array chunking
chunks = _chunk_json_array(messages, chunk_size_chars, overlap_chars)
return chunks
def _chunk_speaker_messages(
content: str,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Chunk messages in 'Speaker: message' format."""
# Split on speaker patterns
pattern = r'(?=^[A-Za-z_][A-Za-z0-9_\s]*:)'
messages = re.split(pattern, content, flags=re.MULTILINE)
messages = [m.strip() for m in messages if m.strip()]
if not messages:
return [content]
chunks: list[str] = []
current_messages: list[str] = []
current_size = 0
for message in messages:
msg_size = len(message)
# If a single message is too large, include it as its own chunk
if msg_size > chunk_size_chars:
if current_messages:
chunks.append('\n'.join(current_messages))
current_messages = []
current_size = 0
chunks.append(message)
continue
if current_messages and current_size + msg_size + 1 > chunk_size_chars:
chunks.append('\n'.join(current_messages))
# Get overlap (last message(s) that fit)
overlap_messages = _get_overlap_messages(current_messages, overlap_chars)
current_messages = overlap_messages
current_size = sum(len(m) for m in current_messages) + len(current_messages) - 1
current_messages.append(message)
current_size += msg_size + 1
if current_messages:
chunks.append('\n'.join(current_messages))
return chunks if chunks else [content]
def _get_overlap_messages(messages: list[str], overlap_chars: int) -> list[str]:
"""Get messages from the end that fit within overlap_chars."""
if not messages:
return []
overlap: list[str] = []
current_size = 0
for msg in reversed(messages):
msg_size = len(msg) + 1
if current_size + msg_size > overlap_chars:
break
overlap.insert(0, msg)
current_size += msg_size
return overlap
def _chunk_by_lines(
content: str,
chunk_size_chars: int,
overlap_chars: int,
) -> list[str]:
"""Chunk content by line boundaries."""
lines = content.split('\n')
chunks: list[str] = []
current_lines: list[str] = []
current_size = 0
for line in lines:
line_size = len(line) + 1
if current_lines and current_size + line_size > chunk_size_chars:
chunks.append('\n'.join(current_lines))
# Get overlap lines
overlap_text = '\n'.join(current_lines)
overlap = _get_overlap_text(overlap_text, overlap_chars)
if overlap:
current_lines = overlap.split('\n')
current_size = len(overlap)
else:
current_lines = []
current_size = 0
current_lines.append(line)
current_size += line_size
if current_lines:
chunks.append('\n'.join(current_lines))
return chunks if chunks else [content]
T = TypeVar('T')
MAX_COMBINATIONS_TO_EVALUATE = 1000
def _random_combination(n: int, k: int) -> tuple[int, ...]:
"""Generate a random combination of k items from range(n)."""
return tuple(sorted(random.sample(range(n), k)))
def generate_covering_chunks(items: list[T], k: int) -> list[tuple[list[T], list[int]]]:
"""Generate chunks of items that cover all pairs using a greedy approach.
Based on the Handshake Flights Problem / Covering Design problem.
Each chunk of K items covers C(K,2) = K(K-1)/2 pairs. We greedily select
chunks to maximize coverage of uncovered pairs, minimizing the total number
of chunks needed to ensure every pair of items appears in at least one chunk.
For large inputs where C(n,k) > MAX_COMBINATIONS_TO_EVALUATE, random sampling
is used instead of exhaustive search to maintain performance.
Lower bound (Schönheim): F >= ceil(N/K * ceil((N-1)/(K-1)))
Args:
items: List of items to partition into covering chunks
k: Maximum number of items per chunk
Returns:
List of tuples (chunk_items, global_indices) where global_indices maps
each position in chunk_items to its index in the original items list.
"""
n = len(items)
if n <= k:
return [(items, list(range(n)))]
# Track uncovered pairs using frozensets of indices
uncovered_pairs: set[frozenset[int]] = {
frozenset([i, j]) for i in range(n) for j in range(i + 1, n)
}
chunks: list[tuple[list[T], list[int]]] = []
# Determine if we need to sample or can enumerate all combinations
total_combinations = comb(n, k)
use_sampling = total_combinations > MAX_COMBINATIONS_TO_EVALUATE
while uncovered_pairs:
# Greedy selection: find the chunk that covers the most uncovered pairs
best_chunk_indices: tuple[int, ...] | None = None
best_covered_count = 0
if use_sampling:
# Sample random combinations when there are too many to enumerate
seen_combinations: set[tuple[int, ...]] = set()
# Limit total attempts (including duplicates) to prevent infinite loops
max_total_attempts = MAX_COMBINATIONS_TO_EVALUATE * 3
total_attempts = 0
samples_evaluated = 0
while samples_evaluated < MAX_COMBINATIONS_TO_EVALUATE:
total_attempts += 1
if total_attempts > max_total_attempts:
# Too many total attempts, break to avoid infinite loop
break
chunk_indices = _random_combination(n, k)
if chunk_indices in seen_combinations:
continue
seen_combinations.add(chunk_indices)
samples_evaluated += 1
# Count how many uncovered pairs this chunk covers
covered_count = sum(
1
for i, idx_i in enumerate(chunk_indices)
for idx_j in chunk_indices[i + 1 :]
if frozenset([idx_i, idx_j]) in uncovered_pairs
)
if covered_count > best_covered_count:
best_covered_count = covered_count
best_chunk_indices = chunk_indices
else:
# Enumerate all combinations when feasible
for chunk_indices in combinations(range(n), k):
# Count how many uncovered pairs this chunk covers
covered_count = sum(
1
for i, idx_i in enumerate(chunk_indices)
for idx_j in chunk_indices[i + 1 :]
if frozenset([idx_i, idx_j]) in uncovered_pairs
)
if covered_count > best_covered_count:
best_covered_count = covered_count
best_chunk_indices = chunk_indices
if best_chunk_indices is None or best_covered_count == 0:
# Greedy search couldn't find a chunk covering uncovered pairs.
# This can happen with random sampling. Fall back to creating
# small chunks that directly cover remaining pairs.
break
# Mark pairs in this chunk as covered
for i, idx_i in enumerate(best_chunk_indices):
for idx_j in best_chunk_indices[i + 1 :]:
uncovered_pairs.discard(frozenset([idx_i, idx_j]))
chunk_items = [items[idx] for idx in best_chunk_indices]
chunks.append((chunk_items, list(best_chunk_indices)))
# Handle any remaining uncovered pairs that the greedy algorithm missed.
# This can happen when random sampling fails to find covering chunks.
# Create minimal chunks (size 2) to guarantee all pairs are covered.
for pair in uncovered_pairs:
pair_indices = sorted(pair)
chunk_items = [items[idx] for idx in pair_indices]
chunks.append((chunk_items, pair_indices))
return chunks
```
--------------------------------------------------------------------------------
/graphiti_core/nodes.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from time import time
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from typing_extensions import LiteralString
from graphiti_core.driver.driver import (
GraphDriver,
GraphProvider,
)
from graphiti_core.embedder import EmbedderClient
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.helpers import parse_db_date
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN,
COMMUNITY_NODE_RETURN_NEPTUNE,
EPISODIC_NODE_RETURN,
EPISODIC_NODE_RETURN_NEPTUNE,
get_community_node_save_query,
get_entity_node_return_query,
get_entity_node_save_query,
get_episode_node_save_query,
)
from graphiti_core.utils.datetime_utils import utc_now
logger = logging.getLogger(__name__)
class EpisodeType(Enum):
"""
Enumeration of different types of episodes that can be processed.
This enum defines the various sources or formats of episodes that the system
can handle. It's used to categorize and potentially handle different types
of input data differently.
Attributes:
-----------
message : str
Represents a standard message-type episode. The content for this type
should be formatted as "actor: content". For example, "user: Hello, how are you?"
or "assistant: I'm doing well, thank you for asking."
json : str
Represents an episode containing a JSON string object with structured data.
text : str
Represents a plain text episode.
"""
message = 'message'
json = 'json'
text = 'text'
@staticmethod
def from_str(episode_type: str):
if episode_type == 'message':
return EpisodeType.message
if episode_type == 'json':
return EpisodeType.json
if episode_type == 'text':
return EpisodeType.text
logger.error(f'Episode type: {episode_type} not implemented')
raise NotImplementedError
class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
name: str = Field(description='name of the node')
group_id: str = Field(description='partition of the graph')
labels: list[str] = Field(default_factory=list)
created_at: datetime = Field(default_factory=lambda: utc_now())
@abstractmethod
async def save(self, driver: GraphDriver): ...
async def delete(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete(self, driver)
match driver.provider:
case GraphProvider.NEO4J:
records, _, _ = await driver.execute_query(
"""
MATCH (n {uuid: $uuid})
WHERE n:Entity OR n:Episodic OR n:Community
OPTIONAL MATCH (n)-[r]-()
WITH collect(r.uuid) AS edge_uuids, n
DETACH DELETE n
RETURN edge_uuids
""",
uuid=self.uuid,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
uuid=self.uuid,
)
await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
DETACH DELETE n
""",
uuid=self.uuid,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{uuid: $uuid}})
DETACH DELETE n
""",
uuid=self.uuid,
)
logger.debug(f'Deleted Node: {self.uuid}')
def __hash__(self):
return hash(self.uuid)
def __eq__(self, other):
if isinstance(other, Node):
return self.uuid == other.uuid
return False
@classmethod
async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete_by_group_id(
cls, driver, group_id, batch_size
)
match driver.provider:
case GraphProvider.NEO4J:
async with driver.session() as session:
await session.run(
"""
MATCH (n:Entity|Episodic|Community {group_id: $group_id})
CALL (n) {
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
group_id=group_id,
batch_size=batch_size,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
DETACH DELETE e
""",
group_id=group_id,
)
await driver.execute_query(
"""
MATCH (n:Entity {group_id: $group_id})
DETACH DELETE n
""",
group_id=group_id,
)
case _: # FalkorDB, Neptune
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label} {{group_id: $group_id}})
DETACH DELETE n
""",
group_id=group_id,
)
@classmethod
async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_delete_by_uuids(
cls, driver, uuids, group_id=None, batch_size=batch_size
)
match driver.provider:
case GraphProvider.FALKORDB:
for label in ['Entity', 'Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
case GraphProvider.KUZU:
for label in ['Episodic', 'Community']:
await driver.execute_query(
f"""
MATCH (n:{label})
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
# Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
# Explicitly delete the "edge" nodes first, then the entity node.
await driver.execute_query(
"""
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
WHERE n.uuid IN $uuids
DETACH DELETE e
""",
uuids=uuids,
)
await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
DETACH DELETE n
""",
uuids=uuids,
)
case _: # Neo4J, Neptune
async with driver.session() as session:
# Collect all edge UUIDs before deleting nodes
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
MATCH (n)-[r]-()
RETURN collect(r.uuid) AS edge_uuids
""",
uuids=uuids,
)
# Now delete the nodes in batches
await session.run(
"""
MATCH (n:Entity|Episodic|Community)
WHERE n.uuid IN $uuids
CALL (n) {
DETACH DELETE n
} IN TRANSACTIONS OF $batch_size ROWS
""",
uuids=uuids,
batch_size=batch_size,
)
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
class EpisodicNode(Node):
source: EpisodeType = Field(description='source type')
source_description: str = Field(description='description of the data source')
content: str = Field(description='raw episode data')
valid_at: datetime = Field(
description='datetime of when the original document was created',
)
entity_edges: list[str] = Field(
description='list of entity edges referenced in this episode',
default_factory=list,
)
async def save(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.episodic_node_save(self, driver)
episode_args = {
'uuid': self.uuid,
'name': self.name,
'group_id': self.group_id,
'source_description': self.source_description,
'content': self.content,
'entity_edges': self.entity_edges,
'created_at': self.created_at,
'valid_at': self.valid_at,
'source': self.source.value,
}
result = await driver.execute_query(
get_episode_node_save_query(driver.provider), **episode_args
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic {uuid: $uuid})
RETURN
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuid=uuid,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
if len(episodes) == 0:
raise NodeNotFoundError(uuid)
return episodes[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)
WHERE e.uuid IN $uuids
RETURN DISTINCT
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
uuids=uuids,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)
WHERE e.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN DISTINCT
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
)
+ """
ORDER BY uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
@classmethod
async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
RETURN DISTINCT
"""
+ (
EPISODIC_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else EPISODIC_NODE_RETURN
),
entity_node_uuid=entity_node_uuid,
routing_='r',
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
class EntityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
attributes: dict[str, Any] = Field(
default={}, description='Additional attributes of the node. Dependent on node labels'
)
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
self.name_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_load_embeddings(self, driver)
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
"""
else:
query: LiteralString = """
MATCH (n:Entity {uuid: $uuid})
RETURN n.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
self.name_embedding = records[0]['name_embedding']
async def save(self, driver: GraphDriver):
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_save(self, driver)
entity_data: dict[str, Any] = {
'uuid': self.uuid,
'name': self.name,
'name_embedding': self.name_embedding,
'group_id': self.group_id,
'summary': self.summary,
'created_at': self.created_at,
}
if driver.provider == GraphProvider.KUZU:
entity_data['attributes'] = json.dumps(self.attributes)
entity_data['labels'] = list(set(self.labels + ['Entity']))
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels=''),
**entity_data,
)
else:
entity_data.update(self.attributes or {})
labels = ':'.join(self.labels + ['Entity'])
result = await driver.execute_query(
get_entity_node_save_query(driver.provider, labels),
entity_data=entity_data,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity {uuid: $uuid})
RETURN
"""
+ get_entity_node_return_query(driver.provider),
uuid=uuid,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.uuid IN $uuids
RETURN
"""
+ get_entity_node_return_query(driver.provider),
uuids=uuids,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
with_embeddings: bool = False,
):
cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
with_embeddings_query: LiteralString = (
""",
n.name_embedding AS name_embedding
"""
if with_embeddings
else ''
)
records, _, _ = await driver.execute_query(
"""
MATCH (n:Entity)
WHERE n.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ with_embeddings_query
+ """
ORDER BY n.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
class CommunityNode(Node):
name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
summary: str = Field(description='region summary of member nodes', default_factory=str)
async def save(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
'communities',
[{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
)
result = await driver.execute_query(
get_community_node_save_query(driver.provider), # type: ignore
uuid=self.uuid,
name=self.name,
group_id=self.group_id,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
)
logger.debug(f'Saved Node to Graph: {self.uuid}')
return result
async def generate_name_embedding(self, embedder: EmbedderClient):
start = time()
text = self.name.replace('\n', ' ')
self.name_embedding = await embedder.create(input_data=[text])
end = time()
logger.debug(f'embedded {text} in {end - start} ms')
return self.name_embedding
async def load_name_embedding(self, driver: GraphDriver):
if driver.provider == GraphProvider.NEPTUNE:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
"""
else:
query: LiteralString = """
MATCH (c:Community {uuid: $uuid})
RETURN c.name_embedding AS name_embedding
"""
records, _, _ = await driver.execute_query(
query,
uuid=self.uuid,
routing_='r',
)
if len(records) == 0:
raise NodeNotFoundError(self.uuid)
self.name_embedding = records[0]['name_embedding']
@classmethod
async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community {uuid: $uuid})
RETURN
"""
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuid=uuid,
routing_='r',
)
nodes = [get_community_node_from_record(record) for record in records]
if len(nodes) == 0:
raise NodeNotFoundError(uuid)
return nodes[0]
@classmethod
async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)
WHERE c.uuid IN $uuids
RETURN
"""
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
),
uuids=uuids,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities
@classmethod
async def get_by_group_ids(
cls,
driver: GraphDriver,
group_ids: list[str],
limit: int | None = None,
uuid_cursor: str | None = None,
):
cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)
WHERE c.group_id IN $group_ids
"""
+ cursor_query
+ """
RETURN
"""
+ (
COMMUNITY_NODE_RETURN_NEPTUNE
if driver.provider == GraphProvider.NEPTUNE
else COMMUNITY_NODE_RETURN
)
+ """
ORDER BY c.uuid DESC
"""
+ limit_query,
group_ids=group_ids,
uuid=uuid_cursor,
limit=limit,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities
# Node helpers
def get_episodic_node_from_record(record: Any) -> EpisodicNode:
created_at = parse_db_date(record['created_at'])
valid_at = parse_db_date(record['valid_at'])
if created_at is None:
raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
if valid_at is None:
raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
return EpisodicNode(
content=record['content'],
created_at=created_at,
valid_at=valid_at,
uuid=record['uuid'],
group_id=record['group_id'],
source=EpisodeType.from_str(record['source']),
name=record['name'],
source_description=record['source_description'],
entity_edges=record['entity_edges'],
)
def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
if provider == GraphProvider.KUZU:
attributes = json.loads(record['attributes']) if record['attributes'] else {}
else:
attributes = record['attributes']
attributes.pop('uuid', None)
attributes.pop('name', None)
attributes.pop('group_id', None)
attributes.pop('name_embedding', None)
attributes.pop('summary', None)
attributes.pop('created_at', None)
attributes.pop('labels', None)
labels = record.get('labels', [])
group_id = record.get('group_id')
if 'Entity_' + group_id.replace('-', '') in labels:
labels.remove('Entity_' + group_id.replace('-', ''))
entity_node = EntityNode(
uuid=record['uuid'],
name=record['name'],
name_embedding=record.get('name_embedding'),
group_id=group_id,
labels=labels,
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
attributes=attributes,
)
return entity_node
def get_community_node_from_record(record: Any) -> CommunityNode:
return CommunityNode(
uuid=record['uuid'],
name=record['name'],
group_id=record['group_id'],
name_embedding=record['name_embedding'],
created_at=parse_db_date(record['created_at']), # type: ignore
summary=record['summary'],
)
async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
# filter out falsey values from nodes
filtered_nodes = [node for node in nodes if node.name]
if not filtered_nodes:
return
name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
node.name_embedding = name_embedding
```