This is page 8 of 9. Use http://codebase.md/getzep/graphiti?page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/tests/test_graphiti_mock.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime, timedelta
from unittest.mock import Mock
import numpy as np
import pytest
from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_utils import (
community_fulltext_search,
community_similarity_search,
edge_bfs_search,
edge_fulltext_search,
edge_similarity_search,
episode_fulltext_search,
episode_mentions_reranker,
get_communities_by_nodes,
get_edge_invalidation_candidates,
get_embeddings_for_communities,
get_embeddings_for_edges,
get_embeddings_for_nodes,
get_mentioned_nodes,
get_relevant_edges,
get_relevant_nodes,
node_bfs_search,
node_distance_reranker,
node_fulltext_search,
node_similarity_search,
)
from graphiti_core.utils.bulk_utils import add_nodes_and_edges_bulk
from graphiti_core.utils.maintenance.community_operations import (
determine_entity_community,
get_community_clusters,
remove_communities,
)
from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
from tests.helpers_test import (
GraphProvider,
assert_entity_edge_equals,
assert_entity_node_equals,
assert_episodic_edge_equals,
assert_episodic_node_equals,
get_edge_count,
get_node_count,
group_id,
group_id_2,
)
pytest_plugins = ('pytest_asyncio',)
@pytest.fixture
def mock_llm_client():
"""Create a mock LLM"""
mock_llm = Mock(spec=LLMClient)
mock_llm.config = Mock()
mock_llm.model = 'test-model'
mock_llm.small_model = 'test-small-model'
mock_llm.temperature = 0.0
mock_llm.max_tokens = 1000
mock_llm.cache_enabled = False
mock_llm.cache_dir = None
# Mock the public method that's actually called
mock_llm.generate_response = Mock()
mock_llm.generate_response.return_value = {
'tool_calls': [
{
'name': 'extract_entities',
'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
}
]
}
return mock_llm
@pytest.fixture
def mock_cross_encoder_client():
"""Create a mock LLM"""
mock_llm = Mock(spec=CrossEncoderClient)
mock_llm.config = Mock()
# Mock the public method that's actually called
mock_llm.rerank = Mock()
mock_llm.rerank.return_value = {
'tool_calls': [
{
'name': 'extract_entities',
'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
}
]
}
return mock_llm
@pytest.mark.asyncio
async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
now = datetime.now()
# Create episodic nodes
episode_node_1 = EpisodicNode(
name='test_episode',
group_id=group_id,
labels=[],
created_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
valid_at=now,
entity_edges=[], # Filled in later
)
episode_node_2 = EpisodicNode(
name='test_episode_2',
group_id=group_id,
labels=[],
created_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Bob adores Alice',
valid_at=now,
entity_edges=[], # Filled in later
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
group_id=group_id,
labels=['Entity', 'Person'],
created_at=now,
summary='test_entity_1 summary',
attributes={'age': 30, 'location': 'New York'},
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
group_id=group_id,
labels=['Entity', 'Person2'],
created_at=now,
summary='test_entity_2 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
group_id=group_id,
labels=['Entity', 'City', 'Location'],
created_at=now,
summary='test_entity_3 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
group_id=group_id,
labels=['Entity'],
created_at=now,
summary='test_entity_4 summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
created_at=now,
name='likes',
fact='test_entity_1 relates to test_entity_2',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_3.uuid,
target_node_uuid=entity_node_4.uuid,
created_at=now,
name='relates_to',
fact='test_entity_3 relates to test_entity_4',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episode_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=now,
group_id=group_id,
)
episodic_edge_2 = EpisodicEdge(
source_node_uuid=episode_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
created_at=now,
group_id=group_id,
)
episodic_edge_3 = EpisodicEdge(
source_node_uuid=episode_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
created_at=now,
group_id=group_id,
)
episodic_edge_4 = EpisodicEdge(
source_node_uuid=episode_node_2.uuid,
target_node_uuid=entity_node_4.uuid,
created_at=now,
group_id=group_id,
)
# Cross reference the ids
episode_node_1.entity_edges = [entity_edge_1.uuid]
episode_node_2.entity_edges = [entity_edge_2.uuid]
entity_edge_1.episodes = [episode_node_1.uuid, episode_node_2.uuid]
entity_edge_2.episodes = [episode_node_2.uuid]
# Test add bulk
await add_nodes_and_edges_bulk(
graph_driver,
[episode_node_1, episode_node_2],
[episodic_edge_1, episodic_edge_2, episodic_edge_3, episodic_edge_4],
[entity_node_1, entity_node_2, entity_node_3, entity_node_4],
[entity_edge_1, entity_edge_2],
mock_embedder,
)
node_ids = [
episode_node_1.uuid,
episode_node_2.uuid,
entity_node_1.uuid,
entity_node_2.uuid,
entity_node_3.uuid,
entity_node_4.uuid,
]
edge_ids = [
episodic_edge_1.uuid,
episodic_edge_2.uuid,
episodic_edge_3.uuid,
episodic_edge_4.uuid,
entity_edge_1.uuid,
entity_edge_2.uuid,
]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == len(node_ids)
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == len(edge_ids)
# Test episodic nodes
retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_1.uuid)
await assert_episodic_node_equals(retrieved_episode, episode_node_1)
retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_2.uuid)
await assert_episodic_node_equals(retrieved_episode, episode_node_2)
# Test entity nodes
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_1.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_1)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_4.uuid)
await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_4)
# Test episodic edges
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_1.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_1)
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_2.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_2)
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_3.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_3)
retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_4.uuid)
await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_4)
# Test entity edges
retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_1.uuid)
await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_1)
retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
@pytest.mark.asyncio
async def test_remove_episode(
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
now = datetime.now()
# Create episodic nodes
episode_node = EpisodicNode(
name='test_episode',
group_id=group_id,
labels=[],
created_at=now,
source=EpisodeType.message,
source_description='conversation message',
content='Alice likes Bob',
valid_at=now,
entity_edges=[], # Filled in later
)
# Create entity nodes
alice_node = EntityNode(
name='Alice',
group_id=group_id,
labels=['Entity', 'Person'],
created_at=now,
summary='Alice summary',
attributes={'age': 30, 'location': 'New York'},
)
await alice_node.generate_name_embedding(mock_embedder)
bob_node = EntityNode(
name='Bob',
group_id=group_id,
labels=['Entity', 'Person2'],
created_at=now,
summary='Bob summary',
attributes={'age': 25, 'location': 'Los Angeles'},
)
await bob_node.generate_name_embedding(mock_embedder)
# Create entity to entity edge
entity_edge = EntityEdge(
source_node_uuid=alice_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
name='likes',
fact='Alice likes Bob',
episodes=[],
expired_at=now,
valid_at=now,
invalid_at=now,
group_id=group_id,
)
await entity_edge.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_alice_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=alice_node.uuid,
created_at=now,
group_id=group_id,
)
episodic_bob_edge = EpisodicEdge(
source_node_uuid=episode_node.uuid,
target_node_uuid=bob_node.uuid,
created_at=now,
group_id=group_id,
)
# Cross reference the ids
episode_node.entity_edges = [entity_edge.uuid]
entity_edge.episodes = [episode_node.uuid]
# Test add bulk
await add_nodes_and_edges_bulk(
graph_driver,
[episode_node],
[episodic_alice_edge, episodic_bob_edge],
[alice_node, bob_node],
[entity_edge],
mock_embedder,
)
node_ids = [episode_node.uuid, alice_node.uuid, bob_node.uuid]
edge_ids = [episodic_alice_edge.uuid, episodic_bob_edge.uuid, entity_edge.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 3
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 3
# Test remove episode
await graphiti.remove_episode(episode_node.uuid)
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 0
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 0
# Test add bulk again
await add_nodes_and_edges_bulk(
graph_driver,
[episode_node],
[episodic_alice_edge, episodic_bob_edge],
[alice_node, bob_node],
[entity_edge],
mock_embedder,
)
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 3
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 3
@pytest.mark.asyncio
async def test_graphiti_retrieve_episodes(
graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
now = datetime.now()
valid_at_1 = now - timedelta(days=2)
valid_at_2 = now - timedelta(days=4)
valid_at_3 = now - timedelta(days=6)
# Create episodic nodes
episode_node_1 = EpisodicNode(
name='test_episode_1',
labels=[],
created_at=now,
valid_at=valid_at_1,
source=EpisodeType.message,
source_description='conversation message',
content='Test message 1',
entity_edges=[],
group_id=group_id,
)
episode_node_2 = EpisodicNode(
name='test_episode_2',
labels=[],
created_at=now,
valid_at=valid_at_2,
source=EpisodeType.message,
source_description='conversation message',
content='Test message 2',
entity_edges=[],
group_id=group_id,
)
episode_node_3 = EpisodicNode(
name='test_episode_3',
labels=[],
created_at=now,
valid_at=valid_at_3,
source=EpisodeType.message,
source_description='conversation message',
content='Test message 3',
entity_edges=[],
group_id=group_id,
)
# Save the nodes
await episode_node_1.save(graph_driver)
await episode_node_2.save(graph_driver)
await episode_node_3.save(graph_driver)
node_ids = [episode_node_1.uuid, episode_node_2.uuid, episode_node_3.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 3
# Retrieve episodes
query_time = now - timedelta(days=3)
episodes = await graphiti.retrieve_episodes(
query_time, last_n=5, group_ids=[group_id], source=EpisodeType.message
)
assert len(episodes) == 2
assert episodes[0].name == episode_node_3.name
assert episodes[1].name == episode_node_2.name
@pytest.mark.asyncio
async def test_filter_existing_duplicate_of_edges(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Save the nodes
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_node_4.save(graph_driver)
node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 4
# Create duplicate entity edge
entity_edge = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='IS_DUPLICATE_OF',
fact='test_entity_1 is a duplicate of test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge.generate_embedding(mock_embedder)
await entity_edge.save(graph_driver)
# Filter duplicate entity edges
duplicate_node_tuples = [
(entity_node_1, entity_node_2),
(entity_node_3, entity_node_4),
]
node_tuples = await filter_existing_duplicate_of_edges(graph_driver, duplicate_node_tuples)
assert len(node_tuples) == 1
assert [node.name for node in node_tuples[0]] == [entity_node_3.name, entity_node_4.name]
@pytest.mark.asyncio
async def test_determine_entity_community(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_2 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
entity_edge_3 = EntityEdge(
source_node_uuid=entity_node_3.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_3 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_3.generate_embedding(mock_embedder)
# Create community nodes
community_node_1 = CommunityNode(
name='test_community_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
community_node_2 = CommunityNode(
name='test_community_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_2.generate_name_embedding(mock_embedder)
# Create community to entity edges
community_edge_1 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
community_edge_2 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
created_at=datetime.now(),
group_id=group_id,
)
community_edge_3 = CommunityEdge(
source_node_uuid=community_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_node_4.save(graph_driver)
await community_node_1.save(graph_driver)
await community_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await entity_edge_3.save(graph_driver)
await community_edge_1.save(graph_driver)
await community_edge_2.save(graph_driver)
await community_edge_3.save(graph_driver)
node_ids = [
entity_node_1.uuid,
entity_node_2.uuid,
entity_node_3.uuid,
entity_node_4.uuid,
community_node_1.uuid,
community_node_2.uuid,
]
edge_ids = [
entity_edge_1.uuid,
entity_edge_2.uuid,
entity_edge_3.uuid,
community_edge_1.uuid,
community_edge_2.uuid,
community_edge_3.uuid,
]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 6
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 6
# Determine entity community
community, is_new = await determine_entity_community(graph_driver, entity_node_4)
assert community.name == community_node_1.name
assert is_new
# Add entity to community edge
community_edge_4 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_4.uuid,
created_at=datetime.now(),
group_id=group_id,
)
await community_edge_4.save(graph_driver)
# Determine entity community again
community, is_new = await determine_entity_community(graph_driver, entity_node_4)
assert community.name == community_node_1.name
assert not is_new
await remove_communities(graph_driver)
node_count = await get_node_count(graph_driver, [community_node_1.uuid, community_node_2.uuid])
assert node_count == 0
@pytest.mark.asyncio
async def test_get_community_clusters(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as test fails on FalkorDB')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id_2,
)
await entity_node_3.generate_name_embedding(mock_embedder)
entity_node_4 = EntityNode(
name='test_entity_4',
labels=[],
created_at=datetime.now(),
group_id=group_id_2,
)
await entity_node_4.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_3.uuid,
target_node_uuid=entity_node_4.uuid,
name='RELATES_TO',
fact='test_entity_3 relates to test_entity_4',
created_at=datetime.now(),
group_id=group_id_2,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_node_4.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
edge_ids = [entity_edge_1.uuid, entity_edge_2.uuid]
node_count = await get_node_count(graph_driver, node_ids)
assert node_count == 4
edge_count = await get_edge_count(graph_driver, edge_ids)
assert edge_count == 2
# Get community clusters
clusters = await get_community_clusters(graph_driver, group_ids=None)
assert len(clusters) == 2
assert len(clusters[0]) == 2
assert len(clusters[1]) == 2
entities_1 = set([node.name for node in clusters[0]])
entities_2 = set([node.name for node in clusters[1]])
assert entities_1 == set(['test_entity_1', 'test_entity_2']) or entities_2 == set(
['test_entity_1', 'test_entity_2']
)
assert entities_1 == set(['test_entity_3', 'test_entity_4']) or entities_2 == set(
['test_entity_3', 'test_entity_4']
)
@pytest.mark.asyncio
async def test_get_mentioned_nodes(graph_driver, mock_embedder):
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='test_source_description',
content='test_content',
valid_at=datetime.now(),
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await episodic_node_1.save(graph_driver)
await entity_node_1.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Get mentioned nodes
mentioned_nodes = await get_mentioned_nodes(graph_driver, [episodic_node_1])
assert len(mentioned_nodes) == 1
assert mentioned_nodes[0].name == entity_node_1.name
@pytest.mark.asyncio
async def test_get_communities_by_nodes(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
# Create community nodes
community_node_1 = CommunityNode(
name='test_community_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
# Create community to entity edges
community_edge_1 = CommunityEdge(
source_node_uuid=community_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await entity_node_1.save(graph_driver)
await community_node_1.save(graph_driver)
await community_edge_1.save(graph_driver)
# Get communities by nodes
communities = await get_communities_by_nodes(graph_driver, [entity_node_1])
assert len(communities) == 1
assert communities[0].name == community_node_1.name
@pytest.mark.asyncio
async def test_edge_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
edges = await edge_fulltext_search(
graph_driver, 'test_entity_1 relates to test_entity_2', search_filters, group_ids=[group_id]
)
assert len(edges) == 1
assert edges[0].name == entity_edge_1.name
@pytest.mark.asyncio
async def test_edge_similarity_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
edges = await edge_similarity_search(
graph_driver,
entity_edge_1.fact_embedding,
entity_node_1.uuid,
entity_node_2.uuid,
search_filters,
group_ids=[group_id],
)
assert len(edges) == 1
assert edges[0].name == entity_edge_1.name
@pytest.mark.asyncio
async def test_edge_bfs_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='test_source_description',
content='test_content',
valid_at=datetime.now(),
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='test_entity_2 relates to test_entity_3',
created_at=created_at,
valid_at=valid_at,
invalid_at=invalid_at,
expired_at=expired_at,
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await episodic_node_1.save(graph_driver)
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
# Test bfs from episodic node
edges = await edge_bfs_search(
graph_driver,
[episodic_node_1.uuid],
1,
search_filters,
group_ids=[group_id],
)
assert len(edges) == 0
edges = await edge_bfs_search(
graph_driver,
[episodic_node_1.uuid],
2,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 1
assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
edges = await edge_bfs_search(
graph_driver,
[episodic_node_1.uuid],
3,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 2
assert edges_deduplicated == {
'test_entity_1 relates to test_entity_2',
'test_entity_2 relates to test_entity_3',
}
# Test bfs from entity node
edges = await edge_bfs_search(
graph_driver,
[entity_node_1.uuid],
1,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 1
assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
edges = await edge_bfs_search(
graph_driver,
[entity_node_1.uuid],
2,
search_filters,
group_ids=[group_id],
)
edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
assert len(edges_deduplicated) == 2
assert edges_deduplicated == {
'test_entity_1 relates to test_entity_2',
'test_entity_2 relates to test_entity_3',
}
@pytest.mark.asyncio
async def test_node_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
summary='Summary about Alice',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
summary='Summary about Bob',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(node_labels=['Entity'])
nodes = await node_fulltext_search(
graph_driver,
'Alice',
search_filters,
group_ids=[group_id],
)
assert len(nodes) == 1
assert nodes[0].name == entity_node_1.name
@pytest.mark.asyncio
async def test_node_similarity_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_alice',
summary='Summary about Alice',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_bob',
summary='Summary about Bob',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
# Search for entity edges
search_filters = SearchFilters(node_labels=['Entity'])
nodes = await node_similarity_search(
graph_driver,
entity_node_1.name_embedding,
search_filters,
group_ids=[group_id],
min_score=0.9,
)
assert len(nodes) == 1
assert nodes[0].name == entity_node_1.name
@pytest.mark.asyncio
async def test_node_bfs_search(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='test_source_description',
content='test_content',
valid_at=datetime.now(),
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='test_entity_2 relates to test_entity_3',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
# Create episodic to entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await episodic_node_1.save(graph_driver)
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Search for entity nodes
search_filters = SearchFilters(
node_labels=['Entity'],
)
# Test bfs from episodic node
nodes = await node_bfs_search(
graph_driver,
[episodic_node_1.uuid],
search_filters,
1,
group_ids=[group_id],
)
nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
assert len(nodes_deduplicated) == 1
assert nodes_deduplicated == {'test_entity_1'}
nodes = await node_bfs_search(
graph_driver,
[episodic_node_1.uuid],
search_filters,
2,
group_ids=[group_id],
)
nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
assert len(nodes_deduplicated) == 2
assert nodes_deduplicated == {'test_entity_1', 'test_entity_2'}
# Test bfs from entity node
nodes = await node_bfs_search(
graph_driver,
[entity_node_1.uuid],
search_filters,
1,
group_ids=[group_id],
)
nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
assert len(nodes_deduplicated) == 1
assert nodes_deduplicated == {'test_entity_2'}
@pytest.mark.asyncio
async def test_episode_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
content='test_content',
created_at=datetime.now(),
valid_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='Description about Alice',
)
episodic_node_2 = EpisodicNode(
name='test_episodic_2',
content='test_content_2',
created_at=datetime.now(),
valid_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='Description about Bob',
)
# Save the graph
await episodic_node_1.save(graph_driver)
await episodic_node_2.save(graph_driver)
# Search for episodic nodes
search_filters = SearchFilters(node_labels=['Episodic'])
nodes = await episode_fulltext_search(
graph_driver,
'Alice',
search_filters,
group_ids=[group_id],
)
assert len(nodes) == 1
assert nodes[0].name == episodic_node_1.name
@pytest.mark.asyncio
async def test_community_fulltext_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create community nodes
community_node_1 = CommunityNode(
name='Alice',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
community_node_2 = CommunityNode(
name='Bob',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await community_node_1.save(graph_driver)
await community_node_2.save(graph_driver)
# Search for community nodes
nodes = await community_fulltext_search(
graph_driver,
'Alice',
group_ids=[group_id],
)
assert len(nodes) == 1
assert nodes[0].name == community_node_1.name
@pytest.mark.asyncio
async def test_community_similarity_search(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create community nodes
community_node_1 = CommunityNode(
name='Alice',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
community_node_2 = CommunityNode(
name='Bob',
created_at=datetime.now(),
group_id=group_id,
)
await community_node_2.generate_name_embedding(mock_embedder)
# Save the graph
await community_node_1.save(graph_driver)
await community_node_2.save(graph_driver)
# Search for community nodes
nodes = await community_similarity_search(
graph_driver,
community_node_1.name_embedding,
group_ids=[group_id],
min_score=0.9,
)
assert len(nodes) == 1
assert nodes[0].name == community_node_1.name
@pytest.mark.asyncio
async def test_get_relevant_nodes(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
if graph_driver.provider == GraphProvider.KUZU:
pytest.skip('Skipping as tests fail on Kuzu')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='Alice',
summary='Alice',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='Bob',
summary='Bob',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='Alice Smith',
summary='Alice Smith',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
# Search for entity nodes
search_filters = SearchFilters(node_labels=['Entity'])
nodes = (
await get_relevant_nodes(
graph_driver,
[entity_node_1],
search_filters,
min_score=0.9,
)
)[0]
assert len(nodes) == 2
assert set({node.name for node in nodes}) == {entity_node_1.name, entity_node_3.name}
@pytest.mark.asyncio
async def test_get_relevant_edges_and_invalidation_candidates(
graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
graphiti = Graphiti(
graph_driver=graph_driver,
llm_client=mock_llm_client,
embedder=mock_embedder,
cross_encoder=mock_cross_encoder_client,
)
await graphiti.build_indices_and_constraints()
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
summary='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
summary='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
summary='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
now = datetime.now()
created_at = now
expired_at = now + timedelta(days=6)
valid_at = now + timedelta(days=2)
invalid_at = now + timedelta(days=4)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='Alice',
created_at=created_at,
expired_at=expired_at,
valid_at=valid_at,
invalid_at=invalid_at,
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
entity_edge_2 = EntityEdge(
source_node_uuid=entity_node_2.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='Bob',
created_at=created_at,
expired_at=expired_at,
valid_at=valid_at,
invalid_at=invalid_at,
group_id=group_id,
)
await entity_edge_2.generate_embedding(mock_embedder)
entity_edge_3 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_3.uuid,
name='RELATES_TO',
fact='Alice',
created_at=created_at,
expired_at=expired_at,
valid_at=valid_at,
invalid_at=invalid_at,
group_id=group_id,
)
await entity_edge_3.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
await entity_edge_2.save(graph_driver)
await entity_edge_3.save(graph_driver)
# Search for entity nodes
search_filters = SearchFilters(
node_labels=['Entity'],
edge_types=['RELATES_TO'],
created_at=[
[DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
],
expired_at=[
[DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
],
valid_at=[
[
DateFilter(
date=now + timedelta(days=1),
comparison_operator=ComparisonOperator.greater_than_equal,
)
],
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.less_than_equal,
)
],
],
invalid_at=[
[
DateFilter(
date=now + timedelta(days=3),
comparison_operator=ComparisonOperator.greater_than,
)
],
[
DateFilter(
date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
)
],
],
)
edges = (
await get_relevant_edges(
graph_driver,
[entity_edge_1],
search_filters,
min_score=0.9,
)
)[0]
assert len(edges) == 1
assert set({edge.name for edge in edges}) == {entity_edge_1.name}
edges = (
await get_edge_invalidation_candidates(
graph_driver,
[entity_edge_1],
search_filters,
min_score=0.9,
)
)[0]
assert len(edges) == 2
assert set({edge.name for edge in edges}) == {entity_edge_1.name, entity_edge_3.name}
@pytest.mark.asyncio
async def test_node_distance_reranker(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
entity_node_3 = EntityNode(
name='test_entity_3',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_3.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_node_3.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Test reranker
reranked_uuids, reranked_scores = await node_distance_reranker(
graph_driver,
[entity_node_2.uuid, entity_node_3.uuid],
entity_node_1.uuid,
)
uuid_to_name = {
entity_node_1.uuid: entity_node_1.name,
entity_node_2.uuid: entity_node_2.name,
entity_node_3.uuid: entity_node_3.name,
}
names = [uuid_to_name[uuid] for uuid in reranked_uuids]
assert names == [entity_node_2.name, entity_node_3.name]
assert np.allclose(reranked_scores, [1.0, 0.0])
@pytest.mark.asyncio
async def test_episode_mentions_reranker(graph_driver, mock_embedder):
if graph_driver.provider == GraphProvider.FALKORDB:
pytest.skip('Skipping as tests fail on Falkordb')
# Create episodic nodes
episodic_node_1 = EpisodicNode(
name='test_episodic_1',
content='test_content',
created_at=datetime.now(),
valid_at=datetime.now(),
group_id=group_id,
source=EpisodeType.message,
source_description='Description about Alice',
)
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Create entity edges
episodic_edge_1 = EpisodicEdge(
source_node_uuid=episodic_node_1.uuid,
target_node_uuid=entity_node_1.uuid,
created_at=datetime.now(),
group_id=group_id,
)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await episodic_node_1.save(graph_driver)
await episodic_edge_1.save(graph_driver)
# Test reranker
reranked_uuids, reranked_scores = await episode_mentions_reranker(
graph_driver,
[[entity_node_1.uuid, entity_node_2.uuid]],
)
uuid_to_name = {entity_node_1.uuid: entity_node_1.name, entity_node_2.uuid: entity_node_2.name}
names = [uuid_to_name[uuid] for uuid in reranked_uuids]
assert names == [entity_node_1.name, entity_node_2.name]
assert np.allclose(reranked_scores, [1.0, float('inf')])
@pytest.mark.asyncio
async def test_get_embeddings_for_edges(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
entity_node_2 = EntityNode(
name='test_entity_2',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_2.generate_name_embedding(mock_embedder)
# Create entity edges
entity_edge_1 = EntityEdge(
source_node_uuid=entity_node_1.uuid,
target_node_uuid=entity_node_2.uuid,
name='RELATES_TO',
fact='test_entity_1 relates to test_entity_2',
created_at=datetime.now(),
group_id=group_id,
)
await entity_edge_1.generate_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
await entity_node_2.save(graph_driver)
await entity_edge_1.save(graph_driver)
# Get embeddings for edges
embeddings = await get_embeddings_for_edges(graph_driver, [entity_edge_1])
assert len(embeddings) == 1
assert entity_edge_1.uuid in embeddings
assert np.allclose(embeddings[entity_edge_1.uuid], entity_edge_1.fact_embedding)
@pytest.mark.asyncio
async def test_get_embeddings_for_nodes(graph_driver, mock_embedder):
# Create entity nodes
entity_node_1 = EntityNode(
name='test_entity_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await entity_node_1.generate_name_embedding(mock_embedder)
# Save the graph
await entity_node_1.save(graph_driver)
# Get embeddings for edges
embeddings = await get_embeddings_for_nodes(graph_driver, [entity_node_1])
assert len(embeddings) == 1
assert entity_node_1.uuid in embeddings
assert np.allclose(embeddings[entity_node_1.uuid], entity_node_1.name_embedding)
@pytest.mark.asyncio
async def test_get_embeddings_for_communities(graph_driver, mock_embedder):
# Create community nodes
community_node_1 = CommunityNode(
name='test_community_1',
labels=[],
created_at=datetime.now(),
group_id=group_id,
)
await community_node_1.generate_name_embedding(mock_embedder)
# Save the graph
await community_node_1.save(graph_driver)
# Get embeddings for communities
embeddings = await get_embeddings_for_communities(graph_driver, [community_node_1])
assert len(embeddings) == 1
assert community_node_1.uuid in embeddings
assert np.allclose(embeddings[community_node_1.uuid], community_node_1.name_embedding)
```
--------------------------------------------------------------------------------
/graphiti_core/search/search_utils.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections import defaultdict
from time import time
from typing import Any
import numpy as np
from numpy._typing import NDArray
from typing_extensions import LiteralString
from graphiti_core.driver.driver import (
GraphDriver,
GraphProvider,
)
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.graph_queries import (
get_nodes_query,
get_relationships_query,
get_vector_cosine_func_query,
)
from graphiti_core.helpers import (
lucene_sanitize,
normalize_l2,
semaphore_gather,
)
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
from graphiti_core.models.nodes.node_db_queries import (
COMMUNITY_NODE_RETURN,
EPISODIC_NODE_RETURN,
get_entity_node_return_query,
)
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodicNode,
get_community_node_from_record,
get_entity_node_from_record,
get_episodic_node_from_record,
)
from graphiti_core.search.search_filters import (
SearchFilters,
edge_search_filter_query_constructor,
node_search_filter_query_constructor,
)
logger = logging.getLogger(__name__)
RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128
def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
"""
Calculates the cosine similarity between two vectors using NumPy.
"""
dot_product = np.dot(vector1, vector2)
norm_vector1 = np.linalg.norm(vector1)
norm_vector2 = np.linalg.norm(vector2)
if norm_vector1 == 0 or norm_vector2 == 0:
return 0 # Handle cases where one or both vectors are zero vectors
return dot_product / (norm_vector1 * norm_vector2)
def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
if driver.provider == GraphProvider.KUZU:
# Kuzu only supports simple queries.
if len(query.split(' ')) > MAX_QUERY_LENGTH:
return ''
return query
elif driver.provider == GraphProvider.FALKORDB:
return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
group_ids_filter_list = (
[driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
if group_ids is not None
else []
)
group_ids_filter = ''
for f in group_ids_filter_list:
group_ids_filter += f if not group_ids_filter else f' OR {f}'
group_ids_filter += ' AND ' if group_ids_filter else ''
lucene_query = lucene_sanitize(query)
# If the lucene query is too long return no query
if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
return ''
full_query = group_ids_filter + '(' + lucene_query + ')'
return full_query
async def get_episodes_by_mentions(
driver: GraphDriver,
nodes: list[EntityNode],
edges: list[EntityEdge],
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
episode_uuids: list[str] = []
for edge in edges:
episode_uuids.extend(edge.episodes)
episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
return episodes
async def get_mentioned_nodes(
driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
WHERE episode.uuid IN $uuids
RETURN DISTINCT
"""
+ get_entity_node_return_query(driver.provider),
uuids=episode_uuids,
routing_='r',
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
async def get_communities_by_nodes(
driver: GraphDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
node_uuids = [node.uuid for node in nodes]
records, _, _ = await driver.execute_query(
"""
MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
WHERE m.uuid IN $uuids
RETURN DISTINCT
"""
+ COMMUNITY_NODE_RETURN,
uuids=node_uuids,
routing_='r',
)
communities = [get_community_node_from_record(record) for record in records]
return communities
async def edge_fulltext_search(
driver: GraphDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
if driver.search_interface:
return await driver.search_interface.edge_fulltext_search(
driver, query, search_filter, group_ids, limit
)
# fulltext search over facts
fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
match_query = """
YIELD relationship AS rel, score
MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
YIELD node, score
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
"""
filter_queries, filter_params = edge_search_filter_query_constructor(
search_filter, driver.provider
)
if group_ids is not None:
filter_queries.append('e.group_id IN $group_ids')
filter_params['group_ids'] = group_ids
filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
# Match the edge ids and return the values
query = (
"""
UNWIND $ids as id
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
WHERE e.group_id IN $group_ids
AND id(e)=id
"""
+ filter_query
+ """
AND id(e)=id
WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
RETURN
e.uuid AS uuid,
e.group_id AS group_id,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
split(e.episodes, ",") AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
ORDER BY score DESC LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
ids=input_ids,
limit=limit,
routing_='r',
**filter_params,
)
else:
return []
else:
query = (
get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
+ match_query
+ filter_query
+ """
WITH e, score, n, m
RETURN
"""
+ get_entity_edge_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
async def edge_similarity_search(
driver: GraphDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
if driver.search_interface:
return await driver.search_interface.edge_similarity_search(
driver,
search_vector,
source_node_uuid,
target_node_uuid,
search_filter,
group_ids,
limit,
min_score,
)
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
"""
filter_queries, filter_params = edge_search_filter_query_constructor(
search_filter, driver.provider
)
if group_ids is not None:
filter_queries.append('e.group_id IN $group_ids')
filter_params['group_ids'] = group_ids
if source_node_uuid is not None:
filter_params['source_uuid'] = source_node_uuid
filter_queries.append('n.uuid = $source_uuid')
if target_node_uuid is not None:
filter_params['target_uuid'] = target_node_uuid
filter_queries.append('m.uuid = $target_uuid')
filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
search_vector_var = '$search_vector'
if driver.provider == GraphProvider.KUZU:
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
"""
)
resp, header, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
if len(resp) > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
if r['embedding']:
score = calculate_cosine_similarity(
search_vector, list(map(float, r['embedding'].split(',')))
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score})
# Match the edge ides and return the values
query = """
UNWIND $ids as i
MATCH ()-[r]->()
WHERE id(r) = i.id
RETURN
r.uuid AS uuid,
r.group_id AS group_id,
startNode(r).uuid AS source_node_uuid,
endNode(r).uuid AS target_node_uuid,
r.created_at AS created_at,
r.name AS name,
r.fact AS fact,
split(r.episodes, ",") AS episodes,
r.expired_at AS expired_at,
r.valid_at AS valid_at,
r.invalid_at AS invalid_at,
properties(r) AS attributes
ORDER BY i.score DESC
LIMIT $limit
"""
records, _, _ = await driver.execute_query(
query,
ids=input_ids,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
else:
return []
else:
query = (
match_query
+ filter_query
+ """
WITH DISTINCT e, n, m, """
+ get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
"""
+ get_entity_edge_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
async def edge_bfs_search(
driver: GraphDriver,
bfs_origin_node_uuids: list[str] | None,
bfs_max_depth: int,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
return []
filter_queries, filter_params = edge_search_filter_query_constructor(
search_filter, driver.provider
)
if group_ids is not None:
filter_queries.append('e.group_id IN $group_ids')
filter_params['group_ids'] = group_ids
filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.KUZU:
# Kuzu stores entity edges twice with an intermediate node, so we need to match them
# separately for the correct BFS depth.
depth = bfs_max_depth * 2 - 1
match_queries = [
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
UNWIND nodes(path) AS relNode
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
""",
]
if bfs_max_depth > 1:
depth = (bfs_max_depth - 1) * 2 - 1
match_queries.append(f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
UNWIND nodes(path) AS relNode
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
""")
records = []
for match_query in match_queries:
sub_records, _, _ = await driver.execute_query(
match_query
+ filter_query
+ """
RETURN DISTINCT
"""
+ get_entity_edge_return_query(driver.provider)
+ """
LIMIT $limit
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
limit=limit,
routing_='r',
**filter_params,
)
records.extend(sub_records)
else:
if driver.provider == GraphProvider.NEPTUNE:
query = (
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
WHERE origin:Entity OR origin:Episodic
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT
e.uuid AS uuid,
e.group_id AS group_id,
startNode(e).uuid AS source_node_uuid,
endNode(e).uuid AS target_node_uuid,
e.created_at AS created_at,
e.name AS name,
e.fact AS fact,
split(e.episodes, ',') AS episodes,
e.expired_at AS expired_at,
e.valid_at AS valid_at,
e.invalid_at AS invalid_at,
properties(e) AS attributes
LIMIT $limit
"""
)
else:
query = (
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
UNWIND relationships(path) AS rel
MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT
"""
+ get_entity_edge_return_query(driver.provider)
+ """
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
bfs_origin_node_uuids=bfs_origin_node_uuids,
depth=bfs_max_depth,
limit=limit,
routing_='r',
**filter_params,
)
edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
return edges
async def node_fulltext_search(
driver: GraphDriver,
query: str,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
if driver.search_interface:
return await driver.search_interface.node_fulltext_search(
driver, query, search_filter, group_ids, limit
)
# BM25 search to get top nodes
fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
)
if group_ids is not None:
filter_queries.append('n.group_id IN $group_ids')
filter_params['group_ids'] = group_ids
filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
yield_query = 'YIELD node AS n, score'
if driver.provider == GraphProvider.KUZU:
yield_query = 'WITH node AS n, score'
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE n.uuid=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
ids=input_ids,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
else:
return []
else:
query = (
get_nodes_query(
'node_name_and_summary', '$query', limit=limit, provider=driver.provider
)
+ yield_query
+ filter_query
+ """
WITH n, score
ORDER BY score DESC
LIMIT $limit
RETURN
"""
+ get_entity_node_return_query(driver.provider)
)
records, _, _ = await driver.execute_query(
query,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
async def node_similarity_search(
driver: GraphDriver,
search_vector: list[float],
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
if driver.search_interface:
return await driver.search_interface.node_similarity_search(
driver, search_vector, search_filter, group_ids, limit, min_score
)
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
)
if group_ids is not None:
filter_queries.append('n.group_id IN $group_ids')
filter_params['group_ids'] = group_ids
filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
search_vector_var = '$search_vector'
if driver.provider == GraphProvider.KUZU:
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
"""
)
resp, header, _ = await driver.execute_query(
query,
params=filter_params,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
)
if len(resp) > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
if r['embedding']:
score = calculate_cosine_similarity(
search_vector, list(map(float, r['embedding'].split(',')))
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score})
# Match the edge ides and return the values
query = (
"""
UNWIND $ids as i
MATCH (n:Entity)
WHERE id(n)=i.id
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY i.score DESC
LIMIT $limit
"""
)
records, header, _ = await driver.execute_query(
query,
ids=input_ids,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
else:
return []
else:
query = (
"""
MATCH (n:Entity)
"""
+ filter_query
+ """
WITH n, """
+ get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
async def node_bfs_search(
driver: GraphDriver,
bfs_origin_node_uuids: list[str] | None,
search_filter: SearchFilters,
bfs_max_depth: int,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
return []
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
)
if group_ids is not None:
filter_queries.append('n.group_id IN $group_ids')
filter_queries.append('origin.group_id IN $group_ids')
filter_params['group_ids'] = group_ids
filter_query = ''
if filter_queries:
filter_query = ' AND ' + (' AND '.join(filter_queries))
match_queries = [
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
WHERE n.group_id = origin.group_id
"""
]
if driver.provider == GraphProvider.NEPTUNE:
match_queries = [
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
WHERE origin:Entity OR origin.Episode
AND n.group_id = origin.group_id
"""
]
if driver.provider == GraphProvider.KUZU:
depth = bfs_max_depth * 2
match_queries = [
"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
WHERE n.group_id = origin.group_id
""",
f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
WHERE n.group_id = origin.group_id
""",
]
if bfs_max_depth > 1:
depth = (bfs_max_depth - 1) * 2
match_queries.append(f"""
UNWIND $bfs_origin_node_uuids AS origin_uuid
MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
WHERE n.group_id = origin.group_id
""")
records = []
for match_query in match_queries:
sub_records, _, _ = await driver.execute_query(
match_query
+ filter_query
+ """
RETURN
"""
+ get_entity_node_return_query(driver.provider)
+ """
LIMIT $limit
""",
bfs_origin_node_uuids=bfs_origin_node_uuids,
limit=limit,
routing_='r',
**filter_params,
)
records.extend(sub_records)
nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
return nodes
async def episode_fulltext_search(
driver: GraphDriver,
query: str,
_search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
if driver.search_interface:
return await driver.search_interface.episode_fulltext_search(
driver, query, _search_filter, group_ids, limit
)
# BM25 search to get top episodes
fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
filter_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += '\nAND e.group_id IN $group_ids'
filter_params['group_ids'] = group_ids
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
# Match the edge ides and return the values
query = """
UNWIND $ids as i
MATCH (e:Episodic)
WHERE e.uuid=i.uuid
RETURN
e.content AS content,
e.created_at AS created_at,
e.valid_at AS valid_at,
e.uuid AS uuid,
e.name AS name,
e.group_id AS group_id,
e.source_description AS source_description,
e.source AS source,
e.entity_edges AS entity_edges
ORDER BY i.score DESC
LIMIT $limit
"""
records, _, _ = await driver.execute_query(
query,
ids=input_ids,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
else:
return []
else:
query = (
get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
+ """
YIELD node AS episode, score
MATCH (e:Episodic)
WHERE e.uuid = episode.uuid
"""
+ group_filter_query
+ """
RETURN
"""
+ EPISODIC_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
)
episodes = [get_episodic_node_from_record(record) for record in records]
return episodes
async def community_fulltext_search(
driver: GraphDriver,
query: str,
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
# BM25 search to get top communities
fuzzy_query = fulltext_query(query, group_ids, driver)
if fuzzy_query == '':
return []
filter_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query = 'WHERE c.group_id IN $group_ids'
filter_params['group_ids'] = group_ids
yield_query = 'YIELD node AS c, score'
if driver.provider == GraphProvider.KUZU:
yield_query = 'WITH node AS c, score'
if driver.provider == GraphProvider.NEPTUNE:
res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
if res['hits']['total']['value'] > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in res['hits']['hits']:
input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
# Match the edge ides and return the values
query = """
UNWIND $ids as i
MATCH (comm:Community)
WHERE comm.uuid=i.id
RETURN
comm.uuid AS uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.created_at AS created_at,
comm.summary AS summary,
[x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
ORDER BY i.score DESC
LIMIT $limit
"""
records, _, _ = await driver.execute_query(
query,
ids=input_ids,
query=fuzzy_query,
limit=limit,
routing_='r',
**filter_params,
)
else:
return []
else:
query = (
get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
+ yield_query
+ """
WITH c, score
"""
+ group_filter_query
+ """
RETURN
"""
+ COMMUNITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
)
communities = [get_community_node_from_record(record) for record in records]
return communities
async def community_similarity_search(
driver: GraphDriver,
search_vector: list[float],
group_ids: list[str] | None = None,
limit=RELEVANT_SCHEMA_LIMIT,
min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
# vector similarity search over entity names
query_params: dict[str, Any] = {}
group_filter_query: LiteralString = ''
if group_ids is not None:
group_filter_query += ' WHERE c.group_id IN $group_ids'
query_params['group_ids'] = group_ids
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
MATCH (n:Community)
"""
+ group_filter_query
+ """
RETURN DISTINCT id(n) as id, n.name_embedding as embedding
"""
)
resp, header, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
if len(resp) > 0:
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
if r['embedding']:
score = calculate_cosine_similarity(
search_vector, list(map(float, r['embedding'].split(',')))
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score})
# Match the edge ides and return the values
query = """
UNWIND $ids as i
MATCH (comm:Community)
WHERE id(comm)=i.id
RETURN
comm.uuid As uuid,
comm.group_id AS group_id,
comm.name AS name,
comm.created_at AS created_at,
comm.summary AS summary,
comm.name_embedding AS name_embedding
ORDER BY i.score DESC
LIMIT $limit
"""
records, header, _ = await driver.execute_query(
query,
ids=input_ids,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
else:
return []
else:
search_vector_var = '$search_vector'
if driver.provider == GraphProvider.KUZU:
search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
query = (
"""
MATCH (c:Community)
"""
+ group_filter_query
+ """
WITH c,
"""
+ get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
+ """ AS score
WHERE score > $min_score
RETURN
"""
+ COMMUNITY_NODE_RETURN
+ """
ORDER BY score DESC
LIMIT $limit
"""
)
records, _, _ = await driver.execute_query(
query,
search_vector=search_vector,
limit=limit,
min_score=min_score,
routing_='r',
**query_params,
)
communities = [get_community_node_from_record(record) for record in records]
return communities
async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: GraphDriver,
search_filter: SearchFilters,
group_ids: list[str] | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Perform a hybrid search for nodes using both text queries and embeddings.
This method combines fulltext search and vector similarity search to find
relevant nodes in the graph database. It uses a rrf reranker.
Parameters
----------
queries : list[str]
A list of text queries to search for.
embeddings : list[list[float]]
A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
driver : GraphDriver
The Neo4j driver instance for database operations.
group_ids : list[str] | None, optional
The list of group ids to retrieve nodes from.
limit : int | None, optional
The maximum number of results to return per search method. If None, a default limit will be applied.
Returns
-------
list[EntityNode]
A list of unique EntityNode objects that match the search criteria.
Notes
-----
This method performs the following steps:
1. Executes fulltext searches for each query.
2. Executes vector similarity searches for each embedding.
3. Combines and deduplicates the results from both search types.
4. Logs the performance metrics of the search operation.
The search results are deduplicated based on the node UUIDs to ensure
uniqueness in the returned list. The 'limit' parameter is applied to each
individual search method before deduplication. If not specified, a default
limit (defined in the individual search functions) will be used.
"""
start = time()
results: list[list[EntityNode]] = list(
await semaphore_gather(
*[
node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
for q in queries
],
*[
node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
for e in embeddings
],
)
)
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for result in results for node in result
}
result_uuids = [[node.uuid for node in result] for result in results]
ranked_uuids, _ = rrf(result_uuids)
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
end = time()
logger.debug(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
return relevant_nodes
async def get_relevant_nodes(
driver: GraphDriver,
nodes: list[EntityNode],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityNode]]:
if len(nodes) == 0:
return []
group_id = nodes[0].group_id
query_nodes = [
{
'uuid': node.uuid,
'name': node.name,
'name_embedding': node.name_embedding,
'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
}
for node in nodes
]
filter_queries, filter_params = node_search_filter_query_constructor(
search_filter, driver.provider
)
filter_query = ''
if filter_queries:
filter_query = 'WHERE ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.KUZU:
embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
if embedding_size == 0:
return []
# FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
+ get_vector_cosine_func_query(
'n.name_embedding',
f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
driver.provider,
)
+ """ AS score
WHERE score > $min_score
WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
"""
+ get_nodes_query(
'node_name_and_summary',
'node.fulltext_query',
limit=limit,
provider=driver.provider,
)
+ """
WITH node AS m
WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
UNWIND combined_nodes AS x
WITH node, collect(DISTINCT {
uuid: x.uuid,
name: x.name,
name_embedding: x.name_embedding,
group_id: x.group_id,
created_at: x.created_at,
summary: x.summary,
labels: x.labels,
attributes: x.attributes
}) AS matches
RETURN
node.uuid AS search_node_uuid, matches
"""
)
else:
query = (
"""
UNWIND $nodes AS node
MATCH (n:Entity {group_id: $group_id})
"""
+ filter_query
+ """
WITH node, n, """
+ get_vector_cosine_func_query(
'n.name_embedding', 'node.name_embedding', driver.provider
)
+ """ AS score
WHERE score > $min_score
WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
"""
+ get_nodes_query(
'node_name_and_summary',
'node.fulltext_query',
limit=limit,
provider=driver.provider,
)
+ """
YIELD node AS m
WHERE m.group_id = $group_id
WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
WITH node,
top_vector_nodes,
[m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
UNWIND combined_nodes AS combined_node
WITH node, collect(DISTINCT combined_node) AS deduped_nodes
RETURN
node.uuid AS search_node_uuid,
[x IN deduped_nodes | {
uuid: x.uuid,
name: x.name,
name_embedding: x.name_embedding,
group_id: x.group_id,
created_at: x.created_at,
summary: x.summary,
labels: labels(x),
attributes: properties(x)
}] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
nodes=query_nodes,
group_id=group_id,
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
relevant_nodes_dict: dict[str, list[EntityNode]] = {
result['search_node_uuid']: [
get_entity_node_from_record(record, driver.provider) for record in result['matches']
]
for result in results
}
relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
return relevant_nodes
async def get_relevant_edges(
driver: GraphDriver,
edges: list[EntityEdge],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityEdge]]:
if len(edges) == 0:
return []
filter_queries, filter_params = edge_search_filter_query_constructor(
search_filter, driver.provider
)
filter_query = ''
if filter_queries:
filter_query = ' WHERE ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
edge.fact_embedding as target_embedding
"""
)
resp, _, _ = await driver.execute_query(
query,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
score = calculate_cosine_similarity(
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
# Match the edge ides and return the values
query = """
UNWIND $ids AS edge
MATCH ()-[e]->()
WHERE id(e) = edge.id
WITH edge, e
ORDER BY edge.score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
episodes: split(e.episodes, ","),
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
results, _, _ = await driver.execute_query(
query,
ids=input_ids,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
else:
if driver.provider == GraphProvider.KUZU:
embedding_size = (
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
)
if embedding_size == 0:
return []
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, n, m, """
+ get_vector_cosine_func_query(
'e.fact_embedding',
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
driver.provider,
)
+ """ AS score
WHERE score > $min_score
WITH e, edge, n, m, score
ORDER BY score DESC
LIMIT $limit
RETURN
edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: n.uuid,
target_node_uuid: m.uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: e.attributes
}) AS matches
"""
)
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
"""
+ filter_query
+ """
WITH e, edge, """
+ get_vector_cosine_func_query(
'e.fact_embedding', 'edge.fact_embedding', driver.provider
)
+ """ AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN
edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
relevant_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
]
for result in results
}
relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
return relevant_edges
async def get_edge_invalidation_candidates(
driver: GraphDriver,
edges: list[EntityEdge],
search_filter: SearchFilters,
min_score: float = DEFAULT_MIN_SCORE,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityEdge]]:
if len(edges) == 0:
return []
filter_queries, filter_params = edge_search_filter_query_constructor(
search_filter, driver.provider
)
filter_query = ''
if filter_queries:
filter_query = ' AND ' + (' AND '.join(filter_queries))
if driver.provider == GraphProvider.NEPTUNE:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH e, edge
RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
edge.fact_embedding as target_embedding,
edge.uuid as search_edge_uuid
"""
)
resp, _, _ = await driver.execute_query(
query,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
# Calculate Cosine similarity then return the edge ids
input_ids = []
for r in resp:
score = calculate_cosine_similarity(
list(map(float, r['source_embedding'].split(','))), r['target_embedding']
)
if score > min_score:
input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
# Match the edge ides and return the values
query = """
UNWIND $ids AS edge
MATCH ()-[e]->()
WHERE id(e) = edge.id
WITH edge, e
ORDER BY edge.score DESC
RETURN edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
episodes: split(e.episodes, ","),
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
results, _, _ = await driver.execute_query(
query,
ids=input_ids,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
else:
if driver.provider == GraphProvider.KUZU:
embedding_size = (
len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
)
if embedding_size == 0:
return []
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
"""
+ filter_query
+ """
WITH edge, e, n, m, """
+ get_vector_cosine_func_query(
'e.fact_embedding',
f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
driver.provider,
)
+ """ AS score
WHERE score > $min_score
WITH edge, e, n, m, score
ORDER BY score DESC
LIMIT $limit
RETURN
edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: n.uuid,
target_node_uuid: m.uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: e.attributes
}) AS matches
"""
)
else:
query = (
"""
UNWIND $edges AS edge
MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
"""
+ filter_query
+ """
WITH edge, e, """
+ get_vector_cosine_func_query(
'e.fact_embedding', 'edge.fact_embedding', driver.provider
)
+ """ AS score
WHERE score > $min_score
WITH edge, e, score
ORDER BY score DESC
RETURN
edge.uuid AS search_edge_uuid,
collect({
uuid: e.uuid,
source_node_uuid: startNode(e).uuid,
target_node_uuid: endNode(e).uuid,
created_at: e.created_at,
name: e.name,
group_id: e.group_id,
fact: e.fact,
fact_embedding: e.fact_embedding,
episodes: e.episodes,
expired_at: e.expired_at,
valid_at: e.valid_at,
invalid_at: e.invalid_at,
attributes: properties(e)
})[..$limit] AS matches
"""
)
results, _, _ = await driver.execute_query(
query,
edges=[edge.model_dump() for edge in edges],
limit=limit,
min_score=min_score,
routing_='r',
**filter_params,
)
invalidation_edges_dict: dict[str, list[EntityEdge]] = {
result['search_edge_uuid']: [
get_entity_edge_from_record(record, driver.provider) for record in result['matches']
]
for result in results
}
invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
return invalidation_edges
# takes in a list of rankings of uuids
def rrf(
results: list[list[str]], rank_const=1, min_score: float = 0
) -> tuple[list[str], list[float]]:
scores: dict[str, float] = defaultdict(float)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)
scored_uuids = [term for term in scores.items()]
scored_uuids.sort(reverse=True, key=lambda term: term[1])
sorted_uuids = [term[0] for term in scored_uuids]
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
]
async def node_distance_reranker(
driver: GraphDriver,
node_uuids: list[str],
center_node_uuid: str,
min_score: float = 0,
) -> tuple[list[str], list[float]]:
# filter out node_uuid center node node uuid
filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
scores: dict[str, float] = {center_node_uuid: 0.0}
query = """
UNWIND $node_uuids AS node_uuid
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
RETURN 1 AS score, node_uuid AS uuid
"""
if driver.provider == GraphProvider.KUZU:
query = """
UNWIND $node_uuids AS node_uuid
MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
RETURN 1 AS score, node_uuid AS uuid
"""
# Find the shortest path to center node
results, header, _ = await driver.execute_query(
query,
node_uuids=filtered_uuids,
center_uuid=center_node_uuid,
routing_='r',
)
if driver.provider == GraphProvider.FALKORDB:
results = [dict(zip(header, row, strict=True)) for row in results]
for result in results:
uuid = result['uuid']
score = result['score']
scores[uuid] = score
for uuid in filtered_uuids:
if uuid not in scores:
scores[uuid] = float('inf')
# rerank on shortest distance
filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
# add back in filtered center uuid if it was filtered out
if center_node_uuid in node_uuids:
scores[center_node_uuid] = 0.1
filtered_uuids = [center_node_uuid] + filtered_uuids
return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
]
async def episode_mentions_reranker(
driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
) -> tuple[list[str], list[float]]:
# use rrf as a preliminary ranker
sorted_uuids, _ = rrf(node_uuids)
scores: dict[str, float] = {}
# Find the shortest path to center node
results, _, _ = await driver.execute_query(
"""
UNWIND $node_uuids AS node_uuid
MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
RETURN count(*) AS score, n.uuid AS uuid
""",
node_uuids=sorted_uuids,
routing_='r',
)
for result in results:
scores[result['uuid']] = result['score']
for uuid in sorted_uuids:
if uuid not in scores:
scores[uuid] = float('inf')
# rerank on shortest distance
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
]
def maximal_marginal_relevance(
query_vector: list[float],
candidates: dict[str, list[float]],
mmr_lambda: float = DEFAULT_MMR_LAMBDA,
min_score: float = -2.0,
) -> tuple[list[str], list[float]]:
start = time()
query_array = np.array(query_vector)
candidate_arrays: dict[str, NDArray] = {}
for uuid, embedding in candidates.items():
candidate_arrays[uuid] = normalize_l2(embedding)
uuids: list[str] = list(candidate_arrays.keys())
similarity_matrix = np.zeros((len(uuids), len(uuids)))
for i, uuid_1 in enumerate(uuids):
for j, uuid_2 in enumerate(uuids[:i]):
u = candidate_arrays[uuid_1]
v = candidate_arrays[uuid_2]
similarity = np.dot(u, v)
similarity_matrix[i, j] = similarity
similarity_matrix[j, i] = similarity
mmr_scores: dict[str, float] = {}
for i, uuid in enumerate(uuids):
max_sim = np.max(similarity_matrix[i, :])
mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
mmr_scores[uuid] = mmr
uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
end = time()
logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
]
async def get_embeddings_for_nodes(
driver: GraphDriver, nodes: list[EntityNode]
) -> dict[str, list[float]]:
if driver.graph_operations_interface:
return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
elif driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (n:Entity)
WHERE n.uuid IN $node_uuids
RETURN DISTINCT
n.uuid AS uuid,
split(n.name_embedding, ",") AS name_embedding
"""
else:
query = """
MATCH (n:Entity)
WHERE n.uuid IN $node_uuids
RETURN DISTINCT
n.uuid AS uuid,
n.name_embedding AS name_embedding
"""
results, _, _ = await driver.execute_query(
query,
node_uuids=[node.uuid for node in nodes],
routing_='r',
)
embeddings_dict: dict[str, list[float]] = {}
for result in results:
uuid: str = result.get('uuid')
embedding: list[float] = result.get('name_embedding')
if uuid is not None and embedding is not None:
embeddings_dict[uuid] = embedding
return embeddings_dict
async def get_embeddings_for_communities(
driver: GraphDriver, communities: list[CommunityNode]
) -> dict[str, list[float]]:
if driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (c:Community)
WHERE c.uuid IN $community_uuids
RETURN DISTINCT
c.uuid AS uuid,
split(c.name_embedding, ",") AS name_embedding
"""
else:
query = """
MATCH (c:Community)
WHERE c.uuid IN $community_uuids
RETURN DISTINCT
c.uuid AS uuid,
c.name_embedding AS name_embedding
"""
results, _, _ = await driver.execute_query(
query,
community_uuids=[community.uuid for community in communities],
routing_='r',
)
embeddings_dict: dict[str, list[float]] = {}
for result in results:
uuid: str = result.get('uuid')
embedding: list[float] = result.get('name_embedding')
if uuid is not None and embedding is not None:
embeddings_dict[uuid] = embedding
return embeddings_dict
async def get_embeddings_for_edges(
driver: GraphDriver, edges: list[EntityEdge]
) -> dict[str, list[float]]:
if driver.graph_operations_interface:
return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
elif driver.provider == GraphProvider.NEPTUNE:
query = """
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
WHERE e.uuid IN $edge_uuids
RETURN DISTINCT
e.uuid AS uuid,
split(e.fact_embedding, ",") AS fact_embedding
"""
else:
match_query = """
MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
"""
if driver.provider == GraphProvider.KUZU:
match_query = """
MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
"""
query = (
match_query
+ """
WHERE e.uuid IN $edge_uuids
RETURN DISTINCT
e.uuid AS uuid,
e.fact_embedding AS fact_embedding
"""
)
results, _, _ = await driver.execute_query(
query,
edge_uuids=[edge.uuid for edge in edges],
routing_='r',
)
embeddings_dict: dict[str, list[float]] = {}
for result in results:
uuid: str = result.get('uuid')
embedding: list[float] = result.get('fact_embedding')
if uuid is not None and embedding is not None:
embeddings_dict[uuid] = embedding
return embeddings_dict
```