#
tokens: 41528/50000 6/234 files (page 8/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 8 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | from abc import ABC, abstractmethod
 20 | from datetime import datetime
 21 | from time import time
 22 | from typing import Any
 23 | from uuid import uuid4
 24 | 
 25 | from pydantic import BaseModel, Field
 26 | from typing_extensions import LiteralString
 27 | 
 28 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
 29 | from graphiti_core.embedder import EmbedderClient
 30 | from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
 31 | from graphiti_core.helpers import parse_db_date
 32 | from graphiti_core.models.edges.edge_db_queries import (
 33 |     COMMUNITY_EDGE_RETURN,
 34 |     EPISODIC_EDGE_RETURN,
 35 |     EPISODIC_EDGE_SAVE,
 36 |     get_community_edge_save_query,
 37 |     get_entity_edge_return_query,
 38 |     get_entity_edge_save_query,
 39 | )
 40 | from graphiti_core.nodes import Node
 41 | 
 42 | logger = logging.getLogger(__name__)
 43 | 
 44 | 
 45 | class Edge(BaseModel, ABC):
 46 |     uuid: str = Field(default_factory=lambda: str(uuid4()))
 47 |     group_id: str = Field(description='partition of the graph')
 48 |     source_node_uuid: str
 49 |     target_node_uuid: str
 50 |     created_at: datetime
 51 | 
 52 |     @abstractmethod
 53 |     async def save(self, driver: GraphDriver): ...
 54 | 
 55 |     async def delete(self, driver: GraphDriver):
 56 |         if driver.graph_operations_interface:
 57 |             return await driver.graph_operations_interface.edge_delete(self, driver)
 58 | 
 59 |         if driver.provider == GraphProvider.KUZU:
 60 |             await driver.execute_query(
 61 |                 """
 62 |                 MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
 63 |                 DELETE e
 64 |                 """,
 65 |                 uuid=self.uuid,
 66 |             )
 67 |             await driver.execute_query(
 68 |                 """
 69 |                 MATCH (e:RelatesToNode_ {uuid: $uuid})
 70 |                 DETACH DELETE e
 71 |                 """,
 72 |                 uuid=self.uuid,
 73 |             )
 74 |         else:
 75 |             await driver.execute_query(
 76 |                 """
 77 |                 MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
 78 |                 DELETE e
 79 |                 """,
 80 |                 uuid=self.uuid,
 81 |             )
 82 | 
 83 |         logger.debug(f'Deleted Edge: {self.uuid}')
 84 | 
 85 |     @classmethod
 86 |     async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
 87 |         if driver.graph_operations_interface:
 88 |             return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
 89 | 
 90 |         if driver.provider == GraphProvider.KUZU:
 91 |             await driver.execute_query(
 92 |                 """
 93 |                 MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
 94 |                 WHERE e.uuid IN $uuids
 95 |                 DELETE e
 96 |                 """,
 97 |                 uuids=uuids,
 98 |             )
 99 |             await driver.execute_query(
100 |                 """
101 |                 MATCH (e:RelatesToNode_)
102 |                 WHERE e.uuid IN $uuids
103 |                 DETACH DELETE e
104 |                 """,
105 |                 uuids=uuids,
106 |             )
107 |         else:
108 |             await driver.execute_query(
109 |                 """
110 |                 MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
111 |                 WHERE e.uuid IN $uuids
112 |                 DELETE e
113 |                 """,
114 |                 uuids=uuids,
115 |             )
116 | 
117 |         logger.debug(f'Deleted Edges: {uuids}')
118 | 
119 |     def __hash__(self):
120 |         return hash(self.uuid)
121 | 
122 |     def __eq__(self, other):
123 |         if isinstance(other, Node):
124 |             return self.uuid == other.uuid
125 |         return False
126 | 
127 |     @classmethod
128 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
129 | 
130 | 
131 | class EpisodicEdge(Edge):
132 |     async def save(self, driver: GraphDriver):
133 |         result = await driver.execute_query(
134 |             EPISODIC_EDGE_SAVE,
135 |             episode_uuid=self.source_node_uuid,
136 |             entity_uuid=self.target_node_uuid,
137 |             uuid=self.uuid,
138 |             group_id=self.group_id,
139 |             created_at=self.created_at,
140 |         )
141 | 
142 |         logger.debug(f'Saved edge to Graph: {self.uuid}')
143 | 
144 |         return result
145 | 
146 |     @classmethod
147 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
148 |         records, _, _ = await driver.execute_query(
149 |             """
150 |             MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
151 |             RETURN
152 |             """
153 |             + EPISODIC_EDGE_RETURN,
154 |             uuid=uuid,
155 |             routing_='r',
156 |         )
157 | 
158 |         edges = [get_episodic_edge_from_record(record) for record in records]
159 | 
160 |         if len(edges) == 0:
161 |             raise EdgeNotFoundError(uuid)
162 |         return edges[0]
163 | 
164 |     @classmethod
165 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
166 |         records, _, _ = await driver.execute_query(
167 |             """
168 |             MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
169 |             WHERE e.uuid IN $uuids
170 |             RETURN
171 |             """
172 |             + EPISODIC_EDGE_RETURN,
173 |             uuids=uuids,
174 |             routing_='r',
175 |         )
176 | 
177 |         edges = [get_episodic_edge_from_record(record) for record in records]
178 | 
179 |         if len(edges) == 0:
180 |             raise EdgeNotFoundError(uuids[0])
181 |         return edges
182 | 
183 |     @classmethod
184 |     async def get_by_group_ids(
185 |         cls,
186 |         driver: GraphDriver,
187 |         group_ids: list[str],
188 |         limit: int | None = None,
189 |         uuid_cursor: str | None = None,
190 |     ):
191 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
192 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
193 | 
194 |         records, _, _ = await driver.execute_query(
195 |             """
196 |             MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
197 |             WHERE e.group_id IN $group_ids
198 |             """
199 |             + cursor_query
200 |             + """
201 |             RETURN
202 |             """
203 |             + EPISODIC_EDGE_RETURN
204 |             + """
205 |             ORDER BY e.uuid DESC
206 |             """
207 |             + limit_query,
208 |             group_ids=group_ids,
209 |             uuid=uuid_cursor,
210 |             limit=limit,
211 |             routing_='r',
212 |         )
213 | 
214 |         edges = [get_episodic_edge_from_record(record) for record in records]
215 | 
216 |         if len(edges) == 0:
217 |             raise GroupsEdgesNotFoundError(group_ids)
218 |         return edges
219 | 
220 | 
221 | class EntityEdge(Edge):
222 |     name: str = Field(description='name of the edge, relation name')
223 |     fact: str = Field(description='fact representing the edge and nodes that it connects')
224 |     fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
225 |     episodes: list[str] = Field(
226 |         default=[],
227 |         description='list of episode ids that reference these entity edges',
228 |     )
229 |     expired_at: datetime | None = Field(
230 |         default=None, description='datetime of when the node was invalidated'
231 |     )
232 |     valid_at: datetime | None = Field(
233 |         default=None, description='datetime of when the fact became true'
234 |     )
235 |     invalid_at: datetime | None = Field(
236 |         default=None, description='datetime of when the fact stopped being true'
237 |     )
238 |     attributes: dict[str, Any] = Field(
239 |         default={}, description='Additional attributes of the edge. Dependent on edge name'
240 |     )
241 | 
242 |     async def generate_embedding(self, embedder: EmbedderClient):
243 |         start = time()
244 | 
245 |         text = self.fact.replace('\n', ' ')
246 |         self.fact_embedding = await embedder.create(input_data=[text])
247 | 
248 |         end = time()
249 |         logger.debug(f'embedded {text} in {end - start} ms')
250 | 
251 |         return self.fact_embedding
252 | 
253 |     async def load_fact_embedding(self, driver: GraphDriver):
254 |         if driver.graph_operations_interface:
255 |             return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256 | 
257 |         query = """
258 |             MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
259 |             RETURN e.fact_embedding AS fact_embedding
260 |         """
261 | 
262 |         if driver.provider == GraphProvider.NEPTUNE:
263 |             query = """
264 |                 MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
265 |                 RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
266 |             """
267 | 
268 |         if driver.provider == GraphProvider.KUZU:
269 |             query = """
270 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
271 |                 RETURN e.fact_embedding AS fact_embedding
272 |             """
273 | 
274 |         records, _, _ = await driver.execute_query(
275 |             query,
276 |             uuid=self.uuid,
277 |             routing_='r',
278 |         )
279 | 
280 |         if len(records) == 0:
281 |             raise EdgeNotFoundError(self.uuid)
282 | 
283 |         self.fact_embedding = records[0]['fact_embedding']
284 | 
285 |     async def save(self, driver: GraphDriver):
286 |         edge_data: dict[str, Any] = {
287 |             'source_uuid': self.source_node_uuid,
288 |             'target_uuid': self.target_node_uuid,
289 |             'uuid': self.uuid,
290 |             'name': self.name,
291 |             'group_id': self.group_id,
292 |             'fact': self.fact,
293 |             'fact_embedding': self.fact_embedding,
294 |             'episodes': self.episodes,
295 |             'created_at': self.created_at,
296 |             'expired_at': self.expired_at,
297 |             'valid_at': self.valid_at,
298 |             'invalid_at': self.invalid_at,
299 |         }
300 | 
301 |         if driver.provider == GraphProvider.KUZU:
302 |             edge_data['attributes'] = json.dumps(self.attributes)
303 |             result = await driver.execute_query(
304 |                 get_entity_edge_save_query(driver.provider),
305 |                 **edge_data,
306 |             )
307 |         else:
308 |             edge_data.update(self.attributes or {})
309 |             result = await driver.execute_query(
310 |                 get_entity_edge_save_query(driver.provider),
311 |                 edge_data=edge_data,
312 |             )
313 | 
314 |         logger.debug(f'Saved edge to Graph: {self.uuid}')
315 | 
316 |         return result
317 | 
318 |     @classmethod
319 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
320 |         match_query = """
321 |             MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
322 |         """
323 |         if driver.provider == GraphProvider.KUZU:
324 |             match_query = """
325 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
326 |             """
327 | 
328 |         records, _, _ = await driver.execute_query(
329 |             match_query
330 |             + """
331 |             RETURN
332 |             """
333 |             + get_entity_edge_return_query(driver.provider),
334 |             uuid=uuid,
335 |             routing_='r',
336 |         )
337 | 
338 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
339 | 
340 |         if len(edges) == 0:
341 |             raise EdgeNotFoundError(uuid)
342 |         return edges[0]
343 | 
344 |     @classmethod
345 |     async def get_between_nodes(
346 |         cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
347 |     ):
348 |         match_query = """
349 |             MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
350 |         """
351 |         if driver.provider == GraphProvider.KUZU:
352 |             match_query = """
353 |                 MATCH (n:Entity {uuid: $source_node_uuid})
354 |                       -[:RELATES_TO]->(e:RelatesToNode_)
355 |                       -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
356 |             """
357 | 
358 |         records, _, _ = await driver.execute_query(
359 |             match_query
360 |             + """
361 |             RETURN
362 |             """
363 |             + get_entity_edge_return_query(driver.provider),
364 |             source_node_uuid=source_node_uuid,
365 |             target_node_uuid=target_node_uuid,
366 |             routing_='r',
367 |         )
368 | 
369 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
370 | 
371 |         return edges
372 | 
373 |     @classmethod
374 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
375 |         if len(uuids) == 0:
376 |             return []
377 | 
378 |         match_query = """
379 |             MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
380 |         """
381 |         if driver.provider == GraphProvider.KUZU:
382 |             match_query = """
383 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
384 |             """
385 | 
386 |         records, _, _ = await driver.execute_query(
387 |             match_query
388 |             + """
389 |             WHERE e.uuid IN $uuids
390 |             RETURN
391 |             """
392 |             + get_entity_edge_return_query(driver.provider),
393 |             uuids=uuids,
394 |             routing_='r',
395 |         )
396 | 
397 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
398 | 
399 |         return edges
400 | 
401 |     @classmethod
402 |     async def get_by_group_ids(
403 |         cls,
404 |         driver: GraphDriver,
405 |         group_ids: list[str],
406 |         limit: int | None = None,
407 |         uuid_cursor: str | None = None,
408 |         with_embeddings: bool = False,
409 |     ):
410 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
411 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
412 |         with_embeddings_query: LiteralString = (
413 |             """,
414 |                 e.fact_embedding AS fact_embedding
415 |                 """
416 |             if with_embeddings
417 |             else ''
418 |         )
419 | 
420 |         match_query = """
421 |             MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
422 |         """
423 |         if driver.provider == GraphProvider.KUZU:
424 |             match_query = """
425 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
426 |             """
427 | 
428 |         records, _, _ = await driver.execute_query(
429 |             match_query
430 |             + """
431 |             WHERE e.group_id IN $group_ids
432 |             """
433 |             + cursor_query
434 |             + """
435 |             RETURN
436 |             """
437 |             + get_entity_edge_return_query(driver.provider)
438 |             + with_embeddings_query
439 |             + """
440 |             ORDER BY e.uuid DESC
441 |             """
442 |             + limit_query,
443 |             group_ids=group_ids,
444 |             uuid=uuid_cursor,
445 |             limit=limit,
446 |             routing_='r',
447 |         )
448 | 
449 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
450 | 
451 |         if len(edges) == 0:
452 |             raise GroupsEdgesNotFoundError(group_ids)
453 |         return edges
454 | 
455 |     @classmethod
456 |     async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
457 |         match_query = """
458 |             MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
459 |         """
460 |         if driver.provider == GraphProvider.KUZU:
461 |             match_query = """
462 |                 MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
463 |             """
464 | 
465 |         records, _, _ = await driver.execute_query(
466 |             match_query
467 |             + """
468 |             RETURN
469 |             """
470 |             + get_entity_edge_return_query(driver.provider),
471 |             node_uuid=node_uuid,
472 |             routing_='r',
473 |         )
474 | 
475 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
476 | 
477 |         return edges
478 | 
479 | 
480 | class CommunityEdge(Edge):
481 |     async def save(self, driver: GraphDriver):
482 |         result = await driver.execute_query(
483 |             get_community_edge_save_query(driver.provider),
484 |             community_uuid=self.source_node_uuid,
485 |             entity_uuid=self.target_node_uuid,
486 |             uuid=self.uuid,
487 |             group_id=self.group_id,
488 |             created_at=self.created_at,
489 |         )
490 | 
491 |         logger.debug(f'Saved edge to Graph: {self.uuid}')
492 | 
493 |         return result
494 | 
495 |     @classmethod
496 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
497 |         records, _, _ = await driver.execute_query(
498 |             """
499 |             MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
500 |             RETURN
501 |             """
502 |             + COMMUNITY_EDGE_RETURN,
503 |             uuid=uuid,
504 |             routing_='r',
505 |         )
506 | 
507 |         edges = [get_community_edge_from_record(record) for record in records]
508 | 
509 |         return edges[0]
510 | 
511 |     @classmethod
512 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
513 |         records, _, _ = await driver.execute_query(
514 |             """
515 |             MATCH (n:Community)-[e:HAS_MEMBER]->(m)
516 |             WHERE e.uuid IN $uuids
517 |             RETURN
518 |             """
519 |             + COMMUNITY_EDGE_RETURN,
520 |             uuids=uuids,
521 |             routing_='r',
522 |         )
523 | 
524 |         edges = [get_community_edge_from_record(record) for record in records]
525 | 
526 |         return edges
527 | 
528 |     @classmethod
529 |     async def get_by_group_ids(
530 |         cls,
531 |         driver: GraphDriver,
532 |         group_ids: list[str],
533 |         limit: int | None = None,
534 |         uuid_cursor: str | None = None,
535 |     ):
536 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
537 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
538 | 
539 |         records, _, _ = await driver.execute_query(
540 |             """
541 |             MATCH (n:Community)-[e:HAS_MEMBER]->(m)
542 |             WHERE e.group_id IN $group_ids
543 |             """
544 |             + cursor_query
545 |             + """
546 |             RETURN
547 |             """
548 |             + COMMUNITY_EDGE_RETURN
549 |             + """
550 |             ORDER BY e.uuid DESC
551 |             """
552 |             + limit_query,
553 |             group_ids=group_ids,
554 |             uuid=uuid_cursor,
555 |             limit=limit,
556 |             routing_='r',
557 |         )
558 | 
559 |         edges = [get_community_edge_from_record(record) for record in records]
560 | 
561 |         return edges
562 | 
563 | 
564 | # Edge helpers
565 | def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
566 |     return EpisodicEdge(
567 |         uuid=record['uuid'],
568 |         group_id=record['group_id'],
569 |         source_node_uuid=record['source_node_uuid'],
570 |         target_node_uuid=record['target_node_uuid'],
571 |         created_at=parse_db_date(record['created_at']),  # type: ignore
572 |     )
573 | 
574 | 
575 | def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
576 |     episodes = record['episodes']
577 |     if provider == GraphProvider.KUZU:
578 |         attributes = json.loads(record['attributes']) if record['attributes'] else {}
579 |     else:
580 |         attributes = record['attributes']
581 |         attributes.pop('uuid', None)
582 |         attributes.pop('source_node_uuid', None)
583 |         attributes.pop('target_node_uuid', None)
584 |         attributes.pop('fact', None)
585 |         attributes.pop('fact_embedding', None)
586 |         attributes.pop('name', None)
587 |         attributes.pop('group_id', None)
588 |         attributes.pop('episodes', None)
589 |         attributes.pop('created_at', None)
590 |         attributes.pop('expired_at', None)
591 |         attributes.pop('valid_at', None)
592 |         attributes.pop('invalid_at', None)
593 | 
594 |     edge = EntityEdge(
595 |         uuid=record['uuid'],
596 |         source_node_uuid=record['source_node_uuid'],
597 |         target_node_uuid=record['target_node_uuid'],
598 |         fact=record['fact'],
599 |         fact_embedding=record.get('fact_embedding'),
600 |         name=record['name'],
601 |         group_id=record['group_id'],
602 |         episodes=episodes,
603 |         created_at=parse_db_date(record['created_at']),  # type: ignore
604 |         expired_at=parse_db_date(record['expired_at']),
605 |         valid_at=parse_db_date(record['valid_at']),
606 |         invalid_at=parse_db_date(record['invalid_at']),
607 |         attributes=attributes,
608 |     )
609 | 
610 |     return edge
611 | 
612 | 
613 | def get_community_edge_from_record(record: Any):
614 |     return CommunityEdge(
615 |         uuid=record['uuid'],
616 |         group_id=record['group_id'],
617 |         source_node_uuid=record['source_node_uuid'],
618 |         target_node_uuid=record['target_node_uuid'],
619 |         created_at=parse_db_date(record['created_at']),  # type: ignore
620 |     )
621 | 
622 | 
623 | async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
624 |     # filter out falsey values from edges
625 |     filtered_edges = [edge for edge in edges if edge.fact]
626 | 
627 |     if len(filtered_edges) == 0:
628 |         return
629 |     fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
630 |     for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
631 |         edge.fact_embedding = fact_embedding
632 | 
```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | from collections import defaultdict
  3 | from unittest.mock import AsyncMock, MagicMock
  4 | 
  5 | import pytest
  6 | 
  7 | from graphiti_core.graphiti_types import GraphitiClients
  8 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
  9 | from graphiti_core.search.search_config import SearchResults
 10 | from graphiti_core.utils.datetime_utils import utc_now
 11 | from graphiti_core.utils.maintenance.dedup_helpers import (
 12 |     DedupCandidateIndexes,
 13 |     DedupResolutionState,
 14 |     _build_candidate_indexes,
 15 |     _cached_shingles,
 16 |     _has_high_entropy,
 17 |     _hash_shingle,
 18 |     _jaccard_similarity,
 19 |     _lsh_bands,
 20 |     _minhash_signature,
 21 |     _name_entropy,
 22 |     _normalize_name_for_fuzzy,
 23 |     _normalize_string_exact,
 24 |     _resolve_with_similarity,
 25 |     _shingles,
 26 | )
 27 | from graphiti_core.utils.maintenance.node_operations import (
 28 |     _collect_candidate_nodes,
 29 |     _resolve_with_llm,
 30 |     extract_attributes_from_node,
 31 |     extract_attributes_from_nodes,
 32 |     resolve_extracted_nodes,
 33 | )
 34 | 
 35 | 
 36 | def _make_clients():
 37 |     driver = MagicMock()
 38 |     embedder = MagicMock()
 39 |     cross_encoder = MagicMock()
 40 |     llm_client = MagicMock()
 41 |     llm_generate = AsyncMock()
 42 |     llm_client.generate_response = llm_generate
 43 | 
 44 |     clients = GraphitiClients.model_construct(  # bypass validation to allow test doubles
 45 |         driver=driver,
 46 |         embedder=embedder,
 47 |         cross_encoder=cross_encoder,
 48 |         llm_client=llm_client,
 49 |     )
 50 | 
 51 |     return clients, llm_generate
 52 | 
 53 | 
 54 | def _make_episode(group_id: str = 'group'):
 55 |     return EpisodicNode(
 56 |         name='episode',
 57 |         group_id=group_id,
 58 |         source=EpisodeType.message,
 59 |         source_description='test',
 60 |         content='content',
 61 |         valid_at=utc_now(),
 62 |     )
 63 | 
 64 | 
 65 | @pytest.mark.asyncio
 66 | async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
 67 |     clients, llm_generate = _make_clients()
 68 | 
 69 |     candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
 70 |     extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
 71 | 
 72 |     async def fake_search(*_, **__):
 73 |         return SearchResults(nodes=[candidate])
 74 | 
 75 |     monkeypatch.setattr(
 76 |         'graphiti_core.utils.maintenance.node_operations.search',
 77 |         fake_search,
 78 |     )
 79 |     monkeypatch.setattr(
 80 |         'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
 81 |         AsyncMock(return_value=[]),
 82 |     )
 83 | 
 84 |     resolved, uuid_map, _ = await resolve_extracted_nodes(
 85 |         clients,
 86 |         [extracted],
 87 |         episode=_make_episode(),
 88 |         previous_episodes=[],
 89 |     )
 90 | 
 91 |     assert resolved[0].uuid == candidate.uuid
 92 |     assert uuid_map[extracted.uuid] == candidate.uuid
 93 |     llm_generate.assert_not_awaited()
 94 | 
 95 | 
 96 | @pytest.mark.asyncio
 97 | async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
 98 |     clients, llm_generate = _make_clients()
 99 |     llm_generate.return_value = {
100 |         'entity_resolutions': [
101 |             {
102 |                 'id': 0,
103 |                 'duplicate_idx': -1,
104 |                 'name': 'Joe',
105 |                 'duplicates': [],
106 |             }
107 |         ]
108 |     }
109 | 
110 |     extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])
111 | 
112 |     async def fake_search(*_, **__):
113 |         return SearchResults(nodes=[])
114 | 
115 |     monkeypatch.setattr(
116 |         'graphiti_core.utils.maintenance.node_operations.search',
117 |         fake_search,
118 |     )
119 |     monkeypatch.setattr(
120 |         'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
121 |         AsyncMock(return_value=[]),
122 |     )
123 | 
124 |     resolved, uuid_map, _ = await resolve_extracted_nodes(
125 |         clients,
126 |         [extracted],
127 |         episode=_make_episode(),
128 |         previous_episodes=[],
129 |     )
130 | 
131 |     assert resolved[0].uuid == extracted.uuid
132 |     assert uuid_map[extracted.uuid] == extracted.uuid
133 |     llm_generate.assert_awaited()
134 | 
135 | 
136 | @pytest.mark.asyncio
137 | async def test_resolve_nodes_fuzzy_match(monkeypatch):
138 |     clients, llm_generate = _make_clients()
139 | 
140 |     candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
141 |     extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
142 | 
143 |     async def fake_search(*_, **__):
144 |         return SearchResults(nodes=[candidate])
145 | 
146 |     monkeypatch.setattr(
147 |         'graphiti_core.utils.maintenance.node_operations.search',
148 |         fake_search,
149 |     )
150 |     monkeypatch.setattr(
151 |         'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
152 |         AsyncMock(return_value=[]),
153 |     )
154 | 
155 |     resolved, uuid_map, _ = await resolve_extracted_nodes(
156 |         clients,
157 |         [extracted],
158 |         episode=_make_episode(),
159 |         previous_episodes=[],
160 |     )
161 | 
162 |     assert resolved[0].uuid == candidate.uuid
163 |     assert uuid_map[extracted.uuid] == candidate.uuid
164 |     llm_generate.assert_not_awaited()
165 | 
166 | 
167 | @pytest.mark.asyncio
168 | async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
169 |     clients, _ = _make_clients()
170 | 
171 |     candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
172 |     override_duplicate = EntityNode(
173 |         uuid=candidate.uuid,
174 |         name='Alice Alt',
175 |         group_id='group',
176 |         labels=['Entity'],
177 |     )
178 |     extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])
179 | 
180 |     search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
181 |     monkeypatch.setattr(
182 |         'graphiti_core.utils.maintenance.node_operations.search',
183 |         search_mock,
184 |     )
185 | 
186 |     result = await _collect_candidate_nodes(
187 |         clients,
188 |         [extracted],
189 |         existing_nodes_override=[override_duplicate],
190 |     )
191 | 
192 |     assert len(result) == 1
193 |     assert result[0].uuid == candidate.uuid
194 |     search_mock.assert_awaited()
195 | 
196 | 
197 | def test_build_candidate_indexes_populates_structures():
198 |     candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])
199 | 
200 |     indexes = _build_candidate_indexes([candidate])
201 | 
202 |     normalized_key = candidate.name.lower()
203 |     assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
204 |     assert indexes.nodes_by_uuid[candidate.uuid] is candidate
205 |     assert candidate.uuid in indexes.shingles_by_candidate
206 |     assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())
207 | 
208 | 
209 | def test_normalize_helpers():
210 |     assert _normalize_string_exact('  Alice   Smith ') == 'alice smith'
211 |     assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'
212 | 
213 | 
214 | def test_name_entropy_variants():
215 |     assert _name_entropy('alice') > _name_entropy('aaaaa')
216 |     assert _name_entropy('') == 0.0
217 | 
218 | 
219 | def test_has_high_entropy_rules():
220 |     assert _has_high_entropy('meaningful name') is True
221 |     assert _has_high_entropy('aa') is False
222 | 
223 | 
224 | def test_shingles_and_cache():
225 |     raw = 'alice'
226 |     shingle_set = _shingles(raw)
227 |     assert shingle_set == {'ali', 'lic', 'ice'}
228 |     assert _cached_shingles(raw) == shingle_set
229 |     assert _cached_shingles(raw) is _cached_shingles(raw)
230 | 
231 | 
232 | def test_hash_minhash_and_lsh():
233 |     shingles = {'abc', 'bcd', 'cde'}
234 |     signature = _minhash_signature(shingles)
235 |     assert len(signature) == 32
236 |     bands = _lsh_bands(signature)
237 |     assert all(len(band) == 4 for band in bands)
238 |     hashed = {_hash_shingle(s, 0) for s in shingles}
239 |     assert len(hashed) == len(shingles)
240 | 
241 | 
242 | def test_jaccard_similarity_edges():
243 |     a = {'a', 'b'}
244 |     b = {'a', 'c'}
245 |     assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
246 |     assert _jaccard_similarity(set(), set()) == 1.0
247 |     assert _jaccard_similarity(a, set()) == 0.0
248 | 
249 | 
250 | def test_resolve_with_similarity_exact_match_updates_state():
251 |     candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
252 |     extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
253 | 
254 |     indexes = _build_candidate_indexes([candidate])
255 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
256 | 
257 |     _resolve_with_similarity([extracted], indexes, state)
258 | 
259 |     assert state.resolved_nodes[0].uuid == candidate.uuid
260 |     assert state.uuid_map[extracted.uuid] == candidate.uuid
261 |     assert state.unresolved_indices == []
262 |     assert state.duplicate_pairs == [(extracted, candidate)]
263 | 
264 | 
265 | def test_resolve_with_similarity_low_entropy_defers_resolution():
266 |     extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
267 |     indexes = DedupCandidateIndexes(
268 |         existing_nodes=[],
269 |         nodes_by_uuid={},
270 |         normalized_existing=defaultdict(list),
271 |         shingles_by_candidate={},
272 |         lsh_buckets=defaultdict(list),
273 |     )
274 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
275 | 
276 |     _resolve_with_similarity([extracted], indexes, state)
277 | 
278 |     assert state.resolved_nodes[0] is None
279 |     assert state.unresolved_indices == [0]
280 |     assert state.duplicate_pairs == []
281 | 
282 | 
283 | def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
284 |     candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
285 |     candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
286 |     extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
287 | 
288 |     indexes = _build_candidate_indexes([candidate1, candidate2])
289 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
290 | 
291 |     _resolve_with_similarity([extracted], indexes, state)
292 | 
293 |     assert state.resolved_nodes[0] is None
294 |     assert state.unresolved_indices == [0]
295 |     assert state.duplicate_pairs == []
296 | 
297 | 
298 | @pytest.mark.asyncio
299 | async def test_resolve_with_llm_updates_unresolved(monkeypatch):
300 |     extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
301 |     candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
302 | 
303 |     indexes = _build_candidate_indexes([candidate])
304 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
305 | 
306 |     captured_context = {}
307 | 
308 |     def fake_prompt_nodes(context):
309 |         captured_context.update(context)
310 |         return ['prompt']
311 | 
312 |     monkeypatch.setattr(
313 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
314 |         fake_prompt_nodes,
315 |     )
316 | 
317 |     async def fake_generate_response(*_, **__):
318 |         return {
319 |             'entity_resolutions': [
320 |                 {
321 |                     'id': 0,
322 |                     'duplicate_idx': 0,
323 |                     'name': 'Dizzy Gillespie',
324 |                     'duplicates': [0],
325 |                 }
326 |             ]
327 |         }
328 | 
329 |     llm_client = MagicMock()
330 |     llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)
331 | 
332 |     await _resolve_with_llm(
333 |         llm_client,
334 |         [extracted],
335 |         indexes,
336 |         state,
337 |         episode=_make_episode(),
338 |         previous_episodes=[],
339 |         entity_types=None,
340 |     )
341 | 
342 |     assert state.resolved_nodes[0].uuid == candidate.uuid
343 |     assert state.uuid_map[extracted.uuid] == candidate.uuid
344 |     assert captured_context['existing_nodes'][0]['idx'] == 0
345 |     assert isinstance(captured_context['existing_nodes'], list)
346 |     assert state.duplicate_pairs == [(extracted, candidate)]
347 | 
348 | 
349 | @pytest.mark.asyncio
350 | async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
351 |     extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
352 | 
353 |     indexes = _build_candidate_indexes([])
354 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
355 | 
356 |     monkeypatch.setattr(
357 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
358 |         lambda context: ['prompt'],
359 |     )
360 | 
361 |     llm_client = MagicMock()
362 |     llm_client.generate_response = AsyncMock(
363 |         return_value={
364 |             'entity_resolutions': [
365 |                 {
366 |                     'id': 5,
367 |                     'duplicate_idx': -1,
368 |                     'name': 'Dexter',
369 |                     'duplicates': [],
370 |                 }
371 |             ]
372 |         }
373 |     )
374 | 
375 |     with caplog.at_level(logging.WARNING):
376 |         await _resolve_with_llm(
377 |             llm_client,
378 |             [extracted],
379 |             indexes,
380 |             state,
381 |             episode=_make_episode(),
382 |             previous_episodes=[],
383 |             entity_types=None,
384 |         )
385 | 
386 |     assert state.resolved_nodes[0] is None
387 |     assert 'Skipping invalid LLM dedupe id 5' in caplog.text
388 | 
389 | 
390 | @pytest.mark.asyncio
391 | async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
392 |     extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
393 |     candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
394 | 
395 |     indexes = _build_candidate_indexes([candidate])
396 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
397 | 
398 |     monkeypatch.setattr(
399 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
400 |         lambda context: ['prompt'],
401 |     )
402 | 
403 |     llm_client = MagicMock()
404 |     llm_client.generate_response = AsyncMock(
405 |         return_value={
406 |             'entity_resolutions': [
407 |                 {
408 |                     'id': 0,
409 |                     'duplicate_idx': 0,
410 |                     'name': 'Dizzy Gillespie',
411 |                     'duplicates': [0],
412 |                 },
413 |                 {
414 |                     'id': 0,
415 |                     'duplicate_idx': -1,
416 |                     'name': 'Dizzy',
417 |                     'duplicates': [],
418 |                 },
419 |             ]
420 |         }
421 |     )
422 | 
423 |     await _resolve_with_llm(
424 |         llm_client,
425 |         [extracted],
426 |         indexes,
427 |         state,
428 |         episode=_make_episode(),
429 |         previous_episodes=[],
430 |         entity_types=None,
431 |     )
432 | 
433 |     assert state.resolved_nodes[0].uuid == candidate.uuid
434 |     assert state.uuid_map[extracted.uuid] == candidate.uuid
435 |     assert state.duplicate_pairs == [(extracted, candidate)]
436 | 
437 | 
438 | @pytest.mark.asyncio
439 | async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
440 |     extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
441 | 
442 |     indexes = _build_candidate_indexes([])
443 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
444 | 
445 |     monkeypatch.setattr(
446 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
447 |         lambda context: ['prompt'],
448 |     )
449 | 
450 |     llm_client = MagicMock()
451 |     llm_client.generate_response = AsyncMock(
452 |         return_value={
453 |             'entity_resolutions': [
454 |                 {
455 |                     'id': 0,
456 |                     'duplicate_idx': 10,
457 |                     'name': 'Dexter',
458 |                     'duplicates': [],
459 |                 }
460 |             ]
461 |         }
462 |     )
463 | 
464 |     await _resolve_with_llm(
465 |         llm_client,
466 |         [extracted],
467 |         indexes,
468 |         state,
469 |         episode=_make_episode(),
470 |         previous_episodes=[],
471 |         entity_types=None,
472 |     )
473 | 
474 |     assert state.resolved_nodes[0] == extracted
475 |     assert state.uuid_map[extracted.uuid] == extracted.uuid
476 |     assert state.duplicate_pairs == []
477 | 
478 | 
479 | @pytest.mark.asyncio
480 | async def test_extract_attributes_without_callback_generates_summary():
481 |     """Test that summary is generated when no callback is provided (default behavior)."""
482 |     llm_client = MagicMock()
483 |     llm_client.generate_response = AsyncMock(
484 |         return_value={'summary': 'Generated summary', 'attributes': {}}
485 |     )
486 | 
487 |     node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
488 |     episode = _make_episode()
489 | 
490 |     result = await extract_attributes_from_node(
491 |         llm_client,
492 |         node,
493 |         episode=episode,
494 |         previous_episodes=[],
495 |         entity_type=None,
496 |         should_summarize_node=None,  # No callback provided
497 |     )
498 | 
499 |     # Summary should be generated
500 |     assert result.summary == 'Generated summary'
501 |     # LLM should have been called for summary
502 |     assert llm_client.generate_response.call_count == 1
503 | 
504 | 
505 | @pytest.mark.asyncio
506 | async def test_extract_attributes_with_callback_skip_summary():
507 |     """Test that summary is NOT regenerated when callback returns False."""
508 |     llm_client = MagicMock()
509 |     llm_client.generate_response = AsyncMock(
510 |         return_value={'summary': 'This should not be used', 'attributes': {}}
511 |     )
512 | 
513 |     node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
514 |     episode = _make_episode()
515 | 
516 |     # Callback that always returns False (skip summary generation)
517 |     async def skip_summary_filter(node: EntityNode) -> bool:
518 |         return False
519 | 
520 |     result = await extract_attributes_from_node(
521 |         llm_client,
522 |         node,
523 |         episode=episode,
524 |         previous_episodes=[],
525 |         entity_type=None,
526 |         should_summarize_node=skip_summary_filter,
527 |     )
528 | 
529 |     # Summary should remain unchanged
530 |     assert result.summary == 'Old summary'
531 |     # LLM should NOT have been called for summary
532 |     assert llm_client.generate_response.call_count == 0
533 | 
534 | 
535 | @pytest.mark.asyncio
536 | async def test_extract_attributes_with_callback_generate_summary():
537 |     """Test that summary is regenerated when callback returns True."""
538 |     llm_client = MagicMock()
539 |     llm_client.generate_response = AsyncMock(
540 |         return_value={'summary': 'New generated summary', 'attributes': {}}
541 |     )
542 | 
543 |     node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
544 |     episode = _make_episode()
545 | 
546 |     # Callback that always returns True (generate summary)
547 |     async def generate_summary_filter(node: EntityNode) -> bool:
548 |         return True
549 | 
550 |     result = await extract_attributes_from_node(
551 |         llm_client,
552 |         node,
553 |         episode=episode,
554 |         previous_episodes=[],
555 |         entity_type=None,
556 |         should_summarize_node=generate_summary_filter,
557 |     )
558 | 
559 |     # Summary should be updated
560 |     assert result.summary == 'New generated summary'
561 |     # LLM should have been called for summary
562 |     assert llm_client.generate_response.call_count == 1
563 | 
564 | 
565 | @pytest.mark.asyncio
566 | async def test_extract_attributes_with_selective_callback():
567 |     """Test callback that selectively skips summaries based on node properties."""
568 |     llm_client = MagicMock()
569 |     llm_client.generate_response = AsyncMock(
570 |         return_value={'summary': 'Generated summary', 'attributes': {}}
571 |     )
572 | 
573 |     user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
574 |     topic_node = EntityNode(
575 |         name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
576 |     )
577 | 
578 |     episode = _make_episode()
579 | 
580 |     # Callback that skips User nodes but generates for others
581 |     async def selective_filter(node: EntityNode) -> bool:
582 |         return 'User' not in node.labels
583 | 
584 |     result_user = await extract_attributes_from_node(
585 |         llm_client,
586 |         user_node,
587 |         episode=episode,
588 |         previous_episodes=[],
589 |         entity_type=None,
590 |         should_summarize_node=selective_filter,
591 |     )
592 | 
593 |     result_topic = await extract_attributes_from_node(
594 |         llm_client,
595 |         topic_node,
596 |         episode=episode,
597 |         previous_episodes=[],
598 |         entity_type=None,
599 |         should_summarize_node=selective_filter,
600 |     )
601 | 
602 |     # User summary should remain unchanged
603 |     assert result_user.summary == 'Old'
604 |     # Topic summary should be generated
605 |     assert result_topic.summary == 'Generated summary'
606 |     # LLM should have been called only once (for topic)
607 |     assert llm_client.generate_response.call_count == 1
608 | 
609 | 
610 | @pytest.mark.asyncio
611 | async def test_extract_attributes_from_nodes_with_callback():
612 |     """Test that callback is properly passed through extract_attributes_from_nodes."""
613 |     clients, _ = _make_clients()
614 |     clients.llm_client.generate_response = AsyncMock(
615 |         return_value={'summary': 'New summary', 'attributes': {}}
616 |     )
617 |     clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
618 |     clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
619 | 
620 |     node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
621 |     node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
622 | 
623 |     episode = _make_episode()
624 | 
625 |     call_tracker = []
626 | 
627 |     # Callback that tracks which nodes it's called with
628 |     async def tracking_filter(node: EntityNode) -> bool:
629 |         call_tracker.append(node.name)
630 |         return 'User' not in node.labels
631 | 
632 |     results = await extract_attributes_from_nodes(
633 |         clients,
634 |         [node1, node2],
635 |         episode=episode,
636 |         previous_episodes=[],
637 |         entity_types=None,
638 |         should_summarize_node=tracking_filter,
639 |     )
640 | 
641 |     # Callback should have been called for both nodes
642 |     assert len(call_tracker) == 2
643 |     assert 'Node1' in call_tracker
644 |     assert 'Node2' in call_tracker
645 | 
646 |     # Node1 (User) should keep old summary, Node2 (Topic) should get new summary
647 |     node1_result = next(n for n in results if n.name == 'Node1')
648 |     node2_result = next(n for n in results if n.name == 'Node2')
649 | 
650 |     assert node1_result.summary == 'Old1'
651 |     assert node2_result.summary == 'New summary'
652 | 
```

--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
 18 | 
 19 | from unittest.mock import AsyncMock, MagicMock, patch
 20 | 
 21 | import pytest
 22 | from pydantic import BaseModel
 23 | 
 24 | from graphiti_core.llm_client.config import LLMConfig, ModelSize
 25 | from graphiti_core.llm_client.errors import RateLimitError
 26 | from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
 27 | from graphiti_core.prompts.models import Message
 28 | 
 29 | 
 30 | # Test model for response testing
 31 | class ResponseModel(BaseModel):
 32 |     """Test model for response testing."""
 33 | 
 34 |     test_field: str
 35 |     optional_field: int = 0
 36 | 
 37 | 
 38 | @pytest.fixture
 39 | def mock_gemini_client():
 40 |     """Fixture to mock the Google Gemini client."""
 41 |     with patch('google.genai.Client') as mock_client:
 42 |         # Setup mock instance and its methods
 43 |         mock_instance = mock_client.return_value
 44 |         mock_instance.aio = MagicMock()
 45 |         mock_instance.aio.models = MagicMock()
 46 |         mock_instance.aio.models.generate_content = AsyncMock()
 47 |         yield mock_instance
 48 | 
 49 | 
 50 | @pytest.fixture
 51 | def gemini_client(mock_gemini_client):
 52 |     """Fixture to create a GeminiClient with a mocked client."""
 53 |     config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
 54 |     client = GeminiClient(config=config, cache=False)
 55 |     # Replace the client's client with our mock to ensure we're using the mock
 56 |     client.client = mock_gemini_client
 57 |     return client
 58 | 
 59 | 
 60 | class TestGeminiClientInitialization:
 61 |     """Tests for GeminiClient initialization."""
 62 | 
 63 |     @patch('google.genai.Client')
 64 |     def test_init_with_config(self, mock_client):
 65 |         """Test initialization with a config object."""
 66 |         config = LLMConfig(
 67 |             api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
 68 |         )
 69 |         client = GeminiClient(config=config, cache=False, max_tokens=1000)
 70 | 
 71 |         assert client.config == config
 72 |         assert client.model == 'test-model'
 73 |         assert client.temperature == 0.5
 74 |         assert client.max_tokens == 1000
 75 | 
 76 |     @patch('google.genai.Client')
 77 |     def test_init_with_default_model(self, mock_client):
 78 |         """Test initialization with default model when none is provided."""
 79 |         config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
 80 |         client = GeminiClient(config=config, cache=False)
 81 | 
 82 |         assert client.model == DEFAULT_MODEL
 83 | 
 84 |     @patch('google.genai.Client')
 85 |     def test_init_without_config(self, mock_client):
 86 |         """Test initialization without a config uses defaults."""
 87 |         client = GeminiClient(cache=False)
 88 | 
 89 |         assert client.config is not None
 90 |         # When no config.model is set, it will be None, not DEFAULT_MODEL
 91 |         assert client.model is None
 92 | 
 93 |     @patch('google.genai.Client')
 94 |     def test_init_with_thinking_config(self, mock_client):
 95 |         """Test initialization with thinking config."""
 96 |         with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
 97 |             thinking_config = mock_thinking_config.return_value
 98 |             client = GeminiClient(thinking_config=thinking_config)
 99 |             assert client.thinking_config == thinking_config
100 | 
101 | 
102 | class TestGeminiClientGenerateResponse:
103 |     """Tests for GeminiClient generate_response method."""
104 | 
105 |     @pytest.mark.asyncio
106 |     async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
107 |         """Test successful response generation with simple text."""
108 |         # Setup mock response
109 |         mock_response = MagicMock()
110 |         mock_response.text = 'Test response text'
111 |         mock_response.candidates = []
112 |         mock_response.prompt_feedback = None
113 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
114 | 
115 |         # Call method
116 |         messages = [Message(role='user', content='Test message')]
117 |         result = await gemini_client.generate_response(messages)
118 | 
119 |         # Assertions
120 |         assert isinstance(result, dict)
121 |         assert result['content'] == 'Test response text'
122 |         mock_gemini_client.aio.models.generate_content.assert_called_once()
123 | 
124 |     @pytest.mark.asyncio
125 |     async def test_generate_response_with_structured_output(
126 |         self, gemini_client, mock_gemini_client
127 |     ):
128 |         """Test response generation with structured output."""
129 |         # Setup mock response
130 |         mock_response = MagicMock()
131 |         mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
132 |         mock_response.candidates = []
133 |         mock_response.prompt_feedback = None
134 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
135 | 
136 |         # Call method
137 |         messages = [
138 |             Message(role='system', content='System message'),
139 |             Message(role='user', content='User message'),
140 |         ]
141 |         result = await gemini_client.generate_response(
142 |             messages=messages, response_model=ResponseModel
143 |         )
144 | 
145 |         # Assertions
146 |         assert isinstance(result, dict)
147 |         assert result['test_field'] == 'test_value'
148 |         assert result['optional_field'] == 42
149 |         mock_gemini_client.aio.models.generate_content.assert_called_once()
150 | 
151 |     @pytest.mark.asyncio
152 |     async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
153 |         """Test response generation with system message handling."""
154 |         # Setup mock response
155 |         mock_response = MagicMock()
156 |         mock_response.text = 'Response with system context'
157 |         mock_response.candidates = []
158 |         mock_response.prompt_feedback = None
159 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
160 | 
161 |         # Call method
162 |         messages = [
163 |             Message(role='system', content='System message'),
164 |             Message(role='user', content='User message'),
165 |         ]
166 |         await gemini_client.generate_response(messages)
167 | 
168 |         # Verify system message is processed correctly
169 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
170 |         config = call_args[1]['config']
171 |         assert 'System message' in config.system_instruction
172 | 
173 |     @pytest.mark.asyncio
174 |     async def test_get_model_for_size(self, gemini_client):
175 |         """Test model selection based on size."""
176 |         # Test small model
177 |         small_model = gemini_client._get_model_for_size(ModelSize.small)
178 |         assert small_model == DEFAULT_SMALL_MODEL
179 | 
180 |         # Test medium/large model
181 |         medium_model = gemini_client._get_model_for_size(ModelSize.medium)
182 |         assert medium_model == gemini_client.model
183 | 
184 |     @pytest.mark.asyncio
185 |     async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
186 |         """Test handling of rate limit errors."""
187 |         # Setup mock to raise rate limit error
188 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
189 |             'Rate limit exceeded'
190 |         )
191 | 
192 |         # Call method and check exception
193 |         messages = [Message(role='user', content='Test message')]
194 |         with pytest.raises(RateLimitError):
195 |             await gemini_client.generate_response(messages)
196 | 
197 |     @pytest.mark.asyncio
198 |     async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
199 |         """Test handling of quota errors."""
200 |         # Setup mock to raise quota error
201 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
202 |             'Quota exceeded for requests'
203 |         )
204 | 
205 |         # Call method and check exception
206 |         messages = [Message(role='user', content='Test message')]
207 |         with pytest.raises(RateLimitError):
208 |             await gemini_client.generate_response(messages)
209 | 
210 |     @pytest.mark.asyncio
211 |     async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
212 |         """Test handling of resource exhausted errors."""
213 |         # Setup mock to raise resource exhausted error
214 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
215 |             'resource_exhausted: Request limit exceeded'
216 |         )
217 | 
218 |         # Call method and check exception
219 |         messages = [Message(role='user', content='Test message')]
220 |         with pytest.raises(RateLimitError):
221 |             await gemini_client.generate_response(messages)
222 | 
223 |     @pytest.mark.asyncio
224 |     async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
225 |         """Test handling of safety blocks."""
226 |         # Setup mock response with safety block
227 |         mock_candidate = MagicMock()
228 |         mock_candidate.finish_reason = 'SAFETY'
229 |         mock_candidate.safety_ratings = [
230 |             MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
231 |         ]
232 | 
233 |         mock_response = MagicMock()
234 |         mock_response.candidates = [mock_candidate]
235 |         mock_response.prompt_feedback = None
236 |         mock_response.text = ''
237 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
238 | 
239 |         # Call method and check exception
240 |         messages = [Message(role='user', content='Test message')]
241 |         with pytest.raises(Exception, match='Content blocked by safety filters'):
242 |             await gemini_client.generate_response(messages)
243 | 
244 |     @pytest.mark.asyncio
245 |     async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
246 |         """Test handling of prompt blocks."""
247 |         # Setup mock response with prompt block
248 |         mock_prompt_feedback = MagicMock()
249 |         mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
250 | 
251 |         mock_response = MagicMock()
252 |         mock_response.candidates = []
253 |         mock_response.prompt_feedback = mock_prompt_feedback
254 |         mock_response.text = ''
255 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
256 | 
257 |         # Call method and check exception
258 |         messages = [Message(role='user', content='Test message')]
259 |         with pytest.raises(Exception, match='Content blocked by safety filters'):
260 |             await gemini_client.generate_response(messages)
261 | 
262 |     @pytest.mark.asyncio
263 |     async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
264 |         """Test handling of structured output parsing errors."""
265 |         # Setup mock response with invalid JSON that will exhaust retries
266 |         mock_response = MagicMock()
267 |         mock_response.text = 'Invalid JSON that cannot be parsed'
268 |         mock_response.candidates = []
269 |         mock_response.prompt_feedback = None
270 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
271 | 
272 |         # Call method and check exception - should exhaust retries
273 |         messages = [Message(role='user', content='Test message')]
274 |         with pytest.raises(Exception):  # noqa: B017
275 |             await gemini_client.generate_response(messages, response_model=ResponseModel)
276 | 
277 |         # Should have called generate_content MAX_RETRIES times (2 attempts total)
278 |         assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
279 | 
280 |     @pytest.mark.asyncio
281 |     async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
282 |         """Test that safety blocks are not retried."""
283 |         # Setup mock response with safety block
284 |         mock_candidate = MagicMock()
285 |         mock_candidate.finish_reason = 'SAFETY'
286 |         mock_candidate.safety_ratings = [
287 |             MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
288 |         ]
289 | 
290 |         mock_response = MagicMock()
291 |         mock_response.candidates = [mock_candidate]
292 |         mock_response.prompt_feedback = None
293 |         mock_response.text = ''
294 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
295 | 
296 |         # Call method and check that it doesn't retry
297 |         messages = [Message(role='user', content='Test message')]
298 |         with pytest.raises(Exception, match='Content blocked by safety filters'):
299 |             await gemini_client.generate_response(messages)
300 | 
301 |         # Should only be called once (no retries for safety blocks)
302 |         assert mock_gemini_client.aio.models.generate_content.call_count == 1
303 | 
304 |     @pytest.mark.asyncio
305 |     async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
306 |         """Test retry behavior on validation error."""
307 |         # First call returns invalid JSON, second call returns valid data
308 |         mock_response1 = MagicMock()
309 |         mock_response1.text = 'Invalid JSON that cannot be parsed'
310 |         mock_response1.candidates = []
311 |         mock_response1.prompt_feedback = None
312 | 
313 |         mock_response2 = MagicMock()
314 |         mock_response2.text = '{"test_field": "correct_value"}'
315 |         mock_response2.candidates = []
316 |         mock_response2.prompt_feedback = None
317 | 
318 |         mock_gemini_client.aio.models.generate_content.side_effect = [
319 |             mock_response1,
320 |             mock_response2,
321 |         ]
322 | 
323 |         # Call method
324 |         messages = [Message(role='user', content='Test message')]
325 |         result = await gemini_client.generate_response(messages, response_model=ResponseModel)
326 | 
327 |         # Should have called generate_content twice due to retry
328 |         assert mock_gemini_client.aio.models.generate_content.call_count == 2
329 |         assert result['test_field'] == 'correct_value'
330 | 
331 |     @pytest.mark.asyncio
332 |     async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
333 |         """Test behavior when max retries are exceeded."""
334 |         # Setup mock to always return invalid JSON
335 |         mock_response = MagicMock()
336 |         mock_response.text = 'Invalid JSON that cannot be parsed'
337 |         mock_response.candidates = []
338 |         mock_response.prompt_feedback = None
339 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 | 
341 |         # Call method and check exception
342 |         messages = [Message(role='user', content='Test message')]
343 |         with pytest.raises(Exception):  # noqa: B017
344 |             await gemini_client.generate_response(messages, response_model=ResponseModel)
345 | 
346 |         # Should have called generate_content MAX_RETRIES times (2 attempts total)
347 |         assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
348 | 
349 |     @pytest.mark.asyncio
350 |     async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
351 |         """Test handling of empty responses."""
352 |         # Setup mock response with no text
353 |         mock_response = MagicMock()
354 |         mock_response.text = ''
355 |         mock_response.candidates = []
356 |         mock_response.prompt_feedback = None
357 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
358 | 
359 |         # Call method with structured output and check exception
360 |         messages = [Message(role='user', content='Test message')]
361 |         with pytest.raises(Exception):  # noqa: B017
362 |             await gemini_client.generate_response(messages, response_model=ResponseModel)
363 | 
364 |         # Should have exhausted retries due to empty response (2 attempts total)
365 |         assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
366 | 
367 |     @pytest.mark.asyncio
368 |     async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
369 |         """Test that explicit max_tokens parameter takes precedence over all other values."""
370 |         # Setup mock response
371 |         mock_response = MagicMock()
372 |         mock_response.text = 'Test response'
373 |         mock_response.candidates = []
374 |         mock_response.prompt_feedback = None
375 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
376 | 
377 |         # Call method with custom max tokens (should take precedence)
378 |         messages = [Message(role='user', content='Test message')]
379 |         await gemini_client.generate_response(messages, max_tokens=500)
380 | 
381 |         # Verify explicit max_tokens parameter takes precedence
382 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
383 |         config = call_args[1]['config']
384 |         # Explicit parameter should override everything else
385 |         assert config.max_output_tokens == 500
386 | 
387 |     @pytest.mark.asyncio
388 |     async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
389 |         """Test max_tokens precedence when no explicit parameter is provided."""
390 |         # Setup mock response
391 |         mock_response = MagicMock()
392 |         mock_response.text = 'Test response'
393 |         mock_response.candidates = []
394 |         mock_response.prompt_feedback = None
395 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
396 | 
397 |         # Test case 1: No explicit max_tokens, has instance max_tokens
398 |         config = LLMConfig(
399 |             api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
400 |         )
401 |         client = GeminiClient(
402 |             config=config, cache=False, max_tokens=2000, client=mock_gemini_client
403 |         )
404 | 
405 |         messages = [Message(role='user', content='Test message')]
406 |         await client.generate_response(messages)
407 | 
408 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
409 |         config = call_args[1]['config']
410 |         # Instance max_tokens should be used
411 |         assert config.max_output_tokens == 2000
412 | 
413 |         # Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
414 |         config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
415 |         client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
416 | 
417 |         messages = [Message(role='user', content='Test message')]
418 |         await client.generate_response(messages)
419 | 
420 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
421 |         config = call_args[1]['config']
422 |         # Model mapping should be used
423 |         assert config.max_output_tokens == 65536
424 | 
425 |     @pytest.mark.asyncio
426 |     async def test_model_size_selection(self, gemini_client, mock_gemini_client):
427 |         """Test that the correct model is selected based on model size."""
428 |         # Setup mock response
429 |         mock_response = MagicMock()
430 |         mock_response.text = 'Test response'
431 |         mock_response.candidates = []
432 |         mock_response.prompt_feedback = None
433 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
434 | 
435 |         # Call method with small model size
436 |         messages = [Message(role='user', content='Test message')]
437 |         await gemini_client.generate_response(messages, model_size=ModelSize.small)
438 | 
439 |         # Verify correct model is used
440 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
441 |         assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
442 | 
443 |     @pytest.mark.asyncio
444 |     async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
445 |         """Test that different Gemini models use their correct max tokens."""
446 |         # Setup mock response
447 |         mock_response = MagicMock()
448 |         mock_response.text = 'Test response'
449 |         mock_response.candidates = []
450 |         mock_response.prompt_feedback = None
451 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
452 | 
453 |         # Test data: (model_name, expected_max_tokens)
454 |         test_cases = [
455 |             ('gemini-2.5-flash', 65536),
456 |             ('gemini-2.5-pro', 65536),
457 |             ('gemini-2.5-flash-lite', 64000),
458 |             ('gemini-2.0-flash', 8192),
459 |             ('gemini-1.5-pro', 8192),
460 |             ('gemini-1.5-flash', 8192),
461 |             ('unknown-model', 8192),  # Fallback case
462 |         ]
463 | 
464 |         for model_name, expected_max_tokens in test_cases:
465 |             # Create client with specific model, no explicit max_tokens to test mapping
466 |             config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
467 |             client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
468 | 
469 |             # Call method without explicit max_tokens to test model mapping fallback
470 |             messages = [Message(role='user', content='Test message')]
471 |             await client.generate_response(messages)
472 | 
473 |             # Verify correct max tokens is used from model mapping
474 |             call_args = mock_gemini_client.aio.models.generate_content.call_args
475 |             config = call_args[1]['config']
476 |             assert config.max_output_tokens == expected_max_tokens, (
477 |                 f'Model {model_name} should use {expected_max_tokens} tokens'
478 |             )
479 | 
480 | 
481 | if __name__ == '__main__':
482 |     pytest.main(['-v', 'test_gemini_client.py'])
483 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Comprehensive integration test suite for Graphiti MCP Server.
  4 | Covers all MCP tools with consideration for LLM inference latency.
  5 | """
  6 | 
  7 | import asyncio
  8 | import json
  9 | import os
 10 | import time
 11 | from dataclasses import dataclass
 12 | from typing import Any
 13 | 
 14 | import pytest
 15 | from mcp import ClientSession, StdioServerParameters
 16 | from mcp.client.stdio import stdio_client
 17 | 
 18 | 
 19 | @dataclass
 20 | class TestMetrics:
 21 |     """Track test performance metrics."""
 22 | 
 23 |     operation: str
 24 |     start_time: float
 25 |     end_time: float
 26 |     success: bool
 27 |     details: dict[str, Any]
 28 | 
 29 |     @property
 30 |     def duration(self) -> float:
 31 |         """Calculate operation duration in seconds."""
 32 |         return self.end_time - self.start_time
 33 | 
 34 | 
 35 | class GraphitiTestClient:
 36 |     """Enhanced test client for comprehensive Graphiti MCP testing."""
 37 | 
 38 |     def __init__(self, test_group_id: str | None = None):
 39 |         self.test_group_id = test_group_id or f'test_{int(time.time())}'
 40 |         self.session = None
 41 |         self.metrics: list[TestMetrics] = []
 42 |         self.default_timeout = 30  # seconds
 43 | 
 44 |     async def __aenter__(self):
 45 |         """Initialize MCP client session."""
 46 |         server_params = StdioServerParameters(
 47 |             command='uv',
 48 |             args=['run', '../main.py', '--transport', 'stdio'],
 49 |             env={
 50 |                 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
 51 |                 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
 52 |                 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
 53 |                 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
 54 |                 'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
 55 |             },
 56 |         )
 57 | 
 58 |         self.client_context = stdio_client(server_params)
 59 |         read, write = await self.client_context.__aenter__()
 60 |         self.session = ClientSession(read, write)
 61 |         await self.session.initialize()
 62 | 
 63 |         # Wait for server to be fully ready
 64 |         await asyncio.sleep(2)
 65 | 
 66 |         return self
 67 | 
 68 |     async def __aexit__(self, exc_type, exc_val, exc_tb):
 69 |         """Clean up client session."""
 70 |         if self.session:
 71 |             await self.session.close()
 72 |         if hasattr(self, 'client_context'):
 73 |             await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
 74 | 
 75 |     async def call_tool_with_metrics(
 76 |         self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
 77 |     ) -> tuple[Any, TestMetrics]:
 78 |         """Call a tool and capture performance metrics."""
 79 |         start_time = time.time()
 80 |         timeout = timeout or self.default_timeout
 81 | 
 82 |         try:
 83 |             result = await asyncio.wait_for(
 84 |                 self.session.call_tool(tool_name, arguments), timeout=timeout
 85 |             )
 86 | 
 87 |             content = result.content[0].text if result.content else None
 88 |             success = True
 89 |             details = {'result': content, 'tool': tool_name}
 90 | 
 91 |         except asyncio.TimeoutError:
 92 |             content = None
 93 |             success = False
 94 |             details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
 95 | 
 96 |         except Exception as e:
 97 |             content = None
 98 |             success = False
 99 |             details = {'error': str(e), 'tool': tool_name}
100 | 
101 |         end_time = time.time()
102 |         metric = TestMetrics(
103 |             operation=f'call_{tool_name}',
104 |             start_time=start_time,
105 |             end_time=end_time,
106 |             success=success,
107 |             details=details,
108 |         )
109 |         self.metrics.append(metric)
110 | 
111 |         return content, metric
112 | 
113 |     async def wait_for_episode_processing(
114 |         self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
115 |     ) -> bool:
116 |         """
117 |         Wait for episodes to be processed with intelligent polling.
118 | 
119 |         Args:
120 |             expected_count: Number of episodes expected to be processed
121 |             max_wait: Maximum seconds to wait
122 |             poll_interval: Seconds between status checks
123 | 
124 |         Returns:
125 |             True if episodes were processed successfully
126 |         """
127 |         start_time = time.time()
128 | 
129 |         while (time.time() - start_time) < max_wait:
130 |             result, _ = await self.call_tool_with_metrics(
131 |                 'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
132 |             )
133 | 
134 |             if result:
135 |                 try:
136 |                     episodes = json.loads(result) if isinstance(result, str) else result
137 |                     if len(episodes.get('episodes', [])) >= expected_count:
138 |                         return True
139 |                 except (json.JSONDecodeError, AttributeError):
140 |                     pass
141 | 
142 |             await asyncio.sleep(poll_interval)
143 | 
144 |         return False
145 | 
146 | 
147 | class TestCoreOperations:
148 |     """Test core Graphiti operations."""
149 | 
150 |     @pytest.mark.asyncio
151 |     async def test_server_initialization(self):
152 |         """Verify server initializes with all required tools."""
153 |         async with GraphitiTestClient() as client:
154 |             tools_result = await client.session.list_tools()
155 |             tools = {tool.name for tool in tools_result.tools}
156 | 
157 |             required_tools = {
158 |                 'add_memory',
159 |                 'search_memory_nodes',
160 |                 'search_memory_facts',
161 |                 'get_episodes',
162 |                 'delete_episode',
163 |                 'delete_entity_edge',
164 |                 'get_entity_edge',
165 |                 'clear_graph',
166 |                 'get_status',
167 |             }
168 | 
169 |             missing_tools = required_tools - tools
170 |             assert not missing_tools, f'Missing required tools: {missing_tools}'
171 | 
172 |     @pytest.mark.asyncio
173 |     async def test_add_text_memory(self):
174 |         """Test adding text-based memories."""
175 |         async with GraphitiTestClient() as client:
176 |             # Add memory
177 |             result, metric = await client.call_tool_with_metrics(
178 |                 'add_memory',
179 |                 {
180 |                     'name': 'Tech Conference Notes',
181 |                     'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
182 |                     'source': 'text',
183 |                     'source_description': 'conference notes',
184 |                     'group_id': client.test_group_id,
185 |                 },
186 |             )
187 | 
188 |             assert metric.success, f'Failed to add memory: {metric.details}'
189 |             assert 'queued' in str(result).lower()
190 | 
191 |             # Wait for processing
192 |             processed = await client.wait_for_episode_processing(expected_count=1)
193 |             assert processed, 'Episode was not processed within timeout'
194 | 
195 |     @pytest.mark.asyncio
196 |     async def test_add_json_memory(self):
197 |         """Test adding structured JSON memories."""
198 |         async with GraphitiTestClient() as client:
199 |             json_data = {
200 |                 'project': {
201 |                     'name': 'GraphitiDB',
202 |                     'version': '2.0.0',
203 |                     'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
204 |                 },
205 |                 'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
206 |             }
207 | 
208 |             result, metric = await client.call_tool_with_metrics(
209 |                 'add_memory',
210 |                 {
211 |                     'name': 'Project Data',
212 |                     'episode_body': json.dumps(json_data),
213 |                     'source': 'json',
214 |                     'source_description': 'project database',
215 |                     'group_id': client.test_group_id,
216 |                 },
217 |             )
218 | 
219 |             assert metric.success
220 |             assert 'queued' in str(result).lower()
221 | 
222 |     @pytest.mark.asyncio
223 |     async def test_add_message_memory(self):
224 |         """Test adding conversation/message memories."""
225 |         async with GraphitiTestClient() as client:
226 |             conversation = """
227 |             user: What are the key features of Graphiti?
228 |             assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
229 |             user: How does it handle entity resolution?
230 |             assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
231 |             """
232 | 
233 |             result, metric = await client.call_tool_with_metrics(
234 |                 'add_memory',
235 |                 {
236 |                     'name': 'Feature Discussion',
237 |                     'episode_body': conversation,
238 |                     'source': 'message',
239 |                     'source_description': 'support chat',
240 |                     'group_id': client.test_group_id,
241 |                 },
242 |             )
243 | 
244 |             assert metric.success
245 |             assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
246 | 
247 | 
248 | class TestSearchOperations:
249 |     """Test search and retrieval operations."""
250 | 
251 |     @pytest.mark.asyncio
252 |     async def test_search_nodes_semantic(self):
253 |         """Test semantic search for nodes."""
254 |         async with GraphitiTestClient() as client:
255 |             # First add some test data
256 |             await client.call_tool_with_metrics(
257 |                 'add_memory',
258 |                 {
259 |                     'name': 'Product Launch',
260 |                     'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
261 |                     'source': 'text',
262 |                     'source_description': 'product roadmap',
263 |                     'group_id': client.test_group_id,
264 |                 },
265 |             )
266 | 
267 |             # Wait for processing
268 |             await client.wait_for_episode_processing()
269 | 
270 |             # Search for nodes
271 |             result, metric = await client.call_tool_with_metrics(
272 |                 'search_memory_nodes',
273 |                 {'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
274 |             )
275 | 
276 |             assert metric.success
277 |             assert result is not None
278 | 
279 |     @pytest.mark.asyncio
280 |     async def test_search_facts_with_filters(self):
281 |         """Test fact search with various filters."""
282 |         async with GraphitiTestClient() as client:
283 |             # Add test data
284 |             await client.call_tool_with_metrics(
285 |                 'add_memory',
286 |                 {
287 |                     'name': 'Company Facts',
288 |                     'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
289 |                     'source': 'text',
290 |                     'source_description': 'company profile',
291 |                     'group_id': client.test_group_id,
292 |                 },
293 |             )
294 | 
295 |             await client.wait_for_episode_processing()
296 | 
297 |             # Search with date filter
298 |             result, metric = await client.call_tool_with_metrics(
299 |                 'search_memory_facts',
300 |                 {
301 |                     'query': 'company information',
302 |                     'group_id': client.test_group_id,
303 |                     'created_after': '2020-01-01T00:00:00Z',
304 |                     'limit': 20,
305 |                 },
306 |             )
307 | 
308 |             assert metric.success
309 | 
310 |     @pytest.mark.asyncio
311 |     async def test_hybrid_search(self):
312 |         """Test hybrid search combining semantic and keyword search."""
313 |         async with GraphitiTestClient() as client:
314 |             # Add diverse test data
315 |             test_memories = [
316 |                 {
317 |                     'name': 'Technical Doc',
318 |                     'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
319 |                     'source': 'text',
320 |                 },
321 |                 {
322 |                     'name': 'Architecture',
323 |                     'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
324 |                     'source': 'text',
325 |                 },
326 |             ]
327 | 
328 |             for memory in test_memories:
329 |                 memory['group_id'] = client.test_group_id
330 |                 memory['source_description'] = 'documentation'
331 |                 await client.call_tool_with_metrics('add_memory', memory)
332 | 
333 |             await client.wait_for_episode_processing(expected_count=2)
334 | 
335 |             # Test semantic + keyword search
336 |             result, metric = await client.call_tool_with_metrics(
337 |                 'search_memory_nodes',
338 |                 {'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
339 |             )
340 | 
341 |             assert metric.success
342 | 
343 | 
344 | class TestEpisodeManagement:
345 |     """Test episode lifecycle operations."""
346 | 
347 |     @pytest.mark.asyncio
348 |     async def test_get_episodes_pagination(self):
349 |         """Test retrieving episodes with pagination."""
350 |         async with GraphitiTestClient() as client:
351 |             # Add multiple episodes
352 |             for i in range(5):
353 |                 await client.call_tool_with_metrics(
354 |                     'add_memory',
355 |                     {
356 |                         'name': f'Episode {i}',
357 |                         'episode_body': f'This is test episode number {i}',
358 |                         'source': 'text',
359 |                         'source_description': 'test',
360 |                         'group_id': client.test_group_id,
361 |                     },
362 |                 )
363 | 
364 |             await client.wait_for_episode_processing(expected_count=5)
365 | 
366 |             # Test pagination
367 |             result, metric = await client.call_tool_with_metrics(
368 |                 'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
369 |             )
370 | 
371 |             assert metric.success
372 |             episodes = json.loads(result) if isinstance(result, str) else result
373 |             assert len(episodes.get('episodes', [])) <= 3
374 | 
375 |     @pytest.mark.asyncio
376 |     async def test_delete_episode(self):
377 |         """Test deleting specific episodes."""
378 |         async with GraphitiTestClient() as client:
379 |             # Add an episode
380 |             await client.call_tool_with_metrics(
381 |                 'add_memory',
382 |                 {
383 |                     'name': 'To Delete',
384 |                     'episode_body': 'This episode will be deleted',
385 |                     'source': 'text',
386 |                     'source_description': 'test',
387 |                     'group_id': client.test_group_id,
388 |                 },
389 |             )
390 | 
391 |             await client.wait_for_episode_processing()
392 | 
393 |             # Get episode UUID
394 |             result, _ = await client.call_tool_with_metrics(
395 |                 'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
396 |             )
397 | 
398 |             episodes = json.loads(result) if isinstance(result, str) else result
399 |             episode_uuid = episodes['episodes'][0]['uuid']
400 | 
401 |             # Delete the episode
402 |             result, metric = await client.call_tool_with_metrics(
403 |                 'delete_episode', {'episode_uuid': episode_uuid}
404 |             )
405 | 
406 |             assert metric.success
407 |             assert 'deleted' in str(result).lower()
408 | 
409 | 
410 | class TestEntityAndEdgeOperations:
411 |     """Test entity and edge management."""
412 | 
413 |     @pytest.mark.asyncio
414 |     async def test_get_entity_edge(self):
415 |         """Test retrieving entity edges."""
416 |         async with GraphitiTestClient() as client:
417 |             # Add data to create entities and edges
418 |             await client.call_tool_with_metrics(
419 |                 'add_memory',
420 |                 {
421 |                     'name': 'Relationship Data',
422 |                     'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
423 |                     'source': 'text',
424 |                     'source_description': 'org chart',
425 |                     'group_id': client.test_group_id,
426 |                 },
427 |             )
428 | 
429 |             await client.wait_for_episode_processing()
430 | 
431 |             # Search for nodes to get UUIDs
432 |             result, _ = await client.call_tool_with_metrics(
433 |                 'search_memory_nodes',
434 |                 {'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
435 |             )
436 | 
437 |             # Note: This test assumes edges are created between entities
438 |             # Actual edge retrieval would require valid edge UUIDs
439 | 
440 |     @pytest.mark.asyncio
441 |     async def test_delete_entity_edge(self):
442 |         """Test deleting entity edges."""
443 |         # Similar structure to get_entity_edge but with deletion
444 |         pass  # Implement based on actual edge creation patterns
445 | 
446 | 
447 | class TestErrorHandling:
448 |     """Test error conditions and edge cases."""
449 | 
450 |     @pytest.mark.asyncio
451 |     async def test_invalid_tool_arguments(self):
452 |         """Test handling of invalid tool arguments."""
453 |         async with GraphitiTestClient() as client:
454 |             # Missing required arguments
455 |             result, metric = await client.call_tool_with_metrics(
456 |                 'add_memory',
457 |                 {'name': 'Incomplete'},  # Missing required fields
458 |             )
459 | 
460 |             assert not metric.success
461 |             assert 'error' in str(metric.details).lower()
462 | 
463 |     @pytest.mark.asyncio
464 |     async def test_timeout_handling(self):
465 |         """Test timeout handling for long operations."""
466 |         async with GraphitiTestClient() as client:
467 |             # Simulate a very large episode that might time out
468 |             large_text = 'Large document content. ' * 10000
469 | 
470 |             result, metric = await client.call_tool_with_metrics(
471 |                 'add_memory',
472 |                 {
473 |                     'name': 'Large Document',
474 |                     'episode_body': large_text,
475 |                     'source': 'text',
476 |                     'source_description': 'large file',
477 |                     'group_id': client.test_group_id,
478 |                 },
479 |                 timeout=5,  # Short timeout
480 |             )
481 | 
482 |             # Check if timeout was handled gracefully
483 |             if not metric.success:
484 |                 assert 'timeout' in str(metric.details).lower()
485 | 
486 |     @pytest.mark.asyncio
487 |     async def test_concurrent_operations(self):
488 |         """Test handling of concurrent operations."""
489 |         async with GraphitiTestClient() as client:
490 |             # Launch multiple operations concurrently
491 |             tasks = []
492 |             for i in range(5):
493 |                 task = client.call_tool_with_metrics(
494 |                     'add_memory',
495 |                     {
496 |                         'name': f'Concurrent {i}',
497 |                         'episode_body': f'Concurrent operation {i}',
498 |                         'source': 'text',
499 |                         'source_description': 'concurrent test',
500 |                         'group_id': client.test_group_id,
501 |                     },
502 |                 )
503 |                 tasks.append(task)
504 | 
505 |             results = await asyncio.gather(*tasks, return_exceptions=True)
506 | 
507 |             # Check that operations were queued successfully
508 |             successful = sum(1 for r, m in results if m.success)
509 |             assert successful >= 3  # At least 60% should succeed
510 | 
511 | 
512 | class TestPerformance:
513 |     """Test performance characteristics and optimization."""
514 | 
515 |     @pytest.mark.asyncio
516 |     async def test_latency_metrics(self):
517 |         """Measure and validate operation latencies."""
518 |         async with GraphitiTestClient() as client:
519 |             operations = [
520 |                 (
521 |                     'add_memory',
522 |                     {
523 |                         'name': 'Perf Test',
524 |                         'episode_body': 'Simple text',
525 |                         'source': 'text',
526 |                         'source_description': 'test',
527 |                         'group_id': client.test_group_id,
528 |                     },
529 |                 ),
530 |                 (
531 |                     'search_memory_nodes',
532 |                     {'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
533 |                 ),
534 |                 ('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
535 |             ]
536 | 
537 |             for tool_name, args in operations:
538 |                 _, metric = await client.call_tool_with_metrics(tool_name, args)
539 | 
540 |                 # Log performance metrics
541 |                 print(f'{tool_name}: {metric.duration:.2f}s')
542 | 
543 |                 # Basic latency assertions
544 |                 if tool_name == 'get_episodes':
545 |                     assert metric.duration < 2, f'{tool_name} too slow'
546 |                 elif tool_name == 'search_memory_nodes':
547 |                     assert metric.duration < 10, f'{tool_name} too slow'
548 | 
549 |     @pytest.mark.asyncio
550 |     async def test_batch_processing_efficiency(self):
551 |         """Test efficiency of batch operations."""
552 |         async with GraphitiTestClient() as client:
553 |             batch_size = 10
554 |             start_time = time.time()
555 | 
556 |             # Batch add memories
557 |             for i in range(batch_size):
558 |                 await client.call_tool_with_metrics(
559 |                     'add_memory',
560 |                     {
561 |                         'name': f'Batch {i}',
562 |                         'episode_body': f'Batch content {i}',
563 |                         'source': 'text',
564 |                         'source_description': 'batch test',
565 |                         'group_id': client.test_group_id,
566 |                     },
567 |                 )
568 | 
569 |             # Wait for all to process
570 |             processed = await client.wait_for_episode_processing(
571 |                 expected_count=batch_size,
572 |                 max_wait=120,  # Allow more time for batch
573 |             )
574 | 
575 |             total_time = time.time() - start_time
576 |             avg_time_per_item = total_time / batch_size
577 | 
578 |             assert processed, f'Failed to process {batch_size} items'
579 |             assert avg_time_per_item < 15, (
580 |                 f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
581 |             )
582 | 
583 |             # Generate performance report
584 |             print('\nBatch Performance Report:')
585 |             print(f'  Total items: {batch_size}')
586 |             print(f'  Total time: {total_time:.2f}s')
587 |             print(f'  Avg per item: {avg_time_per_item:.2f}s')
588 | 
589 | 
590 | class TestDatabaseBackends:
591 |     """Test different database backend configurations."""
592 | 
593 |     @pytest.mark.asyncio
594 |     @pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
595 |     async def test_database_operations(self, database):
596 |         """Test operations with different database backends."""
597 |         env_vars = {
598 |             'DATABASE_PROVIDER': database,
599 |             'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
600 |         }
601 | 
602 |         if database == 'neo4j':
603 |             env_vars.update(
604 |                 {
605 |                     'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
606 |                     'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
607 |                     'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
608 |                 }
609 |             )
610 |         elif database == 'falkordb':
611 |             env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
612 | 
613 |         # This test would require setting up server with specific database
614 |         # Implementation depends on database availability
615 |         pass  # Placeholder for database-specific tests
616 | 
617 | 
618 | def generate_test_report(client: GraphitiTestClient) -> str:
619 |     """Generate a comprehensive test report from metrics."""
620 |     if not client.metrics:
621 |         return 'No metrics collected'
622 | 
623 |     report = []
624 |     report.append('\n' + '=' * 60)
625 |     report.append('GRAPHITI MCP TEST REPORT')
626 |     report.append('=' * 60)
627 | 
628 |     # Summary statistics
629 |     total_ops = len(client.metrics)
630 |     successful_ops = sum(1 for m in client.metrics if m.success)
631 |     avg_duration = sum(m.duration for m in client.metrics) / total_ops
632 | 
633 |     report.append(f'\nTotal Operations: {total_ops}')
634 |     report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
635 |     report.append(f'Average Duration: {avg_duration:.2f}s')
636 | 
637 |     # Operation breakdown
638 |     report.append('\nOperation Breakdown:')
639 |     operation_stats = {}
640 |     for metric in client.metrics:
641 |         if metric.operation not in operation_stats:
642 |             operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
643 |         stats = operation_stats[metric.operation]
644 |         stats['count'] += 1
645 |         stats['success'] += 1 if metric.success else 0
646 |         stats['total_duration'] += metric.duration
647 | 
648 |     for op, stats in sorted(operation_stats.items()):
649 |         avg_dur = stats['total_duration'] / stats['count']
650 |         success_rate = stats['success'] / stats['count'] * 100
651 |         report.append(
652 |             f'  {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
653 |         )
654 | 
655 |     # Slowest operations
656 |     slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
657 |     report.append('\nSlowest Operations:')
658 |     for metric in slowest:
659 |         report.append(f'  {metric.operation}: {metric.duration:.2f}s')
660 | 
661 |     report.append('=' * 60)
662 |     return '\n'.join(report)
663 | 
664 | 
665 | if __name__ == '__main__':
666 |     # Run tests with pytest
667 |     pytest.main([__file__, '-v', '--asyncio-mode=auto'])
668 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/edge_operations.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | from datetime import datetime
 19 | from time import time
 20 | 
 21 | from pydantic import BaseModel
 22 | from typing_extensions import LiteralString
 23 | 
 24 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
 25 | from graphiti_core.edges import (
 26 |     CommunityEdge,
 27 |     EntityEdge,
 28 |     EpisodicEdge,
 29 |     create_entity_edge_embeddings,
 30 | )
 31 | from graphiti_core.graphiti_types import GraphitiClients
 32 | from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
 33 | from graphiti_core.llm_client import LLMClient
 34 | from graphiti_core.llm_client.config import ModelSize
 35 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
 36 | from graphiti_core.prompts import prompt_library
 37 | from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
 38 | from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
 39 | from graphiti_core.search.search import search
 40 | from graphiti_core.search.search_config import SearchResults
 41 | from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
 42 | from graphiti_core.search.search_filters import SearchFilters
 43 | from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
 44 | from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
 45 | 
 46 | DEFAULT_EDGE_NAME = 'RELATES_TO'
 47 | 
 48 | logger = logging.getLogger(__name__)
 49 | 
 50 | 
 51 | def build_episodic_edges(
 52 |     entity_nodes: list[EntityNode],
 53 |     episode_uuid: str,
 54 |     created_at: datetime,
 55 | ) -> list[EpisodicEdge]:
 56 |     episodic_edges: list[EpisodicEdge] = [
 57 |         EpisodicEdge(
 58 |             source_node_uuid=episode_uuid,
 59 |             target_node_uuid=node.uuid,
 60 |             created_at=created_at,
 61 |             group_id=node.group_id,
 62 |         )
 63 |         for node in entity_nodes
 64 |     ]
 65 | 
 66 |     logger.debug(f'Built episodic edges: {episodic_edges}')
 67 | 
 68 |     return episodic_edges
 69 | 
 70 | 
 71 | def build_community_edges(
 72 |     entity_nodes: list[EntityNode],
 73 |     community_node: CommunityNode,
 74 |     created_at: datetime,
 75 | ) -> list[CommunityEdge]:
 76 |     edges: list[CommunityEdge] = [
 77 |         CommunityEdge(
 78 |             source_node_uuid=community_node.uuid,
 79 |             target_node_uuid=node.uuid,
 80 |             created_at=created_at,
 81 |             group_id=community_node.group_id,
 82 |         )
 83 |         for node in entity_nodes
 84 |     ]
 85 | 
 86 |     return edges
 87 | 
 88 | 
 89 | async def extract_edges(
 90 |     clients: GraphitiClients,
 91 |     episode: EpisodicNode,
 92 |     nodes: list[EntityNode],
 93 |     previous_episodes: list[EpisodicNode],
 94 |     edge_type_map: dict[tuple[str, str], list[str]],
 95 |     group_id: str = '',
 96 |     edge_types: dict[str, type[BaseModel]] | None = None,
 97 | ) -> list[EntityEdge]:
 98 |     start = time()
 99 | 
100 |     extract_edges_max_tokens = 16384
101 |     llm_client = clients.llm_client
102 | 
103 |     edge_type_signature_map: dict[str, tuple[str, str]] = {
104 |         edge_type: signature
105 |         for signature, edge_types in edge_type_map.items()
106 |         for edge_type in edge_types
107 |     }
108 | 
109 |     edge_types_context = (
110 |         [
111 |             {
112 |                 'fact_type_name': type_name,
113 |                 'fact_type_signature': edge_type_signature_map.get(type_name, ('Entity', 'Entity')),
114 |                 'fact_type_description': type_model.__doc__,
115 |             }
116 |             for type_name, type_model in edge_types.items()
117 |         ]
118 |         if edge_types is not None
119 |         else []
120 |     )
121 | 
122 |     # Prepare context for LLM
123 |     context = {
124 |         'episode_content': episode.content,
125 |         'nodes': [
126 |             {'id': idx, 'name': node.name, 'entity_types': node.labels}
127 |             for idx, node in enumerate(nodes)
128 |         ],
129 |         'previous_episodes': [ep.content for ep in previous_episodes],
130 |         'reference_time': episode.valid_at,
131 |         'edge_types': edge_types_context,
132 |         'custom_prompt': '',
133 |     }
134 | 
135 |     facts_missed = True
136 |     reflexion_iterations = 0
137 |     while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
138 |         llm_response = await llm_client.generate_response(
139 |             prompt_library.extract_edges.edge(context),
140 |             response_model=ExtractedEdges,
141 |             max_tokens=extract_edges_max_tokens,
142 |             group_id=group_id,
143 |             prompt_name='extract_edges.edge',
144 |         )
145 |         edges_data = ExtractedEdges(**llm_response).edges
146 | 
147 |         context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
148 | 
149 |         reflexion_iterations += 1
150 |         if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
151 |             reflexion_response = await llm_client.generate_response(
152 |                 prompt_library.extract_edges.reflexion(context),
153 |                 response_model=MissingFacts,
154 |                 max_tokens=extract_edges_max_tokens,
155 |                 group_id=group_id,
156 |                 prompt_name='extract_edges.reflexion',
157 |             )
158 | 
159 |             missing_facts = reflexion_response.get('missing_facts', [])
160 | 
161 |             custom_prompt = 'The following facts were missed in a previous extraction: '
162 |             for fact in missing_facts:
163 |                 custom_prompt += f'\n{fact},'
164 | 
165 |             context['custom_prompt'] = custom_prompt
166 | 
167 |             facts_missed = len(missing_facts) != 0
168 | 
169 |     end = time()
170 |     logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
171 | 
172 |     if len(edges_data) == 0:
173 |         return []
174 | 
175 |     # Convert the extracted data into EntityEdge objects
176 |     edges = []
177 |     for edge_data in edges_data:
178 |         # Validate Edge Date information
179 |         valid_at = edge_data.valid_at
180 |         invalid_at = edge_data.invalid_at
181 |         valid_at_datetime = None
182 |         invalid_at_datetime = None
183 | 
184 |         # Filter out empty edges
185 |         if not edge_data.fact.strip():
186 |             continue
187 | 
188 |         source_node_idx = edge_data.source_entity_id
189 |         target_node_idx = edge_data.target_entity_id
190 | 
191 |         if len(nodes) == 0:
192 |             logger.warning('No entities provided for edge extraction')
193 |             continue
194 | 
195 |         if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
196 |             logger.warning(
197 |                 f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
198 |                 f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
199 |                 f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
200 |             )
201 |             continue
202 |         source_node_uuid = nodes[source_node_idx].uuid
203 |         target_node_uuid = nodes[target_node_idx].uuid
204 | 
205 |         if valid_at:
206 |             try:
207 |                 valid_at_datetime = ensure_utc(
208 |                     datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
209 |                 )
210 |             except ValueError as e:
211 |                 logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
212 | 
213 |         if invalid_at:
214 |             try:
215 |                 invalid_at_datetime = ensure_utc(
216 |                     datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
217 |                 )
218 |             except ValueError as e:
219 |                 logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
220 |         edge = EntityEdge(
221 |             source_node_uuid=source_node_uuid,
222 |             target_node_uuid=target_node_uuid,
223 |             name=edge_data.relation_type,
224 |             group_id=group_id,
225 |             fact=edge_data.fact,
226 |             episodes=[episode.uuid],
227 |             created_at=utc_now(),
228 |             valid_at=valid_at_datetime,
229 |             invalid_at=invalid_at_datetime,
230 |         )
231 |         edges.append(edge)
232 |         logger.debug(
233 |             f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
234 |         )
235 | 
236 |     logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
237 | 
238 |     return edges
239 | 
240 | 
241 | async def resolve_extracted_edges(
242 |     clients: GraphitiClients,
243 |     extracted_edges: list[EntityEdge],
244 |     episode: EpisodicNode,
245 |     entities: list[EntityNode],
246 |     edge_types: dict[str, type[BaseModel]],
247 |     edge_type_map: dict[tuple[str, str], list[str]],
248 | ) -> tuple[list[EntityEdge], list[EntityEdge]]:
249 |     # Fast path: deduplicate exact matches within the extracted edges before parallel processing
250 |     seen: dict[tuple[str, str, str], EntityEdge] = {}
251 |     deduplicated_edges: list[EntityEdge] = []
252 | 
253 |     for edge in extracted_edges:
254 |         key = (
255 |             edge.source_node_uuid,
256 |             edge.target_node_uuid,
257 |             _normalize_string_exact(edge.fact),
258 |         )
259 |         if key not in seen:
260 |             seen[key] = edge
261 |             deduplicated_edges.append(edge)
262 | 
263 |     extracted_edges = deduplicated_edges
264 | 
265 |     driver = clients.driver
266 |     llm_client = clients.llm_client
267 |     embedder = clients.embedder
268 |     await create_entity_edge_embeddings(embedder, extracted_edges)
269 | 
270 |     valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
271 |         *[
272 |             EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
273 |             for edge in extracted_edges
274 |         ]
275 |     )
276 | 
277 |     related_edges_results: list[SearchResults] = await semaphore_gather(
278 |         *[
279 |             search(
280 |                 clients,
281 |                 extracted_edge.fact,
282 |                 group_ids=[extracted_edge.group_id],
283 |                 config=EDGE_HYBRID_SEARCH_RRF,
284 |                 search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
285 |             )
286 |             for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
287 |         ]
288 |     )
289 | 
290 |     related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
291 | 
292 |     edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
293 |         *[
294 |             search(
295 |                 clients,
296 |                 extracted_edge.fact,
297 |                 group_ids=[extracted_edge.group_id],
298 |                 config=EDGE_HYBRID_SEARCH_RRF,
299 |                 search_filter=SearchFilters(),
300 |             )
301 |             for extracted_edge in extracted_edges
302 |         ]
303 |     )
304 | 
305 |     edge_invalidation_candidates: list[list[EntityEdge]] = [
306 |         result.edges for result in edge_invalidation_candidate_results
307 |     ]
308 | 
309 |     logger.debug(
310 |         f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
311 |     )
312 | 
313 |     # Build entity hash table
314 |     uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
315 | 
316 |     # Determine which edge types are relevant for each edge.
317 |     # `edge_types_lst` stores the subset of custom edge definitions whose
318 |     # node signature matches each extracted edge. Anything outside this subset
319 |     # should only stay on the edge if it is a non-custom (LLM generated) label.
320 |     edge_types_lst: list[dict[str, type[BaseModel]]] = []
321 |     custom_type_names = set(edge_types or {})
322 |     for extracted_edge in extracted_edges:
323 |         source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
324 |         target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
325 |         source_node_labels = (
326 |             source_node.labels + ['Entity'] if source_node is not None else ['Entity']
327 |         )
328 |         target_node_labels = (
329 |             target_node.labels + ['Entity'] if target_node is not None else ['Entity']
330 |         )
331 |         label_tuples = [
332 |             (source_label, target_label)
333 |             for source_label in source_node_labels
334 |             for target_label in target_node_labels
335 |         ]
336 | 
337 |         extracted_edge_types = {}
338 |         for label_tuple in label_tuples:
339 |             type_names = edge_type_map.get(label_tuple, [])
340 |             for type_name in type_names:
341 |                 type_model = edge_types.get(type_name)
342 |                 if type_model is None:
343 |                     continue
344 | 
345 |                 extracted_edge_types[type_name] = type_model
346 | 
347 |         edge_types_lst.append(extracted_edge_types)
348 | 
349 |     for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
350 |         allowed_type_names = set(extracted_edge_types)
351 |         is_custom_name = extracted_edge.name in custom_type_names
352 |         if not allowed_type_names:
353 |             # No custom types are valid for this node pairing. Keep LLM generated
354 |             # labels, but flip disallowed custom names back to the default.
355 |             if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
356 |                 extracted_edge.name = DEFAULT_EDGE_NAME
357 |             continue
358 |         if is_custom_name and extracted_edge.name not in allowed_type_names:
359 |             # Custom name exists but it is not permitted for this source/target
360 |             # signature, so fall back to the default edge label.
361 |             extracted_edge.name = DEFAULT_EDGE_NAME
362 | 
363 |     # resolve edges with related edges in the graph and find invalidation candidates
364 |     results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
365 |         await semaphore_gather(
366 |             *[
367 |                 resolve_extracted_edge(
368 |                     llm_client,
369 |                     extracted_edge,
370 |                     related_edges,
371 |                     existing_edges,
372 |                     episode,
373 |                     extracted_edge_types,
374 |                     custom_type_names,
375 |                 )
376 |                 for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
377 |                     extracted_edges,
378 |                     related_edges_lists,
379 |                     edge_invalidation_candidates,
380 |                     edge_types_lst,
381 |                     strict=True,
382 |                 )
383 |             ]
384 |         )
385 |     )
386 | 
387 |     resolved_edges: list[EntityEdge] = []
388 |     invalidated_edges: list[EntityEdge] = []
389 |     for result in results:
390 |         resolved_edge = result[0]
391 |         invalidated_edge_chunk = result[1]
392 | 
393 |         resolved_edges.append(resolved_edge)
394 |         invalidated_edges.extend(invalidated_edge_chunk)
395 | 
396 |     logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
397 | 
398 |     await semaphore_gather(
399 |         create_entity_edge_embeddings(embedder, resolved_edges),
400 |         create_entity_edge_embeddings(embedder, invalidated_edges),
401 |     )
402 | 
403 |     return resolved_edges, invalidated_edges
404 | 
405 | 
406 | def resolve_edge_contradictions(
407 |     resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
408 | ) -> list[EntityEdge]:
409 |     if len(invalidation_candidates) == 0:
410 |         return []
411 | 
412 |     # Determine which contradictory edges need to be expired
413 |     invalidated_edges: list[EntityEdge] = []
414 |     for edge in invalidation_candidates:
415 |         # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
416 |         edge_invalid_at_utc = ensure_utc(edge.invalid_at)
417 |         resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
418 |         edge_valid_at_utc = ensure_utc(edge.valid_at)
419 |         resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
420 | 
421 |         if (
422 |             edge_invalid_at_utc is not None
423 |             and resolved_edge_valid_at_utc is not None
424 |             and edge_invalid_at_utc <= resolved_edge_valid_at_utc
425 |         ) or (
426 |             edge_valid_at_utc is not None
427 |             and resolved_edge_invalid_at_utc is not None
428 |             and resolved_edge_invalid_at_utc <= edge_valid_at_utc
429 |         ):
430 |             continue
431 |         # New edge invalidates edge
432 |         elif (
433 |             edge_valid_at_utc is not None
434 |             and resolved_edge_valid_at_utc is not None
435 |             and edge_valid_at_utc < resolved_edge_valid_at_utc
436 |         ):
437 |             edge.invalid_at = resolved_edge.valid_at
438 |             edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
439 |             invalidated_edges.append(edge)
440 | 
441 |     return invalidated_edges
442 | 
443 | 
444 | async def resolve_extracted_edge(
445 |     llm_client: LLMClient,
446 |     extracted_edge: EntityEdge,
447 |     related_edges: list[EntityEdge],
448 |     existing_edges: list[EntityEdge],
449 |     episode: EpisodicNode,
450 |     edge_type_candidates: dict[str, type[BaseModel]] | None = None,
451 |     custom_edge_type_names: set[str] | None = None,
452 | ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
453 |     """Resolve an extracted edge against existing graph context.
454 | 
455 |     Parameters
456 |     ----------
457 |     llm_client : LLMClient
458 |         Client used to invoke the LLM for deduplication and attribute extraction.
459 |     extracted_edge : EntityEdge
460 |         Newly extracted edge whose canonical representation is being resolved.
461 |     related_edges : list[EntityEdge]
462 |         Candidate edges with identical endpoints used for duplicate detection.
463 |     existing_edges : list[EntityEdge]
464 |         Broader set of edges evaluated for contradiction / invalidation.
465 |     episode : EpisodicNode
466 |         Episode providing content context when extracting edge attributes.
467 |     edge_type_candidates : dict[str, type[BaseModel]] | None
468 |         Custom edge types permitted for the current source/target signature.
469 |     custom_edge_type_names : set[str] | None
470 |         Full catalog of registered custom edge names. Used to distinguish
471 |         between disallowed custom types (which fall back to the default label)
472 |         and ad-hoc labels emitted by the LLM.
473 | 
474 |     Returns
475 |     -------
476 |     tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
477 |         The resolved edge, any duplicates, and edges to invalidate.
478 |     """
479 |     if len(related_edges) == 0 and len(existing_edges) == 0:
480 |         return extracted_edge, [], []
481 | 
482 |     # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
483 |     normalized_fact = _normalize_string_exact(extracted_edge.fact)
484 |     for edge in related_edges:
485 |         if (
486 |             edge.source_node_uuid == extracted_edge.source_node_uuid
487 |             and edge.target_node_uuid == extracted_edge.target_node_uuid
488 |             and _normalize_string_exact(edge.fact) == normalized_fact
489 |         ):
490 |             resolved = edge
491 |             if episode is not None and episode.uuid not in resolved.episodes:
492 |                 resolved.episodes.append(episode.uuid)
493 |             return resolved, [], []
494 | 
495 |     start = time()
496 | 
497 |     # Prepare context for LLM
498 |     related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
499 | 
500 |     invalidation_edge_candidates_context = [
501 |         {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
502 |     ]
503 | 
504 |     edge_types_context = (
505 |         [
506 |             {
507 |                 'fact_type_name': type_name,
508 |                 'fact_type_description': type_model.__doc__,
509 |             }
510 |             for type_name, type_model in edge_type_candidates.items()
511 |         ]
512 |         if edge_type_candidates is not None
513 |         else []
514 |     )
515 | 
516 |     context = {
517 |         'existing_edges': related_edges_context,
518 |         'new_edge': extracted_edge.fact,
519 |         'edge_invalidation_candidates': invalidation_edge_candidates_context,
520 |         'edge_types': edge_types_context,
521 |     }
522 | 
523 |     if related_edges or existing_edges:
524 |         logger.debug(
525 |             'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
526 |             len(related_edges),
527 |             f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
528 |             len(existing_edges),
529 |             f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
530 |         )
531 | 
532 |     llm_response = await llm_client.generate_response(
533 |         prompt_library.dedupe_edges.resolve_edge(context),
534 |         response_model=EdgeDuplicate,
535 |         model_size=ModelSize.small,
536 |         prompt_name='dedupe_edges.resolve_edge',
537 |     )
538 |     response_object = EdgeDuplicate(**llm_response)
539 |     duplicate_facts = response_object.duplicate_facts
540 | 
541 |     # Validate duplicate_facts are in valid range for EXISTING FACTS
542 |     invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
543 |     if invalid_duplicates:
544 |         logger.warning(
545 |             'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
546 |             invalid_duplicates,
547 |             len(related_edges) - 1,
548 |         )
549 | 
550 |     duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
551 | 
552 |     resolved_edge = extracted_edge
553 |     for duplicate_fact_id in duplicate_fact_ids:
554 |         resolved_edge = related_edges[duplicate_fact_id]
555 |         break
556 | 
557 |     if duplicate_fact_ids and episode is not None:
558 |         resolved_edge.episodes.append(episode.uuid)
559 | 
560 |     contradicted_facts: list[int] = response_object.contradicted_facts
561 | 
562 |     # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
563 |     invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
564 |     if invalid_contradictions:
565 |         logger.warning(
566 |             'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
567 |             invalid_contradictions,
568 |             len(existing_edges) - 1,
569 |         )
570 | 
571 |     invalidation_candidates: list[EntityEdge] = [
572 |         existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
573 |     ]
574 | 
575 |     fact_type: str = response_object.fact_type
576 |     candidate_type_names = set(edge_type_candidates or {})
577 |     custom_type_names = custom_edge_type_names or set()
578 | 
579 |     is_default_type = fact_type.upper() == 'DEFAULT'
580 |     is_custom_type = fact_type in custom_type_names
581 |     is_allowed_custom_type = fact_type in candidate_type_names
582 | 
583 |     if is_allowed_custom_type:
584 |         # The LLM selected a custom type that is allowed for the node pair.
585 |         # Adopt the custom type and, if needed, extract its structured attributes.
586 |         resolved_edge.name = fact_type
587 | 
588 |         edge_attributes_context = {
589 |             'episode_content': episode.content,
590 |             'reference_time': episode.valid_at,
591 |             'fact': resolved_edge.fact,
592 |         }
593 | 
594 |         edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
595 |         if edge_model is not None and len(edge_model.model_fields) != 0:
596 |             edge_attributes_response = await llm_client.generate_response(
597 |                 prompt_library.extract_edges.extract_attributes(edge_attributes_context),
598 |                 response_model=edge_model,  # type: ignore
599 |                 model_size=ModelSize.small,
600 |                 prompt_name='extract_edges.extract_attributes',
601 |             )
602 | 
603 |             resolved_edge.attributes = edge_attributes_response
604 |     elif not is_default_type and is_custom_type:
605 |         # The LLM picked a custom type that is not allowed for this signature.
606 |         # Reset to the default label and drop any structured attributes.
607 |         resolved_edge.name = DEFAULT_EDGE_NAME
608 |         resolved_edge.attributes = {}
609 |     elif not is_default_type:
610 |         # Non-custom labels are allowed to pass through so long as the LLM does
611 |         # not return the sentinel DEFAULT value.
612 |         resolved_edge.name = fact_type
613 |         resolved_edge.attributes = {}
614 | 
615 |     end = time()
616 |     logger.debug(
617 |         f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
618 |     )
619 | 
620 |     now = utc_now()
621 | 
622 |     if resolved_edge.invalid_at and not resolved_edge.expired_at:
623 |         resolved_edge.expired_at = now
624 | 
625 |     # Determine if the new_edge needs to be expired
626 |     if resolved_edge.expired_at is None:
627 |         invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
628 |         for candidate in invalidation_candidates:
629 |             candidate_valid_at_utc = ensure_utc(candidate.valid_at)
630 |             resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
631 |             if (
632 |                 candidate_valid_at_utc is not None
633 |                 and resolved_edge_valid_at_utc is not None
634 |                 and candidate_valid_at_utc > resolved_edge_valid_at_utc
635 |             ):
636 |                 # Expire new edge since we have information about more recent events
637 |                 resolved_edge.invalid_at = candidate.valid_at
638 |                 resolved_edge.expired_at = now
639 |                 break
640 | 
641 |     # Determine which contradictory edges need to be expired
642 |     invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
643 |         resolved_edge, invalidation_candidates
644 |     )
645 |     duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
646 | 
647 |     return resolved_edge, invalidated_edges, duplicate_edges
648 | 
649 | 
650 | async def filter_existing_duplicate_of_edges(
651 |     driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
652 | ) -> list[tuple[EntityNode, EntityNode]]:
653 |     if not duplicates_node_tuples:
654 |         return []
655 | 
656 |     duplicate_nodes_map = {
657 |         (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
658 |     }
659 | 
660 |     if driver.provider == GraphProvider.NEPTUNE:
661 |         query: LiteralString = """
662 |             UNWIND $duplicate_node_uuids AS duplicate_tuple
663 |             MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
664 |             RETURN DISTINCT
665 |                 n.uuid AS source_uuid,
666 |                 m.uuid AS target_uuid
667 |         """
668 | 
669 |         duplicate_nodes = [
670 |             {'source': source.uuid, 'target': target.uuid}
671 |             for source, target in duplicates_node_tuples
672 |         ]
673 | 
674 |         records, _, _ = await driver.execute_query(
675 |             query,
676 |             duplicate_node_uuids=duplicate_nodes,
677 |             routing_='r',
678 |         )
679 |     else:
680 |         if driver.provider == GraphProvider.KUZU:
681 |             query = """
682 |                 UNWIND $duplicate_node_uuids AS duplicate
683 |                 MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
684 |                 RETURN DISTINCT
685 |                     n.uuid AS source_uuid,
686 |                     m.uuid AS target_uuid
687 |             """
688 |             duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
689 |         else:
690 |             query: LiteralString = """
691 |                 UNWIND $duplicate_node_uuids AS duplicate_tuple
692 |                 MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
693 |                 RETURN DISTINCT
694 |                     n.uuid AS source_uuid,
695 |                     m.uuid AS target_uuid
696 |             """
697 |             duplicate_node_uuids = list(duplicate_nodes_map.keys())
698 | 
699 |         records, _, _ = await driver.execute_query(
700 |             query,
701 |             duplicate_node_uuids=duplicate_node_uuids,
702 |             routing_='r',
703 |         )
704 | 
705 |     # Remove duplicates that already have the IS_DUPLICATE_OF edge
706 |     for record in records:
707 |         duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
708 |         if duplicate_nodes_map.get(duplicate_tuple):
709 |             duplicate_nodes_map.pop(duplicate_tuple)
710 | 
711 |     return list(duplicate_nodes_map.values())
712 | 
```

--------------------------------------------------------------------------------
/graphiti_core/nodes.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | from abc import ABC, abstractmethod
 20 | from datetime import datetime
 21 | from enum import Enum
 22 | from time import time
 23 | from typing import Any
 24 | from uuid import uuid4
 25 | 
 26 | from pydantic import BaseModel, Field
 27 | from typing_extensions import LiteralString
 28 | 
 29 | from graphiti_core.driver.driver import (
 30 |     GraphDriver,
 31 |     GraphProvider,
 32 | )
 33 | from graphiti_core.embedder import EmbedderClient
 34 | from graphiti_core.errors import NodeNotFoundError
 35 | from graphiti_core.helpers import parse_db_date
 36 | from graphiti_core.models.nodes.node_db_queries import (
 37 |     COMMUNITY_NODE_RETURN,
 38 |     COMMUNITY_NODE_RETURN_NEPTUNE,
 39 |     EPISODIC_NODE_RETURN,
 40 |     EPISODIC_NODE_RETURN_NEPTUNE,
 41 |     get_community_node_save_query,
 42 |     get_entity_node_return_query,
 43 |     get_entity_node_save_query,
 44 |     get_episode_node_save_query,
 45 | )
 46 | from graphiti_core.utils.datetime_utils import utc_now
 47 | 
 48 | logger = logging.getLogger(__name__)
 49 | 
 50 | 
 51 | class EpisodeType(Enum):
 52 |     """
 53 |     Enumeration of different types of episodes that can be processed.
 54 | 
 55 |     This enum defines the various sources or formats of episodes that the system
 56 |     can handle. It's used to categorize and potentially handle different types
 57 |     of input data differently.
 58 | 
 59 |     Attributes:
 60 |     -----------
 61 |     message : str
 62 |         Represents a standard message-type episode. The content for this type
 63 |         should be formatted as "actor: content". For example, "user: Hello, how are you?"
 64 |         or "assistant: I'm doing well, thank you for asking."
 65 |     json : str
 66 |         Represents an episode containing a JSON string object with structured data.
 67 |     text : str
 68 |         Represents a plain text episode.
 69 |     """
 70 | 
 71 |     message = 'message'
 72 |     json = 'json'
 73 |     text = 'text'
 74 | 
 75 |     @staticmethod
 76 |     def from_str(episode_type: str):
 77 |         if episode_type == 'message':
 78 |             return EpisodeType.message
 79 |         if episode_type == 'json':
 80 |             return EpisodeType.json
 81 |         if episode_type == 'text':
 82 |             return EpisodeType.text
 83 |         logger.error(f'Episode type: {episode_type} not implemented')
 84 |         raise NotImplementedError
 85 | 
 86 | 
 87 | class Node(BaseModel, ABC):
 88 |     uuid: str = Field(default_factory=lambda: str(uuid4()))
 89 |     name: str = Field(description='name of the node')
 90 |     group_id: str = Field(description='partition of the graph')
 91 |     labels: list[str] = Field(default_factory=list)
 92 |     created_at: datetime = Field(default_factory=lambda: utc_now())
 93 | 
 94 |     @abstractmethod
 95 |     async def save(self, driver: GraphDriver): ...
 96 | 
 97 |     async def delete(self, driver: GraphDriver):
 98 |         if driver.graph_operations_interface:
 99 |             return await driver.graph_operations_interface.node_delete(self, driver)
100 | 
101 |         match driver.provider:
102 |             case GraphProvider.NEO4J:
103 |                 records, _, _ = await driver.execute_query(
104 |                     """
105 |                     MATCH (n {uuid: $uuid})
106 |                     WHERE n:Entity OR n:Episodic OR n:Community
107 |                     OPTIONAL MATCH (n)-[r]-()
108 |                     WITH collect(r.uuid) AS edge_uuids, n
109 |                     DETACH DELETE n
110 |                     RETURN edge_uuids
111 |                     """,
112 |                     uuid=self.uuid,
113 |                 )
114 | 
115 |             case GraphProvider.KUZU:
116 |                 for label in ['Episodic', 'Community']:
117 |                     await driver.execute_query(
118 |                         f"""
119 |                         MATCH (n:{label} {{uuid: $uuid}})
120 |                         DETACH DELETE n
121 |                         """,
122 |                         uuid=self.uuid,
123 |                     )
124 |                 # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
125 |                 # Explicitly delete the "edge" nodes first, then the entity node.
126 |                 await driver.execute_query(
127 |                     """
128 |                     MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
129 |                     DETACH DELETE e
130 |                     """,
131 |                     uuid=self.uuid,
132 |                 )
133 |                 await driver.execute_query(
134 |                     """
135 |                     MATCH (n:Entity {uuid: $uuid})
136 |                     DETACH DELETE n
137 |                     """,
138 |                     uuid=self.uuid,
139 |                 )
140 |             case _:  # FalkorDB, Neptune
141 |                 for label in ['Entity', 'Episodic', 'Community']:
142 |                     await driver.execute_query(
143 |                         f"""
144 |                         MATCH (n:{label} {{uuid: $uuid}})
145 |                         DETACH DELETE n
146 |                         """,
147 |                         uuid=self.uuid,
148 |                     )
149 | 
150 |         logger.debug(f'Deleted Node: {self.uuid}')
151 | 
152 |     def __hash__(self):
153 |         return hash(self.uuid)
154 | 
155 |     def __eq__(self, other):
156 |         if isinstance(other, Node):
157 |             return self.uuid == other.uuid
158 |         return False
159 | 
160 |     @classmethod
161 |     async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
162 |         if driver.graph_operations_interface:
163 |             return await driver.graph_operations_interface.node_delete_by_group_id(
164 |                 cls, driver, group_id, batch_size
165 |             )
166 | 
167 |         match driver.provider:
168 |             case GraphProvider.NEO4J:
169 |                 async with driver.session() as session:
170 |                     await session.run(
171 |                         """
172 |                         MATCH (n:Entity|Episodic|Community {group_id: $group_id})
173 |                         CALL (n) {
174 |                             DETACH DELETE n
175 |                         } IN TRANSACTIONS OF $batch_size ROWS
176 |                         """,
177 |                         group_id=group_id,
178 |                         batch_size=batch_size,
179 |                     )
180 | 
181 |             case GraphProvider.KUZU:
182 |                 for label in ['Episodic', 'Community']:
183 |                     await driver.execute_query(
184 |                         f"""
185 |                         MATCH (n:{label} {{group_id: $group_id}})
186 |                         DETACH DELETE n
187 |                         """,
188 |                         group_id=group_id,
189 |                     )
190 |                 # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
191 |                 # Explicitly delete the "edge" nodes first, then the entity node.
192 |                 await driver.execute_query(
193 |                     """
194 |                     MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
195 |                     DETACH DELETE e
196 |                     """,
197 |                     group_id=group_id,
198 |                 )
199 |                 await driver.execute_query(
200 |                     """
201 |                     MATCH (n:Entity {group_id: $group_id})
202 |                     DETACH DELETE n
203 |                     """,
204 |                     group_id=group_id,
205 |                 )
206 |             case _:  # FalkorDB, Neptune
207 |                 for label in ['Entity', 'Episodic', 'Community']:
208 |                     await driver.execute_query(
209 |                         f"""
210 |                         MATCH (n:{label} {{group_id: $group_id}})
211 |                         DETACH DELETE n
212 |                         """,
213 |                         group_id=group_id,
214 |                     )
215 | 
216 |     @classmethod
217 |     async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
218 |         if driver.graph_operations_interface:
219 |             return await driver.graph_operations_interface.node_delete_by_uuids(
220 |                 cls, driver, uuids, group_id=None, batch_size=batch_size
221 |             )
222 | 
223 |         match driver.provider:
224 |             case GraphProvider.FALKORDB:
225 |                 for label in ['Entity', 'Episodic', 'Community']:
226 |                     await driver.execute_query(
227 |                         f"""
228 |                         MATCH (n:{label})
229 |                         WHERE n.uuid IN $uuids
230 |                         DETACH DELETE n
231 |                         """,
232 |                         uuids=uuids,
233 |                     )
234 |             case GraphProvider.KUZU:
235 |                 for label in ['Episodic', 'Community']:
236 |                     await driver.execute_query(
237 |                         f"""
238 |                         MATCH (n:{label})
239 |                         WHERE n.uuid IN $uuids
240 |                         DETACH DELETE n
241 |                         """,
242 |                         uuids=uuids,
243 |                     )
244 |                 # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
245 |                 # Explicitly delete the "edge" nodes first, then the entity node.
246 |                 await driver.execute_query(
247 |                     """
248 |                     MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
249 |                     WHERE n.uuid IN $uuids
250 |                     DETACH DELETE e
251 |                     """,
252 |                     uuids=uuids,
253 |                 )
254 |                 await driver.execute_query(
255 |                     """
256 |                     MATCH (n:Entity)
257 |                     WHERE n.uuid IN $uuids
258 |                     DETACH DELETE n
259 |                     """,
260 |                     uuids=uuids,
261 |                 )
262 |             case _:  # Neo4J, Neptune
263 |                 async with driver.session() as session:
264 |                     # Collect all edge UUIDs before deleting nodes
265 |                     await session.run(
266 |                         """
267 |                         MATCH (n:Entity|Episodic|Community)
268 |                         WHERE n.uuid IN $uuids
269 |                         MATCH (n)-[r]-()
270 |                         RETURN collect(r.uuid) AS edge_uuids
271 |                         """,
272 |                         uuids=uuids,
273 |                     )
274 | 
275 |                     # Now delete the nodes in batches
276 |                     await session.run(
277 |                         """
278 |                         MATCH (n:Entity|Episodic|Community)
279 |                         WHERE n.uuid IN $uuids
280 |                         CALL (n) {
281 |                             DETACH DELETE n
282 |                         } IN TRANSACTIONS OF $batch_size ROWS
283 |                         """,
284 |                         uuids=uuids,
285 |                         batch_size=batch_size,
286 |                     )
287 | 
288 |     @classmethod
289 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
290 | 
291 |     @classmethod
292 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
293 | 
294 | 
295 | class EpisodicNode(Node):
296 |     source: EpisodeType = Field(description='source type')
297 |     source_description: str = Field(description='description of the data source')
298 |     content: str = Field(description='raw episode data')
299 |     valid_at: datetime = Field(
300 |         description='datetime of when the original document was created',
301 |     )
302 |     entity_edges: list[str] = Field(
303 |         description='list of entity edges referenced in this episode',
304 |         default_factory=list,
305 |     )
306 | 
307 |     async def save(self, driver: GraphDriver):
308 |         if driver.graph_operations_interface:
309 |             return await driver.graph_operations_interface.episodic_node_save(self, driver)
310 | 
311 |         episode_args = {
312 |             'uuid': self.uuid,
313 |             'name': self.name,
314 |             'group_id': self.group_id,
315 |             'source_description': self.source_description,
316 |             'content': self.content,
317 |             'entity_edges': self.entity_edges,
318 |             'created_at': self.created_at,
319 |             'valid_at': self.valid_at,
320 |             'source': self.source.value,
321 |         }
322 | 
323 |         result = await driver.execute_query(
324 |             get_episode_node_save_query(driver.provider), **episode_args
325 |         )
326 | 
327 |         logger.debug(f'Saved Node to Graph: {self.uuid}')
328 | 
329 |         return result
330 | 
331 |     @classmethod
332 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
333 |         records, _, _ = await driver.execute_query(
334 |             """
335 |             MATCH (e:Episodic {uuid: $uuid})
336 |             RETURN
337 |             """
338 |             + (
339 |                 EPISODIC_NODE_RETURN_NEPTUNE
340 |                 if driver.provider == GraphProvider.NEPTUNE
341 |                 else EPISODIC_NODE_RETURN
342 |             ),
343 |             uuid=uuid,
344 |             routing_='r',
345 |         )
346 | 
347 |         episodes = [get_episodic_node_from_record(record) for record in records]
348 | 
349 |         if len(episodes) == 0:
350 |             raise NodeNotFoundError(uuid)
351 | 
352 |         return episodes[0]
353 | 
354 |     @classmethod
355 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
356 |         records, _, _ = await driver.execute_query(
357 |             """
358 |             MATCH (e:Episodic)
359 |             WHERE e.uuid IN $uuids
360 |             RETURN DISTINCT
361 |             """
362 |             + (
363 |                 EPISODIC_NODE_RETURN_NEPTUNE
364 |                 if driver.provider == GraphProvider.NEPTUNE
365 |                 else EPISODIC_NODE_RETURN
366 |             ),
367 |             uuids=uuids,
368 |             routing_='r',
369 |         )
370 | 
371 |         episodes = [get_episodic_node_from_record(record) for record in records]
372 | 
373 |         return episodes
374 | 
375 |     @classmethod
376 |     async def get_by_group_ids(
377 |         cls,
378 |         driver: GraphDriver,
379 |         group_ids: list[str],
380 |         limit: int | None = None,
381 |         uuid_cursor: str | None = None,
382 |     ):
383 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
384 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
385 | 
386 |         records, _, _ = await driver.execute_query(
387 |             """
388 |             MATCH (e:Episodic)
389 |             WHERE e.group_id IN $group_ids
390 |             """
391 |             + cursor_query
392 |             + """
393 |             RETURN DISTINCT
394 |             """
395 |             + (
396 |                 EPISODIC_NODE_RETURN_NEPTUNE
397 |                 if driver.provider == GraphProvider.NEPTUNE
398 |                 else EPISODIC_NODE_RETURN
399 |             )
400 |             + """
401 |             ORDER BY uuid DESC
402 |             """
403 |             + limit_query,
404 |             group_ids=group_ids,
405 |             uuid=uuid_cursor,
406 |             limit=limit,
407 |             routing_='r',
408 |         )
409 | 
410 |         episodes = [get_episodic_node_from_record(record) for record in records]
411 | 
412 |         return episodes
413 | 
414 |     @classmethod
415 |     async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
416 |         records, _, _ = await driver.execute_query(
417 |             """
418 |             MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
419 |             RETURN DISTINCT
420 |             """
421 |             + (
422 |                 EPISODIC_NODE_RETURN_NEPTUNE
423 |                 if driver.provider == GraphProvider.NEPTUNE
424 |                 else EPISODIC_NODE_RETURN
425 |             ),
426 |             entity_node_uuid=entity_node_uuid,
427 |             routing_='r',
428 |         )
429 | 
430 |         episodes = [get_episodic_node_from_record(record) for record in records]
431 | 
432 |         return episodes
433 | 
434 | 
435 | class EntityNode(Node):
436 |     name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
437 |     summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
438 |     attributes: dict[str, Any] = Field(
439 |         default={}, description='Additional attributes of the node. Dependent on node labels'
440 |     )
441 | 
442 |     async def generate_name_embedding(self, embedder: EmbedderClient):
443 |         start = time()
444 |         text = self.name.replace('\n', ' ')
445 |         self.name_embedding = await embedder.create(input_data=[text])
446 |         end = time()
447 |         logger.debug(f'embedded {text} in {end - start} ms')
448 | 
449 |         return self.name_embedding
450 | 
451 |     async def load_name_embedding(self, driver: GraphDriver):
452 |         if driver.graph_operations_interface:
453 |             return await driver.graph_operations_interface.node_load_embeddings(self, driver)
454 | 
455 |         if driver.provider == GraphProvider.NEPTUNE:
456 |             query: LiteralString = """
457 |                 MATCH (n:Entity {uuid: $uuid})
458 |                 RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
459 |             """
460 | 
461 |         else:
462 |             query: LiteralString = """
463 |                 MATCH (n:Entity {uuid: $uuid})
464 |                 RETURN n.name_embedding AS name_embedding
465 |             """
466 |         records, _, _ = await driver.execute_query(
467 |             query,
468 |             uuid=self.uuid,
469 |             routing_='r',
470 |         )
471 | 
472 |         if len(records) == 0:
473 |             raise NodeNotFoundError(self.uuid)
474 | 
475 |         self.name_embedding = records[0]['name_embedding']
476 | 
477 |     async def save(self, driver: GraphDriver):
478 |         if driver.graph_operations_interface:
479 |             return await driver.graph_operations_interface.node_save(self, driver)
480 | 
481 |         entity_data: dict[str, Any] = {
482 |             'uuid': self.uuid,
483 |             'name': self.name,
484 |             'name_embedding': self.name_embedding,
485 |             'group_id': self.group_id,
486 |             'summary': self.summary,
487 |             'created_at': self.created_at,
488 |         }
489 | 
490 |         if driver.provider == GraphProvider.KUZU:
491 |             entity_data['attributes'] = json.dumps(self.attributes)
492 |             entity_data['labels'] = list(set(self.labels + ['Entity']))
493 |             result = await driver.execute_query(
494 |                 get_entity_node_save_query(driver.provider, labels=''),
495 |                 **entity_data,
496 |             )
497 |         else:
498 |             entity_data.update(self.attributes or {})
499 |             labels = ':'.join(self.labels + ['Entity'])
500 | 
501 |             result = await driver.execute_query(
502 |                 get_entity_node_save_query(driver.provider, labels),
503 |                 entity_data=entity_data,
504 |             )
505 | 
506 |         logger.debug(f'Saved Node to Graph: {self.uuid}')
507 | 
508 |         return result
509 | 
510 |     @classmethod
511 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
512 |         records, _, _ = await driver.execute_query(
513 |             """
514 |             MATCH (n:Entity {uuid: $uuid})
515 |             RETURN
516 |             """
517 |             + get_entity_node_return_query(driver.provider),
518 |             uuid=uuid,
519 |             routing_='r',
520 |         )
521 | 
522 |         nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
523 | 
524 |         if len(nodes) == 0:
525 |             raise NodeNotFoundError(uuid)
526 | 
527 |         return nodes[0]
528 | 
529 |     @classmethod
530 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
531 |         records, _, _ = await driver.execute_query(
532 |             """
533 |             MATCH (n:Entity)
534 |             WHERE n.uuid IN $uuids
535 |             RETURN
536 |             """
537 |             + get_entity_node_return_query(driver.provider),
538 |             uuids=uuids,
539 |             routing_='r',
540 |         )
541 | 
542 |         nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
543 | 
544 |         return nodes
545 | 
546 |     @classmethod
547 |     async def get_by_group_ids(
548 |         cls,
549 |         driver: GraphDriver,
550 |         group_ids: list[str],
551 |         limit: int | None = None,
552 |         uuid_cursor: str | None = None,
553 |         with_embeddings: bool = False,
554 |     ):
555 |         cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
556 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
557 |         with_embeddings_query: LiteralString = (
558 |             """,
559 |             n.name_embedding AS name_embedding
560 |             """
561 |             if with_embeddings
562 |             else ''
563 |         )
564 | 
565 |         records, _, _ = await driver.execute_query(
566 |             """
567 |             MATCH (n:Entity)
568 |             WHERE n.group_id IN $group_ids
569 |             """
570 |             + cursor_query
571 |             + """
572 |             RETURN
573 |             """
574 |             + get_entity_node_return_query(driver.provider)
575 |             + with_embeddings_query
576 |             + """
577 |             ORDER BY n.uuid DESC
578 |             """
579 |             + limit_query,
580 |             group_ids=group_ids,
581 |             uuid=uuid_cursor,
582 |             limit=limit,
583 |             routing_='r',
584 |         )
585 | 
586 |         nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
587 | 
588 |         return nodes
589 | 
590 | 
591 | class CommunityNode(Node):
592 |     name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
593 |     summary: str = Field(description='region summary of member nodes', default_factory=str)
594 | 
595 |     async def save(self, driver: GraphDriver):
596 |         if driver.provider == GraphProvider.NEPTUNE:
597 |             await driver.save_to_aoss(  # pyright: ignore reportAttributeAccessIssue
598 |                 'communities',
599 |                 [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
600 |             )
601 |         result = await driver.execute_query(
602 |             get_community_node_save_query(driver.provider),  # type: ignore
603 |             uuid=self.uuid,
604 |             name=self.name,
605 |             group_id=self.group_id,
606 |             summary=self.summary,
607 |             name_embedding=self.name_embedding,
608 |             created_at=self.created_at,
609 |         )
610 | 
611 |         logger.debug(f'Saved Node to Graph: {self.uuid}')
612 | 
613 |         return result
614 | 
615 |     async def generate_name_embedding(self, embedder: EmbedderClient):
616 |         start = time()
617 |         text = self.name.replace('\n', ' ')
618 |         self.name_embedding = await embedder.create(input_data=[text])
619 |         end = time()
620 |         logger.debug(f'embedded {text} in {end - start} ms')
621 | 
622 |         return self.name_embedding
623 | 
624 |     async def load_name_embedding(self, driver: GraphDriver):
625 |         if driver.provider == GraphProvider.NEPTUNE:
626 |             query: LiteralString = """
627 |                 MATCH (c:Community {uuid: $uuid})
628 |                 RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
629 |             """
630 |         else:
631 |             query: LiteralString = """
632 |             MATCH (c:Community {uuid: $uuid})
633 |             RETURN c.name_embedding AS name_embedding
634 |             """
635 | 
636 |         records, _, _ = await driver.execute_query(
637 |             query,
638 |             uuid=self.uuid,
639 |             routing_='r',
640 |         )
641 | 
642 |         if len(records) == 0:
643 |             raise NodeNotFoundError(self.uuid)
644 | 
645 |         self.name_embedding = records[0]['name_embedding']
646 | 
647 |     @classmethod
648 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
649 |         records, _, _ = await driver.execute_query(
650 |             """
651 |             MATCH (c:Community {uuid: $uuid})
652 |             RETURN
653 |             """
654 |             + (
655 |                 COMMUNITY_NODE_RETURN_NEPTUNE
656 |                 if driver.provider == GraphProvider.NEPTUNE
657 |                 else COMMUNITY_NODE_RETURN
658 |             ),
659 |             uuid=uuid,
660 |             routing_='r',
661 |         )
662 | 
663 |         nodes = [get_community_node_from_record(record) for record in records]
664 | 
665 |         if len(nodes) == 0:
666 |             raise NodeNotFoundError(uuid)
667 | 
668 |         return nodes[0]
669 | 
670 |     @classmethod
671 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
672 |         records, _, _ = await driver.execute_query(
673 |             """
674 |             MATCH (c:Community)
675 |             WHERE c.uuid IN $uuids
676 |             RETURN
677 |             """
678 |             + (
679 |                 COMMUNITY_NODE_RETURN_NEPTUNE
680 |                 if driver.provider == GraphProvider.NEPTUNE
681 |                 else COMMUNITY_NODE_RETURN
682 |             ),
683 |             uuids=uuids,
684 |             routing_='r',
685 |         )
686 | 
687 |         communities = [get_community_node_from_record(record) for record in records]
688 | 
689 |         return communities
690 | 
691 |     @classmethod
692 |     async def get_by_group_ids(
693 |         cls,
694 |         driver: GraphDriver,
695 |         group_ids: list[str],
696 |         limit: int | None = None,
697 |         uuid_cursor: str | None = None,
698 |     ):
699 |         cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
700 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
701 | 
702 |         records, _, _ = await driver.execute_query(
703 |             """
704 |             MATCH (c:Community)
705 |             WHERE c.group_id IN $group_ids
706 |             """
707 |             + cursor_query
708 |             + """
709 |             RETURN
710 |             """
711 |             + (
712 |                 COMMUNITY_NODE_RETURN_NEPTUNE
713 |                 if driver.provider == GraphProvider.NEPTUNE
714 |                 else COMMUNITY_NODE_RETURN
715 |             )
716 |             + """
717 |             ORDER BY c.uuid DESC
718 |             """
719 |             + limit_query,
720 |             group_ids=group_ids,
721 |             uuid=uuid_cursor,
722 |             limit=limit,
723 |             routing_='r',
724 |         )
725 | 
726 |         communities = [get_community_node_from_record(record) for record in records]
727 | 
728 |         return communities
729 | 
730 | 
731 | # Node helpers
732 | def get_episodic_node_from_record(record: Any) -> EpisodicNode:
733 |     created_at = parse_db_date(record['created_at'])
734 |     valid_at = parse_db_date(record['valid_at'])
735 | 
736 |     if created_at is None:
737 |         raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
738 |     if valid_at is None:
739 |         raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
740 | 
741 |     return EpisodicNode(
742 |         content=record['content'],
743 |         created_at=created_at,
744 |         valid_at=valid_at,
745 |         uuid=record['uuid'],
746 |         group_id=record['group_id'],
747 |         source=EpisodeType.from_str(record['source']),
748 |         name=record['name'],
749 |         source_description=record['source_description'],
750 |         entity_edges=record['entity_edges'],
751 |     )
752 | 
753 | 
754 | def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
755 |     if provider == GraphProvider.KUZU:
756 |         attributes = json.loads(record['attributes']) if record['attributes'] else {}
757 |     else:
758 |         attributes = record['attributes']
759 |         attributes.pop('uuid', None)
760 |         attributes.pop('name', None)
761 |         attributes.pop('group_id', None)
762 |         attributes.pop('name_embedding', None)
763 |         attributes.pop('summary', None)
764 |         attributes.pop('created_at', None)
765 |         attributes.pop('labels', None)
766 | 
767 |     labels = record.get('labels', [])
768 |     group_id = record.get('group_id')
769 |     if 'Entity_' + group_id.replace('-', '') in labels:
770 |         labels.remove('Entity_' + group_id.replace('-', ''))
771 | 
772 |     entity_node = EntityNode(
773 |         uuid=record['uuid'],
774 |         name=record['name'],
775 |         name_embedding=record.get('name_embedding'),
776 |         group_id=group_id,
777 |         labels=labels,
778 |         created_at=parse_db_date(record['created_at']),  # type: ignore
779 |         summary=record['summary'],
780 |         attributes=attributes,
781 |     )
782 | 
783 |     return entity_node
784 | 
785 | 
786 | def get_community_node_from_record(record: Any) -> CommunityNode:
787 |     return CommunityNode(
788 |         uuid=record['uuid'],
789 |         name=record['name'],
790 |         group_id=record['group_id'],
791 |         name_embedding=record['name_embedding'],
792 |         created_at=parse_db_date(record['created_at']),  # type: ignore
793 |         summary=record['summary'],
794 |     )
795 | 
796 | 
797 | async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
798 |     # filter out falsey values from nodes
799 |     filtered_nodes = [node for node in nodes if node.name]
800 | 
801 |     if not filtered_nodes:
802 |         return
803 | 
804 |     name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
805 |     for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
806 |         node.name_embedding = name_embedding
807 | 
```
Page 8/12FirstPrevNextLast