This is page 5 of 9. Use http://codebase.md/getzep/graphiti?page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── dense_vs_normal_ingestion.py
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── content_chunking.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_entity_extraction.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ ├── search
│ │ └── search_utils_test.py
│ └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
from graphiti_core.llm_client import LLMConfig, RateLimitError
@pytest.fixture
def mock_gemini_client():
"""Fixture to mock the Google Gemini client."""
with patch('google.genai.Client') as mock_client:
# Setup mock instance and its methods
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.generate_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_reranker_client(mock_gemini_client):
"""Fixture to create a GeminiRerankerClient with a mocked client."""
config = LLMConfig(api_key='test_api_key', model='test-model')
client = GeminiRerankerClient(config=config)
# Replace the client's client with our mock to ensure we're using the mock
client.client = mock_gemini_client
return client
def create_mock_response(score_text: str) -> MagicMock:
"""Helper function to create a mock Gemini response."""
mock_response = MagicMock()
mock_response.text = score_text
return mock_response
class TestGeminiRerankerClientInitialization:
"""Tests for GeminiRerankerClient initialization."""
def test_init_with_config(self):
"""Test initialization with a config object."""
config = LLMConfig(api_key='test_api_key', model='test-model')
client = GeminiRerankerClient(config=config)
assert client.config == config
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
client = GeminiRerankerClient()
assert client.config is not None
def test_init_with_custom_client(self):
"""Test initialization with a custom client."""
mock_client = MagicMock()
client = GeminiRerankerClient(client=mock_client)
assert client.client == mock_client
class TestGeminiRerankerClientRanking:
"""Tests for GeminiRerankerClient rank method."""
@pytest.mark.asyncio
async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
"""Test basic ranking functionality."""
# Setup mock responses with different scores
mock_responses = [
create_mock_response('85'), # High relevance
create_mock_response('45'), # Medium relevance
create_mock_response('20'), # Low relevance
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
# Test data
query = 'What is the capital of France?'
passages = [
'Paris is the capital and most populous city of France.',
'London is the capital city of England and the United Kingdom.',
'Berlin is the capital and largest city of Germany.',
]
# Call method
result = await gemini_reranker_client.rank(query, passages)
# Assertions
assert len(result) == 3
assert all(isinstance(item, tuple) for item in result)
assert all(
isinstance(passage, str) and isinstance(score, float) for passage, score in result
)
# Check scores are normalized to [0, 1] and sorted in descending order
scores = [score for _, score in result]
assert all(0.0 <= score <= 1.0 for score in scores)
assert scores == sorted(scores, reverse=True)
# Check that the highest scoring passage is first
assert result[0][1] == 0.85 # 85/100
assert result[1][1] == 0.45 # 45/100
assert result[2][1] == 0.20 # 20/100
@pytest.mark.asyncio
async def test_rank_empty_passages(self, gemini_reranker_client):
"""Test ranking with empty passages list."""
query = 'Test query'
passages = []
result = await gemini_reranker_client.rank(query, passages)
assert result == []
@pytest.mark.asyncio
async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
"""Test ranking with a single passage."""
# Setup mock response
mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
query = 'Test query'
passages = ['Single test passage']
result = await gemini_reranker_client.rank(query, passages)
assert len(result) == 1
assert result[0][0] == 'Single test passage'
assert result[0][1] == 1.0 # Single passage gets full score
@pytest.mark.asyncio
async def test_rank_score_extraction_with_regex(
self, gemini_reranker_client, mock_gemini_client
):
"""Test score extraction from various response formats."""
# Setup mock responses with different formats
mock_responses = [
create_mock_response('Score: 90'), # Contains text before number
create_mock_response('The relevance is 65 out of 100'), # Contains text around number
create_mock_response('8'), # Just the number
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that scores were extracted correctly and normalized
scores = [score for _, score in result]
assert 0.90 in scores # 90/100
assert 0.65 in scores # 65/100
assert 0.08 in scores # 8/100
@pytest.mark.asyncio
async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of invalid or non-numeric scores."""
# Setup mock responses with invalid scores
mock_responses = [
create_mock_response('Not a number'), # Invalid response
create_mock_response(''), # Empty response
create_mock_response('95'), # Valid response
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that invalid scores are handled gracefully (assigned 0.0)
scores = [score for _, score in result]
assert 0.95 in scores # Valid score
assert scores.count(0.0) == 2 # Two invalid scores assigned 0.0
@pytest.mark.asyncio
async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
"""Test that scores are properly clamped to [0, 1] range."""
# Setup mock responses with extreme scores
# Note: regex only matches 1-3 digits, so negative numbers won't match
mock_responses = [
create_mock_response('999'), # Above 100 but within regex range
create_mock_response('invalid'), # Invalid response becomes 0.0
create_mock_response('50'), # Normal score
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
result = await gemini_reranker_client.rank(query, passages)
# Check that scores are normalized and clamped
scores = [score for _, score in result]
assert all(0.0 <= score <= 1.0 for score in scores)
# 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
assert 1.0 in scores
# Invalid response should be 0.0
assert 0.0 in scores
# Normal score should be normalized (50/100 = 0.5)
assert 0.5 in scores
@pytest.mark.asyncio
async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of rate limit errors."""
# Setup mock to raise rate limit error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'Rate limit exceeded'
)
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of quota errors."""
# Setup mock to raise quota error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of resource exhausted errors."""
# Setup mock to raise resource exhausted error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of HTTP 429 errors."""
# Setup mock to raise 429 error
mock_gemini_client.aio.models.generate_content.side_effect = Exception(
'HTTP 429 Too Many Requests'
)
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(RateLimitError):
await gemini_reranker_client.rank(query, passages)
@pytest.mark.asyncio
async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of generic errors."""
# Setup mock to raise generic error
mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
query = 'Test query'
passages = ['Passage 1', 'Passage 2']
with pytest.raises(Exception) as exc_info:
await gemini_reranker_client.rank(query, passages)
assert 'Generic error' in str(exc_info.value)
@pytest.mark.asyncio
async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
"""Test that multiple passages are scored concurrently."""
# Setup mock responses
mock_responses = [
create_mock_response('80'),
create_mock_response('60'),
create_mock_response('40'),
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
passages = ['Passage 1', 'Passage 2', 'Passage 3']
await gemini_reranker_client.rank(query, passages)
# Verify that generate_content was called for each passage
assert mock_gemini_client.aio.models.generate_content.call_count == 3
# Verify that all calls were made with correct parameters
calls = mock_gemini_client.aio.models.generate_content.call_args_list
for call in calls:
args, kwargs = call
assert kwargs['model'] == gemini_reranker_client.config.model
assert kwargs['config'].temperature == 0.0
assert kwargs['config'].max_output_tokens == 3
@pytest.mark.asyncio
async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of response parsing errors."""
# Setup mock responses that will trigger ValueError during parsing
mock_responses = [
create_mock_response('not a number at all'), # Will fail regex match
create_mock_response('also invalid text'), # Will fail regex match
]
mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
query = 'Test query'
# Use multiple passages to avoid the single passage special case
passages = ['Passage 1', 'Passage 2']
result = await gemini_reranker_client.rank(query, passages)
# Should handle the error gracefully and assign 0.0 score to both
assert len(result) == 2
assert all(score == 0.0 for _, score in result)
@pytest.mark.asyncio
async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
"""Test handling of empty response text."""
# Setup mock response with empty text
mock_response = MagicMock()
mock_response.text = '' # Empty string instead of None
mock_gemini_client.aio.models.generate_content.return_value = mock_response
query = 'Test query'
# Use multiple passages to avoid the single passage special case
passages = ['Passage 1', 'Passage 2']
result = await gemini_reranker_client.rank(query, passages)
# Should handle empty text gracefully and assign 0.0 score to both
assert len(result) == 2
assert all(score == 0.0 for _, score in result)
if __name__ == '__main__':
pytest.main(['-v', 'test_gemini_reranker_client.py'])
```
--------------------------------------------------------------------------------
/examples/quickstart/dense_vs_normal_ingestion.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Dense vs Normal Episode Ingestion Example
-----------------------------------------
This example demonstrates how Graphiti handles different types of content:
1. Normal Content (prose, narrative, conversations):
- Lower entity density (few entities per token)
- Processed in a single LLM call
- Examples: meeting transcripts, news articles, documentation
2. Dense Content (structured data with many entities):
- High entity density (many entities per token)
- Automatically chunked for reliable extraction
- Examples: bulk data imports, cost reports, entity-dense JSON
The chunking behavior is controlled by environment variables:
- CHUNK_MIN_TOKENS: Minimum tokens before considering chunking (default: 1000)
- CHUNK_DENSITY_THRESHOLD: Entity density threshold (default: 0.15)
- CHUNK_TOKEN_SIZE: Target size per chunk (default: 3000)
- CHUNK_OVERLAP_TOKENS: Overlap between chunks (default: 200)
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
#################################################
# CONFIGURATION
#################################################
logging.basicConfig(
level=INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
load_dotenv()
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
if not neo4j_uri or not neo4j_user or not neo4j_password:
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
#################################################
# EXAMPLE DATA
#################################################
# Normal content: A meeting transcript (low entity density)
# This is prose/narrative content with few entities per token.
# It will NOT trigger chunking - processed in a single LLM call.
NORMAL_EPISODE_CONTENT = """
Meeting Notes - Q4 Planning Session
Alice opened the meeting by reviewing our progress on the mobile app redesign.
She mentioned that the user research phase went well and highlighted key findings
from the customer interviews conducted last month.
Bob then presented the engineering timeline. He explained that the backend API
refactoring is about 60% complete and should be finished by end of November.
The team has resolved most of the performance issues identified in the load tests.
Carol raised concerns about the holiday freeze period affecting our deployment
schedule. She suggested we move the beta launch to early December to give the
QA team enough time for regression testing before the code freeze.
David agreed with Carol's assessment and proposed allocating two additional
engineers from the platform team to help with the testing effort. He also
mentioned that the documentation needs to be updated before the release.
Action items:
- Alice will finalize the design specs by Friday
- Bob will coordinate with the platform team on resource allocation
- Carol will update the project timeline in Jira
- David will schedule a follow-up meeting for next Tuesday
The meeting concluded at 3:30 PM with agreement to reconvene next week.
"""
# Dense content: AWS cost data (high entity density)
# This is structured data with many entities per token.
# It WILL trigger chunking - processed in multiple LLM calls.
DENSE_EPISODE_CONTENT = {
'report_type': 'AWS Cost Breakdown',
'months': [
{
'period': '2025-01',
'services': [
{'name': 'Amazon S3', 'cost': 2487.97},
{'name': 'Amazon RDS', 'cost': 1071.74},
{'name': 'Amazon ECS', 'cost': 853.74},
{'name': 'Amazon OpenSearch', 'cost': 389.74},
{'name': 'AWS Secrets Manager', 'cost': 265.77},
{'name': 'CloudWatch', 'cost': 232.34},
{'name': 'Amazon VPC', 'cost': 238.39},
{'name': 'EC2 Other', 'cost': 226.82},
{'name': 'Amazon EC2 Compute', 'cost': 78.27},
{'name': 'Amazon DocumentDB', 'cost': 65.40},
{'name': 'Amazon ECR', 'cost': 29.00},
{'name': 'Amazon ELB', 'cost': 37.53},
],
},
{
'period': '2025-02',
'services': [
{'name': 'Amazon S3', 'cost': 2721.04},
{'name': 'Amazon RDS', 'cost': 1035.77},
{'name': 'Amazon ECS', 'cost': 779.49},
{'name': 'Amazon OpenSearch', 'cost': 357.90},
{'name': 'AWS Secrets Manager', 'cost': 268.57},
{'name': 'CloudWatch', 'cost': 224.57},
{'name': 'Amazon VPC', 'cost': 215.15},
{'name': 'EC2 Other', 'cost': 213.86},
{'name': 'Amazon EC2 Compute', 'cost': 70.70},
{'name': 'Amazon DocumentDB', 'cost': 59.07},
{'name': 'Amazon ECR', 'cost': 33.92},
{'name': 'Amazon ELB', 'cost': 33.89},
],
},
{
'period': '2025-03',
'services': [
{'name': 'Amazon S3', 'cost': 2952.31},
{'name': 'Amazon RDS', 'cost': 1198.79},
{'name': 'Amazon ECS', 'cost': 869.78},
{'name': 'Amazon OpenSearch', 'cost': 389.75},
{'name': 'AWS Secrets Manager', 'cost': 271.33},
{'name': 'CloudWatch', 'cost': 233.00},
{'name': 'Amazon VPC', 'cost': 238.31},
{'name': 'EC2 Other', 'cost': 227.78},
{'name': 'Amazon EC2 Compute', 'cost': 78.21},
{'name': 'Amazon DocumentDB', 'cost': 65.40},
{'name': 'Amazon ECR', 'cost': 33.75},
{'name': 'Amazon ELB', 'cost': 37.54},
],
},
{
'period': '2025-04',
'services': [
{'name': 'Amazon S3', 'cost': 3189.62},
{'name': 'Amazon RDS', 'cost': 1102.30},
{'name': 'Amazon ECS', 'cost': 848.19},
{'name': 'Amazon OpenSearch', 'cost': 379.14},
{'name': 'AWS Secrets Manager', 'cost': 270.89},
{'name': 'CloudWatch', 'cost': 230.64},
{'name': 'Amazon VPC', 'cost': 230.54},
{'name': 'EC2 Other', 'cost': 220.18},
{'name': 'Amazon EC2 Compute', 'cost': 75.70},
{'name': 'Amazon DocumentDB', 'cost': 63.29},
{'name': 'Amazon ECR', 'cost': 35.21},
{'name': 'Amazon ELB', 'cost': 36.30},
],
},
{
'period': '2025-05',
'services': [
{'name': 'Amazon S3', 'cost': 3423.07},
{'name': 'Amazon RDS', 'cost': 1014.50},
{'name': 'Amazon ECS', 'cost': 874.75},
{'name': 'Amazon OpenSearch', 'cost': 389.71},
{'name': 'AWS Secrets Manager', 'cost': 274.91},
{'name': 'CloudWatch', 'cost': 233.28},
{'name': 'Amazon VPC', 'cost': 238.53},
{'name': 'EC2 Other', 'cost': 227.27},
{'name': 'Amazon EC2 Compute', 'cost': 78.27},
{'name': 'Amazon DocumentDB', 'cost': 65.40},
{'name': 'Amazon ECR', 'cost': 37.42},
{'name': 'Amazon ELB', 'cost': 37.52},
],
},
{
'period': '2025-06',
'services': [
{'name': 'Amazon S3', 'cost': 3658.14},
{'name': 'Amazon RDS', 'cost': 963.60},
{'name': 'Amazon ECS', 'cost': 942.45},
{'name': 'Amazon OpenSearch', 'cost': 379.06},
{'name': 'AWS Secrets Manager', 'cost': 282.41},
{'name': 'CloudWatch', 'cost': 230.69},
{'name': 'Amazon VPC', 'cost': 230.64},
{'name': 'EC2 Other', 'cost': 220.11},
{'name': 'Amazon EC2 Compute', 'cost': 75.74},
{'name': 'Amazon DocumentDB', 'cost': 63.30},
{'name': 'Amazon ECR', 'cost': 39.49},
{'name': 'Amazon ELB', 'cost': 36.30},
],
},
],
}
# Message content: A conversation (moderate entity density)
# This demonstrates message-type episodes with speaker patterns.
MESSAGE_EPISODE_CONTENT = """User: I need to set up a new development environment for the analytics project.
Assistant: I can help with that. Which cloud provider are you using - AWS, GCP, or Azure?
User: We're using AWS. The project requires PostgreSQL, Redis, and Elasticsearch.
Assistant: Great choice. For PostgreSQL, I recommend Amazon RDS with the db.r5.large instance type. For Redis, you can use Amazon ElastiCache. For Elasticsearch, Amazon OpenSearch Service would be ideal.
User: What about the compute layer? We need to run Python data pipelines.
Assistant: For Python data pipelines, I suggest Amazon ECS with Fargate for serverless containers, or AWS Lambda for smaller, event-driven tasks. If you need more control, EC2 instances with Auto Scaling would work well.
User: Let's go with ECS Fargate. Can you also recommend a CI/CD setup?
Assistant: For CI/CD with ECS Fargate, I recommend AWS CodePipeline integrated with CodeBuild for building container images, and CodeDeploy for blue-green deployments. You can store your container images in Amazon ECR.
"""
async def main():
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
try:
#################################################
# EXAMPLE 1: Normal Content (No Chunking)
#################################################
# This prose content has low entity density.
# Graphiti will process it in a single LLM call.
#################################################
print('=' * 60)
print('EXAMPLE 1: Normal Content (Meeting Transcript)')
print('=' * 60)
print(f'Content length: {len(NORMAL_EPISODE_CONTENT)} characters')
print(f'Estimated tokens: ~{len(NORMAL_EPISODE_CONTENT) // 4}')
print('Expected behavior: Single LLM call (no chunking)')
print()
await graphiti.add_episode(
name='Q4 Planning Meeting',
episode_body=NORMAL_EPISODE_CONTENT,
source=EpisodeType.text,
source_description='Meeting transcript',
reference_time=datetime.now(timezone.utc),
)
print('Successfully added normal episode\n')
#################################################
# EXAMPLE 2: Dense Content (Chunking Triggered)
#################################################
# This structured data has high entity density.
# Graphiti will automatically chunk it for
# reliable extraction across multiple LLM calls.
#################################################
print('=' * 60)
print('EXAMPLE 2: Dense Content (AWS Cost Report)')
print('=' * 60)
dense_json = json.dumps(DENSE_EPISODE_CONTENT)
print(f'Content length: {len(dense_json)} characters')
print(f'Estimated tokens: ~{len(dense_json) // 4}')
print('Expected behavior: Multiple LLM calls (chunking enabled)')
print()
await graphiti.add_episode(
name='AWS Cost Report 2025 H1',
episode_body=dense_json,
source=EpisodeType.json,
source_description='AWS cost breakdown by service',
reference_time=datetime.now(timezone.utc),
)
print('Successfully added dense episode\n')
#################################################
# EXAMPLE 3: Message Content
#################################################
# Conversation content with speaker patterns.
# Chunking preserves message boundaries.
#################################################
print('=' * 60)
print('EXAMPLE 3: Message Content (Conversation)')
print('=' * 60)
print(f'Content length: {len(MESSAGE_EPISODE_CONTENT)} characters')
print(f'Estimated tokens: ~{len(MESSAGE_EPISODE_CONTENT) // 4}')
print('Expected behavior: Depends on density threshold')
print()
await graphiti.add_episode(
name='Dev Environment Setup Chat',
episode_body=MESSAGE_EPISODE_CONTENT,
source=EpisodeType.message,
source_description='Support conversation',
reference_time=datetime.now(timezone.utc),
)
print('Successfully added message episode\n')
#################################################
# SEARCH RESULTS
#################################################
print('=' * 60)
print('SEARCH: Verifying extracted entities')
print('=' * 60)
# Search for entities from normal content
print("\nSearching for: 'Q4 planning meeting participants'")
results = await graphiti.search('Q4 planning meeting participants')
print(f'Found {len(results)} results')
for r in results[:3]:
print(f' - {r.fact}')
# Search for entities from dense content
print("\nSearching for: 'AWS S3 costs'")
results = await graphiti.search('AWS S3 costs')
print(f'Found {len(results)} results')
for r in results[:3]:
print(f' - {r.fact}')
# Search for entities from message content
print("\nSearching for: 'ECS Fargate recommendations'")
results = await graphiti.search('ECS Fargate recommendations')
print(f'Found {len(results)} results')
for r in results[:3]:
print(f' - {r.fact}')
finally:
await graphiti.close()
print('\nConnection closed')
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import sys
from datetime import datetime
import numpy as np
import pytest
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from tests.helpers_test import get_edge_count, get_node_count, group_id
pytest_plugins = ('pytest_asyncio',)
def setup_logging():
# Create a logger
logger = logging.getLogger()
logger.setLevel(logging.INFO) # Set the logging level to INFO
# Create console handler and set level to INFO
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
# Create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Add formatter to console handler
console_handler.setFormatter(formatter)
# Add console handler to logger
logger.addHandler(console_handler)
return logger
@pytest.mark.asyncio
async def test_episodic_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create episodic node
episode_node = EpisodicNode(
name='test_episode',
labels=[],
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
entity_edges=[],
group_id=group_id,
)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await episode_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 1
# Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create episodic to entity edge
episodic_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
await episodic_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
assert retrieved.uuid == episodic_edge.uuid
assert retrieved.source_node_uuid == episode_node.uuid
assert retrieved.target_node_uuid == alice_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == episodic_edge.uuid
assert retrieved[0].source_node_uuid == episode_node.uuid
assert retrieved[0].target_node_uuid == alice_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get episodic node by entity node uuid
retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == episode_node.uuid
assert retrieved[0].name == 'test_episode'
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await episodic_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await episodic_edge.save(graph_driver)
await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await episode_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [episode_node.uuid])
assert node_count == 0
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_entity_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create entity node
alice_node = EntityNode(
name='Alice',
labels=[],
created_at=now,
summary='Alice summary',
group_id=group_id,
)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create entity node
bob_node = EntityNode(
name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
)
await bob_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await bob_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 1
# Create entity to entity edge
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
edge_embedding = await entity_edge.generate_embedding(mock_embedder)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
await entity_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
assert retrieved.uuid == entity_edge.uuid
assert retrieved.source_node_uuid == alice_node.uuid
assert retrieved.target_node_uuid == bob_node.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by node uuid
retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
assert len(retrieved) == 1
assert retrieved[0].uuid == entity_edge.uuid
assert retrieved[0].source_node_uuid == alice_node.uuid
assert retrieved[0].target_node_uuid == bob_node.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get fact embedding
await entity_edge.load_fact_embedding(graph_driver)
assert np.allclose(entity_edge.fact_embedding, edge_embedding)
# Delete edge by uuid
await entity_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await entity_edge.save(graph_driver)
await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node should delete the edge
await entity_edge.save(graph_driver)
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by uuids should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Deleting node by group id should delete the edge
await alice_node.save(graph_driver)
await entity_edge.save(graph_driver)
await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await bob_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [bob_node.uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_community_edge(graph_driver, mock_embedder):
now = datetime.now()
# Create community node
community_node_1 = CommunityNode(
name='test_community_1',
group_id=group_id,
summary='Community A summary',
)
await community_node_1.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_1.save(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 1
# Create community node
community_node_2 = CommunityNode(
name='test_community_2',
group_id=group_id,
summary='Community B summary',
)
await community_node_2.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await community_node_2.save(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 1
# Create entity node
alice_node = EntityNode(
name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
)
await alice_node.generate_name_embedding(mock_embedder)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await alice_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 1
# Create community to community edge
community_edge = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=community_node_2.uuid,
created_at=now,
group_id=group_id,
)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
await community_edge.save(graph_driver)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 1
# Get edge by uuid
retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
assert retrieved.uuid == community_edge.uuid
assert retrieved.source_node_uuid == community_node_1.uuid
assert retrieved.target_node_uuid == community_node_2.uuid
assert retrieved.created_at == now
assert retrieved.group_id == group_id
# Get edge by uuids
retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Get edge by group ids
retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
assert len(retrieved) == 1
assert retrieved[0].uuid == community_edge.uuid
assert retrieved[0].source_node_uuid == community_node_1.uuid
assert retrieved[0].target_node_uuid == community_node_2.uuid
assert retrieved[0].created_at == now
assert retrieved[0].group_id == group_id
# Delete edge by uuid
await community_edge.delete(graph_driver)
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Delete edge by uuids
await community_edge.save(graph_driver)
await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
assert edge_count == 0
# Cleanup nodes
await alice_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [alice_node.uuid])
assert node_count == 0
await community_node_1.delete(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_1.uuid])
assert node_count == 0
await community_node_2.delete(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_2.uuid])
assert node_count == 0
await graph_driver.close()
```
--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# Running tests: pytest -xvs tests/embedder/test_gemini.py
from collections.abc import Generator
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from embedder_fixtures import create_embedding_values
from graphiti_core.embedder.gemini import (
DEFAULT_EMBEDDING_MODEL,
GeminiEmbedder,
GeminiEmbedderConfig,
)
def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
"""Create a mock Gemini embedding with specified value multiplier and dimension."""
mock_embedding = MagicMock()
mock_embedding.values = create_embedding_values(multiplier, dimension)
return mock_embedding
@pytest.fixture
def mock_gemini_response() -> MagicMock:
"""Create a mock Gemini embeddings response."""
mock_result = MagicMock()
mock_result.embeddings = [create_gemini_embedding()]
return mock_result
@pytest.fixture
def mock_gemini_batch_response() -> MagicMock:
"""Create a mock Gemini batch embeddings response."""
mock_result = MagicMock()
mock_result.embeddings = [
create_gemini_embedding(0.1),
create_gemini_embedding(0.2),
create_gemini_embedding(0.3),
]
return mock_result
@pytest.fixture
def mock_gemini_client() -> Generator[Any, Any, None]:
"""Create a mocked Gemini client."""
with patch('google.genai.Client') as mock_client:
mock_instance = mock_client.return_value
mock_instance.aio = MagicMock()
mock_instance.aio.models = MagicMock()
mock_instance.aio.models.embed_content = AsyncMock()
yield mock_instance
@pytest.fixture
def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
"""Create a GeminiEmbedder with a mocked client."""
config = GeminiEmbedderConfig(api_key='test_api_key')
client = GeminiEmbedder(config=config)
client.client = mock_gemini_client
return client
class TestGeminiEmbedderInitialization:
"""Tests for GeminiEmbedder initialization."""
@patch('google.genai.Client')
def test_init_with_config(self, mock_client):
"""Test initialization with a config object."""
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
)
embedder = GeminiEmbedder(config=config)
assert embedder.config == config
assert embedder.config.embedding_model == 'custom-model'
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_dim == 768
@patch('google.genai.Client')
def test_init_without_config(self, mock_client):
"""Test initialization without a config uses defaults."""
embedder = GeminiEmbedder()
assert embedder.config is not None
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
@patch('google.genai.Client')
def test_init_with_partial_config(self, mock_client):
"""Test initialization with partial config."""
config = GeminiEmbedderConfig(api_key='test_api_key')
embedder = GeminiEmbedder(config=config)
assert embedder.config.api_key == 'test_api_key'
assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
class TestGeminiEmbedderCreate:
"""Tests for GeminiEmbedder create method."""
@pytest.mark.asyncio
async def test_create_calls_api_correctly(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test that create method correctly calls the API and processes the response."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
result = await gemini_embedder.create('Test input')
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == ['Test input']
# Verify result is processed correctly
assert result == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_model(
self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
) -> None:
"""Test create method with custom embedding model."""
# Setup embedder with custom model
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Call method
await embedder.create('Test input')
# Verify custom model is used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-model'
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_with_custom_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create method with custom embedding dimension."""
# Setup embedder with custom dimension
config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response with custom dimension
mock_response = MagicMock()
mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method
result = await embedder.create('Test input')
# Verify custom dimension is used in config
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['config'].output_dimensionality == 768
# Verify result has correct dimension
assert len(result) == 768
@pytest.mark.asyncio
async def test_create_with_different_input_types(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create method with different input types."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
# Test with string
await gemini_embedder.create('Test string')
# Test with list of strings
await gemini_embedder.create(['Test', 'List'])
# Test with iterable of integers
await gemini_embedder.create([1, 2, 3])
# Verify all calls were made
assert mock_gemini_client.aio.models.embed_content.call_count == 3
@pytest.mark.asyncio
async def test_create_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_no_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create method handling of embeddings with no values."""
# Setup mock response with embedding but no values
mock_embedding = MagicMock()
mock_embedding.values = None
mock_response = MagicMock()
mock_response.embeddings = [mock_embedding]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
# Call method and expect exception
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create('Test input')
assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
class TestGeminiEmbedderCreateBatch:
"""Tests for GeminiEmbedder create_batch method."""
@pytest.mark.asyncio
async def test_create_batch_processes_multiple_inputs(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_batch_response: MagicMock,
) -> None:
"""Test that create_batch method correctly processes multiple inputs."""
# Setup
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
input_batch = ['Input 1', 'Input 2', 'Input 3']
# Call method
result = await gemini_embedder.create_batch(input_batch)
# Verify API is called with correct parameters
mock_gemini_client.aio.models.embed_content.assert_called_once()
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
assert kwargs['contents'] == input_batch
# Verify all results are processed correctly
assert len(result) == 3
assert result == [
mock_gemini_batch_response.embeddings[0].values,
mock_gemini_batch_response.embeddings[1].values,
mock_gemini_batch_response.embeddings[2].values,
]
@pytest.mark.asyncio
async def test_create_batch_single_input(
self,
gemini_embedder: GeminiEmbedder,
mock_gemini_client: Any,
mock_gemini_response: MagicMock,
) -> None:
"""Test create_batch method with single input."""
mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
input_batch = ['Single input']
result = await gemini_embedder.create_batch(input_batch)
assert len(result) == 1
assert result[0] == mock_gemini_response.embeddings[0].values
@pytest.mark.asyncio
async def test_create_batch_empty_input(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method with empty input."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = []
result = await gemini_embedder.create_batch(input_batch)
assert result == []
mock_gemini_client.aio.models.embed_content.assert_not_called()
@pytest.mark.asyncio
async def test_create_batch_no_embeddings_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of no embeddings response."""
# Setup mock response with no embeddings
mock_response = MagicMock()
mock_response.embeddings = []
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'No embeddings returned from Gemini API' in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_batch_empty_values_error(
self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
) -> None:
"""Test create_batch method handling of embeddings with empty values."""
# Setup mock response with embeddings but empty values
mock_embedding1 = MagicMock()
mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
mock_embedding2 = MagicMock()
mock_embedding2.values = None # Empty values
# Mock response for the initial batch call
mock_batch_response = MagicMock()
mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
# Mock response for individual processing of 'Input 1'
mock_individual_response_1 = MagicMock()
mock_individual_response_1.embeddings = [mock_embedding1]
# Mock response for individual processing of 'Input 2' (which has empty values)
mock_individual_response_2 = MagicMock()
mock_individual_response_2.embeddings = [mock_embedding2]
# Set side_effect for embed_content to control return values for each call
mock_gemini_client.aio.models.embed_content.side_effect = [
mock_batch_response, # First call for the batch
mock_individual_response_1, # Second call for individual item 1
mock_individual_response_2, # Third call for individual item 2
]
input_batch = ['Input 1', 'Input 2']
with pytest.raises(ValueError) as exc_info:
await gemini_embedder.create_batch(input_batch)
assert 'Empty embedding values returned' in str(exc_info.value)
@pytest.mark.asyncio
@patch('google.genai.Client')
async def test_create_batch_with_custom_model_and_dimension(
self, mock_client_class, mock_gemini_client: Any
) -> None:
"""Test create_batch method with custom model and dimension."""
# Setup embedder with custom settings
config = GeminiEmbedderConfig(
api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
)
embedder = GeminiEmbedder(config=config)
embedder.client = mock_gemini_client
# Setup mock response
mock_response = MagicMock()
mock_response.embeddings = [
create_gemini_embedding(0.1, 512),
create_gemini_embedding(0.2, 512),
]
mock_gemini_client.aio.models.embed_content.return_value = mock_response
input_batch = ['Input 1', 'Input 2']
result = await embedder.create_batch(input_batch)
# Verify custom settings are used
_, kwargs = mock_gemini_client.aio.models.embed_content.call_args
assert kwargs['model'] == 'custom-batch-model'
assert kwargs['config'].output_dimensionality == 512
# Verify results have correct dimension
assert len(result) == 2
assert all(len(embedding) == 512 for embedding in result)
if __name__ == '__main__':
pytest.main(['-xvs', __file__])
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_entity_extraction.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.prompts.extract_nodes import ExtractedEntity
from graphiti_core.utils import content_chunking
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance import node_operations
from graphiti_core.utils.maintenance.node_operations import (
_build_entity_types_context,
_merge_extracted_entities,
extract_nodes,
)
def _make_clients():
"""Create mock GraphitiClients for testing."""
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
llm_generate = AsyncMock()
llm_client.generate_response = llm_generate
clients = GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
)
return clients, llm_generate
def _make_episode(
content: str = 'Test content',
source: EpisodeType = EpisodeType.text,
group_id: str = 'group',
) -> EpisodicNode:
"""Create a test episode node."""
return EpisodicNode(
name='test_episode',
group_id=group_id,
source=source,
source_description='test',
content=content,
valid_at=utc_now(),
)
class TestExtractNodesSmallInput:
@pytest.mark.asyncio
async def test_small_input_single_llm_call(self, monkeypatch):
"""Small inputs should use a single LLM call without chunking."""
clients, llm_generate = _make_clients()
# Mock LLM response
llm_generate.return_value = {
'extracted_entities': [
{'name': 'Alice', 'entity_type_id': 0},
{'name': 'Bob', 'entity_type_id': 0},
]
}
# Small content (below threshold)
episode = _make_episode(content='Alice talked to Bob.')
nodes = await extract_nodes(
clients,
episode,
previous_episodes=[],
)
# Verify results
assert len(nodes) == 2
assert {n.name for n in nodes} == {'Alice', 'Bob'}
# LLM should be called exactly once
llm_generate.assert_awaited_once()
@pytest.mark.asyncio
async def test_extracts_entity_types(self, monkeypatch):
"""Entity type classification should work correctly."""
clients, llm_generate = _make_clients()
from pydantic import BaseModel
class Person(BaseModel):
"""A human person."""
pass
llm_generate.return_value = {
'extracted_entities': [
{'name': 'Alice', 'entity_type_id': 1}, # Person
{'name': 'Acme Corp', 'entity_type_id': 0}, # Default Entity
]
}
episode = _make_episode(content='Alice works at Acme Corp.')
nodes = await extract_nodes(
clients,
episode,
previous_episodes=[],
entity_types={'Person': Person},
)
# Alice should have Person label
alice = next(n for n in nodes if n.name == 'Alice')
assert 'Person' in alice.labels
# Acme should have Entity label
acme = next(n for n in nodes if n.name == 'Acme Corp')
assert 'Entity' in acme.labels
@pytest.mark.asyncio
async def test_excludes_entity_types(self, monkeypatch):
"""Excluded entity types should not appear in results."""
clients, llm_generate = _make_clients()
from pydantic import BaseModel
class User(BaseModel):
"""A user of the system."""
pass
llm_generate.return_value = {
'extracted_entities': [
{'name': 'Alice', 'entity_type_id': 1}, # User (excluded)
{'name': 'Project X', 'entity_type_id': 0}, # Entity
]
}
episode = _make_episode(content='Alice created Project X.')
nodes = await extract_nodes(
clients,
episode,
previous_episodes=[],
entity_types={'User': User},
excluded_entity_types=['User'],
)
# Alice should be excluded
assert len(nodes) == 1
assert nodes[0].name == 'Project X'
@pytest.mark.asyncio
async def test_filters_empty_names(self, monkeypatch):
"""Entities with empty names should be filtered out."""
clients, llm_generate = _make_clients()
llm_generate.return_value = {
'extracted_entities': [
{'name': 'Alice', 'entity_type_id': 0},
{'name': '', 'entity_type_id': 0},
{'name': ' ', 'entity_type_id': 0},
]
}
episode = _make_episode(content='Alice is here.')
nodes = await extract_nodes(
clients,
episode,
previous_episodes=[],
)
assert len(nodes) == 1
assert nodes[0].name == 'Alice'
class TestExtractNodesChunking:
@pytest.mark.asyncio
async def test_large_input_triggers_chunking(self, monkeypatch):
"""Large inputs should be chunked and processed in parallel."""
clients, llm_generate = _make_clients()
# Track number of LLM calls
call_count = 0
async def mock_generate(*args, **kwargs):
nonlocal call_count
call_count += 1
return {
'extracted_entities': [
{'name': f'Entity{call_count}', 'entity_type_id': 0},
]
}
llm_generate.side_effect = mock_generate
# Patch should_chunk where it's imported in node_operations
monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50) # Small chunk size
# Large content that exceeds threshold
large_content = 'word ' * 1000
episode = _make_episode(content=large_content)
await extract_nodes(
clients,
episode,
previous_episodes=[],
)
# Multiple LLM calls should have been made
assert call_count > 1
@pytest.mark.asyncio
async def test_json_content_uses_json_chunking(self, monkeypatch):
"""JSON episodes should use JSON-aware chunking."""
clients, llm_generate = _make_clients()
llm_generate.return_value = {
'extracted_entities': [
{'name': 'Service1', 'entity_type_id': 0},
]
}
# Patch should_chunk where it's imported in node_operations
monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50) # Small chunk size
# JSON content
json_data = [{'service': f'Service{i}'} for i in range(50)]
episode = _make_episode(
content=json.dumps(json_data),
source=EpisodeType.json,
)
await extract_nodes(
clients,
episode,
previous_episodes=[],
)
# Verify JSON chunking was used (LLM called multiple times)
assert llm_generate.await_count > 1
@pytest.mark.asyncio
async def test_message_content_uses_message_chunking(self, monkeypatch):
"""Message episodes should use message-aware chunking."""
clients, llm_generate = _make_clients()
llm_generate.return_value = {
'extracted_entities': [
{'name': 'Speaker', 'entity_type_id': 0},
]
}
# Patch should_chunk where it's imported in node_operations
monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50) # Small chunk size
# Conversation content
messages = [f'Speaker{i}: Hello from speaker {i}!' for i in range(50)]
episode = _make_episode(
content='\n'.join(messages),
source=EpisodeType.message,
)
await extract_nodes(
clients,
episode,
previous_episodes=[],
)
assert llm_generate.await_count > 1
@pytest.mark.asyncio
async def test_deduplicates_across_chunks(self, monkeypatch):
"""Entities appearing in multiple chunks should be deduplicated."""
clients, llm_generate = _make_clients()
# Simulate same entity appearing in multiple chunks
call_count = 0
async def mock_generate(*args, **kwargs):
nonlocal call_count
call_count += 1
# Return 'Alice' in every chunk
return {
'extracted_entities': [
{'name': 'Alice', 'entity_type_id': 0},
{'name': f'Entity{call_count}', 'entity_type_id': 0},
]
}
llm_generate.side_effect = mock_generate
# Patch should_chunk where it's imported in node_operations
monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50) # Small chunk size
large_content = 'word ' * 1000
episode = _make_episode(content=large_content)
nodes = await extract_nodes(
clients,
episode,
previous_episodes=[],
)
# Alice should appear only once despite being in every chunk
alice_count = sum(1 for n in nodes if n.name == 'Alice')
assert alice_count == 1
@pytest.mark.asyncio
async def test_deduplication_case_insensitive(self, monkeypatch):
"""Deduplication should be case-insensitive."""
clients, llm_generate = _make_clients()
call_count = 0
async def mock_generate(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
return {'extracted_entities': [{'name': 'alice', 'entity_type_id': 0}]}
return {'extracted_entities': [{'name': 'Alice', 'entity_type_id': 0}]}
llm_generate.side_effect = mock_generate
# Patch should_chunk where it's imported in node_operations
monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50) # Small chunk size
large_content = 'word ' * 1000
episode = _make_episode(content=large_content)
nodes = await extract_nodes(
clients,
episode,
previous_episodes=[],
)
# Should have only one Alice (case-insensitive dedup)
alice_variants = [n for n in nodes if n.name.lower() == 'alice']
assert len(alice_variants) == 1
class TestExtractNodesPromptSelection:
@pytest.mark.asyncio
async def test_uses_text_prompt_for_text_episodes(self, monkeypatch):
"""Text episodes should use extract_text prompt."""
clients, llm_generate = _make_clients()
llm_generate.return_value = {'extracted_entities': []}
episode = _make_episode(source=EpisodeType.text)
await extract_nodes(clients, episode, previous_episodes=[])
# Check prompt_name parameter
call_kwargs = llm_generate.call_args[1]
assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_text'
@pytest.mark.asyncio
async def test_uses_json_prompt_for_json_episodes(self, monkeypatch):
"""JSON episodes should use extract_json prompt."""
clients, llm_generate = _make_clients()
llm_generate.return_value = {'extracted_entities': []}
episode = _make_episode(content='{}', source=EpisodeType.json)
await extract_nodes(clients, episode, previous_episodes=[])
call_kwargs = llm_generate.call_args[1]
assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_json'
@pytest.mark.asyncio
async def test_uses_message_prompt_for_message_episodes(self, monkeypatch):
"""Message episodes should use extract_message prompt."""
clients, llm_generate = _make_clients()
llm_generate.return_value = {'extracted_entities': []}
episode = _make_episode(source=EpisodeType.message)
await extract_nodes(clients, episode, previous_episodes=[])
call_kwargs = llm_generate.call_args[1]
assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_message'
class TestBuildEntityTypesContext:
def test_default_entity_type_always_included(self):
"""Default Entity type should always be at index 0."""
context = _build_entity_types_context(None)
assert len(context) == 1
assert context[0]['entity_type_id'] == 0
assert context[0]['entity_type_name'] == 'Entity'
def test_custom_types_added_after_default(self):
"""Custom entity types should be added with sequential IDs."""
from pydantic import BaseModel
class Person(BaseModel):
"""A human person."""
pass
class Organization(BaseModel):
"""A business or organization."""
pass
context = _build_entity_types_context(
{
'Person': Person,
'Organization': Organization,
}
)
assert len(context) == 3
assert context[0]['entity_type_name'] == 'Entity'
assert context[1]['entity_type_name'] == 'Person'
assert context[1]['entity_type_id'] == 1
assert context[2]['entity_type_name'] == 'Organization'
assert context[2]['entity_type_id'] == 2
class TestMergeExtractedEntities:
def test_merge_deduplicates_by_name(self):
"""Entities with same name should be deduplicated."""
chunk_results = [
[
ExtractedEntity(name='Alice', entity_type_id=0),
ExtractedEntity(name='Bob', entity_type_id=0),
],
[
ExtractedEntity(name='Alice', entity_type_id=0), # Duplicate
ExtractedEntity(name='Charlie', entity_type_id=0),
],
]
merged = _merge_extracted_entities(chunk_results)
assert len(merged) == 3
names = {e.name for e in merged}
assert names == {'Alice', 'Bob', 'Charlie'}
def test_merge_prefers_first_occurrence(self):
"""When duplicates exist, first occurrence should be preferred."""
chunk_results = [
[ExtractedEntity(name='Alice', entity_type_id=1)], # First: type 1
[ExtractedEntity(name='Alice', entity_type_id=2)], # Later: type 2
]
merged = _merge_extracted_entities(chunk_results)
assert len(merged) == 1
assert merged[0].entity_type_id == 1 # First occurrence wins
```
--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from graphiti_core.driver.driver import GraphProvider
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
HAS_FALKORDB = True
except ImportError:
FalkorDriver = None
HAS_FALKORDB = False
class TestFalkorDriver:
"""Comprehensive test suite for FalkorDB driver."""
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def setup_method(self):
"""Set up test fixtures."""
self.mock_client = MagicMock()
with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
self.driver = FalkorDriver()
self.driver.client = self.mock_client
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_init_with_connection_params(self):
"""Test initialization with connection parameters."""
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
driver = FalkorDriver(
host='test-host', port='1234', username='test-user', password='test-pass'
)
assert driver.provider == GraphProvider.FALKORDB
mock_falkor_db.assert_called_once_with(
host='test-host', port='1234', username='test-user', password='test-pass'
)
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_init_with_falkor_db_instance(self):
"""Test initialization with a FalkorDB instance."""
with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
mock_falkor_db = MagicMock()
driver = FalkorDriver(falkor_db=mock_falkor_db)
assert driver.provider == GraphProvider.FALKORDB
assert driver.client is mock_falkor_db
mock_falkor_db_class.assert_not_called()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_provider(self):
"""Test driver provider identification."""
assert self.driver.provider == GraphProvider.FALKORDB
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_get_graph_with_name(self):
"""Test _get_graph with specific graph name."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph('test_graph')
self.mock_client.select_graph.assert_called_once_with('test_graph')
assert result is mock_graph
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_get_graph_with_none_defaults_to_default_database(self):
"""Test _get_graph with None defaults to default_db."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
result = self.driver._get_graph(None)
self.mock_client.select_graph.assert_called_once_with('default_db')
assert result is mock_graph
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_success(self):
"""Test successful query execution."""
mock_graph = MagicMock()
mock_result = MagicMock()
mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
mock_result.result_set = [['row1col1', 'row1col2']]
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')
mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})
result_set, header, summary = result
assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
assert header == ['column1', 'column2']
assert summary is None
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_handles_index_already_exists_error(self):
"""Test handling of 'already indexed' error."""
mock_graph = MagicMock()
mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
self.mock_client.select_graph.return_value = mock_graph
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
result = await self.driver.execute_query('CREATE INDEX ...')
mock_logger.info.assert_called_once()
assert result is None
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_propagates_other_exceptions(self):
"""Test that other exceptions are properly propagated."""
mock_graph = MagicMock()
mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
self.mock_client.select_graph.return_value = mock_graph
with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
with pytest.raises(Exception, match='Other error'):
await self.driver.execute_query('INVALID QUERY')
mock_logger.error.assert_called_once()
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_query_converts_datetime_parameters(self):
"""Test that datetime objects in kwargs are converted to ISO strings."""
mock_graph = MagicMock()
mock_result = MagicMock()
mock_result.header = []
mock_result.result_set = []
mock_graph.query = AsyncMock(return_value=mock_result)
self.mock_client.select_graph.return_value = mock_graph
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
await self.driver.execute_query(
'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
)
call_args = mock_graph.query.call_args[0]
assert call_args[1]['created_at'] == test_datetime.isoformat()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_session_creation(self):
"""Test session creation with specific database."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session()
assert isinstance(session, FalkorDriverSession)
assert session.graph is mock_graph
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_session_creation_with_none_uses_default_database(self):
"""Test session creation with None uses default database."""
mock_graph = MagicMock()
self.mock_client.select_graph.return_value = mock_graph
session = self.driver.session()
assert isinstance(session, FalkorDriverSession)
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_close_calls_connection_close(self):
"""Test driver close method calls connection close."""
mock_connection = MagicMock()
mock_connection.close = AsyncMock()
self.mock_client.connection = mock_connection
# Ensure hasattr checks work correctly
del self.mock_client.aclose # Remove aclose if it exists
with patch('builtins.hasattr') as mock_hasattr:
# hasattr(self.client, 'aclose') returns False
# hasattr(self.client.connection, 'aclose') returns False
# hasattr(self.client.connection, 'close') returns True
mock_hasattr.side_effect = lambda obj, attr: (
attr == 'close' and obj is mock_connection
)
await self.driver.close()
mock_connection.close.assert_called_once()
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_delete_all_indexes(self):
"""Test delete_all_indexes method."""
with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
# Return None to simulate no indexes found
mock_execute.return_value = None
await self.driver.delete_all_indexes()
mock_execute.assert_called_once_with('CALL db.indexes()')
class TestFalkorDriverSession:
"""Test FalkorDB driver session functionality."""
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def setup_method(self):
"""Set up test fixtures."""
self.mock_graph = MagicMock()
self.session = FalkorDriverSession(self.mock_graph)
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_session_async_context_manager(self):
"""Test session can be used as async context manager."""
async with self.session as s:
assert s is self.session
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_close_method(self):
"""Test session close method doesn't raise exceptions."""
await self.session.close() # Should not raise
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_execute_write_passes_session_and_args(self):
"""Test execute_write method passes session and arguments correctly."""
async def test_func(session, *args, **kwargs):
assert session is self.session
assert args == ('arg1', 'arg2')
assert kwargs == {'key': 'value'}
return 'result'
result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
assert result == 'result'
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_run_single_query_with_parameters(self):
"""Test running a single query with parameters."""
self.mock_graph.query = AsyncMock()
await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')
self.mock_graph.query.assert_called_once_with(
'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
)
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_run_multiple_queries_as_list(self):
"""Test running multiple queries passed as list."""
self.mock_graph.query = AsyncMock()
queries = [
('MATCH (n) RETURN n', {'param1': 'value1'}),
('CREATE (n:Node)', {'param2': 'value2'}),
]
await self.session.run(queries)
assert self.mock_graph.query.call_count == 2
calls = self.mock_graph.query.call_args_list
assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_run_converts_datetime_objects_to_iso_strings(self):
"""Test that datetime objects are converted to ISO strings."""
self.mock_graph.query = AsyncMock()
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
await self.session.run(
'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
)
self.mock_graph.query.assert_called_once()
call_args = self.mock_graph.query.call_args[0]
assert call_args[1]['created_at'] == test_datetime.isoformat()
class TestDatetimeConversion:
"""Test datetime conversion utility function."""
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_datetime_dict(self):
"""Test datetime conversion in nested dictionary."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
input_dict = {
'string_val': 'test',
'datetime_val': test_datetime,
'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
}
result = convert_datetimes_to_strings(input_dict)
assert result['string_val'] == 'test'
assert result['datetime_val'] == test_datetime.isoformat()
assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
assert result['nested_dict']['nested_string'] == 'nested_test'
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_datetime_list_and_tuple(self):
"""Test datetime conversion in lists and tuples."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
# Test list
input_list = ['test', test_datetime, ['nested', test_datetime]]
result_list = convert_datetimes_to_strings(input_list)
assert result_list[0] == 'test'
assert result_list[1] == test_datetime.isoformat()
assert result_list[2][1] == test_datetime.isoformat()
# Test tuple
input_tuple = ('test', test_datetime)
result_tuple = convert_datetimes_to_strings(input_tuple)
assert isinstance(result_tuple, tuple)
assert result_tuple[0] == 'test'
assert result_tuple[1] == test_datetime.isoformat()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_single_datetime(self):
"""Test datetime conversion for single datetime object."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
result = convert_datetimes_to_strings(test_datetime)
assert result == test_datetime.isoformat()
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
def test_convert_other_types_unchanged(self):
"""Test that non-datetime types are returned unchanged."""
from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
assert convert_datetimes_to_strings('string') == 'string'
assert convert_datetimes_to_strings(123) == 123
assert convert_datetimes_to_strings(None) is None
assert convert_datetimes_to_strings(True) is True
# Simple integration test
class TestFalkorDriverIntegration:
"""Simple integration test for FalkorDB driver."""
@pytest.mark.asyncio
@unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
async def test_basic_integration_with_real_falkordb(self):
"""Basic integration test with real FalkorDB instance."""
pytest.importorskip('falkordb')
falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
falkor_port = os.getenv('FALKORDB_PORT', '6379')
try:
driver = FalkorDriver(host=falkor_host, port=falkor_port)
# Test basic query execution
result = await driver.execute_query('RETURN 1 as test')
assert result is not None
result_set, header, summary = result
assert header == ['test']
assert result_set == [{'test': 1}]
await driver.close()
except Exception as e:
pytest.skip(f'FalkorDB not available for integration test: {e}')
```
--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------
```python
"""Factory classes for creating LLM, Embedder, and Database clients."""
from openai import AsyncAzureOpenAI
from config.schema import (
DatabaseConfig,
EmbedderConfig,
LLMConfig,
)
# Try to import FalkorDriver if available
try:
from graphiti_core.driver.falkordb_driver import FalkorDriver # noqa: F401
HAS_FALKOR = True
except ImportError:
HAS_FALKOR = False
# Kuzu support removed - FalkorDB is now the default
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig
# Try to import additional providers if available
try:
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
HAS_AZURE_EMBEDDER = True
except ImportError:
HAS_AZURE_EMBEDDER = False
try:
from graphiti_core.embedder.gemini import GeminiEmbedder
HAS_GEMINI_EMBEDDER = True
except ImportError:
HAS_GEMINI_EMBEDDER = False
try:
from graphiti_core.embedder.voyage import VoyageAIEmbedder
HAS_VOYAGE_EMBEDDER = True
except ImportError:
HAS_VOYAGE_EMBEDDER = False
try:
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
HAS_AZURE_LLM = True
except ImportError:
HAS_AZURE_LLM = False
try:
from graphiti_core.llm_client.anthropic_client import AnthropicClient
HAS_ANTHROPIC = True
except ImportError:
HAS_ANTHROPIC = False
try:
from graphiti_core.llm_client.gemini_client import GeminiClient
HAS_GEMINI = True
except ImportError:
HAS_GEMINI = False
try:
from graphiti_core.llm_client.groq_client import GroqClient
HAS_GROQ = True
except ImportError:
HAS_GROQ = False
from utils.utils import create_azure_credential_token_provider
def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
"""Validate API key is present.
Args:
provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
api_key: The API key to validate
logger: Logger instance for output
Returns:
The validated API key
Raises:
ValueError: If API key is None or empty
"""
if not api_key:
raise ValueError(
f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
)
logger.info(f'Creating {provider_name} client')
return api_key
class LLMClientFactory:
"""Factory for creating LLM clients based on configuration."""
@staticmethod
def create(config: LLMConfig) -> LLMClient:
"""Create an LLM client based on the configured provider."""
import logging
logger = logging.getLogger(__name__)
provider = config.provider.lower()
match provider:
case 'openai':
if not config.providers.openai:
raise ValueError('OpenAI provider configuration not found')
api_key = config.providers.openai.api_key
_validate_api_key('OpenAI', api_key, logger)
from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
# Determine appropriate small model based on main model type
is_reasoning_model = (
config.model.startswith('gpt-5')
or config.model.startswith('o1')
or config.model.startswith('o3')
)
small_model = (
'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
) # Use reasoning model for small tasks if main model is reasoning
llm_config = CoreLLMConfig(
api_key=api_key,
model=config.model,
small_model=small_model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
# Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
if is_reasoning_model:
return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
else:
# For non-reasoning models, explicitly pass None to disable these parameters
return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)
case 'azure_openai':
if not HAS_AZURE_LLM:
raise ValueError(
'Azure OpenAI LLM client not available in current graphiti-core version'
)
if not config.providers.azure_openai:
raise ValueError('Azure OpenAI provider configuration not found')
azure_config = config.providers.azure_openai
if not azure_config.api_url:
raise ValueError('Azure OpenAI API URL is required')
# Handle Azure AD authentication if enabled
api_key: str | None = None
azure_ad_token_provider = None
if azure_config.use_azure_ad:
logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
azure_ad_token_provider = create_azure_credential_token_provider()
else:
api_key = azure_config.api_key
_validate_api_key('Azure OpenAI', api_key, logger)
# Create the Azure OpenAI client first
azure_client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=azure_config.api_url,
api_version=azure_config.api_version,
azure_deployment=azure_config.deployment_name,
azure_ad_token_provider=azure_ad_token_provider,
)
# Then create the LLMConfig
from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
llm_config = CoreLLMConfig(
api_key=api_key,
base_url=azure_config.api_url,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return AzureOpenAILLMClient(
azure_client=azure_client,
config=llm_config,
max_tokens=config.max_tokens,
)
case 'anthropic':
if not HAS_ANTHROPIC:
raise ValueError(
'Anthropic client not available in current graphiti-core version'
)
if not config.providers.anthropic:
raise ValueError('Anthropic provider configuration not found')
api_key = config.providers.anthropic.api_key
_validate_api_key('Anthropic', api_key, logger)
llm_config = GraphitiLLMConfig(
api_key=api_key,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return AnthropicClient(config=llm_config)
case 'gemini':
if not HAS_GEMINI:
raise ValueError('Gemini client not available in current graphiti-core version')
if not config.providers.gemini:
raise ValueError('Gemini provider configuration not found')
api_key = config.providers.gemini.api_key
_validate_api_key('Gemini', api_key, logger)
llm_config = GraphitiLLMConfig(
api_key=api_key,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return GeminiClient(config=llm_config)
case 'groq':
if not HAS_GROQ:
raise ValueError('Groq client not available in current graphiti-core version')
if not config.providers.groq:
raise ValueError('Groq provider configuration not found')
api_key = config.providers.groq.api_key
_validate_api_key('Groq', api_key, logger)
llm_config = GraphitiLLMConfig(
api_key=api_key,
base_url=config.providers.groq.api_url,
model=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
)
return GroqClient(config=llm_config)
case _:
raise ValueError(f'Unsupported LLM provider: {provider}')
class EmbedderFactory:
"""Factory for creating Embedder clients based on configuration."""
@staticmethod
def create(config: EmbedderConfig) -> EmbedderClient:
"""Create an Embedder client based on the configured provider."""
import logging
logger = logging.getLogger(__name__)
provider = config.provider.lower()
match provider:
case 'openai':
if not config.providers.openai:
raise ValueError('OpenAI provider configuration not found')
api_key = config.providers.openai.api_key
_validate_api_key('OpenAI Embedder', api_key, logger)
from graphiti_core.embedder.openai import OpenAIEmbedderConfig
embedder_config = OpenAIEmbedderConfig(
api_key=api_key,
embedding_model=config.model,
)
return OpenAIEmbedder(config=embedder_config)
case 'azure_openai':
if not HAS_AZURE_EMBEDDER:
raise ValueError(
'Azure OpenAI embedder not available in current graphiti-core version'
)
if not config.providers.azure_openai:
raise ValueError('Azure OpenAI provider configuration not found')
azure_config = config.providers.azure_openai
if not azure_config.api_url:
raise ValueError('Azure OpenAI API URL is required')
# Handle Azure AD authentication if enabled
api_key: str | None = None
azure_ad_token_provider = None
if azure_config.use_azure_ad:
logger.info(
'Creating Azure OpenAI Embedder client with Azure AD authentication'
)
azure_ad_token_provider = create_azure_credential_token_provider()
else:
api_key = azure_config.api_key
_validate_api_key('Azure OpenAI Embedder', api_key, logger)
# Create the Azure OpenAI client first
azure_client = AsyncAzureOpenAI(
api_key=api_key,
azure_endpoint=azure_config.api_url,
api_version=azure_config.api_version,
azure_deployment=azure_config.deployment_name,
azure_ad_token_provider=azure_ad_token_provider,
)
return AzureOpenAIEmbedderClient(
azure_client=azure_client,
model=config.model or 'text-embedding-3-small',
)
case 'gemini':
if not HAS_GEMINI_EMBEDDER:
raise ValueError(
'Gemini embedder not available in current graphiti-core version'
)
if not config.providers.gemini:
raise ValueError('Gemini provider configuration not found')
api_key = config.providers.gemini.api_key
_validate_api_key('Gemini Embedder', api_key, logger)
from graphiti_core.embedder.gemini import GeminiEmbedderConfig
gemini_config = GeminiEmbedderConfig(
api_key=api_key,
embedding_model=config.model or 'models/text-embedding-004',
embedding_dim=config.dimensions or 768,
)
return GeminiEmbedder(config=gemini_config)
case 'voyage':
if not HAS_VOYAGE_EMBEDDER:
raise ValueError(
'Voyage embedder not available in current graphiti-core version'
)
if not config.providers.voyage:
raise ValueError('Voyage provider configuration not found')
api_key = config.providers.voyage.api_key
_validate_api_key('Voyage Embedder', api_key, logger)
from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig
voyage_config = VoyageAIEmbedderConfig(
api_key=api_key,
embedding_model=config.model or 'voyage-3',
embedding_dim=config.dimensions or 1024,
)
return VoyageAIEmbedder(config=voyage_config)
case _:
raise ValueError(f'Unsupported Embedder provider: {provider}')
class DatabaseDriverFactory:
"""Factory for creating Database drivers based on configuration.
Note: This returns configuration dictionaries that can be passed to Graphiti(),
not driver instances directly, as the drivers require complex initialization.
"""
@staticmethod
def create_config(config: DatabaseConfig) -> dict:
"""Create database configuration dictionary based on the configured provider."""
provider = config.provider.lower()
match provider:
case 'neo4j':
# Use Neo4j config if provided, otherwise use defaults
if config.providers.neo4j:
neo4j_config = config.providers.neo4j
else:
# Create default Neo4j configuration
from config.schema import Neo4jProviderConfig
neo4j_config = Neo4jProviderConfig()
# Check for environment variable overrides (for CI/CD compatibility)
import os
uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
username = os.environ.get('NEO4J_USER', neo4j_config.username)
password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)
return {
'uri': uri,
'user': username,
'password': password,
# Note: database and use_parallel_runtime would need to be passed
# to the driver after initialization if supported
}
case 'falkordb':
if not HAS_FALKOR:
raise ValueError(
'FalkorDB driver not available in current graphiti-core version'
)
# Use FalkorDB config if provided, otherwise use defaults
if config.providers.falkordb:
falkor_config = config.providers.falkordb
else:
# Create default FalkorDB configuration
from config.schema import FalkorDBProviderConfig
falkor_config = FalkorDBProviderConfig()
# Check for environment variable overrides (for CI/CD compatibility)
import os
from urllib.parse import urlparse
uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)
# Parse the URI to extract host and port
parsed = urlparse(uri)
host = parsed.hostname or 'localhost'
port = parsed.port or 6379
return {
'driver': 'falkordb',
'host': host,
'port': port,
'password': password,
'database': falkor_config.database,
}
case _:
raise ValueError(f'Unsupported Database provider: {provider}')
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/anthropic_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import os
import typing
from json import JSONDecodeError
from typing import TYPE_CHECKING, Literal
from pydantic import BaseModel, ValidationError
from ..prompts.models import Message
from .client import LLMClient
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError
if TYPE_CHECKING:
import anthropic
from anthropic import AsyncAnthropic
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
else:
try:
import anthropic
from anthropic import AsyncAnthropic
from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
except ImportError:
raise ImportError(
'anthropic is required for AnthropicClient. '
'Install it with: pip install graphiti-core[anthropic]'
) from None
logger = logging.getLogger(__name__)
AnthropicModel = Literal[
'claude-sonnet-4-5-latest',
'claude-sonnet-4-5-20250929',
'claude-haiku-4-5-latest',
'claude-3-7-sonnet-latest',
'claude-3-7-sonnet-20250219',
'claude-3-5-haiku-latest',
'claude-3-5-haiku-20241022',
'claude-3-5-sonnet-latest',
'claude-3-5-sonnet-20241022',
'claude-3-5-sonnet-20240620',
'claude-3-opus-latest',
'claude-3-opus-20240229',
'claude-3-sonnet-20240229',
'claude-3-haiku-20240307',
'claude-2.1',
'claude-2.0',
]
DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
# Maximum output tokens for different Anthropic models
# Based on official Anthropic documentation (as of 2025)
# Note: These represent standard limits without beta headers.
# Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
# 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
ANTHROPIC_MODEL_MAX_TOKENS = {
# Claude 4.5 models - 64K tokens
'claude-sonnet-4-5-latest': 65536,
'claude-sonnet-4-5-20250929': 65536,
'claude-haiku-4-5-latest': 65536,
# Claude 3.7 models - standard 64K tokens
'claude-3-7-sonnet-latest': 65536,
'claude-3-7-sonnet-20250219': 65536,
# Claude 3.5 models
'claude-3-5-haiku-latest': 8192,
'claude-3-5-haiku-20241022': 8192,
'claude-3-5-sonnet-latest': 8192,
'claude-3-5-sonnet-20241022': 8192,
'claude-3-5-sonnet-20240620': 8192,
# Claude 3 models - 4K tokens
'claude-3-opus-latest': 4096,
'claude-3-opus-20240229': 4096,
'claude-3-sonnet-20240229': 4096,
'claude-3-haiku-20240307': 4096,
# Claude 2 models - 4K tokens
'claude-2.1': 4096,
'claude-2.0': 4096,
}
# Default max tokens for models not in the mapping
DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
class AnthropicClient(LLMClient):
"""
A client for the Anthropic LLM.
Args:
config: A configuration object for the LLM.
cache: Whether to cache the LLM responses.
client: An optional client instance to use.
max_tokens: The maximum number of tokens to generate.
Methods:
generate_response: Generate a response from the LLM.
Notes:
- If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
variable, and all default values will be used for the LLMConfig.
"""
model: AnthropicModel
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
client: AsyncAnthropic | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
) -> None:
if config is None:
config = LLMConfig()
config.api_key = os.getenv('ANTHROPIC_API_KEY')
config.max_tokens = max_tokens
if config.model is None:
config.model = DEFAULT_MODEL
super().__init__(config, cache)
# Explicitly set the instance model to the config model to prevent type checking errors
self.model = typing.cast(AnthropicModel, config.model)
if not client:
self.client = AsyncAnthropic(
api_key=config.api_key,
max_retries=1,
)
else:
self.client = client
def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
"""Extract JSON from text content.
A helper method to extract JSON from text content, used when tool use fails or
no response_model is provided.
Args:
text: The text to extract JSON from
Returns:
Extracted JSON as a dictionary
Raises:
ValueError: If JSON cannot be extracted or parsed
"""
try:
json_start = text.find('{')
json_end = text.rfind('}') + 1
if json_start >= 0 and json_end > json_start:
json_str = text[json_start:json_end]
return json.loads(json_str)
else:
raise ValueError(f'Could not extract JSON from model response: {text}')
except (JSONDecodeError, ValueError) as e:
raise ValueError(f'Could not extract JSON from model response: {text}') from e
def _create_tool(
self, response_model: type[BaseModel] | None = None
) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
"""
Create a tool definition based on the response_model if provided, or a generic JSON tool if not.
Args:
response_model: Optional Pydantic model to use for structured output.
Returns:
A list containing a single tool definition for use with the Anthropic API.
"""
if response_model is not None:
# Use the response_model to define the tool
model_schema = response_model.model_json_schema()
tool_name = response_model.__name__
description = model_schema.get('description', f'Extract {tool_name} information')
else:
# Create a generic JSON output tool
tool_name = 'generic_json_output'
description = 'Output data in JSON format'
model_schema = {
'type': 'object',
'additionalProperties': True,
'description': 'Any JSON object containing the requested information',
}
tool = {
'name': tool_name,
'description': description,
'input_schema': model_schema,
}
tool_list = [tool]
tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
tool_choice = {'type': 'tool', 'name': tool_name}
tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
return tool_list_cast, tool_choice_cast
def _get_max_tokens_for_model(self, model: str) -> int:
"""Get the maximum output tokens for a specific Anthropic model.
Args:
model: The model name to look up
Returns:
int: The maximum output tokens for the model
"""
return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
"""
Resolve the maximum output tokens to use based on precedence rules.
Precedence order (highest to lowest):
1. Explicit max_tokens parameter passed to generate_response()
2. Instance max_tokens set during client initialization
3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
Args:
requested_max_tokens: The max_tokens parameter passed to generate_response()
model: The model name to look up model-specific limits
Returns:
int: The resolved maximum tokens to use
"""
# 1. Use explicit parameter if provided
if requested_max_tokens is not None:
return requested_max_tokens
# 2. Use instance max_tokens if set during initialization
if self.max_tokens is not None:
return self.max_tokens
# 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
return self._get_max_tokens_for_model(model)
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
Generate a response from the Anthropic LLM using tool-based approach for all requests.
Args:
messages: List of message objects to send to the LLM.
response_model: Optional Pydantic model to use for structured output.
max_tokens: Maximum number of tokens to generate.
Returns:
Dictionary containing the structured response from the LLM.
Raises:
RateLimitError: If the rate limit is exceeded.
RefusalError: If the LLM refuses to respond.
Exception: If an error occurs during the generation process.
"""
system_message = messages[0]
user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
user_messages_cast = typing.cast(list[MessageParam], user_messages)
# Resolve max_tokens dynamically based on the model's capabilities
# This allows different models to use their full output capacity
max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
try:
# Create the appropriate tool based on whether response_model is provided
tools, tool_choice = self._create_tool(response_model)
result = await self.client.messages.create(
system=system_message.content,
max_tokens=max_creation_tokens,
temperature=self.temperature,
messages=user_messages_cast,
model=self.model,
tools=tools,
tool_choice=tool_choice,
)
# Extract the tool output from the response
for content_item in result.content:
if content_item.type == 'tool_use':
if isinstance(content_item.input, dict):
tool_args: dict[str, typing.Any] = content_item.input
else:
tool_args = json.loads(str(content_item.input))
return tool_args
# If we didn't get a proper tool_use response, try to extract from text
for content_item in result.content:
if content_item.type == 'text':
return self._extract_json_from_text(content_item.text)
else:
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
)
# If we get here, we couldn't parse a structured response
raise ValueError(
f'Could not extract structured data from model response: {result.content}'
)
except anthropic.RateLimitError as e:
raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
except anthropic.APIError as e:
# Special case for content policy violations. We convert these to RefusalError
# to bypass the retry mechanism, as retrying policy-violating content will always fail.
# This avoids wasting API calls and provides more specific error messaging to the user.
if 'refused to respond' in str(e).lower():
raise RefusalError(str(e)) from e
raise e
except Exception as e:
raise e
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
prompt_name: str | None = None,
) -> dict[str, typing.Any]:
"""
Generate a response from the LLM.
Args:
messages: List of message objects to send to the LLM.
response_model: Optional Pydantic model to use for structured output.
max_tokens: Maximum number of tokens to generate.
Returns:
Dictionary containing the structured response from the LLM.
Raises:
RateLimitError: If the rate limit is exceeded.
RefusalError: If the LLM refuses to respond.
Exception: If an error occurs during the generation process.
"""
if max_tokens is None:
max_tokens = self.max_tokens
# Wrap entire operation in tracing span
with self.tracer.start_span('llm.generate') as span:
attributes = {
'llm.provider': 'anthropic',
'model.size': model_size.value,
'max_tokens': max_tokens,
}
if prompt_name:
attributes['prompt.name'] = prompt_name
span.add_attributes(attributes)
retry_count = 0
max_retries = 2
last_error: Exception | None = None
while retry_count <= max_retries:
try:
response = await self._generate_response(
messages, response_model, max_tokens, model_size
)
# If we have a response_model, attempt to validate the response
if response_model is not None:
# Validate the response against the response_model
model_instance = response_model(**response)
return model_instance.model_dump()
# If no validation needed, return the response
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
span.set_status('error', str(last_error))
raise
except Exception as e:
last_error = e
if retry_count >= max_retries:
if isinstance(e, ValidationError):
logger.error(
f'Validation error after {retry_count}/{max_retries} attempts: {e}'
)
else:
logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
span.set_status('error', str(e))
span.record_exception(e)
raise e
if isinstance(e, ValidationError):
response_model_cast = typing.cast(type[BaseModel], response_model)
error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
else:
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response.'
)
# Common retry logic
retry_count += 1
messages.append(Message(role='user', content=error_context))
logger.warning(
f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
)
# If we somehow get here, raise the last error
span.set_status('error', str(last_error))
raise last_error or Exception('Max retries exceeded with no specific error')
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_bulk_utils.py:
--------------------------------------------------------------------------------
```python
from collections import deque
from unittest.mock import AsyncMock, MagicMock
import pytest
from graphiti_core.edges import EntityEdge
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils import bulk_utils
from graphiti_core.utils.bulk_utils import extract_nodes_and_edges_bulk
from graphiti_core.utils.datetime_utils import utc_now
def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
return EpisodicNode(
name=f'episode-{uuid_suffix}',
group_id=group_id,
labels=[],
source=EpisodeType.message,
content='content',
source_description='test',
created_at=utc_now(),
valid_at=utc_now(),
)
def _make_clients() -> GraphitiClients:
driver = MagicMock()
embedder = MagicMock()
cross_encoder = MagicMock()
llm_client = MagicMock()
return GraphitiClients.model_construct( # bypass validation to allow test doubles
driver=driver,
embedder=embedder,
cross_encoder=cross_encoder,
llm_client=llm_client,
)
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('1')
episode_two = _make_episode('2')
extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
canonical = extracted_one
call_queue = deque()
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
call_queue.append(existing_nodes_override)
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
assert existing_nodes_override is None
return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert len(call_queue) == 2
assert call_queue[0] is None
assert call_queue[1] is None
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(extracted_two.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
clients = _make_clients()
resolve_mock = AsyncMock()
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[],
[],
)
assert nodes_by_episode == {}
assert compressed_map == {}
resolve_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
clients = _make_clients()
episode = _make_episode('solo')
extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode == {episode.uuid: [extracted]}
assert compressed_map == {extracted.uuid: extracted.uuid}
resolve_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
clients = _make_clients()
episode_one = _make_episode('one')
episode_two = _make_episode('two')
extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
canonical = extracted_one
alias = extracted_two
async def fake_resolve(
clients_arg,
nodes_arg,
episode_arg,
previous_episodes_arg,
entity_types_arg,
existing_nodes_override=None,
):
if nodes_arg == [extracted_one]:
return [canonical], {canonical.uuid: canonical.uuid}, []
assert nodes_arg == [extracted_two]
return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted_one], [extracted_two]],
[(episode_one, []), (episode_two, [])],
)
assert nodes_by_episode[episode_one.uuid] == [canonical]
assert nodes_by_episode[episode_two.uuid] == [canonical]
assert compressed_map.get(alias.uuid) == canonical.uuid
@pytest.mark.asyncio
async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
clients = _make_clients()
episode = _make_episode('missing')
extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
with caplog.at_level('WARNING'):
nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
clients,
[[extracted]],
[(episode, [])],
)
assert nodes_by_episode[episode.uuid] == [extracted]
assert compressed_map.get(extracted.uuid) == 'missing-canonical'
assert any('Canonical node missing' in rec.message for rec in caplog.records)
def test_build_directed_uuid_map_empty():
assert bulk_utils._build_directed_uuid_map([]) == {}
def test_build_directed_uuid_map_chain():
mapping = bulk_utils._build_directed_uuid_map(
[
('a', 'b'),
('b', 'c'),
]
)
assert mapping['a'] == 'c'
assert mapping['b'] == 'c'
assert mapping['c'] == 'c'
def test_build_directed_uuid_map_preserves_direction():
mapping = bulk_utils._build_directed_uuid_map(
[
('alias', 'canonical'),
]
)
assert mapping['alias'] == 'canonical'
assert mapping['canonical'] == 'canonical'
def test_resolve_edge_pointers_updates_sources():
created_at = utc_now()
edge = EntityEdge(
name='knows',
fact='fact',
group_id='group',
source_node_uuid='alias',
target_node_uuid='target',
created_at=created_at,
)
bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
assert edge.source_node_uuid == 'canonical'
assert edge.target_node_uuid == 'target'
@pytest.mark.asyncio
async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
"""Test that dedupe_edges_bulk correctly compares edges within the same episode.
This test verifies the fix that removed the `if i == j: continue` check,
which was preventing edges from the same episode from being compared against each other.
"""
clients = _make_clients()
# Track which edges are compared
comparisons_made = []
# Create mock embedder that sets embedding values
async def mock_create_embeddings(embedder, edges):
for edge in edges:
edge.fact_embedding = [0.1, 0.2, 0.3]
monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)
# Mock resolve_extracted_edge to track comparisons and mark duplicates
async def mock_resolve_extracted_edge(
llm_client,
extracted_edge,
related_edges,
existing_edges,
episode,
edge_type_candidates=None,
custom_edge_type_names=None,
):
# Track that this edge was compared against the related_edges
comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))
# If there are related edges with same source/target/fact, mark as duplicate
for related in related_edges:
if (
related.uuid != extracted_edge.uuid # Can't be duplicate of self
and related.source_node_uuid == extracted_edge.source_node_uuid
and related.target_node_uuid == extracted_edge.target_node_uuid
and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
):
# Return the related edge and mark extracted_edge as duplicate
return related, [], [related]
# Otherwise return the extracted edge as-is
return extracted_edge, [], []
monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)
episode = _make_episode('1')
source_uuid = 'source-uuid'
target_uuid = 'target-uuid'
# Create 3 identical edges within the same episode
edge1 = EntityEdge(
name='recommends',
fact='assistant recommends yoga poses',
group_id='group',
source_node_uuid=source_uuid,
target_node_uuid=target_uuid,
created_at=utc_now(),
episodes=[episode.uuid],
)
edge2 = EntityEdge(
name='recommends',
fact='assistant recommends yoga poses',
group_id='group',
source_node_uuid=source_uuid,
target_node_uuid=target_uuid,
created_at=utc_now(),
episodes=[episode.uuid],
)
edge3 = EntityEdge(
name='recommends',
fact='assistant recommends yoga poses',
group_id='group',
source_node_uuid=source_uuid,
target_node_uuid=target_uuid,
created_at=utc_now(),
episodes=[episode.uuid],
)
await bulk_utils.dedupe_edges_bulk(
clients,
[[edge1, edge2, edge3]],
[(episode, [])],
[],
{},
{},
)
# Verify that edges were compared against each other (within same episode)
# Each edge should have been compared against all 3 edges (including itself, which gets filtered)
assert len(comparisons_made) == 3
for _, compared_against in comparisons_made:
# Each edge should have access to all 3 edges as candidates
assert len(compared_against) >= 2 # At least 2 others (self is filtered out)
@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_passes_custom_instructions_to_extract_nodes(
monkeypatch,
):
"""Test that custom_extraction_instructions is passed to extract_nodes."""
clients = _make_clients()
episode = _make_episode('1')
# Track calls to extract_nodes
extract_nodes_calls = []
async def mock_extract_nodes(
clients,
episode,
previous_episodes,
entity_types=None,
excluded_entity_types=None,
custom_extraction_instructions=None,
):
extract_nodes_calls.append(
{
'entity_types': entity_types,
'excluded_entity_types': excluded_entity_types,
'custom_extraction_instructions': custom_extraction_instructions,
}
)
return []
async def mock_extract_edges(
clients,
episode,
nodes,
previous_episodes,
edge_type_map,
group_id='',
edge_types=None,
custom_extraction_instructions=None,
):
return []
monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)
custom_instructions = 'Focus on extracting person entities and their relationships.'
await extract_nodes_and_edges_bulk(
clients,
[(episode, [])],
edge_type_map={},
custom_extraction_instructions=custom_instructions,
)
assert len(extract_nodes_calls) == 1
assert extract_nodes_calls[0]['custom_extraction_instructions'] == custom_instructions
@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_passes_custom_instructions_to_extract_edges(
monkeypatch,
):
"""Test that custom_extraction_instructions is passed to extract_edges."""
clients = _make_clients()
episode = _make_episode('1')
# Track calls to extract_edges
extract_edges_calls = []
extracted_node = EntityNode(name='Test', group_id='group', labels=['Entity'])
async def mock_extract_nodes(
clients,
episode,
previous_episodes,
entity_types=None,
excluded_entity_types=None,
custom_extraction_instructions=None,
):
return [extracted_node]
async def mock_extract_edges(
clients,
episode,
nodes,
previous_episodes,
edge_type_map,
group_id='',
edge_types=None,
custom_extraction_instructions=None,
):
extract_edges_calls.append(
{
'nodes': nodes,
'edge_type_map': edge_type_map,
'edge_types': edge_types,
'custom_extraction_instructions': custom_extraction_instructions,
}
)
return []
monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)
custom_instructions = 'Extract only professional relationships between people.'
await extract_nodes_and_edges_bulk(
clients,
[(episode, [])],
edge_type_map={('Entity', 'Entity'): ['knows']},
custom_extraction_instructions=custom_instructions,
)
assert len(extract_edges_calls) == 1
assert extract_edges_calls[0]['custom_extraction_instructions'] == custom_instructions
assert extract_edges_calls[0]['nodes'] == [extracted_node]
@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_custom_instructions_none_by_default(monkeypatch):
"""Test that custom_extraction_instructions defaults to None when not provided."""
clients = _make_clients()
episode = _make_episode('1')
extract_nodes_calls = []
extract_edges_calls = []
async def mock_extract_nodes(
clients,
episode,
previous_episodes,
entity_types=None,
excluded_entity_types=None,
custom_extraction_instructions=None,
):
extract_nodes_calls.append(
{'custom_extraction_instructions': custom_extraction_instructions}
)
return []
async def mock_extract_edges(
clients,
episode,
nodes,
previous_episodes,
edge_type_map,
group_id='',
edge_types=None,
custom_extraction_instructions=None,
):
extract_edges_calls.append(
{'custom_extraction_instructions': custom_extraction_instructions}
)
return []
monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)
# Call without custom_extraction_instructions
await extract_nodes_and_edges_bulk(
clients,
[(episode, [])],
edge_type_map={},
)
assert len(extract_nodes_calls) == 1
assert extract_nodes_calls[0]['custom_extraction_instructions'] is None
assert len(extract_edges_calls) == 1
assert extract_edges_calls[0]['custom_extraction_instructions'] is None
@pytest.mark.asyncio
async def test_extract_nodes_and_edges_bulk_custom_instructions_multiple_episodes(monkeypatch):
"""Test that custom_extraction_instructions is passed for all episodes in bulk."""
clients = _make_clients()
episode1 = _make_episode('1')
episode2 = _make_episode('2')
episode3 = _make_episode('3')
extract_nodes_calls = []
extract_edges_calls = []
async def mock_extract_nodes(
clients,
episode,
previous_episodes,
entity_types=None,
excluded_entity_types=None,
custom_extraction_instructions=None,
):
extract_nodes_calls.append(
{
'episode_name': episode.name,
'custom_extraction_instructions': custom_extraction_instructions,
}
)
return []
async def mock_extract_edges(
clients,
episode,
nodes,
previous_episodes,
edge_type_map,
group_id='',
edge_types=None,
custom_extraction_instructions=None,
):
extract_edges_calls.append(
{
'episode_name': episode.name,
'custom_extraction_instructions': custom_extraction_instructions,
}
)
return []
monkeypatch.setattr(bulk_utils, 'extract_nodes', mock_extract_nodes)
monkeypatch.setattr(bulk_utils, 'extract_edges', mock_extract_edges)
custom_instructions = 'Extract entities related to financial transactions.'
await extract_nodes_and_edges_bulk(
clients,
[(episode1, []), (episode2, []), (episode3, [])],
edge_type_map={},
custom_extraction_instructions=custom_instructions,
)
# All 3 episodes should have received the custom instructions
assert len(extract_nodes_calls) == 3
assert len(extract_edges_calls) == 3
for call in extract_nodes_calls:
assert call['custom_extraction_instructions'] == custom_instructions
for call in extract_edges_calls:
assert call['custom_extraction_instructions'] == custom_instructions
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_async_operations.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Asynchronous operation tests for Graphiti MCP Server.
Tests concurrent operations, queue management, and async patterns.
"""
import asyncio
import contextlib
import json
import time
import pytest
from test_fixtures import (
TestDataGenerator,
graphiti_test_client,
)
class TestAsyncQueueManagement:
"""Test asynchronous queue operations and episode processing."""
@pytest.mark.asyncio
async def test_sequential_queue_processing(self):
"""Verify episodes are processed sequentially within a group."""
async with graphiti_test_client() as (session, group_id):
# Add multiple episodes quickly
episodes = []
for i in range(5):
result = await session.call_tool(
'add_memory',
{
'name': f'Sequential Test {i}',
'episode_body': f'Episode {i} with timestamp {time.time()}',
'source': 'text',
'source_description': 'sequential test',
'group_id': group_id,
'reference_id': f'seq_{i}', # Add reference for tracking
},
)
episodes.append(result)
# Wait for processing
await asyncio.sleep(10) # Allow time for sequential processing
# Retrieve episodes and verify order
result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})
processed_episodes = json.loads(result.content[0].text)['episodes']
# Verify all episodes were processed
assert len(processed_episodes) >= 5, (
f'Expected at least 5 episodes, got {len(processed_episodes)}'
)
# Verify sequential processing (timestamps should be ordered)
timestamps = [ep.get('created_at') for ep in processed_episodes]
assert timestamps == sorted(timestamps), 'Episodes not processed in order'
@pytest.mark.asyncio
async def test_concurrent_group_processing(self):
"""Test that different groups can process concurrently."""
async with graphiti_test_client() as (session, _):
groups = [f'group_{i}_{time.time()}' for i in range(3)]
tasks = []
# Create tasks for different groups
for group_id in groups:
for j in range(2):
task = session.call_tool(
'add_memory',
{
'name': f'Group {group_id} Episode {j}',
'episode_body': f'Content for {group_id}',
'source': 'text',
'source_description': 'concurrent test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute all tasks concurrently
start_time = time.time()
results = await asyncio.gather(*tasks, return_exceptions=True)
execution_time = time.time() - start_time
# Verify all succeeded
failures = [r for r in results if isinstance(r, Exception)]
assert not failures, f'Concurrent operations failed: {failures}'
# Check that execution was actually concurrent (should be faster than sequential)
# Sequential would take at least 6 * processing_time
assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'
@pytest.mark.asyncio
async def test_queue_overflow_handling(self):
"""Test behavior when queue reaches capacity."""
async with graphiti_test_client() as (session, group_id):
# Attempt to add many episodes rapidly
tasks = []
for i in range(100): # Large number to potentially overflow
task = session.call_tool(
'add_memory',
{
'name': f'Overflow Test {i}',
'episode_body': f'Episode {i}',
'source': 'text',
'source_description': 'overflow test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute with gathering to catch any failures
results = await asyncio.gather(*tasks, return_exceptions=True)
# Count successful queuing
successful = sum(1 for r in results if not isinstance(r, Exception))
# Should handle overflow gracefully
assert successful > 0, 'No episodes were queued successfully'
# Log overflow behavior
if successful < 100:
print(f'Queue overflow: {successful}/100 episodes queued')
class TestConcurrentOperations:
"""Test concurrent tool calls and operations."""
@pytest.mark.asyncio
async def test_concurrent_search_operations(self):
"""Test multiple concurrent search operations."""
async with graphiti_test_client() as (session, group_id):
# First, add some test data
data_gen = TestDataGenerator()
add_tasks = []
for _ in range(5):
task = session.call_tool(
'add_memory',
{
'name': 'Search Test Data',
'episode_body': data_gen.generate_technical_document(),
'source': 'text',
'source_description': 'search test',
'group_id': group_id,
},
)
add_tasks.append(task)
await asyncio.gather(*add_tasks)
await asyncio.sleep(15) # Wait for processing
# Now perform concurrent searches
search_queries = [
'architecture',
'performance',
'implementation',
'dependencies',
'latency',
]
search_tasks = []
for query in search_queries:
task = session.call_tool(
'search_memory_nodes',
{
'query': query,
'group_id': group_id,
'limit': 10,
},
)
search_tasks.append(task)
start_time = time.time()
results = await asyncio.gather(*search_tasks, return_exceptions=True)
search_time = time.time() - start_time
# Verify all searches completed
failures = [r for r in results if isinstance(r, Exception)]
assert not failures, f'Search operations failed: {failures}'
# Verify concurrent execution efficiency
assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'
@pytest.mark.asyncio
async def test_mixed_operation_concurrency(self):
"""Test different types of operations running concurrently."""
async with graphiti_test_client() as (session, group_id):
operations = []
# Add memory operation
operations.append(
session.call_tool(
'add_memory',
{
'name': 'Mixed Op Test',
'episode_body': 'Testing mixed operations',
'source': 'text',
'source_description': 'test',
'group_id': group_id,
},
)
)
# Search operation
operations.append(
session.call_tool(
'search_memory_nodes',
{
'query': 'test',
'group_id': group_id,
'limit': 5,
},
)
)
# Get episodes operation
operations.append(
session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 10,
},
)
)
# Get status operation
operations.append(session.call_tool('get_status', {}))
# Execute all concurrently
results = await asyncio.gather(*operations, return_exceptions=True)
# Check results
for i, result in enumerate(results):
assert not isinstance(result, Exception), f'Operation {i} failed: {result}'
class TestAsyncErrorHandling:
"""Test async error handling and recovery."""
@pytest.mark.asyncio
async def test_timeout_recovery(self):
"""Test recovery from operation timeouts."""
async with graphiti_test_client() as (session, group_id):
# Create a very large episode that might time out
large_content = 'x' * 1000000 # 1MB of data
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(
session.call_tool(
'add_memory',
{
'name': 'Timeout Test',
'episode_body': large_content,
'source': 'text',
'source_description': 'timeout test',
'group_id': group_id,
},
),
timeout=2.0, # Short timeout - expected to timeout
)
# Verify server is still responsive after timeout
status_result = await session.call_tool('get_status', {})
assert status_result is not None, 'Server unresponsive after timeout'
@pytest.mark.asyncio
async def test_cancellation_handling(self):
"""Test proper handling of cancelled operations."""
async with graphiti_test_client() as (session, group_id):
# Start a long-running operation
task = asyncio.create_task(
session.call_tool(
'add_memory',
{
'name': 'Cancellation Test',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'cancel test',
'group_id': group_id,
},
)
)
# Cancel after a short delay
await asyncio.sleep(0.1)
task.cancel()
# Verify cancellation was handled
with pytest.raises(asyncio.CancelledError):
await task
# Server should still be operational
result = await session.call_tool('get_status', {})
assert result is not None
@pytest.mark.asyncio
async def test_exception_propagation(self):
"""Test that exceptions are properly propagated in async context."""
async with graphiti_test_client() as (session, group_id):
# Call with invalid arguments
with pytest.raises(ValueError):
await session.call_tool(
'add_memory',
{
# Missing required fields
'group_id': group_id,
},
)
# Server should remain operational
status = await session.call_tool('get_status', {})
assert status is not None
class TestAsyncPerformance:
"""Performance tests for async operations."""
@pytest.mark.asyncio
async def test_async_throughput(self, performance_benchmark):
"""Measure throughput of async operations."""
async with graphiti_test_client() as (session, group_id):
num_operations = 50
start_time = time.time()
# Create many concurrent operations
tasks = []
for i in range(num_operations):
task = session.call_tool(
'add_memory',
{
'name': f'Throughput Test {i}',
'episode_body': f'Content {i}',
'source': 'text',
'source_description': 'throughput test',
'group_id': group_id,
},
)
tasks.append(task)
# Execute all
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.time() - start_time
# Calculate metrics
successful = sum(1 for r in results if not isinstance(r, Exception))
throughput = successful / total_time
performance_benchmark.record('async_throughput', throughput)
# Log results
print('\nAsync Throughput Test:')
print(f' Operations: {num_operations}')
print(f' Successful: {successful}')
print(f' Total time: {total_time:.2f}s')
print(f' Throughput: {throughput:.2f} ops/s')
# Assert minimum throughput
assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'
@pytest.mark.asyncio
async def test_latency_under_load(self, performance_benchmark):
"""Test operation latency under concurrent load."""
async with graphiti_test_client() as (session, group_id):
# Create background load
background_tasks = []
for i in range(10):
task = asyncio.create_task(
session.call_tool(
'add_memory',
{
'name': f'Background {i}',
'episode_body': TestDataGenerator.generate_technical_document(),
'source': 'text',
'source_description': 'background',
'group_id': f'background_{group_id}',
},
)
)
background_tasks.append(task)
# Measure latency of operations under load
latencies = []
for _ in range(5):
start = time.time()
await session.call_tool('get_status', {})
latency = time.time() - start
latencies.append(latency)
performance_benchmark.record('latency_under_load', latency)
# Clean up background tasks
for task in background_tasks:
task.cancel()
# Analyze latencies
avg_latency = sum(latencies) / len(latencies)
max_latency = max(latencies)
print('\nLatency Under Load:')
print(f' Average: {avg_latency:.3f}s')
print(f' Max: {max_latency:.3f}s')
# Assert acceptable latency
assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'
class TestAsyncStreamHandling:
"""Test handling of streaming responses and data."""
@pytest.mark.asyncio
async def test_large_response_streaming(self):
"""Test handling of large streamed responses."""
async with graphiti_test_client() as (session, group_id):
# Add many episodes
for i in range(20):
await session.call_tool(
'add_memory',
{
'name': f'Stream Test {i}',
'episode_body': f'Episode content {i}',
'source': 'text',
'source_description': 'stream test',
'group_id': group_id,
},
)
# Wait for processing
await asyncio.sleep(30)
# Request large result set
result = await session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 100, # Request all
},
)
# Verify response handling
episodes = json.loads(result.content[0].text)['episodes']
assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'
@pytest.mark.asyncio
async def test_incremental_processing(self):
"""Test incremental processing of results."""
async with graphiti_test_client() as (session, group_id):
# Add episodes incrementally
for batch in range(3):
batch_tasks = []
for i in range(5):
task = session.call_tool(
'add_memory',
{
'name': f'Batch {batch} Item {i}',
'episode_body': f'Content for batch {batch}',
'source': 'text',
'source_description': 'incremental test',
'group_id': group_id,
},
)
batch_tasks.append(task)
# Process batch
await asyncio.gather(*batch_tasks)
# Wait for this batch to process
await asyncio.sleep(10)
# Verify incremental results
result = await session.call_tool(
'get_episodes',
{
'group_id': group_id,
'last_n': 100,
},
)
episodes = json.loads(result.content[0].text)['episodes']
expected_min = (batch + 1) * 5
assert len(episodes) >= expected_min, (
f'Batch {batch}: Expected at least {expected_min} episodes'
)
if __name__ == '__main__':
pytest.main([__file__, '-v', '--asyncio-mode=auto'])
```
--------------------------------------------------------------------------------
/graphiti_core/search/search.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections import defaultdict
from time import time
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.edges import EntityEdge
from graphiti_core.embedder.client import EMBEDDING_DIM
from graphiti_core.errors import SearchRerankerError
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import semaphore_gather
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
from graphiti_core.search.search_config import (
DEFAULT_SEARCH_LIMIT,
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
EpisodeReranker,
EpisodeSearchConfig,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
SearchResults,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_fulltext_search,
episode_mentions_reranker,
get_embeddings_for_communities,
get_embeddings_for_edges,
get_embeddings_for_nodes,
maximal_marginal_relevance,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
rrf,
)
logger = logging.getLogger(__name__)
async def search(
clients: GraphitiClients,
query: str,
group_ids: list[str] | None,
config: SearchConfig,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
query_vector: list[float] | None = None,
driver: GraphDriver | None = None,
) -> SearchResults:
start = time()
driver = driver or clients.driver
embedder = clients.embedder
cross_encoder = clients.cross_encoder
if query.strip() == '':
return SearchResults()
if (
config.edge_config
and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
or config.edge_config
and EdgeReranker.mmr == config.edge_config.reranker
or config.node_config
and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
or config.node_config
and NodeReranker.mmr == config.node_config.reranker
or (
config.community_config
and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
)
or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
):
search_vector = (
query_vector
if query_vector is not None
else await embedder.create(input_data=[query.replace('\n', ' ')])
)
else:
search_vector = [0.0] * EMBEDDING_DIM
# if group_ids is empty, set it to None
group_ids = group_ids if group_ids and group_ids != [''] else None
(
(edges, edge_reranker_scores),
(nodes, node_reranker_scores),
(episodes, episode_reranker_scores),
(communities, community_reranker_scores),
) = await semaphore_gather(
edge_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.edge_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
node_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.node_config,
search_filter,
center_node_uuid,
bfs_origin_node_uuids,
config.limit,
config.reranker_min_score,
),
episode_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.episode_config,
search_filter,
config.limit,
config.reranker_min_score,
),
community_search(
driver,
cross_encoder,
query,
search_vector,
group_ids,
config.community_config,
config.limit,
config.reranker_min_score,
),
)
results = SearchResults(
edges=edges,
edge_reranker_scores=edge_reranker_scores,
nodes=nodes,
node_reranker_scores=node_reranker_scores,
episodes=episodes,
episode_reranker_scores=episode_reranker_scores,
communities=communities,
community_reranker_scores=community_reranker_scores,
)
latency = (time() - start) * 1000
logger.debug(f'search returned context for query {query} in {latency} ms')
return results
async def edge_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: EdgeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[EntityEdge], list[float]]:
if config is None:
return [], []
# Build search tasks based on configured search methods
search_tasks = []
if EdgeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
)
if EdgeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
edge_similarity_search(
driver,
query_vector,
None,
None,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
)
)
if EdgeSearchMethod.bfs in config.search_methods:
search_tasks.append(
edge_bfs_search(
driver,
bfs_origin_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
)
)
# Execute only the configured search methods
search_results: list[list[EntityEdge]] = []
if search_tasks:
search_results = list(await semaphore_gather(*search_tasks))
if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
search_results.append(
await edge_bfs_search(
driver,
source_node_uuids,
config.bfs_max_depth,
search_filter,
group_ids,
2 * limit,
)
)
edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
reranked_uuids: list[str] = []
edge_scores: list[float] = []
if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EdgeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_edges(
driver, list(edge_uuid_map.values())
)
reranked_uuids, edge_scores = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
reranker_min_score,
)
elif config.reranker == EdgeReranker.cross_encoder:
fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
reranked_uuids = [
fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
]
edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
elif config.reranker == EdgeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
# use rrf as a preliminary sort
sorted_result_uuids, node_scores = rrf(
[[edge.uuid for edge in result] for result in search_results],
min_score=reranker_min_score,
)
sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
# node distance reranking
source_to_edge_uuid_map = defaultdict(list)
for edge in sorted_results:
source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
reranked_node_uuids, edge_scores = await node_distance_reranker(
driver, source_uuids, center_node_uuid, min_score=reranker_min_score
)
for node_uuid in reranked_node_uuids:
reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
if config.reranker == EdgeReranker.episode_mentions:
reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
return reranked_edges[:limit], edge_scores[:limit]
async def node_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: NodeSearchConfig | None,
search_filter: SearchFilters,
center_node_uuid: str | None = None,
bfs_origin_node_uuids: list[str] | None = None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[EntityNode], list[float]]:
if config is None:
return [], []
# Build search tasks based on configured search methods
search_tasks = []
if NodeSearchMethod.bm25 in config.search_methods:
search_tasks.append(
node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
)
if NodeSearchMethod.cosine_similarity in config.search_methods:
search_tasks.append(
node_similarity_search(
driver,
query_vector,
search_filter,
group_ids,
2 * limit,
config.sim_min_score,
)
)
if NodeSearchMethod.bfs in config.search_methods:
search_tasks.append(
node_bfs_search(
driver,
bfs_origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
)
)
# Execute only the configured search methods
search_results: list[list[EntityNode]] = []
if search_tasks:
search_results = list(await semaphore_gather(*search_tasks))
if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
origin_node_uuids = [node.uuid for result in search_results for node in result]
search_results.append(
await node_bfs_search(
driver,
origin_node_uuids,
search_filter,
config.bfs_max_depth,
group_ids,
2 * limit,
)
)
search_result_uuids = [[node.uuid for node in result] for result in search_results]
node_uuid_map = {node.uuid: node for result in search_results for node in result}
reranked_uuids: list[str] = []
node_scores: list[float] = []
if config.reranker == NodeReranker.rrf:
reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == NodeReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_nodes(
driver, list(node_uuid_map.values())
)
reranked_uuids, node_scores = maximal_marginal_relevance(
query_vector,
search_result_uuids_and_vectors,
config.mmr_lambda,
reranker_min_score,
)
elif config.reranker == NodeReranker.cross_encoder:
name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
reranked_uuids = [
name_to_uuid_map[name]
for name, score in reranked_node_names
if score >= reranker_min_score
]
node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
elif config.reranker == NodeReranker.episode_mentions:
reranked_uuids, node_scores = await episode_mentions_reranker(
driver, search_result_uuids, min_score=reranker_min_score
)
elif config.reranker == NodeReranker.node_distance:
if center_node_uuid is None:
raise SearchRerankerError('No center node provided for Node Distance reranker')
reranked_uuids, node_scores = await node_distance_reranker(
driver,
rrf(search_result_uuids, min_score=reranker_min_score)[0],
center_node_uuid,
min_score=reranker_min_score,
)
reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_nodes[:limit], node_scores[:limit]
async def episode_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
_query_vector: list[float],
group_ids: list[str] | None,
config: EpisodeSearchConfig | None,
search_filter: SearchFilters,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[EpisodicNode], list[float]]:
if config is None:
return [], []
search_results: list[list[EpisodicNode]] = list(
await semaphore_gather(
*[
episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
]
)
)
search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
reranked_uuids: list[str] = []
episode_scores: list[float] = []
if config.reranker == EpisodeReranker.rrf:
reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == EpisodeReranker.cross_encoder:
# use rrf as a preliminary reranker
rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
reranked_uuids = [
content_to_uuid_map[content]
for content, score in reranked_contents
if score >= reranker_min_score
]
episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_episodes[:limit], episode_scores[:limit]
async def community_search(
driver: GraphDriver,
cross_encoder: CrossEncoderClient,
query: str,
query_vector: list[float],
group_ids: list[str] | None,
config: CommunitySearchConfig | None,
limit=DEFAULT_SEARCH_LIMIT,
reranker_min_score: float = 0,
) -> tuple[list[CommunityNode], list[float]]:
if config is None:
return [], []
search_results: list[list[CommunityNode]] = list(
await semaphore_gather(
*[
community_fulltext_search(driver, query, group_ids, 2 * limit),
community_similarity_search(
driver, query_vector, group_ids, 2 * limit, config.sim_min_score
),
]
)
)
search_result_uuids = [[community.uuid for community in result] for result in search_results]
community_uuid_map = {
community.uuid: community for result in search_results for community in result
}
reranked_uuids: list[str] = []
community_scores: list[float] = []
if config.reranker == CommunityReranker.rrf:
reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
elif config.reranker == CommunityReranker.mmr:
search_result_uuids_and_vectors = await get_embeddings_for_communities(
driver, list(community_uuid_map.values())
)
reranked_uuids, community_scores = maximal_marginal_relevance(
query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
)
elif config.reranker == CommunityReranker.cross_encoder:
name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
reranked_uuids = [
name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
]
community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
return reranked_communities[:limit], community_scores[:limit]
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_integration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
Tests all major MCP tools and handles episode processing latency.
"""
import asyncio
import json
import os
import time
from typing import Any
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
class GraphitiMCPIntegrationTest:
"""Integration test client for Graphiti MCP Server using official MCP SDK."""
def __init__(self):
self.test_group_id = f'test_group_{int(time.time())}'
self.session = None
async def __aenter__(self):
"""Start the MCP client session."""
# Configure server parameters to run our refactored server
server_params = StdioServerParameters(
command='uv',
args=['run', 'main.py', '--transport', 'stdio'],
env={
'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
},
)
print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
# Use the async context manager properly
self.client_context = stdio_client(server_params)
read, write = await self.client_context.__aenter__()
self.session = ClientSession(read, write)
await self.session.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Close the MCP client session."""
if self.session:
await self.session.close()
if hasattr(self, 'client_context'):
await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
"""Call an MCP tool and return the result."""
try:
result = await self.session.call_tool(tool_name, arguments)
return result.content[0].text if result.content else {'error': 'No content returned'}
except Exception as e:
return {'error': str(e)}
async def test_server_initialization(self) -> bool:
"""Test that the server initializes properly."""
print('🔍 Testing server initialization...')
try:
# List available tools to verify server is responding
tools_result = await self.session.list_tools()
tools = [tool.name for tool in tools_result.tools]
expected_tools = [
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'delete_entity_edge',
'get_entity_edge',
'clear_graph',
]
available_tools = len([tool for tool in expected_tools if tool in tools])
print(
f' ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
)
print(f' Available tools: {", ".join(sorted(tools))}')
return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools
except Exception as e:
print(f' ❌ Server initialization failed: {e}')
return False
async def test_add_memory_operations(self) -> dict[str, bool]:
"""Test adding various types of memory episodes."""
print('📝 Testing add_memory operations...')
results = {}
# Test 1: Add text episode
print(' Testing text episode...')
try:
result = await self.call_tool(
'add_memory',
{
'name': 'Test Company News',
'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
'source': 'text',
'source_description': 'news article',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ Text episode: {result}')
results['text'] = True
else:
print(f' ❌ Text episode failed: {result}')
results['text'] = False
except Exception as e:
print(f' ❌ Text episode error: {e}')
results['text'] = False
# Test 2: Add JSON episode
print(' Testing JSON episode...')
try:
json_data = {
'company': {'name': 'TechCorp', 'founded': 2010},
'products': [
{'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
{'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
],
'employees': 150,
}
result = await self.call_tool(
'add_memory',
{
'name': 'Company Profile',
'episode_body': json.dumps(json_data),
'source': 'json',
'source_description': 'CRM data',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ JSON episode: {result}')
results['json'] = True
else:
print(f' ❌ JSON episode failed: {result}')
results['json'] = False
except Exception as e:
print(f' ❌ JSON episode error: {e}')
results['json'] = False
# Test 3: Add message episode
print(' Testing message episode...')
try:
result = await self.call_tool(
'add_memory',
{
'name': 'Customer Support Chat',
'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
'source': 'message',
'source_description': 'support chat log',
'group_id': self.test_group_id,
},
)
if isinstance(result, str) and 'queued' in result.lower():
print(f' ✅ Message episode: {result}')
results['message'] = True
else:
print(f' ❌ Message episode failed: {result}')
results['message'] = False
except Exception as e:
print(f' ❌ Message episode error: {e}')
results['message'] = False
return results
async def wait_for_processing(self, max_wait: int = 45) -> bool:
"""Wait for episode processing to complete."""
print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
for i in range(max_wait):
await asyncio.sleep(1)
try:
# Check if we have any episodes
result = await self.call_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
# Parse the JSON result if it's a string
if isinstance(result, str):
try:
parsed_result = json.loads(result)
if isinstance(parsed_result, list) and len(parsed_result) > 0:
print(
f' ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
)
return True
except json.JSONDecodeError:
if 'episodes' in result.lower():
print(f' ✅ Episodes detected after {i + 1} seconds')
return True
except Exception as e:
if i == 0: # Only log first error to avoid spam
print(f' ⚠️ Waiting for processing... ({e})')
continue
print(f' ⚠️ Still waiting after {max_wait} seconds...')
return False
async def test_search_operations(self) -> dict[str, bool]:
"""Test search functionality."""
print('🔍 Testing search operations...')
results = {}
# Test search_memory_nodes
print(' Testing search_memory_nodes...')
try:
result = await self.call_tool(
'search_memory_nodes',
{
'query': 'Acme Corp product launch AI',
'group_ids': [self.test_group_id],
'max_nodes': 5,
},
)
success = False
if isinstance(result, str):
try:
parsed = json.loads(result)
nodes = parsed.get('nodes', [])
success = isinstance(nodes, list)
print(f' ✅ Node search returned {len(nodes)} nodes')
except json.JSONDecodeError:
success = 'nodes' in result.lower() and 'successfully' in result.lower()
if success:
print(' ✅ Node search completed successfully')
results['nodes'] = success
if not success:
print(f' ❌ Node search failed: {result}')
except Exception as e:
print(f' ❌ Node search error: {e}')
results['nodes'] = False
# Test search_memory_facts
print(' Testing search_memory_facts...')
try:
result = await self.call_tool(
'search_memory_facts',
{
'query': 'company products software TechCorp',
'group_ids': [self.test_group_id],
'max_facts': 5,
},
)
success = False
if isinstance(result, str):
try:
parsed = json.loads(result)
facts = parsed.get('facts', [])
success = isinstance(facts, list)
print(f' ✅ Fact search returned {len(facts)} facts')
except json.JSONDecodeError:
success = 'facts' in result.lower() and 'successfully' in result.lower()
if success:
print(' ✅ Fact search completed successfully')
results['facts'] = success
if not success:
print(f' ❌ Fact search failed: {result}')
except Exception as e:
print(f' ❌ Fact search error: {e}')
results['facts'] = False
return results
async def test_episode_retrieval(self) -> bool:
"""Test episode retrieval."""
print('📚 Testing episode retrieval...')
try:
result = await self.call_tool(
'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
)
if isinstance(result, str):
try:
parsed = json.loads(result)
if isinstance(parsed, list):
print(f' ✅ Retrieved {len(parsed)} episodes')
# Show episode details
for i, episode in enumerate(parsed[:3]):
name = episode.get('name', 'Unknown')
source = episode.get('source', 'unknown')
print(f' Episode {i + 1}: {name} (source: {source})')
return len(parsed) > 0
except json.JSONDecodeError:
# Check if response indicates success
if 'episode' in result.lower():
print(' ✅ Episode retrieval completed')
return True
print(f' ❌ Unexpected result format: {result}')
return False
except Exception as e:
print(f' ❌ Episode retrieval failed: {e}')
return False
async def test_error_handling(self) -> dict[str, bool]:
"""Test error handling and edge cases."""
print('🧪 Testing error handling...')
results = {}
# Test with nonexistent group
print(' Testing nonexistent group handling...')
try:
result = await self.call_tool(
'search_memory_nodes',
{
'query': 'nonexistent data',
'group_ids': ['nonexistent_group_12345'],
'max_nodes': 5,
},
)
# Should handle gracefully, not crash
success = (
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
)
if success:
print(' ✅ Nonexistent group handled gracefully')
else:
print(f' ❌ Nonexistent group caused issues: {result}')
results['nonexistent_group'] = success
except Exception as e:
print(f' ❌ Nonexistent group test failed: {e}')
results['nonexistent_group'] = False
# Test empty query
print(' Testing empty query handling...')
try:
result = await self.call_tool(
'search_memory_nodes',
{'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
)
# Should handle gracefully
success = (
'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
)
if success:
print(' ✅ Empty query handled gracefully')
else:
print(f' ❌ Empty query caused issues: {result}')
results['empty_query'] = success
except Exception as e:
print(f' ❌ Empty query test failed: {e}')
results['empty_query'] = False
return results
async def run_comprehensive_test(self) -> dict[str, Any]:
"""Run the complete integration test suite."""
print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
print(f' Test group ID: {self.test_group_id}')
print('=' * 70)
results = {
'server_init': False,
'add_memory': {},
'processing_wait': False,
'search': {},
'episodes': False,
'error_handling': {},
'overall_success': False,
}
# Test 1: Server Initialization
results['server_init'] = await self.test_server_initialization()
if not results['server_init']:
print('❌ Server initialization failed, aborting remaining tests')
return results
print()
# Test 2: Add Memory Operations
results['add_memory'] = await self.test_add_memory_operations()
print()
# Test 3: Wait for Processing
results['processing_wait'] = await self.wait_for_processing()
print()
# Test 4: Search Operations
results['search'] = await self.test_search_operations()
print()
# Test 5: Episode Retrieval
results['episodes'] = await self.test_episode_retrieval()
print()
# Test 6: Error Handling
results['error_handling'] = await self.test_error_handling()
print()
# Calculate overall success
memory_success = any(results['add_memory'].values())
search_success = any(results['search'].values()) if results['search'] else False
error_success = (
any(results['error_handling'].values()) if results['error_handling'] else True
)
results['overall_success'] = (
results['server_init']
and memory_success
and (results['episodes'] or results['processing_wait'])
and error_success
)
# Print comprehensive summary
print('=' * 70)
print('📊 COMPREHENSIVE TEST SUMMARY')
print('-' * 35)
print(f'Server Initialization: {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
print(
f'Memory Operations: {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
)
print(f'Processing Pipeline: {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
search_stats = (
f'({sum(results["search"].values())}/{len(results["search"])} types)'
if results['search']
else '(0/0 types)'
)
print(
f'Search Operations: {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
)
print(f'Episode Retrieval: {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
error_stats = (
f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
if results['error_handling']
else '(0/0 cases)'
)
print(
f'Error Handling: {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
)
print('-' * 35)
print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
if results['overall_success']:
print('\n🎉 The refactored Graphiti MCP server is working correctly!')
print(' All core functionality has been successfully tested.')
else:
print('\n⚠️ Some issues were detected. Review the test results above.')
print(' The refactoring may need additional attention.')
return results
async def main():
"""Run the integration test."""
try:
async with GraphitiMCPIntegrationTest() as test:
results = await test.run_comprehensive_test()
# Exit with appropriate code
exit_code = 0 if results['overall_success'] else 1
exit(exit_code)
except Exception as e:
print(f'❌ Test setup failed: {e}')
exit(1)
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/gemini_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import re
import typing
from typing import TYPE_CHECKING, ClassVar
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient, get_extraction_language_instruction
from .config import LLMConfig, ModelSize
from .errors import RateLimitError
if TYPE_CHECKING:
from google import genai
from google.genai import types
else:
try:
from google import genai
from google.genai import types
except ImportError:
# If gemini client is not installed, raise an ImportError
raise ImportError(
'google-genai is required for GeminiClient. '
'Install it with: pip install graphiti-core[google-genai]'
) from None
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.5-flash'
DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'
# Maximum output tokens for different Gemini models
GEMINI_MODEL_MAX_TOKENS = {
# Gemini 2.5 models
'gemini-2.5-pro': 65536,
'gemini-2.5-flash': 65536,
'gemini-2.5-flash-lite': 64000,
# Gemini 2.0 models
'gemini-2.0-flash': 8192,
'gemini-2.0-flash-lite': 8192,
# Gemini 1.5 models
'gemini-1.5-pro': 8192,
'gemini-1.5-flash': 8192,
'gemini-1.5-flash-8b': 8192,
}
# Default max tokens for models not in the mapping
DEFAULT_GEMINI_MAX_TOKENS = 8192
class GeminiClient(LLMClient):
"""
GeminiClient is a client class for interacting with Google's Gemini language models.
This class extends the LLMClient and provides methods to initialize the client
and generate responses from the Gemini language model.
Attributes:
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
max_tokens: int | None = None,
thinking_config: types.ThinkingConfig | None = None,
client: 'genai.Client | None' = None,
):
"""
Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
Only use with models that support thinking (gemini-2.5+). Defaults to None.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()
super().__init__(config, cache)
self.model = config.model
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
self.max_tokens = max_tokens
self.thinking_config = thinking_config
def _check_safety_blocks(self, response) -> None:
"""Check if response was blocked for safety reasons and raise appropriate exceptions."""
# Check if the response was blocked for safety reasons
if not (hasattr(response, 'candidates') and response.candidates):
return
candidate = response.candidates[0]
if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
return
# Content was blocked for safety reasons - collect safety details
safety_info = []
safety_ratings = getattr(candidate, 'safety_ratings', None)
if safety_ratings:
for rating in safety_ratings:
if getattr(rating, 'blocked', False):
category = getattr(rating, 'category', 'Unknown')
probability = getattr(rating, 'probability', 'Unknown')
safety_info.append(f'{category}: {probability}')
safety_details = (
', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
)
raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
def _check_prompt_blocks(self, response) -> None:
"""Check if prompt was blocked and raise appropriate exceptions."""
prompt_feedback = getattr(response, 'prompt_feedback', None)
if not prompt_feedback:
return
block_reason = getattr(prompt_feedback, 'block_reason', None)
if block_reason:
raise Exception(f'Prompt blocked by Gemini: {block_reason}')
def _get_model_for_size(self, model_size: ModelSize) -> str:
"""Get the appropriate model name based on the requested size."""
if model_size == ModelSize.small:
return self.small_model or DEFAULT_SMALL_MODEL
else:
return self.model or DEFAULT_MODEL
def _get_max_tokens_for_model(self, model: str) -> int:
"""Get the maximum output tokens for a specific Gemini model."""
return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
"""
Resolve the maximum output tokens to use based on precedence rules.
Precedence order (highest to lowest):
1. Explicit max_tokens parameter passed to generate_response()
2. Instance max_tokens set during client initialization
3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
4. DEFAULT_MAX_TOKENS as final fallback
Args:
requested_max_tokens: The max_tokens parameter passed to generate_response()
model: The model name to look up model-specific limits
Returns:
int: The resolved maximum tokens to use
"""
# 1. Use explicit parameter if provided
if requested_max_tokens is not None:
return requested_max_tokens
# 2. Use instance max_tokens if set during initialization
if self.max_tokens is not None:
return self.max_tokens
# 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
return self._get_max_tokens_for_model(model)
def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
"""
Attempt to salvage a JSON object if the raw output is truncated.
This is accomplished by looking for the last closing bracket for an array or object.
If found, it will try to load the JSON object from the raw output.
If the JSON object is not valid, it will return None.
Args:
raw_output (str): The raw output from the LLM.
Returns:
dict[str, typing.Any]: The salvaged JSON object.
None: If no salvage is possible.
"""
if not raw_output:
return None
# Try to salvage a JSON array
array_match = re.search(r'\]\s*$', raw_output)
if array_match:
try:
return json.loads(raw_output[: array_match.end()])
except Exception:
pass
# Try to salvage a JSON object
obj_match = re.search(r'\}\s*$', raw_output)
if obj_match:
try:
return json.loads(raw_output[: obj_match.end()])
except Exception:
pass
return None
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
model_size (ModelSize): The size of the model to use (small or medium).
Returns:
dict[str, typing.Any]: The response from the language model.
Raises:
RateLimitError: If the API rate limit is exceeded.
Exception: If there is an error generating the response or content is blocked.
"""
try:
gemini_messages: typing.Any = []
# If a response model is provided, add schema for structured output
system_prompt = ''
if response_model is not None:
# Get the schema from the Pydantic model
pydantic_schema = response_model.model_json_schema()
# Create instruction to output in the desired JSON format
system_prompt += (
f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
'Do not include any explanatory text before or after the JSON.\n\n'
)
# Add messages content
# First check for a system message
if messages and messages[0].role == 'system':
system_prompt = f'{messages[0].content}\n\n {system_prompt}'
messages = messages[1:]
# Add the rest of the messages
for m in messages:
m.content = self._clean_input(m.content)
gemini_messages.append(
types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
)
# Get the appropriate model for the requested size
model = self._get_model_for_size(model_size)
# Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
# Create generation config
generation_config = types.GenerateContentConfig(
temperature=self.temperature,
max_output_tokens=resolved_max_tokens,
response_mime_type='application/json' if response_model else None,
response_schema=response_model if response_model else None,
system_instruction=system_prompt,
thinking_config=self.thinking_config,
)
# Generate content using the simple string approach
response = await self.client.aio.models.generate_content(
model=model,
contents=gemini_messages,
config=generation_config,
)
# Always capture the raw output for debugging
raw_output = getattr(response, 'text', None)
# Check for safety and prompt blocks
self._check_safety_blocks(response)
self._check_prompt_blocks(response)
# If this was a structured output request, parse the response into the Pydantic model
if response_model is not None:
try:
if not raw_output:
raise ValueError('No response text')
validated_model = response_model.model_validate(json.loads(raw_output))
# Return as a dictionary for API consistency
return validated_model.model_dump()
except Exception as e:
if raw_output:
logger.error(
'🦀 LLM generation failed parsing as JSON, will try to salvage.'
)
logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
# Try to salvage
salvaged = self.salvage_json(raw_output)
if salvaged is not None:
logger.warning('Salvaged partial JSON from truncated/malformed output.')
return salvaged
raise Exception(f'Failed to parse structured response: {e}') from e
# Otherwise, return the response text as a dictionary
return {'content': raw_output}
except Exception as e:
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise Exception from e
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
prompt_name: str | None = None,
) -> dict[str, typing.Any]:
"""
Generate a response from the Gemini language model with retry logic and error handling.
This method overrides the parent class method to provide a direct implementation with advanced retry logic.
Args:
messages (list[Message]): A list of messages to send to the language model.
response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
max_tokens (int | None): The maximum number of tokens to generate in the response.
model_size (ModelSize): The size of the model to use (small or medium).
group_id (str | None): Optional partition identifier for the graph.
prompt_name (str | None): Optional name of the prompt for tracing.
Returns:
dict[str, typing.Any]: The response from the language model.
"""
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction(group_id)
# Wrap entire operation in tracing span
with self.tracer.start_span('llm.generate') as span:
attributes = {
'llm.provider': 'gemini',
'model.size': model_size.value,
'max_tokens': max_tokens or self.max_tokens,
}
if prompt_name:
attributes['prompt.name'] = prompt_name
span.add_attributes(attributes)
retry_count = 0
last_error = None
last_output = None
while retry_count < self.MAX_RETRIES:
try:
response = await self._generate_response(
messages=messages,
response_model=response_model,
max_tokens=max_tokens,
model_size=model_size,
)
last_output = (
response.get('content')
if isinstance(response, dict) and 'content' in response
else None
)
return response
except RateLimitError as e:
# Rate limit errors should not trigger retries (fail fast)
span.set_status('error', str(e))
raise e
except Exception as e:
last_error = e
# Check if this is a safety block - these typically shouldn't be retried
error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
logger.warning(f'Content blocked by safety filters: {e}')
span.set_status('error', str(e))
raise Exception(f'Content blocked by safety filters: {e}') from e
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we exit the loop without returning, all retries are exhausted
logger.error('🦀 LLM generation failed and retries are exhausted.')
logger.error(self._get_failed_generation_log(messages, last_output))
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
span.set_status('error', str(last_error))
span.record_exception(last_error) if last_error else None
raise last_error or Exception('Max retries exceeded')
```