This is page 10 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/graphiti_core/graphiti.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from datetime import datetime
19 | from time import time
20 |
21 | from dotenv import load_dotenv
22 | from pydantic import BaseModel
23 | from typing_extensions import LiteralString
24 |
25 | from graphiti_core.cross_encoder.client import CrossEncoderClient
26 | from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
27 | from graphiti_core.decorators import handle_multiple_group_ids
28 | from graphiti_core.driver.driver import GraphDriver
29 | from graphiti_core.driver.neo4j_driver import Neo4jDriver
30 | from graphiti_core.edges import (
31 | CommunityEdge,
32 | Edge,
33 | EntityEdge,
34 | EpisodicEdge,
35 | create_entity_edge_embeddings,
36 | )
37 | from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
38 | from graphiti_core.graphiti_types import GraphitiClients
39 | from graphiti_core.helpers import (
40 | get_default_group_id,
41 | semaphore_gather,
42 | validate_excluded_entity_types,
43 | validate_group_id,
44 | )
45 | from graphiti_core.llm_client import LLMClient, OpenAIClient
46 | from graphiti_core.nodes import (
47 | CommunityNode,
48 | EntityNode,
49 | EpisodeType,
50 | EpisodicNode,
51 | Node,
52 | create_entity_node_embeddings,
53 | )
54 | from graphiti_core.search.search import SearchConfig, search
55 | from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
56 | from graphiti_core.search.search_config_recipes import (
57 | COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
58 | EDGE_HYBRID_SEARCH_NODE_DISTANCE,
59 | EDGE_HYBRID_SEARCH_RRF,
60 | )
61 | from graphiti_core.search.search_filters import SearchFilters
62 | from graphiti_core.search.search_utils import (
63 | RELEVANT_SCHEMA_LIMIT,
64 | get_mentioned_nodes,
65 | )
66 | from graphiti_core.telemetry import capture_event
67 | from graphiti_core.tracer import Tracer, create_tracer
68 | from graphiti_core.utils.bulk_utils import (
69 | RawEpisode,
70 | add_nodes_and_edges_bulk,
71 | dedupe_edges_bulk,
72 | dedupe_nodes_bulk,
73 | extract_nodes_and_edges_bulk,
74 | resolve_edge_pointers,
75 | retrieve_previous_episodes_bulk,
76 | )
77 | from graphiti_core.utils.datetime_utils import utc_now
78 | from graphiti_core.utils.maintenance.community_operations import (
79 | build_communities,
80 | remove_communities,
81 | update_community,
82 | )
83 | from graphiti_core.utils.maintenance.edge_operations import (
84 | build_episodic_edges,
85 | extract_edges,
86 | resolve_extracted_edge,
87 | resolve_extracted_edges,
88 | )
89 | from graphiti_core.utils.maintenance.graph_data_operations import (
90 | EPISODE_WINDOW_LEN,
91 | retrieve_episodes,
92 | )
93 | from graphiti_core.utils.maintenance.node_operations import (
94 | extract_attributes_from_nodes,
95 | extract_nodes,
96 | resolve_extracted_nodes,
97 | )
98 | from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
99 |
100 | logger = logging.getLogger(__name__)
101 |
102 | load_dotenv()
103 |
104 |
105 | class AddEpisodeResults(BaseModel):
106 | episode: EpisodicNode
107 | episodic_edges: list[EpisodicEdge]
108 | nodes: list[EntityNode]
109 | edges: list[EntityEdge]
110 | communities: list[CommunityNode]
111 | community_edges: list[CommunityEdge]
112 |
113 |
114 | class AddBulkEpisodeResults(BaseModel):
115 | episodes: list[EpisodicNode]
116 | episodic_edges: list[EpisodicEdge]
117 | nodes: list[EntityNode]
118 | edges: list[EntityEdge]
119 | communities: list[CommunityNode]
120 | community_edges: list[CommunityEdge]
121 |
122 |
123 | class AddTripletResults(BaseModel):
124 | nodes: list[EntityNode]
125 | edges: list[EntityEdge]
126 |
127 |
128 | class Graphiti:
129 | def __init__(
130 | self,
131 | uri: str | None = None,
132 | user: str | None = None,
133 | password: str | None = None,
134 | llm_client: LLMClient | None = None,
135 | embedder: EmbedderClient | None = None,
136 | cross_encoder: CrossEncoderClient | None = None,
137 | store_raw_episode_content: bool = True,
138 | graph_driver: GraphDriver | None = None,
139 | max_coroutines: int | None = None,
140 | tracer: Tracer | None = None,
141 | trace_span_prefix: str = 'graphiti',
142 | ):
143 | """
144 | Initialize a Graphiti instance.
145 |
146 | This constructor sets up a connection to a graph database and initializes
147 | the LLM client for natural language processing tasks.
148 |
149 | Parameters
150 | ----------
151 | uri : str
152 | The URI of the Neo4j database.
153 | user : str
154 | The username for authenticating with the Neo4j database.
155 | password : str
156 | The password for authenticating with the Neo4j database.
157 | llm_client : LLMClient | None, optional
158 | An instance of LLMClient for natural language processing tasks.
159 | If not provided, a default OpenAIClient will be initialized.
160 | embedder : EmbedderClient | None, optional
161 | An instance of EmbedderClient for embedding tasks.
162 | If not provided, a default OpenAIEmbedder will be initialized.
163 | cross_encoder : CrossEncoderClient | None, optional
164 | An instance of CrossEncoderClient for reranking tasks.
165 | If not provided, a default OpenAIRerankerClient will be initialized.
166 | store_raw_episode_content : bool, optional
167 | Whether to store the raw content of episodes. Defaults to True.
168 | graph_driver : GraphDriver | None, optional
169 | An instance of GraphDriver for database operations.
170 | If not provided, a default Neo4jDriver will be initialized.
171 | max_coroutines : int | None, optional
172 | The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
173 | If not set, the Graphiti default is used.
174 | tracer : Tracer | None, optional
175 | An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
176 | trace_span_prefix : str, optional
177 | Prefix to prepend to all span names. Defaults to 'graphiti'.
178 |
179 | Returns
180 | -------
181 | None
182 |
183 | Notes
184 | -----
185 | This method establishes a connection to a graph database (Neo4j by default) using the provided
186 | credentials. It also sets up the LLM client, either using the provided client
187 | or by creating a default OpenAIClient.
188 |
189 | The default database name is defined during the driver’s construction. If a different database name
190 | is required, it should be specified in the URI or set separately after
191 | initialization.
192 |
193 | The OpenAI API key is expected to be set in the environment variables.
194 | Make sure to set the OPENAI_API_KEY environment variable before initializing
195 | Graphiti if you're using the default OpenAIClient.
196 | """
197 |
198 | if graph_driver:
199 | self.driver = graph_driver
200 | else:
201 | if uri is None:
202 | raise ValueError('uri must be provided when graph_driver is None')
203 | self.driver = Neo4jDriver(uri, user, password)
204 |
205 | self.store_raw_episode_content = store_raw_episode_content
206 | self.max_coroutines = max_coroutines
207 | if llm_client:
208 | self.llm_client = llm_client
209 | else:
210 | self.llm_client = OpenAIClient()
211 | if embedder:
212 | self.embedder = embedder
213 | else:
214 | self.embedder = OpenAIEmbedder()
215 | if cross_encoder:
216 | self.cross_encoder = cross_encoder
217 | else:
218 | self.cross_encoder = OpenAIRerankerClient()
219 |
220 | # Initialize tracer
221 | self.tracer = create_tracer(tracer, trace_span_prefix)
222 |
223 | # Set tracer on clients
224 | self.llm_client.set_tracer(self.tracer)
225 |
226 | self.clients = GraphitiClients(
227 | driver=self.driver,
228 | llm_client=self.llm_client,
229 | embedder=self.embedder,
230 | cross_encoder=self.cross_encoder,
231 | tracer=self.tracer,
232 | )
233 |
234 | # Capture telemetry event
235 | self._capture_initialization_telemetry()
236 |
237 | def _capture_initialization_telemetry(self):
238 | """Capture telemetry event for Graphiti initialization."""
239 | try:
240 | # Detect provider types from class names
241 | llm_provider = self._get_provider_type(self.llm_client)
242 | embedder_provider = self._get_provider_type(self.embedder)
243 | reranker_provider = self._get_provider_type(self.cross_encoder)
244 | database_provider = self._get_provider_type(self.driver)
245 |
246 | properties = {
247 | 'llm_provider': llm_provider,
248 | 'embedder_provider': embedder_provider,
249 | 'reranker_provider': reranker_provider,
250 | 'database_provider': database_provider,
251 | }
252 |
253 | capture_event('graphiti_initialized', properties)
254 | except Exception:
255 | # Silently handle telemetry errors
256 | pass
257 |
258 | def _get_provider_type(self, client) -> str:
259 | """Get provider type from client class name."""
260 | if client is None:
261 | return 'none'
262 |
263 | class_name = client.__class__.__name__.lower()
264 |
265 | # LLM providers
266 | if 'openai' in class_name:
267 | return 'openai'
268 | elif 'azure' in class_name:
269 | return 'azure'
270 | elif 'anthropic' in class_name:
271 | return 'anthropic'
272 | elif 'crossencoder' in class_name:
273 | return 'crossencoder'
274 | elif 'gemini' in class_name:
275 | return 'gemini'
276 | elif 'groq' in class_name:
277 | return 'groq'
278 | # Database providers
279 | elif 'neo4j' in class_name:
280 | return 'neo4j'
281 | elif 'falkor' in class_name:
282 | return 'falkordb'
283 | # Embedder providers
284 | elif 'voyage' in class_name:
285 | return 'voyage'
286 | else:
287 | return 'unknown'
288 |
289 | async def close(self):
290 | """
291 | Close the connection to the Neo4j database.
292 |
293 | This method safely closes the driver connection to the Neo4j database.
294 | It should be called when the Graphiti instance is no longer needed or
295 | when the application is shutting down.
296 |
297 | Parameters
298 | ----------
299 | self
300 |
301 | Returns
302 | -------
303 | None
304 |
305 | Notes
306 | -----
307 | It's important to close the driver connection to release system resources
308 | and ensure that all pending transactions are completed or rolled back.
309 | This method should be called as part of a cleanup process, potentially
310 | in a context manager or a shutdown hook.
311 |
312 | Example:
313 | graphiti = Graphiti(uri, user, password)
314 | try:
315 | # Use graphiti...
316 | finally:
317 | graphiti.close()
318 | """
319 | await self.driver.close()
320 |
321 | async def build_indices_and_constraints(self, delete_existing: bool = False):
322 | """
323 | Build indices and constraints in the Neo4j database.
324 |
325 | This method sets up the necessary indices and constraints in the Neo4j database
326 | to optimize query performance and ensure data integrity for the knowledge graph.
327 |
328 | Parameters
329 | ----------
330 | self
331 | delete_existing : bool, optional
332 | Whether to clear existing indices before creating new ones.
333 |
334 |
335 | Returns
336 | -------
337 | None
338 |
339 | Notes
340 | -----
341 | This method should typically be called once during the initial setup of the
342 | knowledge graph or when updating the database schema. It uses the
343 | driver's `build_indices_and_constraints` method to perform
344 | the actual database operations.
345 |
346 | The specific indices and constraints created depend on the implementation
347 | of the driver's `build_indices_and_constraints` method. Refer to the specific
348 | driver documentation for details on the exact database schema modifications.
349 |
350 | Caution: Running this method on a large existing database may take some time
351 | and could impact database performance during execution.
352 | """
353 | await self.driver.build_indices_and_constraints(delete_existing)
354 |
355 | async def _extract_and_resolve_nodes(
356 | self,
357 | episode: EpisodicNode,
358 | previous_episodes: list[EpisodicNode],
359 | entity_types: dict[str, type[BaseModel]] | None,
360 | excluded_entity_types: list[str] | None,
361 | ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
362 | """Extract nodes from episode and resolve against existing graph."""
363 | extracted_nodes = await extract_nodes(
364 | self.clients, episode, previous_episodes, entity_types, excluded_entity_types
365 | )
366 |
367 | nodes, uuid_map, duplicates = await resolve_extracted_nodes(
368 | self.clients,
369 | extracted_nodes,
370 | episode,
371 | previous_episodes,
372 | entity_types,
373 | )
374 |
375 | return nodes, uuid_map, duplicates
376 |
377 | async def _extract_and_resolve_edges(
378 | self,
379 | episode: EpisodicNode,
380 | extracted_nodes: list[EntityNode],
381 | previous_episodes: list[EpisodicNode],
382 | edge_type_map: dict[tuple[str, str], list[str]],
383 | group_id: str,
384 | edge_types: dict[str, type[BaseModel]] | None,
385 | nodes: list[EntityNode],
386 | uuid_map: dict[str, str],
387 | ) -> tuple[list[EntityEdge], list[EntityEdge]]:
388 | """Extract edges from episode and resolve against existing graph."""
389 | extracted_edges = await extract_edges(
390 | self.clients,
391 | episode,
392 | extracted_nodes,
393 | previous_episodes,
394 | edge_type_map,
395 | group_id,
396 | edge_types,
397 | )
398 |
399 | edges = resolve_edge_pointers(extracted_edges, uuid_map)
400 |
401 | resolved_edges, invalidated_edges = await resolve_extracted_edges(
402 | self.clients,
403 | edges,
404 | episode,
405 | nodes,
406 | edge_types or {},
407 | edge_type_map,
408 | )
409 |
410 | return resolved_edges, invalidated_edges
411 |
412 | async def _process_episode_data(
413 | self,
414 | episode: EpisodicNode,
415 | nodes: list[EntityNode],
416 | entity_edges: list[EntityEdge],
417 | now: datetime,
418 | ) -> tuple[list[EpisodicEdge], EpisodicNode]:
419 | """Process and save episode data to the graph."""
420 | episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
421 | episode.entity_edges = [edge.uuid for edge in entity_edges]
422 |
423 | if not self.store_raw_episode_content:
424 | episode.content = ''
425 |
426 | await add_nodes_and_edges_bulk(
427 | self.driver,
428 | [episode],
429 | episodic_edges,
430 | nodes,
431 | entity_edges,
432 | self.embedder,
433 | )
434 |
435 | return episodic_edges, episode
436 |
437 | async def _extract_and_dedupe_nodes_bulk(
438 | self,
439 | episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
440 | edge_type_map: dict[tuple[str, str], list[str]],
441 | edge_types: dict[str, type[BaseModel]] | None,
442 | entity_types: dict[str, type[BaseModel]] | None,
443 | excluded_entity_types: list[str] | None,
444 | ) -> tuple[
445 | dict[str, list[EntityNode]],
446 | dict[str, str],
447 | list[list[EntityEdge]],
448 | ]:
449 | """Extract nodes and edges from all episodes and deduplicate."""
450 | # Extract all nodes and edges for each episode
451 | extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
452 | self.clients,
453 | episode_context,
454 | edge_type_map=edge_type_map,
455 | edge_types=edge_types,
456 | entity_types=entity_types,
457 | excluded_entity_types=excluded_entity_types,
458 | )
459 |
460 | # Dedupe extracted nodes in memory
461 | nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
462 | self.clients, extracted_nodes_bulk, episode_context, entity_types
463 | )
464 |
465 | return nodes_by_episode, uuid_map, extracted_edges_bulk
466 |
467 | async def _resolve_nodes_and_edges_bulk(
468 | self,
469 | nodes_by_episode: dict[str, list[EntityNode]],
470 | edges_by_episode: dict[str, list[EntityEdge]],
471 | episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
472 | entity_types: dict[str, type[BaseModel]] | None,
473 | edge_types: dict[str, type[BaseModel]] | None,
474 | edge_type_map: dict[tuple[str, str], list[str]],
475 | episodes: list[EpisodicNode],
476 | ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
477 | """Resolve nodes and edges against the existing graph."""
478 | nodes_by_uuid: dict[str, EntityNode] = {
479 | node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
480 | }
481 |
482 | # Get unique nodes per episode
483 | nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
484 | nodes_uuid_set: set[str] = set()
485 | for episode, _ in episode_context:
486 | nodes_by_episode_unique[episode.uuid] = []
487 | nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
488 | for node in nodes:
489 | if node.uuid not in nodes_uuid_set:
490 | nodes_by_episode_unique[episode.uuid].append(node)
491 | nodes_uuid_set.add(node.uuid)
492 |
493 | # Resolve nodes
494 | node_results = await semaphore_gather(
495 | *[
496 | resolve_extracted_nodes(
497 | self.clients,
498 | nodes_by_episode_unique[episode.uuid],
499 | episode,
500 | previous_episodes,
501 | entity_types,
502 | )
503 | for episode, previous_episodes in episode_context
504 | ]
505 | )
506 |
507 | resolved_nodes: list[EntityNode] = []
508 | uuid_map: dict[str, str] = {}
509 | for result in node_results:
510 | resolved_nodes.extend(result[0])
511 | uuid_map.update(result[1])
512 |
513 | # Update nodes_by_uuid with resolved nodes
514 | for resolved_node in resolved_nodes:
515 | nodes_by_uuid[resolved_node.uuid] = resolved_node
516 |
517 | # Update nodes_by_episode_unique with resolved pointers
518 | for episode_uuid, nodes in nodes_by_episode_unique.items():
519 | updated_nodes: list[EntityNode] = []
520 | for node in nodes:
521 | updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
522 | updated_node = nodes_by_uuid[updated_node_uuid]
523 | updated_nodes.append(updated_node)
524 | nodes_by_episode_unique[episode_uuid] = updated_nodes
525 |
526 | # Extract attributes for resolved nodes
527 | hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
528 | *[
529 | extract_attributes_from_nodes(
530 | self.clients,
531 | nodes_by_episode_unique[episode.uuid],
532 | episode,
533 | previous_episodes,
534 | entity_types,
535 | )
536 | for episode, previous_episodes in episode_context
537 | ]
538 | )
539 |
540 | final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]
541 |
542 | # Resolve edges with updated pointers
543 | edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
544 | edges_uuid_set: set[str] = set()
545 | for episode_uuid, edges in edges_by_episode.items():
546 | edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
547 | edges_by_episode_unique[episode_uuid] = []
548 |
549 | for edge in edges_with_updated_pointers:
550 | if edge.uuid not in edges_uuid_set:
551 | edges_by_episode_unique[episode_uuid].append(edge)
552 | edges_uuid_set.add(edge.uuid)
553 |
554 | edge_results = await semaphore_gather(
555 | *[
556 | resolve_extracted_edges(
557 | self.clients,
558 | edges_by_episode_unique[episode.uuid],
559 | episode,
560 | final_hydrated_nodes,
561 | edge_types or {},
562 | edge_type_map,
563 | )
564 | for episode in episodes
565 | ]
566 | )
567 |
568 | resolved_edges: list[EntityEdge] = []
569 | invalidated_edges: list[EntityEdge] = []
570 | for result in edge_results:
571 | resolved_edges.extend(result[0])
572 | invalidated_edges.extend(result[1])
573 |
574 | return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map
575 |
576 | @handle_multiple_group_ids
577 | async def retrieve_episodes(
578 | self,
579 | reference_time: datetime,
580 | last_n: int = EPISODE_WINDOW_LEN,
581 | group_ids: list[str] | None = None,
582 | source: EpisodeType | None = None,
583 | driver: GraphDriver | None = None,
584 | ) -> list[EpisodicNode]:
585 | """
586 | Retrieve the last n episodic nodes from the graph.
587 |
588 | This method fetches a specified number of the most recent episodic nodes
589 | from the graph, relative to the given reference time.
590 |
591 | Parameters
592 | ----------
593 | reference_time : datetime
594 | The reference time to retrieve episodes before.
595 | last_n : int, optional
596 | The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
597 | group_ids : list[str | None], optional
598 | The group ids to return data from.
599 |
600 | Returns
601 | -------
602 | list[EpisodicNode]
603 | A list of the most recent EpisodicNode objects.
604 |
605 | Notes
606 | -----
607 | The actual retrieval is performed by the `retrieve_episodes` function
608 | from the `graphiti_core.utils` module.
609 | """
610 | if driver is None:
611 | driver = self.clients.driver
612 |
613 | return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)
614 |
615 | async def add_episode(
616 | self,
617 | name: str,
618 | episode_body: str,
619 | source_description: str,
620 | reference_time: datetime,
621 | source: EpisodeType = EpisodeType.message,
622 | group_id: str | None = None,
623 | uuid: str | None = None,
624 | update_communities: bool = False,
625 | entity_types: dict[str, type[BaseModel]] | None = None,
626 | excluded_entity_types: list[str] | None = None,
627 | previous_episode_uuids: list[str] | None = None,
628 | edge_types: dict[str, type[BaseModel]] | None = None,
629 | edge_type_map: dict[tuple[str, str], list[str]] | None = None,
630 | ) -> AddEpisodeResults:
631 | """
632 | Process an episode and update the graph.
633 |
634 | This method extracts information from the episode, creates nodes and edges,
635 | and updates the graph database accordingly.
636 |
637 | Parameters
638 | ----------
639 | name : str
640 | The name of the episode.
641 | episode_body : str
642 | The content of the episode.
643 | source_description : str
644 | A description of the episode's source.
645 | reference_time : datetime
646 | The reference time for the episode.
647 | source : EpisodeType, optional
648 | The type of the episode. Defaults to EpisodeType.message.
649 | group_id : str | None
650 | An id for the graph partition the episode is a part of.
651 | uuid : str | None
652 | Optional uuid of the episode.
653 | update_communities : bool
654 | Optional. Whether to update communities with new node information
655 | entity_types : dict[str, BaseModel] | None
656 | Optional. Dictionary mapping entity type names to their Pydantic model definitions.
657 | excluded_entity_types : list[str] | None
658 | Optional. List of entity type names to exclude from the graph. Entities classified
659 | into these types will not be added to the graph. Can include 'Entity' to exclude
660 | the default entity type.
661 | previous_episode_uuids : list[str] | None
662 | Optional. list of episode uuids to use as the previous episodes. If this is not provided,
663 | the most recent episodes by created_at date will be used.
664 |
665 | Returns
666 | -------
667 | None
668 |
669 | Notes
670 | -----
671 | This method performs several steps including node extraction, edge extraction,
672 | deduplication, and database updates. It also handles embedding generation
673 | and edge invalidation.
674 |
675 | It is recommended to run this method as a background process, such as in a queue.
676 | It's important that each episode is added sequentially and awaited before adding
677 | the next one. For web applications, consider using FastAPI's background tasks
678 | or a dedicated task queue like Celery for this purpose.
679 |
680 | Example using FastAPI background tasks:
681 | @app.post("/add_episode")
682 | async def add_episode_endpoint(episode_data: EpisodeData):
683 | background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
684 | return {"message": "Episode processing started"}
685 | """
686 | start = time()
687 | now = utc_now()
688 |
689 | validate_entity_types(entity_types)
690 | validate_excluded_entity_types(excluded_entity_types, entity_types)
691 |
692 | if group_id is None:
693 | # if group_id is None, use the default group id by the provider
694 | # and the preset database name will be used
695 | group_id = get_default_group_id(self.driver.provider)
696 | else:
697 | validate_group_id(group_id)
698 | if group_id != self.driver._database:
699 | # if group_id is provided, use it as the database name
700 | self.driver = self.driver.clone(database=group_id)
701 | self.clients.driver = self.driver
702 |
703 | with self.tracer.start_span('add_episode') as span:
704 | try:
705 | # Retrieve previous episodes for context
706 | previous_episodes = (
707 | await self.retrieve_episodes(
708 | reference_time,
709 | last_n=RELEVANT_SCHEMA_LIMIT,
710 | group_ids=[group_id],
711 | source=source,
712 | )
713 | if previous_episode_uuids is None
714 | else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
715 | )
716 |
717 | # Get or create episode
718 | episode = (
719 | await EpisodicNode.get_by_uuid(self.driver, uuid)
720 | if uuid is not None
721 | else EpisodicNode(
722 | name=name,
723 | group_id=group_id,
724 | labels=[],
725 | source=source,
726 | content=episode_body,
727 | source_description=source_description,
728 | created_at=now,
729 | valid_at=reference_time,
730 | )
731 | )
732 |
733 | # Create default edge type map
734 | edge_type_map_default = (
735 | {('Entity', 'Entity'): list(edge_types.keys())}
736 | if edge_types is not None
737 | else {('Entity', 'Entity'): []}
738 | )
739 |
740 | # Extract and resolve nodes
741 | extracted_nodes = await extract_nodes(
742 | self.clients, episode, previous_episodes, entity_types, excluded_entity_types
743 | )
744 |
745 | nodes, uuid_map, _ = await resolve_extracted_nodes(
746 | self.clients,
747 | extracted_nodes,
748 | episode,
749 | previous_episodes,
750 | entity_types,
751 | )
752 |
753 | # Extract and resolve edges in parallel with attribute extraction
754 | resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
755 | episode,
756 | extracted_nodes,
757 | previous_episodes,
758 | edge_type_map or edge_type_map_default,
759 | group_id,
760 | edge_types,
761 | nodes,
762 | uuid_map,
763 | )
764 |
765 | # Extract node attributes
766 | hydrated_nodes = await extract_attributes_from_nodes(
767 | self.clients, nodes, episode, previous_episodes, entity_types
768 | )
769 |
770 | entity_edges = resolved_edges + invalidated_edges
771 |
772 | # Process and save episode data
773 | episodic_edges, episode = await self._process_episode_data(
774 | episode, hydrated_nodes, entity_edges, now
775 | )
776 |
777 | # Update communities if requested
778 | communities = []
779 | community_edges = []
780 | if update_communities:
781 | communities, community_edges = await semaphore_gather(
782 | *[
783 | update_community(self.driver, self.llm_client, self.embedder, node)
784 | for node in nodes
785 | ],
786 | max_coroutines=self.max_coroutines,
787 | )
788 |
789 | end = time()
790 |
791 | # Add span attributes
792 | span.add_attributes(
793 | {
794 | 'episode.uuid': episode.uuid,
795 | 'episode.source': source.value,
796 | 'episode.reference_time': reference_time.isoformat(),
797 | 'group_id': group_id,
798 | 'node.count': len(hydrated_nodes),
799 | 'edge.count': len(entity_edges),
800 | 'edge.invalidated_count': len(invalidated_edges),
801 | 'previous_episodes.count': len(previous_episodes),
802 | 'entity_types.count': len(entity_types) if entity_types else 0,
803 | 'edge_types.count': len(edge_types) if edge_types else 0,
804 | 'update_communities': update_communities,
805 | 'communities.count': len(communities) if update_communities else 0,
806 | 'duration_ms': (end - start) * 1000,
807 | }
808 | )
809 |
810 | logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
811 |
812 | return AddEpisodeResults(
813 | episode=episode,
814 | episodic_edges=episodic_edges,
815 | nodes=hydrated_nodes,
816 | edges=entity_edges,
817 | communities=communities,
818 | community_edges=community_edges,
819 | )
820 |
821 | except Exception as e:
822 | span.set_status('error', str(e))
823 | span.record_exception(e)
824 | raise e
825 |
826 | async def add_episode_bulk(
827 | self,
828 | bulk_episodes: list[RawEpisode],
829 | group_id: str | None = None,
830 | entity_types: dict[str, type[BaseModel]] | None = None,
831 | excluded_entity_types: list[str] | None = None,
832 | edge_types: dict[str, type[BaseModel]] | None = None,
833 | edge_type_map: dict[tuple[str, str], list[str]] | None = None,
834 | ) -> AddBulkEpisodeResults:
835 | """
836 | Process multiple episodes in bulk and update the graph.
837 |
838 | This method extracts information from multiple episodes, creates nodes and edges,
839 | and updates the graph database accordingly, all in a single batch operation.
840 |
841 | Parameters
842 | ----------
843 | bulk_episodes : list[RawEpisode]
844 | A list of RawEpisode objects to be processed and added to the graph.
845 | group_id : str | None
846 | An id for the graph partition the episode is a part of.
847 |
848 | Returns
849 | -------
850 | AddBulkEpisodeResults
851 |
852 | Notes
853 | -----
854 | This method performs several steps including:
855 | - Saving all episodes to the database
856 | - Retrieving previous episode context for each new episode
857 | - Extracting nodes and edges from all episodes
858 | - Generating embeddings for nodes and edges
859 | - Deduplicating nodes and edges
860 | - Saving nodes, episodic edges, and entity edges to the knowledge graph
861 |
862 | This bulk operation is designed for efficiency when processing multiple episodes
863 | at once. However, it's important to ensure that the bulk operation doesn't
864 | overwhelm system resources. Consider implementing rate limiting or chunking for
865 | very large batches of episodes.
866 |
867 | Important: This method does not perform edge invalidation or date extraction steps.
868 | If these operations are required, use the `add_episode` method instead for each
869 | individual episode.
870 | """
871 | with self.tracer.start_span('add_episode_bulk') as bulk_span:
872 | bulk_span.add_attributes({'episode.count': len(bulk_episodes)})
873 |
874 | try:
875 | start = time()
876 | now = utc_now()
877 |
878 | # if group_id is None, use the default group id by the provider
879 | if group_id is None:
880 | group_id = get_default_group_id(self.driver.provider)
881 | else:
882 | validate_group_id(group_id)
883 | if group_id != self.driver._database:
884 | # if group_id is provided, use it as the database name
885 | self.driver = self.driver.clone(database=group_id)
886 | self.clients.driver = self.driver
887 |
888 | # Create default edge type map
889 | edge_type_map_default = (
890 | {('Entity', 'Entity'): list(edge_types.keys())}
891 | if edge_types is not None
892 | else {('Entity', 'Entity'): []}
893 | )
894 |
895 | episodes = [
896 | await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
897 | if episode.uuid is not None
898 | else EpisodicNode(
899 | name=episode.name,
900 | labels=[],
901 | source=episode.source,
902 | content=episode.content,
903 | source_description=episode.source_description,
904 | group_id=group_id,
905 | created_at=now,
906 | valid_at=episode.reference_time,
907 | )
908 | for episode in bulk_episodes
909 | ]
910 |
911 | # Save all episodes
912 | await add_nodes_and_edges_bulk(
913 | driver=self.driver,
914 | episodic_nodes=episodes,
915 | episodic_edges=[],
916 | entity_nodes=[],
917 | entity_edges=[],
918 | embedder=self.embedder,
919 | )
920 |
921 | # Get previous episode context for each episode
922 | episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)
923 |
924 | # Extract and dedupe nodes and edges
925 | (
926 | nodes_by_episode,
927 | uuid_map,
928 | extracted_edges_bulk,
929 | ) = await self._extract_and_dedupe_nodes_bulk(
930 | episode_context,
931 | edge_type_map or edge_type_map_default,
932 | edge_types,
933 | entity_types,
934 | excluded_entity_types,
935 | )
936 |
937 | # Create Episodic Edges
938 | episodic_edges: list[EpisodicEdge] = []
939 | for episode_uuid, nodes in nodes_by_episode.items():
940 | episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))
941 |
942 | # Re-map edge pointers and dedupe edges
943 | extracted_edges_bulk_updated: list[list[EntityEdge]] = [
944 | resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
945 | ]
946 |
947 | edges_by_episode = await dedupe_edges_bulk(
948 | self.clients,
949 | extracted_edges_bulk_updated,
950 | episode_context,
951 | [],
952 | edge_types or {},
953 | edge_type_map or edge_type_map_default,
954 | )
955 |
956 | # Resolve nodes and edges against the existing graph
957 | (
958 | final_hydrated_nodes,
959 | resolved_edges,
960 | invalidated_edges,
961 | final_uuid_map,
962 | ) = await self._resolve_nodes_and_edges_bulk(
963 | nodes_by_episode,
964 | edges_by_episode,
965 | episode_context,
966 | entity_types,
967 | edge_types,
968 | edge_type_map or edge_type_map_default,
969 | episodes,
970 | )
971 |
972 | # Resolved pointers for episodic edges
973 | resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)
974 |
975 | # save data to KG
976 | await add_nodes_and_edges_bulk(
977 | self.driver,
978 | episodes,
979 | resolved_episodic_edges,
980 | final_hydrated_nodes,
981 | resolved_edges + invalidated_edges,
982 | self.embedder,
983 | )
984 |
985 | end = time()
986 |
987 | # Add span attributes
988 | bulk_span.add_attributes(
989 | {
990 | 'group_id': group_id,
991 | 'node.count': len(final_hydrated_nodes),
992 | 'edge.count': len(resolved_edges + invalidated_edges),
993 | 'duration_ms': (end - start) * 1000,
994 | }
995 | )
996 |
997 | logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')
998 |
999 | return AddBulkEpisodeResults(
1000 | episodes=episodes,
1001 | episodic_edges=resolved_episodic_edges,
1002 | nodes=final_hydrated_nodes,
1003 | edges=resolved_edges + invalidated_edges,
1004 | communities=[],
1005 | community_edges=[],
1006 | )
1007 |
1008 | except Exception as e:
1009 | bulk_span.set_status('error', str(e))
1010 | bulk_span.record_exception(e)
1011 | raise e
1012 |
1013 | @handle_multiple_group_ids
1014 | async def build_communities(
1015 | self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
1016 | ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
1017 | """
1018 | Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
1019 | the content of these communities.
1020 | ----------
1021 | group_ids : list[str] | None
1022 | Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
1023 | """
1024 | if driver is None:
1025 | driver = self.clients.driver
1026 |
1027 | # Clear existing communities
1028 | await remove_communities(driver)
1029 |
1030 | community_nodes, community_edges = await build_communities(
1031 | driver, self.llm_client, group_ids
1032 | )
1033 |
1034 | await semaphore_gather(
1035 | *[node.generate_name_embedding(self.embedder) for node in community_nodes],
1036 | max_coroutines=self.max_coroutines,
1037 | )
1038 |
1039 | await semaphore_gather(
1040 | *[node.save(driver) for node in community_nodes],
1041 | max_coroutines=self.max_coroutines,
1042 | )
1043 | await semaphore_gather(
1044 | *[edge.save(driver) for edge in community_edges],
1045 | max_coroutines=self.max_coroutines,
1046 | )
1047 |
1048 | return community_nodes, community_edges
1049 |
1050 | @handle_multiple_group_ids
1051 | async def search(
1052 | self,
1053 | query: str,
1054 | center_node_uuid: str | None = None,
1055 | group_ids: list[str] | None = None,
1056 | num_results=DEFAULT_SEARCH_LIMIT,
1057 | search_filter: SearchFilters | None = None,
1058 | driver: GraphDriver | None = None,
1059 | ) -> list[EntityEdge]:
1060 | """
1061 | Perform a hybrid search on the knowledge graph.
1062 |
1063 | This method executes a search query on the graph, combining vector and
1064 | text-based search techniques to retrieve relevant facts, returning the edges as a string.
1065 |
1066 | This is our basic out-of-the-box search, for more robust results we recommend using our more advanced
1067 | search method graphiti.search_().
1068 |
1069 | Parameters
1070 | ----------
1071 | query : str
1072 | The search query string.
1073 | center_node_uuid: str, optional
1074 | Facts will be reranked based on proximity to this node
1075 | group_ids : list[str | None] | None, optional
1076 | The graph partitions to return data from.
1077 | num_results : int, optional
1078 | The maximum number of results to return. Defaults to 10.
1079 |
1080 | Returns
1081 | -------
1082 | list
1083 | A list of EntityEdge objects that are relevant to the search query.
1084 |
1085 | Notes
1086 | -----
1087 | This method uses a SearchConfig with num_episodes set to 0 and
1088 | num_results set to the provided num_results parameter.
1089 |
1090 | The search is performed using the current date and time as the reference
1091 | point for temporal relevance.
1092 | """
1093 | search_config = (
1094 | EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
1095 | )
1096 | search_config.limit = num_results
1097 |
1098 | edges = (
1099 | await search(
1100 | self.clients,
1101 | query,
1102 | group_ids,
1103 | search_config,
1104 | search_filter if search_filter is not None else SearchFilters(),
1105 | driver=driver,
1106 | center_node_uuid=center_node_uuid,
1107 | )
1108 | ).edges
1109 |
1110 | return edges
1111 |
1112 | async def _search(
1113 | self,
1114 | query: str,
1115 | config: SearchConfig,
1116 | group_ids: list[str] | None = None,
1117 | center_node_uuid: str | None = None,
1118 | bfs_origin_node_uuids: list[str] | None = None,
1119 | search_filter: SearchFilters | None = None,
1120 | ) -> SearchResults:
1121 | """DEPRECATED"""
1122 | return await self.search_(
1123 | query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
1124 | )
1125 |
1126 | @handle_multiple_group_ids
1127 | async def search_(
1128 | self,
1129 | query: str,
1130 | config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
1131 | group_ids: list[str] | None = None,
1132 | center_node_uuid: str | None = None,
1133 | bfs_origin_node_uuids: list[str] | None = None,
1134 | search_filter: SearchFilters | None = None,
1135 | driver: GraphDriver | None = None,
1136 | ) -> SearchResults:
1137 | """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
1138 | than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
1139 | different search and reranker methodologies across different layers in the graph.
1140 |
1141 | For different config recipes refer to search/search_config_recipes.
1142 | """
1143 |
1144 | return await search(
1145 | self.clients,
1146 | query,
1147 | group_ids,
1148 | config,
1149 | search_filter if search_filter is not None else SearchFilters(),
1150 | center_node_uuid,
1151 | bfs_origin_node_uuids,
1152 | driver=driver,
1153 | )
1154 |
1155 | async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
1156 | episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)
1157 |
1158 | edges_list = await semaphore_gather(
1159 | *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
1160 | max_coroutines=self.max_coroutines,
1161 | )
1162 |
1163 | edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]
1164 |
1165 | nodes = await get_mentioned_nodes(self.driver, episodes)
1166 |
1167 | return SearchResults(edges=edges, nodes=nodes)
1168 |
1169 | async def add_triplet(
1170 | self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
1171 | ) -> AddTripletResults:
1172 | if source_node.name_embedding is None:
1173 | await source_node.generate_name_embedding(self.embedder)
1174 | if target_node.name_embedding is None:
1175 | await target_node.generate_name_embedding(self.embedder)
1176 | if edge.fact_embedding is None:
1177 | await edge.generate_embedding(self.embedder)
1178 |
1179 | nodes, uuid_map, _ = await resolve_extracted_nodes(
1180 | self.clients,
1181 | [source_node, target_node],
1182 | )
1183 |
1184 | updated_edge = resolve_edge_pointers([edge], uuid_map)[0]
1185 |
1186 | valid_edges = await EntityEdge.get_between_nodes(
1187 | self.driver, edge.source_node_uuid, edge.target_node_uuid
1188 | )
1189 |
1190 | related_edges = (
1191 | await search(
1192 | self.clients,
1193 | updated_edge.fact,
1194 | group_ids=[updated_edge.group_id],
1195 | config=EDGE_HYBRID_SEARCH_RRF,
1196 | search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
1197 | )
1198 | ).edges
1199 | existing_edges = (
1200 | await search(
1201 | self.clients,
1202 | updated_edge.fact,
1203 | group_ids=[updated_edge.group_id],
1204 | config=EDGE_HYBRID_SEARCH_RRF,
1205 | search_filter=SearchFilters(),
1206 | )
1207 | ).edges
1208 |
1209 | resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
1210 | self.llm_client,
1211 | updated_edge,
1212 | related_edges,
1213 | existing_edges,
1214 | EpisodicNode(
1215 | name='',
1216 | source=EpisodeType.text,
1217 | source_description='',
1218 | content='',
1219 | valid_at=edge.valid_at or utc_now(),
1220 | entity_edges=[],
1221 | group_id=edge.group_id,
1222 | ),
1223 | None,
1224 | None,
1225 | )
1226 |
1227 | edges: list[EntityEdge] = [resolved_edge] + invalidated_edges
1228 |
1229 | await create_entity_edge_embeddings(self.embedder, edges)
1230 | await create_entity_node_embeddings(self.embedder, nodes)
1231 |
1232 | await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
1233 | return AddTripletResults(edges=edges, nodes=nodes)
1234 |
1235 | async def remove_episode(self, episode_uuid: str):
1236 | # Find the episode to be deleted
1237 | episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid)
1238 |
1239 | # Find edges mentioned by the episode
1240 | edges = await EntityEdge.get_by_uuids(self.driver, episode.entity_edges)
1241 |
1242 | # We should only delete edges created by the episode
1243 | edges_to_delete: list[EntityEdge] = []
1244 | for edge in edges:
1245 | if edge.episodes and edge.episodes[0] == episode.uuid:
1246 | edges_to_delete.append(edge)
1247 |
1248 | # Find nodes mentioned by the episode
1249 | nodes = await get_mentioned_nodes(self.driver, [episode])
1250 | # We should delete all nodes that are only mentioned in the deleted episode
1251 | nodes_to_delete: list[EntityNode] = []
1252 | for node in nodes:
1253 | query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
1254 | records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')
1255 |
1256 | for record in records:
1257 | if record['episode_count'] == 1:
1258 | nodes_to_delete.append(node)
1259 |
1260 | await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
1261 | await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])
1262 |
1263 | await episode.delete(self.driver)
1264 |
```
--------------------------------------------------------------------------------
/tests/test_graphiti_mock.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from datetime import datetime, timedelta
18 | from unittest.mock import Mock
19 |
20 | import numpy as np
21 | import pytest
22 |
23 | from graphiti_core.cross_encoder.client import CrossEncoderClient
24 | from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
25 | from graphiti_core.graphiti import Graphiti
26 | from graphiti_core.llm_client import LLMClient
27 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
28 | from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
29 | from graphiti_core.search.search_utils import (
30 | community_fulltext_search,
31 | community_similarity_search,
32 | edge_bfs_search,
33 | edge_fulltext_search,
34 | edge_similarity_search,
35 | episode_fulltext_search,
36 | episode_mentions_reranker,
37 | get_communities_by_nodes,
38 | get_edge_invalidation_candidates,
39 | get_embeddings_for_communities,
40 | get_embeddings_for_edges,
41 | get_embeddings_for_nodes,
42 | get_mentioned_nodes,
43 | get_relevant_edges,
44 | get_relevant_nodes,
45 | node_bfs_search,
46 | node_distance_reranker,
47 | node_fulltext_search,
48 | node_similarity_search,
49 | )
50 | from graphiti_core.utils.bulk_utils import add_nodes_and_edges_bulk
51 | from graphiti_core.utils.maintenance.community_operations import (
52 | determine_entity_community,
53 | get_community_clusters,
54 | remove_communities,
55 | )
56 | from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
57 | from tests.helpers_test import (
58 | GraphProvider,
59 | assert_entity_edge_equals,
60 | assert_entity_node_equals,
61 | assert_episodic_edge_equals,
62 | assert_episodic_node_equals,
63 | get_edge_count,
64 | get_node_count,
65 | group_id,
66 | group_id_2,
67 | )
68 |
69 | pytest_plugins = ('pytest_asyncio',)
70 |
71 |
72 | @pytest.fixture
73 | def mock_llm_client():
74 | """Create a mock LLM"""
75 | mock_llm = Mock(spec=LLMClient)
76 | mock_llm.config = Mock()
77 | mock_llm.model = 'test-model'
78 | mock_llm.small_model = 'test-small-model'
79 | mock_llm.temperature = 0.0
80 | mock_llm.max_tokens = 1000
81 | mock_llm.cache_enabled = False
82 | mock_llm.cache_dir = None
83 |
84 | # Mock the public method that's actually called
85 | mock_llm.generate_response = Mock()
86 | mock_llm.generate_response.return_value = {
87 | 'tool_calls': [
88 | {
89 | 'name': 'extract_entities',
90 | 'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
91 | }
92 | ]
93 | }
94 |
95 | return mock_llm
96 |
97 |
98 | @pytest.fixture
99 | def mock_cross_encoder_client():
100 | """Create a mock LLM"""
101 | mock_llm = Mock(spec=CrossEncoderClient)
102 | mock_llm.config = Mock()
103 |
104 | # Mock the public method that's actually called
105 | mock_llm.rerank = Mock()
106 | mock_llm.rerank.return_value = {
107 | 'tool_calls': [
108 | {
109 | 'name': 'extract_entities',
110 | 'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
111 | }
112 | ]
113 | }
114 |
115 | return mock_llm
116 |
117 |
118 | @pytest.mark.asyncio
119 | async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
120 | if graph_driver.provider == GraphProvider.FALKORDB:
121 | pytest.skip('Skipping as test fails on FalkorDB')
122 |
123 | graphiti = Graphiti(
124 | graph_driver=graph_driver,
125 | llm_client=mock_llm_client,
126 | embedder=mock_embedder,
127 | cross_encoder=mock_cross_encoder_client,
128 | )
129 |
130 | await graphiti.build_indices_and_constraints()
131 |
132 | now = datetime.now()
133 |
134 | # Create episodic nodes
135 | episode_node_1 = EpisodicNode(
136 | name='test_episode',
137 | group_id=group_id,
138 | labels=[],
139 | created_at=now,
140 | source=EpisodeType.message,
141 | source_description='conversation message',
142 | content='Alice likes Bob',
143 | valid_at=now,
144 | entity_edges=[], # Filled in later
145 | )
146 | episode_node_2 = EpisodicNode(
147 | name='test_episode_2',
148 | group_id=group_id,
149 | labels=[],
150 | created_at=now,
151 | source=EpisodeType.message,
152 | source_description='conversation message',
153 | content='Bob adores Alice',
154 | valid_at=now,
155 | entity_edges=[], # Filled in later
156 | )
157 |
158 | # Create entity nodes
159 | entity_node_1 = EntityNode(
160 | name='test_entity_1',
161 | group_id=group_id,
162 | labels=['Entity', 'Person'],
163 | created_at=now,
164 | summary='test_entity_1 summary',
165 | attributes={'age': 30, 'location': 'New York'},
166 | )
167 | await entity_node_1.generate_name_embedding(mock_embedder)
168 |
169 | entity_node_2 = EntityNode(
170 | name='test_entity_2',
171 | group_id=group_id,
172 | labels=['Entity', 'Person2'],
173 | created_at=now,
174 | summary='test_entity_2 summary',
175 | attributes={'age': 25, 'location': 'Los Angeles'},
176 | )
177 | await entity_node_2.generate_name_embedding(mock_embedder)
178 |
179 | entity_node_3 = EntityNode(
180 | name='test_entity_3',
181 | group_id=group_id,
182 | labels=['Entity', 'City', 'Location'],
183 | created_at=now,
184 | summary='test_entity_3 summary',
185 | attributes={'age': 25, 'location': 'Los Angeles'},
186 | )
187 | await entity_node_3.generate_name_embedding(mock_embedder)
188 |
189 | entity_node_4 = EntityNode(
190 | name='test_entity_4',
191 | group_id=group_id,
192 | labels=['Entity'],
193 | created_at=now,
194 | summary='test_entity_4 summary',
195 | attributes={'age': 25, 'location': 'Los Angeles'},
196 | )
197 | await entity_node_4.generate_name_embedding(mock_embedder)
198 |
199 | # Create entity edges
200 | entity_edge_1 = EntityEdge(
201 | source_node_uuid=entity_node_1.uuid,
202 | target_node_uuid=entity_node_2.uuid,
203 | created_at=now,
204 | name='likes',
205 | fact='test_entity_1 relates to test_entity_2',
206 | episodes=[],
207 | expired_at=now,
208 | valid_at=now,
209 | invalid_at=now,
210 | group_id=group_id,
211 | )
212 | await entity_edge_1.generate_embedding(mock_embedder)
213 |
214 | entity_edge_2 = EntityEdge(
215 | source_node_uuid=entity_node_3.uuid,
216 | target_node_uuid=entity_node_4.uuid,
217 | created_at=now,
218 | name='relates_to',
219 | fact='test_entity_3 relates to test_entity_4',
220 | episodes=[],
221 | expired_at=now,
222 | valid_at=now,
223 | invalid_at=now,
224 | group_id=group_id,
225 | )
226 | await entity_edge_2.generate_embedding(mock_embedder)
227 |
228 | # Create episodic to entity edges
229 | episodic_edge_1 = EpisodicEdge(
230 | source_node_uuid=episode_node_1.uuid,
231 | target_node_uuid=entity_node_1.uuid,
232 | created_at=now,
233 | group_id=group_id,
234 | )
235 | episodic_edge_2 = EpisodicEdge(
236 | source_node_uuid=episode_node_1.uuid,
237 | target_node_uuid=entity_node_2.uuid,
238 | created_at=now,
239 | group_id=group_id,
240 | )
241 | episodic_edge_3 = EpisodicEdge(
242 | source_node_uuid=episode_node_2.uuid,
243 | target_node_uuid=entity_node_3.uuid,
244 | created_at=now,
245 | group_id=group_id,
246 | )
247 | episodic_edge_4 = EpisodicEdge(
248 | source_node_uuid=episode_node_2.uuid,
249 | target_node_uuid=entity_node_4.uuid,
250 | created_at=now,
251 | group_id=group_id,
252 | )
253 |
254 | # Cross reference the ids
255 | episode_node_1.entity_edges = [entity_edge_1.uuid]
256 | episode_node_2.entity_edges = [entity_edge_2.uuid]
257 | entity_edge_1.episodes = [episode_node_1.uuid, episode_node_2.uuid]
258 | entity_edge_2.episodes = [episode_node_2.uuid]
259 |
260 | # Test add bulk
261 | await add_nodes_and_edges_bulk(
262 | graph_driver,
263 | [episode_node_1, episode_node_2],
264 | [episodic_edge_1, episodic_edge_2, episodic_edge_3, episodic_edge_4],
265 | [entity_node_1, entity_node_2, entity_node_3, entity_node_4],
266 | [entity_edge_1, entity_edge_2],
267 | mock_embedder,
268 | )
269 |
270 | node_ids = [
271 | episode_node_1.uuid,
272 | episode_node_2.uuid,
273 | entity_node_1.uuid,
274 | entity_node_2.uuid,
275 | entity_node_3.uuid,
276 | entity_node_4.uuid,
277 | ]
278 | edge_ids = [
279 | episodic_edge_1.uuid,
280 | episodic_edge_2.uuid,
281 | episodic_edge_3.uuid,
282 | episodic_edge_4.uuid,
283 | entity_edge_1.uuid,
284 | entity_edge_2.uuid,
285 | ]
286 | node_count = await get_node_count(graph_driver, node_ids)
287 | assert node_count == len(node_ids)
288 | edge_count = await get_edge_count(graph_driver, edge_ids)
289 | assert edge_count == len(edge_ids)
290 |
291 | # Test episodic nodes
292 | retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_1.uuid)
293 | await assert_episodic_node_equals(retrieved_episode, episode_node_1)
294 |
295 | retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_2.uuid)
296 | await assert_episodic_node_equals(retrieved_episode, episode_node_2)
297 |
298 | # Test entity nodes
299 | retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_1.uuid)
300 | await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_1)
301 |
302 | retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
303 | await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)
304 |
305 | retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
306 | await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)
307 |
308 | retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_4.uuid)
309 | await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_4)
310 |
311 | # Test episodic edges
312 | retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_1.uuid)
313 | await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_1)
314 |
315 | retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_2.uuid)
316 | await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_2)
317 |
318 | retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_3.uuid)
319 | await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_3)
320 |
321 | retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_4.uuid)
322 | await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_4)
323 |
324 | # Test entity edges
325 | retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_1.uuid)
326 | await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_1)
327 |
328 | retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
329 | await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)
330 |
331 |
332 | @pytest.mark.asyncio
333 | async def test_remove_episode(
334 | graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
335 | ):
336 | graphiti = Graphiti(
337 | graph_driver=graph_driver,
338 | llm_client=mock_llm_client,
339 | embedder=mock_embedder,
340 | cross_encoder=mock_cross_encoder_client,
341 | )
342 |
343 | await graphiti.build_indices_and_constraints()
344 |
345 | now = datetime.now()
346 |
347 | # Create episodic nodes
348 | episode_node = EpisodicNode(
349 | name='test_episode',
350 | group_id=group_id,
351 | labels=[],
352 | created_at=now,
353 | source=EpisodeType.message,
354 | source_description='conversation message',
355 | content='Alice likes Bob',
356 | valid_at=now,
357 | entity_edges=[], # Filled in later
358 | )
359 |
360 | # Create entity nodes
361 | alice_node = EntityNode(
362 | name='Alice',
363 | group_id=group_id,
364 | labels=['Entity', 'Person'],
365 | created_at=now,
366 | summary='Alice summary',
367 | attributes={'age': 30, 'location': 'New York'},
368 | )
369 | await alice_node.generate_name_embedding(mock_embedder)
370 |
371 | bob_node = EntityNode(
372 | name='Bob',
373 | group_id=group_id,
374 | labels=['Entity', 'Person2'],
375 | created_at=now,
376 | summary='Bob summary',
377 | attributes={'age': 25, 'location': 'Los Angeles'},
378 | )
379 | await bob_node.generate_name_embedding(mock_embedder)
380 |
381 | # Create entity to entity edge
382 | entity_edge = EntityEdge(
383 | source_node_uuid=alice_node.uuid,
384 | target_node_uuid=bob_node.uuid,
385 | created_at=now,
386 | name='likes',
387 | fact='Alice likes Bob',
388 | episodes=[],
389 | expired_at=now,
390 | valid_at=now,
391 | invalid_at=now,
392 | group_id=group_id,
393 | )
394 | await entity_edge.generate_embedding(mock_embedder)
395 |
396 | # Create episodic to entity edges
397 | episodic_alice_edge = EpisodicEdge(
398 | source_node_uuid=episode_node.uuid,
399 | target_node_uuid=alice_node.uuid,
400 | created_at=now,
401 | group_id=group_id,
402 | )
403 | episodic_bob_edge = EpisodicEdge(
404 | source_node_uuid=episode_node.uuid,
405 | target_node_uuid=bob_node.uuid,
406 | created_at=now,
407 | group_id=group_id,
408 | )
409 |
410 | # Cross reference the ids
411 | episode_node.entity_edges = [entity_edge.uuid]
412 | entity_edge.episodes = [episode_node.uuid]
413 |
414 | # Test add bulk
415 | await add_nodes_and_edges_bulk(
416 | graph_driver,
417 | [episode_node],
418 | [episodic_alice_edge, episodic_bob_edge],
419 | [alice_node, bob_node],
420 | [entity_edge],
421 | mock_embedder,
422 | )
423 |
424 | node_ids = [episode_node.uuid, alice_node.uuid, bob_node.uuid]
425 | edge_ids = [episodic_alice_edge.uuid, episodic_bob_edge.uuid, entity_edge.uuid]
426 | node_count = await get_node_count(graph_driver, node_ids)
427 | assert node_count == 3
428 | edge_count = await get_edge_count(graph_driver, edge_ids)
429 | assert edge_count == 3
430 |
431 | # Test remove episode
432 | await graphiti.remove_episode(episode_node.uuid)
433 | node_count = await get_node_count(graph_driver, node_ids)
434 | assert node_count == 0
435 | edge_count = await get_edge_count(graph_driver, edge_ids)
436 | assert edge_count == 0
437 |
438 | # Test add bulk again
439 | await add_nodes_and_edges_bulk(
440 | graph_driver,
441 | [episode_node],
442 | [episodic_alice_edge, episodic_bob_edge],
443 | [alice_node, bob_node],
444 | [entity_edge],
445 | mock_embedder,
446 | )
447 | node_count = await get_node_count(graph_driver, node_ids)
448 | assert node_count == 3
449 | edge_count = await get_edge_count(graph_driver, edge_ids)
450 | assert edge_count == 3
451 |
452 |
453 | @pytest.mark.asyncio
454 | async def test_graphiti_retrieve_episodes(
455 | graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
456 | ):
457 | if graph_driver.provider == GraphProvider.FALKORDB:
458 | pytest.skip('Skipping as test fails on FalkorDB')
459 |
460 | graphiti = Graphiti(
461 | graph_driver=graph_driver,
462 | llm_client=mock_llm_client,
463 | embedder=mock_embedder,
464 | cross_encoder=mock_cross_encoder_client,
465 | )
466 |
467 | await graphiti.build_indices_and_constraints()
468 |
469 | now = datetime.now()
470 | valid_at_1 = now - timedelta(days=2)
471 | valid_at_2 = now - timedelta(days=4)
472 | valid_at_3 = now - timedelta(days=6)
473 |
474 | # Create episodic nodes
475 | episode_node_1 = EpisodicNode(
476 | name='test_episode_1',
477 | labels=[],
478 | created_at=now,
479 | valid_at=valid_at_1,
480 | source=EpisodeType.message,
481 | source_description='conversation message',
482 | content='Test message 1',
483 | entity_edges=[],
484 | group_id=group_id,
485 | )
486 | episode_node_2 = EpisodicNode(
487 | name='test_episode_2',
488 | labels=[],
489 | created_at=now,
490 | valid_at=valid_at_2,
491 | source=EpisodeType.message,
492 | source_description='conversation message',
493 | content='Test message 2',
494 | entity_edges=[],
495 | group_id=group_id,
496 | )
497 | episode_node_3 = EpisodicNode(
498 | name='test_episode_3',
499 | labels=[],
500 | created_at=now,
501 | valid_at=valid_at_3,
502 | source=EpisodeType.message,
503 | source_description='conversation message',
504 | content='Test message 3',
505 | entity_edges=[],
506 | group_id=group_id,
507 | )
508 |
509 | # Save the nodes
510 | await episode_node_1.save(graph_driver)
511 | await episode_node_2.save(graph_driver)
512 | await episode_node_3.save(graph_driver)
513 |
514 | node_ids = [episode_node_1.uuid, episode_node_2.uuid, episode_node_3.uuid]
515 | node_count = await get_node_count(graph_driver, node_ids)
516 | assert node_count == 3
517 |
518 | # Retrieve episodes
519 | query_time = now - timedelta(days=3)
520 | episodes = await graphiti.retrieve_episodes(
521 | query_time, last_n=5, group_ids=[group_id], source=EpisodeType.message
522 | )
523 | assert len(episodes) == 2
524 | assert episodes[0].name == episode_node_3.name
525 | assert episodes[1].name == episode_node_2.name
526 |
527 |
528 | @pytest.mark.asyncio
529 | async def test_filter_existing_duplicate_of_edges(graph_driver, mock_embedder):
530 | # Create entity nodes
531 | entity_node_1 = EntityNode(
532 | name='test_entity_1',
533 | labels=[],
534 | created_at=datetime.now(),
535 | group_id=group_id,
536 | )
537 | await entity_node_1.generate_name_embedding(mock_embedder)
538 | entity_node_2 = EntityNode(
539 | name='test_entity_2',
540 | labels=[],
541 | created_at=datetime.now(),
542 | group_id=group_id,
543 | )
544 | await entity_node_2.generate_name_embedding(mock_embedder)
545 | entity_node_3 = EntityNode(
546 | name='test_entity_3',
547 | labels=[],
548 | created_at=datetime.now(),
549 | group_id=group_id,
550 | )
551 | await entity_node_3.generate_name_embedding(mock_embedder)
552 | entity_node_4 = EntityNode(
553 | name='test_entity_4',
554 | labels=[],
555 | created_at=datetime.now(),
556 | group_id=group_id,
557 | )
558 | await entity_node_4.generate_name_embedding(mock_embedder)
559 |
560 | # Save the nodes
561 | await entity_node_1.save(graph_driver)
562 | await entity_node_2.save(graph_driver)
563 | await entity_node_3.save(graph_driver)
564 | await entity_node_4.save(graph_driver)
565 |
566 | node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
567 | node_count = await get_node_count(graph_driver, node_ids)
568 | assert node_count == 4
569 |
570 | # Create duplicate entity edge
571 | entity_edge = EntityEdge(
572 | source_node_uuid=entity_node_1.uuid,
573 | target_node_uuid=entity_node_2.uuid,
574 | name='IS_DUPLICATE_OF',
575 | fact='test_entity_1 is a duplicate of test_entity_2',
576 | created_at=datetime.now(),
577 | group_id=group_id,
578 | )
579 | await entity_edge.generate_embedding(mock_embedder)
580 | await entity_edge.save(graph_driver)
581 |
582 | # Filter duplicate entity edges
583 | duplicate_node_tuples = [
584 | (entity_node_1, entity_node_2),
585 | (entity_node_3, entity_node_4),
586 | ]
587 | node_tuples = await filter_existing_duplicate_of_edges(graph_driver, duplicate_node_tuples)
588 | assert len(node_tuples) == 1
589 | assert [node.name for node in node_tuples[0]] == [entity_node_3.name, entity_node_4.name]
590 |
591 |
592 | @pytest.mark.asyncio
593 | async def test_determine_entity_community(graph_driver, mock_embedder):
594 | if graph_driver.provider == GraphProvider.FALKORDB:
595 | pytest.skip('Skipping as test fails on FalkorDB')
596 |
597 | # Create entity nodes
598 | entity_node_1 = EntityNode(
599 | name='test_entity_1',
600 | labels=[],
601 | created_at=datetime.now(),
602 | group_id=group_id,
603 | )
604 | await entity_node_1.generate_name_embedding(mock_embedder)
605 | entity_node_2 = EntityNode(
606 | name='test_entity_2',
607 | labels=[],
608 | created_at=datetime.now(),
609 | group_id=group_id,
610 | )
611 | await entity_node_2.generate_name_embedding(mock_embedder)
612 | entity_node_3 = EntityNode(
613 | name='test_entity_3',
614 | labels=[],
615 | created_at=datetime.now(),
616 | group_id=group_id,
617 | )
618 | await entity_node_3.generate_name_embedding(mock_embedder)
619 | entity_node_4 = EntityNode(
620 | name='test_entity_4',
621 | labels=[],
622 | created_at=datetime.now(),
623 | group_id=group_id,
624 | )
625 | await entity_node_4.generate_name_embedding(mock_embedder)
626 |
627 | # Create entity edges
628 | entity_edge_1 = EntityEdge(
629 | source_node_uuid=entity_node_1.uuid,
630 | target_node_uuid=entity_node_4.uuid,
631 | name='RELATES_TO',
632 | fact='test_entity_1 relates to test_entity_4',
633 | created_at=datetime.now(),
634 | group_id=group_id,
635 | )
636 | await entity_edge_1.generate_embedding(mock_embedder)
637 | entity_edge_2 = EntityEdge(
638 | source_node_uuid=entity_node_2.uuid,
639 | target_node_uuid=entity_node_4.uuid,
640 | name='RELATES_TO',
641 | fact='test_entity_2 relates to test_entity_4',
642 | created_at=datetime.now(),
643 | group_id=group_id,
644 | )
645 | await entity_edge_2.generate_embedding(mock_embedder)
646 | entity_edge_3 = EntityEdge(
647 | source_node_uuid=entity_node_3.uuid,
648 | target_node_uuid=entity_node_4.uuid,
649 | name='RELATES_TO',
650 | fact='test_entity_3 relates to test_entity_4',
651 | created_at=datetime.now(),
652 | group_id=group_id,
653 | )
654 | await entity_edge_3.generate_embedding(mock_embedder)
655 |
656 | # Create community nodes
657 | community_node_1 = CommunityNode(
658 | name='test_community_1',
659 | labels=[],
660 | created_at=datetime.now(),
661 | group_id=group_id,
662 | )
663 | await community_node_1.generate_name_embedding(mock_embedder)
664 | community_node_2 = CommunityNode(
665 | name='test_community_2',
666 | labels=[],
667 | created_at=datetime.now(),
668 | group_id=group_id,
669 | )
670 | await community_node_2.generate_name_embedding(mock_embedder)
671 |
672 | # Create community to entity edges
673 | community_edge_1 = CommunityEdge(
674 | source_node_uuid=community_node_1.uuid,
675 | target_node_uuid=entity_node_1.uuid,
676 | created_at=datetime.now(),
677 | group_id=group_id,
678 | )
679 | community_edge_2 = CommunityEdge(
680 | source_node_uuid=community_node_1.uuid,
681 | target_node_uuid=entity_node_2.uuid,
682 | created_at=datetime.now(),
683 | group_id=group_id,
684 | )
685 | community_edge_3 = CommunityEdge(
686 | source_node_uuid=community_node_2.uuid,
687 | target_node_uuid=entity_node_3.uuid,
688 | created_at=datetime.now(),
689 | group_id=group_id,
690 | )
691 |
692 | # Save the graph
693 | await entity_node_1.save(graph_driver)
694 | await entity_node_2.save(graph_driver)
695 | await entity_node_3.save(graph_driver)
696 | await entity_node_4.save(graph_driver)
697 | await community_node_1.save(graph_driver)
698 | await community_node_2.save(graph_driver)
699 |
700 | await entity_edge_1.save(graph_driver)
701 | await entity_edge_2.save(graph_driver)
702 | await entity_edge_3.save(graph_driver)
703 | await community_edge_1.save(graph_driver)
704 | await community_edge_2.save(graph_driver)
705 | await community_edge_3.save(graph_driver)
706 |
707 | node_ids = [
708 | entity_node_1.uuid,
709 | entity_node_2.uuid,
710 | entity_node_3.uuid,
711 | entity_node_4.uuid,
712 | community_node_1.uuid,
713 | community_node_2.uuid,
714 | ]
715 | edge_ids = [
716 | entity_edge_1.uuid,
717 | entity_edge_2.uuid,
718 | entity_edge_3.uuid,
719 | community_edge_1.uuid,
720 | community_edge_2.uuid,
721 | community_edge_3.uuid,
722 | ]
723 | node_count = await get_node_count(graph_driver, node_ids)
724 | assert node_count == 6
725 | edge_count = await get_edge_count(graph_driver, edge_ids)
726 | assert edge_count == 6
727 |
728 | # Determine entity community
729 | community, is_new = await determine_entity_community(graph_driver, entity_node_4)
730 | assert community.name == community_node_1.name
731 | assert is_new
732 |
733 | # Add entity to community edge
734 | community_edge_4 = CommunityEdge(
735 | source_node_uuid=community_node_1.uuid,
736 | target_node_uuid=entity_node_4.uuid,
737 | created_at=datetime.now(),
738 | group_id=group_id,
739 | )
740 | await community_edge_4.save(graph_driver)
741 |
742 | # Determine entity community again
743 | community, is_new = await determine_entity_community(graph_driver, entity_node_4)
744 | assert community.name == community_node_1.name
745 | assert not is_new
746 |
747 | await remove_communities(graph_driver)
748 | node_count = await get_node_count(graph_driver, [community_node_1.uuid, community_node_2.uuid])
749 | assert node_count == 0
750 |
751 |
752 | @pytest.mark.asyncio
753 | async def test_get_community_clusters(graph_driver, mock_embedder):
754 | if graph_driver.provider == GraphProvider.FALKORDB:
755 | pytest.skip('Skipping as test fails on FalkorDB')
756 |
757 | # Create entity nodes
758 | entity_node_1 = EntityNode(
759 | name='test_entity_1',
760 | labels=[],
761 | created_at=datetime.now(),
762 | group_id=group_id,
763 | )
764 | await entity_node_1.generate_name_embedding(mock_embedder)
765 | entity_node_2 = EntityNode(
766 | name='test_entity_2',
767 | labels=[],
768 | created_at=datetime.now(),
769 | group_id=group_id,
770 | )
771 | await entity_node_2.generate_name_embedding(mock_embedder)
772 | entity_node_3 = EntityNode(
773 | name='test_entity_3',
774 | labels=[],
775 | created_at=datetime.now(),
776 | group_id=group_id_2,
777 | )
778 | await entity_node_3.generate_name_embedding(mock_embedder)
779 | entity_node_4 = EntityNode(
780 | name='test_entity_4',
781 | labels=[],
782 | created_at=datetime.now(),
783 | group_id=group_id_2,
784 | )
785 | await entity_node_4.generate_name_embedding(mock_embedder)
786 |
787 | # Create entity edges
788 | entity_edge_1 = EntityEdge(
789 | source_node_uuid=entity_node_1.uuid,
790 | target_node_uuid=entity_node_2.uuid,
791 | name='RELATES_TO',
792 | fact='test_entity_1 relates to test_entity_2',
793 | created_at=datetime.now(),
794 | group_id=group_id,
795 | )
796 | await entity_edge_1.generate_embedding(mock_embedder)
797 | entity_edge_2 = EntityEdge(
798 | source_node_uuid=entity_node_3.uuid,
799 | target_node_uuid=entity_node_4.uuid,
800 | name='RELATES_TO',
801 | fact='test_entity_3 relates to test_entity_4',
802 | created_at=datetime.now(),
803 | group_id=group_id_2,
804 | )
805 | await entity_edge_2.generate_embedding(mock_embedder)
806 |
807 | # Save the graph
808 | await entity_node_1.save(graph_driver)
809 | await entity_node_2.save(graph_driver)
810 | await entity_node_3.save(graph_driver)
811 | await entity_node_4.save(graph_driver)
812 | await entity_edge_1.save(graph_driver)
813 | await entity_edge_2.save(graph_driver)
814 |
815 | node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
816 | edge_ids = [entity_edge_1.uuid, entity_edge_2.uuid]
817 | node_count = await get_node_count(graph_driver, node_ids)
818 | assert node_count == 4
819 | edge_count = await get_edge_count(graph_driver, edge_ids)
820 | assert edge_count == 2
821 |
822 | # Get community clusters
823 | clusters = await get_community_clusters(graph_driver, group_ids=None)
824 | assert len(clusters) == 2
825 | assert len(clusters[0]) == 2
826 | assert len(clusters[1]) == 2
827 | entities_1 = set([node.name for node in clusters[0]])
828 | entities_2 = set([node.name for node in clusters[1]])
829 | assert entities_1 == set(['test_entity_1', 'test_entity_2']) or entities_2 == set(
830 | ['test_entity_1', 'test_entity_2']
831 | )
832 | assert entities_1 == set(['test_entity_3', 'test_entity_4']) or entities_2 == set(
833 | ['test_entity_3', 'test_entity_4']
834 | )
835 |
836 |
837 | @pytest.mark.asyncio
838 | async def test_get_mentioned_nodes(graph_driver, mock_embedder):
839 | # Create episodic nodes
840 | episodic_node_1 = EpisodicNode(
841 | name='test_episodic_1',
842 | labels=[],
843 | created_at=datetime.now(),
844 | group_id=group_id,
845 | source=EpisodeType.message,
846 | source_description='test_source_description',
847 | content='test_content',
848 | valid_at=datetime.now(),
849 | )
850 | # Create entity nodes
851 | entity_node_1 = EntityNode(
852 | name='test_entity_1',
853 | labels=[],
854 | created_at=datetime.now(),
855 | group_id=group_id,
856 | )
857 | await entity_node_1.generate_name_embedding(mock_embedder)
858 |
859 | # Create episodic to entity edges
860 | episodic_edge_1 = EpisodicEdge(
861 | source_node_uuid=episodic_node_1.uuid,
862 | target_node_uuid=entity_node_1.uuid,
863 | created_at=datetime.now(),
864 | group_id=group_id,
865 | )
866 |
867 | # Save the graph
868 | await episodic_node_1.save(graph_driver)
869 | await entity_node_1.save(graph_driver)
870 | await episodic_edge_1.save(graph_driver)
871 |
872 | # Get mentioned nodes
873 | mentioned_nodes = await get_mentioned_nodes(graph_driver, [episodic_node_1])
874 | assert len(mentioned_nodes) == 1
875 | assert mentioned_nodes[0].name == entity_node_1.name
876 |
877 |
878 | @pytest.mark.asyncio
879 | async def test_get_communities_by_nodes(graph_driver, mock_embedder):
880 | # Create entity nodes
881 | entity_node_1 = EntityNode(
882 | name='test_entity_1',
883 | labels=[],
884 | created_at=datetime.now(),
885 | group_id=group_id,
886 | )
887 | await entity_node_1.generate_name_embedding(mock_embedder)
888 |
889 | # Create community nodes
890 | community_node_1 = CommunityNode(
891 | name='test_community_1',
892 | labels=[],
893 | created_at=datetime.now(),
894 | group_id=group_id,
895 | )
896 | await community_node_1.generate_name_embedding(mock_embedder)
897 |
898 | # Create community to entity edges
899 | community_edge_1 = CommunityEdge(
900 | source_node_uuid=community_node_1.uuid,
901 | target_node_uuid=entity_node_1.uuid,
902 | created_at=datetime.now(),
903 | group_id=group_id,
904 | )
905 |
906 | # Save the graph
907 | await entity_node_1.save(graph_driver)
908 | await community_node_1.save(graph_driver)
909 | await community_edge_1.save(graph_driver)
910 |
911 | # Get communities by nodes
912 | communities = await get_communities_by_nodes(graph_driver, [entity_node_1])
913 | assert len(communities) == 1
914 | assert communities[0].name == community_node_1.name
915 |
916 |
917 | @pytest.mark.asyncio
918 | async def test_edge_fulltext_search(
919 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
920 | ):
921 | if graph_driver.provider == GraphProvider.KUZU:
922 | pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
923 |
924 | graphiti = Graphiti(
925 | graph_driver=graph_driver,
926 | llm_client=mock_llm_client,
927 | embedder=mock_embedder,
928 | cross_encoder=mock_cross_encoder_client,
929 | )
930 | await graphiti.build_indices_and_constraints()
931 |
932 | # Create entity nodes
933 | entity_node_1 = EntityNode(
934 | name='test_entity_1',
935 | labels=[],
936 | created_at=datetime.now(),
937 | group_id=group_id,
938 | )
939 | await entity_node_1.generate_name_embedding(mock_embedder)
940 | entity_node_2 = EntityNode(
941 | name='test_entity_2',
942 | labels=[],
943 | created_at=datetime.now(),
944 | group_id=group_id,
945 | )
946 | await entity_node_2.generate_name_embedding(mock_embedder)
947 |
948 | now = datetime.now()
949 | created_at = now
950 | expired_at = now + timedelta(days=6)
951 | valid_at = now + timedelta(days=2)
952 | invalid_at = now + timedelta(days=4)
953 |
954 | # Create entity edges
955 | entity_edge_1 = EntityEdge(
956 | source_node_uuid=entity_node_1.uuid,
957 | target_node_uuid=entity_node_2.uuid,
958 | name='RELATES_TO',
959 | fact='test_entity_1 relates to test_entity_2',
960 | created_at=created_at,
961 | valid_at=valid_at,
962 | invalid_at=invalid_at,
963 | expired_at=expired_at,
964 | group_id=group_id,
965 | )
966 | await entity_edge_1.generate_embedding(mock_embedder)
967 |
968 | # Save the graph
969 | await entity_node_1.save(graph_driver)
970 | await entity_node_2.save(graph_driver)
971 | await entity_edge_1.save(graph_driver)
972 |
973 | # Search for entity edges
974 | search_filters = SearchFilters(
975 | node_labels=['Entity'],
976 | edge_types=['RELATES_TO'],
977 | created_at=[
978 | [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
979 | ],
980 | expired_at=[
981 | [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
982 | ],
983 | valid_at=[
984 | [
985 | DateFilter(
986 | date=now + timedelta(days=1),
987 | comparison_operator=ComparisonOperator.greater_than_equal,
988 | )
989 | ],
990 | [
991 | DateFilter(
992 | date=now + timedelta(days=3),
993 | comparison_operator=ComparisonOperator.less_than_equal,
994 | )
995 | ],
996 | ],
997 | invalid_at=[
998 | [
999 | DateFilter(
1000 | date=now + timedelta(days=3),
1001 | comparison_operator=ComparisonOperator.greater_than,
1002 | )
1003 | ],
1004 | [
1005 | DateFilter(
1006 | date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1007 | )
1008 | ],
1009 | ],
1010 | )
1011 | edges = await edge_fulltext_search(
1012 | graph_driver, 'test_entity_1 relates to test_entity_2', search_filters, group_ids=[group_id]
1013 | )
1014 | assert len(edges) == 1
1015 | assert edges[0].name == entity_edge_1.name
1016 |
1017 |
1018 | @pytest.mark.asyncio
1019 | async def test_edge_similarity_search(graph_driver, mock_embedder):
1020 | if graph_driver.provider == GraphProvider.FALKORDB:
1021 | pytest.skip('Skipping as tests fail on Falkordb')
1022 |
1023 | # Create entity nodes
1024 | entity_node_1 = EntityNode(
1025 | name='test_entity_1',
1026 | labels=[],
1027 | created_at=datetime.now(),
1028 | group_id=group_id,
1029 | )
1030 | await entity_node_1.generate_name_embedding(mock_embedder)
1031 | entity_node_2 = EntityNode(
1032 | name='test_entity_2',
1033 | labels=[],
1034 | created_at=datetime.now(),
1035 | group_id=group_id,
1036 | )
1037 | await entity_node_2.generate_name_embedding(mock_embedder)
1038 |
1039 | now = datetime.now()
1040 | created_at = now
1041 | expired_at = now + timedelta(days=6)
1042 | valid_at = now + timedelta(days=2)
1043 | invalid_at = now + timedelta(days=4)
1044 |
1045 | # Create entity edges
1046 | entity_edge_1 = EntityEdge(
1047 | source_node_uuid=entity_node_1.uuid,
1048 | target_node_uuid=entity_node_2.uuid,
1049 | name='RELATES_TO',
1050 | fact='test_entity_1 relates to test_entity_2',
1051 | created_at=created_at,
1052 | valid_at=valid_at,
1053 | invalid_at=invalid_at,
1054 | expired_at=expired_at,
1055 | group_id=group_id,
1056 | )
1057 | await entity_edge_1.generate_embedding(mock_embedder)
1058 |
1059 | # Save the graph
1060 | await entity_node_1.save(graph_driver)
1061 | await entity_node_2.save(graph_driver)
1062 | await entity_edge_1.save(graph_driver)
1063 |
1064 | # Search for entity edges
1065 | search_filters = SearchFilters(
1066 | node_labels=['Entity'],
1067 | edge_types=['RELATES_TO'],
1068 | created_at=[
1069 | [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
1070 | ],
1071 | expired_at=[
1072 | [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
1073 | ],
1074 | valid_at=[
1075 | [
1076 | DateFilter(
1077 | date=now + timedelta(days=1),
1078 | comparison_operator=ComparisonOperator.greater_than_equal,
1079 | )
1080 | ],
1081 | [
1082 | DateFilter(
1083 | date=now + timedelta(days=3),
1084 | comparison_operator=ComparisonOperator.less_than_equal,
1085 | )
1086 | ],
1087 | ],
1088 | invalid_at=[
1089 | [
1090 | DateFilter(
1091 | date=now + timedelta(days=3),
1092 | comparison_operator=ComparisonOperator.greater_than,
1093 | )
1094 | ],
1095 | [
1096 | DateFilter(
1097 | date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1098 | )
1099 | ],
1100 | ],
1101 | )
1102 | edges = await edge_similarity_search(
1103 | graph_driver,
1104 | entity_edge_1.fact_embedding,
1105 | entity_node_1.uuid,
1106 | entity_node_2.uuid,
1107 | search_filters,
1108 | group_ids=[group_id],
1109 | )
1110 | assert len(edges) == 1
1111 | assert edges[0].name == entity_edge_1.name
1112 |
1113 |
1114 | @pytest.mark.asyncio
1115 | async def test_edge_bfs_search(graph_driver, mock_embedder):
1116 | if graph_driver.provider == GraphProvider.FALKORDB:
1117 | pytest.skip('Skipping as tests fail on Falkordb')
1118 |
1119 | # Create episodic nodes
1120 | episodic_node_1 = EpisodicNode(
1121 | name='test_episodic_1',
1122 | labels=[],
1123 | created_at=datetime.now(),
1124 | group_id=group_id,
1125 | source=EpisodeType.message,
1126 | source_description='test_source_description',
1127 | content='test_content',
1128 | valid_at=datetime.now(),
1129 | )
1130 |
1131 | # Create entity nodes
1132 | entity_node_1 = EntityNode(
1133 | name='test_entity_1',
1134 | labels=[],
1135 | created_at=datetime.now(),
1136 | group_id=group_id,
1137 | )
1138 | await entity_node_1.generate_name_embedding(mock_embedder)
1139 | entity_node_2 = EntityNode(
1140 | name='test_entity_2',
1141 | labels=[],
1142 | created_at=datetime.now(),
1143 | group_id=group_id,
1144 | )
1145 | await entity_node_2.generate_name_embedding(mock_embedder)
1146 | entity_node_3 = EntityNode(
1147 | name='test_entity_3',
1148 | labels=[],
1149 | created_at=datetime.now(),
1150 | group_id=group_id,
1151 | )
1152 | await entity_node_3.generate_name_embedding(mock_embedder)
1153 |
1154 | now = datetime.now()
1155 | created_at = now
1156 | expired_at = now + timedelta(days=6)
1157 | valid_at = now + timedelta(days=2)
1158 | invalid_at = now + timedelta(days=4)
1159 |
1160 | # Create entity edges
1161 | entity_edge_1 = EntityEdge(
1162 | source_node_uuid=entity_node_1.uuid,
1163 | target_node_uuid=entity_node_2.uuid,
1164 | name='RELATES_TO',
1165 | fact='test_entity_1 relates to test_entity_2',
1166 | created_at=created_at,
1167 | valid_at=valid_at,
1168 | invalid_at=invalid_at,
1169 | expired_at=expired_at,
1170 | group_id=group_id,
1171 | )
1172 | await entity_edge_1.generate_embedding(mock_embedder)
1173 | entity_edge_2 = EntityEdge(
1174 | source_node_uuid=entity_node_2.uuid,
1175 | target_node_uuid=entity_node_3.uuid,
1176 | name='RELATES_TO',
1177 | fact='test_entity_2 relates to test_entity_3',
1178 | created_at=created_at,
1179 | valid_at=valid_at,
1180 | invalid_at=invalid_at,
1181 | expired_at=expired_at,
1182 | group_id=group_id,
1183 | )
1184 | await entity_edge_2.generate_embedding(mock_embedder)
1185 |
1186 | # Create episodic to entity edges
1187 | episodic_edge_1 = EpisodicEdge(
1188 | source_node_uuid=episodic_node_1.uuid,
1189 | target_node_uuid=entity_node_1.uuid,
1190 | created_at=datetime.now(),
1191 | group_id=group_id,
1192 | )
1193 |
1194 | # Save the graph
1195 | await episodic_node_1.save(graph_driver)
1196 | await entity_node_1.save(graph_driver)
1197 | await entity_node_2.save(graph_driver)
1198 | await entity_node_3.save(graph_driver)
1199 | await entity_edge_1.save(graph_driver)
1200 | await entity_edge_2.save(graph_driver)
1201 | await episodic_edge_1.save(graph_driver)
1202 |
1203 | # Search for entity edges
1204 | search_filters = SearchFilters(
1205 | node_labels=['Entity'],
1206 | edge_types=['RELATES_TO'],
1207 | created_at=[
1208 | [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
1209 | ],
1210 | expired_at=[
1211 | [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
1212 | ],
1213 | valid_at=[
1214 | [
1215 | DateFilter(
1216 | date=now + timedelta(days=1),
1217 | comparison_operator=ComparisonOperator.greater_than_equal,
1218 | )
1219 | ],
1220 | [
1221 | DateFilter(
1222 | date=now + timedelta(days=3),
1223 | comparison_operator=ComparisonOperator.less_than_equal,
1224 | )
1225 | ],
1226 | ],
1227 | invalid_at=[
1228 | [
1229 | DateFilter(
1230 | date=now + timedelta(days=3),
1231 | comparison_operator=ComparisonOperator.greater_than,
1232 | )
1233 | ],
1234 | [
1235 | DateFilter(
1236 | date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1237 | )
1238 | ],
1239 | ],
1240 | )
1241 |
1242 | # Test bfs from episodic node
1243 |
1244 | edges = await edge_bfs_search(
1245 | graph_driver,
1246 | [episodic_node_1.uuid],
1247 | 1,
1248 | search_filters,
1249 | group_ids=[group_id],
1250 | )
1251 | assert len(edges) == 0
1252 |
1253 | edges = await edge_bfs_search(
1254 | graph_driver,
1255 | [episodic_node_1.uuid],
1256 | 2,
1257 | search_filters,
1258 | group_ids=[group_id],
1259 | )
1260 | edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1261 | assert len(edges_deduplicated) == 1
1262 | assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
1263 |
1264 | edges = await edge_bfs_search(
1265 | graph_driver,
1266 | [episodic_node_1.uuid],
1267 | 3,
1268 | search_filters,
1269 | group_ids=[group_id],
1270 | )
1271 | edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1272 | assert len(edges_deduplicated) == 2
1273 | assert edges_deduplicated == {
1274 | 'test_entity_1 relates to test_entity_2',
1275 | 'test_entity_2 relates to test_entity_3',
1276 | }
1277 |
1278 | # Test bfs from entity node
1279 |
1280 | edges = await edge_bfs_search(
1281 | graph_driver,
1282 | [entity_node_1.uuid],
1283 | 1,
1284 | search_filters,
1285 | group_ids=[group_id],
1286 | )
1287 | edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1288 | assert len(edges_deduplicated) == 1
1289 | assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}
1290 |
1291 | edges = await edge_bfs_search(
1292 | graph_driver,
1293 | [entity_node_1.uuid],
1294 | 2,
1295 | search_filters,
1296 | group_ids=[group_id],
1297 | )
1298 | edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
1299 | assert len(edges_deduplicated) == 2
1300 | assert edges_deduplicated == {
1301 | 'test_entity_1 relates to test_entity_2',
1302 | 'test_entity_2 relates to test_entity_3',
1303 | }
1304 |
1305 |
1306 | @pytest.mark.asyncio
1307 | async def test_node_fulltext_search(
1308 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1309 | ):
1310 | if graph_driver.provider == GraphProvider.KUZU:
1311 | pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
1312 |
1313 | graphiti = Graphiti(
1314 | graph_driver=graph_driver,
1315 | llm_client=mock_llm_client,
1316 | embedder=mock_embedder,
1317 | cross_encoder=mock_cross_encoder_client,
1318 | )
1319 | await graphiti.build_indices_and_constraints()
1320 |
1321 | # Create entity nodes
1322 | entity_node_1 = EntityNode(
1323 | name='test_entity_1',
1324 | summary='Summary about Alice',
1325 | labels=[],
1326 | created_at=datetime.now(),
1327 | group_id=group_id,
1328 | )
1329 | await entity_node_1.generate_name_embedding(mock_embedder)
1330 | entity_node_2 = EntityNode(
1331 | name='test_entity_2',
1332 | summary='Summary about Bob',
1333 | labels=[],
1334 | created_at=datetime.now(),
1335 | group_id=group_id,
1336 | )
1337 | await entity_node_2.generate_name_embedding(mock_embedder)
1338 |
1339 | # Save the graph
1340 | await entity_node_1.save(graph_driver)
1341 | await entity_node_2.save(graph_driver)
1342 |
1343 | # Search for entity edges
1344 | search_filters = SearchFilters(node_labels=['Entity'])
1345 | nodes = await node_fulltext_search(
1346 | graph_driver,
1347 | 'Alice',
1348 | search_filters,
1349 | group_ids=[group_id],
1350 | )
1351 | assert len(nodes) == 1
1352 | assert nodes[0].name == entity_node_1.name
1353 |
1354 |
1355 | @pytest.mark.asyncio
1356 | async def test_node_similarity_search(graph_driver, mock_embedder):
1357 | if graph_driver.provider == GraphProvider.FALKORDB:
1358 | pytest.skip('Skipping as tests fail on Falkordb')
1359 |
1360 | # Create entity nodes
1361 | entity_node_1 = EntityNode(
1362 | name='test_entity_alice',
1363 | summary='Summary about Alice',
1364 | labels=[],
1365 | created_at=datetime.now(),
1366 | group_id=group_id,
1367 | )
1368 | await entity_node_1.generate_name_embedding(mock_embedder)
1369 | entity_node_2 = EntityNode(
1370 | name='test_entity_bob',
1371 | summary='Summary about Bob',
1372 | labels=[],
1373 | created_at=datetime.now(),
1374 | group_id=group_id,
1375 | )
1376 | await entity_node_2.generate_name_embedding(mock_embedder)
1377 |
1378 | # Save the graph
1379 | await entity_node_1.save(graph_driver)
1380 | await entity_node_2.save(graph_driver)
1381 |
1382 | # Search for entity edges
1383 | search_filters = SearchFilters(node_labels=['Entity'])
1384 | nodes = await node_similarity_search(
1385 | graph_driver,
1386 | entity_node_1.name_embedding,
1387 | search_filters,
1388 | group_ids=[group_id],
1389 | min_score=0.9,
1390 | )
1391 | assert len(nodes) == 1
1392 | assert nodes[0].name == entity_node_1.name
1393 |
1394 |
1395 | @pytest.mark.asyncio
1396 | async def test_node_bfs_search(graph_driver, mock_embedder):
1397 | if graph_driver.provider == GraphProvider.FALKORDB:
1398 | pytest.skip('Skipping as tests fail on Falkordb')
1399 |
1400 | # Create episodic nodes
1401 | episodic_node_1 = EpisodicNode(
1402 | name='test_episodic_1',
1403 | labels=[],
1404 | created_at=datetime.now(),
1405 | group_id=group_id,
1406 | source=EpisodeType.message,
1407 | source_description='test_source_description',
1408 | content='test_content',
1409 | valid_at=datetime.now(),
1410 | )
1411 |
1412 | # Create entity nodes
1413 | entity_node_1 = EntityNode(
1414 | name='test_entity_1',
1415 | labels=[],
1416 | created_at=datetime.now(),
1417 | group_id=group_id,
1418 | )
1419 | await entity_node_1.generate_name_embedding(mock_embedder)
1420 | entity_node_2 = EntityNode(
1421 | name='test_entity_2',
1422 | labels=[],
1423 | created_at=datetime.now(),
1424 | group_id=group_id,
1425 | )
1426 | await entity_node_2.generate_name_embedding(mock_embedder)
1427 | entity_node_3 = EntityNode(
1428 | name='test_entity_3',
1429 | labels=[],
1430 | created_at=datetime.now(),
1431 | group_id=group_id,
1432 | )
1433 | await entity_node_3.generate_name_embedding(mock_embedder)
1434 |
1435 | # Create entity edges
1436 | entity_edge_1 = EntityEdge(
1437 | source_node_uuid=entity_node_1.uuid,
1438 | target_node_uuid=entity_node_2.uuid,
1439 | name='RELATES_TO',
1440 | fact='test_entity_1 relates to test_entity_2',
1441 | created_at=datetime.now(),
1442 | group_id=group_id,
1443 | )
1444 | await entity_edge_1.generate_embedding(mock_embedder)
1445 | entity_edge_2 = EntityEdge(
1446 | source_node_uuid=entity_node_2.uuid,
1447 | target_node_uuid=entity_node_3.uuid,
1448 | name='RELATES_TO',
1449 | fact='test_entity_2 relates to test_entity_3',
1450 | created_at=datetime.now(),
1451 | group_id=group_id,
1452 | )
1453 | await entity_edge_2.generate_embedding(mock_embedder)
1454 |
1455 | # Create episodic to entity edges
1456 | episodic_edge_1 = EpisodicEdge(
1457 | source_node_uuid=episodic_node_1.uuid,
1458 | target_node_uuid=entity_node_1.uuid,
1459 | created_at=datetime.now(),
1460 | group_id=group_id,
1461 | )
1462 |
1463 | # Save the graph
1464 | await episodic_node_1.save(graph_driver)
1465 | await entity_node_1.save(graph_driver)
1466 | await entity_node_2.save(graph_driver)
1467 | await entity_node_3.save(graph_driver)
1468 | await entity_edge_1.save(graph_driver)
1469 | await entity_edge_2.save(graph_driver)
1470 | await episodic_edge_1.save(graph_driver)
1471 |
1472 | # Search for entity nodes
1473 | search_filters = SearchFilters(
1474 | node_labels=['Entity'],
1475 | )
1476 |
1477 | # Test bfs from episodic node
1478 |
1479 | nodes = await node_bfs_search(
1480 | graph_driver,
1481 | [episodic_node_1.uuid],
1482 | search_filters,
1483 | 1,
1484 | group_ids=[group_id],
1485 | )
1486 | nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
1487 | assert len(nodes_deduplicated) == 1
1488 | assert nodes_deduplicated == {'test_entity_1'}
1489 |
1490 | nodes = await node_bfs_search(
1491 | graph_driver,
1492 | [episodic_node_1.uuid],
1493 | search_filters,
1494 | 2,
1495 | group_ids=[group_id],
1496 | )
1497 | nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
1498 | assert len(nodes_deduplicated) == 2
1499 | assert nodes_deduplicated == {'test_entity_1', 'test_entity_2'}
1500 |
1501 | # Test bfs from entity node
1502 |
1503 | nodes = await node_bfs_search(
1504 | graph_driver,
1505 | [entity_node_1.uuid],
1506 | search_filters,
1507 | 1,
1508 | group_ids=[group_id],
1509 | )
1510 | nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
1511 | assert len(nodes_deduplicated) == 1
1512 | assert nodes_deduplicated == {'test_entity_2'}
1513 |
1514 |
1515 | @pytest.mark.asyncio
1516 | async def test_episode_fulltext_search(
1517 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1518 | ):
1519 | if graph_driver.provider == GraphProvider.KUZU:
1520 | pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
1521 |
1522 | graphiti = Graphiti(
1523 | graph_driver=graph_driver,
1524 | llm_client=mock_llm_client,
1525 | embedder=mock_embedder,
1526 | cross_encoder=mock_cross_encoder_client,
1527 | )
1528 | await graphiti.build_indices_and_constraints()
1529 |
1530 | # Create episodic nodes
1531 | episodic_node_1 = EpisodicNode(
1532 | name='test_episodic_1',
1533 | content='test_content',
1534 | created_at=datetime.now(),
1535 | valid_at=datetime.now(),
1536 | group_id=group_id,
1537 | source=EpisodeType.message,
1538 | source_description='Description about Alice',
1539 | )
1540 | episodic_node_2 = EpisodicNode(
1541 | name='test_episodic_2',
1542 | content='test_content_2',
1543 | created_at=datetime.now(),
1544 | valid_at=datetime.now(),
1545 | group_id=group_id,
1546 | source=EpisodeType.message,
1547 | source_description='Description about Bob',
1548 | )
1549 |
1550 | # Save the graph
1551 | await episodic_node_1.save(graph_driver)
1552 | await episodic_node_2.save(graph_driver)
1553 |
1554 | # Search for episodic nodes
1555 | search_filters = SearchFilters(node_labels=['Episodic'])
1556 | nodes = await episode_fulltext_search(
1557 | graph_driver,
1558 | 'Alice',
1559 | search_filters,
1560 | group_ids=[group_id],
1561 | )
1562 | assert len(nodes) == 1
1563 | assert nodes[0].name == episodic_node_1.name
1564 |
1565 |
1566 | @pytest.mark.asyncio
1567 | async def test_community_fulltext_search(
1568 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1569 | ):
1570 | if graph_driver.provider == GraphProvider.KUZU:
1571 | pytest.skip('Skipping as fulltext indexing not supported for Kuzu')
1572 |
1573 | graphiti = Graphiti(
1574 | graph_driver=graph_driver,
1575 | llm_client=mock_llm_client,
1576 | embedder=mock_embedder,
1577 | cross_encoder=mock_cross_encoder_client,
1578 | )
1579 | await graphiti.build_indices_and_constraints()
1580 |
1581 | # Create community nodes
1582 | community_node_1 = CommunityNode(
1583 | name='Alice',
1584 | created_at=datetime.now(),
1585 | group_id=group_id,
1586 | )
1587 | await community_node_1.generate_name_embedding(mock_embedder)
1588 | community_node_2 = CommunityNode(
1589 | name='Bob',
1590 | created_at=datetime.now(),
1591 | group_id=group_id,
1592 | )
1593 | await community_node_2.generate_name_embedding(mock_embedder)
1594 |
1595 | # Save the graph
1596 | await community_node_1.save(graph_driver)
1597 | await community_node_2.save(graph_driver)
1598 |
1599 | # Search for community nodes
1600 | nodes = await community_fulltext_search(
1601 | graph_driver,
1602 | 'Alice',
1603 | group_ids=[group_id],
1604 | )
1605 | assert len(nodes) == 1
1606 | assert nodes[0].name == community_node_1.name
1607 |
1608 |
1609 | @pytest.mark.asyncio
1610 | async def test_community_similarity_search(
1611 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1612 | ):
1613 | if graph_driver.provider == GraphProvider.FALKORDB:
1614 | pytest.skip('Skipping as tests fail on Falkordb')
1615 |
1616 | graphiti = Graphiti(
1617 | graph_driver=graph_driver,
1618 | llm_client=mock_llm_client,
1619 | embedder=mock_embedder,
1620 | cross_encoder=mock_cross_encoder_client,
1621 | )
1622 | await graphiti.build_indices_and_constraints()
1623 |
1624 | # Create community nodes
1625 | community_node_1 = CommunityNode(
1626 | name='Alice',
1627 | created_at=datetime.now(),
1628 | group_id=group_id,
1629 | )
1630 | await community_node_1.generate_name_embedding(mock_embedder)
1631 | community_node_2 = CommunityNode(
1632 | name='Bob',
1633 | created_at=datetime.now(),
1634 | group_id=group_id,
1635 | )
1636 | await community_node_2.generate_name_embedding(mock_embedder)
1637 |
1638 | # Save the graph
1639 | await community_node_1.save(graph_driver)
1640 | await community_node_2.save(graph_driver)
1641 |
1642 | # Search for community nodes
1643 | nodes = await community_similarity_search(
1644 | graph_driver,
1645 | community_node_1.name_embedding,
1646 | group_ids=[group_id],
1647 | min_score=0.9,
1648 | )
1649 | assert len(nodes) == 1
1650 | assert nodes[0].name == community_node_1.name
1651 |
1652 |
1653 | @pytest.mark.asyncio
1654 | async def test_get_relevant_nodes(
1655 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1656 | ):
1657 | if graph_driver.provider == GraphProvider.FALKORDB:
1658 | pytest.skip('Skipping as tests fail on Falkordb')
1659 |
1660 | if graph_driver.provider == GraphProvider.KUZU:
1661 | pytest.skip('Skipping as tests fail on Kuzu')
1662 |
1663 | graphiti = Graphiti(
1664 | graph_driver=graph_driver,
1665 | llm_client=mock_llm_client,
1666 | embedder=mock_embedder,
1667 | cross_encoder=mock_cross_encoder_client,
1668 | )
1669 | await graphiti.build_indices_and_constraints()
1670 |
1671 | # Create entity nodes
1672 | entity_node_1 = EntityNode(
1673 | name='Alice',
1674 | summary='Alice',
1675 | labels=[],
1676 | created_at=datetime.now(),
1677 | group_id=group_id,
1678 | )
1679 | await entity_node_1.generate_name_embedding(mock_embedder)
1680 | entity_node_2 = EntityNode(
1681 | name='Bob',
1682 | summary='Bob',
1683 | labels=[],
1684 | created_at=datetime.now(),
1685 | group_id=group_id,
1686 | )
1687 | await entity_node_2.generate_name_embedding(mock_embedder)
1688 | entity_node_3 = EntityNode(
1689 | name='Alice Smith',
1690 | summary='Alice Smith',
1691 | labels=[],
1692 | created_at=datetime.now(),
1693 | group_id=group_id,
1694 | )
1695 | await entity_node_3.generate_name_embedding(mock_embedder)
1696 |
1697 | # Save the graph
1698 | await entity_node_1.save(graph_driver)
1699 | await entity_node_2.save(graph_driver)
1700 | await entity_node_3.save(graph_driver)
1701 |
1702 | # Search for entity nodes
1703 | search_filters = SearchFilters(node_labels=['Entity'])
1704 | nodes = (
1705 | await get_relevant_nodes(
1706 | graph_driver,
1707 | [entity_node_1],
1708 | search_filters,
1709 | min_score=0.9,
1710 | )
1711 | )[0]
1712 | assert len(nodes) == 2
1713 | assert set({node.name for node in nodes}) == {entity_node_1.name, entity_node_3.name}
1714 |
1715 |
1716 | @pytest.mark.asyncio
1717 | async def test_get_relevant_edges_and_invalidation_candidates(
1718 | graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
1719 | ):
1720 | if graph_driver.provider == GraphProvider.FALKORDB:
1721 | pytest.skip('Skipping as tests fail on Falkordb')
1722 |
1723 | graphiti = Graphiti(
1724 | graph_driver=graph_driver,
1725 | llm_client=mock_llm_client,
1726 | embedder=mock_embedder,
1727 | cross_encoder=mock_cross_encoder_client,
1728 | )
1729 | await graphiti.build_indices_and_constraints()
1730 |
1731 | # Create entity nodes
1732 | entity_node_1 = EntityNode(
1733 | name='test_entity_1',
1734 | summary='test_entity_1',
1735 | labels=[],
1736 | created_at=datetime.now(),
1737 | group_id=group_id,
1738 | )
1739 | await entity_node_1.generate_name_embedding(mock_embedder)
1740 | entity_node_2 = EntityNode(
1741 | name='test_entity_2',
1742 | summary='test_entity_2',
1743 | labels=[],
1744 | created_at=datetime.now(),
1745 | group_id=group_id,
1746 | )
1747 | await entity_node_2.generate_name_embedding(mock_embedder)
1748 | entity_node_3 = EntityNode(
1749 | name='test_entity_3',
1750 | summary='test_entity_3',
1751 | labels=[],
1752 | created_at=datetime.now(),
1753 | group_id=group_id,
1754 | )
1755 | await entity_node_3.generate_name_embedding(mock_embedder)
1756 |
1757 | now = datetime.now()
1758 | created_at = now
1759 | expired_at = now + timedelta(days=6)
1760 | valid_at = now + timedelta(days=2)
1761 | invalid_at = now + timedelta(days=4)
1762 |
1763 | # Create entity edges
1764 | entity_edge_1 = EntityEdge(
1765 | source_node_uuid=entity_node_1.uuid,
1766 | target_node_uuid=entity_node_2.uuid,
1767 | name='RELATES_TO',
1768 | fact='Alice',
1769 | created_at=created_at,
1770 | expired_at=expired_at,
1771 | valid_at=valid_at,
1772 | invalid_at=invalid_at,
1773 | group_id=group_id,
1774 | )
1775 | await entity_edge_1.generate_embedding(mock_embedder)
1776 | entity_edge_2 = EntityEdge(
1777 | source_node_uuid=entity_node_2.uuid,
1778 | target_node_uuid=entity_node_3.uuid,
1779 | name='RELATES_TO',
1780 | fact='Bob',
1781 | created_at=created_at,
1782 | expired_at=expired_at,
1783 | valid_at=valid_at,
1784 | invalid_at=invalid_at,
1785 | group_id=group_id,
1786 | )
1787 | await entity_edge_2.generate_embedding(mock_embedder)
1788 | entity_edge_3 = EntityEdge(
1789 | source_node_uuid=entity_node_1.uuid,
1790 | target_node_uuid=entity_node_3.uuid,
1791 | name='RELATES_TO',
1792 | fact='Alice',
1793 | created_at=created_at,
1794 | expired_at=expired_at,
1795 | valid_at=valid_at,
1796 | invalid_at=invalid_at,
1797 | group_id=group_id,
1798 | )
1799 | await entity_edge_3.generate_embedding(mock_embedder)
1800 |
1801 | # Save the graph
1802 | await entity_node_1.save(graph_driver)
1803 | await entity_node_2.save(graph_driver)
1804 | await entity_node_3.save(graph_driver)
1805 | await entity_edge_1.save(graph_driver)
1806 | await entity_edge_2.save(graph_driver)
1807 | await entity_edge_3.save(graph_driver)
1808 |
1809 | # Search for entity nodes
1810 | search_filters = SearchFilters(
1811 | node_labels=['Entity'],
1812 | edge_types=['RELATES_TO'],
1813 | created_at=[
1814 | [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
1815 | ],
1816 | expired_at=[
1817 | [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
1818 | ],
1819 | valid_at=[
1820 | [
1821 | DateFilter(
1822 | date=now + timedelta(days=1),
1823 | comparison_operator=ComparisonOperator.greater_than_equal,
1824 | )
1825 | ],
1826 | [
1827 | DateFilter(
1828 | date=now + timedelta(days=3),
1829 | comparison_operator=ComparisonOperator.less_than_equal,
1830 | )
1831 | ],
1832 | ],
1833 | invalid_at=[
1834 | [
1835 | DateFilter(
1836 | date=now + timedelta(days=3),
1837 | comparison_operator=ComparisonOperator.greater_than,
1838 | )
1839 | ],
1840 | [
1841 | DateFilter(
1842 | date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
1843 | )
1844 | ],
1845 | ],
1846 | )
1847 | edges = (
1848 | await get_relevant_edges(
1849 | graph_driver,
1850 | [entity_edge_1],
1851 | search_filters,
1852 | min_score=0.9,
1853 | )
1854 | )[0]
1855 | assert len(edges) == 1
1856 | assert set({edge.name for edge in edges}) == {entity_edge_1.name}
1857 |
1858 | edges = (
1859 | await get_edge_invalidation_candidates(
1860 | graph_driver,
1861 | [entity_edge_1],
1862 | search_filters,
1863 | min_score=0.9,
1864 | )
1865 | )[0]
1866 | assert len(edges) == 2
1867 | assert set({edge.name for edge in edges}) == {entity_edge_1.name, entity_edge_3.name}
1868 |
1869 |
1870 | @pytest.mark.asyncio
1871 | async def test_node_distance_reranker(graph_driver, mock_embedder):
1872 | if graph_driver.provider == GraphProvider.FALKORDB:
1873 | pytest.skip('Skipping as tests fail on Falkordb')
1874 |
1875 | # Create entity nodes
1876 | entity_node_1 = EntityNode(
1877 | name='test_entity_1',
1878 | labels=[],
1879 | created_at=datetime.now(),
1880 | group_id=group_id,
1881 | )
1882 | await entity_node_1.generate_name_embedding(mock_embedder)
1883 | entity_node_2 = EntityNode(
1884 | name='test_entity_2',
1885 | labels=[],
1886 | created_at=datetime.now(),
1887 | group_id=group_id,
1888 | )
1889 | await entity_node_2.generate_name_embedding(mock_embedder)
1890 | entity_node_3 = EntityNode(
1891 | name='test_entity_3',
1892 | labels=[],
1893 | created_at=datetime.now(),
1894 | group_id=group_id,
1895 | )
1896 | await entity_node_3.generate_name_embedding(mock_embedder)
1897 |
1898 | # Create entity edges
1899 | entity_edge_1 = EntityEdge(
1900 | source_node_uuid=entity_node_1.uuid,
1901 | target_node_uuid=entity_node_2.uuid,
1902 | name='RELATES_TO',
1903 | fact='test_entity_1 relates to test_entity_2',
1904 | created_at=datetime.now(),
1905 | group_id=group_id,
1906 | )
1907 | await entity_edge_1.generate_embedding(mock_embedder)
1908 |
1909 | # Save the graph
1910 | await entity_node_1.save(graph_driver)
1911 | await entity_node_2.save(graph_driver)
1912 | await entity_node_3.save(graph_driver)
1913 | await entity_edge_1.save(graph_driver)
1914 |
1915 | # Test reranker
1916 | reranked_uuids, reranked_scores = await node_distance_reranker(
1917 | graph_driver,
1918 | [entity_node_2.uuid, entity_node_3.uuid],
1919 | entity_node_1.uuid,
1920 | )
1921 | uuid_to_name = {
1922 | entity_node_1.uuid: entity_node_1.name,
1923 | entity_node_2.uuid: entity_node_2.name,
1924 | entity_node_3.uuid: entity_node_3.name,
1925 | }
1926 | names = [uuid_to_name[uuid] for uuid in reranked_uuids]
1927 | assert names == [entity_node_2.name, entity_node_3.name]
1928 | assert np.allclose(reranked_scores, [1.0, 0.0])
1929 |
1930 |
1931 | @pytest.mark.asyncio
1932 | async def test_episode_mentions_reranker(graph_driver, mock_embedder):
1933 | if graph_driver.provider == GraphProvider.FALKORDB:
1934 | pytest.skip('Skipping as tests fail on Falkordb')
1935 |
1936 | # Create episodic nodes
1937 | episodic_node_1 = EpisodicNode(
1938 | name='test_episodic_1',
1939 | content='test_content',
1940 | created_at=datetime.now(),
1941 | valid_at=datetime.now(),
1942 | group_id=group_id,
1943 | source=EpisodeType.message,
1944 | source_description='Description about Alice',
1945 | )
1946 |
1947 | # Create entity nodes
1948 | entity_node_1 = EntityNode(
1949 | name='test_entity_1',
1950 | labels=[],
1951 | created_at=datetime.now(),
1952 | group_id=group_id,
1953 | )
1954 | await entity_node_1.generate_name_embedding(mock_embedder)
1955 | entity_node_2 = EntityNode(
1956 | name='test_entity_2',
1957 | labels=[],
1958 | created_at=datetime.now(),
1959 | group_id=group_id,
1960 | )
1961 | await entity_node_2.generate_name_embedding(mock_embedder)
1962 |
1963 | # Create entity edges
1964 | episodic_edge_1 = EpisodicEdge(
1965 | source_node_uuid=episodic_node_1.uuid,
1966 | target_node_uuid=entity_node_1.uuid,
1967 | created_at=datetime.now(),
1968 | group_id=group_id,
1969 | )
1970 |
1971 | # Save the graph
1972 | await entity_node_1.save(graph_driver)
1973 | await entity_node_2.save(graph_driver)
1974 | await episodic_node_1.save(graph_driver)
1975 | await episodic_edge_1.save(graph_driver)
1976 |
1977 | # Test reranker
1978 | reranked_uuids, reranked_scores = await episode_mentions_reranker(
1979 | graph_driver,
1980 | [[entity_node_1.uuid, entity_node_2.uuid]],
1981 | )
1982 | uuid_to_name = {entity_node_1.uuid: entity_node_1.name, entity_node_2.uuid: entity_node_2.name}
1983 | names = [uuid_to_name[uuid] for uuid in reranked_uuids]
1984 | assert names == [entity_node_1.name, entity_node_2.name]
1985 | assert np.allclose(reranked_scores, [1.0, float('inf')])
1986 |
1987 |
1988 | @pytest.mark.asyncio
1989 | async def test_get_embeddings_for_edges(graph_driver, mock_embedder):
1990 | # Create entity nodes
1991 | entity_node_1 = EntityNode(
1992 | name='test_entity_1',
1993 | labels=[],
1994 | created_at=datetime.now(),
1995 | group_id=group_id,
1996 | )
1997 | await entity_node_1.generate_name_embedding(mock_embedder)
1998 | entity_node_2 = EntityNode(
1999 | name='test_entity_2',
2000 | labels=[],
2001 | created_at=datetime.now(),
2002 | group_id=group_id,
2003 | )
2004 | await entity_node_2.generate_name_embedding(mock_embedder)
2005 |
2006 | # Create entity edges
2007 | entity_edge_1 = EntityEdge(
2008 | source_node_uuid=entity_node_1.uuid,
2009 | target_node_uuid=entity_node_2.uuid,
2010 | name='RELATES_TO',
2011 | fact='test_entity_1 relates to test_entity_2',
2012 | created_at=datetime.now(),
2013 | group_id=group_id,
2014 | )
2015 | await entity_edge_1.generate_embedding(mock_embedder)
2016 |
2017 | # Save the graph
2018 | await entity_node_1.save(graph_driver)
2019 | await entity_node_2.save(graph_driver)
2020 | await entity_edge_1.save(graph_driver)
2021 |
2022 | # Get embeddings for edges
2023 | embeddings = await get_embeddings_for_edges(graph_driver, [entity_edge_1])
2024 | assert len(embeddings) == 1
2025 | assert entity_edge_1.uuid in embeddings
2026 | assert np.allclose(embeddings[entity_edge_1.uuid], entity_edge_1.fact_embedding)
2027 |
2028 |
2029 | @pytest.mark.asyncio
2030 | async def test_get_embeddings_for_nodes(graph_driver, mock_embedder):
2031 | # Create entity nodes
2032 | entity_node_1 = EntityNode(
2033 | name='test_entity_1',
2034 | labels=[],
2035 | created_at=datetime.now(),
2036 | group_id=group_id,
2037 | )
2038 | await entity_node_1.generate_name_embedding(mock_embedder)
2039 |
2040 | # Save the graph
2041 | await entity_node_1.save(graph_driver)
2042 |
2043 | # Get embeddings for edges
2044 | embeddings = await get_embeddings_for_nodes(graph_driver, [entity_node_1])
2045 | assert len(embeddings) == 1
2046 | assert entity_node_1.uuid in embeddings
2047 | assert np.allclose(embeddings[entity_node_1.uuid], entity_node_1.name_embedding)
2048 |
2049 |
2050 | @pytest.mark.asyncio
2051 | async def test_get_embeddings_for_communities(graph_driver, mock_embedder):
2052 | # Create community nodes
2053 | community_node_1 = CommunityNode(
2054 | name='test_community_1',
2055 | labels=[],
2056 | created_at=datetime.now(),
2057 | group_id=group_id,
2058 | )
2059 | await community_node_1.generate_name_embedding(mock_embedder)
2060 |
2061 | # Save the graph
2062 | await community_node_1.save(graph_driver)
2063 |
2064 | # Get embeddings for communities
2065 | embeddings = await get_embeddings_for_communities(graph_driver, [community_node_1])
2066 | assert len(embeddings) == 1
2067 | assert community_node_1.uuid in embeddings
2068 | assert np.allclose(embeddings[community_node_1.uuid], community_node_1.name_embedding)
2069 |
```