This is page 5 of 9. Use http://codebase.md/getzep/graphiti?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/graphiti_core/models/nodes/node_db_queries.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any
from graphiti_core.driver.driver import GraphProvider
def get_episode_node_save_query(provider: GraphProvider) -> str:
match provider:
case GraphProvider.NEPTUNE:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Episodic {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.source = $source,
n.source_description = $source_description,
n.content = $content,
n.valid_at = $valid_at,
n.entity_edges = $entity_edges
RETURN n.uuid AS uuid
"""
case GraphProvider.FALKORDB:
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Episodic {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
RETURN n.uuid AS uuid
"""
def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
match provider:
case GraphProvider.NEPTUNE:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
source: episode.source, content: episode.content,
entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Episodic {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.source = $source,
n.source_description = $source_description,
n.content = $content,
n.valid_at = $valid_at,
n.entity_edges = $entity_edges
RETURN n.uuid AS uuid
"""
case GraphProvider.FALKORDB:
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
UNWIND $episodes AS episode
MERGE (n:Episodic {uuid: episode.uuid})
SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
RETURN n.uuid AS uuid
"""
EPISODIC_NODE_RETURN = """
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.created_at AS created_at,
e.source AS source,
e.source_description AS source_description,
e.content AS content,
e.valid_at AS valid_at,
e.entity_edges AS entity_edges
"""
EPISODIC_NODE_RETURN_NEPTUNE = """
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
split(e.entity_edges, ",") AS entity_edges
"""
def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
match provider:
case GraphProvider.FALKORDB:
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
SET n.name_embedding = vecf32($entity_data.name_embedding)
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Entity {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.labels = $labels,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary,
n.attributes = $attributes
WITH n
RETURN n.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
label_subquery = ''
for label in labels.split(':'):
label_subquery += f' SET n:{label}\n'
return f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
{label_subquery}
SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case _:
save_embedding_query = (
'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
if not has_aoss
else ''
)
return (
f"""
MERGE (n:Entity {{uuid: $entity_data.uuid}})
SET n:{labels}
SET n = $entity_data
"""
+ save_embedding_query
+ """
RETURN n.uuid AS uuid
"""
)
def get_entity_node_save_bulk_query(
provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
) -> str | Any:
match provider:
case GraphProvider.FALKORDB:
queries = []
for node in nodes:
for label in node['labels']:
queries.append(
(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
SET n:{label}
SET n = node
WITH n, node
SET n.name_embedding = vecf32(node.name_embedding)
RETURN n.uuid AS uuid
""",
{'nodes': [node]},
)
)
return queries
case GraphProvider.NEPTUNE:
queries = []
for node in nodes:
labels = ''
for label in node['labels']:
labels += f' SET n:{label}\n'
queries.append(
f"""
UNWIND $nodes AS node
MERGE (n:Entity {{uuid: node.uuid}})
{labels}
SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
)
return queries
case GraphProvider.KUZU:
return """
MERGE (n:Entity {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.labels = $labels,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary,
n.attributes = $attributes
RETURN n.uuid AS uuid
"""
case _: # Neo4j
save_embedding_query = (
'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
if not has_aoss
else ''
)
return (
"""
UNWIND $nodes AS node
MERGE (n:Entity {uuid: node.uuid})
SET n:$(node.labels)
SET n = node
"""
+ save_embedding_query
+ """
RETURN n.uuid AS uuid
"""
)
def get_entity_node_return_query(provider: GraphProvider) -> str:
# `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
if provider == GraphProvider.KUZU:
return """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.labels AS labels,
n.created_at AS created_at,
n.summary AS summary,
n.attributes AS attributes
"""
return """
n.uuid AS uuid,
n.name AS name,
n.group_id AS group_id,
n.created_at AS created_at,
n.summary AS summary,
labels(n) AS labels,
properties(n) AS attributes
"""
def get_community_node_save_query(provider: GraphProvider) -> str:
match provider:
case GraphProvider.FALKORDB:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
RETURN n.uuid AS uuid
"""
case GraphProvider.NEPTUNE:
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
RETURN n.uuid AS uuid
"""
case GraphProvider.KUZU:
return """
MERGE (n:Community {uuid: $uuid})
SET
n.name = $name,
n.group_id = $group_id,
n.created_at = $created_at,
n.name_embedding = $name_embedding,
n.summary = $summary
RETURN n.uuid AS uuid
"""
case _: # Neo4j
return """
MERGE (n:Community {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
RETURN n.uuid AS uuid
"""
COMMUNITY_NODE_RETURN = """
c.uuid AS uuid,
c.name AS name,
c.group_id AS group_id,
c.created_at AS created_at,
c.name_embedding AS name_embedding,
c.summary AS summary
"""
COMMUNITY_NODE_RETURN_NEPTUNE = """
n.uuid AS uuid,
n.name AS name,
[x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
n.group_id AS group_id,
n.summary AS summary,
n.created_at AS created_at
"""
```
--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
from graphiti_core.llm_client import LLMConfig, RateLimitError
@pytest.fixture
def mock_gemini_client():
"""Fixture to mock the Google Gemini client."""
with patch('google.genai.Client') as mock_client:
# Setup mock instance and its methods
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.generate_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_reranker_client(mock_gemini_client):
"""Fixture to create a GeminiRerankerClient with a mocked client."""
config = LLMConfig(api_key='test_api_key', model='test-model')
client = GeminiRerankerClient(config=config)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_gemini_client
return client
def create_mock_response(score_text: str) -> MagicMock:
"""Helper function to create a mock Gemini response."""
mock_response = MagicMock()
mock_response.text = score_text
return mock_response
class TestGeminiRerankerClientInitialization:
"""Tests for GeminiRerankerClient initialization."""
def test_init_with_config(self):
"""Test initialization with a config object."""
config = LLMConfig(api_key='test_api_key', model='test-model')
client = GeminiRerankerClient(config=config)
assert client.config == config
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
client = GeminiRerankerClient()
assert client.config is not None
def test_init_with_custom_client(self):
"""Test initialization with a custom client."""
mock_client = MagicMock()
client = GeminiRerankerClient(client=mock_client)
assert client.client == mock_client
class TestGeminiRerankerClientRanking:
"""Tests for GeminiRerankerClient rank method."""
@pytest.mark.asyncio
async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
"""Test basic ranking functionality."""
# Setup mock responses with different scores
mock_responses = [
create_mock_response('85'), # High relevance
create_mock_response('45'), # Medium relevance
create_mock_response('20'), # Low relevance
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
# Test data
query = 'What is the capital of France?'
passages = [
'Paris is the capital and most populous city of France.',
'London is the capital city of England and the United Kingdom.',
'Berlin is the capital and largest city of Germany.',
]
# Call method
result = await gemini_reranker_client.rank(query, passages)
# Assertions
assert len(result) == 3
assert all(isinstance(item, tuple) for item in result)
assert all(
isinstance(passage, str) and isinstance(score, float) for passage, score in result
)
# Check scores are normalized to [0, 1] and sorted in descending order
scores = [score for _, score in result]
assert all(0.0 <= score <= 1.0 for score in scores)
assert scores == sorted(scores, reverse=True)
# Check that the highest scoring passage is first
assert result[0][1] == 0.85 # 85/100
assert result[1][1] == 0.45 # 45/100
assert result[2][1] == 0.20 # 20/100
@pytest.mark.asyncio
async def test_rank_empty_passages(self, gemini_reranker_client):
"""Test ranking with empty passages list."""
query = 'Test query'
passages = []
result = await gemini_reranker_client.rank(query, passages)
assert result == []
@pytest.mark.asyncio
async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
"""Test ranking with a single passage."""
# Setup mock response
mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
query = 'Test query'
passages = ['Single test passage']
result = await gemini_reranker_client.rank(query, passages)
assert len(result) == 1
assert result[0][0] == 'Single test passage'
assert result[0][1] == 1.0 # Single passage gets full score
@pytest.mark.asyncio
async def test_rank_score_extraction_with_regex(
self, gemini_reranker_client, mock_gemini_client
):
"""Test score extraction from various response formats."""
# Setup mock responses with different formats
mock_responses = [
create_mock_response('Score: 90'), # Contains text before number
create_mock_response('The relevance is 65 out of 100'), # Contains text around number
create_mock_response('8'), # Just the number
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that scores were extracted correctly and normalized
scores = [score for _, score in result]
assert 0.90 in scores # 90/100
assert 0.65 in scores # 65/100
assert 0.08 in scores # 8/100
@pytest.mark.asyncio
async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of invalid or non-numeric scores."""
# Setup mock responses with invalid scores
mock_responses = [
create_mock_response('Not a number'), # Invalid response
create_mock_response(''), # Empty response
create_mock_response('95'), # Valid response
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that invalid scores are handled gracefully (assigned 0.0)
scores = [score for _, score in result]
assert 0.95 in scores # Valid score
assert scores.count(0.0) == 2 # Two invalid scores assigned 0.0
@pytest.mark.asyncio
async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
"""Test that scores are properly clamped to [0, 1] range."""
# Setup mock responses with extreme scores
# Note: regex only matches 1-3 digits, so negative numbers won't match
mock_responses = [
create_mock_response('999'), # Above 100 but within regex range
create_mock_response('invalid'), # Invalid response becomes 0.0
create_mock_response('50'), # Normal score
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that scores are normalized and clamped
scores = [score for _, score in result]
assert all(0.0 <= score <= 1.0 for score in scores)
# 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
assert 1.0 in scores
# Invalid response should be 0.0
assert 0.0 in scores
# Normal score should be normalized (50/100 = 0.5)
assert 0.5 in scores
@pytest.mark.asyncio
async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of rate limit errors."""
# Setup mock to raise rate limit error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Rate limit exceeded'
)
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of quota errors."""
# Setup mock to raise quota error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of resource exhausted errors."""
# Setup mock to raise resource exhausted error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of HTTP 429 errors."""
# Setup mock to raise 429 error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'HTTP 429 Too Many Requests'
)
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of generic errors."""
# Setup mock to raise generic error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(Exception) as exc_info:
await gemini_reranker_client.rank(query, passages)
assert 'Generic error' in str(exc_info.value)
@pytest.mark.asyncio
async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
"""Test that multiple passages are scored concurrently."""
# Setup mock responses
mock_responses = [
create_mock_response('80'),
create_mock_response('60'),
create_mock_response('40'),
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
await gemini_reranker_client.rank(query, passages)
# Verify that generate_content was called for each passage
assert mock_gemini_client.aio.models.generate_content.call_count == 3
# Verify that all calls were made with correct parameters
calls = mock_gemini_client.aio.models.generate_content.call_args_list
for call in calls:
args, kwargs = call
assert kwargs['model'] == gemini_reranker_client.config.model
assert kwargs['config'].temperature == 0.0
assert kwargs['config'].max_output_tokens == 3
@pytest.mark.asyncio
async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of response parsing errors."""
# Setup mock responses that will trigger ValueError during parsing
mock_responses = [
create_mock_response('not a number at all'), # Will fail regex match
create_mock_response('also invalid text'), # Will fail regex match
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
# Use multiple passages to avoid the single passage special case
passages = ['Passage 1', 'Passage 2']
result = await gemini_reranker_client.rank(query, passages)
# Should handle the error gracefully and assign 0.0 score to both
assert len(result) == 2
assert all(score == 0.0 for _, score in result)
@pytest.mark.asyncio
async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of empty response text."""
# Setup mock response with empty text
mock_response = MagicMock()
mock_response.text = '' # Empty string instead of None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
query = 'Test query'
# Use multiple passages to avoid the single passage special case
passages = ['Passage 1', 'Passage 2']
result = await gemini_reranker_client.rank(query, passages)
# Should handle empty text gracefully and assign 0.0 score to both
assert len(result) == 2
assert all(score == 0.0 for _, score in result)
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_reranker_client.py'])
```
--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import sys
from datetime import datetime
import numpy as np
import pytest
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import get_edge_count, get_node_count, group_id
pytest_plugins = ('pytest_asyncio',)
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
@pytest.mark.asyncio
async def test_episodic_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create episodic node
episode_node = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
group_id=group_id,
)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await episode_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 1
# Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create episodic to entity edge
episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
await episodic_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid
assert retrieved.target_node_uuid == alice_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get episodic node by entity node uuid
retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == episode_node.uuid
assert retrieved[0].name == 'test_episode'
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await episodic_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await episodic_edge.save(graph_driver)
await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_entity_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create entity node
bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
)
await bob_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await bob_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 1
# Create entity to entity edge
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
edge_embedding = await entity_edge.generate_embedding(mock_embedder)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
await entity_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid
assert retrieved.target_node_uuid == bob_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by node uuid
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get fact embedding
await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
# Delete edge by uuid
await entity_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await entity_edge.save(graph_driver)
await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node should delete the edge
await entity_edge.save(graph_driver)
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by uuids should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by group id should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await bob_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_community_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create community node
community_node_1 = CommunityNode(
name='test_community_1',
group_id=group_id,
summary='Community A summary',
)
await community_node_1.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_1.save(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 1
# Create community node
community_node_2 = CommunityNode(
name='test_community_2',
group_id=group_id,
summary='Community B summary',
)
await community_node_2.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await community_node_2.save(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 1
# Create entity node
alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create community to community edge
community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
await community_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid
assert retrieved.target_node_uuid == community_node_2.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await community_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await community_edge.save(graph_driver)
await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await community_node_1.delete(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_2.delete(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await graph_driver.close()
```
--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/embedder/test_gemini.py
from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from embedder_fixtures import create_embedding_values
from graphiti_core.embedder.gemini import (
DEFAULT_EMBEDDING_MODEL,
GeminiEmbedder,
GeminiEmbedderConfig,
)
def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
"""Create a mock Gemini embedding with specified value multiplier and dimension."""
mock_embedding = MagicMock()
mock_embedding.values = create_embedding_values(multiplier, dimension)
return mock_embedding
@pytest.fixture
def mock_gemini_response() -> MagicMock:
"""Create a mock Gemini embeddings response."""
mock_result = MagicMock()
mock_result.embeddings = [create_gemini_embedding()]
return mock_result
@pytest.fixture
def mock_gemini_batch_response() -> MagicMock:
"""Create a mock Gemini batch embeddings response."""
mock_result = MagicMock()
mock_result.embeddings = [
create_gemini_embedding(0.1),
create_gemini_embedding(0.2),
create_gemini_embedding(0.3),
]
return mock_result
@pytest.fixture
def mock_gemini_client() -> Generator[Any, Any, None]:
"""Create a mocked Gemini client."""
with patch('google.genai.Client') as mock_client:
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.embed_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
"""Create a GeminiEmbedder with a mocked client."""
config = GeminiEmbedderConfig(api_key='test_api_key')
client = GeminiEmbedder(config=config)
client.client = mock_gemini_client
return client
class TestGeminiEmbedderInitialization:
"""Tests for GeminiEmbedder initialization."""
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
)
embedder = GeminiEmbedder(config=config)
assert embedder.config == config
assert embedder.config.embedding_model == 'custom-model'
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_dim == 768
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
embedder = GeminiEmbedder()
assert embedder.config is not None
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
@patch('google.genai.Client')
def test_init_with_partial_config(self, mock_client):
"""Test initialization with partial config."""
config = GeminiEmbedderConfig(api_key='test_api_key')
embedder = GeminiEmbedder(config=config)
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
class TestGeminiEmbedderCreate:
"""Tests for GeminiEmbedder create method."""
@pytest.mark.asyncio
async def test_create_calls_api_correctly(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test that create method correctly calls the API and processes the response."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
result = await gemini_embedder.create('Test input')
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == ['Test input']
# Verify result is processed correctly
assert result == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_model(
self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
) -> None:
"""Test create method with custom embedding model."""
# Setup embedder with custom model
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
await embedder.create('Test input')
# Verify custom model is used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-model'
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create method with custom embedding dimension."""
# Setup embedder with custom dimension
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response with custom dimension
mock_response = MagicMock()
mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method
result = await embedder.create('Test input')
# Verify custom dimension is used in config
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['config'].output_dimensionality == 768
# Verify result has correct dimension
assert len(result) == 768
@pytest.mark.asyncio
async def test_create_with_different_input_types(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create method with different input types."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Test with string
await gemini_embedder.create('Test string')
# Test with list of strings
await gemini_embedder.create(['Test', 'List'])
# Test with iterable of integers
await gemini_embedder.create([1, 2, 3])
# Verify all calls were made
assert mock_gemini_client.aio.models.embed_content.call_count == 3
@pytest.mark.asyncio
async def test_create_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_no_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of embeddings with no values."""
# Setup mock response with embedding but no values
mock_embedding = MagicMock()
mock_embedding.values = None
mock_response = MagicMock()
mock_response.embeddings = [mock_embedding]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
class TestGeminiEmbedderCreateBatch:
"""Tests for GeminiEmbedder create_batch method."""
@pytest.mark.asyncio
async def test_create_batch_processes_multiple_inputs(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_batch_response: MagicMock,
) -> None:
"""Test that create_batch method correctly processes multiple inputs."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
input_batch = ['Input 1', 'Input 2', 'Input 3']
# Call method
result = await gemini_embedder.create_batch(input_batch)
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == input_batch
# Verify all results are processed correctly
assert len(result) == 3
assert result == [
mock_gemini_batch_response.embeddings[0].values,
mock_gemini_batch_response.embeddings[1].values,
mock_gemini_batch_response.embeddings[2].values,
]
@pytest.mark.asyncio
async def test_create_batch_single_input(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create_batch method with single input."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
input_batch = ['Single input']
result = await gemini_embedder.create_batch(input_batch)
assert len(result) == 1
assert result[0] == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
async def test_create_batch_empty_input(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method with empty input."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = []
result = await gemini_embedder.create_batch(input_batch)
assert result == []
mock_gemini_client.aio.models.embed_content.assert_not_called()
@pytest.mark.asyncio
async def test_create_batch_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'No embeddings returned from Gemini API' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_batch_empty_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of embeddings with empty values."""
# Setup mock response with embeddings but empty values
mock_embedding1 = MagicMock()
mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
mock_embedding2 = MagicMock()
mock_embedding2.values = None # Empty values
# Mock response for the initial batch call
mock_batch_response = MagicMock()
mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
# Mock response for individual processing of 'Input 1'
mock_individual_response_1 = MagicMock()
mock_individual_response_1.embeddings = [mock_embedding1]
# Mock response for individual processing of 'Input 2' (which has empty values)
mock_individual_response_2 = MagicMock()
mock_individual_response_2.embeddings = [mock_embedding2]
# Set side_effect for embed_content to control return values for each call
mock_gemini_client.aio.models.embed_content.side_effect = [
mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2, # Third call for individual item 2
]
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'Empty embedding values returned' in str(exc_info.value)
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_batch_with_custom_model_and_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create_batch method with custom model and dimension."""
# Setup embedder with custom settings
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response
mock_response = MagicMock()
mock_response.embeddings = [
create_gemini_embedding(0.1, 512),
create_gemini_embedding(0.2, 512),
]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
result = await embedder.create_batch(input_batch)
# Verify custom settings are used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-batch-model'
assert kwargs['config'].output_dimensionality == 512
# Verify results have correct dimension
assert len(result) == 2
assert all(len(embedding) == 512 for embedding in result)
if __name__ == '__main__':
pytest.main(['-xvs', __file__])
```
--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.driver.driver import GraphProvider
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
HAS_FALKORDB = True
except ImportError:
FalkorDriver = None
HAS_FALKORDB = False
class TestFalkorDriver:
"""Comprehensive test suite for FalkorDB driver."""
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def setup_method(self):
"""Set up test fixtures."""
self.mock_client = MagicMock()
with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
self.driver = FalkorDriver()
self.driver.client = self.mock_client
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_init_with_connection_params(self):
"""Test initialization with connection parameters."""
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
driver = FalkorDriver(
host='test-host', port='1234', username='test-user', password='test-pass'
)
assert driver.provider == GraphProvider.FALKORDB
mock_falkor_db.assert_called_once_with(
host='test-host', port='1234', username='test-user', password='test-pass'
)
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_init_with_falkor_db_instance(self):
"""Test initialization with a FalkorDB instance."""
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
mock_falkor_db = MagicMock()
driver = FalkorDriver(falkor_db=mock_falkor_db)
assert driver.provider == GraphProvider.FALKORDB
assert driver.client is mock_falkor_db
mock_falkor_db_class.assert_not_called()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_provider(self):
"""Test driver provider identification."""
assert self.driver.provider == GraphProvider.FALKORDB
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_get_graph_with_name(self):
"""Test _get_graph with specific graph name."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph('test_graph')
self.mock_client.select_graph.assert_called_once_with('test_graph')
assert result is mock_graph
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_get_graph_with_none_defaults_to_default_database(self):
"""Test _get_graph with None defaults to default_db."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph(None)
self.mock_client.select_graph.assert_called_once_with('default_db')
assert result is mock_graph
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_success(self):
"""Test successful query execution."""
mock_graph = MagicMock()
mock_result = MagicMock()
mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
mock_result.result_set = [['row1col1', 'row1col2']]
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')
mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})
result_set, header, summary = result
assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
assert header == ['column1', 'column2']
assert summary is None
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_handles_index_already_exists_error(self):
"""Test handling of 'already indexed' error."""
mock_graph = MagicMock()
mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
self.mock_client.select_graph.return_value = mock_graph
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
result = await self.driver.execute_query('CREATE INDEX ...')
mock_logger.info.assert_called_once()
assert result is None
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_propagates_other_exceptions(self):
"""Test that other exceptions are properly propagated."""
mock_graph = MagicMock()
mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
self.mock_client.select_graph.return_value = mock_graph
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
with pytest.raises(Exception, match='Other error'):
await self.driver.execute_query('INVALID QUERY')
mock_logger.error.assert_called_once()
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_converts_datetime_parameters(self):
"""Test that datetime objects in kwargs are converted to ISO strings."""
mock_graph = MagicMock()
mock_result = MagicMock()
mock_result.header = []
mock_result.result_set = []
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
await self.driver.execute_query(
'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
)
call_args = mock_graph.query.call_args[0]
assert call_args[1]['created_at'] == test_datetime.isoformat()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_session_creation(self):
"""Test session creation with specific database."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session()
assert isinstance(session, FalkorDriverSession)
assert session.graph is mock_graph
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_session_creation_with_none_uses_default_database(self):
"""Test session creation with None uses default database."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session()
assert isinstance(session, FalkorDriverSession)
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_close_calls_connection_close(self):
"""Test driver close method calls connection close."""
mock_connection = MagicMock()
mock_connection.close = AsyncMock()
self.mock_client.connection = mock_connection
# Ensure hasattr checks work correctly
del self.mock_client.aclose # Remove aclose if it exists
with patch('builtins.hasattr') as mock_hasattr:
# hasattr(self.client, 'aclose') returns False
# hasattr(self.client.connection, 'aclose') returns False
# hasattr(self.client.connection, 'close') returns True
mock_hasattr.side_effect = lambda obj, attr: (
attr == 'close' and obj is mock_connection
)
await self.driver.close()
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_delete_all_indexes(self):
"""Test delete_all_indexes method."""
with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
# Return None to simulate no indexes found
mock_execute.return_value = None
await self.driver.delete_all_indexes()
mock_execute.assert_called_once_with('CALL db.indexes()')
class TestFalkorDriverSession:
"""Test FalkorDB driver session functionality."""
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def setup_method(self):
"""Set up test fixtures."""
self.mock_graph = MagicMock()
self.session = FalkorDriverSession(self.mock_graph)
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_session_async_context_manager(self):
"""Test session can be used as async context manager."""
async with self.session as s:
assert s is self.session
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_close_method(self):
"""Test session close method doesn't raise exceptions."""
await self.session.close() # Should not raise
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_write_passes_session_and_args(self):
"""Test execute_write method passes session and arguments correctly."""
async def test_func(session, *args, **kwargs):
assert session is self.session
assert args == ('arg1', 'arg2')
assert kwargs == {'key': 'value'}
return 'result'
result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
assert result == 'result'
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_run_single_query_with_parameters(self):
"""Test running a single query with parameters."""
self.mock_graph.query = AsyncMock()
await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')
self.mock_graph.query.assert_called_once_with(
'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
)
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_run_multiple_queries_as_list(self):
"""Test running multiple queries passed as list."""
self.mock_graph.query = AsyncMock()
queries = [
('MATCH (n) RETURN n', {'param1': 'value1'}),
('CREATE (n:Node)', {'param2': 'value2'}),
]
await self.session.run(queries)
assert self.mock_graph.query.call_count == 2
calls = self.mock_graph.query.call_args_list
assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_run_converts_datetime_objects_to_iso_strings(self):
"""Test that datetime objects are converted to ISO strings."""
self.mock_graph.query = AsyncMock()
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
await self.session.run(
'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
)
self.mock_graph.query.assert_called_once()
call_args = self.mock_graph.query.call_args[0]
assert call_args[1]['created_at'] == test_datetime.isoformat()
class TestDatetimeConversion:
"""Test datetime conversion utility function."""
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_datetime_dict(self):
"""Test datetime conversion in nested dictionary."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
input_dict = {
'string_val': 'test',
'datetime_val': test_datetime,
'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
}
result = convert_datetimes_to_strings(input_dict)
assert result['string_val'] == 'test'
assert result['datetime_val'] == test_datetime.isoformat()
assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
assert result['nested_dict']['nested_string'] == 'nested_test'
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_datetime_list_and_tuple(self):
"""Test datetime conversion in lists and tuples."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
# Test list
input_list = ['test', test_datetime, ['nested', test_datetime]]
result_list = convert_datetimes_to_strings(input_list)
assert result_list[0] == 'test'
assert result_list[1] == test_datetime.isoformat()
assert result_list[2][1] == test_datetime.isoformat()
# Test tuple
input_tuple = ('test', test_datetime)
result_tuple = convert_datetimes_to_strings(input_tuple)
assert isinstance(result_tuple, tuple)
assert result_tuple[0] == 'test'
assert result_tuple[1] == test_datetime.isoformat()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_single_datetime(self):
"""Test datetime conversion for single datetime object."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
result = convert_datetimes_to_strings(test_datetime)
assert result == test_datetime.isoformat()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_other_types_unchanged(self):
"""Test that non-datetime types are returned unchanged."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
assert convert_datetimes_to_strings('string') == 'string'
assert convert_datetimes_to_strings(123) == 123
assert convert_datetimes_to_strings(None) is None
assert convert_datetimes_to_strings(True) is True
# Simple integration test
class TestFalkorDriverIntegration:
"""Simple integration test for FalkorDB driver."""
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_basic_integration_with_real_falkordb(self):
"""Basic integration test with real FalkorDB instance."""
pytest.importorskip('falkordb')
falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
falkor_port = os.getenv('FALKORDB_PORT', '6379')
try:
driver = FalkorDriver(host=falkor_host, port=falkor_port)
# Test basic query execution
result = await driver.execute_query('RETURN 1 as test')
assert result is not None
result_set, header, summary = result
assert header == ['test']
assert result_set == [{'test': 1}]
await driver.close()
except Exception as e:
pytest.skip(f'FalkorDB not available for integration test: {e}')
```
--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------
```python
"""Factory classes for creating LLM, Embedder, and Database clients."""
from openai import AsyncAzureOpenAI
from config.schema import (
DatabaseConfig,
EmbedderConfig,
LLMConfig,
)
# Try to import FalkorDriver if available
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver # noqa: F401
HAS_FALKOR = True
except ImportError:
HAS_FALKOR = False
# Kuzu support removed - FalkorDB is now the default
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig
# Try to import additional providers if available
try:
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
HAS_AZURE_EMBEDDER = True
except ImportError:
HAS_AZURE_EMBEDDER = False
try:
from graphiti_core.embedder.gemini import GeminiEmbedder
HAS_GEMINI_EMBEDDER = True
except ImportError:
HAS_GEMINI_EMBEDDER = False
try:
from graphiti_core.embedder.voyage import VoyageAIEmbedder
HAS_VOYAGE_EMBEDDER = True
except ImportError:
HAS_VOYAGE_EMBEDDER = False
try:
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
HAS_AZURE_LLM = True
except ImportError:
HAS_AZURE_LLM = False
try:
from graphiti_core.llm_client.anthropic_client import AnthropicClient
HAS_ANTHROPIC = True
except ImportError:
HAS_ANTHROPIC = False
try:
from graphiti_core.llm_client.gemini_client import GeminiClient
HAS_GEMINI = True
except ImportError:
HAS_GEMINI = False
try:
from graphiti_core.llm_client.groq_client import GroqClient
HAS_GROQ = True
except ImportError:
HAS_GROQ = False
from utils.utils import create_azure_credential_token_provider
def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
"""Validate API key is present.
Args:
provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
api_key: The API key to validate
logger: Logger instance for output
Returns:
The validated API key
Raises:
ValueError: If API key is None or empty
"""
if not api_key:
raise ValueError(
f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
)
logger.info(f'Creating {provider_name} client')
return api_key
class LLMClientFactory:
"""Factory for creating LLM clients based on configuration."""
@staticmethod
def create(config: LLMConfig) -> LLMClient:
"""Create an LLM client based on the configured provider."""
import logging
logger = logging.getLogger(__name__)
provider = config.provider.lower()
match provider:
case 'openai':
if not config.providers.openai:
raise ValueError('OpenAI provider configuration not found')
api_key = config.providers.openai.api_key
_validate_api_key('OpenAI', api_key, logger)
from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
# Determine appropriate small model based on main model type
is_reasoning_model = (
config.model.startswith('gpt-5')
or config.model.startswith('o1')
or config.model.startswith('o3')
)
small_model = (
'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
) # Use reasoning model for small tasks if main model is reasoning
llm_config = CoreLLMConfig(
api_key=api_key,
model=config.model,
small_model=small_model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
# Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
if is_reasoning_model:
return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
else:
# For non-reasoning models, explicitly pass None to disable these parameters
return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)
case 'azure_openai':
if not HAS_AZURE_LLM:
raise ValueError(
'Azure OpenAI LLM client not available in current graphiti-core version'
)
if not config.providers.azure_openai:
raise ValueError('Azure OpenAI provider configuration not found')
azure_config = config.providers.azure_openai
if not azure_config.api_url:
raise ValueError('Azure OpenAI API URL is required')
# Handle Azure AD authentication if enabled
api_key: str | None = None
azure_ad_token_provider = None
if azure_config.use_azure_ad:
logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
azure_ad_token_provider = create_azure_credential_token_provider()
else:
api_key = azure_config.api_key
_validate_api_key('Azure OpenAI', api_key, logger)
# Create the Azure OpenAI client first
azure_client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=azure_config.api_url,
api_version=azure_config.api_version,
azure_deployment=azure_config.deployment_name,
azure_ad_token_provider=azure_ad_token_provider,
)
# Then create the LLMConfig
from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
llm_config = CoreLLMConfig(
api_key=api_key,
base_url=azure_config.api_url,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return AzureOpenAILLMClient(
azure_client=azure_client,
config=llm_config,
max_tokens=config.max_tokens,
)
case 'anthropic':
if not HAS_ANTHROPIC:
raise ValueError(
'Anthropic client not available in current graphiti-core version'
)
if not config.providers.anthropic:
raise ValueError('Anthropic provider configuration not found')
api_key = config.providers.anthropic.api_key
_validate_api_key('Anthropic', api_key, logger)
llm_config = GraphitiLLMConfig(
api_key=api_key,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return AnthropicClient(config=llm_config)
case 'gemini':
if not HAS_GEMINI:
raise ValueError('Gemini client not available in current graphiti-core version')
if not config.providers.gemini:
raise ValueError('Gemini provider configuration not found')
api_key = config.providers.gemini.api_key
_validate_api_key('Gemini', api_key, logger)
llm_config = GraphitiLLMConfig(
api_key=api_key,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return GeminiClient(config=llm_config)
case 'groq':
if not HAS_GROQ:
raise ValueError('Groq client not available in current graphiti-core version')
if not config.providers.groq:
raise ValueError('Groq provider configuration not found')
api_key = config.providers.groq.api_key
_validate_api_key('Groq', api_key, logger)
llm_config = GraphitiLLMConfig(
api_key=api_key,
base_url=config.providers.groq.api_url,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return GroqClient(config=llm_config)
case _:
raise ValueError(f'Unsupported LLM provider: {provider}')
class EmbedderFactory:
"""Factory for creating Embedder clients based on configuration."""
@staticmethod
def create(config: EmbedderConfig) -> EmbedderClient:
"""Create an Embedder client based on the configured provider."""
import logging
logger = logging.getLogger(__name__)
provider = config.provider.lower()
match provider:
case 'openai':
if not config.providers.openai:
raise ValueError('OpenAI provider configuration not found')
api_key = config.providers.openai.api_key
_validate_api_key('OpenAI Embedder', api_key, logger)
from graphiti_core.embedder.openai import OpenAIEmbedderConfig
embedder_config = OpenAIEmbedderConfig(
api_key=api_key,
embedding_model=config.model,
)
return OpenAIEmbedder(config=embedder_config)
case 'azure_openai':
if not HAS_AZURE_EMBEDDER:
raise ValueError(
'Azure OpenAI embedder not available in current graphiti-core version'
)
if not config.providers.azure_openai:
raise ValueError('Azure OpenAI provider configuration not found')
azure_config = config.providers.azure_openai
if not azure_config.api_url:
raise ValueError('Azure OpenAI API URL is required')
# Handle Azure AD authentication if enabled
api_key: str | None = None
azure_ad_token_provider = None
if azure_config.use_azure_ad:
logger.info(
'Creating Azure OpenAI Embedder client with Azure AD authentication'
)
azure_ad_token_provider = create_azure_credential_token_provider()
else:
api_key = azure_config.api_key
_validate_api_key('Azure OpenAI Embedder', api_key, logger)
# Create the Azure OpenAI client first
azure_client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=azure_config.api_url,
api_version=azure_config.api_version,
azure_deployment=azure_config.deployment_name,
azure_ad_token_provider=azure_ad_token_provider,
)
return AzureOpenAIEmbedderClient(
azure_client=azure_client,
model=config.model or 'text-embedding-3-small',
)
case 'gemini':
if not HAS_GEMINI_EMBEDDER:
raise ValueError(
'Gemini embedder not available in current graphiti-core version'
)
if not config.providers.gemini:
raise ValueError('Gemini provider configuration not found')
api_key = config.providers.gemini.api_key
_validate_api_key('Gemini Embedder', api_key, logger)
from graphiti_core.embedder.gemini import GeminiEmbedderConfig
gemini_config = GeminiEmbedderConfig(
api_key=api_key,
embedding_model=config.model or 'models/text-embedding-004',
embedding_dim=config.dimensions or 768,
)
return GeminiEmbedder(config=gemini_config)
case 'voyage':
if not HAS_VOYAGE_EMBEDDER:
raise ValueError(
'Voyage embedder not available in current graphiti-core version'
)
if not config.providers.voyage:
raise ValueError('Voyage provider configuration not found')
api_key = config.providers.voyage.api_key
_validate_api_key('Voyage Embedder', api_key, logger)
from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig
voyage_config = VoyageAIEmbedderConfig(
api_key=api_key,
embedding_model=config.model or 'voyage-3',
embedding_dim=config.dimensions or 1024,
)
return VoyageAIEmbedder(config=voyage_config)
case _:
raise ValueError(f'Unsupported Embedder provider: {provider}')
class DatabaseDriverFactory:
"""Factory for creating Database drivers based on configuration.
Note: This returns configuration dictionaries that can be passed to Graphiti(),
not driver instances directly, as the drivers require complex initialization.
"""
@staticmethod
def create_config(config: DatabaseConfig) -> dict:
"""Create database configuration dictionary based on the configured provider."""
provider = config.provider.lower()
match provider:
case 'neo4j':
# Use Neo4j config if provided, otherwise use defaults
if config.providers.neo4j:
neo4j_config = config.providers.neo4j
else:
# Create default Neo4j configuration
from config.schema import Neo4jProviderConfig
neo4j_config = Neo4jProviderConfig()
# Check for environment variable overrides (for CI/CD compatibility)
import os
uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
username = os.environ.get('NEO4J_USER', neo4j_config.username)
password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)
return {
'uri': uri,
'user': username,
'password': password,
# Note: database and use_parallel_runtime would need to be passed
# to the driver after initialization if supported
}
case 'falkordb':
if not HAS_FALKOR:
raise ValueError(
'FalkorDB driver not available in current graphiti-core version'
)
# Use FalkorDB config if provided, otherwise use defaults
if config.providers.falkordb:
falkor_config = config.providers.falkordb
else:
# Create default FalkorDB configuration
from config.schema import FalkorDBProviderConfig
falkor_config = FalkorDBProviderConfig()
# Check for environment variable overrides (for CI/CD compatibility)
import os
from urllib.parse import urlparse
uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)
# Parse the URI to extract host and port
parsed = urlparse(uri)
host = parsed.hostname or 'localhost'
port = parsed.port or 6379
return {
'driver': 'falkordb',
'host': host,
'port': port,
'password': password,
'database': falkor_config.database,
}
case _:
raise ValueError(f'Unsupported Database provider: {provider}')
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/anthropic_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import os
import typing
from json import JSONDecodeError
from typing import TYPE_CHECKING, Literal
from pydantic import BaseModel, ValidationError
from ..prompts.models import Message
from .client import LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError
if TYPE_CHECKING:
import anthropic
from anthropic import AsyncAnthropic
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
else:
try:
import anthropic
from anthropic import AsyncAnthropic
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
except ImportError:
raise ImportError(
'anthropic is required for AnthropicClient. '
'Install it with: pip install graphiti-core[anthropic]'
) from None
logger = logging.getLogger(__name__)
AnthropicModel = Literal[
'claude-sonnet-4-5-latest',
'claude-sonnet-4-5-20250929',
'claude-haiku-4-5-latest',
'claude-3-7-sonnet-latest',
'claude-3-7-sonnet-20250219',
'claude-3-5-haiku-latest',
'claude-3-5-haiku-20241022',
'claude-3-5-sonnet-latest',
'claude-3-5-sonnet-20241022',
'claude-3-5-sonnet-20240620',
'claude-3-opus-latest',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
'claude-2.1',
'claude-2.0',
]
DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
# Maximum output tokens for different Anthropic models
# Based on official Anthropic documentation (as of 2025)
# Note: These represent standard limits without beta headers.
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
ANTHROPIC_MODEL_MAX_TOKENS = {
# Claude 4.5 models - 64K tokens
'claude-sonnet-4-5-latest': 65536,
'claude-sonnet-4-5-20250929': 65536,
'claude-haiku-4-5-latest': 65536,
# Claude 3.7 models - standard 64K tokens
'claude-3-7-sonnet-latest': 65536,
'claude-3-7-sonnet-20250219': 65536,
# Claude 3.5 models
'claude-3-5-haiku-latest': 8192,
'claude-3-5-haiku-20241022': 8192,
'claude-3-5-sonnet-latest': 8192,
'claude-3-5-sonnet-20241022': 8192,
'claude-3-5-sonnet-20240620': 8192,
# Claude 3 models - 4K tokens
'claude-3-opus-latest': 4096,
'claude-3-opus-20240229': 4096,
'claude-3-sonnet-20240229': 4096,
'claude-3-haiku-20240307': 4096,
# Claude 2 models - 4K tokens
'claude-2.1': 4096,
'claude-2.0': 4096,
}
# Default max tokens for models not in the mapping
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
class AnthropicClient(LLMClient):
"""
A client for the Anthropic LLM.
Args:
config: A configuration object for the LLM.
cache: Whether to cache the LLM responses.
client: An optional client instance to use.
max_tokens: The maximum number of tokens to generate.
Methods:
generate_response: Generate a response from the LLM.
Notes:
- If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
variable, and all default values will be used for the LLMConfig.
"""
model: AnthropicModel
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
client: AsyncAnthropic | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> None:
if config is None:
config = LLMConfig()
config.api_key = os.getenv('ANTHROPIC_API_KEY')
config.max_tokens = max_tokens
if config.model is None:
config.model = DEFAULT_MODEL
super().__init__(config, cache)
# Explicitly set the instance model to the config model to prevent type checking errors
self.model = typing.cast(AnthropicModel, config.model)
if not client:
self.client = AsyncAnthropic(
api_key=config.api_key,
max_retries=1,
)
else:
self.client = client
def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
"""Extract JSON from text content.
A helper method to extract JSON from text content, used when tool use fails or
no response_model is provided.
Args:
text: The text to extract JSON from
Returns:
Extracted JSON as a dictionary
Raises:
ValueError: If JSON cannot be extracted or parsed
"""
try:
json_start = text.find('{')
json_end = text.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = text[json_start:json_end]
return json.loads(json_str)
else:
raise ValueError(f'Could not extract JSON from model response: {text}')
except (JSONDecodeError, ValueError) as e:
raise ValueError(f'Could not extract JSON from model response: {text}') from e
def _create_tool(
self, response_model: type[BaseModel] | None = None
) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
"""
Create a tool definition based on the response_model if provided, or a generic JSON tool if not.
Args:
response_model: Optional Pydantic model to use for structured output.
Returns:
A list containing a single tool definition for use with the Anthropic API.
"""
if response_model is not None:
# Use the response_model to define the tool
model_schema = response_model.model_json_schema()
tool_name = response_model.__name__
description = model_schema.get('description', f'Extract {tool_name} information')
else:
# Create a generic JSON output tool
tool_name = 'generic_json_output'
description = 'Output data in JSON format'
model_schema = {
'type': 'object',
'additionalProperties': True,
'description': 'Any JSON object containing the requested information',
}
tool = {
'name': tool_name,
'description': description,
'input_schema': model_schema,
}
tool_list = [tool]
tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
tool_choice = {'type': 'tool', 'name': tool_name}
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
return tool_list_cast, tool_choice_cast
def _get_max_tokens_for_model(self, model: str) -> int:
"""Get the maximum output tokens for a specific Anthropic model.
Args:
model: The model name to look up
Returns:
int: The maximum output tokens for the model
"""
return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
"""
Resolve the maximum output tokens to use based on precedence rules.
Precedence order (highest to lowest):
1. Explicit max_tokens parameter passed to generate_response()
2. Instance max_tokens set during client initialization
3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
Args:
requested_max_tokens: The max_tokens parameter passed to generate_response()
model: The model name to look up model-specific limits
Returns:
int: The resolved maximum tokens to use
"""
# 1. Use explicit parameter if provided
if requested_max_tokens is not None:
return requested_max_tokens
# 2. Use instance max_tokens if set during initialization
if self.max_tokens is not None:
return self.max_tokens
# 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
return self._get_max_tokens_for_model(model)
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
Generate a response from the Anthropic LLM using tool-based approach for all requests.
Args:
messages: List of message objects to send to the LLM.
response_model: Optional Pydantic model to use for structured output.
max_tokens: Maximum number of tokens to generate.
Returns:
Dictionary containing the structured response from the LLM.
Raises:
RateLimitError: If the rate limit is exceeded.
RefusalError: If the LLM refuses to respond.
Exception: If an error occurs during the generation process.
"""
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
user_messages_cast = typing.cast(list[MessageParam], user_messages)
# Resolve max_tokens dynamically based on the model's capabilities
# This allows different models to use their full output capacity
max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
try:
# Create the appropriate tool based on whether response_model is provided
tools, tool_choice = self._create_tool(response_model)
result = await self.client.messages.create(
system=system_message.content,
max_tokens=max_creation_tokens,
temperature=self.temperature,
messages=user_messages_cast,
model=self.model,
tools=tools,
tool_choice=tool_choice,
)
# Extract the tool output from the response
for content_item in result.content:
if content_item.type == 'tool_use':
if isinstance(content_item.input, dict):
tool_args: dict[str, typing.Any] = content_item.input
else:
tool_args = json.loads(str(content_item.input))
return tool_args
# If we didn't get a proper tool_use response, try to extract from text
for content_item in result.content:
if content_item.type == 'text':
return self._extract_json_from_text(content_item.text)
else:
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
)
# If we get here, we couldn't parse a structured response
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
)
except anthropic.RateLimitError as e:
raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
except anthropic.APIError as e:
# Special case for content policy violations. We convert these to RefusalError
# to bypass the retry mechanism, as retrying policy-violating content will always fail.
# This avoids wasting API calls and provides more specific error messaging to the user.
if 'refused to respond' in str(e).lower():
raise RefusalError(str(e)) from e
raise e
except Exception as e:
raise e
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
prompt_name: str | None = None,
) -> dict[str, typing.Any]:
"""
Generate a response from the LLM.
Args:
messages: List of message objects to send to the LLM.
response_model: Optional Pydantic model to use for structured output.
max_tokens: Maximum number of tokens to generate.
Returns:
Dictionary containing the structured response from the LLM.
Raises:
RateLimitError: If the rate limit is exceeded.
RefusalError: If the LLM refuses to respond.
Exception: If an error occurs during the generation process.
"""
if max_tokens is None:
max_tokens = self.max_tokens
# Wrap entire operation in tracing span
with self.tracer.start_span('llm.generate') as span:
attributes = {
'llm.provider': 'anthropic',
'model.size': model_size.value,
'max_tokens': max_tokens,
}
if prompt_name:
attributes['prompt.name'] = prompt_name
span.add_attributes(attributes)
retry_count = 0
max_retries = 2
last_error: Exception | None = None
while retry_count <= max_retries:
try:
response = await self._generate_response(
messages, response_model, max_tokens, model_size
)
# If we have a response_model, attempt to validate the response
if response_model is not None:
# Validate the response against the response_model
model_instance = response_model(**response)
return model_instance.model_dump()
# If no validation needed, return the response
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
span.set_status('error', str(last_error))
raise
except Exception as e:
last_error = e
if retry_count >= max_retries:
if isinstance(e, ValidationError):
logger.error(
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
)
else:
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
span.set_status('error', str(e))
span.record_exception(e)
raise e
if isinstance(e, ValidationError):
response_model_cast = typing.cast(type[BaseModel], response_model)
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
else:
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response.'
)
# Common retry logic
retry_count += 1
messages.append(Message(role='user', content=error_context))
logger.warning(
f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
)
# If we somehow get here, raise the last error
span.set_status('error', str(last_error))
raise last_error or Exception('Max retries exceeded with no specific error')
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_async_operations.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Asynchronous operation tests for Graphiti MCP Server.
Tests concurrent operations, queue management, and async patterns.
"""
import asyncio
import contextlib
import json
import time
import pytest
from test_fixtures import (
TestDataGenerator,
graphiti_test_client,
)
class TestAsyncQueueManagement:
"""Test asynchronous queue operations and episode processing."""
@pytest.mark.asyncio
async def test_sequential_queue_processing(self):
"""Verify episodes are processed sequentially within a group."""
async with graphiti_test_client() as (session, group_id):
# Add multiple episodes quickly
episodes = []
for i in range(5):
result = await session.call_tool(
'add_memory',
{
'name': f'Sequential Test {i}',
'episode_body': f'Episode {i} with timestamp {time.time()}',
'source': 'text',
'source_description': 'sequential test',
'group_id': group_id,
'reference_id': f'seq_{i}', # Add reference for tracking
},
)
episodes.append(result)
# Wait for processing
await asyncio.sleep(10) # Allow time for sequential processing
# Retrieve episodes and verify order
result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})
processed_episodes = json.loads(result.content[0].text)['episodes']
# Verify all episodes were processed
assert len(processed_episodes) >= 5, (
f'Expected at least 5 episodes, got {len(processed_episodes)}'
)
# Verify sequential processing (timestamps should be ordered)
timestamps = [ep.get('created_at') for ep in processed_episodes]
assert timestamps == sorted(timestamps), 'Episodes not processed in order'
@pytest.mark.asyncio
async def test_concurrent_group_processing(self):
"""Test that different groups can process concurrently."""
async with graphiti_test_client() as (session, _):
groups = [f'group_{i}_{time.time()}' for i in range(3)]
tasks = []
# Create tasks for different groups
for group_id in groups:
for j in range(2):
task = session.call_tool(
'add_memory',
{
'name': f'Group {group_id} Episode {j}',
'episode_body': f'Content for {group_id}',
'source': 'text',
'source_description': 'concurrent test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute all tasks concurrently
start_time = time.time()
results = await asyncio.gather(*tasks, return_exceptions=True)
execution_time = time.time() - start_time
# Verify all succeeded
failures = [r for r in results if isinstance(r, Exception)]
assert not failures, f'Concurrent operations failed: {failures}'
# Check that execution was actually concurrent (should be faster than sequential)
# Sequential would take at least 6 * processing_time
assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'
@pytest.mark.asyncio
async def test_queue_overflow_handling(self):
"""Test behavior when queue reaches capacity."""
async with graphiti_test_client() as (session, group_id):
# Attempt to add many episodes rapidly
tasks = []
for i in range(100): # Large number to potentially overflow
task = session.call_tool(
'add_memory',
{
'name': f'Overflow Test {i}',
'episode_body': f'Episode {i}',
'source': 'text',
'source_description': 'overflow test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute with gathering to catch any failures
results = await asyncio.gather(*tasks, return_exceptions=True)
# Count successful queuing
successful = sum(1 for r in results if not isinstance(r, Exception))
# Should handle overflow gracefully
assert successful > 0, 'No episodes were queued successfully'
# Log overflow behavior
if successful < 100:
print(f'Queue overflow: {successful}/100 episodes queued')
class TestConcurrentOperations:
"""Test concurrent tool calls and operations."""
@pytest.mark.asyncio
async def test_concurrent_search_operations(self):
"""Test multiple concurrent search operations."""
async with graphiti_test_client() as (session, group_id):
# First, add some test data
data_gen = TestDataGenerator()
add_tasks = []
for _ in range(5):
task = session.call_tool(
'add_memory',
{
'name': 'Search Test Data',
'episode_body': data_gen.generate_technical_document(),
'source': 'text',
'source_description': 'search test',
'group_id': group_id,
},
)
add_tasks.append(task)
await asyncio.gather(*add_tasks)
await asyncio.sleep(15) # Wait for processing
# Now perform concurrent searches
search_queries = [
'architecture',
'performance',
'implementation',
'dependencies',
'latency',
]
search_tasks = []
for query in search_queries:
task = session.call_tool(
'search_memory_nodes',
{
'query': query,
'group_id': group_id,
'limit': 10,
},
)
search_tasks.append(task)
start_time = time.time()
results = await asyncio.gather(*search_tasks, return_exceptions=True)
search_time = time.time() - start_time
# Verify all searches completed
failures = [r for r in results if isinstance(r, Exception)]
assert not failures, f'Search operations failed: {failures}'
# Verify concurrent execution efficiency
assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'
@pytest.mark.asyncio
async def test_mixed_operation_concurrency(self):
"""Test different types of operations running concurrently."""
async with graphiti_test_client() as (session, group_id):
operations = []
# Add memory operation
operations.append(
session.call_tool(
'add_memory',
{
'name': 'Mixed Op Test',
'episode_body': 'Testing mixed operations',
'source': 'text',
'source_description': 'test',
'group_id': group_id,
},
)
)
# Search operation
operations.append(
session.call_tool(
'search_memory_nodes',
{
'query': 'test',
'group_id': group_id,
'limit': 5,
},
)
)
# Get episodes operation
operations.append(
session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 10,
},
)
)
# Get status operation
operations.append(session.call_tool('get_status', {}))
# Execute all concurrently
results = await asyncio.gather(*operations, return_exceptions=True)
# Check results
for i, result in enumerate(results):
assert not isinstance(result, Exception), f'Operation {i} failed: {result}'
class TestAsyncErrorHandling:
"""Test async error handling and recovery."""
@pytest.mark.asyncio
async def test_timeout_recovery(self):
"""Test recovery from operation timeouts."""
async with graphiti_test_client() as (session, group_id):
# Create a very large episode that might time out
large_content = 'x' * 1000000 # 1MB of data
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(
session.call_tool(
'add_memory',
{
'name': 'Timeout Test',
'episode_body': large_content,
'source': 'text',
'source_description': 'timeout test',
'group_id': group_id,
},
),
timeout=2.0, # Short timeout - expected to timeout
)
# Verify server is still responsive after timeout
status_result = await session.call_tool('get_status', {})
assert status_result is not None, 'Server unresponsive after timeout'
@pytest.mark.asyncio
async def test_cancellation_handling(self):
"""Test proper handling of cancelled operations."""
async with graphiti_test_client() as (session, group_id):
# Start a long-running operation
task = asyncio.create_task(
session.call_tool(
'add_memory',
{
'name': 'Cancellation Test',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'cancel test',
'group_id': group_id,
},
)
)
# Cancel after a short delay
await asyncio.sleep(0.1)
task.cancel()
# Verify cancellation was handled
with pytest.raises(asyncio.CancelledError):
await task
# Server should still be operational
result = await session.call_tool('get_status', {})
assert result is not None
@pytest.mark.asyncio
async def test_exception_propagation(self):
"""Test that exceptions are properly propagated in async context."""
async with graphiti_test_client() as (session, group_id):
# Call with invalid arguments
with pytest.raises(ValueError):
await session.call_tool(
'add_memory',
{
# Missing required fields
'group_id': group_id,
},
)
# Server should remain operational
status = await session.call_tool('get_status', {})
assert status is not None
class TestAsyncPerformance:
"""Performance tests for async operations."""
@pytest.mark.asyncio
async def test_async_throughput(self, performance_benchmark):
"""Measure throughput of async operations."""
async with graphiti_test_client() as (session, group_id):
num_operations = 50
start_time = time.time()
# Create many concurrent operations
tasks = []
for i in range(num_operations):
task = session.call_tool(
'add_memory',
{
'name': f'Throughput Test {i}',
'episode_body': f'Content {i}',
'source': 'text',
'source_description': 'throughput test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute all
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.time() - start_time
# Calculate metrics
successful = sum(1 for r in results if not isinstance(r, Exception))
throughput = successful / total_time
performance_benchmark.record('async_throughput', throughput)
# Log results
print('\nAsync Throughput Test:')
print(f' Operations: {num_operations}')
print(f' Successful: {successful}')
print(f' Total time: {total_time:.2f}s')
print(f' Throughput: {throughput:.2f} ops/s')
# Assert minimum throughput
assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'
@pytest.mark.asyncio
async def test_latency_under_load(self, performance_benchmark):
"""Test operation latency under concurrent load."""
async with graphiti_test_client() as (session, group_id):
# Create background load
background_tasks = []
for i in range(10):
task = asyncio.create_task(
session.call_tool(
'add_memory',
{
'name': f'Background {i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'background',
'group_id': f'background_{group_id}',
},
)
)
background_tasks.append(task)
# Measure latency of operations under load
latencies = []
for _ in range(5):
start = time.time()
await session.call_tool('get_status', {})
latency = time.time() - start
latencies.append(latency)
performance_benchmark.record('latency_under_load', latency)
# Clean up background tasks
for task in background_tasks:
task.cancel()
# Analyze latencies
avg_latency = sum(latencies) / len(latencies)
max_latency = max(latencies)
print('\nLatency Under Load:')
print(f' Average: {avg_latency:.3f}s')
print(f' Max: {max_latency:.3f}s')
# Assert acceptable latency
assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'
class TestAsyncStreamHandling:
"""Test handling of streaming responses and data."""
@pytest.mark.asyncio
async def test_large_response_streaming(self):
"""Test handling of large streamed responses."""
async with graphiti_test_client() as (session, group_id):
# Add many episodes
for i in range(20):
await session.call_tool(
'add_memory',
{
'name': f'Stream Test {i}',
'episode_body': f'Episode content {i}',
'source': 'text',
'source_description': 'stream test',
'group_id': group_id,
},
)
# Wait for processing
await asyncio.sleep(30)
# Request large result set
result = await session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 100, # Request all
},
)
# Verify response handling
episodes = json.loads(result.content[0].text)['episodes']
assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'
@pytest.mark.asyncio
async def test_incremental_processing(self):
"""Test incremental processing of results."""
async with graphiti_test_client() as (session, group_id):
# Add episodes incrementally
for batch in range(3):
batch_tasks = []
for i in range(5):
task = session.call_tool(
'add_memory',
{
'name': f'Batch {batch} Item {i}',
'episode_body': f'Content for batch {batch}',
'source': 'text',
'source_description': 'incremental test',
'group_id': group_id,
},
)
batch_tasks.append(task)
# Process batch
await asyncio.gather(*batch_tasks)
# Wait for this batch to process
await asyncio.sleep(10)
# Verify incremental results
result = await session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 100,
},
)
episodes = json.loads(result.content[0].text)['episodes']
expected_min = (batch + 1) * 5
assert len(episodes) >= expected_min, (
f'Batch {batch}: Expected at least {expected_min} episodes'
)
if __name__ == '__main__':
pytest.main([__file__, '-v', '--asyncio-mode=auto'])
```
--------------------------------------------------------------------------------
/graphiti_core/search/search.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections import defaultdict
from time import time
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder.client import EMBEDDING_DIM
from graphiti_core.errors import SearchRerankerError
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
EpisodeReranker,
EpisodeSearchConfig,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_fulltext_search,
episode_mentions_reranker,
get_embeddings_for_communities,
get_embeddings_for_edges,
get_embeddings_for_nodes,
maximal_marginal_relevance,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
rrf,
)
logger = logging.getLogger(__name__)
async def search(
clients: GraphitiClients,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
query_vector: list[float] | None = None,
driver: GraphDriver | None = None,
) -> SearchResults:
start = time()
driver = driver or clients.driver
embedder = clients.embedder
cross_encoder = clients.cross_encoder
if query.strip() == '':
return SearchResults()
if (
config.edge_config
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
or config.edge_config
and EdgeReranker.mmr == config.edge_config.reranker
or config.node_config
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
or config.node_config
and NodeReranker.mmr == config.node_config.reranker
or (
config.community_config
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
)
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
):
search_vector = (
query_vector
if query_vector is not None
else await embedder.create(input_data=[query.replace('\n', ' ')])
)
else:
search_vector = [0.0] * EMBEDDING_DIM
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids and group_ids != [''] else None
(
(edges, edge_reranker_scores),
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
edge_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
node_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
episode_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.episode_config,
search_filter,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.community_config,
config.limit,
config.reranker_min_score,
),
)
results = SearchResults(
edges=edges,
edge_reranker_scores=edge_reranker_scores,
nodes=nodes,
node_reranker_scores=node_reranker_scores,
episodes=episodes,
episode_reranker_scores=episode_reranker_scores,
communities=communities,
community_reranker_scores=community_reranker_scores,
)
latency = (time() - start) * 1000
logger.debug(f'search returned context for query {query} in {latency} ms')
return results
async def edge_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[EntityEdge], list[float]]:
if config is None:
return [], []
# Build search tasks based on configured search methods
search_tasks = []
if EdgeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
edge_similarity_search(
driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
)
)
if EdgeSearchMethod.bfs in config.search_methods:
search_tasks.append(
edge_bfs_search(
driver,
bfs_origin_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
)
)
# Execute only the configured search methods
search_results: list[list[EntityEdge]] = []
if search_tasks:
search_results = list(await semaphore_gather(*search_tasks))
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append(
await edge_bfs_search(
driver,
source_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
)
)
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
edge_scores: list[float] = []
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_edges(
driver, list(edge_uuid_map.values())
)
reranked_uuids, edge_scores = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
reranker_min_score,
)
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
]
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort
sorted_result_uuids, node_scores = rrf(
[[edge.uuid for edge in result] for result in search_results],
min_score=reranker_min_score,
)
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
# node distance reranking
source_to_edge_uuid_map = defaultdict(list)
for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids, edge_scores = await node_distance_reranker(
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
)
for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges[:limit], edge_scores[:limit]
async def node_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[EntityNode], list[float]]:
if config is None:
return [], []
# Build search tasks based on configured search methods
search_tasks = []
if NodeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
node_similarity_search(
driver,
query_vector,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
)
)
if NodeSearchMethod.bfs in config.search_methods:
search_tasks.append(
node_bfs_search(
driver,
bfs_origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
)
)
# Execute only the configured search methods
search_results: list[list[EntityNode]] = []
if search_tasks:
search_results = list(await semaphore_gather(*search_tasks))
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(
driver,
origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
)
)
search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = []
node_scores: list[float] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
driver, list(node_uuid_map.values())
)
reranked_uuids, node_scores = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
reranker_min_score,
)
elif config.reranker == NodeReranker.cross_encoder:
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
reranked_uuids = [
name_to_uuid_map[name]
for name, score in reranked_node_names
if score >= reranker_min_score
]
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids, node_scores = await episode_mentions_reranker(
driver, search_result_uuids, min_score=reranker_min_score
)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids, node_scores = await node_distance_reranker(
driver,
rrf(search_result_uuids, min_score=reranker_min_score)[0],
center_node_uuid,
min_score=reranker_min_score,
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes[:limit], node_scores[:limit]
async def episode_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
_query_vector: list[float],
group_ids: list[str] | None,
config: EpisodeSearchConfig | None,
search_filter: SearchFilters,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[EpisodicNode], list[float]]:
if config is None:
return [], []
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
]
)
)
search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
reranked_uuids: list[str] = []
episode_scores: list[float] = []
if config.reranker == EpisodeReranker.rrf:
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EpisodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
reranked_uuids = [
content_to_uuid_map[content]
for content, score in reranked_contents
if score >= reranker_min_score
]
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_episodes[:limit], episode_scores[:limit]
async def community_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[CommunityNode], list[float]]:
if config is None:
return [], []
search_results: list[list[CommunityNode]] = list(
await semaphore_gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
]
)
)
search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = {
community.uuid: community for result in search_results for community in result
}
reranked_uuids: list[str] = []
community_scores: list[float] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_communities(
driver, list(community_uuid_map.values())
)
reranked_uuids, community_scores = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
)
elif config.reranker == CommunityReranker.cross_encoder:
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
reranked_uuids = [
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
]
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities[:limit], community_scores[:limit]
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_integration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
Tests all major MCP tools and handles episode processing latency.
"""
import asyncio
import json
import os
import time
from typing import Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class GraphitiMCPIntegrationTest:
"""Integration test client for Graphiti MCP Server using official MCP SDK."""
def __init__(self):
self.test_group_id = f'test_group_{int(time.time())}'
self.session = None
async def __aenter__(self):
"""Start the MCP client session."""
# Configure server parameters to run our refactored server
server_params = StdioServerParameters(
command='uv',
args=['run', 'main.py', '--transport', 'stdio'],
env={
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
},
)
print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
# Use the async context manager properly
self.client_context = stdio_client(server_params)
read, write = await self.client_context.__aenter__()
self.session = ClientSession(read, write)
await self.session.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Close the MCP client session."""
if self.session:
await self.session.close()
if hasattr(self, 'client_context'):
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call an MCP tool and return the result."""
try:
result = await self.session.call_tool(tool_name, arguments)
return result.content[0].text if result.content else {'error': 'No content returned'}
except Exception as e:
return {'error': str(e)}
async def test_server_initialization(self) -> bool:
"""Test that the server initializes properly."""
print('🔍 Testing server initialization...')
try:
# List available tools to verify server is responding
tools_result = await self.session.list_tools()
tools = [tool.name for tool in tools_result.tools]
expected_tools = [
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'delete_entity_edge',
'get_entity_edge',
'clear_graph',
]
available_tools = len([tool for tool in expected_tools if tool in tools])
print(
f' ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
)
print(f' Available tools: {", ".join(sorted(tools))}')
return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools
except Exception as e:
print(f' ❌ Server initialization failed: {e}')
return False
async def test_add_memory_operations(self) -> dict[str, bool]:
"""Test adding various types of memory episodes."""
print('📝 Testing add_memory operations...')
results = {}
# Test 1: Add text episode
print(' Testing text episode...')
try:
result = await self.call_tool(
'add_memory',
{
'name': 'Test Company News',
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
'source': 'text',
'source_description': 'news article',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ Text episode: {result}')
results['text'] = True
else:
print(f' ❌ Text episode failed: {result}')
results['text'] = False
except Exception as e:
print(f' ❌ Text episode error: {e}')
results['text'] = False
# Test 2: Add JSON episode
print(' Testing JSON episode...')
try:
json_data = {
'company': {'name': 'TechCorp', 'founded': 2010},
'products': [
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
],
'employees': 150,
}
result = await self.call_tool(
'add_memory',
{
'name': 'Company Profile',
'episode_body': json.dumps(json_data),
'source': 'json',
'source_description': 'CRM data',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ JSON episode: {result}')
results['json'] = True
else:
print(f' ❌ JSON episode failed: {result}')
results['json'] = False
except Exception as e:
print(f' ❌ JSON episode error: {e}')
results['json'] = False
# Test 3: Add message episode
print(' Testing message episode...')
try:
result = await self.call_tool(
'add_memory',
{
'name': 'Customer Support Chat',
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
'source': 'message',
'source_description': 'support chat log',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ Message episode: {result}')
results['message'] = True
else:
print(f' ❌ Message episode failed: {result}')
results['message'] = False
except Exception as e:
print(f' ❌ Message episode error: {e}')
results['message'] = False
return results
async def wait_for_processing(self, max_wait: int = 45) -> bool:
"""Wait for episode processing to complete."""
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
for i in range(max_wait):
await asyncio.sleep(1)
try:
# Check if we have any episodes
result = await self.call_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
# Parse the JSON result if it's a string
if isinstance(result, str):
try:
parsed_result = json.loads(result)
if isinstance(parsed_result, list) and len(parsed_result) > 0:
print(
f' ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
)
return True
except json.JSONDecodeError:
if 'episodes' in result.lower():
print(f' ✅ Episodes detected after {i + 1} seconds')
return True
except Exception as e:
if i == 0: # Only log first error to avoid spam
print(f' ⚠️ Waiting for processing... ({e})')
continue
print(f' ⚠️ Still waiting after {max_wait} seconds...')
return False
async def test_search_operations(self) -> dict[str, bool]:
"""Test search functionality."""
print('🔍 Testing search operations...')
results = {}
# Test search_memory_nodes
print(' Testing search_memory_nodes...')
try:
result = await self.call_tool(
'search_memory_nodes',
{
'query': 'Acme Corp product launch AI',
'group_ids': [self.test_group_id],
'max_nodes': 5,
},
)
success = False
if isinstance(result, str):
try:
parsed = json.loads(result)
nodes = parsed.get('nodes', [])
success = isinstance(nodes, list)
print(f' ✅ Node search returned {len(nodes)} nodes')
except json.JSONDecodeError:
success = 'nodes' in result.lower() and 'successfully' in result.lower()
if success:
print(' ✅ Node search completed successfully')
results['nodes'] = success
if not success:
print(f' ❌ Node search failed: {result}')
except Exception as e:
print(f' ❌ Node search error: {e}')
results['nodes'] = False
# Test search_memory_facts
print(' Testing search_memory_facts...')
try:
result = await self.call_tool(
'search_memory_facts',
{
'query': 'company products software TechCorp',
'group_ids': [self.test_group_id],
'max_facts': 5,
},
)
success = False
if isinstance(result, str):
try:
parsed = json.loads(result)
facts = parsed.get('facts', [])
success = isinstance(facts, list)
print(f' ✅ Fact search returned {len(facts)} facts')
except json.JSONDecodeError:
success = 'facts' in result.lower() and 'successfully' in result.lower()
if success:
print(' ✅ Fact search completed successfully')
results['facts'] = success
if not success:
print(f' ❌ Fact search failed: {result}')
except Exception as e:
print(f' ❌ Fact search error: {e}')
results['facts'] = False
return results
async def test_episode_retrieval(self) -> bool:
"""Test episode retrieval."""
print('📚 Testing episode retrieval...')
try:
result = await self.call_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
if isinstance(result, str):
try:
parsed = json.loads(result)
if isinstance(parsed, list):
print(f' ✅ Retrieved {len(parsed)} episodes')
# Show episode details
for i, episode in enumerate(parsed[:3]):
name = episode.get('name', 'Unknown')
source = episode.get('source', 'unknown')
print(f' Episode {i + 1}: {name} (source: {source})')
return len(parsed) > 0
except json.JSONDecodeError:
# Check if response indicates success
if 'episode' in result.lower():
print(' ✅ Episode retrieval completed')
return True
print(f' ❌ Unexpected result format: {result}')
return False
except Exception as e:
print(f' ❌ Episode retrieval failed: {e}')
return False
async def test_error_handling(self) -> dict[str, bool]:
"""Test error handling and edge cases."""
print('🧪 Testing error handling...')
results = {}
# Test with nonexistent group
print(' Testing nonexistent group handling...')
try:
result = await self.call_tool(
'search_memory_nodes',
{
'query': 'nonexistent data',
'group_ids': ['nonexistent_group_12345'],
'max_nodes': 5,
},
)
# Should handle gracefully, not crash
success = (
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
)
if success:
print(' ✅ Nonexistent group handled gracefully')
else:
print(f' ❌ Nonexistent group caused issues: {result}')
results['nonexistent_group'] = success
except Exception as e:
print(f' ❌ Nonexistent group test failed: {e}')
results['nonexistent_group'] = False
# Test empty query
print(' Testing empty query handling...')
try:
result = await self.call_tool(
'search_memory_nodes',
{'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
)
# Should handle gracefully
success = (
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
)
if success:
print(' ✅ Empty query handled gracefully')
else:
print(f' ❌ Empty query caused issues: {result}')
results['empty_query'] = success
except Exception as e:
print(f' ❌ Empty query test failed: {e}')
results['empty_query'] = False
return results
async def run_comprehensive_test(self) -> dict[str, Any]:
"""Run the complete integration test suite."""
print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
print(f' Test group ID: {self.test_group_id}')
print('=' * 70)
results = {
'server_init': False,
'add_memory': {},
'processing_wait': False,
'search': {},
'episodes': False,
'error_handling': {},
'overall_success': False,
}
# Test 1: Server Initialization
results['server_init'] = await self.test_server_initialization()
if not results['server_init']:
print('❌ Server initialization failed, aborting remaining tests')
return results
print()
# Test 2: Add Memory Operations
results['add_memory'] = await self.test_add_memory_operations()
print()
# Test 3: Wait for Processing
results['processing_wait'] = await self.wait_for_processing()
print()
# Test 4: Search Operations
results['search'] = await self.test_search_operations()
print()
# Test 5: Episode Retrieval
results['episodes'] = await self.test_episode_retrieval()
print()
# Test 6: Error Handling
results['error_handling'] = await self.test_error_handling()
print()
# Calculate overall success
memory_success = any(results['add_memory'].values())
search_success = any(results['search'].values()) if results['search'] else False
error_success = (
any(results['error_handling'].values()) if results['error_handling'] else True
)
results['overall_success'] = (
results['server_init']
and memory_success
and (results['episodes'] or results['processing_wait'])
and error_success
)
# Print comprehensive summary
print('=' * 70)
print('📊 COMPREHENSIVE TEST SUMMARY')
print('-' * 35)
print(f'Server Initialization: {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
print(
f'Memory Operations: {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
)
print(f'Processing Pipeline: {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
search_stats = (
f'({sum(results["search"].values())}/{len(results["search"])} types)'
if results['search']
else '(0/0 types)'
)
print(
f'Search Operations: {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
)
print(f'Episode Retrieval: {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
error_stats = (
f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
if results['error_handling']
else '(0/0 cases)'
)
print(
f'Error Handling: {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
)
print('-' * 35)
print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
if results['overall_success']:
print('\n🎉 The refactored Graphiti MCP server is working correctly!')
print(' All core functionality has been successfully tested.')
else:
print('\n⚠️ Some issues were detected. Review the test results above.')
print(' The refactoring may need additional attention.')
return results
async def main():
"""Run the integration test."""
try:
async with GraphitiMCPIntegrationTest() as test:
results = await test.run_comprehensive_test()
# Exit with appropriate code
exit_code = 0 if results['overall_success'] else 1
exit(exit_code)
except Exception as e:
print(f'❌ Test setup failed: {e}')
exit(1)
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/gemini_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import re
import typing
from typing import TYPE_CHECKING, ClassVar
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient, get_extraction_language_instruction
from .config import LLMConfig, ModelSize
from .errors import RateLimitError
if TYPE_CHECKING:
from google import genai
from google.genai import types
else:
try:
from google import genai
from google.genai import types
except ImportError:
# If gemini client is not installed, raise an ImportError
raise ImportError(
'google-genai is required for GeminiClient. '
'Install it with: pip install graphiti-core[google-genai]'
) from None
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'
# Maximum output tokens for different Gemini models
GEMINI_MODEL_MAX_TOKENS = {
# Gemini 2.5 models
'gemini-2.5-pro': 65536,
'gemini-2.5-flash': 65536,
'gemini-2.5-flash-lite': 64000,
# Gemini 2.0 models
'gemini-2.0-flash': 8192,
'gemini-2.0-flash-lite': 8192,
# Gemini 1.5 models
'gemini-1.5-pro': 8192,
'gemini-1.5-flash': 8192,
'gemini-1.5-flash-8b': 8192,
}
# Default max tokens for models not in the mapping
DEFAULT_GEMINI_MAX_TOKENS = 8192
class GeminiClient(LLMClient):
"""
GeminiClient is a client class for interacting with Google's Gemini language models.
This class extends the LLMClient and provides methods to initialize the client
and generate responses from the Gemini language model.
Attributes:
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int | None = None,
thinking_config: types.ThinkingConfig | None = None,
client: 'genai.Client | None' = None,
):
"""
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Only use with models that support thinking (gemini-2.5+). Defaults to None.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.model = config.model
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
self.max_tokens = max_tokens
self.thinking_config = thinking_config
def _check_safety_blocks(self, response) -> None:
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
# Check if the response was blocked for safety reasons
if not (hasattr(response, 'candidates') and response.candidates):
return
candidate = response.candidates[0]
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
return
# Content was blocked for safety reasons - collect safety details
safety_info = []
safety_ratings = getattr(candidate, 'safety_ratings', None)
if safety_ratings:
for rating in safety_ratings:
if getattr(rating, 'blocked', False):
category = getattr(rating, 'category', 'Unknown')
probability = getattr(rating, 'probability', 'Unknown')
safety_info.append(f'{category}: {probability}')
safety_details = (
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
)
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
def _check_prompt_blocks(self, response) -> None:
"""Check if prompt was blocked and raise appropriate exceptions."""
prompt_feedback = getattr(response, 'prompt_feedback', None)
if not prompt_feedback:
return
block_reason = getattr(prompt_feedback, 'block_reason', None)
if block_reason:
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
def _get_model_for_size(self, model_size: ModelSize) -> str:
"""Get the appropriate model name based on the requested size."""
if model_size == ModelSize.small:
return self.small_model or DEFAULT_SMALL_MODEL
else:
return self.model or DEFAULT_MODEL
def _get_max_tokens_for_model(self, model: str) -> int:
"""Get the maximum output tokens for a specific Gemini model."""
return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
"""
Resolve the maximum output tokens to use based on precedence rules.
Precedence order (highest to lowest):
1. Explicit max_tokens parameter passed to generate_response()
2. Instance max_tokens set during client initialization
3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
4. DEFAULT_MAX_TOKENS as final fallback
Args:
requested_max_tokens: The max_tokens parameter passed to generate_response()
model: The model name to look up model-specific limits
Returns:
int: The resolved maximum tokens to use
"""
# 1. Use explicit parameter if provided
if requested_max_tokens is not None:
return requested_max_tokens
# 2. Use instance max_tokens if set during initialization
if self.max_tokens is not None:
return self.max_tokens
# 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
return self._get_max_tokens_for_model(model)
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
"""
Attempt to salvage a JSON object if the raw output is truncated.
This is accomplished by looking for the last closing bracket for an array or object.
If found, it will try to load the JSON object from the raw output.
If the JSON object is not valid, it will return None.
Args:
raw_output (str): The raw output from the LLM.
Returns:
dict[str, typing.Any]: The salvaged JSON object.
None: If no salvage is possible.
"""
if not raw_output:
return None
# Try to salvage a JSON array
array_match = re.search(r'\]\s*$', raw_output)
if array_match:
try:
return json.loads(raw_output[: array_match.end()])
except Exception:
pass
# Try to salvage a JSON object
obj_match = re.search(r'\}\s*$', raw_output)
if obj_match:
try:
return json.loads(raw_output[: obj_match.end()])
except Exception:
pass
return None
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
dict[str, typing.Any]: The response from the language model.
Raises:
RateLimitError: If the API rate limit is exceeded.
Exception: If there is an error generating the response or content is blocked.
"""
try:
gemini_messages: typing.Any = []
# If a response model is provided, add schema for structured output
system_prompt = ''
if response_model is not None:
# Get the schema from the Pydantic model
pydantic_schema = response_model.model_json_schema()
# Create instruction to output in the desired JSON format
system_prompt += (
f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
'Do not include any explanatory text before or after the JSON.\n\n'
)
# Add messages content
# First check for a system message
if messages and messages[0].role == 'system':
system_prompt = f'{messages[0].content}\n\n {system_prompt}'
messages = messages[1:]
# Add the rest of the messages
for m in messages:
m.content = self._clean_input(m.content)
gemini_messages.append(
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
)
# Get the appropriate model for the requested size
model = self._get_model_for_size(model_size)
# Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
# Create generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
max_output_tokens=resolved_max_tokens,
response_mime_type='application/json' if response_model else None,
response_schema=response_model if response_model else None,
system_instruction=system_prompt,
thinking_config=self.thinking_config,
)
# Generate content using the simple string approach
response = await self.client.aio.models.generate_content(
model=model,
contents=gemini_messages,
config=generation_config,
)
# Always capture the raw output for debugging
raw_output = getattr(response, 'text', None)
# Check for safety and prompt blocks
self._check_safety_blocks(response)
self._check_prompt_blocks(response)
# If this was a structured output request, parse the response into the Pydantic model
if response_model is not None:
try:
if not raw_output:
raise ValueError('No response text')
validated_model = response_model.model_validate(json.loads(raw_output))
# Return as a dictionary for API consistency
return validated_model.model_dump()
except Exception as e:
if raw_output:
logger.error(
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
)
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
# Try to salvage
salvaged = self.salvage_json(raw_output)
if salvaged is not None:
logger.warning('Salvaged partial JSON from truncated/malformed output.')
return salvaged
raise Exception(f'Failed to parse structured response: {e}') from e
# Otherwise, return the response text as a dictionary
return {'content': raw_output}
except Exception as e:
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise Exception from e
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
prompt_name: str | None = None,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model with retry logic and error handling.
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int | None): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
group_id (str | None): Optional partition identifier for the graph.
prompt_name (str | None): Optional name of the prompt for tracing.
Returns:
dict[str, typing.Any]: The response from the language model.
"""
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction(group_id)
# Wrap entire operation in tracing span
with self.tracer.start_span('llm.generate') as span:
attributes = {
'llm.provider': 'gemini',
'model.size': model_size.value,
'max_tokens': max_tokens or self.max_tokens,
}
if prompt_name:
attributes['prompt.name'] = prompt_name
span.add_attributes(attributes)
retry_count = 0
last_error = None
last_output = None
while retry_count < self.MAX_RETRIES:
try:
response = await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
last_output = (
response.get('content')
if isinstance(response, dict) and 'content' in response
else None
)
return response
except RateLimitError as e:
# Rate limit errors should not trigger retries (fail fast)
span.set_status('error', str(e))
raise e
except Exception as e:
last_error = e
# Check if this is a safety block - these typically shouldn't be retried
error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
logger.warning(f'Content blocked by safety filters: {e}')
span.set_status('error', str(e))
raise Exception(f'Content blocked by safety filters: {e}') from e
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we exit the loop without returning, all retries are exhausted
logger.error('🦀 LLM generation failed and retries are exhausted.')
logger.error(self._get_failed_generation_log(messages, last_output))
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
span.set_status('error', str(last_error))
span.record_exception(last_error) if last_error else None
raise last_error or Exception('Max retries exceeded')
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_edge_operations.py:
--------------------------------------------------------------------------------
```python
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import BaseModel
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.search.search_config import SearchResults
from graphiti_core.utils.maintenance.edge_operations import (
DEFAULT_EDGE_NAME,
resolve_extracted_edge,
resolve_extracted_edges,
)
@pytest.fixture
def mock_llm_client():
client = MagicMock()
client.generate_response = AsyncMock()
return client
@pytest.fixture
def mock_extracted_edge():
return EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='Test fact',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
@pytest.fixture
def mock_related_edges():
return [
EntityEdge(
source_node_uuid='source_uuid_2',
target_node_uuid='target_uuid_2',
name='related_edge',
group_id='group_1',
fact='Related fact',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
invalid_at=None,
)
]
@pytest.fixture
def mock_existing_edges():
return [
EntityEdge(
source_node_uuid='source_uuid_3',
target_node_uuid='target_uuid_3',
name='existing_edge',
group_id='group_1',
fact='Existing fact',
episodes=['episode_3'],
created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=datetime.now(timezone.utc) - timedelta(days=2),
invalid_at=None,
)
]
@pytest.fixture
def mock_current_episode():
return EpisodicNode(
uuid='episode_1',
content='Current episode content',
valid_at=datetime.now(timezone.utc),
name='Current Episode',
group_id='group_1',
source='message',
source_description='Test source description',
)
@pytest.fixture
def mock_previous_episodes():
return [
EpisodicNode(
uuid='episode_2',
content='Previous episode content',
valid_at=datetime.now(timezone.utc) - timedelta(days=1),
name='Previous Episode',
group_id='group_1',
source='message',
source_description='Test source description',
)
]
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])
@pytest.mark.asyncio
async def test_resolve_extracted_edge_exact_fact_short_circuit(
mock_llm_client,
mock_existing_edges,
mock_current_episode,
):
extracted = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='Related fact',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
related_edges = [
EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='related_edge',
group_id='group_1',
fact=' related FACT ',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=None,
invalid_at=None,
)
]
resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted,
related_edges,
mock_existing_edges,
mock_current_episode,
edge_type_candidates=None,
)
assert resolved_edge is related_edges[0]
assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
assert duplicate_edges == []
assert invalidated == []
mock_llm_client.generate_response.assert_not_called()
class OccurredAtEdge(BaseModel):
"""Edge model stub for OCCURRED_AT."""
@pytest.mark.asyncio
async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
)
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
)
source_node = EntityNode(
uuid='source_uuid',
name='Document Node',
group_id='group_1',
labels=['Document'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == DEFAULT_EDGE_NAME
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
async def immediate_gather(*aws, max_coroutines=None):
return [await aw for aw in aws]
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
llm_client = MagicMock()
llm_client.generate_response = AsyncMock(
return_value={
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
)
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
)
source_node = EntityNode(
uuid='source_uuid',
name='User Node',
group_id='group_1',
labels=['User'],
)
target_node = EntityNode(
uuid='target_uuid',
name='Topic Node',
group_id='group_1',
labels=['Topic'],
)
extracted_edge = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='INTERACTED_WITH',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
edge_types = {'OCCURRED_AT': OccurredAtEdge}
edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[extracted_edge],
episode,
[source_node, target_node],
edge_types,
edge_type_map,
)
assert resolved_edges[0].name == 'INTERACTED_WITH'
assert invalidated_edges == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'OCCURRED_AT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='OCCURRED_AT',
group_id='group_1',
fact='Document occurred at somewhere',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='alt_source',
target_node_uuid='alt_target',
name='OTHER',
group_id='group_1',
fact='Different fact',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={},
custom_edge_type_names={'OCCURRED_AT'},
)
assert resolved_edge.name == DEFAULT_EDGE_NAME
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [],
'contradicted_facts': [],
'fact_type': 'INTERACTED_WITH',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User interacted with topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
related_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='DEFAULT',
group_id='group_1',
fact='User mentioned a topic',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
[related_edge],
[],
episode,
edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
custom_edge_type_names={'OCCURRED_AT'},
)
assert resolved_edge.name == 'INTERACTED_WITH'
assert resolved_edge.attributes == {}
assert duplicates == []
assert invalidated == []
@pytest.mark.asyncio
async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
"""Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
# Mock LLM to return duplicate_facts with integer indices
mock_llm_client.generate_response.return_value = {
'duplicate_facts': [0, 1], # LLM identifies first two related edges as duplicates
'contradicted_facts': [],
'fact_type': 'DEFAULT',
}
extracted_edge = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User likes yoga',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
# Create multiple related edges - LLM should receive these with integer indices
related_edge_0 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User enjoys yoga',
episodes=['episode_1'],
created_at=datetime.now(timezone.utc) - timedelta(days=1),
valid_at=None,
invalid_at=None,
)
related_edge_1 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User practices yoga',
episodes=['episode_2'],
created_at=datetime.now(timezone.utc) - timedelta(days=2),
valid_at=None,
invalid_at=None,
)
related_edge_2 = EntityEdge(
source_node_uuid='source_uuid',
target_node_uuid='target_uuid',
name='test_edge',
group_id='group_1',
fact='User loves swimming',
episodes=['episode_3'],
created_at=datetime.now(timezone.utc) - timedelta(days=3),
valid_at=None,
invalid_at=None,
)
related_edges = [related_edge_0, related_edge_1, related_edge_2]
resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
mock_llm_client,
extracted_edge,
related_edges,
[],
episode,
edge_type_candidates=None,
custom_edge_type_names=set(),
)
# Verify LLM was called
mock_llm_client.generate_response.assert_called_once()
# Verify the system correctly identified duplicates using integer indices
# The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
assert len(duplicates) == 2
assert related_edge_0 in duplicates
assert related_edge_1 in duplicates
assert invalidated == []
# Verify that the resolved edge is one of the duplicates (the first one found)
# Check UUID since the episode list gets modified
assert resolved_edge.uuid == related_edge_0.uuid
assert episode.uuid in resolved_edge.episodes
@pytest.mark.asyncio
async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
"""Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
from graphiti_core.utils.maintenance import edge_operations as edge_ops
monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
# Track how many times resolve_extracted_edge is called
resolve_call_count = 0
async def mock_resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
edge_type_candidates=None,
custom_edge_type_names=None,
):
nonlocal resolve_call_count
resolve_call_count += 1
return extracted_edge, [], []
# Mock semaphore_gather to execute awaitable immediately
async def immediate_gather(*aws, max_coroutines=None):
results = []
for aw in aws:
results.append(await aw)
return results
monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
llm_client = MagicMock()
clients = SimpleNamespace(
driver=MagicMock(),
llm_client=llm_client,
embedder=MagicMock(),
cross_encoder=MagicMock(),
)
source_node = EntityNode(
uuid='source_uuid',
name='Assistant',
group_id='group_1',
labels=['Entity'],
)
target_node = EntityNode(
uuid='target_uuid',
name='User',
group_id='group_1',
labels=['Entity'],
)
# Create 3 identical edges
edge1 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact='assistant recommends yoga poses',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
edge2 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact=' Assistant Recommends YOGA Poses ', # Different whitespace/case
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
edge3 = EntityEdge(
source_node_uuid=source_node.uuid,
target_node_uuid=target_node.uuid,
name='recommends',
group_id='group_1',
fact='assistant recommends yoga poses',
episodes=[],
created_at=datetime.now(timezone.utc),
valid_at=None,
invalid_at=None,
)
episode = EpisodicNode(
uuid='episode_uuid',
name='Episode',
group_id='group_1',
source='message',
source_description='desc',
content='Episode content',
valid_at=datetime.now(timezone.utc),
)
resolved_edges, invalidated_edges = await resolve_extracted_edges(
clients,
[edge1, edge2, edge3],
episode,
[source_node, target_node],
{},
{},
)
# Fast path should have deduplicated the 3 identical edges to 1
# So resolve_extracted_edge should only be called once
assert resolve_call_count == 1
assert len(resolved_edges) == 1
assert invalidated_edges == []
```
--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from datetime import datetime
import numpy as np
from pydantic import BaseModel, Field
from typing_extensions import Any
from graphiti_core.driver.driver import (
GraphDriver,
GraphDriverSession,
GraphProvider,
)
from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
from graphiti_core.embedder import EmbedderClient
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import normalize_l2, semaphore_gather
from graphiti_core.models.edges.edge_db_queries import (
get_entity_edge_save_bulk_query,
get_episodic_edge_save_bulk_query,
)
from graphiti_core.models.nodes.node_db_queries import (
get_entity_node_save_bulk_query,
get_episode_node_save_bulk_query,
)
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
from graphiti_core.utils.maintenance.dedup_helpers import (
DedupResolutionState,
_build_candidate_indexes,
_normalize_string_exact,
_resolve_with_similarity,
)
from graphiti_core.utils.maintenance.edge_operations import (
extract_edges,
resolve_extracted_edge,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
logger = logging.getLogger(__name__)
CHUNK_SIZE = 10
def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
"""Collapse alias -> canonical chains while preserving direction.
The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
union-find with iterative path compression to ensure every source UUID resolves to its ultimate
canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
"""
parent: dict[str, str] = {}
def find(uuid: str) -> str:
"""Directed union-find lookup using iterative path compression."""
parent.setdefault(uuid, uuid)
root = uuid
while parent[root] != root:
root = parent[root]
while parent[uuid] != root:
next_uuid = parent[uuid]
parent[uuid] = root
uuid = next_uuid
return root
for source_uuid, target_uuid in pairs:
parent.setdefault(source_uuid, source_uuid)
parent.setdefault(target_uuid, target_uuid)
parent[find(source_uuid)] = find(target_uuid)
return {uuid: find(uuid) for uuid in parent}
class RawEpisode(BaseModel):
name: str
uuid: str | None = Field(default=None)
content: str
source_description: str
source: EpisodeType
reference_time: datetime
async def retrieve_previous_episodes_bulk(
driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
previous_episodes_list = await semaphore_gather(
*[
retrieve_episodes(
driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
)
for episode in episodes
]
)
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
(episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
]
return episode_tuples
async def add_nodes_and_edges_bulk(
driver: GraphDriver,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
):
session = driver.session()
try:
await session.execute_write(
add_nodes_and_edges_bulk_tx,
episodic_nodes,
episodic_edges,
entity_nodes,
entity_edges,
embedder,
driver=driver,
)
finally:
await session.close()
async def add_nodes_and_edges_bulk_tx(
tx: GraphDriverSession,
episodic_nodes: list[EpisodicNode],
episodic_edges: list[EpisodicEdge],
entity_nodes: list[EntityNode],
entity_edges: list[EntityEdge],
embedder: EmbedderClient,
driver: GraphDriver,
):
episodes = [dict(episode) for episode in episodic_nodes]
for episode in episodes:
episode['source'] = str(episode['source'].value)
episode.pop('labels', None)
nodes = []
for node in entity_nodes:
if node.name_embedding is None:
await node.generate_name_embedding(embedder)
entity_data: dict[str, Any] = {
'uuid': node.uuid,
'name': node.name,
'group_id': node.group_id,
'summary': node.summary,
'created_at': node.created_at,
'name_embedding': node.name_embedding,
'labels': list(set(node.labels + ['Entity'])),
}
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
entity_data['attributes'] = json.dumps(attributes)
else:
entity_data.update(node.attributes or {})
nodes.append(entity_data)
edges = []
for edge in entity_edges:
if edge.fact_embedding is None:
await edge.generate_embedding(embedder)
edge_data: dict[str, Any] = {
'uuid': edge.uuid,
'source_node_uuid': edge.source_node_uuid,
'target_node_uuid': edge.target_node_uuid,
'name': edge.name,
'fact': edge.fact,
'group_id': edge.group_id,
'episodes': edge.episodes,
'created_at': edge.created_at,
'expired_at': edge.expired_at,
'valid_at': edge.valid_at,
'invalid_at': edge.invalid_at,
'fact_embedding': edge.fact_embedding,
}
if driver.provider == GraphProvider.KUZU:
attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
edge_data['attributes'] = json.dumps(attributes)
else:
edge_data.update(edge.attributes or {})
edges.append(edge_data)
if driver.graph_operations_interface:
await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
await driver.graph_operations_interface.episodic_edge_save_bulk(
None, driver, tx, [edge.model_dump() for edge in episodic_edges]
)
await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
elif driver.provider == GraphProvider.KUZU:
# FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
episode_query = get_episode_node_save_bulk_query(driver.provider)
for episode in episodes:
await tx.run(episode_query, **episode)
entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
for node in nodes:
await tx.run(entity_node_query, **node)
entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
for edge in edges:
await tx.run(entity_edge_query, **edge)
episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
for edge in episodic_edges:
await tx.run(episodic_edge_query, **edge.model_dump())
else:
await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
await tx.run(
get_entity_node_save_bulk_query(driver.provider, nodes),
nodes=nodes,
)
await tx.run(
get_episodic_edge_save_bulk_query(driver.provider),
episodic_edges=[edge.model_dump() for edge in episodic_edges],
)
await tx.run(
get_entity_edge_save_bulk_query(driver.provider),
entity_edges=edges,
)
async def extract_nodes_and_edges_bulk(
clients: GraphitiClients,
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
edge_type_map: dict[tuple[str, str], list[str]],
entity_types: dict[str, type[BaseModel]] | None = None,
excluded_entity_types: list[str] | None = None,
edge_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
*[
extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
for episode, previous_episodes in episode_tuples
]
)
extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
*[
extract_edges(
clients,
episode,
extracted_nodes_bulk[i],
previous_episodes,
edge_type_map=edge_type_map,
group_id=episode.group_id,
edge_types=edge_types,
)
for i, (episode, previous_episodes) in enumerate(episode_tuples)
]
)
return extracted_nodes_bulk, extracted_edges_bulk
async def dedupe_nodes_bulk(
clients: GraphitiClients,
extracted_nodes: list[list[EntityNode]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
entity_types: dict[str, type[BaseModel]] | None = None,
) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
"""Resolve entity duplicates across an in-memory batch using a two-pass strategy.
1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
reconciled against the live graph just like the non-batch flow.
2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
can apply to edges and persistence.
"""
first_pass_results = await semaphore_gather(
*[
resolve_extracted_nodes(
clients,
nodes,
episode_tuples[i][0],
episode_tuples[i][1],
entity_types,
)
for i, nodes in enumerate(extracted_nodes)
]
)
episode_resolutions: list[tuple[str, list[EntityNode]]] = []
per_episode_uuid_maps: list[dict[str, str]] = []
duplicate_pairs: list[tuple[str, str]] = []
for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
first_pass_results, episode_tuples, strict=True
):
episode_resolutions.append((episode.uuid, resolved_nodes))
per_episode_uuid_maps.append(uuid_map)
duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
canonical_nodes: dict[str, EntityNode] = {}
for _, resolved_nodes in episode_resolutions:
for node in resolved_nodes:
# NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
# the MinHash index for the accumulated canonical pool each time. The LRU-backed
# shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
# but if batches grow significantly we should switch to an incremental index or chunked
# processing.
if not canonical_nodes:
canonical_nodes[node.uuid] = node
continue
existing_candidates = list(canonical_nodes.values())
normalized = _normalize_string_exact(node.name)
exact_match = next(
(
candidate
for candidate in existing_candidates
if _normalize_string_exact(candidate.name) == normalized
),
None,
)
if exact_match is not None:
if exact_match.uuid != node.uuid:
duplicate_pairs.append((node.uuid, exact_match.uuid))
continue
indexes = _build_candidate_indexes(existing_candidates)
state = DedupResolutionState(
resolved_nodes=[None],
uuid_map={},
unresolved_indices=[],
)
_resolve_with_similarity([node], indexes, state)
resolved = state.resolved_nodes[0]
if resolved is None:
canonical_nodes[node.uuid] = node
continue
canonical_uuid = resolved.uuid
canonical_nodes.setdefault(canonical_uuid, resolved)
if canonical_uuid != node.uuid:
duplicate_pairs.append((node.uuid, canonical_uuid))
union_pairs: list[tuple[str, str]] = []
for uuid_map in per_episode_uuid_maps:
union_pairs.extend(uuid_map.items())
union_pairs.extend(duplicate_pairs)
compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
nodes_by_episode: dict[str, list[EntityNode]] = {}
for episode_uuid, resolved_nodes in episode_resolutions:
deduped_nodes: list[EntityNode] = []
seen: set[str] = set()
for node in resolved_nodes:
canonical_uuid = compressed_map.get(node.uuid, node.uuid)
if canonical_uuid in seen:
continue
seen.add(canonical_uuid)
canonical_node = canonical_nodes.get(canonical_uuid)
if canonical_node is None:
logger.error(
'Canonical node %s missing during batch dedupe; falling back to %s',
canonical_uuid,
node.uuid,
)
canonical_node = node
deduped_nodes.append(canonical_node)
nodes_by_episode[episode_uuid] = deduped_nodes
return nodes_by_episode, compressed_map
async def dedupe_edges_bulk(
clients: GraphitiClients,
extracted_edges: list[list[EntityEdge]],
episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
_entities: list[EntityNode],
edge_types: dict[str, type[BaseModel]],
_edge_type_map: dict[tuple[str, str], list[str]],
) -> dict[str, list[EntityEdge]]:
embedder = clients.embedder
min_score = 0.6
# generate embeddings
await semaphore_gather(
*[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
)
# Find similar results
dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
for i, edges_i in enumerate(extracted_edges):
existing_edges: list[EntityEdge] = []
for edges_j in extracted_edges:
existing_edges += edges_j
for edge in edges_i:
candidates: list[EntityEdge] = []
for existing_edge in existing_edges:
# Skip self-comparison
if edge.uuid == existing_edge.uuid:
continue
# Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
# This approach will cast a wider net than BM25, which is ideal for this use case
if (
edge.source_node_uuid != existing_edge.source_node_uuid
or edge.target_node_uuid != existing_edge.target_node_uuid
):
continue
edge_words = set(edge.fact.lower().split())
existing_edge_words = set(existing_edge.fact.lower().split())
has_overlap = not edge_words.isdisjoint(existing_edge_words)
if has_overlap:
candidates.append(existing_edge)
continue
# Check for semantic similarity even if there is no overlap
similarity = np.dot(
normalize_l2(edge.fact_embedding or []),
normalize_l2(existing_edge.fact_embedding or []),
)
if similarity >= min_score:
candidates.append(existing_edge)
dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
bulk_edge_resolutions: list[
tuple[EntityEdge, EntityEdge, list[EntityEdge]]
] = await semaphore_gather(
*[
resolve_extracted_edge(
clients.llm_client,
edge,
candidates,
candidates,
episode,
edge_types,
set(edge_types),
)
for episode, edge, candidates in dedupe_tuples
]
)
# For now we won't track edge invalidation
duplicate_pairs: list[tuple[str, str]] = []
for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
episode, edge, candidates = dedupe_tuples[i]
for duplicate in duplicates:
duplicate_pairs.append((edge.uuid, duplicate.uuid))
# Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
edge_uuid_map: dict[str, EntityEdge] = {
edge.uuid: edge for edges in extracted_edges for edge in edges
}
edges_by_episode: dict[str, list[EntityEdge]] = {}
for i, edges in enumerate(extracted_edges):
episode = episode_tuples[i][0]
edges_by_episode[episode.uuid] = [
edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
]
return edges_by_episode
class UnionFind:
def __init__(self, elements):
# start each element in its own set
self.parent = {e: e for e in elements}
def find(self, x):
# path‐compression
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]
def union(self, a, b):
ra, rb = self.find(a), self.find(b)
if ra == rb:
return
# attach the lexicographically larger root under the smaller
if ra < rb:
self.parent[rb] = ra
else:
self.parent[ra] = rb
def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
"""
all_ids: iterable of all entity IDs (strings)
duplicate_pairs: iterable of (id1, id2) pairs
returns: dict mapping each id -> lexicographically smallest id in its duplicate set
"""
all_uuids = set()
for pair in duplicate_pairs:
all_uuids.add(pair[0])
all_uuids.add(pair[1])
uf = UnionFind(all_uuids)
for a, b in duplicate_pairs:
uf.union(a, b)
# ensure full path‐compression before mapping
return {uuid: uf.find(uuid) for uuid in all_uuids}
E = typing.TypeVar('E', bound=Edge)
def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
for edge in edges:
source_uuid = edge.source_node_uuid
target_uuid = edge.target_node_uuid
edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
return edges
```