#
tokens: 37642/50000 2/234 files (page 10/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 10 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/graphiti.py:
--------------------------------------------------------------------------------

```python
   1 | """
   2 | Copyright 2024, Zep Software, Inc.
   3 | 
   4 | Licensed under the Apache License, Version 2.0 (the "License");
   5 | you may not use this file except in compliance with the License.
   6 | You may obtain a copy of the License at
   7 | 
   8 |     http://www.apache.org/licenses/LICENSE-2.0
   9 | 
  10 | Unless required by applicable law or agreed to in writing, software
  11 | distributed under the License is distributed on an "AS IS" BASIS,
  12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 | See the License for the specific language governing permissions and
  14 | limitations under the License.
  15 | """
  16 | 
  17 | import logging
  18 | from datetime import datetime
  19 | from time import time
  20 | 
  21 | from dotenv import load_dotenv
  22 | from pydantic import BaseModel
  23 | from typing_extensions import LiteralString
  24 | 
  25 | from graphiti_core.cross_encoder.client import CrossEncoderClient
  26 | from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
  27 | from graphiti_core.decorators import handle_multiple_group_ids
  28 | from graphiti_core.driver.driver import GraphDriver
  29 | from graphiti_core.driver.neo4j_driver import Neo4jDriver
  30 | from graphiti_core.edges import (
  31 |     CommunityEdge,
  32 |     Edge,
  33 |     EntityEdge,
  34 |     EpisodicEdge,
  35 |     create_entity_edge_embeddings,
  36 | )
  37 | from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
  38 | from graphiti_core.graphiti_types import GraphitiClients
  39 | from graphiti_core.helpers import (
  40 |     get_default_group_id,
  41 |     semaphore_gather,
  42 |     validate_excluded_entity_types,
  43 |     validate_group_id,
  44 | )
  45 | from graphiti_core.llm_client import LLMClient, OpenAIClient
  46 | from graphiti_core.nodes import (
  47 |     CommunityNode,
  48 |     EntityNode,
  49 |     EpisodeType,
  50 |     EpisodicNode,
  51 |     Node,
  52 |     create_entity_node_embeddings,
  53 | )
  54 | from graphiti_core.search.search import SearchConfig, search
  55 | from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
  56 | from graphiti_core.search.search_config_recipes import (
  57 |     COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
  58 |     EDGE_HYBRID_SEARCH_NODE_DISTANCE,
  59 |     EDGE_HYBRID_SEARCH_RRF,
  60 | )
  61 | from graphiti_core.search.search_filters import SearchFilters
  62 | from graphiti_core.search.search_utils import (
  63 |     RELEVANT_SCHEMA_LIMIT,
  64 |     get_mentioned_nodes,
  65 | )
  66 | from graphiti_core.telemetry import capture_event
  67 | from graphiti_core.tracer import Tracer, create_tracer
  68 | from graphiti_core.utils.bulk_utils import (
  69 |     RawEpisode,
  70 |     add_nodes_and_edges_bulk,
  71 |     dedupe_edges_bulk,
  72 |     dedupe_nodes_bulk,
  73 |     extract_nodes_and_edges_bulk,
  74 |     resolve_edge_pointers,
  75 |     retrieve_previous_episodes_bulk,
  76 | )
  77 | from graphiti_core.utils.datetime_utils import utc_now
  78 | from graphiti_core.utils.maintenance.community_operations import (
  79 |     build_communities,
  80 |     remove_communities,
  81 |     update_community,
  82 | )
  83 | from graphiti_core.utils.maintenance.edge_operations import (
  84 |     build_episodic_edges,
  85 |     extract_edges,
  86 |     resolve_extracted_edge,
  87 |     resolve_extracted_edges,
  88 | )
  89 | from graphiti_core.utils.maintenance.graph_data_operations import (
  90 |     EPISODE_WINDOW_LEN,
  91 |     retrieve_episodes,
  92 | )
  93 | from graphiti_core.utils.maintenance.node_operations import (
  94 |     extract_attributes_from_nodes,
  95 |     extract_nodes,
  96 |     resolve_extracted_nodes,
  97 | )
  98 | from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
  99 | 
 100 | logger = logging.getLogger(__name__)
 101 | 
 102 | load_dotenv()
 103 | 
 104 | 
 105 | class AddEpisodeResults(BaseModel):
 106 |     episode: EpisodicNode
 107 |     episodic_edges: list[EpisodicEdge]
 108 |     nodes: list[EntityNode]
 109 |     edges: list[EntityEdge]
 110 |     communities: list[CommunityNode]
 111 |     community_edges: list[CommunityEdge]
 112 | 
 113 | 
 114 | class AddBulkEpisodeResults(BaseModel):
 115 |     episodes: list[EpisodicNode]
 116 |     episodic_edges: list[EpisodicEdge]
 117 |     nodes: list[EntityNode]
 118 |     edges: list[EntityEdge]
 119 |     communities: list[CommunityNode]
 120 |     community_edges: list[CommunityEdge]
 121 | 
 122 | 
 123 | class AddTripletResults(BaseModel):
 124 |     nodes: list[EntityNode]
 125 |     edges: list[EntityEdge]
 126 | 
 127 | 
 128 | class Graphiti:
 129 |     def __init__(
 130 |         self,
 131 |         uri: str | None = None,
 132 |         user: str | None = None,
 133 |         password: str | None = None,
 134 |         llm_client: LLMClient | None = None,
 135 |         embedder: EmbedderClient | None = None,
 136 |         cross_encoder: CrossEncoderClient | None = None,
 137 |         store_raw_episode_content: bool = True,
 138 |         graph_driver: GraphDriver | None = None,
 139 |         max_coroutines: int | None = None,
 140 |         tracer: Tracer | None = None,
 141 |         trace_span_prefix: str = 'graphiti',
 142 |     ):
 143 |         """
 144 |         Initialize a Graphiti instance.
 145 | 
 146 |         This constructor sets up a connection to a graph database and initializes
 147 |         the LLM client for natural language processing tasks.
 148 | 
 149 |         Parameters
 150 |         ----------
 151 |         uri : str
 152 |             The URI of the Neo4j database.
 153 |         user : str
 154 |             The username for authenticating with the Neo4j database.
 155 |         password : str
 156 |             The password for authenticating with the Neo4j database.
 157 |         llm_client : LLMClient | None, optional
 158 |             An instance of LLMClient for natural language processing tasks.
 159 |             If not provided, a default OpenAIClient will be initialized.
 160 |         embedder : EmbedderClient | None, optional
 161 |             An instance of EmbedderClient for embedding tasks.
 162 |             If not provided, a default OpenAIEmbedder will be initialized.
 163 |         cross_encoder : CrossEncoderClient | None, optional
 164 |             An instance of CrossEncoderClient for reranking tasks.
 165 |             If not provided, a default OpenAIRerankerClient will be initialized.
 166 |         store_raw_episode_content : bool, optional
 167 |             Whether to store the raw content of episodes. Defaults to True.
 168 |         graph_driver : GraphDriver | None, optional
 169 |             An instance of GraphDriver for database operations.
 170 |             If not provided, a default Neo4jDriver will be initialized.
 171 |         max_coroutines : int | None, optional
 172 |             The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
 173 |             If not set, the Graphiti default is used.
 174 |         tracer : Tracer | None, optional
 175 |             An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
 176 |         trace_span_prefix : str, optional
 177 |             Prefix to prepend to all span names. Defaults to 'graphiti'.
 178 | 
 179 |         Returns
 180 |         -------
 181 |         None
 182 | 
 183 |         Notes
 184 |         -----
 185 |         This method establishes a connection to a graph database (Neo4j by default) using the provided
 186 |         credentials. It also sets up the LLM client, either using the provided client
 187 |         or by creating a default OpenAIClient.
 188 | 
 189 |         The default database name is defined during the driver’s construction. If a different database name
 190 |         is required, it should be specified in the URI or set separately after
 191 |         initialization.
 192 | 
 193 |         The OpenAI API key is expected to be set in the environment variables.
 194 |         Make sure to set the OPENAI_API_KEY environment variable before initializing
 195 |         Graphiti if you're using the default OpenAIClient.
 196 |         """
 197 | 
 198 |         if graph_driver:
 199 |             self.driver = graph_driver
 200 |         else:
 201 |             if uri is None:
 202 |                 raise ValueError('uri must be provided when graph_driver is None')
 203 |             self.driver = Neo4jDriver(uri, user, password)
 204 | 
 205 |         self.store_raw_episode_content = store_raw_episode_content
 206 |         self.max_coroutines = max_coroutines
 207 |         if llm_client:
 208 |             self.llm_client = llm_client
 209 |         else:
 210 |             self.llm_client = OpenAIClient()
 211 |         if embedder:
 212 |             self.embedder = embedder
 213 |         else:
 214 |             self.embedder = OpenAIEmbedder()
 215 |         if cross_encoder:
 216 |             self.cross_encoder = cross_encoder
 217 |         else:
 218 |             self.cross_encoder = OpenAIRerankerClient()
 219 | 
 220 |         # Initialize tracer
 221 |         self.tracer = create_tracer(tracer, trace_span_prefix)
 222 | 
 223 |         # Set tracer on clients
 224 |         self.llm_client.set_tracer(self.tracer)
 225 | 
 226 |         self.clients = GraphitiClients(
 227 |             driver=self.driver,
 228 |             llm_client=self.llm_client,
 229 |             embedder=self.embedder,
 230 |             cross_encoder=self.cross_encoder,
 231 |             tracer=self.tracer,
 232 |         )
 233 | 
 234 |         # Capture telemetry event
 235 |         self._capture_initialization_telemetry()
 236 | 
 237 |     def _capture_initialization_telemetry(self):
 238 |         """Capture telemetry event for Graphiti initialization."""
 239 |         try:
 240 |             # Detect provider types from class names
 241 |             llm_provider = self._get_provider_type(self.llm_client)
 242 |             embedder_provider = self._get_provider_type(self.embedder)
 243 |             reranker_provider = self._get_provider_type(self.cross_encoder)
 244 |             database_provider = self._get_provider_type(self.driver)
 245 | 
 246 |             properties = {
 247 |                 'llm_provider': llm_provider,
 248 |                 'embedder_provider': embedder_provider,
 249 |                 'reranker_provider': reranker_provider,
 250 |                 'database_provider': database_provider,
 251 |             }
 252 | 
 253 |             capture_event('graphiti_initialized', properties)
 254 |         except Exception:
 255 |             # Silently handle telemetry errors
 256 |             pass
 257 | 
 258 |     def _get_provider_type(self, client) -> str:
 259 |         """Get provider type from client class name."""
 260 |         if client is None:
 261 |             return 'none'
 262 | 
 263 |         class_name = client.__class__.__name__.lower()
 264 | 
 265 |         # LLM providers
 266 |         if 'openai' in class_name:
 267 |             return 'openai'
 268 |         elif 'azure' in class_name:
 269 |             return 'azure'
 270 |         elif 'anthropic' in class_name:
 271 |             return 'anthropic'
 272 |         elif 'crossencoder' in class_name:
 273 |             return 'crossencoder'
 274 |         elif 'gemini' in class_name:
 275 |             return 'gemini'
 276 |         elif 'groq' in class_name:
 277 |             return 'groq'
 278 |         # Database providers
 279 |         elif 'neo4j' in class_name:
 280 |             return 'neo4j'
 281 |         elif 'falkor' in class_name:
 282 |             return 'falkordb'
 283 |         # Embedder providers
 284 |         elif 'voyage' in class_name:
 285 |             return 'voyage'
 286 |         else:
 287 |             return 'unknown'
 288 | 
 289 |     async def close(self):
 290 |         """
 291 |         Close the connection to the Neo4j database.
 292 | 
 293 |         This method safely closes the driver connection to the Neo4j database.
 294 |         It should be called when the Graphiti instance is no longer needed or
 295 |         when the application is shutting down.
 296 | 
 297 |         Parameters
 298 |         ----------
 299 |         self
 300 | 
 301 |         Returns
 302 |         -------
 303 |         None
 304 | 
 305 |         Notes
 306 |         -----
 307 |         It's important to close the driver connection to release system resources
 308 |         and ensure that all pending transactions are completed or rolled back.
 309 |         This method should be called as part of a cleanup process, potentially
 310 |         in a context manager or a shutdown hook.
 311 | 
 312 |         Example:
 313 |             graphiti = Graphiti(uri, user, password)
 314 |             try:
 315 |                 # Use graphiti...
 316 |             finally:
 317 |                 graphiti.close()
 318 |         """
 319 |         await self.driver.close()
 320 | 
 321 |     async def build_indices_and_constraints(self, delete_existing: bool = False):
 322 |         """
 323 |         Build indices and constraints in the Neo4j database.
 324 | 
 325 |         This method sets up the necessary indices and constraints in the Neo4j database
 326 |         to optimize query performance and ensure data integrity for the knowledge graph.
 327 | 
 328 |         Parameters
 329 |         ----------
 330 |         self
 331 |         delete_existing : bool, optional
 332 |             Whether to clear existing indices before creating new ones.
 333 | 
 334 | 
 335 |         Returns
 336 |         -------
 337 |         None
 338 | 
 339 |         Notes
 340 |         -----
 341 |         This method should typically be called once during the initial setup of the
 342 |         knowledge graph or when updating the database schema. It uses the
 343 |         driver's `build_indices_and_constraints` method to perform
 344 |         the actual database operations.
 345 | 
 346 |         The specific indices and constraints created depend on the implementation
 347 |         of the driver's `build_indices_and_constraints` method. Refer to the specific
 348 |         driver documentation for details on the exact database schema modifications.
 349 | 
 350 |         Caution: Running this method on a large existing database may take some time
 351 |         and could impact database performance during execution.
 352 |         """
 353 |         await self.driver.build_indices_and_constraints(delete_existing)
 354 | 
 355 |     async def _extract_and_resolve_nodes(
 356 |         self,
 357 |         episode: EpisodicNode,
 358 |         previous_episodes: list[EpisodicNode],
 359 |         entity_types: dict[str, type[BaseModel]] | None,
 360 |         excluded_entity_types: list[str] | None,
 361 |     ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
 362 |         """Extract nodes from episode and resolve against existing graph."""
 363 |         extracted_nodes = await extract_nodes(
 364 |             self.clients, episode, previous_episodes, entity_types, excluded_entity_types
 365 |         )
 366 | 
 367 |         nodes, uuid_map, duplicates = await resolve_extracted_nodes(
 368 |             self.clients,
 369 |             extracted_nodes,
 370 |             episode,
 371 |             previous_episodes,
 372 |             entity_types,
 373 |         )
 374 | 
 375 |         return nodes, uuid_map, duplicates
 376 | 
 377 |     async def _extract_and_resolve_edges(
 378 |         self,
 379 |         episode: EpisodicNode,
 380 |         extracted_nodes: list[EntityNode],
 381 |         previous_episodes: list[EpisodicNode],
 382 |         edge_type_map: dict[tuple[str, str], list[str]],
 383 |         group_id: str,
 384 |         edge_types: dict[str, type[BaseModel]] | None,
 385 |         nodes: list[EntityNode],
 386 |         uuid_map: dict[str, str],
 387 |     ) -> tuple[list[EntityEdge], list[EntityEdge]]:
 388 |         """Extract edges from episode and resolve against existing graph."""
 389 |         extracted_edges = await extract_edges(
 390 |             self.clients,
 391 |             episode,
 392 |             extracted_nodes,
 393 |             previous_episodes,
 394 |             edge_type_map,
 395 |             group_id,
 396 |             edge_types,
 397 |         )
 398 | 
 399 |         edges = resolve_edge_pointers(extracted_edges, uuid_map)
 400 | 
 401 |         resolved_edges, invalidated_edges = await resolve_extracted_edges(
 402 |             self.clients,
 403 |             edges,
 404 |             episode,
 405 |             nodes,
 406 |             edge_types or {},
 407 |             edge_type_map,
 408 |         )
 409 | 
 410 |         return resolved_edges, invalidated_edges
 411 | 
 412 |     async def _process_episode_data(
 413 |         self,
 414 |         episode: EpisodicNode,
 415 |         nodes: list[EntityNode],
 416 |         entity_edges: list[EntityEdge],
 417 |         now: datetime,
 418 |     ) -> tuple[list[EpisodicEdge], EpisodicNode]:
 419 |         """Process and save episode data to the graph."""
 420 |         episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
 421 |         episode.entity_edges = [edge.uuid for edge in entity_edges]
 422 | 
 423 |         if not self.store_raw_episode_content:
 424 |             episode.content = ''
 425 | 
 426 |         await add_nodes_and_edges_bulk(
 427 |             self.driver,
 428 |             [episode],
 429 |             episodic_edges,
 430 |             nodes,
 431 |             entity_edges,
 432 |             self.embedder,
 433 |         )
 434 | 
 435 |         return episodic_edges, episode
 436 | 
 437 |     async def _extract_and_dedupe_nodes_bulk(
 438 |         self,
 439 |         episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
 440 |         edge_type_map: dict[tuple[str, str], list[str]],
 441 |         edge_types: dict[str, type[BaseModel]] | None,
 442 |         entity_types: dict[str, type[BaseModel]] | None,
 443 |         excluded_entity_types: list[str] | None,
 444 |     ) -> tuple[
 445 |         dict[str, list[EntityNode]],
 446 |         dict[str, str],
 447 |         list[list[EntityEdge]],
 448 |     ]:
 449 |         """Extract nodes and edges from all episodes and deduplicate."""
 450 |         # Extract all nodes and edges for each episode
 451 |         extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
 452 |             self.clients,
 453 |             episode_context,
 454 |             edge_type_map=edge_type_map,
 455 |             edge_types=edge_types,
 456 |             entity_types=entity_types,
 457 |             excluded_entity_types=excluded_entity_types,
 458 |         )
 459 | 
 460 |         # Dedupe extracted nodes in memory
 461 |         nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
 462 |             self.clients, extracted_nodes_bulk, episode_context, entity_types
 463 |         )
 464 | 
 465 |         return nodes_by_episode, uuid_map, extracted_edges_bulk
 466 | 
 467 |     async def _resolve_nodes_and_edges_bulk(
 468 |         self,
 469 |         nodes_by_episode: dict[str, list[EntityNode]],
 470 |         edges_by_episode: dict[str, list[EntityEdge]],
 471 |         episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
 472 |         entity_types: dict[str, type[BaseModel]] | None,
 473 |         edge_types: dict[str, type[BaseModel]] | None,
 474 |         edge_type_map: dict[tuple[str, str], list[str]],
 475 |         episodes: list[EpisodicNode],
 476 |     ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
 477 |         """Resolve nodes and edges against the existing graph."""
 478 |         nodes_by_uuid: dict[str, EntityNode] = {
 479 |             node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
 480 |         }
 481 | 
 482 |         # Get unique nodes per episode
 483 |         nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
 484 |         nodes_uuid_set: set[str] = set()
 485 |         for episode, _ in episode_context:
 486 |             nodes_by_episode_unique[episode.uuid] = []
 487 |             nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
 488 |             for node in nodes:
 489 |                 if node.uuid not in nodes_uuid_set:
 490 |                     nodes_by_episode_unique[episode.uuid].append(node)
 491 |                     nodes_uuid_set.add(node.uuid)
 492 | 
 493 |         # Resolve nodes
 494 |         node_results = await semaphore_gather(
 495 |             *[
 496 |                 resolve_extracted_nodes(
 497 |                     self.clients,
 498 |                     nodes_by_episode_unique[episode.uuid],
 499 |                     episode,
 500 |                     previous_episodes,
 501 |                     entity_types,
 502 |                 )
 503 |                 for episode, previous_episodes in episode_context
 504 |             ]
 505 |         )
 506 | 
 507 |         resolved_nodes: list[EntityNode] = []
 508 |         uuid_map: dict[str, str] = {}
 509 |         for result in node_results:
 510 |             resolved_nodes.extend(result[0])
 511 |             uuid_map.update(result[1])
 512 | 
 513 |         # Update nodes_by_uuid with resolved nodes
 514 |         for resolved_node in resolved_nodes:
 515 |             nodes_by_uuid[resolved_node.uuid] = resolved_node
 516 | 
 517 |         # Update nodes_by_episode_unique with resolved pointers
 518 |         for episode_uuid, nodes in nodes_by_episode_unique.items():
 519 |             updated_nodes: list[EntityNode] = []
 520 |             for node in nodes:
 521 |                 updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
 522 |                 updated_node = nodes_by_uuid[updated_node_uuid]
 523 |                 updated_nodes.append(updated_node)
 524 |             nodes_by_episode_unique[episode_uuid] = updated_nodes
 525 | 
 526 |         # Extract attributes for resolved nodes
 527 |         hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
 528 |             *[
 529 |                 extract_attributes_from_nodes(
 530 |                     self.clients,
 531 |                     nodes_by_episode_unique[episode.uuid],
 532 |                     episode,
 533 |                     previous_episodes,
 534 |                     entity_types,
 535 |                 )
 536 |                 for episode, previous_episodes in episode_context
 537 |             ]
 538 |         )
 539 | 
 540 |         final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
 541 | 
 542 |         # Resolve edges with updated pointers
 543 |         edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
 544 |         edges_uuid_set: set[str] = set()
 545 |         for episode_uuid, edges in edges_by_episode.items():
 546 |             edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
 547 |             edges_by_episode_unique[episode_uuid] = []
 548 | 
 549 |             for edge in edges_with_updated_pointers:
 550 |                 if edge.uuid not in edges_uuid_set:
 551 |                     edges_by_episode_unique[episode_uuid].append(edge)
 552 |                     edges_uuid_set.add(edge.uuid)
 553 | 
 554 |         edge_results = await semaphore_gather(
 555 |             *[
 556 |                 resolve_extracted_edges(
 557 |                     self.clients,
 558 |                     edges_by_episode_unique[episode.uuid],
 559 |                     episode,
 560 |                     final_hydrated_nodes,
 561 |                     edge_types or {},
 562 |                     edge_type_map,
 563 |                 )
 564 |                 for episode in episodes
 565 |             ]
 566 |         )
 567 | 
 568 |         resolved_edges: list[EntityEdge] = []
 569 |         invalidated_edges: list[EntityEdge] = []
 570 |         for result in edge_results:
 571 |             resolved_edges.extend(result[0])
 572 |             invalidated_edges.extend(result[1])
 573 | 
 574 |         return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
 575 | 
 576 |     @handle_multiple_group_ids
 577 |     async def retrieve_episodes(
 578 |         self,
 579 |         reference_time: datetime,
 580 |         last_n: int = EPISODE_WINDOW_LEN,
 581 |         group_ids: list[str] | None = None,
 582 |         source: EpisodeType | None = None,
 583 |         driver: GraphDriver | None = None,
 584 |     ) -> list[EpisodicNode]:
 585 |         """
 586 |         Retrieve the last n episodic nodes from the graph.
 587 | 
 588 |         This method fetches a specified number of the most recent episodic nodes
 589 |         from the graph, relative to the given reference time.
 590 | 
 591 |         Parameters
 592 |         ----------
 593 |         reference_time : datetime
 594 |             The reference time to retrieve episodes before.
 595 |         last_n : int, optional
 596 |             The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
 597 |         group_ids : list[str | None], optional
 598 |             The group ids to return data from.
 599 | 
 600 |         Returns
 601 |         -------
 602 |         list[EpisodicNode]
 603 |             A list of the most recent EpisodicNode objects.
 604 | 
 605 |         Notes
 606 |         -----
 607 |         The actual retrieval is performed by the `retrieve_episodes` function
 608 |         from the `graphiti_core.utils` module.
 609 |         """
 610 |         if driver is None:
 611 |             driver = self.clients.driver
 612 | 
 613 |         return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
 614 | 
 615 |     async def add_episode(
 616 |         self,
 617 |         name: str,
 618 |         episode_body: str,
 619 |         source_description: str,
 620 |         reference_time: datetime,
 621 |         source: EpisodeType = EpisodeType.message,
 622 |         group_id: str | None = None,
 623 |         uuid: str | None = None,
 624 |         update_communities: bool = False,
 625 |         entity_types: dict[str, type[BaseModel]] | None = None,
 626 |         excluded_entity_types: list[str] | None = None,
 627 |         previous_episode_uuids: list[str] | None = None,
 628 |         edge_types: dict[str, type[BaseModel]] | None = None,
 629 |         edge_type_map: dict[tuple[str, str], list[str]] | None = None,
 630 |     ) -> AddEpisodeResults:
 631 |         """
 632 |         Process an episode and update the graph.
 633 | 
 634 |         This method extracts information from the episode, creates nodes and edges,
 635 |         and updates the graph database accordingly.
 636 | 
 637 |         Parameters
 638 |         ----------
 639 |         name : str
 640 |             The name of the episode.
 641 |         episode_body : str
 642 |             The content of the episode.
 643 |         source_description : str
 644 |             A description of the episode's source.
 645 |         reference_time : datetime
 646 |             The reference time for the episode.
 647 |         source : EpisodeType, optional
 648 |             The type of the episode. Defaults to EpisodeType.message.
 649 |         group_id : str | None
 650 |             An id for the graph partition the episode is a part of.
 651 |         uuid : str | None
 652 |             Optional uuid of the episode.
 653 |         update_communities : bool
 654 |             Optional. Whether to update communities with new node information
 655 |         entity_types : dict[str, BaseModel] | None
 656 |             Optional. Dictionary mapping entity type names to their Pydantic model definitions.
 657 |         excluded_entity_types : list[str] | None
 658 |             Optional. List of entity type names to exclude from the graph. Entities classified
 659 |             into these types will not be added to the graph. Can include 'Entity' to exclude
 660 |             the default entity type.
 661 |         previous_episode_uuids : list[str] | None
 662 |             Optional.  list of episode uuids to use as the previous episodes. If this is not provided,
 663 |             the most recent episodes by created_at date will be used.
 664 | 
 665 |         Returns
 666 |         -------
 667 |         None
 668 | 
 669 |         Notes
 670 |         -----
 671 |         This method performs several steps including node extraction, edge extraction,
 672 |         deduplication, and database updates. It also handles embedding generation
 673 |         and edge invalidation.
 674 | 
 675 |         It is recommended to run this method as a background process, such as in a queue.
 676 |         It's important that each episode is added sequentially and awaited before adding
 677 |         the next one. For web applications, consider using FastAPI's background tasks
 678 |         or a dedicated task queue like Celery for this purpose.
 679 | 
 680 |         Example using FastAPI background tasks:
 681 |             @app.post("/add_episode")
 682 |             async def add_episode_endpoint(episode_data: EpisodeData):
 683 |                 background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
 684 |                 return {"message": "Episode processing started"}
 685 |         """
 686 |         start = time()
 687 |         now = utc_now()
 688 | 
 689 |         validate_entity_types(entity_types)
 690 |         validate_excluded_entity_types(excluded_entity_types, entity_types)
 691 | 
 692 |         if group_id is None:
 693 |             # if group_id is None, use the default group id by the provider
 694 |             # and the preset database name will be used
 695 |             group_id = get_default_group_id(self.driver.provider)
 696 |         else:
 697 |             validate_group_id(group_id)
 698 |             if group_id != self.driver._database:
 699 |                 # if group_id is provided, use it as the database name
 700 |                 self.driver = self.driver.clone(database=group_id)
 701 |                 self.clients.driver = self.driver
 702 | 
 703 |         with self.tracer.start_span('add_episode') as span:
 704 |             try:
 705 |                 # Retrieve previous episodes for context
 706 |                 previous_episodes = (
 707 |                     await self.retrieve_episodes(
 708 |                         reference_time,
 709 |                         last_n=RELEVANT_SCHEMA_LIMIT,
 710 |                         group_ids=[group_id],
 711 |                         source=source,
 712 |                     )
 713 |                     if previous_episode_uuids is None
 714 |                     else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
 715 |                 )
 716 | 
 717 |                 # Get or create episode
 718 |                 episode = (
 719 |                     await EpisodicNode.get_by_uuid(self.driver, uuid)
 720 |                     if uuid is not None
 721 |                     else EpisodicNode(
 722 |                         name=name,
 723 |                         group_id=group_id,
 724 |                         labels=[],
 725 |                         source=source,
 726 |                         content=episode_body,
 727 |                         source_description=source_description,
 728 |                         created_at=now,
 729 |                         valid_at=reference_time,
 730 |                     )
 731 |                 )
 732 | 
 733 |                 # Create default edge type map
 734 |                 edge_type_map_default = (
 735 |                     {('Entity', 'Entity'): list(edge_types.keys())}
 736 |                     if edge_types is not None
 737 |                     else {('Entity', 'Entity'): []}
 738 |                 )
 739 | 
 740 |                 # Extract and resolve nodes
 741 |                 extracted_nodes = await extract_nodes(
 742 |                     self.clients, episode, previous_episodes, entity_types, excluded_entity_types
 743 |                 )
 744 | 
 745 |                 nodes, uuid_map, _ = await resolve_extracted_nodes(
 746 |                     self.clients,
 747 |                     extracted_nodes,
 748 |                     episode,
 749 |                     previous_episodes,
 750 |                     entity_types,
 751 |                 )
 752 | 
 753 |                 # Extract and resolve edges in parallel with attribute extraction
 754 |                 resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
 755 |                     episode,
 756 |                     extracted_nodes,
 757 |                     previous_episodes,
 758 |                     edge_type_map or edge_type_map_default,
 759 |                     group_id,
 760 |                     edge_types,
 761 |                     nodes,
 762 |                     uuid_map,
 763 |                 )
 764 | 
 765 |                 # Extract node attributes
 766 |                 hydrated_nodes = await extract_attributes_from_nodes(
 767 |                     self.clients, nodes, episode, previous_episodes, entity_types
 768 |                 )
 769 | 
 770 |                 entity_edges = resolved_edges + invalidated_edges
 771 | 
 772 |                 # Process and save episode data
 773 |                 episodic_edges, episode = await self._process_episode_data(
 774 |                     episode, hydrated_nodes, entity_edges, now
 775 |                 )
 776 | 
 777 |                 # Update communities if requested
 778 |                 communities = []
 779 |                 community_edges = []
 780 |                 if update_communities:
 781 |                     communities, community_edges = await semaphore_gather(
 782 |                         *[
 783 |                             update_community(self.driver, self.llm_client, self.embedder, node)
 784 |                             for node in nodes
 785 |                         ],
 786 |                         max_coroutines=self.max_coroutines,
 787 |                     )
 788 | 
 789 |                 end = time()
 790 | 
 791 |                 # Add span attributes
 792 |                 span.add_attributes(
 793 |                     {
 794 |                         'episode.uuid': episode.uuid,
 795 |                         'episode.source': source.value,
 796 |                         'episode.reference_time': reference_time.isoformat(),
 797 |                         'group_id': group_id,
 798 |                         'node.count': len(hydrated_nodes),
 799 |                         'edge.count': len(entity_edges),
 800 |                         'edge.invalidated_count': len(invalidated_edges),
 801 |                         'previous_episodes.count': len(previous_episodes),
 802 |                         'entity_types.count': len(entity_types) if entity_types else 0,
 803 |                         'edge_types.count': len(edge_types) if edge_types else 0,
 804 |                         'update_communities': update_communities,
 805 |                         'communities.count': len(communities) if update_communities else 0,
 806 |                         'duration_ms': (end - start) * 1000,
 807 |                     }
 808 |                 )
 809 | 
 810 |                 logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
 811 | 
 812 |                 return AddEpisodeResults(
 813 |                     episode=episode,
 814 |                     episodic_edges=episodic_edges,
 815 |                     nodes=hydrated_nodes,
 816 |                     edges=entity_edges,
 817 |                     communities=communities,
 818 |                     community_edges=community_edges,
 819 |                 )
 820 | 
 821 |             except Exception as e:
 822 |                 span.set_status('error', str(e))
 823 |                 span.record_exception(e)
 824 |                 raise e
 825 | 
 826 |     async def add_episode_bulk(
 827 |         self,
 828 |         bulk_episodes: list[RawEpisode],
 829 |         group_id: str | None = None,
 830 |         entity_types: dict[str, type[BaseModel]] | None = None,
 831 |         excluded_entity_types: list[str] | None = None,
 832 |         edge_types: dict[str, type[BaseModel]] | None = None,
 833 |         edge_type_map: dict[tuple[str, str], list[str]] | None = None,
 834 |     ) -> AddBulkEpisodeResults:
 835 |         """
 836 |         Process multiple episodes in bulk and update the graph.
 837 | 
 838 |         This method extracts information from multiple episodes, creates nodes and edges,
 839 |         and updates the graph database accordingly, all in a single batch operation.
 840 | 
 841 |         Parameters
 842 |         ----------
 843 |         bulk_episodes : list[RawEpisode]
 844 |             A list of RawEpisode objects to be processed and added to the graph.
 845 |         group_id : str | None
 846 |             An id for the graph partition the episode is a part of.
 847 | 
 848 |         Returns
 849 |         -------
 850 |         AddBulkEpisodeResults
 851 | 
 852 |         Notes
 853 |         -----
 854 |         This method performs several steps including:
 855 |         - Saving all episodes to the database
 856 |         - Retrieving previous episode context for each new episode
 857 |         - Extracting nodes and edges from all episodes
 858 |         - Generating embeddings for nodes and edges
 859 |         - Deduplicating nodes and edges
 860 |         - Saving nodes, episodic edges, and entity edges to the knowledge graph
 861 | 
 862 |         This bulk operation is designed for efficiency when processing multiple episodes
 863 |         at once. However, it's important to ensure that the bulk operation doesn't
 864 |         overwhelm system resources. Consider implementing rate limiting or chunking for
 865 |         very large batches of episodes.
 866 | 
 867 |         Important: This method does not perform edge invalidation or date extraction steps.
 868 |         If these operations are required, use the `add_episode` method instead for each
 869 |         individual episode.
 870 |         """
 871 |         with self.tracer.start_span('add_episode_bulk') as bulk_span:
 872 |             bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
 873 | 
 874 |             try:
 875 |                 start = time()
 876 |                 now = utc_now()
 877 | 
 878 |                 # if group_id is None, use the default group id by the provider
 879 |                 if group_id is None:
 880 |                     group_id = get_default_group_id(self.driver.provider)
 881 |                 else:
 882 |                     validate_group_id(group_id)
 883 |                     if group_id != self.driver._database:
 884 |                         # if group_id is provided, use it as the database name
 885 |                         self.driver = self.driver.clone(database=group_id)
 886 |                         self.clients.driver = self.driver
 887 | 
 888 |                 # Create default edge type map
 889 |                 edge_type_map_default = (
 890 |                     {('Entity', 'Entity'): list(edge_types.keys())}
 891 |                     if edge_types is not None
 892 |                     else {('Entity', 'Entity'): []}
 893 |                 )
 894 | 
 895 |                 episodes = [
 896 |                     await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
 897 |                     if episode.uuid is not None
 898 |                     else EpisodicNode(
 899 |                         name=episode.name,
 900 |                         labels=[],
 901 |                         source=episode.source,
 902 |                         content=episode.content,
 903 |                         source_description=episode.source_description,
 904 |                         group_id=group_id,
 905 |                         created_at=now,
 906 |                         valid_at=episode.reference_time,
 907 |                     )
 908 |                     for episode in bulk_episodes
 909 |                 ]
 910 | 
 911 |                 # Save all episodes
 912 |                 await add_nodes_and_edges_bulk(
 913 |                     driver=self.driver,
 914 |                     episodic_nodes=episodes,
 915 |                     episodic_edges=[],
 916 |                     entity_nodes=[],
 917 |                     entity_edges=[],
 918 |                     embedder=self.embedder,
 919 |                 )
 920 | 
 921 |                 # Get previous episode context for each episode
 922 |                 episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
 923 | 
 924 |                 # Extract and dedupe nodes and edges
 925 |                 (
 926 |                     nodes_by_episode,
 927 |                     uuid_map,
 928 |                     extracted_edges_bulk,
 929 |                 ) = await self._extract_and_dedupe_nodes_bulk(
 930 |                     episode_context,
 931 |                     edge_type_map or edge_type_map_default,
 932 |                     edge_types,
 933 |                     entity_types,
 934 |                     excluded_entity_types,
 935 |                 )
 936 | 
 937 |                 # Create Episodic Edges
 938 |                 episodic_edges: list[EpisodicEdge] = []
 939 |                 for episode_uuid, nodes in nodes_by_episode.items():
 940 |                     episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
 941 | 
 942 |                 # Re-map edge pointers and dedupe edges
 943 |                 extracted_edges_bulk_updated: list[list[EntityEdge]] = [
 944 |                     resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
 945 |                 ]
 946 | 
 947 |                 edges_by_episode = await dedupe_edges_bulk(
 948 |                     self.clients,
 949 |                     extracted_edges_bulk_updated,
 950 |                     episode_context,
 951 |                     [],
 952 |                     edge_types or {},
 953 |                     edge_type_map or edge_type_map_default,
 954 |                 )
 955 | 
 956 |                 # Resolve nodes and edges against the existing graph
 957 |                 (
 958 |                     final_hydrated_nodes,
 959 |                     resolved_edges,
 960 |                     invalidated_edges,
 961 |                     final_uuid_map,
 962 |                 ) = await self._resolve_nodes_and_edges_bulk(
 963 |                     nodes_by_episode,
 964 |                     edges_by_episode,
 965 |                     episode_context,
 966 |                     entity_types,
 967 |                     edge_types,
 968 |                     edge_type_map or edge_type_map_default,
 969 |                     episodes,
 970 |                 )
 971 | 
 972 |                 # Resolved pointers for episodic edges
 973 |                 resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
 974 | 
 975 |                 # save data to KG
 976 |                 await add_nodes_and_edges_bulk(
 977 |                     self.driver,
 978 |                     episodes,
 979 |                     resolved_episodic_edges,
 980 |                     final_hydrated_nodes,
 981 |                     resolved_edges + invalidated_edges,
 982 |                     self.embedder,
 983 |                 )
 984 | 
 985 |                 end = time()
 986 | 
 987 |                 # Add span attributes
 988 |                 bulk_span.add_attributes(
 989 |                     {
 990 |                         'group_id': group_id,
 991 |                         'node.count': len(final_hydrated_nodes),
 992 |                         'edge.count': len(resolved_edges + invalidated_edges),
 993 |                         'duration_ms': (end - start) * 1000,
 994 |                     }
 995 |                 )
 996 | 
 997 |                 logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
 998 | 
 999 |                 return AddBulkEpisodeResults(
1000 |                     episodes=episodes,
1001 |                     episodic_edges=resolved_episodic_edges,
1002 |                     nodes=final_hydrated_nodes,
1003 |                     edges=resolved_edges + invalidated_edges,
1004 |                     communities=[],
1005 |                     community_edges=[],
1006 |                 )
1007 | 
1008 |             except Exception as e:
1009 |                 bulk_span.set_status('error', str(e))
1010 |                 bulk_span.record_exception(e)
1011 |                 raise e
1012 | 
1013 |     @handle_multiple_group_ids
1014 |     async def build_communities(
1015 |         self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
1016 |     ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
1017 |         """
1018 |         Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
1019 |         the content of these communities.
1020 |         ----------
1021 |         group_ids : list[str] | None
1022 |             Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
1023 |         """
1024 |         if driver is None:
1025 |             driver = self.clients.driver
1026 | 
1027 |         # Clear existing communities
1028 |         await remove_communities(driver)
1029 | 
1030 |         community_nodes, community_edges = await build_communities(
1031 |             driver, self.llm_client, group_ids
1032 |         )
1033 | 
1034 |         await semaphore_gather(
1035 |             *[node.generate_name_embedding(self.embedder) for node in community_nodes],
1036 |             max_coroutines=self.max_coroutines,
1037 |         )
1038 | 
1039 |         await semaphore_gather(
1040 |             *[node.save(driver) for node in community_nodes],
1041 |             max_coroutines=self.max_coroutines,
1042 |         )
1043 |         await semaphore_gather(
1044 |             *[edge.save(driver) for edge in community_edges],
1045 |             max_coroutines=self.max_coroutines,
1046 |         )
1047 | 
1048 |         return community_nodes, community_edges
1049 | 
1050 |     @handle_multiple_group_ids
1051 |     async def search(
1052 |         self,
1053 |         query: str,
1054 |         center_node_uuid: str | None = None,
1055 |         group_ids: list[str] | None = None,
1056 |         num_results=DEFAULT_SEARCH_LIMIT,
1057 |         search_filter: SearchFilters | None = None,
1058 |         driver: GraphDriver | None = None,
1059 |     ) -> list[EntityEdge]:
1060 |         """
1061 |         Perform a hybrid search on the knowledge graph.
1062 | 
1063 |         This method executes a search query on the graph, combining vector and
1064 |         text-based search techniques to retrieve relevant facts, returning the edges as a string.
1065 | 
1066 |         This is our basic out-of-the-box search, for more robust results we recommend using our more advanced
1067 |         search method graphiti.search_().
1068 | 
1069 |         Parameters
1070 |         ----------
1071 |         query : str
1072 |             The search query string.
1073 |         center_node_uuid: str, optional
1074 |             Facts will be reranked based on proximity to this node
1075 |         group_ids : list[str | None] | None, optional
1076 |             The graph partitions to return data from.
1077 |         num_results : int, optional
1078 |             The maximum number of results to return. Defaults to 10.
1079 | 
1080 |         Returns
1081 |         -------
1082 |         list
1083 |             A list of EntityEdge objects that are relevant to the search query.
1084 | 
1085 |         Notes
1086 |         -----
1087 |         This method uses a SearchConfig with num_episodes set to 0 and
1088 |         num_results set to the provided num_results parameter.
1089 | 
1090 |         The search is performed using the current date and time as the reference
1091 |         point for temporal relevance.
1092 |         """
1093 |         search_config = (
1094 |             EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
1095 |         )
1096 |         search_config.limit = num_results
1097 | 
1098 |         edges = (
1099 |             await search(
1100 |                 self.clients,
1101 |                 query,
1102 |                 group_ids,
1103 |                 search_config,
1104 |                 search_filter if search_filter is not None else SearchFilters(),
1105 |                 driver=driver,
1106 |                 center_node_uuid=center_node_uuid,
1107 |             )
1108 |         ).edges
1109 | 
1110 |         return edges
1111 | 
1112 |     async def _search(
1113 |         self,
1114 |         query: str,
1115 |         config: SearchConfig,
1116 |         group_ids: list[str] | None = None,
1117 |         center_node_uuid: str | None = None,
1118 |         bfs_origin_node_uuids: list[str] | None = None,
1119 |         search_filter: SearchFilters | None = None,
1120 |     ) -> SearchResults:
1121 |         """DEPRECATED"""
1122 |         return await self.search_(
1123 |             query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
1124 |         )
1125 | 
1126 |     @handle_multiple_group_ids
1127 |     async def search_(
1128 |         self,
1129 |         query: str,
1130 |         config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
1131 |         group_ids: list[str] | None = None,
1132 |         center_node_uuid: str | None = None,
1133 |         bfs_origin_node_uuids: list[str] | None = None,
1134 |         search_filter: SearchFilters | None = None,
1135 |         driver: GraphDriver | None = None,
1136 |     ) -> SearchResults:
1137 |         """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
1138 |         than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
1139 |         different search and reranker methodologies across different layers in the graph.
1140 | 
1141 |         For different config recipes refer to search/search_config_recipes.
1142 |         """
1143 | 
1144 |         return await search(
1145 |             self.clients,
1146 |             query,
1147 |             group_ids,
1148 |             config,
1149 |             search_filter if search_filter is not None else SearchFilters(),
1150 |             center_node_uuid,
1151 |             bfs_origin_node_uuids,
1152 |             driver=driver,
1153 |         )
1154 | 
1155 |     async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
1156 |         episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
1157 | 
1158 |         edges_list = await semaphore_gather(
1159 |             *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
1160 |             max_coroutines=self.max_coroutines,
1161 |         )
1162 | 
1163 |         edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
1164 | 
1165 |         nodes = await get_mentioned_nodes(self.driver, episodes)
1166 | 
1167 |         return SearchResults(edges=edges, nodes=nodes)
1168 | 
1169 |     async def add_triplet(
1170 |         self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
1171 |     ) -> AddTripletResults:
1172 |         if source_node.name_embedding is None:
1173 |             await source_node.generate_name_embedding(self.embedder)
1174 |         if target_node.name_embedding is None:
1175 |             await target_node.generate_name_embedding(self.embedder)
1176 |         if edge.fact_embedding is None:
1177 |             await edge.generate_embedding(self.embedder)
1178 | 
1179 |         nodes, uuid_map, _ = await resolve_extracted_nodes(
1180 |             self.clients,
1181 |             [source_node, target_node],
1182 |         )
1183 | 
1184 |         updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1185 | 
1186 |         valid_edges = await EntityEdge.get_between_nodes(
1187 |             self.driver, edge.source_node_uuid, edge.target_node_uuid
1188 |         )
1189 | 
1190 |         related_edges = (
1191 |             await search(
1192 |                 self.clients,
1193 |                 updated_edge.fact,
1194 |                 group_ids=[updated_edge.group_id],
1195 |                 config=EDGE_HYBRID_SEARCH_RRF,
1196 |                 search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1197 |             )
1198 |         ).edges
1199 |         existing_edges = (
1200 |             await search(
1201 |                 self.clients,
1202 |                 updated_edge.fact,
1203 |                 group_ids=[updated_edge.group_id],
1204 |                 config=EDGE_HYBRID_SEARCH_RRF,
1205 |                 search_filter=SearchFilters(),
1206 |             )
1207 |         ).edges
1208 | 
1209 |         resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
1210 |             self.llm_client,
1211 |             updated_edge,
1212 |             related_edges,
1213 |             existing_edges,
1214 |             EpisodicNode(
1215 |                 name='',
1216 |                 source=EpisodeType.text,
1217 |                 source_description='',
1218 |                 content='',
1219 |                 valid_at=edge.valid_at or utc_now(),
1220 |                 entity_edges=[],
1221 |                 group_id=edge.group_id,
1222 |             ),
1223 |             None,
1224 |             None,
1225 |         )
1226 | 
1227 |         edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
1228 | 
1229 |         await create_entity_edge_embeddings(self.embedder, edges)
1230 |         await create_entity_node_embeddings(self.embedder, nodes)
1231 | 
1232 |         await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
1233 |         return AddTripletResults(edges=edges, nodes=nodes)
1234 | 
1235 |     async def remove_episode(self, episode_uuid: str):
1236 |         # Find the episode to be deleted
1237 |         episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid)
1238 | 
1239 |         # Find edges mentioned by the episode
1240 |         edges = await EntityEdge.get_by_uuids(self.driver, episode.entity_edges)
1241 | 
1242 |         # We should only delete edges created by the episode
1243 |         edges_to_delete: list[EntityEdge] = []
1244 |         for edge in edges:
1245 |             if edge.episodes and edge.episodes[0] == episode.uuid:
1246 |                 edges_to_delete.append(edge)
1247 | 
1248 |         # Find nodes mentioned by the episode
1249 |         nodes = await get_mentioned_nodes(self.driver, [episode])
1250 |         # We should delete all nodes that are only mentioned in the deleted episode
1251 |         nodes_to_delete: list[EntityNode] = []
1252 |         for node in nodes:
1253 |             query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
1254 |             records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
1255 | 
1256 |             for record in records:
1257 |                 if record['episode_count'] == 1:
1258 |                     nodes_to_delete.append(node)
1259 | 
1260 |         await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1261 |         await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
1262 | 
1263 |         await episode.delete(self.driver)
1264 | 
```

--------------------------------------------------------------------------------
/tests/test_graphiti_mock.py:
--------------------------------------------------------------------------------

```python
   1 | """
   2 | Copyright 2024, Zep Software, Inc.
   3 | 
   4 | Licensed under the Apache License, Version 2.0 (the "License");
   5 | you may not use this file except in compliance with the License.
   6 | You may obtain a copy of the License at
   7 | 
   8 |     http://www.apache.org/licenses/LICENSE-2.0
   9 | 
  10 | Unless required by applicable law or agreed to in writing, software
  11 | distributed under the License is distributed on an "AS IS" BASIS,
  12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 | See the License for the specific language governing permissions and
  14 | limitations under the License.
  15 | """
  16 | 
  17 | from datetime import datetime, timedelta
  18 | from unittest.mock import Mock
  19 | 
  20 | import numpy as np
  21 | import pytest
  22 | 
  23 | from graphiti_core.cross_encoder.client import CrossEncoderClient
  24 | from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
  25 | from graphiti_core.graphiti import Graphiti
  26 | from graphiti_core.llm_client import LLMClient
  27 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
  28 | from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
  29 | from graphiti_core.search.search_utils import (
  30 |     community_fulltext_search,
  31 |     community_similarity_search,
  32 |     edge_bfs_search,
  33 |     edge_fulltext_search,
  34 |     edge_similarity_search,
  35 |     episode_fulltext_search,
  36 |     episode_mentions_reranker,
  37 |     get_communities_by_nodes,
  38 |     get_edge_invalidation_candidates,
  39 |     get_embeddings_for_communities,
  40 |     get_embeddings_for_edges,
  41 |     get_embeddings_for_nodes,
  42 |     get_mentioned_nodes,
  43 |     get_relevant_edges,
  44 |     get_relevant_nodes,
  45 |     node_bfs_search,
  46 |     node_distance_reranker,
  47 |     node_fulltext_search,
  48 |     node_similarity_search,
  49 | )
  50 | from graphiti_core.utils.bulk_utils import add_nodes_and_edges_bulk
  51 | from graphiti_core.utils.maintenance.community_operations import (
  52 |     determine_entity_community,
  53 |     get_community_clusters,
  54 |     remove_communities,
  55 | )
  56 | from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
  57 | from tests.helpers_test import (
  58 |     GraphProvider,
  59 |     assert_entity_edge_equals,
  60 |     assert_entity_node_equals,
  61 |     assert_episodic_edge_equals,
  62 |     assert_episodic_node_equals,
  63 |     get_edge_count,
  64 |     get_node_count,
  65 |     group_id,
  66 |     group_id_2,
  67 | )
  68 | 
  69 | pytest_plugins = ('pytest_asyncio',)
  70 | 
  71 | 
  72 | @pytest.fixture
  73 | def mock_llm_client():
  74 |     """Create a mock LLM"""
  75 |     mock_llm = Mock(spec=LLMClient)
  76 |     mock_llm.config = Mock()
  77 |     mock_llm.model = 'test-model'
  78 |     mock_llm.small_model = 'test-small-model'
  79 |     mock_llm.temperature = 0.0
  80 |     mock_llm.max_tokens = 1000
  81 |     mock_llm.cache_enabled = False
  82 |     mock_llm.cache_dir = None
  83 | 
  84 |     # Mock the public method that's actually called
  85 |     mock_llm.generate_response = Mock()
  86 |     mock_llm.generate_response.return_value = {
  87 |         'tool_calls': [
  88 |             {
  89 |                 'name': 'extract_entities',
  90 |                 'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
  91 |             }
  92 |         ]
  93 |     }
  94 | 
  95 |     return mock_llm
  96 | 
  97 | 
  98 | @pytest.fixture
  99 | def mock_cross_encoder_client():
 100 |     """Create a mock LLM"""
 101 |     mock_llm = Mock(spec=CrossEncoderClient)
 102 |     mock_llm.config = Mock()
 103 | 
 104 |     # Mock the public method that's actually called
 105 |     mock_llm.rerank = Mock()
 106 |     mock_llm.rerank.return_value = {
 107 |         'tool_calls': [
 108 |             {
 109 |                 'name': 'extract_entities',
 110 |                 'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
 111 |             }
 112 |         ]
 113 |     }
 114 | 
 115 |     return mock_llm
 116 | 
 117 | 
 118 | @pytest.mark.asyncio
 119 | async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
 120 |     if graph_driver.provider == GraphProvider.FALKORDB:
 121 |         pytest.skip('Skipping as test fails on FalkorDB')
 122 | 
 123 |     graphiti = Graphiti(
 124 |         graph_driver=graph_driver,
 125 |         llm_client=mock_llm_client,
 126 |         embedder=mock_embedder,
 127 |         cross_encoder=mock_cross_encoder_client,
 128 |     )
 129 | 
 130 |     await graphiti.build_indices_and_constraints()
 131 | 
 132 |     now = datetime.now()
 133 | 
 134 |     # Create episodic nodes
 135 |     episode_node_1 = EpisodicNode(
 136 |         name='test_episode',
 137 |         group_id=group_id,
 138 |         labels=[],
 139 |         created_at=now,
 140 |         source=EpisodeType.message,
 141 |         source_description='conversation message',
 142 |         content='Alice likes Bob',
 143 |         valid_at=now,
 144 |         entity_edges=[],  # Filled in later
 145 |     )
 146 |     episode_node_2 = EpisodicNode(
 147 |         name='test_episode_2',
 148 |         group_id=group_id,
 149 |         labels=[],
 150 |         created_at=now,
 151 |         source=EpisodeType.message,
 152 |         source_description='conversation message',
 153 |         content='Bob adores Alice',
 154 |         valid_at=now,
 155 |         entity_edges=[],  # Filled in later
 156 |     )
 157 | 
 158 |     # Create entity nodes
 159 |     entity_node_1 = EntityNode(
 160 |         name='test_entity_1',
 161 |         group_id=group_id,
 162 |         labels=['Entity', 'Person'],
 163 |         created_at=now,
 164 |         summary='test_entity_1 summary',
 165 |         attributes={'age': 30, 'location': 'New York'},
 166 |     )
 167 |     await entity_node_1.generate_name_embedding(mock_embedder)
 168 | 
 169 |     entity_node_2 = EntityNode(
 170 |         name='test_entity_2',
 171 |         group_id=group_id,
 172 |         labels=['Entity', 'Person2'],
 173 |         created_at=now,
 174 |         summary='test_entity_2 summary',
 175 |         attributes={'age': 25, 'location': 'Los Angeles'},
 176 |     )
 177 |     await entity_node_2.generate_name_embedding(mock_embedder)
 178 | 
 179 |     entity_node_3 = EntityNode(
 180 |         name='test_entity_3',
 181 |         group_id=group_id,
 182 |         labels=['Entity', 'City', 'Location'],
 183 |         created_at=now,
 184 |         summary='test_entity_3 summary',
 185 |         attributes={'age': 25, 'location': 'Los Angeles'},
 186 |     )
 187 |     await entity_node_3.generate_name_embedding(mock_embedder)
 188 | 
 189 |     entity_node_4 = EntityNode(
 190 |         name='test_entity_4',
 191 |         group_id=group_id,
 192 |         labels=['Entity'],
 193 |         created_at=now,
 194 |         summary='test_entity_4 summary',
 195 |         attributes={'age': 25, 'location': 'Los Angeles'},
 196 |     )
 197 |     await entity_node_4.generate_name_embedding(mock_embedder)
 198 | 
 199 |     # Create entity edges
 200 |     entity_edge_1 = EntityEdge(
 201 |         source_node_uuid=entity_node_1.uuid,
 202 |         target_node_uuid=entity_node_2.uuid,
 203 |         created_at=now,
 204 |         name='likes',
 205 |         fact='test_entity_1 relates to test_entity_2',
 206 |         episodes=[],
 207 |         expired_at=now,
 208 |         valid_at=now,
 209 |         invalid_at=now,
 210 |         group_id=group_id,
 211 |     )
 212 |     await entity_edge_1.generate_embedding(mock_embedder)
 213 | 
 214 |     entity_edge_2 = EntityEdge(
 215 |         source_node_uuid=entity_node_3.uuid,
 216 |         target_node_uuid=entity_node_4.uuid,
 217 |         created_at=now,
 218 |         name='relates_to',
 219 |         fact='test_entity_3 relates to test_entity_4',
 220 |         episodes=[],
 221 |         expired_at=now,
 222 |         valid_at=now,
 223 |         invalid_at=now,
 224 |         group_id=group_id,
 225 |     )
 226 |     await entity_edge_2.generate_embedding(mock_embedder)
 227 | 
 228 |     # Create episodic to entity edges
 229 |     episodic_edge_1 = EpisodicEdge(
 230 |         source_node_uuid=episode_node_1.uuid,
 231 |         target_node_uuid=entity_node_1.uuid,
 232 |         created_at=now,
 233 |         group_id=group_id,
 234 |     )
 235 |     episodic_edge_2 = EpisodicEdge(
 236 |         source_node_uuid=episode_node_1.uuid,
 237 |         target_node_uuid=entity_node_2.uuid,
 238 |         created_at=now,
 239 |         group_id=group_id,
 240 |     )
 241 |     episodic_edge_3 = EpisodicEdge(
 242 |         source_node_uuid=episode_node_2.uuid,
 243 |         target_node_uuid=entity_node_3.uuid,
 244 |         created_at=now,
 245 |         group_id=group_id,
 246 |     )
 247 |     episodic_edge_4 = EpisodicEdge(
 248 |         source_node_uuid=episode_node_2.uuid,
 249 |         target_node_uuid=entity_node_4.uuid,
 250 |         created_at=now,
 251 |         group_id=group_id,
 252 |     )
 253 | 
 254 |     # Cross reference the ids
 255 |     episode_node_1.entity_edges = [entity_edge_1.uuid]
 256 |     episode_node_2.entity_edges = [entity_edge_2.uuid]
 257 |     entity_edge_1.episodes = [episode_node_1.uuid, episode_node_2.uuid]
 258 |     entity_edge_2.episodes = [episode_node_2.uuid]
 259 | 
 260 |     # Test add bulk
 261 |     await add_nodes_and_edges_bulk(
 262 |         graph_driver,
 263 |         [episode_node_1, episode_node_2],
 264 |         [episodic_edge_1, episodic_edge_2, episodic_edge_3, episodic_edge_4],
 265 |         [entity_node_1, entity_node_2, entity_node_3, entity_node_4],
 266 |         [entity_edge_1, entity_edge_2],
 267 |         mock_embedder,
 268 |     )
 269 | 
 270 |     node_ids = [
 271 |         episode_node_1.uuid,
 272 |         episode_node_2.uuid,
 273 |         entity_node_1.uuid,
 274 |         entity_node_2.uuid,
 275 |         entity_node_3.uuid,
 276 |         entity_node_4.uuid,
 277 |     ]
 278 |     edge_ids = [
 279 |         episodic_edge_1.uuid,
 280 |         episodic_edge_2.uuid,
 281 |         episodic_edge_3.uuid,
 282 |         episodic_edge_4.uuid,
 283 |         entity_edge_1.uuid,
 284 |         entity_edge_2.uuid,
 285 |     ]
 286 |     node_count = await get_node_count(graph_driver, node_ids)
 287 |     assert node_count == len(node_ids)
 288 |     edge_count = await get_edge_count(graph_driver, edge_ids)
 289 |     assert edge_count == len(edge_ids)
 290 | 
 291 |     # Test episodic nodes
 292 |     retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_1.uuid)
 293 |     await assert_episodic_node_equals(retrieved_episode, episode_node_1)
 294 | 
 295 |     retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_2.uuid)
 296 |     await assert_episodic_node_equals(retrieved_episode, episode_node_2)
 297 | 
 298 |     # Test entity nodes
 299 |     retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_1.uuid)
 300 |     await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_1)
 301 | 
 302 |     retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
 303 |     await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
 304 | 
 305 |     retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
 306 |     await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
 307 | 
 308 |     retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_4.uuid)
 309 |     await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_4)
 310 | 
 311 |     # Test episodic edges
 312 |     retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_1.uuid)
 313 |     await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_1)
 314 | 
 315 |     retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_2.uuid)
 316 |     await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_2)
 317 | 
 318 |     retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_3.uuid)
 319 |     await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_3)
 320 | 
 321 |     retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_4.uuid)
 322 |     await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_4)
 323 | 
 324 |     # Test entity edges
 325 |     retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_1.uuid)
 326 |     await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_1)
 327 | 
 328 |     retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
 329 |     await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
 330 | 
 331 | 
 332 | @pytest.mark.asyncio
 333 | async def test_remove_episode(
 334 |     graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
 335 | ):
 336 |     graphiti = Graphiti(
 337 |         graph_driver=graph_driver,
 338 |         llm_client=mock_llm_client,
 339 |         embedder=mock_embedder,
 340 |         cross_encoder=mock_cross_encoder_client,
 341 |     )
 342 | 
 343 |     await graphiti.build_indices_and_constraints()
 344 | 
 345 |     now = datetime.now()
 346 | 
 347 |     # Create episodic nodes
 348 |     episode_node = EpisodicNode(
 349 |         name='test_episode',
 350 |         group_id=group_id,
 351 |         labels=[],
 352 |         created_at=now,
 353 |         source=EpisodeType.message,
 354 |         source_description='conversation message',
 355 |         content='Alice likes Bob',
 356 |         valid_at=now,
 357 |         entity_edges=[],  # Filled in later
 358 |     )
 359 | 
 360 |     # Create entity nodes
 361 |     alice_node = EntityNode(
 362 |         name='Alice',
 363 |         group_id=group_id,
 364 |         labels=['Entity', 'Person'],
 365 |         created_at=now,
 366 |         summary='Alice summary',
 367 |         attributes={'age': 30, 'location': 'New York'},
 368 |     )
 369 |     await alice_node.generate_name_embedding(mock_embedder)
 370 | 
 371 |     bob_node = EntityNode(
 372 |         name='Bob',
 373 |         group_id=group_id,
 374 |         labels=['Entity', 'Person2'],
 375 |         created_at=now,
 376 |         summary='Bob summary',
 377 |         attributes={'age': 25, 'location': 'Los Angeles'},
 378 |     )
 379 |     await bob_node.generate_name_embedding(mock_embedder)
 380 | 
 381 |     # Create entity to entity edge
 382 |     entity_edge = EntityEdge(
 383 |         source_node_uuid=alice_node.uuid,
 384 |         target_node_uuid=bob_node.uuid,
 385 |         created_at=now,
 386 |         name='likes',
 387 |         fact='Alice likes Bob',
 388 |         episodes=[],
 389 |         expired_at=now,
 390 |         valid_at=now,
 391 |         invalid_at=now,
 392 |         group_id=group_id,
 393 |     )
 394 |     await entity_edge.generate_embedding(mock_embedder)
 395 | 
 396 |     # Create episodic to entity edges
 397 |     episodic_alice_edge = EpisodicEdge(
 398 |         source_node_uuid=episode_node.uuid,
 399 |         target_node_uuid=alice_node.uuid,
 400 |         created_at=now,
 401 |         group_id=group_id,
 402 |     )
 403 |     episodic_bob_edge = EpisodicEdge(
 404 |         source_node_uuid=episode_node.uuid,
 405 |         target_node_uuid=bob_node.uuid,
 406 |         created_at=now,
 407 |         group_id=group_id,
 408 |     )
 409 | 
 410 |     # Cross reference the ids
 411 |     episode_node.entity_edges = [entity_edge.uuid]
 412 |     entity_edge.episodes = [episode_node.uuid]
 413 | 
 414 |     # Test add bulk
 415 |     await add_nodes_and_edges_bulk(
 416 |         graph_driver,
 417 |         [episode_node],
 418 |         [episodic_alice_edge, episodic_bob_edge],
 419 |         [alice_node, bob_node],
 420 |         [entity_edge],
 421 |         mock_embedder,
 422 |     )
 423 | 
 424 |     node_ids = [episode_node.uuid, alice_node.uuid, bob_node.uuid]
 425 |     edge_ids = [episodic_alice_edge.uuid, episodic_bob_edge.uuid, entity_edge.uuid]
 426 |     node_count = await get_node_count(graph_driver, node_ids)
 427 |     assert node_count == 3
 428 |     edge_count = await get_edge_count(graph_driver, edge_ids)
 429 |     assert edge_count == 3
 430 | 
 431 |     # Test remove episode
 432 |     await graphiti.remove_episode(episode_node.uuid)
 433 |     node_count = await get_node_count(graph_driver, node_ids)
 434 |     assert node_count == 0
 435 |     edge_count = await get_edge_count(graph_driver, edge_ids)
 436 |     assert edge_count == 0
 437 | 
 438 |     # Test add bulk again
 439 |     await add_nodes_and_edges_bulk(
 440 |         graph_driver,
 441 |         [episode_node],
 442 |         [episodic_alice_edge, episodic_bob_edge],
 443 |         [alice_node, bob_node],
 444 |         [entity_edge],
 445 |         mock_embedder,
 446 |     )
 447 |     node_count = await get_node_count(graph_driver, node_ids)
 448 |     assert node_count == 3
 449 |     edge_count = await get_edge_count(graph_driver, edge_ids)
 450 |     assert edge_count == 3
 451 | 
 452 | 
 453 | @pytest.mark.asyncio
 454 | async def test_graphiti_retrieve_episodes(
 455 |     graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
 456 | ):
 457 |     if graph_driver.provider == GraphProvider.FALKORDB:
 458 |         pytest.skip('Skipping as test fails on FalkorDB')
 459 | 
 460 |     graphiti = Graphiti(
 461 |         graph_driver=graph_driver,
 462 |         llm_client=mock_llm_client,
 463 |         embedder=mock_embedder,
 464 |         cross_encoder=mock_cross_encoder_client,
 465 |     )
 466 | 
 467 |     await graphiti.build_indices_and_constraints()
 468 | 
 469 |     now = datetime.now()
 470 |     valid_at_1 = now - timedelta(days=2)
 471 |     valid_at_2 = now - timedelta(days=4)
 472 |     valid_at_3 = now - timedelta(days=6)
 473 | 
 474 |     # Create episodic nodes
 475 |     episode_node_1 = EpisodicNode(
 476 |         name='test_episode_1',
 477 |         labels=[],
 478 |         created_at=now,
 479 |         valid_at=valid_at_1,
 480 |         source=EpisodeType.message,
 481 |         source_description='conversation message',
 482 |         content='Test message 1',
 483 |         entity_edges=[],
 484 |         group_id=group_id,
 485 |     )
 486 |     episode_node_2 = EpisodicNode(
 487 |         name='test_episode_2',
 488 |         labels=[],
 489 |         created_at=now,
 490 |         valid_at=valid_at_2,
 491 |         source=EpisodeType.message,
 492 |         source_description='conversation message',
 493 |         content='Test message 2',
 494 |         entity_edges=[],
 495 |         group_id=group_id,
 496 |     )
 497 |     episode_node_3 = EpisodicNode(
 498 |         name='test_episode_3',
 499 |         labels=[],
 500 |         created_at=now,
 501 |         valid_at=valid_at_3,
 502 |         source=EpisodeType.message,
 503 |         source_description='conversation message',
 504 |         content='Test message 3',
 505 |         entity_edges=[],
 506 |         group_id=group_id,
 507 |     )
 508 | 
 509 |     # Save the nodes
 510 |     await episode_node_1.save(graph_driver)
 511 |     await episode_node_2.save(graph_driver)
 512 |     await episode_node_3.save(graph_driver)
 513 | 
 514 |     node_ids = [episode_node_1.uuid, episode_node_2.uuid, episode_node_3.uuid]
 515 |     node_count = await get_node_count(graph_driver, node_ids)
 516 |     assert node_count == 3
 517 | 
 518 |     # Retrieve episodes
 519 |     query_time = now - timedelta(days=3)
 520 |     episodes = await graphiti.retrieve_episodes(
 521 |         query_time, last_n=5, group_ids=[group_id], source=EpisodeType.message
 522 |     )
 523 |     assert len(episodes) == 2
 524 |     assert episodes[0].name == episode_node_3.name
 525 |     assert episodes[1].name == episode_node_2.name
 526 | 
 527 | 
 528 | @pytest.mark.asyncio
 529 | async def test_filter_existing_duplicate_of_edges(graph_driver, mock_embedder):
 530 |     # Create entity nodes
 531 |     entity_node_1 = EntityNode(
 532 |         name='test_entity_1',
 533 |         labels=[],
 534 |         created_at=datetime.now(),
 535 |         group_id=group_id,
 536 |     )
 537 |     await entity_node_1.generate_name_embedding(mock_embedder)
 538 |     entity_node_2 = EntityNode(
 539 |         name='test_entity_2',
 540 |         labels=[],
 541 |         created_at=datetime.now(),
 542 |         group_id=group_id,
 543 |     )
 544 |     await entity_node_2.generate_name_embedding(mock_embedder)
 545 |     entity_node_3 = EntityNode(
 546 |         name='test_entity_3',
 547 |         labels=[],
 548 |         created_at=datetime.now(),
 549 |         group_id=group_id,
 550 |     )
 551 |     await entity_node_3.generate_name_embedding(mock_embedder)
 552 |     entity_node_4 = EntityNode(
 553 |         name='test_entity_4',
 554 |         labels=[],
 555 |         created_at=datetime.now(),
 556 |         group_id=group_id,
 557 |     )
 558 |     await entity_node_4.generate_name_embedding(mock_embedder)
 559 | 
 560 |     # Save the nodes
 561 |     await entity_node_1.save(graph_driver)
 562 |     await entity_node_2.save(graph_driver)
 563 |     await entity_node_3.save(graph_driver)
 564 |     await entity_node_4.save(graph_driver)
 565 | 
 566 |     node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
 567 |     node_count = await get_node_count(graph_driver, node_ids)
 568 |     assert node_count == 4
 569 | 
 570 |     # Create duplicate entity edge
 571 |     entity_edge = EntityEdge(
 572 |         source_node_uuid=entity_node_1.uuid,
 573 |         target_node_uuid=entity_node_2.uuid,
 574 |         name='IS_DUPLICATE_OF',
 575 |         fact='test_entity_1 is a duplicate of test_entity_2',
 576 |         created_at=datetime.now(),
 577 |         group_id=group_id,
 578 |     )
 579 |     await entity_edge.generate_embedding(mock_embedder)
 580 |     await entity_edge.save(graph_driver)
 581 | 
 582 |     # Filter duplicate entity edges
 583 |     duplicate_node_tuples = [
 584 |         (entity_node_1, entity_node_2),
 585 |         (entity_node_3, entity_node_4),
 586 |     ]
 587 |     node_tuples = await filter_existing_duplicate_of_edges(graph_driver, duplicate_node_tuples)
 588 |     assert len(node_tuples) == 1
 589 |     assert [node.name for node in node_tuples[0]] == [entity_node_3.name, entity_node_4.name]
 590 | 
 591 | 
 592 | @pytest.mark.asyncio
 593 | async def test_determine_entity_community(graph_driver, mock_embedder):
 594 |     if graph_driver.provider == GraphProvider.FALKORDB:
 595 |         pytest.skip('Skipping as test fails on FalkorDB')
 596 | 
 597 |     # Create entity nodes
 598 |     entity_node_1 = EntityNode(
 599 |         name='test_entity_1',
 600 |         labels=[],
 601 |         created_at=datetime.now(),
 602 |         group_id=group_id,
 603 |     )
 604 |     await entity_node_1.generate_name_embedding(mock_embedder)
 605 |     entity_node_2 = EntityNode(
 606 |         name='test_entity_2',
 607 |         labels=[],
 608 |         created_at=datetime.now(),
 609 |         group_id=group_id,
 610 |     )
 611 |     await entity_node_2.generate_name_embedding(mock_embedder)
 612 |     entity_node_3 = EntityNode(
 613 |         name='test_entity_3',
 614 |         labels=[],
 615 |         created_at=datetime.now(),
 616 |         group_id=group_id,
 617 |     )
 618 |     await entity_node_3.generate_name_embedding(mock_embedder)
 619 |     entity_node_4 = EntityNode(
 620 |         name='test_entity_4',
 621 |         labels=[],
 622 |         created_at=datetime.now(),
 623 |         group_id=group_id,
 624 |     )
 625 |     await entity_node_4.generate_name_embedding(mock_embedder)
 626 | 
 627 |     # Create entity edges
 628 |     entity_edge_1 = EntityEdge(
 629 |         source_node_uuid=entity_node_1.uuid,
 630 |         target_node_uuid=entity_node_4.uuid,
 631 |         name='RELATES_TO',
 632 |         fact='test_entity_1 relates to test_entity_4',
 633 |         created_at=datetime.now(),
 634 |         group_id=group_id,
 635 |     )
 636 |     await entity_edge_1.generate_embedding(mock_embedder)
 637 |     entity_edge_2 = EntityEdge(
 638 |         source_node_uuid=entity_node_2.uuid,
 639 |         target_node_uuid=entity_node_4.uuid,
 640 |         name='RELATES_TO',
 641 |         fact='test_entity_2 relates to test_entity_4',
 642 |         created_at=datetime.now(),
 643 |         group_id=group_id,
 644 |     )
 645 |     await entity_edge_2.generate_embedding(mock_embedder)
 646 |     entity_edge_3 = EntityEdge(
 647 |         source_node_uuid=entity_node_3.uuid,
 648 |         target_node_uuid=entity_node_4.uuid,
 649 |         name='RELATES_TO',
 650 |         fact='test_entity_3 relates to test_entity_4',
 651 |         created_at=datetime.now(),
 652 |         group_id=group_id,
 653 |     )
 654 |     await entity_edge_3.generate_embedding(mock_embedder)
 655 | 
 656 |     # Create community nodes
 657 |     community_node_1 = CommunityNode(
 658 |         name='test_community_1',
 659 |         labels=[],
 660 |         created_at=datetime.now(),
 661 |         group_id=group_id,
 662 |     )
 663 |     await community_node_1.generate_name_embedding(mock_embedder)
 664 |     community_node_2 = CommunityNode(
 665 |         name='test_community_2',
 666 |         labels=[],
 667 |         created_at=datetime.now(),
 668 |         group_id=group_id,
 669 |     )
 670 |     await community_node_2.generate_name_embedding(mock_embedder)
 671 | 
 672 |     # Create community to entity edges
 673 |     community_edge_1 = CommunityEdge(
 674 |         source_node_uuid=community_node_1.uuid,
 675 |         target_node_uuid=entity_node_1.uuid,
 676 |         created_at=datetime.now(),
 677 |         group_id=group_id,
 678 |     )
 679 |     community_edge_2 = CommunityEdge(
 680 |         source_node_uuid=community_node_1.uuid,
 681 |         target_node_uuid=entity_node_2.uuid,
 682 |         created_at=datetime.now(),
 683 |         group_id=group_id,
 684 |     )
 685 |     community_edge_3 = CommunityEdge(
 686 |         source_node_uuid=community_node_2.uuid,
 687 |         target_node_uuid=entity_node_3.uuid,
 688 |         created_at=datetime.now(),
 689 |         group_id=group_id,
 690 |     )
 691 | 
 692 |     # Save the graph
 693 |     await entity_node_1.save(graph_driver)
 694 |     await entity_node_2.save(graph_driver)
 695 |     await entity_node_3.save(graph_driver)
 696 |     await entity_node_4.save(graph_driver)
 697 |     await community_node_1.save(graph_driver)
 698 |     await community_node_2.save(graph_driver)
 699 | 
 700 |     await entity_edge_1.save(graph_driver)
 701 |     await entity_edge_2.save(graph_driver)
 702 |     await entity_edge_3.save(graph_driver)
 703 |     await community_edge_1.save(graph_driver)
 704 |     await community_edge_2.save(graph_driver)
 705 |     await community_edge_3.save(graph_driver)
 706 | 
 707 |     node_ids = [
 708 |         entity_node_1.uuid,
 709 |         entity_node_2.uuid,
 710 |         entity_node_3.uuid,
 711 |         entity_node_4.uuid,
 712 |         community_node_1.uuid,
 713 |         community_node_2.uuid,
 714 |     ]
 715 |     edge_ids = [
 716 |         entity_edge_1.uuid,
 717 |         entity_edge_2.uuid,
 718 |         entity_edge_3.uuid,
 719 |         community_edge_1.uuid,
 720 |         community_edge_2.uuid,
 721 |         community_edge_3.uuid,
 722 |     ]
 723 |     node_count = await get_node_count(graph_driver, node_ids)
 724 |     assert node_count == 6
 725 |     edge_count = await get_edge_count(graph_driver, edge_ids)
 726 |     assert edge_count == 6
 727 | 
 728 |     # Determine entity community
 729 |     community, is_new = await determine_entity_community(graph_driver, entity_node_4)
 730 |     assert community.name == community_node_1.name
 731 |     assert is_new
 732 | 
 733 |     # Add entity to community edge
 734 |     community_edge_4 = CommunityEdge(
 735 |         source_node_uuid=community_node_1.uuid,
 736 |         target_node_uuid=entity_node_4.uuid,
 737 |         created_at=datetime.now(),
 738 |         group_id=group_id,
 739 |     )
 740 |     await community_edge_4.save(graph_driver)
 741 | 
 742 |     # Determine entity community again
 743 |     community, is_new = await determine_entity_community(graph_driver, entity_node_4)
 744 |     assert community.name == community_node_1.name
 745 |     assert not is_new
 746 | 
 747 |     await remove_communities(graph_driver)
 748 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid, community_node_2.uuid])
 749 |     assert node_count == 0
 750 | 
 751 | 
 752 | @pytest.mark.asyncio
 753 | async def test_get_community_clusters(graph_driver, mock_embedder):
 754 |     if graph_driver.provider == GraphProvider.FALKORDB:
 755 |         pytest.skip('Skipping as test fails on FalkorDB')
 756 | 
 757 |     # Create entity nodes
 758 |     entity_node_1 = EntityNode(
 759 |         name='test_entity_1',
 760 |         labels=[],
 761 |         created_at=datetime.now(),
 762 |         group_id=group_id,
 763 |     )
 764 |     await entity_node_1.generate_name_embedding(mock_embedder)
 765 |     entity_node_2 = EntityNode(
 766 |         name='test_entity_2',
 767 |         labels=[],
 768 |         created_at=datetime.now(),
 769 |         group_id=group_id,
 770 |     )
 771 |     await entity_node_2.generate_name_embedding(mock_embedder)
 772 |     entity_node_3 = EntityNode(
 773 |         name='test_entity_3',
 774 |         labels=[],
 775 |         created_at=datetime.now(),
 776 |         group_id=group_id_2,
 777 |     )
 778 |     await entity_node_3.generate_name_embedding(mock_embedder)
 779 |     entity_node_4 = EntityNode(
 780 |         name='test_entity_4',
 781 |         labels=[],
 782 |         created_at=datetime.now(),
 783 |         group_id=group_id_2,
 784 |     )
 785 |     await entity_node_4.generate_name_embedding(mock_embedder)
 786 | 
 787 |     # Create entity edges
 788 |     entity_edge_1 = EntityEdge(
 789 |         source_node_uuid=entity_node_1.uuid,
 790 |         target_node_uuid=entity_node_2.uuid,
 791 |         name='RELATES_TO',
 792 |         fact='test_entity_1 relates to test_entity_2',
 793 |         created_at=datetime.now(),
 794 |         group_id=group_id,
 795 |     )
 796 |     await entity_edge_1.generate_embedding(mock_embedder)
 797 |     entity_edge_2 = EntityEdge(
 798 |         source_node_uuid=entity_node_3.uuid,
 799 |         target_node_uuid=entity_node_4.uuid,
 800 |         name='RELATES_TO',
 801 |         fact='test_entity_3 relates to test_entity_4',
 802 |         created_at=datetime.now(),
 803 |         group_id=group_id_2,
 804 |     )
 805 |     await entity_edge_2.generate_embedding(mock_embedder)
 806 | 
 807 |     # Save the graph
 808 |     await entity_node_1.save(graph_driver)
 809 |     await entity_node_2.save(graph_driver)
 810 |     await entity_node_3.save(graph_driver)
 811 |     await entity_node_4.save(graph_driver)
 812 |     await entity_edge_1.save(graph_driver)
 813 |     await entity_edge_2.save(graph_driver)
 814 | 
 815 |     node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
 816 |     edge_ids = [entity_edge_1.uuid, entity_edge_2.uuid]
 817 |     node_count = await get_node_count(graph_driver, node_ids)
 818 |     assert node_count == 4
 819 |     edge_count = await get_edge_count(graph_driver, edge_ids)
 820 |     assert edge_count == 2
 821 | 
 822 |     # Get community clusters
 823 |     clusters = await get_community_clusters(graph_driver, group_ids=None)
 824 |     assert len(clusters) == 2
 825 |     assert len(clusters[0]) == 2
 826 |     assert len(clusters[1]) == 2
 827 |     entities_1 = set([node.name for node in clusters[0]])
 828 |     entities_2 = set([node.name for node in clusters[1]])
 829 |     assert entities_1 == set(['test_entity_1', 'test_entity_2']) or entities_2 == set(
 830 |         ['test_entity_1', 'test_entity_2']
 831 |     )
 832 |     assert entities_1 == set(['test_entity_3', 'test_entity_4']) or entities_2 == set(
 833 |         ['test_entity_3', 'test_entity_4']
 834 |     )
 835 | 
 836 | 
 837 | @pytest.mark.asyncio
 838 | async def test_get_mentioned_nodes(graph_driver, mock_embedder):
 839 |     # Create episodic nodes
 840 |     episodic_node_1 = EpisodicNode(
 841 |         name='test_episodic_1',
 842 |         labels=[],
 843 |         created_at=datetime.now(),
 844 |         group_id=group_id,
 845 |         source=EpisodeType.message,
 846 |         source_description='test_source_description',
 847 |         content='test_content',
 848 |         valid_at=datetime.now(),
 849 |     )
 850 |     # Create entity nodes
 851 |     entity_node_1 = EntityNode(
 852 |         name='test_entity_1',
 853 |         labels=[],
 854 |         created_at=datetime.now(),
 855 |         group_id=group_id,
 856 |     )
 857 |     await entity_node_1.generate_name_embedding(mock_embedder)
 858 | 
 859 |     # Create episodic to entity edges
 860 |     episodic_edge_1 = EpisodicEdge(
 861 |         source_node_uuid=episodic_node_1.uuid,
 862 |         target_node_uuid=entity_node_1.uuid,
 863 |         created_at=datetime.now(),
 864 |         group_id=group_id,
 865 |     )
 866 | 
 867 |     # Save the graph
 868 |     await episodic_node_1.save(graph_driver)
 869 |     await entity_node_1.save(graph_driver)
 870 |     await episodic_edge_1.save(graph_driver)
 871 | 
 872 |     # Get mentioned nodes
 873 |     mentioned_nodes = await get_mentioned_nodes(graph_driver, [episodic_node_1])
 874 |     assert len(mentioned_nodes) == 1
 875 |     assert mentioned_nodes[0].name == entity_node_1.name
 876 | 
 877 | 
 878 | @pytest.mark.asyncio
 879 | async def test_get_communities_by_nodes(graph_driver, mock_embedder):
 880 |     # Create entity nodes
 881 |     entity_node_1 = EntityNode(
 882 |         name='test_entity_1',
 883 |         labels=[],
 884 |         created_at=datetime.now(),
 885 |         group_id=group_id,
 886 |     )
 887 |     await entity_node_1.generate_name_embedding(mock_embedder)
 888 | 
 889 |     # Create community nodes
 890 |     community_node_1 = CommunityNode(
 891 |         name='test_community_1',
 892 |         labels=[],
 893 |         created_at=datetime.now(),
 894 |         group_id=group_id,
 895 |     )
 896 |     await community_node_1.generate_name_embedding(mock_embedder)
 897 | 
 898 |     # Create community to entity edges
 899 |     community_edge_1 = CommunityEdge(
 900 |         source_node_uuid=community_node_1.uuid,
 901 |         target_node_uuid=entity_node_1.uuid,
 902 |         created_at=datetime.now(),
 903 |         group_id=group_id,
 904 |     )
 905 | 
 906 |     # Save the graph
 907 |     await entity_node_1.save(graph_driver)
 908 |     await community_node_1.save(graph_driver)
 909 |     await community_edge_1.save(graph_driver)
 910 | 
 911 |     # Get communities by nodes
 912 |     communities = await get_communities_by_nodes(graph_driver, [entity_node_1])
 913 |     assert len(communities) == 1
 914 |     assert communities[0].name == community_node_1.name
 915 | 
 916 | 
 917 | @pytest.mark.asyncio
 918 | async def test_edge_fulltext_search(
 919 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
 920 | ):
 921 |     if graph_driver.provider == GraphProvider.KUZU:
 922 |         pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
 923 | 
 924 |     graphiti = Graphiti(
 925 |         graph_driver=graph_driver,
 926 |         llm_client=mock_llm_client,
 927 |         embedder=mock_embedder,
 928 |         cross_encoder=mock_cross_encoder_client,
 929 |     )
 930 |     await graphiti.build_indices_and_constraints()
 931 | 
 932 |     # Create entity nodes
 933 |     entity_node_1 = EntityNode(
 934 |         name='test_entity_1',
 935 |         labels=[],
 936 |         created_at=datetime.now(),
 937 |         group_id=group_id,
 938 |     )
 939 |     await entity_node_1.generate_name_embedding(mock_embedder)
 940 |     entity_node_2 = EntityNode(
 941 |         name='test_entity_2',
 942 |         labels=[],
 943 |         created_at=datetime.now(),
 944 |         group_id=group_id,
 945 |     )
 946 |     await entity_node_2.generate_name_embedding(mock_embedder)
 947 | 
 948 |     now = datetime.now()
 949 |     created_at = now
 950 |     expired_at = now + timedelta(days=6)
 951 |     valid_at = now + timedelta(days=2)
 952 |     invalid_at = now + timedelta(days=4)
 953 | 
 954 |     # Create entity edges
 955 |     entity_edge_1 = EntityEdge(
 956 |         source_node_uuid=entity_node_1.uuid,
 957 |         target_node_uuid=entity_node_2.uuid,
 958 |         name='RELATES_TO',
 959 |         fact='test_entity_1 relates to test_entity_2',
 960 |         created_at=created_at,
 961 |         valid_at=valid_at,
 962 |         invalid_at=invalid_at,
 963 |         expired_at=expired_at,
 964 |         group_id=group_id,
 965 |     )
 966 |     await entity_edge_1.generate_embedding(mock_embedder)
 967 | 
 968 |     # Save the graph
 969 |     await entity_node_1.save(graph_driver)
 970 |     await entity_node_2.save(graph_driver)
 971 |     await entity_edge_1.save(graph_driver)
 972 | 
 973 |     # Search for entity edges
 974 |     search_filters = SearchFilters(
 975 |         node_labels=['Entity'],
 976 |         edge_types=['RELATES_TO'],
 977 |         created_at=[
 978 |             [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
 979 |         ],
 980 |         expired_at=[
 981 |             [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
 982 |         ],
 983 |         valid_at=[
 984 |             [
 985 |                 DateFilter(
 986 |                     date=now + timedelta(days=1),
 987 |                     comparison_operator=ComparisonOperator.greater_than_equal,
 988 |                 )
 989 |             ],
 990 |             [
 991 |                 DateFilter(
 992 |                     date=now + timedelta(days=3),
 993 |                     comparison_operator=ComparisonOperator.less_than_equal,
 994 |                 )
 995 |             ],
 996 |         ],
 997 |         invalid_at=[
 998 |             [
 999 |                 DateFilter(
1000 |                     date=now + timedelta(days=3),
1001 |                     comparison_operator=ComparisonOperator.greater_than,
1002 |                 )
1003 |             ],
1004 |             [
1005 |                 DateFilter(
1006 |                     date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1007 |                 )
1008 |             ],
1009 |         ],
1010 |     )
1011 |     edges = await edge_fulltext_search(
1012 |         graph_driver, 'test_entity_1 relates to test_entity_2', search_filters, group_ids=[group_id]
1013 |     )
1014 |     assert len(edges) == 1
1015 |     assert edges[0].name == entity_edge_1.name
1016 | 
1017 | 
1018 | @pytest.mark.asyncio
1019 | async def test_edge_similarity_search(graph_driver, mock_embedder):
1020 |     if graph_driver.provider == GraphProvider.FALKORDB:
1021 |         pytest.skip('Skipping as tests fail on Falkordb')
1022 | 
1023 |     # Create entity nodes
1024 |     entity_node_1 = EntityNode(
1025 |         name='test_entity_1',
1026 |         labels=[],
1027 |         created_at=datetime.now(),
1028 |         group_id=group_id,
1029 |     )
1030 |     await entity_node_1.generate_name_embedding(mock_embedder)
1031 |     entity_node_2 = EntityNode(
1032 |         name='test_entity_2',
1033 |         labels=[],
1034 |         created_at=datetime.now(),
1035 |         group_id=group_id,
1036 |     )
1037 |     await entity_node_2.generate_name_embedding(mock_embedder)
1038 | 
1039 |     now = datetime.now()
1040 |     created_at = now
1041 |     expired_at = now + timedelta(days=6)
1042 |     valid_at = now + timedelta(days=2)
1043 |     invalid_at = now + timedelta(days=4)
1044 | 
1045 |     # Create entity edges
1046 |     entity_edge_1 = EntityEdge(
1047 |         source_node_uuid=entity_node_1.uuid,
1048 |         target_node_uuid=entity_node_2.uuid,
1049 |         name='RELATES_TO',
1050 |         fact='test_entity_1 relates to test_entity_2',
1051 |         created_at=created_at,
1052 |         valid_at=valid_at,
1053 |         invalid_at=invalid_at,
1054 |         expired_at=expired_at,
1055 |         group_id=group_id,
1056 |     )
1057 |     await entity_edge_1.generate_embedding(mock_embedder)
1058 | 
1059 |     # Save the graph
1060 |     await entity_node_1.save(graph_driver)
1061 |     await entity_node_2.save(graph_driver)
1062 |     await entity_edge_1.save(graph_driver)
1063 | 
1064 |     # Search for entity edges
1065 |     search_filters = SearchFilters(
1066 |         node_labels=['Entity'],
1067 |         edge_types=['RELATES_TO'],
1068 |         created_at=[
1069 |             [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
1070 |         ],
1071 |         expired_at=[
1072 |             [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
1073 |         ],
1074 |         valid_at=[
1075 |             [
1076 |                 DateFilter(
1077 |                     date=now + timedelta(days=1),
1078 |                     comparison_operator=ComparisonOperator.greater_than_equal,
1079 |                 )
1080 |             ],
1081 |             [
1082 |                 DateFilter(
1083 |                     date=now + timedelta(days=3),
1084 |                     comparison_operator=ComparisonOperator.less_than_equal,
1085 |                 )
1086 |             ],
1087 |         ],
1088 |         invalid_at=[
1089 |             [
1090 |                 DateFilter(
1091 |                     date=now + timedelta(days=3),
1092 |                     comparison_operator=ComparisonOperator.greater_than,
1093 |                 )
1094 |             ],
1095 |             [
1096 |                 DateFilter(
1097 |                     date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1098 |                 )
1099 |             ],
1100 |         ],
1101 |     )
1102 |     edges = await edge_similarity_search(
1103 |         graph_driver,
1104 |         entity_edge_1.fact_embedding,
1105 |         entity_node_1.uuid,
1106 |         entity_node_2.uuid,
1107 |         search_filters,
1108 |         group_ids=[group_id],
1109 |     )
1110 |     assert len(edges) == 1
1111 |     assert edges[0].name == entity_edge_1.name
1112 | 
1113 | 
1114 | @pytest.mark.asyncio
1115 | async def test_edge_bfs_search(graph_driver, mock_embedder):
1116 |     if graph_driver.provider == GraphProvider.FALKORDB:
1117 |         pytest.skip('Skipping as tests fail on Falkordb')
1118 | 
1119 |     # Create episodic nodes
1120 |     episodic_node_1 = EpisodicNode(
1121 |         name='test_episodic_1',
1122 |         labels=[],
1123 |         created_at=datetime.now(),
1124 |         group_id=group_id,
1125 |         source=EpisodeType.message,
1126 |         source_description='test_source_description',
1127 |         content='test_content',
1128 |         valid_at=datetime.now(),
1129 |     )
1130 | 
1131 |     # Create entity nodes
1132 |     entity_node_1 = EntityNode(
1133 |         name='test_entity_1',
1134 |         labels=[],
1135 |         created_at=datetime.now(),
1136 |         group_id=group_id,
1137 |     )
1138 |     await entity_node_1.generate_name_embedding(mock_embedder)
1139 |     entity_node_2 = EntityNode(
1140 |         name='test_entity_2',
1141 |         labels=[],
1142 |         created_at=datetime.now(),
1143 |         group_id=group_id,
1144 |     )
1145 |     await entity_node_2.generate_name_embedding(mock_embedder)
1146 |     entity_node_3 = EntityNode(
1147 |         name='test_entity_3',
1148 |         labels=[],
1149 |         created_at=datetime.now(),
1150 |         group_id=group_id,
1151 |     )
1152 |     await entity_node_3.generate_name_embedding(mock_embedder)
1153 | 
1154 |     now = datetime.now()
1155 |     created_at = now
1156 |     expired_at = now + timedelta(days=6)
1157 |     valid_at = now + timedelta(days=2)
1158 |     invalid_at = now + timedelta(days=4)
1159 | 
1160 |     # Create entity edges
1161 |     entity_edge_1 = EntityEdge(
1162 |         source_node_uuid=entity_node_1.uuid,
1163 |         target_node_uuid=entity_node_2.uuid,
1164 |         name='RELATES_TO',
1165 |         fact='test_entity_1 relates to test_entity_2',
1166 |         created_at=created_at,
1167 |         valid_at=valid_at,
1168 |         invalid_at=invalid_at,
1169 |         expired_at=expired_at,
1170 |         group_id=group_id,
1171 |     )
1172 |     await entity_edge_1.generate_embedding(mock_embedder)
1173 |     entity_edge_2 = EntityEdge(
1174 |         source_node_uuid=entity_node_2.uuid,
1175 |         target_node_uuid=entity_node_3.uuid,
1176 |         name='RELATES_TO',
1177 |         fact='test_entity_2 relates to test_entity_3',
1178 |         created_at=created_at,
1179 |         valid_at=valid_at,
1180 |         invalid_at=invalid_at,
1181 |         expired_at=expired_at,
1182 |         group_id=group_id,
1183 |     )
1184 |     await entity_edge_2.generate_embedding(mock_embedder)
1185 | 
1186 |     # Create episodic to entity edges
1187 |     episodic_edge_1 = EpisodicEdge(
1188 |         source_node_uuid=episodic_node_1.uuid,
1189 |         target_node_uuid=entity_node_1.uuid,
1190 |         created_at=datetime.now(),
1191 |         group_id=group_id,
1192 |     )
1193 | 
1194 |     # Save the graph
1195 |     await episodic_node_1.save(graph_driver)
1196 |     await entity_node_1.save(graph_driver)
1197 |     await entity_node_2.save(graph_driver)
1198 |     await entity_node_3.save(graph_driver)
1199 |     await entity_edge_1.save(graph_driver)
1200 |     await entity_edge_2.save(graph_driver)
1201 |     await episodic_edge_1.save(graph_driver)
1202 | 
1203 |     # Search for entity edges
1204 |     search_filters = SearchFilters(
1205 |         node_labels=['Entity'],
1206 |         edge_types=['RELATES_TO'],
1207 |         created_at=[
1208 |             [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
1209 |         ],
1210 |         expired_at=[
1211 |             [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
1212 |         ],
1213 |         valid_at=[
1214 |             [
1215 |                 DateFilter(
1216 |                     date=now + timedelta(days=1),
1217 |                     comparison_operator=ComparisonOperator.greater_than_equal,
1218 |                 )
1219 |             ],
1220 |             [
1221 |                 DateFilter(
1222 |                     date=now + timedelta(days=3),
1223 |                     comparison_operator=ComparisonOperator.less_than_equal,
1224 |                 )
1225 |             ],
1226 |         ],
1227 |         invalid_at=[
1228 |             [
1229 |                 DateFilter(
1230 |                     date=now + timedelta(days=3),
1231 |                     comparison_operator=ComparisonOperator.greater_than,
1232 |                 )
1233 |             ],
1234 |             [
1235 |                 DateFilter(
1236 |                     date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1237 |                 )
1238 |             ],
1239 |         ],
1240 |     )
1241 | 
1242 |     # Test bfs from episodic node
1243 | 
1244 |     edges = await edge_bfs_search(
1245 |         graph_driver,
1246 |         [episodic_node_1.uuid],
1247 |         1,
1248 |         search_filters,
1249 |         group_ids=[group_id],
1250 |     )
1251 |     assert len(edges) == 0
1252 | 
1253 |     edges = await edge_bfs_search(
1254 |         graph_driver,
1255 |         [episodic_node_1.uuid],
1256 |         2,
1257 |         search_filters,
1258 |         group_ids=[group_id],
1259 |     )
1260 |     edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1261 |     assert len(edges_deduplicated) == 1
1262 |     assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
1263 | 
1264 |     edges = await edge_bfs_search(
1265 |         graph_driver,
1266 |         [episodic_node_1.uuid],
1267 |         3,
1268 |         search_filters,
1269 |         group_ids=[group_id],
1270 |     )
1271 |     edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1272 |     assert len(edges_deduplicated) == 2
1273 |     assert edges_deduplicated == {
1274 |         'test_entity_1 relates to test_entity_2',
1275 |         'test_entity_2 relates to test_entity_3',
1276 |     }
1277 | 
1278 |     # Test bfs from entity node
1279 | 
1280 |     edges = await edge_bfs_search(
1281 |         graph_driver,
1282 |         [entity_node_1.uuid],
1283 |         1,
1284 |         search_filters,
1285 |         group_ids=[group_id],
1286 |     )
1287 |     edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1288 |     assert len(edges_deduplicated) == 1
1289 |     assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
1290 | 
1291 |     edges = await edge_bfs_search(
1292 |         graph_driver,
1293 |         [entity_node_1.uuid],
1294 |         2,
1295 |         search_filters,
1296 |         group_ids=[group_id],
1297 |     )
1298 |     edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1299 |     assert len(edges_deduplicated) == 2
1300 |     assert edges_deduplicated == {
1301 |         'test_entity_1 relates to test_entity_2',
1302 |         'test_entity_2 relates to test_entity_3',
1303 |     }
1304 | 
1305 | 
1306 | @pytest.mark.asyncio
1307 | async def test_node_fulltext_search(
1308 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1309 | ):
1310 |     if graph_driver.provider == GraphProvider.KUZU:
1311 |         pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
1312 | 
1313 |     graphiti = Graphiti(
1314 |         graph_driver=graph_driver,
1315 |         llm_client=mock_llm_client,
1316 |         embedder=mock_embedder,
1317 |         cross_encoder=mock_cross_encoder_client,
1318 |     )
1319 |     await graphiti.build_indices_and_constraints()
1320 | 
1321 |     # Create entity nodes
1322 |     entity_node_1 = EntityNode(
1323 |         name='test_entity_1',
1324 |         summary='Summary about Alice',
1325 |         labels=[],
1326 |         created_at=datetime.now(),
1327 |         group_id=group_id,
1328 |     )
1329 |     await entity_node_1.generate_name_embedding(mock_embedder)
1330 |     entity_node_2 = EntityNode(
1331 |         name='test_entity_2',
1332 |         summary='Summary about Bob',
1333 |         labels=[],
1334 |         created_at=datetime.now(),
1335 |         group_id=group_id,
1336 |     )
1337 |     await entity_node_2.generate_name_embedding(mock_embedder)
1338 | 
1339 |     # Save the graph
1340 |     await entity_node_1.save(graph_driver)
1341 |     await entity_node_2.save(graph_driver)
1342 | 
1343 |     # Search for entity edges
1344 |     search_filters = SearchFilters(node_labels=['Entity'])
1345 |     nodes = await node_fulltext_search(
1346 |         graph_driver,
1347 |         'Alice',
1348 |         search_filters,
1349 |         group_ids=[group_id],
1350 |     )
1351 |     assert len(nodes) == 1
1352 |     assert nodes[0].name == entity_node_1.name
1353 | 
1354 | 
1355 | @pytest.mark.asyncio
1356 | async def test_node_similarity_search(graph_driver, mock_embedder):
1357 |     if graph_driver.provider == GraphProvider.FALKORDB:
1358 |         pytest.skip('Skipping as tests fail on Falkordb')
1359 | 
1360 |     # Create entity nodes
1361 |     entity_node_1 = EntityNode(
1362 |         name='test_entity_alice',
1363 |         summary='Summary about Alice',
1364 |         labels=[],
1365 |         created_at=datetime.now(),
1366 |         group_id=group_id,
1367 |     )
1368 |     await entity_node_1.generate_name_embedding(mock_embedder)
1369 |     entity_node_2 = EntityNode(
1370 |         name='test_entity_bob',
1371 |         summary='Summary about Bob',
1372 |         labels=[],
1373 |         created_at=datetime.now(),
1374 |         group_id=group_id,
1375 |     )
1376 |     await entity_node_2.generate_name_embedding(mock_embedder)
1377 | 
1378 |     # Save the graph
1379 |     await entity_node_1.save(graph_driver)
1380 |     await entity_node_2.save(graph_driver)
1381 | 
1382 |     # Search for entity edges
1383 |     search_filters = SearchFilters(node_labels=['Entity'])
1384 |     nodes = await node_similarity_search(
1385 |         graph_driver,
1386 |         entity_node_1.name_embedding,
1387 |         search_filters,
1388 |         group_ids=[group_id],
1389 |         min_score=0.9,
1390 |     )
1391 |     assert len(nodes) == 1
1392 |     assert nodes[0].name == entity_node_1.name
1393 | 
1394 | 
1395 | @pytest.mark.asyncio
1396 | async def test_node_bfs_search(graph_driver, mock_embedder):
1397 |     if graph_driver.provider == GraphProvider.FALKORDB:
1398 |         pytest.skip('Skipping as tests fail on Falkordb')
1399 | 
1400 |     # Create episodic nodes
1401 |     episodic_node_1 = EpisodicNode(
1402 |         name='test_episodic_1',
1403 |         labels=[],
1404 |         created_at=datetime.now(),
1405 |         group_id=group_id,
1406 |         source=EpisodeType.message,
1407 |         source_description='test_source_description',
1408 |         content='test_content',
1409 |         valid_at=datetime.now(),
1410 |     )
1411 | 
1412 |     # Create entity nodes
1413 |     entity_node_1 = EntityNode(
1414 |         name='test_entity_1',
1415 |         labels=[],
1416 |         created_at=datetime.now(),
1417 |         group_id=group_id,
1418 |     )
1419 |     await entity_node_1.generate_name_embedding(mock_embedder)
1420 |     entity_node_2 = EntityNode(
1421 |         name='test_entity_2',
1422 |         labels=[],
1423 |         created_at=datetime.now(),
1424 |         group_id=group_id,
1425 |     )
1426 |     await entity_node_2.generate_name_embedding(mock_embedder)
1427 |     entity_node_3 = EntityNode(
1428 |         name='test_entity_3',
1429 |         labels=[],
1430 |         created_at=datetime.now(),
1431 |         group_id=group_id,
1432 |     )
1433 |     await entity_node_3.generate_name_embedding(mock_embedder)
1434 | 
1435 |     # Create entity edges
1436 |     entity_edge_1 = EntityEdge(
1437 |         source_node_uuid=entity_node_1.uuid,
1438 |         target_node_uuid=entity_node_2.uuid,
1439 |         name='RELATES_TO',
1440 |         fact='test_entity_1 relates to test_entity_2',
1441 |         created_at=datetime.now(),
1442 |         group_id=group_id,
1443 |     )
1444 |     await entity_edge_1.generate_embedding(mock_embedder)
1445 |     entity_edge_2 = EntityEdge(
1446 |         source_node_uuid=entity_node_2.uuid,
1447 |         target_node_uuid=entity_node_3.uuid,
1448 |         name='RELATES_TO',
1449 |         fact='test_entity_2 relates to test_entity_3',
1450 |         created_at=datetime.now(),
1451 |         group_id=group_id,
1452 |     )
1453 |     await entity_edge_2.generate_embedding(mock_embedder)
1454 | 
1455 |     # Create episodic to entity edges
1456 |     episodic_edge_1 = EpisodicEdge(
1457 |         source_node_uuid=episodic_node_1.uuid,
1458 |         target_node_uuid=entity_node_1.uuid,
1459 |         created_at=datetime.now(),
1460 |         group_id=group_id,
1461 |     )
1462 | 
1463 |     # Save the graph
1464 |     await episodic_node_1.save(graph_driver)
1465 |     await entity_node_1.save(graph_driver)
1466 |     await entity_node_2.save(graph_driver)
1467 |     await entity_node_3.save(graph_driver)
1468 |     await entity_edge_1.save(graph_driver)
1469 |     await entity_edge_2.save(graph_driver)
1470 |     await episodic_edge_1.save(graph_driver)
1471 | 
1472 |     # Search for entity nodes
1473 |     search_filters = SearchFilters(
1474 |         node_labels=['Entity'],
1475 |     )
1476 | 
1477 |     # Test bfs from episodic node
1478 | 
1479 |     nodes = await node_bfs_search(
1480 |         graph_driver,
1481 |         [episodic_node_1.uuid],
1482 |         search_filters,
1483 |         1,
1484 |         group_ids=[group_id],
1485 |     )
1486 |     nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
1487 |     assert len(nodes_deduplicated) == 1
1488 |     assert nodes_deduplicated == {'test_entity_1'}
1489 | 
1490 |     nodes = await node_bfs_search(
1491 |         graph_driver,
1492 |         [episodic_node_1.uuid],
1493 |         search_filters,
1494 |         2,
1495 |         group_ids=[group_id],
1496 |     )
1497 |     nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
1498 |     assert len(nodes_deduplicated) == 2
1499 |     assert nodes_deduplicated == {'test_entity_1', 'test_entity_2'}
1500 | 
1501 |     # Test bfs from entity node
1502 | 
1503 |     nodes = await node_bfs_search(
1504 |         graph_driver,
1505 |         [entity_node_1.uuid],
1506 |         search_filters,
1507 |         1,
1508 |         group_ids=[group_id],
1509 |     )
1510 |     nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
1511 |     assert len(nodes_deduplicated) == 1
1512 |     assert nodes_deduplicated == {'test_entity_2'}
1513 | 
1514 | 
1515 | @pytest.mark.asyncio
1516 | async def test_episode_fulltext_search(
1517 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1518 | ):
1519 |     if graph_driver.provider == GraphProvider.KUZU:
1520 |         pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
1521 | 
1522 |     graphiti = Graphiti(
1523 |         graph_driver=graph_driver,
1524 |         llm_client=mock_llm_client,
1525 |         embedder=mock_embedder,
1526 |         cross_encoder=mock_cross_encoder_client,
1527 |     )
1528 |     await graphiti.build_indices_and_constraints()
1529 | 
1530 |     # Create episodic nodes
1531 |     episodic_node_1 = EpisodicNode(
1532 |         name='test_episodic_1',
1533 |         content='test_content',
1534 |         created_at=datetime.now(),
1535 |         valid_at=datetime.now(),
1536 |         group_id=group_id,
1537 |         source=EpisodeType.message,
1538 |         source_description='Description about Alice',
1539 |     )
1540 |     episodic_node_2 = EpisodicNode(
1541 |         name='test_episodic_2',
1542 |         content='test_content_2',
1543 |         created_at=datetime.now(),
1544 |         valid_at=datetime.now(),
1545 |         group_id=group_id,
1546 |         source=EpisodeType.message,
1547 |         source_description='Description about Bob',
1548 |     )
1549 | 
1550 |     # Save the graph
1551 |     await episodic_node_1.save(graph_driver)
1552 |     await episodic_node_2.save(graph_driver)
1553 | 
1554 |     # Search for episodic nodes
1555 |     search_filters = SearchFilters(node_labels=['Episodic'])
1556 |     nodes = await episode_fulltext_search(
1557 |         graph_driver,
1558 |         'Alice',
1559 |         search_filters,
1560 |         group_ids=[group_id],
1561 |     )
1562 |     assert len(nodes) == 1
1563 |     assert nodes[0].name == episodic_node_1.name
1564 | 
1565 | 
1566 | @pytest.mark.asyncio
1567 | async def test_community_fulltext_search(
1568 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1569 | ):
1570 |     if graph_driver.provider == GraphProvider.KUZU:
1571 |         pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
1572 | 
1573 |     graphiti = Graphiti(
1574 |         graph_driver=graph_driver,
1575 |         llm_client=mock_llm_client,
1576 |         embedder=mock_embedder,
1577 |         cross_encoder=mock_cross_encoder_client,
1578 |     )
1579 |     await graphiti.build_indices_and_constraints()
1580 | 
1581 |     # Create community nodes
1582 |     community_node_1 = CommunityNode(
1583 |         name='Alice',
1584 |         created_at=datetime.now(),
1585 |         group_id=group_id,
1586 |     )
1587 |     await community_node_1.generate_name_embedding(mock_embedder)
1588 |     community_node_2 = CommunityNode(
1589 |         name='Bob',
1590 |         created_at=datetime.now(),
1591 |         group_id=group_id,
1592 |     )
1593 |     await community_node_2.generate_name_embedding(mock_embedder)
1594 | 
1595 |     # Save the graph
1596 |     await community_node_1.save(graph_driver)
1597 |     await community_node_2.save(graph_driver)
1598 | 
1599 |     # Search for community nodes
1600 |     nodes = await community_fulltext_search(
1601 |         graph_driver,
1602 |         'Alice',
1603 |         group_ids=[group_id],
1604 |     )
1605 |     assert len(nodes) == 1
1606 |     assert nodes[0].name == community_node_1.name
1607 | 
1608 | 
1609 | @pytest.mark.asyncio
1610 | async def test_community_similarity_search(
1611 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1612 | ):
1613 |     if graph_driver.provider == GraphProvider.FALKORDB:
1614 |         pytest.skip('Skipping as tests fail on Falkordb')
1615 | 
1616 |     graphiti = Graphiti(
1617 |         graph_driver=graph_driver,
1618 |         llm_client=mock_llm_client,
1619 |         embedder=mock_embedder,
1620 |         cross_encoder=mock_cross_encoder_client,
1621 |     )
1622 |     await graphiti.build_indices_and_constraints()
1623 | 
1624 |     # Create community nodes
1625 |     community_node_1 = CommunityNode(
1626 |         name='Alice',
1627 |         created_at=datetime.now(),
1628 |         group_id=group_id,
1629 |     )
1630 |     await community_node_1.generate_name_embedding(mock_embedder)
1631 |     community_node_2 = CommunityNode(
1632 |         name='Bob',
1633 |         created_at=datetime.now(),
1634 |         group_id=group_id,
1635 |     )
1636 |     await community_node_2.generate_name_embedding(mock_embedder)
1637 | 
1638 |     # Save the graph
1639 |     await community_node_1.save(graph_driver)
1640 |     await community_node_2.save(graph_driver)
1641 | 
1642 |     # Search for community nodes
1643 |     nodes = await community_similarity_search(
1644 |         graph_driver,
1645 |         community_node_1.name_embedding,
1646 |         group_ids=[group_id],
1647 |         min_score=0.9,
1648 |     )
1649 |     assert len(nodes) == 1
1650 |     assert nodes[0].name == community_node_1.name
1651 | 
1652 | 
1653 | @pytest.mark.asyncio
1654 | async def test_get_relevant_nodes(
1655 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1656 | ):
1657 |     if graph_driver.provider == GraphProvider.FALKORDB:
1658 |         pytest.skip('Skipping as tests fail on Falkordb')
1659 | 
1660 |     if graph_driver.provider == GraphProvider.KUZU:
1661 |         pytest.skip('Skipping as tests fail on Kuzu')
1662 | 
1663 |     graphiti = Graphiti(
1664 |         graph_driver=graph_driver,
1665 |         llm_client=mock_llm_client,
1666 |         embedder=mock_embedder,
1667 |         cross_encoder=mock_cross_encoder_client,
1668 |     )
1669 |     await graphiti.build_indices_and_constraints()
1670 | 
1671 |     # Create entity nodes
1672 |     entity_node_1 = EntityNode(
1673 |         name='Alice',
1674 |         summary='Alice',
1675 |         labels=[],
1676 |         created_at=datetime.now(),
1677 |         group_id=group_id,
1678 |     )
1679 |     await entity_node_1.generate_name_embedding(mock_embedder)
1680 |     entity_node_2 = EntityNode(
1681 |         name='Bob',
1682 |         summary='Bob',
1683 |         labels=[],
1684 |         created_at=datetime.now(),
1685 |         group_id=group_id,
1686 |     )
1687 |     await entity_node_2.generate_name_embedding(mock_embedder)
1688 |     entity_node_3 = EntityNode(
1689 |         name='Alice Smith',
1690 |         summary='Alice Smith',
1691 |         labels=[],
1692 |         created_at=datetime.now(),
1693 |         group_id=group_id,
1694 |     )
1695 |     await entity_node_3.generate_name_embedding(mock_embedder)
1696 | 
1697 |     # Save the graph
1698 |     await entity_node_1.save(graph_driver)
1699 |     await entity_node_2.save(graph_driver)
1700 |     await entity_node_3.save(graph_driver)
1701 | 
1702 |     # Search for entity nodes
1703 |     search_filters = SearchFilters(node_labels=['Entity'])
1704 |     nodes = (
1705 |         await get_relevant_nodes(
1706 |             graph_driver,
1707 |             [entity_node_1],
1708 |             search_filters,
1709 |             min_score=0.9,
1710 |         )
1711 |     )[0]
1712 |     assert len(nodes) == 2
1713 |     assert set({node.name for node in nodes}) == {entity_node_1.name, entity_node_3.name}
1714 | 
1715 | 
1716 | @pytest.mark.asyncio
1717 | async def test_get_relevant_edges_and_invalidation_candidates(
1718 |     graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1719 | ):
1720 |     if graph_driver.provider == GraphProvider.FALKORDB:
1721 |         pytest.skip('Skipping as tests fail on Falkordb')
1722 | 
1723 |     graphiti = Graphiti(
1724 |         graph_driver=graph_driver,
1725 |         llm_client=mock_llm_client,
1726 |         embedder=mock_embedder,
1727 |         cross_encoder=mock_cross_encoder_client,
1728 |     )
1729 |     await graphiti.build_indices_and_constraints()
1730 | 
1731 |     # Create entity nodes
1732 |     entity_node_1 = EntityNode(
1733 |         name='test_entity_1',
1734 |         summary='test_entity_1',
1735 |         labels=[],
1736 |         created_at=datetime.now(),
1737 |         group_id=group_id,
1738 |     )
1739 |     await entity_node_1.generate_name_embedding(mock_embedder)
1740 |     entity_node_2 = EntityNode(
1741 |         name='test_entity_2',
1742 |         summary='test_entity_2',
1743 |         labels=[],
1744 |         created_at=datetime.now(),
1745 |         group_id=group_id,
1746 |     )
1747 |     await entity_node_2.generate_name_embedding(mock_embedder)
1748 |     entity_node_3 = EntityNode(
1749 |         name='test_entity_3',
1750 |         summary='test_entity_3',
1751 |         labels=[],
1752 |         created_at=datetime.now(),
1753 |         group_id=group_id,
1754 |     )
1755 |     await entity_node_3.generate_name_embedding(mock_embedder)
1756 | 
1757 |     now = datetime.now()
1758 |     created_at = now
1759 |     expired_at = now + timedelta(days=6)
1760 |     valid_at = now + timedelta(days=2)
1761 |     invalid_at = now + timedelta(days=4)
1762 | 
1763 |     # Create entity edges
1764 |     entity_edge_1 = EntityEdge(
1765 |         source_node_uuid=entity_node_1.uuid,
1766 |         target_node_uuid=entity_node_2.uuid,
1767 |         name='RELATES_TO',
1768 |         fact='Alice',
1769 |         created_at=created_at,
1770 |         expired_at=expired_at,
1771 |         valid_at=valid_at,
1772 |         invalid_at=invalid_at,
1773 |         group_id=group_id,
1774 |     )
1775 |     await entity_edge_1.generate_embedding(mock_embedder)
1776 |     entity_edge_2 = EntityEdge(
1777 |         source_node_uuid=entity_node_2.uuid,
1778 |         target_node_uuid=entity_node_3.uuid,
1779 |         name='RELATES_TO',
1780 |         fact='Bob',
1781 |         created_at=created_at,
1782 |         expired_at=expired_at,
1783 |         valid_at=valid_at,
1784 |         invalid_at=invalid_at,
1785 |         group_id=group_id,
1786 |     )
1787 |     await entity_edge_2.generate_embedding(mock_embedder)
1788 |     entity_edge_3 = EntityEdge(
1789 |         source_node_uuid=entity_node_1.uuid,
1790 |         target_node_uuid=entity_node_3.uuid,
1791 |         name='RELATES_TO',
1792 |         fact='Alice',
1793 |         created_at=created_at,
1794 |         expired_at=expired_at,
1795 |         valid_at=valid_at,
1796 |         invalid_at=invalid_at,
1797 |         group_id=group_id,
1798 |     )
1799 |     await entity_edge_3.generate_embedding(mock_embedder)
1800 | 
1801 |     # Save the graph
1802 |     await entity_node_1.save(graph_driver)
1803 |     await entity_node_2.save(graph_driver)
1804 |     await entity_node_3.save(graph_driver)
1805 |     await entity_edge_1.save(graph_driver)
1806 |     await entity_edge_2.save(graph_driver)
1807 |     await entity_edge_3.save(graph_driver)
1808 | 
1809 |     # Search for entity nodes
1810 |     search_filters = SearchFilters(
1811 |         node_labels=['Entity'],
1812 |         edge_types=['RELATES_TO'],
1813 |         created_at=[
1814 |             [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
1815 |         ],
1816 |         expired_at=[
1817 |             [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
1818 |         ],
1819 |         valid_at=[
1820 |             [
1821 |                 DateFilter(
1822 |                     date=now + timedelta(days=1),
1823 |                     comparison_operator=ComparisonOperator.greater_than_equal,
1824 |                 )
1825 |             ],
1826 |             [
1827 |                 DateFilter(
1828 |                     date=now + timedelta(days=3),
1829 |                     comparison_operator=ComparisonOperator.less_than_equal,
1830 |                 )
1831 |             ],
1832 |         ],
1833 |         invalid_at=[
1834 |             [
1835 |                 DateFilter(
1836 |                     date=now + timedelta(days=3),
1837 |                     comparison_operator=ComparisonOperator.greater_than,
1838 |                 )
1839 |             ],
1840 |             [
1841 |                 DateFilter(
1842 |                     date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1843 |                 )
1844 |             ],
1845 |         ],
1846 |     )
1847 |     edges = (
1848 |         await get_relevant_edges(
1849 |             graph_driver,
1850 |             [entity_edge_1],
1851 |             search_filters,
1852 |             min_score=0.9,
1853 |         )
1854 |     )[0]
1855 |     assert len(edges) == 1
1856 |     assert set({edge.name for edge in edges}) == {entity_edge_1.name}
1857 | 
1858 |     edges = (
1859 |         await get_edge_invalidation_candidates(
1860 |             graph_driver,
1861 |             [entity_edge_1],
1862 |             search_filters,
1863 |             min_score=0.9,
1864 |         )
1865 |     )[0]
1866 |     assert len(edges) == 2
1867 |     assert set({edge.name for edge in edges}) == {entity_edge_1.name, entity_edge_3.name}
1868 | 
1869 | 
1870 | @pytest.mark.asyncio
1871 | async def test_node_distance_reranker(graph_driver, mock_embedder):
1872 |     if graph_driver.provider == GraphProvider.FALKORDB:
1873 |         pytest.skip('Skipping as tests fail on Falkordb')
1874 | 
1875 |     # Create entity nodes
1876 |     entity_node_1 = EntityNode(
1877 |         name='test_entity_1',
1878 |         labels=[],
1879 |         created_at=datetime.now(),
1880 |         group_id=group_id,
1881 |     )
1882 |     await entity_node_1.generate_name_embedding(mock_embedder)
1883 |     entity_node_2 = EntityNode(
1884 |         name='test_entity_2',
1885 |         labels=[],
1886 |         created_at=datetime.now(),
1887 |         group_id=group_id,
1888 |     )
1889 |     await entity_node_2.generate_name_embedding(mock_embedder)
1890 |     entity_node_3 = EntityNode(
1891 |         name='test_entity_3',
1892 |         labels=[],
1893 |         created_at=datetime.now(),
1894 |         group_id=group_id,
1895 |     )
1896 |     await entity_node_3.generate_name_embedding(mock_embedder)
1897 | 
1898 |     # Create entity edges
1899 |     entity_edge_1 = EntityEdge(
1900 |         source_node_uuid=entity_node_1.uuid,
1901 |         target_node_uuid=entity_node_2.uuid,
1902 |         name='RELATES_TO',
1903 |         fact='test_entity_1 relates to test_entity_2',
1904 |         created_at=datetime.now(),
1905 |         group_id=group_id,
1906 |     )
1907 |     await entity_edge_1.generate_embedding(mock_embedder)
1908 | 
1909 |     # Save the graph
1910 |     await entity_node_1.save(graph_driver)
1911 |     await entity_node_2.save(graph_driver)
1912 |     await entity_node_3.save(graph_driver)
1913 |     await entity_edge_1.save(graph_driver)
1914 | 
1915 |     # Test reranker
1916 |     reranked_uuids, reranked_scores = await node_distance_reranker(
1917 |         graph_driver,
1918 |         [entity_node_2.uuid, entity_node_3.uuid],
1919 |         entity_node_1.uuid,
1920 |     )
1921 |     uuid_to_name = {
1922 |         entity_node_1.uuid: entity_node_1.name,
1923 |         entity_node_2.uuid: entity_node_2.name,
1924 |         entity_node_3.uuid: entity_node_3.name,
1925 |     }
1926 |     names = [uuid_to_name[uuid] for uuid in reranked_uuids]
1927 |     assert names == [entity_node_2.name, entity_node_3.name]
1928 |     assert np.allclose(reranked_scores, [1.0, 0.0])
1929 | 
1930 | 
1931 | @pytest.mark.asyncio
1932 | async def test_episode_mentions_reranker(graph_driver, mock_embedder):
1933 |     if graph_driver.provider == GraphProvider.FALKORDB:
1934 |         pytest.skip('Skipping as tests fail on Falkordb')
1935 | 
1936 |     # Create episodic nodes
1937 |     episodic_node_1 = EpisodicNode(
1938 |         name='test_episodic_1',
1939 |         content='test_content',
1940 |         created_at=datetime.now(),
1941 |         valid_at=datetime.now(),
1942 |         group_id=group_id,
1943 |         source=EpisodeType.message,
1944 |         source_description='Description about Alice',
1945 |     )
1946 | 
1947 |     # Create entity nodes
1948 |     entity_node_1 = EntityNode(
1949 |         name='test_entity_1',
1950 |         labels=[],
1951 |         created_at=datetime.now(),
1952 |         group_id=group_id,
1953 |     )
1954 |     await entity_node_1.generate_name_embedding(mock_embedder)
1955 |     entity_node_2 = EntityNode(
1956 |         name='test_entity_2',
1957 |         labels=[],
1958 |         created_at=datetime.now(),
1959 |         group_id=group_id,
1960 |     )
1961 |     await entity_node_2.generate_name_embedding(mock_embedder)
1962 | 
1963 |     # Create entity edges
1964 |     episodic_edge_1 = EpisodicEdge(
1965 |         source_node_uuid=episodic_node_1.uuid,
1966 |         target_node_uuid=entity_node_1.uuid,
1967 |         created_at=datetime.now(),
1968 |         group_id=group_id,
1969 |     )
1970 | 
1971 |     # Save the graph
1972 |     await entity_node_1.save(graph_driver)
1973 |     await entity_node_2.save(graph_driver)
1974 |     await episodic_node_1.save(graph_driver)
1975 |     await episodic_edge_1.save(graph_driver)
1976 | 
1977 |     # Test reranker
1978 |     reranked_uuids, reranked_scores = await episode_mentions_reranker(
1979 |         graph_driver,
1980 |         [[entity_node_1.uuid, entity_node_2.uuid]],
1981 |     )
1982 |     uuid_to_name = {entity_node_1.uuid: entity_node_1.name, entity_node_2.uuid: entity_node_2.name}
1983 |     names = [uuid_to_name[uuid] for uuid in reranked_uuids]
1984 |     assert names == [entity_node_1.name, entity_node_2.name]
1985 |     assert np.allclose(reranked_scores, [1.0, float('inf')])
1986 | 
1987 | 
1988 | @pytest.mark.asyncio
1989 | async def test_get_embeddings_for_edges(graph_driver, mock_embedder):
1990 |     # Create entity nodes
1991 |     entity_node_1 = EntityNode(
1992 |         name='test_entity_1',
1993 |         labels=[],
1994 |         created_at=datetime.now(),
1995 |         group_id=group_id,
1996 |     )
1997 |     await entity_node_1.generate_name_embedding(mock_embedder)
1998 |     entity_node_2 = EntityNode(
1999 |         name='test_entity_2',
2000 |         labels=[],
2001 |         created_at=datetime.now(),
2002 |         group_id=group_id,
2003 |     )
2004 |     await entity_node_2.generate_name_embedding(mock_embedder)
2005 | 
2006 |     # Create entity edges
2007 |     entity_edge_1 = EntityEdge(
2008 |         source_node_uuid=entity_node_1.uuid,
2009 |         target_node_uuid=entity_node_2.uuid,
2010 |         name='RELATES_TO',
2011 |         fact='test_entity_1 relates to test_entity_2',
2012 |         created_at=datetime.now(),
2013 |         group_id=group_id,
2014 |     )
2015 |     await entity_edge_1.generate_embedding(mock_embedder)
2016 | 
2017 |     # Save the graph
2018 |     await entity_node_1.save(graph_driver)
2019 |     await entity_node_2.save(graph_driver)
2020 |     await entity_edge_1.save(graph_driver)
2021 | 
2022 |     # Get embeddings for edges
2023 |     embeddings = await get_embeddings_for_edges(graph_driver, [entity_edge_1])
2024 |     assert len(embeddings) == 1
2025 |     assert entity_edge_1.uuid in embeddings
2026 |     assert np.allclose(embeddings[entity_edge_1.uuid], entity_edge_1.fact_embedding)
2027 | 
2028 | 
2029 | @pytest.mark.asyncio
2030 | async def test_get_embeddings_for_nodes(graph_driver, mock_embedder):
2031 |     # Create entity nodes
2032 |     entity_node_1 = EntityNode(
2033 |         name='test_entity_1',
2034 |         labels=[],
2035 |         created_at=datetime.now(),
2036 |         group_id=group_id,
2037 |     )
2038 |     await entity_node_1.generate_name_embedding(mock_embedder)
2039 | 
2040 |     # Save the graph
2041 |     await entity_node_1.save(graph_driver)
2042 | 
2043 |     # Get embeddings for edges
2044 |     embeddings = await get_embeddings_for_nodes(graph_driver, [entity_node_1])
2045 |     assert len(embeddings) == 1
2046 |     assert entity_node_1.uuid in embeddings
2047 |     assert np.allclose(embeddings[entity_node_1.uuid], entity_node_1.name_embedding)
2048 | 
2049 | 
2050 | @pytest.mark.asyncio
2051 | async def test_get_embeddings_for_communities(graph_driver, mock_embedder):
2052 |     # Create community nodes
2053 |     community_node_1 = CommunityNode(
2054 |         name='test_community_1',
2055 |         labels=[],
2056 |         created_at=datetime.now(),
2057 |         group_id=group_id,
2058 |     )
2059 |     await community_node_1.generate_name_embedding(mock_embedder)
2060 | 
2061 |     # Save the graph
2062 |     await community_node_1.save(graph_driver)
2063 | 
2064 |     # Get embeddings for communities
2065 |     embeddings = await get_embeddings_for_communities(graph_driver, [community_node_1])
2066 |     assert len(embeddings) == 1
2067 |     assert community_node_1.uuid in embeddings
2068 |     assert np.allclose(embeddings[community_node_1.uuid], community_node_1.name_embedding)
2069 | 
```
Page 10/12FirstPrevNextLast