This is page 8 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | from abc import ABC, abstractmethod
20 | from datetime import datetime
21 | from time import time
22 | from typing import Any
23 | from uuid import uuid4
24 |
25 | from pydantic import BaseModel, Field
26 | from typing_extensions import LiteralString
27 |
28 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
29 | from graphiti_core.embedder import EmbedderClient
30 | from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
31 | from graphiti_core.helpers import parse_db_date
32 | from graphiti_core.models.edges.edge_db_queries import (
33 | COMMUNITY_EDGE_RETURN,
34 | EPISODIC_EDGE_RETURN,
35 | EPISODIC_EDGE_SAVE,
36 | get_community_edge_save_query,
37 | get_entity_edge_return_query,
38 | get_entity_edge_save_query,
39 | )
40 | from graphiti_core.nodes import Node
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | class Edge(BaseModel, ABC):
46 | uuid: str = Field(default_factory=lambda: str(uuid4()))
47 | group_id: str = Field(description='partition of the graph')
48 | source_node_uuid: str
49 | target_node_uuid: str
50 | created_at: datetime
51 |
52 | @abstractmethod
53 | async def save(self, driver: GraphDriver): ...
54 |
55 | async def delete(self, driver: GraphDriver):
56 | if driver.graph_operations_interface:
57 | return await driver.graph_operations_interface.edge_delete(self, driver)
58 |
59 | if driver.provider == GraphProvider.KUZU:
60 | await driver.execute_query(
61 | """
62 | MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
63 | DELETE e
64 | """,
65 | uuid=self.uuid,
66 | )
67 | await driver.execute_query(
68 | """
69 | MATCH (e:RelatesToNode_ {uuid: $uuid})
70 | DETACH DELETE e
71 | """,
72 | uuid=self.uuid,
73 | )
74 | else:
75 | await driver.execute_query(
76 | """
77 | MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
78 | DELETE e
79 | """,
80 | uuid=self.uuid,
81 | )
82 |
83 | logger.debug(f'Deleted Edge: {self.uuid}')
84 |
85 | @classmethod
86 | async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
87 | if driver.graph_operations_interface:
88 | return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
89 |
90 | if driver.provider == GraphProvider.KUZU:
91 | await driver.execute_query(
92 | """
93 | MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
94 | WHERE e.uuid IN $uuids
95 | DELETE e
96 | """,
97 | uuids=uuids,
98 | )
99 | await driver.execute_query(
100 | """
101 | MATCH (e:RelatesToNode_)
102 | WHERE e.uuid IN $uuids
103 | DETACH DELETE e
104 | """,
105 | uuids=uuids,
106 | )
107 | else:
108 | await driver.execute_query(
109 | """
110 | MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
111 | WHERE e.uuid IN $uuids
112 | DELETE e
113 | """,
114 | uuids=uuids,
115 | )
116 |
117 | logger.debug(f'Deleted Edges: {uuids}')
118 |
119 | def __hash__(self):
120 | return hash(self.uuid)
121 |
122 | def __eq__(self, other):
123 | if isinstance(other, Node):
124 | return self.uuid == other.uuid
125 | return False
126 |
127 | @classmethod
128 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
129 |
130 |
131 | class EpisodicEdge(Edge):
132 | async def save(self, driver: GraphDriver):
133 | result = await driver.execute_query(
134 | EPISODIC_EDGE_SAVE,
135 | episode_uuid=self.source_node_uuid,
136 | entity_uuid=self.target_node_uuid,
137 | uuid=self.uuid,
138 | group_id=self.group_id,
139 | created_at=self.created_at,
140 | )
141 |
142 | logger.debug(f'Saved edge to Graph: {self.uuid}')
143 |
144 | return result
145 |
146 | @classmethod
147 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
148 | records, _, _ = await driver.execute_query(
149 | """
150 | MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
151 | RETURN
152 | """
153 | + EPISODIC_EDGE_RETURN,
154 | uuid=uuid,
155 | routing_='r',
156 | )
157 |
158 | edges = [get_episodic_edge_from_record(record) for record in records]
159 |
160 | if len(edges) == 0:
161 | raise EdgeNotFoundError(uuid)
162 | return edges[0]
163 |
164 | @classmethod
165 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
166 | records, _, _ = await driver.execute_query(
167 | """
168 | MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
169 | WHERE e.uuid IN $uuids
170 | RETURN
171 | """
172 | + EPISODIC_EDGE_RETURN,
173 | uuids=uuids,
174 | routing_='r',
175 | )
176 |
177 | edges = [get_episodic_edge_from_record(record) for record in records]
178 |
179 | if len(edges) == 0:
180 | raise EdgeNotFoundError(uuids[0])
181 | return edges
182 |
183 | @classmethod
184 | async def get_by_group_ids(
185 | cls,
186 | driver: GraphDriver,
187 | group_ids: list[str],
188 | limit: int | None = None,
189 | uuid_cursor: str | None = None,
190 | ):
191 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
192 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
193 |
194 | records, _, _ = await driver.execute_query(
195 | """
196 | MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
197 | WHERE e.group_id IN $group_ids
198 | """
199 | + cursor_query
200 | + """
201 | RETURN
202 | """
203 | + EPISODIC_EDGE_RETURN
204 | + """
205 | ORDER BY e.uuid DESC
206 | """
207 | + limit_query,
208 | group_ids=group_ids,
209 | uuid=uuid_cursor,
210 | limit=limit,
211 | routing_='r',
212 | )
213 |
214 | edges = [get_episodic_edge_from_record(record) for record in records]
215 |
216 | if len(edges) == 0:
217 | raise GroupsEdgesNotFoundError(group_ids)
218 | return edges
219 |
220 |
221 | class EntityEdge(Edge):
222 | name: str = Field(description='name of the edge, relation name')
223 | fact: str = Field(description='fact representing the edge and nodes that it connects')
224 | fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
225 | episodes: list[str] = Field(
226 | default=[],
227 | description='list of episode ids that reference these entity edges',
228 | )
229 | expired_at: datetime | None = Field(
230 | default=None, description='datetime of when the node was invalidated'
231 | )
232 | valid_at: datetime | None = Field(
233 | default=None, description='datetime of when the fact became true'
234 | )
235 | invalid_at: datetime | None = Field(
236 | default=None, description='datetime of when the fact stopped being true'
237 | )
238 | attributes: dict[str, Any] = Field(
239 | default={}, description='Additional attributes of the edge. Dependent on edge name'
240 | )
241 |
242 | async def generate_embedding(self, embedder: EmbedderClient):
243 | start = time()
244 |
245 | text = self.fact.replace('\n', ' ')
246 | self.fact_embedding = await embedder.create(input_data=[text])
247 |
248 | end = time()
249 | logger.debug(f'embedded {text} in {end - start} ms')
250 |
251 | return self.fact_embedding
252 |
253 | async def load_fact_embedding(self, driver: GraphDriver):
254 | if driver.graph_operations_interface:
255 | return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256 |
257 | query = """
258 | MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
259 | RETURN e.fact_embedding AS fact_embedding
260 | """
261 |
262 | if driver.provider == GraphProvider.NEPTUNE:
263 | query = """
264 | MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
265 | RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
266 | """
267 |
268 | if driver.provider == GraphProvider.KUZU:
269 | query = """
270 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
271 | RETURN e.fact_embedding AS fact_embedding
272 | """
273 |
274 | records, _, _ = await driver.execute_query(
275 | query,
276 | uuid=self.uuid,
277 | routing_='r',
278 | )
279 |
280 | if len(records) == 0:
281 | raise EdgeNotFoundError(self.uuid)
282 |
283 | self.fact_embedding = records[0]['fact_embedding']
284 |
285 | async def save(self, driver: GraphDriver):
286 | edge_data: dict[str, Any] = {
287 | 'source_uuid': self.source_node_uuid,
288 | 'target_uuid': self.target_node_uuid,
289 | 'uuid': self.uuid,
290 | 'name': self.name,
291 | 'group_id': self.group_id,
292 | 'fact': self.fact,
293 | 'fact_embedding': self.fact_embedding,
294 | 'episodes': self.episodes,
295 | 'created_at': self.created_at,
296 | 'expired_at': self.expired_at,
297 | 'valid_at': self.valid_at,
298 | 'invalid_at': self.invalid_at,
299 | }
300 |
301 | if driver.provider == GraphProvider.KUZU:
302 | edge_data['attributes'] = json.dumps(self.attributes)
303 | result = await driver.execute_query(
304 | get_entity_edge_save_query(driver.provider),
305 | **edge_data,
306 | )
307 | else:
308 | edge_data.update(self.attributes or {})
309 | result = await driver.execute_query(
310 | get_entity_edge_save_query(driver.provider),
311 | edge_data=edge_data,
312 | )
313 |
314 | logger.debug(f'Saved edge to Graph: {self.uuid}')
315 |
316 | return result
317 |
318 | @classmethod
319 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
320 | match_query = """
321 | MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
322 | """
323 | if driver.provider == GraphProvider.KUZU:
324 | match_query = """
325 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
326 | """
327 |
328 | records, _, _ = await driver.execute_query(
329 | match_query
330 | + """
331 | RETURN
332 | """
333 | + get_entity_edge_return_query(driver.provider),
334 | uuid=uuid,
335 | routing_='r',
336 | )
337 |
338 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
339 |
340 | if len(edges) == 0:
341 | raise EdgeNotFoundError(uuid)
342 | return edges[0]
343 |
344 | @classmethod
345 | async def get_between_nodes(
346 | cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
347 | ):
348 | match_query = """
349 | MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
350 | """
351 | if driver.provider == GraphProvider.KUZU:
352 | match_query = """
353 | MATCH (n:Entity {uuid: $source_node_uuid})
354 | -[:RELATES_TO]->(e:RelatesToNode_)
355 | -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
356 | """
357 |
358 | records, _, _ = await driver.execute_query(
359 | match_query
360 | + """
361 | RETURN
362 | """
363 | + get_entity_edge_return_query(driver.provider),
364 | source_node_uuid=source_node_uuid,
365 | target_node_uuid=target_node_uuid,
366 | routing_='r',
367 | )
368 |
369 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
370 |
371 | return edges
372 |
373 | @classmethod
374 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
375 | if len(uuids) == 0:
376 | return []
377 |
378 | match_query = """
379 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
380 | """
381 | if driver.provider == GraphProvider.KUZU:
382 | match_query = """
383 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
384 | """
385 |
386 | records, _, _ = await driver.execute_query(
387 | match_query
388 | + """
389 | WHERE e.uuid IN $uuids
390 | RETURN
391 | """
392 | + get_entity_edge_return_query(driver.provider),
393 | uuids=uuids,
394 | routing_='r',
395 | )
396 |
397 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
398 |
399 | return edges
400 |
401 | @classmethod
402 | async def get_by_group_ids(
403 | cls,
404 | driver: GraphDriver,
405 | group_ids: list[str],
406 | limit: int | None = None,
407 | uuid_cursor: str | None = None,
408 | with_embeddings: bool = False,
409 | ):
410 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
411 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
412 | with_embeddings_query: LiteralString = (
413 | """,
414 | e.fact_embedding AS fact_embedding
415 | """
416 | if with_embeddings
417 | else ''
418 | )
419 |
420 | match_query = """
421 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
422 | """
423 | if driver.provider == GraphProvider.KUZU:
424 | match_query = """
425 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
426 | """
427 |
428 | records, _, _ = await driver.execute_query(
429 | match_query
430 | + """
431 | WHERE e.group_id IN $group_ids
432 | """
433 | + cursor_query
434 | + """
435 | RETURN
436 | """
437 | + get_entity_edge_return_query(driver.provider)
438 | + with_embeddings_query
439 | + """
440 | ORDER BY e.uuid DESC
441 | """
442 | + limit_query,
443 | group_ids=group_ids,
444 | uuid=uuid_cursor,
445 | limit=limit,
446 | routing_='r',
447 | )
448 |
449 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
450 |
451 | if len(edges) == 0:
452 | raise GroupsEdgesNotFoundError(group_ids)
453 | return edges
454 |
455 | @classmethod
456 | async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
457 | match_query = """
458 | MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
459 | """
460 | if driver.provider == GraphProvider.KUZU:
461 | match_query = """
462 | MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
463 | """
464 |
465 | records, _, _ = await driver.execute_query(
466 | match_query
467 | + """
468 | RETURN
469 | """
470 | + get_entity_edge_return_query(driver.provider),
471 | node_uuid=node_uuid,
472 | routing_='r',
473 | )
474 |
475 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
476 |
477 | return edges
478 |
479 |
480 | class CommunityEdge(Edge):
481 | async def save(self, driver: GraphDriver):
482 | result = await driver.execute_query(
483 | get_community_edge_save_query(driver.provider),
484 | community_uuid=self.source_node_uuid,
485 | entity_uuid=self.target_node_uuid,
486 | uuid=self.uuid,
487 | group_id=self.group_id,
488 | created_at=self.created_at,
489 | )
490 |
491 | logger.debug(f'Saved edge to Graph: {self.uuid}')
492 |
493 | return result
494 |
495 | @classmethod
496 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
497 | records, _, _ = await driver.execute_query(
498 | """
499 | MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
500 | RETURN
501 | """
502 | + COMMUNITY_EDGE_RETURN,
503 | uuid=uuid,
504 | routing_='r',
505 | )
506 |
507 | edges = [get_community_edge_from_record(record) for record in records]
508 |
509 | return edges[0]
510 |
511 | @classmethod
512 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
513 | records, _, _ = await driver.execute_query(
514 | """
515 | MATCH (n:Community)-[e:HAS_MEMBER]->(m)
516 | WHERE e.uuid IN $uuids
517 | RETURN
518 | """
519 | + COMMUNITY_EDGE_RETURN,
520 | uuids=uuids,
521 | routing_='r',
522 | )
523 |
524 | edges = [get_community_edge_from_record(record) for record in records]
525 |
526 | return edges
527 |
528 | @classmethod
529 | async def get_by_group_ids(
530 | cls,
531 | driver: GraphDriver,
532 | group_ids: list[str],
533 | limit: int | None = None,
534 | uuid_cursor: str | None = None,
535 | ):
536 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
537 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
538 |
539 | records, _, _ = await driver.execute_query(
540 | """
541 | MATCH (n:Community)-[e:HAS_MEMBER]->(m)
542 | WHERE e.group_id IN $group_ids
543 | """
544 | + cursor_query
545 | + """
546 | RETURN
547 | """
548 | + COMMUNITY_EDGE_RETURN
549 | + """
550 | ORDER BY e.uuid DESC
551 | """
552 | + limit_query,
553 | group_ids=group_ids,
554 | uuid=uuid_cursor,
555 | limit=limit,
556 | routing_='r',
557 | )
558 |
559 | edges = [get_community_edge_from_record(record) for record in records]
560 |
561 | return edges
562 |
563 |
564 | # Edge helpers
565 | def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
566 | return EpisodicEdge(
567 | uuid=record['uuid'],
568 | group_id=record['group_id'],
569 | source_node_uuid=record['source_node_uuid'],
570 | target_node_uuid=record['target_node_uuid'],
571 | created_at=parse_db_date(record['created_at']), # type: ignore
572 | )
573 |
574 |
575 | def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
576 | episodes = record['episodes']
577 | if provider == GraphProvider.KUZU:
578 | attributes = json.loads(record['attributes']) if record['attributes'] else {}
579 | else:
580 | attributes = record['attributes']
581 | attributes.pop('uuid', None)
582 | attributes.pop('source_node_uuid', None)
583 | attributes.pop('target_node_uuid', None)
584 | attributes.pop('fact', None)
585 | attributes.pop('fact_embedding', None)
586 | attributes.pop('name', None)
587 | attributes.pop('group_id', None)
588 | attributes.pop('episodes', None)
589 | attributes.pop('created_at', None)
590 | attributes.pop('expired_at', None)
591 | attributes.pop('valid_at', None)
592 | attributes.pop('invalid_at', None)
593 |
594 | edge = EntityEdge(
595 | uuid=record['uuid'],
596 | source_node_uuid=record['source_node_uuid'],
597 | target_node_uuid=record['target_node_uuid'],
598 | fact=record['fact'],
599 | fact_embedding=record.get('fact_embedding'),
600 | name=record['name'],
601 | group_id=record['group_id'],
602 | episodes=episodes,
603 | created_at=parse_db_date(record['created_at']), # type: ignore
604 | expired_at=parse_db_date(record['expired_at']),
605 | valid_at=parse_db_date(record['valid_at']),
606 | invalid_at=parse_db_date(record['invalid_at']),
607 | attributes=attributes,
608 | )
609 |
610 | return edge
611 |
612 |
613 | def get_community_edge_from_record(record: Any):
614 | return CommunityEdge(
615 | uuid=record['uuid'],
616 | group_id=record['group_id'],
617 | source_node_uuid=record['source_node_uuid'],
618 | target_node_uuid=record['target_node_uuid'],
619 | created_at=parse_db_date(record['created_at']), # type: ignore
620 | )
621 |
622 |
623 | async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
624 | # filter out falsey values from edges
625 | filtered_edges = [edge for edge in edges if edge.fact]
626 |
627 | if len(filtered_edges) == 0:
628 | return
629 | fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
630 | for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
631 | edge.fact_embedding = fact_embedding
632 |
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------
```python
1 | import logging
2 | from collections import defaultdict
3 | from unittest.mock import AsyncMock, MagicMock
4 |
5 | import pytest
6 |
7 | from graphiti_core.graphiti_types import GraphitiClients
8 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
9 | from graphiti_core.search.search_config import SearchResults
10 | from graphiti_core.utils.datetime_utils import utc_now
11 | from graphiti_core.utils.maintenance.dedup_helpers import (
12 | DedupCandidateIndexes,
13 | DedupResolutionState,
14 | _build_candidate_indexes,
15 | _cached_shingles,
16 | _has_high_entropy,
17 | _hash_shingle,
18 | _jaccard_similarity,
19 | _lsh_bands,
20 | _minhash_signature,
21 | _name_entropy,
22 | _normalize_name_for_fuzzy,
23 | _normalize_string_exact,
24 | _resolve_with_similarity,
25 | _shingles,
26 | )
27 | from graphiti_core.utils.maintenance.node_operations import (
28 | _collect_candidate_nodes,
29 | _resolve_with_llm,
30 | extract_attributes_from_node,
31 | extract_attributes_from_nodes,
32 | resolve_extracted_nodes,
33 | )
34 |
35 |
36 | def _make_clients():
37 | driver = MagicMock()
38 | embedder = MagicMock()
39 | cross_encoder = MagicMock()
40 | llm_client = MagicMock()
41 | llm_generate = AsyncMock()
42 | llm_client.generate_response = llm_generate
43 |
44 | clients = GraphitiClients.model_construct( # bypass validation to allow test doubles
45 | driver=driver,
46 | embedder=embedder,
47 | cross_encoder=cross_encoder,
48 | llm_client=llm_client,
49 | )
50 |
51 | return clients, llm_generate
52 |
53 |
54 | def _make_episode(group_id: str = 'group'):
55 | return EpisodicNode(
56 | name='episode',
57 | group_id=group_id,
58 | source=EpisodeType.message,
59 | source_description='test',
60 | content='content',
61 | valid_at=utc_now(),
62 | )
63 |
64 |
65 | @pytest.mark.asyncio
66 | async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
67 | clients, llm_generate = _make_clients()
68 |
69 | candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
70 | extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
71 |
72 | async def fake_search(*_, **__):
73 | return SearchResults(nodes=[candidate])
74 |
75 | monkeypatch.setattr(
76 | 'graphiti_core.utils.maintenance.node_operations.search',
77 | fake_search,
78 | )
79 | monkeypatch.setattr(
80 | 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
81 | AsyncMock(return_value=[]),
82 | )
83 |
84 | resolved, uuid_map, _ = await resolve_extracted_nodes(
85 | clients,
86 | [extracted],
87 | episode=_make_episode(),
88 | previous_episodes=[],
89 | )
90 |
91 | assert resolved[0].uuid == candidate.uuid
92 | assert uuid_map[extracted.uuid] == candidate.uuid
93 | llm_generate.assert_not_awaited()
94 |
95 |
96 | @pytest.mark.asyncio
97 | async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
98 | clients, llm_generate = _make_clients()
99 | llm_generate.return_value = {
100 | 'entity_resolutions': [
101 | {
102 | 'id': 0,
103 | 'duplicate_idx': -1,
104 | 'name': 'Joe',
105 | 'duplicates': [],
106 | }
107 | ]
108 | }
109 |
110 | extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])
111 |
112 | async def fake_search(*_, **__):
113 | return SearchResults(nodes=[])
114 |
115 | monkeypatch.setattr(
116 | 'graphiti_core.utils.maintenance.node_operations.search',
117 | fake_search,
118 | )
119 | monkeypatch.setattr(
120 | 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
121 | AsyncMock(return_value=[]),
122 | )
123 |
124 | resolved, uuid_map, _ = await resolve_extracted_nodes(
125 | clients,
126 | [extracted],
127 | episode=_make_episode(),
128 | previous_episodes=[],
129 | )
130 |
131 | assert resolved[0].uuid == extracted.uuid
132 | assert uuid_map[extracted.uuid] == extracted.uuid
133 | llm_generate.assert_awaited()
134 |
135 |
136 | @pytest.mark.asyncio
137 | async def test_resolve_nodes_fuzzy_match(monkeypatch):
138 | clients, llm_generate = _make_clients()
139 |
140 | candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
141 | extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
142 |
143 | async def fake_search(*_, **__):
144 | return SearchResults(nodes=[candidate])
145 |
146 | monkeypatch.setattr(
147 | 'graphiti_core.utils.maintenance.node_operations.search',
148 | fake_search,
149 | )
150 | monkeypatch.setattr(
151 | 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
152 | AsyncMock(return_value=[]),
153 | )
154 |
155 | resolved, uuid_map, _ = await resolve_extracted_nodes(
156 | clients,
157 | [extracted],
158 | episode=_make_episode(),
159 | previous_episodes=[],
160 | )
161 |
162 | assert resolved[0].uuid == candidate.uuid
163 | assert uuid_map[extracted.uuid] == candidate.uuid
164 | llm_generate.assert_not_awaited()
165 |
166 |
167 | @pytest.mark.asyncio
168 | async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
169 | clients, _ = _make_clients()
170 |
171 | candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
172 | override_duplicate = EntityNode(
173 | uuid=candidate.uuid,
174 | name='Alice Alt',
175 | group_id='group',
176 | labels=['Entity'],
177 | )
178 | extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])
179 |
180 | search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
181 | monkeypatch.setattr(
182 | 'graphiti_core.utils.maintenance.node_operations.search',
183 | search_mock,
184 | )
185 |
186 | result = await _collect_candidate_nodes(
187 | clients,
188 | [extracted],
189 | existing_nodes_override=[override_duplicate],
190 | )
191 |
192 | assert len(result) == 1
193 | assert result[0].uuid == candidate.uuid
194 | search_mock.assert_awaited()
195 |
196 |
197 | def test_build_candidate_indexes_populates_structures():
198 | candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])
199 |
200 | indexes = _build_candidate_indexes([candidate])
201 |
202 | normalized_key = candidate.name.lower()
203 | assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
204 | assert indexes.nodes_by_uuid[candidate.uuid] is candidate
205 | assert candidate.uuid in indexes.shingles_by_candidate
206 | assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())
207 |
208 |
209 | def test_normalize_helpers():
210 | assert _normalize_string_exact(' Alice Smith ') == 'alice smith'
211 | assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'
212 |
213 |
214 | def test_name_entropy_variants():
215 | assert _name_entropy('alice') > _name_entropy('aaaaa')
216 | assert _name_entropy('') == 0.0
217 |
218 |
219 | def test_has_high_entropy_rules():
220 | assert _has_high_entropy('meaningful name') is True
221 | assert _has_high_entropy('aa') is False
222 |
223 |
224 | def test_shingles_and_cache():
225 | raw = 'alice'
226 | shingle_set = _shingles(raw)
227 | assert shingle_set == {'ali', 'lic', 'ice'}
228 | assert _cached_shingles(raw) == shingle_set
229 | assert _cached_shingles(raw) is _cached_shingles(raw)
230 |
231 |
232 | def test_hash_minhash_and_lsh():
233 | shingles = {'abc', 'bcd', 'cde'}
234 | signature = _minhash_signature(shingles)
235 | assert len(signature) == 32
236 | bands = _lsh_bands(signature)
237 | assert all(len(band) == 4 for band in bands)
238 | hashed = {_hash_shingle(s, 0) for s in shingles}
239 | assert len(hashed) == len(shingles)
240 |
241 |
242 | def test_jaccard_similarity_edges():
243 | a = {'a', 'b'}
244 | b = {'a', 'c'}
245 | assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
246 | assert _jaccard_similarity(set(), set()) == 1.0
247 | assert _jaccard_similarity(a, set()) == 0.0
248 |
249 |
250 | def test_resolve_with_similarity_exact_match_updates_state():
251 | candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
252 | extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
253 |
254 | indexes = _build_candidate_indexes([candidate])
255 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
256 |
257 | _resolve_with_similarity([extracted], indexes, state)
258 |
259 | assert state.resolved_nodes[0].uuid == candidate.uuid
260 | assert state.uuid_map[extracted.uuid] == candidate.uuid
261 | assert state.unresolved_indices == []
262 | assert state.duplicate_pairs == [(extracted, candidate)]
263 |
264 |
265 | def test_resolve_with_similarity_low_entropy_defers_resolution():
266 | extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
267 | indexes = DedupCandidateIndexes(
268 | existing_nodes=[],
269 | nodes_by_uuid={},
270 | normalized_existing=defaultdict(list),
271 | shingles_by_candidate={},
272 | lsh_buckets=defaultdict(list),
273 | )
274 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
275 |
276 | _resolve_with_similarity([extracted], indexes, state)
277 |
278 | assert state.resolved_nodes[0] is None
279 | assert state.unresolved_indices == [0]
280 | assert state.duplicate_pairs == []
281 |
282 |
283 | def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
284 | candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
285 | candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
286 | extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
287 |
288 | indexes = _build_candidate_indexes([candidate1, candidate2])
289 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
290 |
291 | _resolve_with_similarity([extracted], indexes, state)
292 |
293 | assert state.resolved_nodes[0] is None
294 | assert state.unresolved_indices == [0]
295 | assert state.duplicate_pairs == []
296 |
297 |
298 | @pytest.mark.asyncio
299 | async def test_resolve_with_llm_updates_unresolved(monkeypatch):
300 | extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
301 | candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
302 |
303 | indexes = _build_candidate_indexes([candidate])
304 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
305 |
306 | captured_context = {}
307 |
308 | def fake_prompt_nodes(context):
309 | captured_context.update(context)
310 | return ['prompt']
311 |
312 | monkeypatch.setattr(
313 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
314 | fake_prompt_nodes,
315 | )
316 |
317 | async def fake_generate_response(*_, **__):
318 | return {
319 | 'entity_resolutions': [
320 | {
321 | 'id': 0,
322 | 'duplicate_idx': 0,
323 | 'name': 'Dizzy Gillespie',
324 | 'duplicates': [0],
325 | }
326 | ]
327 | }
328 |
329 | llm_client = MagicMock()
330 | llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)
331 |
332 | await _resolve_with_llm(
333 | llm_client,
334 | [extracted],
335 | indexes,
336 | state,
337 | episode=_make_episode(),
338 | previous_episodes=[],
339 | entity_types=None,
340 | )
341 |
342 | assert state.resolved_nodes[0].uuid == candidate.uuid
343 | assert state.uuid_map[extracted.uuid] == candidate.uuid
344 | assert captured_context['existing_nodes'][0]['idx'] == 0
345 | assert isinstance(captured_context['existing_nodes'], list)
346 | assert state.duplicate_pairs == [(extracted, candidate)]
347 |
348 |
349 | @pytest.mark.asyncio
350 | async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
351 | extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
352 |
353 | indexes = _build_candidate_indexes([])
354 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
355 |
356 | monkeypatch.setattr(
357 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
358 | lambda context: ['prompt'],
359 | )
360 |
361 | llm_client = MagicMock()
362 | llm_client.generate_response = AsyncMock(
363 | return_value={
364 | 'entity_resolutions': [
365 | {
366 | 'id': 5,
367 | 'duplicate_idx': -1,
368 | 'name': 'Dexter',
369 | 'duplicates': [],
370 | }
371 | ]
372 | }
373 | )
374 |
375 | with caplog.at_level(logging.WARNING):
376 | await _resolve_with_llm(
377 | llm_client,
378 | [extracted],
379 | indexes,
380 | state,
381 | episode=_make_episode(),
382 | previous_episodes=[],
383 | entity_types=None,
384 | )
385 |
386 | assert state.resolved_nodes[0] is None
387 | assert 'Skipping invalid LLM dedupe id 5' in caplog.text
388 |
389 |
390 | @pytest.mark.asyncio
391 | async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
392 | extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
393 | candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
394 |
395 | indexes = _build_candidate_indexes([candidate])
396 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
397 |
398 | monkeypatch.setattr(
399 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
400 | lambda context: ['prompt'],
401 | )
402 |
403 | llm_client = MagicMock()
404 | llm_client.generate_response = AsyncMock(
405 | return_value={
406 | 'entity_resolutions': [
407 | {
408 | 'id': 0,
409 | 'duplicate_idx': 0,
410 | 'name': 'Dizzy Gillespie',
411 | 'duplicates': [0],
412 | },
413 | {
414 | 'id': 0,
415 | 'duplicate_idx': -1,
416 | 'name': 'Dizzy',
417 | 'duplicates': [],
418 | },
419 | ]
420 | }
421 | )
422 |
423 | await _resolve_with_llm(
424 | llm_client,
425 | [extracted],
426 | indexes,
427 | state,
428 | episode=_make_episode(),
429 | previous_episodes=[],
430 | entity_types=None,
431 | )
432 |
433 | assert state.resolved_nodes[0].uuid == candidate.uuid
434 | assert state.uuid_map[extracted.uuid] == candidate.uuid
435 | assert state.duplicate_pairs == [(extracted, candidate)]
436 |
437 |
438 | @pytest.mark.asyncio
439 | async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
440 | extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
441 |
442 | indexes = _build_candidate_indexes([])
443 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
444 |
445 | monkeypatch.setattr(
446 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
447 | lambda context: ['prompt'],
448 | )
449 |
450 | llm_client = MagicMock()
451 | llm_client.generate_response = AsyncMock(
452 | return_value={
453 | 'entity_resolutions': [
454 | {
455 | 'id': 0,
456 | 'duplicate_idx': 10,
457 | 'name': 'Dexter',
458 | 'duplicates': [],
459 | }
460 | ]
461 | }
462 | )
463 |
464 | await _resolve_with_llm(
465 | llm_client,
466 | [extracted],
467 | indexes,
468 | state,
469 | episode=_make_episode(),
470 | previous_episodes=[],
471 | entity_types=None,
472 | )
473 |
474 | assert state.resolved_nodes[0] == extracted
475 | assert state.uuid_map[extracted.uuid] == extracted.uuid
476 | assert state.duplicate_pairs == []
477 |
478 |
479 | @pytest.mark.asyncio
480 | async def test_extract_attributes_without_callback_generates_summary():
481 | """Test that summary is generated when no callback is provided (default behavior)."""
482 | llm_client = MagicMock()
483 | llm_client.generate_response = AsyncMock(
484 | return_value={'summary': 'Generated summary', 'attributes': {}}
485 | )
486 |
487 | node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
488 | episode = _make_episode()
489 |
490 | result = await extract_attributes_from_node(
491 | llm_client,
492 | node,
493 | episode=episode,
494 | previous_episodes=[],
495 | entity_type=None,
496 | should_summarize_node=None, # No callback provided
497 | )
498 |
499 | # Summary should be generated
500 | assert result.summary == 'Generated summary'
501 | # LLM should have been called for summary
502 | assert llm_client.generate_response.call_count == 1
503 |
504 |
505 | @pytest.mark.asyncio
506 | async def test_extract_attributes_with_callback_skip_summary():
507 | """Test that summary is NOT regenerated when callback returns False."""
508 | llm_client = MagicMock()
509 | llm_client.generate_response = AsyncMock(
510 | return_value={'summary': 'This should not be used', 'attributes': {}}
511 | )
512 |
513 | node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
514 | episode = _make_episode()
515 |
516 | # Callback that always returns False (skip summary generation)
517 | async def skip_summary_filter(node: EntityNode) -> bool:
518 | return False
519 |
520 | result = await extract_attributes_from_node(
521 | llm_client,
522 | node,
523 | episode=episode,
524 | previous_episodes=[],
525 | entity_type=None,
526 | should_summarize_node=skip_summary_filter,
527 | )
528 |
529 | # Summary should remain unchanged
530 | assert result.summary == 'Old summary'
531 | # LLM should NOT have been called for summary
532 | assert llm_client.generate_response.call_count == 0
533 |
534 |
535 | @pytest.mark.asyncio
536 | async def test_extract_attributes_with_callback_generate_summary():
537 | """Test that summary is regenerated when callback returns True."""
538 | llm_client = MagicMock()
539 | llm_client.generate_response = AsyncMock(
540 | return_value={'summary': 'New generated summary', 'attributes': {}}
541 | )
542 |
543 | node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
544 | episode = _make_episode()
545 |
546 | # Callback that always returns True (generate summary)
547 | async def generate_summary_filter(node: EntityNode) -> bool:
548 | return True
549 |
550 | result = await extract_attributes_from_node(
551 | llm_client,
552 | node,
553 | episode=episode,
554 | previous_episodes=[],
555 | entity_type=None,
556 | should_summarize_node=generate_summary_filter,
557 | )
558 |
559 | # Summary should be updated
560 | assert result.summary == 'New generated summary'
561 | # LLM should have been called for summary
562 | assert llm_client.generate_response.call_count == 1
563 |
564 |
565 | @pytest.mark.asyncio
566 | async def test_extract_attributes_with_selective_callback():
567 | """Test callback that selectively skips summaries based on node properties."""
568 | llm_client = MagicMock()
569 | llm_client.generate_response = AsyncMock(
570 | return_value={'summary': 'Generated summary', 'attributes': {}}
571 | )
572 |
573 | user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
574 | topic_node = EntityNode(
575 | name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
576 | )
577 |
578 | episode = _make_episode()
579 |
580 | # Callback that skips User nodes but generates for others
581 | async def selective_filter(node: EntityNode) -> bool:
582 | return 'User' not in node.labels
583 |
584 | result_user = await extract_attributes_from_node(
585 | llm_client,
586 | user_node,
587 | episode=episode,
588 | previous_episodes=[],
589 | entity_type=None,
590 | should_summarize_node=selective_filter,
591 | )
592 |
593 | result_topic = await extract_attributes_from_node(
594 | llm_client,
595 | topic_node,
596 | episode=episode,
597 | previous_episodes=[],
598 | entity_type=None,
599 | should_summarize_node=selective_filter,
600 | )
601 |
602 | # User summary should remain unchanged
603 | assert result_user.summary == 'Old'
604 | # Topic summary should be generated
605 | assert result_topic.summary == 'Generated summary'
606 | # LLM should have been called only once (for topic)
607 | assert llm_client.generate_response.call_count == 1
608 |
609 |
610 | @pytest.mark.asyncio
611 | async def test_extract_attributes_from_nodes_with_callback():
612 | """Test that callback is properly passed through extract_attributes_from_nodes."""
613 | clients, _ = _make_clients()
614 | clients.llm_client.generate_response = AsyncMock(
615 | return_value={'summary': 'New summary', 'attributes': {}}
616 | )
617 | clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
618 | clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
619 |
620 | node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
621 | node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
622 |
623 | episode = _make_episode()
624 |
625 | call_tracker = []
626 |
627 | # Callback that tracks which nodes it's called with
628 | async def tracking_filter(node: EntityNode) -> bool:
629 | call_tracker.append(node.name)
630 | return 'User' not in node.labels
631 |
632 | results = await extract_attributes_from_nodes(
633 | clients,
634 | [node1, node2],
635 | episode=episode,
636 | previous_episodes=[],
637 | entity_types=None,
638 | should_summarize_node=tracking_filter,
639 | )
640 |
641 | # Callback should have been called for both nodes
642 | assert len(call_tracker) == 2
643 | assert 'Node1' in call_tracker
644 | assert 'Node2' in call_tracker
645 |
646 | # Node1 (User) should keep old summary, Node2 (Topic) should get new summary
647 | node1_result = next(n for n in results if n.name == 'Node1')
648 | node2_result = next(n for n in results if n.name == 'Node2')
649 |
650 | assert node1_result.summary == 'Old1'
651 | assert node2_result.summary == 'New summary'
652 |
```
--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
18 |
19 | from unittest.mock import AsyncMock, MagicMock, patch
20 |
21 | import pytest
22 | from pydantic import BaseModel
23 |
24 | from graphiti_core.llm_client.config import LLMConfig, ModelSize
25 | from graphiti_core.llm_client.errors import RateLimitError
26 | from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
27 | from graphiti_core.prompts.models import Message
28 |
29 |
30 | # Test model for response testing
31 | class ResponseModel(BaseModel):
32 | """Test model for response testing."""
33 |
34 | test_field: str
35 | optional_field: int = 0
36 |
37 |
38 | @pytest.fixture
39 | def mock_gemini_client():
40 | """Fixture to mock the Google Gemini client."""
41 | with patch('google.genai.Client') as mock_client:
42 | # Setup mock instance and its methods
43 | mock_instance = mock_client.return_value
44 | mock_instance.aio = MagicMock()
45 | mock_instance.aio.models = MagicMock()
46 | mock_instance.aio.models.generate_content = AsyncMock()
47 | yield mock_instance
48 |
49 |
50 | @pytest.fixture
51 | def gemini_client(mock_gemini_client):
52 | """Fixture to create a GeminiClient with a mocked client."""
53 | config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
54 | client = GeminiClient(config=config, cache=False)
55 | # Replace the client's client with our mock to ensure we're using the mock
56 | client.client = mock_gemini_client
57 | return client
58 |
59 |
60 | class TestGeminiClientInitialization:
61 | """Tests for GeminiClient initialization."""
62 |
63 | @patch('google.genai.Client')
64 | def test_init_with_config(self, mock_client):
65 | """Test initialization with a config object."""
66 | config = LLMConfig(
67 | api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
68 | )
69 | client = GeminiClient(config=config, cache=False, max_tokens=1000)
70 |
71 | assert client.config == config
72 | assert client.model == 'test-model'
73 | assert client.temperature == 0.5
74 | assert client.max_tokens == 1000
75 |
76 | @patch('google.genai.Client')
77 | def test_init_with_default_model(self, mock_client):
78 | """Test initialization with default model when none is provided."""
79 | config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
80 | client = GeminiClient(config=config, cache=False)
81 |
82 | assert client.model == DEFAULT_MODEL
83 |
84 | @patch('google.genai.Client')
85 | def test_init_without_config(self, mock_client):
86 | """Test initialization without a config uses defaults."""
87 | client = GeminiClient(cache=False)
88 |
89 | assert client.config is not None
90 | # When no config.model is set, it will be None, not DEFAULT_MODEL
91 | assert client.model is None
92 |
93 | @patch('google.genai.Client')
94 | def test_init_with_thinking_config(self, mock_client):
95 | """Test initialization with thinking config."""
96 | with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
97 | thinking_config = mock_thinking_config.return_value
98 | client = GeminiClient(thinking_config=thinking_config)
99 | assert client.thinking_config == thinking_config
100 |
101 |
102 | class TestGeminiClientGenerateResponse:
103 | """Tests for GeminiClient generate_response method."""
104 |
105 | @pytest.mark.asyncio
106 | async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
107 | """Test successful response generation with simple text."""
108 | # Setup mock response
109 | mock_response = MagicMock()
110 | mock_response.text = 'Test response text'
111 | mock_response.candidates = []
112 | mock_response.prompt_feedback = None
113 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
114 |
115 | # Call method
116 | messages = [Message(role='user', content='Test message')]
117 | result = await gemini_client.generate_response(messages)
118 |
119 | # Assertions
120 | assert isinstance(result, dict)
121 | assert result['content'] == 'Test response text'
122 | mock_gemini_client.aio.models.generate_content.assert_called_once()
123 |
124 | @pytest.mark.asyncio
125 | async def test_generate_response_with_structured_output(
126 | self, gemini_client, mock_gemini_client
127 | ):
128 | """Test response generation with structured output."""
129 | # Setup mock response
130 | mock_response = MagicMock()
131 | mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
132 | mock_response.candidates = []
133 | mock_response.prompt_feedback = None
134 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
135 |
136 | # Call method
137 | messages = [
138 | Message(role='system', content='System message'),
139 | Message(role='user', content='User message'),
140 | ]
141 | result = await gemini_client.generate_response(
142 | messages=messages, response_model=ResponseModel
143 | )
144 |
145 | # Assertions
146 | assert isinstance(result, dict)
147 | assert result['test_field'] == 'test_value'
148 | assert result['optional_field'] == 42
149 | mock_gemini_client.aio.models.generate_content.assert_called_once()
150 |
151 | @pytest.mark.asyncio
152 | async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
153 | """Test response generation with system message handling."""
154 | # Setup mock response
155 | mock_response = MagicMock()
156 | mock_response.text = 'Response with system context'
157 | mock_response.candidates = []
158 | mock_response.prompt_feedback = None
159 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
160 |
161 | # Call method
162 | messages = [
163 | Message(role='system', content='System message'),
164 | Message(role='user', content='User message'),
165 | ]
166 | await gemini_client.generate_response(messages)
167 |
168 | # Verify system message is processed correctly
169 | call_args = mock_gemini_client.aio.models.generate_content.call_args
170 | config = call_args[1]['config']
171 | assert 'System message' in config.system_instruction
172 |
173 | @pytest.mark.asyncio
174 | async def test_get_model_for_size(self, gemini_client):
175 | """Test model selection based on size."""
176 | # Test small model
177 | small_model = gemini_client._get_model_for_size(ModelSize.small)
178 | assert small_model == DEFAULT_SMALL_MODEL
179 |
180 | # Test medium/large model
181 | medium_model = gemini_client._get_model_for_size(ModelSize.medium)
182 | assert medium_model == gemini_client.model
183 |
184 | @pytest.mark.asyncio
185 | async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
186 | """Test handling of rate limit errors."""
187 | # Setup mock to raise rate limit error
188 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
189 | 'Rate limit exceeded'
190 | )
191 |
192 | # Call method and check exception
193 | messages = [Message(role='user', content='Test message')]
194 | with pytest.raises(RateLimitError):
195 | await gemini_client.generate_response(messages)
196 |
197 | @pytest.mark.asyncio
198 | async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
199 | """Test handling of quota errors."""
200 | # Setup mock to raise quota error
201 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
202 | 'Quota exceeded for requests'
203 | )
204 |
205 | # Call method and check exception
206 | messages = [Message(role='user', content='Test message')]
207 | with pytest.raises(RateLimitError):
208 | await gemini_client.generate_response(messages)
209 |
210 | @pytest.mark.asyncio
211 | async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
212 | """Test handling of resource exhausted errors."""
213 | # Setup mock to raise resource exhausted error
214 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
215 | 'resource_exhausted: Request limit exceeded'
216 | )
217 |
218 | # Call method and check exception
219 | messages = [Message(role='user', content='Test message')]
220 | with pytest.raises(RateLimitError):
221 | await gemini_client.generate_response(messages)
222 |
223 | @pytest.mark.asyncio
224 | async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
225 | """Test handling of safety blocks."""
226 | # Setup mock response with safety block
227 | mock_candidate = MagicMock()
228 | mock_candidate.finish_reason = 'SAFETY'
229 | mock_candidate.safety_ratings = [
230 | MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
231 | ]
232 |
233 | mock_response = MagicMock()
234 | mock_response.candidates = [mock_candidate]
235 | mock_response.prompt_feedback = None
236 | mock_response.text = ''
237 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
238 |
239 | # Call method and check exception
240 | messages = [Message(role='user', content='Test message')]
241 | with pytest.raises(Exception, match='Content blocked by safety filters'):
242 | await gemini_client.generate_response(messages)
243 |
244 | @pytest.mark.asyncio
245 | async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
246 | """Test handling of prompt blocks."""
247 | # Setup mock response with prompt block
248 | mock_prompt_feedback = MagicMock()
249 | mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
250 |
251 | mock_response = MagicMock()
252 | mock_response.candidates = []
253 | mock_response.prompt_feedback = mock_prompt_feedback
254 | mock_response.text = ''
255 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
256 |
257 | # Call method and check exception
258 | messages = [Message(role='user', content='Test message')]
259 | with pytest.raises(Exception, match='Content blocked by safety filters'):
260 | await gemini_client.generate_response(messages)
261 |
262 | @pytest.mark.asyncio
263 | async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
264 | """Test handling of structured output parsing errors."""
265 | # Setup mock response with invalid JSON that will exhaust retries
266 | mock_response = MagicMock()
267 | mock_response.text = 'Invalid JSON that cannot be parsed'
268 | mock_response.candidates = []
269 | mock_response.prompt_feedback = None
270 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
271 |
272 | # Call method and check exception - should exhaust retries
273 | messages = [Message(role='user', content='Test message')]
274 | with pytest.raises(Exception): # noqa: B017
275 | await gemini_client.generate_response(messages, response_model=ResponseModel)
276 |
277 | # Should have called generate_content MAX_RETRIES times (2 attempts total)
278 | assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
279 |
280 | @pytest.mark.asyncio
281 | async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
282 | """Test that safety blocks are not retried."""
283 | # Setup mock response with safety block
284 | mock_candidate = MagicMock()
285 | mock_candidate.finish_reason = 'SAFETY'
286 | mock_candidate.safety_ratings = [
287 | MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
288 | ]
289 |
290 | mock_response = MagicMock()
291 | mock_response.candidates = [mock_candidate]
292 | mock_response.prompt_feedback = None
293 | mock_response.text = ''
294 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
295 |
296 | # Call method and check that it doesn't retry
297 | messages = [Message(role='user', content='Test message')]
298 | with pytest.raises(Exception, match='Content blocked by safety filters'):
299 | await gemini_client.generate_response(messages)
300 |
301 | # Should only be called once (no retries for safety blocks)
302 | assert mock_gemini_client.aio.models.generate_content.call_count == 1
303 |
304 | @pytest.mark.asyncio
305 | async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
306 | """Test retry behavior on validation error."""
307 | # First call returns invalid JSON, second call returns valid data
308 | mock_response1 = MagicMock()
309 | mock_response1.text = 'Invalid JSON that cannot be parsed'
310 | mock_response1.candidates = []
311 | mock_response1.prompt_feedback = None
312 |
313 | mock_response2 = MagicMock()
314 | mock_response2.text = '{"test_field": "correct_value"}'
315 | mock_response2.candidates = []
316 | mock_response2.prompt_feedback = None
317 |
318 | mock_gemini_client.aio.models.generate_content.side_effect = [
319 | mock_response1,
320 | mock_response2,
321 | ]
322 |
323 | # Call method
324 | messages = [Message(role='user', content='Test message')]
325 | result = await gemini_client.generate_response(messages, response_model=ResponseModel)
326 |
327 | # Should have called generate_content twice due to retry
328 | assert mock_gemini_client.aio.models.generate_content.call_count == 2
329 | assert result['test_field'] == 'correct_value'
330 |
331 | @pytest.mark.asyncio
332 | async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
333 | """Test behavior when max retries are exceeded."""
334 | # Setup mock to always return invalid JSON
335 | mock_response = MagicMock()
336 | mock_response.text = 'Invalid JSON that cannot be parsed'
337 | mock_response.candidates = []
338 | mock_response.prompt_feedback = None
339 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 |
341 | # Call method and check exception
342 | messages = [Message(role='user', content='Test message')]
343 | with pytest.raises(Exception): # noqa: B017
344 | await gemini_client.generate_response(messages, response_model=ResponseModel)
345 |
346 | # Should have called generate_content MAX_RETRIES times (2 attempts total)
347 | assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
348 |
349 | @pytest.mark.asyncio
350 | async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
351 | """Test handling of empty responses."""
352 | # Setup mock response with no text
353 | mock_response = MagicMock()
354 | mock_response.text = ''
355 | mock_response.candidates = []
356 | mock_response.prompt_feedback = None
357 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
358 |
359 | # Call method with structured output and check exception
360 | messages = [Message(role='user', content='Test message')]
361 | with pytest.raises(Exception): # noqa: B017
362 | await gemini_client.generate_response(messages, response_model=ResponseModel)
363 |
364 | # Should have exhausted retries due to empty response (2 attempts total)
365 | assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
366 |
367 | @pytest.mark.asyncio
368 | async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
369 | """Test that explicit max_tokens parameter takes precedence over all other values."""
370 | # Setup mock response
371 | mock_response = MagicMock()
372 | mock_response.text = 'Test response'
373 | mock_response.candidates = []
374 | mock_response.prompt_feedback = None
375 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
376 |
377 | # Call method with custom max tokens (should take precedence)
378 | messages = [Message(role='user', content='Test message')]
379 | await gemini_client.generate_response(messages, max_tokens=500)
380 |
381 | # Verify explicit max_tokens parameter takes precedence
382 | call_args = mock_gemini_client.aio.models.generate_content.call_args
383 | config = call_args[1]['config']
384 | # Explicit parameter should override everything else
385 | assert config.max_output_tokens == 500
386 |
387 | @pytest.mark.asyncio
388 | async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
389 | """Test max_tokens precedence when no explicit parameter is provided."""
390 | # Setup mock response
391 | mock_response = MagicMock()
392 | mock_response.text = 'Test response'
393 | mock_response.candidates = []
394 | mock_response.prompt_feedback = None
395 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
396 |
397 | # Test case 1: No explicit max_tokens, has instance max_tokens
398 | config = LLMConfig(
399 | api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
400 | )
401 | client = GeminiClient(
402 | config=config, cache=False, max_tokens=2000, client=mock_gemini_client
403 | )
404 |
405 | messages = [Message(role='user', content='Test message')]
406 | await client.generate_response(messages)
407 |
408 | call_args = mock_gemini_client.aio.models.generate_content.call_args
409 | config = call_args[1]['config']
410 | # Instance max_tokens should be used
411 | assert config.max_output_tokens == 2000
412 |
413 | # Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
414 | config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
415 | client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
416 |
417 | messages = [Message(role='user', content='Test message')]
418 | await client.generate_response(messages)
419 |
420 | call_args = mock_gemini_client.aio.models.generate_content.call_args
421 | config = call_args[1]['config']
422 | # Model mapping should be used
423 | assert config.max_output_tokens == 65536
424 |
425 | @pytest.mark.asyncio
426 | async def test_model_size_selection(self, gemini_client, mock_gemini_client):
427 | """Test that the correct model is selected based on model size."""
428 | # Setup mock response
429 | mock_response = MagicMock()
430 | mock_response.text = 'Test response'
431 | mock_response.candidates = []
432 | mock_response.prompt_feedback = None
433 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
434 |
435 | # Call method with small model size
436 | messages = [Message(role='user', content='Test message')]
437 | await gemini_client.generate_response(messages, model_size=ModelSize.small)
438 |
439 | # Verify correct model is used
440 | call_args = mock_gemini_client.aio.models.generate_content.call_args
441 | assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
442 |
443 | @pytest.mark.asyncio
444 | async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
445 | """Test that different Gemini models use their correct max tokens."""
446 | # Setup mock response
447 | mock_response = MagicMock()
448 | mock_response.text = 'Test response'
449 | mock_response.candidates = []
450 | mock_response.prompt_feedback = None
451 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
452 |
453 | # Test data: (model_name, expected_max_tokens)
454 | test_cases = [
455 | ('gemini-2.5-flash', 65536),
456 | ('gemini-2.5-pro', 65536),
457 | ('gemini-2.5-flash-lite', 64000),
458 | ('gemini-2.0-flash', 8192),
459 | ('gemini-1.5-pro', 8192),
460 | ('gemini-1.5-flash', 8192),
461 | ('unknown-model', 8192), # Fallback case
462 | ]
463 |
464 | for model_name, expected_max_tokens in test_cases:
465 | # Create client with specific model, no explicit max_tokens to test mapping
466 | config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
467 | client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
468 |
469 | # Call method without explicit max_tokens to test model mapping fallback
470 | messages = [Message(role='user', content='Test message')]
471 | await client.generate_response(messages)
472 |
473 | # Verify correct max tokens is used from model mapping
474 | call_args = mock_gemini_client.aio.models.generate_content.call_args
475 | config = call_args[1]['config']
476 | assert config.max_output_tokens == expected_max_tokens, (
477 | f'Model {model_name} should use {expected_max_tokens} tokens'
478 | )
479 |
480 |
481 | if __name__ == '__main__':
482 | pytest.main(['-v', 'test_gemini_client.py'])
483 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Comprehensive integration test suite for Graphiti MCP Server.
4 | Covers all MCP tools with consideration for LLM inference latency.
5 | """
6 |
7 | import asyncio
8 | import json
9 | import os
10 | import time
11 | from dataclasses import dataclass
12 | from typing import Any
13 |
14 | import pytest
15 | from mcp import ClientSession, StdioServerParameters
16 | from mcp.client.stdio import stdio_client
17 |
18 |
19 | @dataclass
20 | class TestMetrics:
21 | """Track test performance metrics."""
22 |
23 | operation: str
24 | start_time: float
25 | end_time: float
26 | success: bool
27 | details: dict[str, Any]
28 |
29 | @property
30 | def duration(self) -> float:
31 | """Calculate operation duration in seconds."""
32 | return self.end_time - self.start_time
33 |
34 |
35 | class GraphitiTestClient:
36 | """Enhanced test client for comprehensive Graphiti MCP testing."""
37 |
38 | def __init__(self, test_group_id: str | None = None):
39 | self.test_group_id = test_group_id or f'test_{int(time.time())}'
40 | self.session = None
41 | self.metrics: list[TestMetrics] = []
42 | self.default_timeout = 30 # seconds
43 |
44 | async def __aenter__(self):
45 | """Initialize MCP client session."""
46 | server_params = StdioServerParameters(
47 | command='uv',
48 | args=['run', '../main.py', '--transport', 'stdio'],
49 | env={
50 | 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
51 | 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
52 | 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
53 | 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
54 | 'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
55 | },
56 | )
57 |
58 | self.client_context = stdio_client(server_params)
59 | read, write = await self.client_context.__aenter__()
60 | self.session = ClientSession(read, write)
61 | await self.session.initialize()
62 |
63 | # Wait for server to be fully ready
64 | await asyncio.sleep(2)
65 |
66 | return self
67 |
68 | async def __aexit__(self, exc_type, exc_val, exc_tb):
69 | """Clean up client session."""
70 | if self.session:
71 | await self.session.close()
72 | if hasattr(self, 'client_context'):
73 | await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
74 |
75 | async def call_tool_with_metrics(
76 | self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
77 | ) -> tuple[Any, TestMetrics]:
78 | """Call a tool and capture performance metrics."""
79 | start_time = time.time()
80 | timeout = timeout or self.default_timeout
81 |
82 | try:
83 | result = await asyncio.wait_for(
84 | self.session.call_tool(tool_name, arguments), timeout=timeout
85 | )
86 |
87 | content = result.content[0].text if result.content else None
88 | success = True
89 | details = {'result': content, 'tool': tool_name}
90 |
91 | except asyncio.TimeoutError:
92 | content = None
93 | success = False
94 | details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
95 |
96 | except Exception as e:
97 | content = None
98 | success = False
99 | details = {'error': str(e), 'tool': tool_name}
100 |
101 | end_time = time.time()
102 | metric = TestMetrics(
103 | operation=f'call_{tool_name}',
104 | start_time=start_time,
105 | end_time=end_time,
106 | success=success,
107 | details=details,
108 | )
109 | self.metrics.append(metric)
110 |
111 | return content, metric
112 |
113 | async def wait_for_episode_processing(
114 | self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
115 | ) -> bool:
116 | """
117 | Wait for episodes to be processed with intelligent polling.
118 |
119 | Args:
120 | expected_count: Number of episodes expected to be processed
121 | max_wait: Maximum seconds to wait
122 | poll_interval: Seconds between status checks
123 |
124 | Returns:
125 | True if episodes were processed successfully
126 | """
127 | start_time = time.time()
128 |
129 | while (time.time() - start_time) < max_wait:
130 | result, _ = await self.call_tool_with_metrics(
131 | 'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
132 | )
133 |
134 | if result:
135 | try:
136 | episodes = json.loads(result) if isinstance(result, str) else result
137 | if len(episodes.get('episodes', [])) >= expected_count:
138 | return True
139 | except (json.JSONDecodeError, AttributeError):
140 | pass
141 |
142 | await asyncio.sleep(poll_interval)
143 |
144 | return False
145 |
146 |
147 | class TestCoreOperations:
148 | """Test core Graphiti operations."""
149 |
150 | @pytest.mark.asyncio
151 | async def test_server_initialization(self):
152 | """Verify server initializes with all required tools."""
153 | async with GraphitiTestClient() as client:
154 | tools_result = await client.session.list_tools()
155 | tools = {tool.name for tool in tools_result.tools}
156 |
157 | required_tools = {
158 | 'add_memory',
159 | 'search_memory_nodes',
160 | 'search_memory_facts',
161 | 'get_episodes',
162 | 'delete_episode',
163 | 'delete_entity_edge',
164 | 'get_entity_edge',
165 | 'clear_graph',
166 | 'get_status',
167 | }
168 |
169 | missing_tools = required_tools - tools
170 | assert not missing_tools, f'Missing required tools: {missing_tools}'
171 |
172 | @pytest.mark.asyncio
173 | async def test_add_text_memory(self):
174 | """Test adding text-based memories."""
175 | async with GraphitiTestClient() as client:
176 | # Add memory
177 | result, metric = await client.call_tool_with_metrics(
178 | 'add_memory',
179 | {
180 | 'name': 'Tech Conference Notes',
181 | 'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
182 | 'source': 'text',
183 | 'source_description': 'conference notes',
184 | 'group_id': client.test_group_id,
185 | },
186 | )
187 |
188 | assert metric.success, f'Failed to add memory: {metric.details}'
189 | assert 'queued' in str(result).lower()
190 |
191 | # Wait for processing
192 | processed = await client.wait_for_episode_processing(expected_count=1)
193 | assert processed, 'Episode was not processed within timeout'
194 |
195 | @pytest.mark.asyncio
196 | async def test_add_json_memory(self):
197 | """Test adding structured JSON memories."""
198 | async with GraphitiTestClient() as client:
199 | json_data = {
200 | 'project': {
201 | 'name': 'GraphitiDB',
202 | 'version': '2.0.0',
203 | 'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
204 | },
205 | 'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
206 | }
207 |
208 | result, metric = await client.call_tool_with_metrics(
209 | 'add_memory',
210 | {
211 | 'name': 'Project Data',
212 | 'episode_body': json.dumps(json_data),
213 | 'source': 'json',
214 | 'source_description': 'project database',
215 | 'group_id': client.test_group_id,
216 | },
217 | )
218 |
219 | assert metric.success
220 | assert 'queued' in str(result).lower()
221 |
222 | @pytest.mark.asyncio
223 | async def test_add_message_memory(self):
224 | """Test adding conversation/message memories."""
225 | async with GraphitiTestClient() as client:
226 | conversation = """
227 | user: What are the key features of Graphiti?
228 | assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
229 | user: How does it handle entity resolution?
230 | assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
231 | """
232 |
233 | result, metric = await client.call_tool_with_metrics(
234 | 'add_memory',
235 | {
236 | 'name': 'Feature Discussion',
237 | 'episode_body': conversation,
238 | 'source': 'message',
239 | 'source_description': 'support chat',
240 | 'group_id': client.test_group_id,
241 | },
242 | )
243 |
244 | assert metric.success
245 | assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
246 |
247 |
248 | class TestSearchOperations:
249 | """Test search and retrieval operations."""
250 |
251 | @pytest.mark.asyncio
252 | async def test_search_nodes_semantic(self):
253 | """Test semantic search for nodes."""
254 | async with GraphitiTestClient() as client:
255 | # First add some test data
256 | await client.call_tool_with_metrics(
257 | 'add_memory',
258 | {
259 | 'name': 'Product Launch',
260 | 'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
261 | 'source': 'text',
262 | 'source_description': 'product roadmap',
263 | 'group_id': client.test_group_id,
264 | },
265 | )
266 |
267 | # Wait for processing
268 | await client.wait_for_episode_processing()
269 |
270 | # Search for nodes
271 | result, metric = await client.call_tool_with_metrics(
272 | 'search_memory_nodes',
273 | {'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
274 | )
275 |
276 | assert metric.success
277 | assert result is not None
278 |
279 | @pytest.mark.asyncio
280 | async def test_search_facts_with_filters(self):
281 | """Test fact search with various filters."""
282 | async with GraphitiTestClient() as client:
283 | # Add test data
284 | await client.call_tool_with_metrics(
285 | 'add_memory',
286 | {
287 | 'name': 'Company Facts',
288 | 'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
289 | 'source': 'text',
290 | 'source_description': 'company profile',
291 | 'group_id': client.test_group_id,
292 | },
293 | )
294 |
295 | await client.wait_for_episode_processing()
296 |
297 | # Search with date filter
298 | result, metric = await client.call_tool_with_metrics(
299 | 'search_memory_facts',
300 | {
301 | 'query': 'company information',
302 | 'group_id': client.test_group_id,
303 | 'created_after': '2020-01-01T00:00:00Z',
304 | 'limit': 20,
305 | },
306 | )
307 |
308 | assert metric.success
309 |
310 | @pytest.mark.asyncio
311 | async def test_hybrid_search(self):
312 | """Test hybrid search combining semantic and keyword search."""
313 | async with GraphitiTestClient() as client:
314 | # Add diverse test data
315 | test_memories = [
316 | {
317 | 'name': 'Technical Doc',
318 | 'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
319 | 'source': 'text',
320 | },
321 | {
322 | 'name': 'Architecture',
323 | 'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
324 | 'source': 'text',
325 | },
326 | ]
327 |
328 | for memory in test_memories:
329 | memory['group_id'] = client.test_group_id
330 | memory['source_description'] = 'documentation'
331 | await client.call_tool_with_metrics('add_memory', memory)
332 |
333 | await client.wait_for_episode_processing(expected_count=2)
334 |
335 | # Test semantic + keyword search
336 | result, metric = await client.call_tool_with_metrics(
337 | 'search_memory_nodes',
338 | {'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
339 | )
340 |
341 | assert metric.success
342 |
343 |
344 | class TestEpisodeManagement:
345 | """Test episode lifecycle operations."""
346 |
347 | @pytest.mark.asyncio
348 | async def test_get_episodes_pagination(self):
349 | """Test retrieving episodes with pagination."""
350 | async with GraphitiTestClient() as client:
351 | # Add multiple episodes
352 | for i in range(5):
353 | await client.call_tool_with_metrics(
354 | 'add_memory',
355 | {
356 | 'name': f'Episode {i}',
357 | 'episode_body': f'This is test episode number {i}',
358 | 'source': 'text',
359 | 'source_description': 'test',
360 | 'group_id': client.test_group_id,
361 | },
362 | )
363 |
364 | await client.wait_for_episode_processing(expected_count=5)
365 |
366 | # Test pagination
367 | result, metric = await client.call_tool_with_metrics(
368 | 'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
369 | )
370 |
371 | assert metric.success
372 | episodes = json.loads(result) if isinstance(result, str) else result
373 | assert len(episodes.get('episodes', [])) <= 3
374 |
375 | @pytest.mark.asyncio
376 | async def test_delete_episode(self):
377 | """Test deleting specific episodes."""
378 | async with GraphitiTestClient() as client:
379 | # Add an episode
380 | await client.call_tool_with_metrics(
381 | 'add_memory',
382 | {
383 | 'name': 'To Delete',
384 | 'episode_body': 'This episode will be deleted',
385 | 'source': 'text',
386 | 'source_description': 'test',
387 | 'group_id': client.test_group_id,
388 | },
389 | )
390 |
391 | await client.wait_for_episode_processing()
392 |
393 | # Get episode UUID
394 | result, _ = await client.call_tool_with_metrics(
395 | 'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
396 | )
397 |
398 | episodes = json.loads(result) if isinstance(result, str) else result
399 | episode_uuid = episodes['episodes'][0]['uuid']
400 |
401 | # Delete the episode
402 | result, metric = await client.call_tool_with_metrics(
403 | 'delete_episode', {'episode_uuid': episode_uuid}
404 | )
405 |
406 | assert metric.success
407 | assert 'deleted' in str(result).lower()
408 |
409 |
410 | class TestEntityAndEdgeOperations:
411 | """Test entity and edge management."""
412 |
413 | @pytest.mark.asyncio
414 | async def test_get_entity_edge(self):
415 | """Test retrieving entity edges."""
416 | async with GraphitiTestClient() as client:
417 | # Add data to create entities and edges
418 | await client.call_tool_with_metrics(
419 | 'add_memory',
420 | {
421 | 'name': 'Relationship Data',
422 | 'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
423 | 'source': 'text',
424 | 'source_description': 'org chart',
425 | 'group_id': client.test_group_id,
426 | },
427 | )
428 |
429 | await client.wait_for_episode_processing()
430 |
431 | # Search for nodes to get UUIDs
432 | result, _ = await client.call_tool_with_metrics(
433 | 'search_memory_nodes',
434 | {'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
435 | )
436 |
437 | # Note: This test assumes edges are created between entities
438 | # Actual edge retrieval would require valid edge UUIDs
439 |
440 | @pytest.mark.asyncio
441 | async def test_delete_entity_edge(self):
442 | """Test deleting entity edges."""
443 | # Similar structure to get_entity_edge but with deletion
444 | pass # Implement based on actual edge creation patterns
445 |
446 |
447 | class TestErrorHandling:
448 | """Test error conditions and edge cases."""
449 |
450 | @pytest.mark.asyncio
451 | async def test_invalid_tool_arguments(self):
452 | """Test handling of invalid tool arguments."""
453 | async with GraphitiTestClient() as client:
454 | # Missing required arguments
455 | result, metric = await client.call_tool_with_metrics(
456 | 'add_memory',
457 | {'name': 'Incomplete'}, # Missing required fields
458 | )
459 |
460 | assert not metric.success
461 | assert 'error' in str(metric.details).lower()
462 |
463 | @pytest.mark.asyncio
464 | async def test_timeout_handling(self):
465 | """Test timeout handling for long operations."""
466 | async with GraphitiTestClient() as client:
467 | # Simulate a very large episode that might time out
468 | large_text = 'Large document content. ' * 10000
469 |
470 | result, metric = await client.call_tool_with_metrics(
471 | 'add_memory',
472 | {
473 | 'name': 'Large Document',
474 | 'episode_body': large_text,
475 | 'source': 'text',
476 | 'source_description': 'large file',
477 | 'group_id': client.test_group_id,
478 | },
479 | timeout=5, # Short timeout
480 | )
481 |
482 | # Check if timeout was handled gracefully
483 | if not metric.success:
484 | assert 'timeout' in str(metric.details).lower()
485 |
486 | @pytest.mark.asyncio
487 | async def test_concurrent_operations(self):
488 | """Test handling of concurrent operations."""
489 | async with GraphitiTestClient() as client:
490 | # Launch multiple operations concurrently
491 | tasks = []
492 | for i in range(5):
493 | task = client.call_tool_with_metrics(
494 | 'add_memory',
495 | {
496 | 'name': f'Concurrent {i}',
497 | 'episode_body': f'Concurrent operation {i}',
498 | 'source': 'text',
499 | 'source_description': 'concurrent test',
500 | 'group_id': client.test_group_id,
501 | },
502 | )
503 | tasks.append(task)
504 |
505 | results = await asyncio.gather(*tasks, return_exceptions=True)
506 |
507 | # Check that operations were queued successfully
508 | successful = sum(1 for r, m in results if m.success)
509 | assert successful >= 3 # At least 60% should succeed
510 |
511 |
512 | class TestPerformance:
513 | """Test performance characteristics and optimization."""
514 |
515 | @pytest.mark.asyncio
516 | async def test_latency_metrics(self):
517 | """Measure and validate operation latencies."""
518 | async with GraphitiTestClient() as client:
519 | operations = [
520 | (
521 | 'add_memory',
522 | {
523 | 'name': 'Perf Test',
524 | 'episode_body': 'Simple text',
525 | 'source': 'text',
526 | 'source_description': 'test',
527 | 'group_id': client.test_group_id,
528 | },
529 | ),
530 | (
531 | 'search_memory_nodes',
532 | {'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
533 | ),
534 | ('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
535 | ]
536 |
537 | for tool_name, args in operations:
538 | _, metric = await client.call_tool_with_metrics(tool_name, args)
539 |
540 | # Log performance metrics
541 | print(f'{tool_name}: {metric.duration:.2f}s')
542 |
543 | # Basic latency assertions
544 | if tool_name == 'get_episodes':
545 | assert metric.duration < 2, f'{tool_name} too slow'
546 | elif tool_name == 'search_memory_nodes':
547 | assert metric.duration < 10, f'{tool_name} too slow'
548 |
549 | @pytest.mark.asyncio
550 | async def test_batch_processing_efficiency(self):
551 | """Test efficiency of batch operations."""
552 | async with GraphitiTestClient() as client:
553 | batch_size = 10
554 | start_time = time.time()
555 |
556 | # Batch add memories
557 | for i in range(batch_size):
558 | await client.call_tool_with_metrics(
559 | 'add_memory',
560 | {
561 | 'name': f'Batch {i}',
562 | 'episode_body': f'Batch content {i}',
563 | 'source': 'text',
564 | 'source_description': 'batch test',
565 | 'group_id': client.test_group_id,
566 | },
567 | )
568 |
569 | # Wait for all to process
570 | processed = await client.wait_for_episode_processing(
571 | expected_count=batch_size,
572 | max_wait=120, # Allow more time for batch
573 | )
574 |
575 | total_time = time.time() - start_time
576 | avg_time_per_item = total_time / batch_size
577 |
578 | assert processed, f'Failed to process {batch_size} items'
579 | assert avg_time_per_item < 15, (
580 | f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
581 | )
582 |
583 | # Generate performance report
584 | print('\nBatch Performance Report:')
585 | print(f' Total items: {batch_size}')
586 | print(f' Total time: {total_time:.2f}s')
587 | print(f' Avg per item: {avg_time_per_item:.2f}s')
588 |
589 |
590 | class TestDatabaseBackends:
591 | """Test different database backend configurations."""
592 |
593 | @pytest.mark.asyncio
594 | @pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
595 | async def test_database_operations(self, database):
596 | """Test operations with different database backends."""
597 | env_vars = {
598 | 'DATABASE_PROVIDER': database,
599 | 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
600 | }
601 |
602 | if database == 'neo4j':
603 | env_vars.update(
604 | {
605 | 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
606 | 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
607 | 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
608 | }
609 | )
610 | elif database == 'falkordb':
611 | env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
612 |
613 | # This test would require setting up server with specific database
614 | # Implementation depends on database availability
615 | pass # Placeholder for database-specific tests
616 |
617 |
618 | def generate_test_report(client: GraphitiTestClient) -> str:
619 | """Generate a comprehensive test report from metrics."""
620 | if not client.metrics:
621 | return 'No metrics collected'
622 |
623 | report = []
624 | report.append('\n' + '=' * 60)
625 | report.append('GRAPHITI MCP TEST REPORT')
626 | report.append('=' * 60)
627 |
628 | # Summary statistics
629 | total_ops = len(client.metrics)
630 | successful_ops = sum(1 for m in client.metrics if m.success)
631 | avg_duration = sum(m.duration for m in client.metrics) / total_ops
632 |
633 | report.append(f'\nTotal Operations: {total_ops}')
634 | report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
635 | report.append(f'Average Duration: {avg_duration:.2f}s')
636 |
637 | # Operation breakdown
638 | report.append('\nOperation Breakdown:')
639 | operation_stats = {}
640 | for metric in client.metrics:
641 | if metric.operation not in operation_stats:
642 | operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
643 | stats = operation_stats[metric.operation]
644 | stats['count'] += 1
645 | stats['success'] += 1 if metric.success else 0
646 | stats['total_duration'] += metric.duration
647 |
648 | for op, stats in sorted(operation_stats.items()):
649 | avg_dur = stats['total_duration'] / stats['count']
650 | success_rate = stats['success'] / stats['count'] * 100
651 | report.append(
652 | f' {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
653 | )
654 |
655 | # Slowest operations
656 | slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
657 | report.append('\nSlowest Operations:')
658 | for metric in slowest:
659 | report.append(f' {metric.operation}: {metric.duration:.2f}s')
660 |
661 | report.append('=' * 60)
662 | return '\n'.join(report)
663 |
664 |
665 | if __name__ == '__main__':
666 | # Run tests with pytest
667 | pytest.main([__file__, '-v', '--asyncio-mode=auto'])
668 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/edge_operations.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from datetime import datetime
19 | from time import time
20 |
21 | from pydantic import BaseModel
22 | from typing_extensions import LiteralString
23 |
24 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
25 | from graphiti_core.edges import (
26 | CommunityEdge,
27 | EntityEdge,
28 | EpisodicEdge,
29 | create_entity_edge_embeddings,
30 | )
31 | from graphiti_core.graphiti_types import GraphitiClients
32 | from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
33 | from graphiti_core.llm_client import LLMClient
34 | from graphiti_core.llm_client.config import ModelSize
35 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
36 | from graphiti_core.prompts import prompt_library
37 | from graphiti_core.prompts.dedupe_edges import EdgeDuplicate
38 | from graphiti_core.prompts.extract_edges import ExtractedEdges, MissingFacts
39 | from graphiti_core.search.search import search
40 | from graphiti_core.search.search_config import SearchResults
41 | from graphiti_core.search.search_config_recipes import EDGE_HYBRID_SEARCH_RRF
42 | from graphiti_core.search.search_filters import SearchFilters
43 | from graphiti_core.utils.datetime_utils import ensure_utc, utc_now
44 | from graphiti_core.utils.maintenance.dedup_helpers import _normalize_string_exact
45 |
46 | DEFAULT_EDGE_NAME = 'RELATES_TO'
47 |
48 | logger = logging.getLogger(__name__)
49 |
50 |
51 | def build_episodic_edges(
52 | entity_nodes: list[EntityNode],
53 | episode_uuid: str,
54 | created_at: datetime,
55 | ) -> list[EpisodicEdge]:
56 | episodic_edges: list[EpisodicEdge] = [
57 | EpisodicEdge(
58 | source_node_uuid=episode_uuid,
59 | target_node_uuid=node.uuid,
60 | created_at=created_at,
61 | group_id=node.group_id,
62 | )
63 | for node in entity_nodes
64 | ]
65 |
66 | logger.debug(f'Built episodic edges: {episodic_edges}')
67 |
68 | return episodic_edges
69 |
70 |
71 | def build_community_edges(
72 | entity_nodes: list[EntityNode],
73 | community_node: CommunityNode,
74 | created_at: datetime,
75 | ) -> list[CommunityEdge]:
76 | edges: list[CommunityEdge] = [
77 | CommunityEdge(
78 | source_node_uuid=community_node.uuid,
79 | target_node_uuid=node.uuid,
80 | created_at=created_at,
81 | group_id=community_node.group_id,
82 | )
83 | for node in entity_nodes
84 | ]
85 |
86 | return edges
87 |
88 |
89 | async def extract_edges(
90 | clients: GraphitiClients,
91 | episode: EpisodicNode,
92 | nodes: list[EntityNode],
93 | previous_episodes: list[EpisodicNode],
94 | edge_type_map: dict[tuple[str, str], list[str]],
95 | group_id: str = '',
96 | edge_types: dict[str, type[BaseModel]] | None = None,
97 | ) -> list[EntityEdge]:
98 | start = time()
99 |
100 | extract_edges_max_tokens = 16384
101 | llm_client = clients.llm_client
102 |
103 | edge_type_signature_map: dict[str, tuple[str, str]] = {
104 | edge_type: signature
105 | for signature, edge_types in edge_type_map.items()
106 | for edge_type in edge_types
107 | }
108 |
109 | edge_types_context = (
110 | [
111 | {
112 | 'fact_type_name': type_name,
113 | 'fact_type_signature': edge_type_signature_map.get(type_name, ('Entity', 'Entity')),
114 | 'fact_type_description': type_model.__doc__,
115 | }
116 | for type_name, type_model in edge_types.items()
117 | ]
118 | if edge_types is not None
119 | else []
120 | )
121 |
122 | # Prepare context for LLM
123 | context = {
124 | 'episode_content': episode.content,
125 | 'nodes': [
126 | {'id': idx, 'name': node.name, 'entity_types': node.labels}
127 | for idx, node in enumerate(nodes)
128 | ],
129 | 'previous_episodes': [ep.content for ep in previous_episodes],
130 | 'reference_time': episode.valid_at,
131 | 'edge_types': edge_types_context,
132 | 'custom_prompt': '',
133 | }
134 |
135 | facts_missed = True
136 | reflexion_iterations = 0
137 | while facts_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
138 | llm_response = await llm_client.generate_response(
139 | prompt_library.extract_edges.edge(context),
140 | response_model=ExtractedEdges,
141 | max_tokens=extract_edges_max_tokens,
142 | group_id=group_id,
143 | prompt_name='extract_edges.edge',
144 | )
145 | edges_data = ExtractedEdges(**llm_response).edges
146 |
147 | context['extracted_facts'] = [edge_data.fact for edge_data in edges_data]
148 |
149 | reflexion_iterations += 1
150 | if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
151 | reflexion_response = await llm_client.generate_response(
152 | prompt_library.extract_edges.reflexion(context),
153 | response_model=MissingFacts,
154 | max_tokens=extract_edges_max_tokens,
155 | group_id=group_id,
156 | prompt_name='extract_edges.reflexion',
157 | )
158 |
159 | missing_facts = reflexion_response.get('missing_facts', [])
160 |
161 | custom_prompt = 'The following facts were missed in a previous extraction: '
162 | for fact in missing_facts:
163 | custom_prompt += f'\n{fact},'
164 |
165 | context['custom_prompt'] = custom_prompt
166 |
167 | facts_missed = len(missing_facts) != 0
168 |
169 | end = time()
170 | logger.debug(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms')
171 |
172 | if len(edges_data) == 0:
173 | return []
174 |
175 | # Convert the extracted data into EntityEdge objects
176 | edges = []
177 | for edge_data in edges_data:
178 | # Validate Edge Date information
179 | valid_at = edge_data.valid_at
180 | invalid_at = edge_data.invalid_at
181 | valid_at_datetime = None
182 | invalid_at_datetime = None
183 |
184 | # Filter out empty edges
185 | if not edge_data.fact.strip():
186 | continue
187 |
188 | source_node_idx = edge_data.source_entity_id
189 | target_node_idx = edge_data.target_entity_id
190 |
191 | if len(nodes) == 0:
192 | logger.warning('No entities provided for edge extraction')
193 | continue
194 |
195 | if not (0 <= source_node_idx < len(nodes) and 0 <= target_node_idx < len(nodes)):
196 | logger.warning(
197 | f'Invalid entity IDs in edge extraction for {edge_data.relation_type}. '
198 | f'source_entity_id: {source_node_idx}, target_entity_id: {target_node_idx}, '
199 | f'but only {len(nodes)} entities available (valid range: 0-{len(nodes) - 1})'
200 | )
201 | continue
202 | source_node_uuid = nodes[source_node_idx].uuid
203 | target_node_uuid = nodes[target_node_idx].uuid
204 |
205 | if valid_at:
206 | try:
207 | valid_at_datetime = ensure_utc(
208 | datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
209 | )
210 | except ValueError as e:
211 | logger.warning(f'WARNING: Error parsing valid_at date: {e}. Input: {valid_at}')
212 |
213 | if invalid_at:
214 | try:
215 | invalid_at_datetime = ensure_utc(
216 | datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
217 | )
218 | except ValueError as e:
219 | logger.warning(f'WARNING: Error parsing invalid_at date: {e}. Input: {invalid_at}')
220 | edge = EntityEdge(
221 | source_node_uuid=source_node_uuid,
222 | target_node_uuid=target_node_uuid,
223 | name=edge_data.relation_type,
224 | group_id=group_id,
225 | fact=edge_data.fact,
226 | episodes=[episode.uuid],
227 | created_at=utc_now(),
228 | valid_at=valid_at_datetime,
229 | invalid_at=invalid_at_datetime,
230 | )
231 | edges.append(edge)
232 | logger.debug(
233 | f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})'
234 | )
235 |
236 | logger.debug(f'Extracted edges: {[(e.name, e.uuid) for e in edges]}')
237 |
238 | return edges
239 |
240 |
241 | async def resolve_extracted_edges(
242 | clients: GraphitiClients,
243 | extracted_edges: list[EntityEdge],
244 | episode: EpisodicNode,
245 | entities: list[EntityNode],
246 | edge_types: dict[str, type[BaseModel]],
247 | edge_type_map: dict[tuple[str, str], list[str]],
248 | ) -> tuple[list[EntityEdge], list[EntityEdge]]:
249 | # Fast path: deduplicate exact matches within the extracted edges before parallel processing
250 | seen: dict[tuple[str, str, str], EntityEdge] = {}
251 | deduplicated_edges: list[EntityEdge] = []
252 |
253 | for edge in extracted_edges:
254 | key = (
255 | edge.source_node_uuid,
256 | edge.target_node_uuid,
257 | _normalize_string_exact(edge.fact),
258 | )
259 | if key not in seen:
260 | seen[key] = edge
261 | deduplicated_edges.append(edge)
262 |
263 | extracted_edges = deduplicated_edges
264 |
265 | driver = clients.driver
266 | llm_client = clients.llm_client
267 | embedder = clients.embedder
268 | await create_entity_edge_embeddings(embedder, extracted_edges)
269 |
270 | valid_edges_list: list[list[EntityEdge]] = await semaphore_gather(
271 | *[
272 | EntityEdge.get_between_nodes(driver, edge.source_node_uuid, edge.target_node_uuid)
273 | for edge in extracted_edges
274 | ]
275 | )
276 |
277 | related_edges_results: list[SearchResults] = await semaphore_gather(
278 | *[
279 | search(
280 | clients,
281 | extracted_edge.fact,
282 | group_ids=[extracted_edge.group_id],
283 | config=EDGE_HYBRID_SEARCH_RRF,
284 | search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
285 | )
286 | for extracted_edge, valid_edges in zip(extracted_edges, valid_edges_list, strict=True)
287 | ]
288 | )
289 |
290 | related_edges_lists: list[list[EntityEdge]] = [result.edges for result in related_edges_results]
291 |
292 | edge_invalidation_candidate_results: list[SearchResults] = await semaphore_gather(
293 | *[
294 | search(
295 | clients,
296 | extracted_edge.fact,
297 | group_ids=[extracted_edge.group_id],
298 | config=EDGE_HYBRID_SEARCH_RRF,
299 | search_filter=SearchFilters(),
300 | )
301 | for extracted_edge in extracted_edges
302 | ]
303 | )
304 |
305 | edge_invalidation_candidates: list[list[EntityEdge]] = [
306 | result.edges for result in edge_invalidation_candidate_results
307 | ]
308 |
309 | logger.debug(
310 | f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_lists for e in edges_lst]}'
311 | )
312 |
313 | # Build entity hash table
314 | uuid_entity_map: dict[str, EntityNode] = {entity.uuid: entity for entity in entities}
315 |
316 | # Determine which edge types are relevant for each edge.
317 | # `edge_types_lst` stores the subset of custom edge definitions whose
318 | # node signature matches each extracted edge. Anything outside this subset
319 | # should only stay on the edge if it is a non-custom (LLM generated) label.
320 | edge_types_lst: list[dict[str, type[BaseModel]]] = []
321 | custom_type_names = set(edge_types or {})
322 | for extracted_edge in extracted_edges:
323 | source_node = uuid_entity_map.get(extracted_edge.source_node_uuid)
324 | target_node = uuid_entity_map.get(extracted_edge.target_node_uuid)
325 | source_node_labels = (
326 | source_node.labels + ['Entity'] if source_node is not None else ['Entity']
327 | )
328 | target_node_labels = (
329 | target_node.labels + ['Entity'] if target_node is not None else ['Entity']
330 | )
331 | label_tuples = [
332 | (source_label, target_label)
333 | for source_label in source_node_labels
334 | for target_label in target_node_labels
335 | ]
336 |
337 | extracted_edge_types = {}
338 | for label_tuple in label_tuples:
339 | type_names = edge_type_map.get(label_tuple, [])
340 | for type_name in type_names:
341 | type_model = edge_types.get(type_name)
342 | if type_model is None:
343 | continue
344 |
345 | extracted_edge_types[type_name] = type_model
346 |
347 | edge_types_lst.append(extracted_edge_types)
348 |
349 | for extracted_edge, extracted_edge_types in zip(extracted_edges, edge_types_lst, strict=True):
350 | allowed_type_names = set(extracted_edge_types)
351 | is_custom_name = extracted_edge.name in custom_type_names
352 | if not allowed_type_names:
353 | # No custom types are valid for this node pairing. Keep LLM generated
354 | # labels, but flip disallowed custom names back to the default.
355 | if is_custom_name and extracted_edge.name != DEFAULT_EDGE_NAME:
356 | extracted_edge.name = DEFAULT_EDGE_NAME
357 | continue
358 | if is_custom_name and extracted_edge.name not in allowed_type_names:
359 | # Custom name exists but it is not permitted for this source/target
360 | # signature, so fall back to the default edge label.
361 | extracted_edge.name = DEFAULT_EDGE_NAME
362 |
363 | # resolve edges with related edges in the graph and find invalidation candidates
364 | results: list[tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]] = list(
365 | await semaphore_gather(
366 | *[
367 | resolve_extracted_edge(
368 | llm_client,
369 | extracted_edge,
370 | related_edges,
371 | existing_edges,
372 | episode,
373 | extracted_edge_types,
374 | custom_type_names,
375 | )
376 | for extracted_edge, related_edges, existing_edges, extracted_edge_types in zip(
377 | extracted_edges,
378 | related_edges_lists,
379 | edge_invalidation_candidates,
380 | edge_types_lst,
381 | strict=True,
382 | )
383 | ]
384 | )
385 | )
386 |
387 | resolved_edges: list[EntityEdge] = []
388 | invalidated_edges: list[EntityEdge] = []
389 | for result in results:
390 | resolved_edge = result[0]
391 | invalidated_edge_chunk = result[1]
392 |
393 | resolved_edges.append(resolved_edge)
394 | invalidated_edges.extend(invalidated_edge_chunk)
395 |
396 | logger.debug(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')
397 |
398 | await semaphore_gather(
399 | create_entity_edge_embeddings(embedder, resolved_edges),
400 | create_entity_edge_embeddings(embedder, invalidated_edges),
401 | )
402 |
403 | return resolved_edges, invalidated_edges
404 |
405 |
406 | def resolve_edge_contradictions(
407 | resolved_edge: EntityEdge, invalidation_candidates: list[EntityEdge]
408 | ) -> list[EntityEdge]:
409 | if len(invalidation_candidates) == 0:
410 | return []
411 |
412 | # Determine which contradictory edges need to be expired
413 | invalidated_edges: list[EntityEdge] = []
414 | for edge in invalidation_candidates:
415 | # (Edge invalid before new edge becomes valid) or (new edge invalid before edge becomes valid)
416 | edge_invalid_at_utc = ensure_utc(edge.invalid_at)
417 | resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
418 | edge_valid_at_utc = ensure_utc(edge.valid_at)
419 | resolved_edge_invalid_at_utc = ensure_utc(resolved_edge.invalid_at)
420 |
421 | if (
422 | edge_invalid_at_utc is not None
423 | and resolved_edge_valid_at_utc is not None
424 | and edge_invalid_at_utc <= resolved_edge_valid_at_utc
425 | ) or (
426 | edge_valid_at_utc is not None
427 | and resolved_edge_invalid_at_utc is not None
428 | and resolved_edge_invalid_at_utc <= edge_valid_at_utc
429 | ):
430 | continue
431 | # New edge invalidates edge
432 | elif (
433 | edge_valid_at_utc is not None
434 | and resolved_edge_valid_at_utc is not None
435 | and edge_valid_at_utc < resolved_edge_valid_at_utc
436 | ):
437 | edge.invalid_at = resolved_edge.valid_at
438 | edge.expired_at = edge.expired_at if edge.expired_at is not None else utc_now()
439 | invalidated_edges.append(edge)
440 |
441 | return invalidated_edges
442 |
443 |
444 | async def resolve_extracted_edge(
445 | llm_client: LLMClient,
446 | extracted_edge: EntityEdge,
447 | related_edges: list[EntityEdge],
448 | existing_edges: list[EntityEdge],
449 | episode: EpisodicNode,
450 | edge_type_candidates: dict[str, type[BaseModel]] | None = None,
451 | custom_edge_type_names: set[str] | None = None,
452 | ) -> tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]:
453 | """Resolve an extracted edge against existing graph context.
454 |
455 | Parameters
456 | ----------
457 | llm_client : LLMClient
458 | Client used to invoke the LLM for deduplication and attribute extraction.
459 | extracted_edge : EntityEdge
460 | Newly extracted edge whose canonical representation is being resolved.
461 | related_edges : list[EntityEdge]
462 | Candidate edges with identical endpoints used for duplicate detection.
463 | existing_edges : list[EntityEdge]
464 | Broader set of edges evaluated for contradiction / invalidation.
465 | episode : EpisodicNode
466 | Episode providing content context when extracting edge attributes.
467 | edge_type_candidates : dict[str, type[BaseModel]] | None
468 | Custom edge types permitted for the current source/target signature.
469 | custom_edge_type_names : set[str] | None
470 | Full catalog of registered custom edge names. Used to distinguish
471 | between disallowed custom types (which fall back to the default label)
472 | and ad-hoc labels emitted by the LLM.
473 |
474 | Returns
475 | -------
476 | tuple[EntityEdge, list[EntityEdge], list[EntityEdge]]
477 | The resolved edge, any duplicates, and edges to invalidate.
478 | """
479 | if len(related_edges) == 0 and len(existing_edges) == 0:
480 | return extracted_edge, [], []
481 |
482 | # Fast path: if the fact text and endpoints already exist verbatim, reuse the matching edge.
483 | normalized_fact = _normalize_string_exact(extracted_edge.fact)
484 | for edge in related_edges:
485 | if (
486 | edge.source_node_uuid == extracted_edge.source_node_uuid
487 | and edge.target_node_uuid == extracted_edge.target_node_uuid
488 | and _normalize_string_exact(edge.fact) == normalized_fact
489 | ):
490 | resolved = edge
491 | if episode is not None and episode.uuid not in resolved.episodes:
492 | resolved.episodes.append(episode.uuid)
493 | return resolved, [], []
494 |
495 | start = time()
496 |
497 | # Prepare context for LLM
498 | related_edges_context = [{'idx': i, 'fact': edge.fact} for i, edge in enumerate(related_edges)]
499 |
500 | invalidation_edge_candidates_context = [
501 | {'idx': i, 'fact': existing_edge.fact} for i, existing_edge in enumerate(existing_edges)
502 | ]
503 |
504 | edge_types_context = (
505 | [
506 | {
507 | 'fact_type_name': type_name,
508 | 'fact_type_description': type_model.__doc__,
509 | }
510 | for type_name, type_model in edge_type_candidates.items()
511 | ]
512 | if edge_type_candidates is not None
513 | else []
514 | )
515 |
516 | context = {
517 | 'existing_edges': related_edges_context,
518 | 'new_edge': extracted_edge.fact,
519 | 'edge_invalidation_candidates': invalidation_edge_candidates_context,
520 | 'edge_types': edge_types_context,
521 | }
522 |
523 | if related_edges or existing_edges:
524 | logger.debug(
525 | 'Resolving edge: sent %d EXISTING FACTS%s and %d INVALIDATION CANDIDATES%s',
526 | len(related_edges),
527 | f' (idx 0-{len(related_edges) - 1})' if related_edges else '',
528 | len(existing_edges),
529 | f' (idx 0-{len(existing_edges) - 1})' if existing_edges else '',
530 | )
531 |
532 | llm_response = await llm_client.generate_response(
533 | prompt_library.dedupe_edges.resolve_edge(context),
534 | response_model=EdgeDuplicate,
535 | model_size=ModelSize.small,
536 | prompt_name='dedupe_edges.resolve_edge',
537 | )
538 | response_object = EdgeDuplicate(**llm_response)
539 | duplicate_facts = response_object.duplicate_facts
540 |
541 | # Validate duplicate_facts are in valid range for EXISTING FACTS
542 | invalid_duplicates = [i for i in duplicate_facts if i < 0 or i >= len(related_edges)]
543 | if invalid_duplicates:
544 | logger.warning(
545 | 'LLM returned invalid duplicate_facts idx values %s (valid range: 0-%d for EXISTING FACTS)',
546 | invalid_duplicates,
547 | len(related_edges) - 1,
548 | )
549 |
550 | duplicate_fact_ids: list[int] = [i for i in duplicate_facts if 0 <= i < len(related_edges)]
551 |
552 | resolved_edge = extracted_edge
553 | for duplicate_fact_id in duplicate_fact_ids:
554 | resolved_edge = related_edges[duplicate_fact_id]
555 | break
556 |
557 | if duplicate_fact_ids and episode is not None:
558 | resolved_edge.episodes.append(episode.uuid)
559 |
560 | contradicted_facts: list[int] = response_object.contradicted_facts
561 |
562 | # Validate contradicted_facts are in valid range for INVALIDATION CANDIDATES
563 | invalid_contradictions = [i for i in contradicted_facts if i < 0 or i >= len(existing_edges)]
564 | if invalid_contradictions:
565 | logger.warning(
566 | 'LLM returned invalid contradicted_facts idx values %s (valid range: 0-%d for INVALIDATION CANDIDATES)',
567 | invalid_contradictions,
568 | len(existing_edges) - 1,
569 | )
570 |
571 | invalidation_candidates: list[EntityEdge] = [
572 | existing_edges[i] for i in contradicted_facts if 0 <= i < len(existing_edges)
573 | ]
574 |
575 | fact_type: str = response_object.fact_type
576 | candidate_type_names = set(edge_type_candidates or {})
577 | custom_type_names = custom_edge_type_names or set()
578 |
579 | is_default_type = fact_type.upper() == 'DEFAULT'
580 | is_custom_type = fact_type in custom_type_names
581 | is_allowed_custom_type = fact_type in candidate_type_names
582 |
583 | if is_allowed_custom_type:
584 | # The LLM selected a custom type that is allowed for the node pair.
585 | # Adopt the custom type and, if needed, extract its structured attributes.
586 | resolved_edge.name = fact_type
587 |
588 | edge_attributes_context = {
589 | 'episode_content': episode.content,
590 | 'reference_time': episode.valid_at,
591 | 'fact': resolved_edge.fact,
592 | }
593 |
594 | edge_model = edge_type_candidates.get(fact_type) if edge_type_candidates else None
595 | if edge_model is not None and len(edge_model.model_fields) != 0:
596 | edge_attributes_response = await llm_client.generate_response(
597 | prompt_library.extract_edges.extract_attributes(edge_attributes_context),
598 | response_model=edge_model, # type: ignore
599 | model_size=ModelSize.small,
600 | prompt_name='extract_edges.extract_attributes',
601 | )
602 |
603 | resolved_edge.attributes = edge_attributes_response
604 | elif not is_default_type and is_custom_type:
605 | # The LLM picked a custom type that is not allowed for this signature.
606 | # Reset to the default label and drop any structured attributes.
607 | resolved_edge.name = DEFAULT_EDGE_NAME
608 | resolved_edge.attributes = {}
609 | elif not is_default_type:
610 | # Non-custom labels are allowed to pass through so long as the LLM does
611 | # not return the sentinel DEFAULT value.
612 | resolved_edge.name = fact_type
613 | resolved_edge.attributes = {}
614 |
615 | end = time()
616 | logger.debug(
617 | f'Resolved Edge: {extracted_edge.name} is {resolved_edge.name}, in {(end - start) * 1000} ms'
618 | )
619 |
620 | now = utc_now()
621 |
622 | if resolved_edge.invalid_at and not resolved_edge.expired_at:
623 | resolved_edge.expired_at = now
624 |
625 | # Determine if the new_edge needs to be expired
626 | if resolved_edge.expired_at is None:
627 | invalidation_candidates.sort(key=lambda c: (c.valid_at is None, ensure_utc(c.valid_at)))
628 | for candidate in invalidation_candidates:
629 | candidate_valid_at_utc = ensure_utc(candidate.valid_at)
630 | resolved_edge_valid_at_utc = ensure_utc(resolved_edge.valid_at)
631 | if (
632 | candidate_valid_at_utc is not None
633 | and resolved_edge_valid_at_utc is not None
634 | and candidate_valid_at_utc > resolved_edge_valid_at_utc
635 | ):
636 | # Expire new edge since we have information about more recent events
637 | resolved_edge.invalid_at = candidate.valid_at
638 | resolved_edge.expired_at = now
639 | break
640 |
641 | # Determine which contradictory edges need to be expired
642 | invalidated_edges: list[EntityEdge] = resolve_edge_contradictions(
643 | resolved_edge, invalidation_candidates
644 | )
645 | duplicate_edges: list[EntityEdge] = [related_edges[idx] for idx in duplicate_fact_ids]
646 |
647 | return resolved_edge, invalidated_edges, duplicate_edges
648 |
649 |
650 | async def filter_existing_duplicate_of_edges(
651 | driver: GraphDriver, duplicates_node_tuples: list[tuple[EntityNode, EntityNode]]
652 | ) -> list[tuple[EntityNode, EntityNode]]:
653 | if not duplicates_node_tuples:
654 | return []
655 |
656 | duplicate_nodes_map = {
657 | (source.uuid, target.uuid): (source, target) for source, target in duplicates_node_tuples
658 | }
659 |
660 | if driver.provider == GraphProvider.NEPTUNE:
661 | query: LiteralString = """
662 | UNWIND $duplicate_node_uuids AS duplicate_tuple
663 | MATCH (n:Entity {uuid: duplicate_tuple.source})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple.target})
664 | RETURN DISTINCT
665 | n.uuid AS source_uuid,
666 | m.uuid AS target_uuid
667 | """
668 |
669 | duplicate_nodes = [
670 | {'source': source.uuid, 'target': target.uuid}
671 | for source, target in duplicates_node_tuples
672 | ]
673 |
674 | records, _, _ = await driver.execute_query(
675 | query,
676 | duplicate_node_uuids=duplicate_nodes,
677 | routing_='r',
678 | )
679 | else:
680 | if driver.provider == GraphProvider.KUZU:
681 | query = """
682 | UNWIND $duplicate_node_uuids AS duplicate
683 | MATCH (n:Entity {uuid: duplicate.src})-[:RELATES_TO]->(e:RelatesToNode_ {name: 'IS_DUPLICATE_OF'})-[:RELATES_TO]->(m:Entity {uuid: duplicate.dst})
684 | RETURN DISTINCT
685 | n.uuid AS source_uuid,
686 | m.uuid AS target_uuid
687 | """
688 | duplicate_node_uuids = [{'src': src, 'dst': dst} for src, dst in duplicate_nodes_map]
689 | else:
690 | query: LiteralString = """
691 | UNWIND $duplicate_node_uuids AS duplicate_tuple
692 | MATCH (n:Entity {uuid: duplicate_tuple[0]})-[r:RELATES_TO {name: 'IS_DUPLICATE_OF'}]->(m:Entity {uuid: duplicate_tuple[1]})
693 | RETURN DISTINCT
694 | n.uuid AS source_uuid,
695 | m.uuid AS target_uuid
696 | """
697 | duplicate_node_uuids = list(duplicate_nodes_map.keys())
698 |
699 | records, _, _ = await driver.execute_query(
700 | query,
701 | duplicate_node_uuids=duplicate_node_uuids,
702 | routing_='r',
703 | )
704 |
705 | # Remove duplicates that already have the IS_DUPLICATE_OF edge
706 | for record in records:
707 | duplicate_tuple = (record.get('source_uuid'), record.get('target_uuid'))
708 | if duplicate_nodes_map.get(duplicate_tuple):
709 | duplicate_nodes_map.pop(duplicate_tuple)
710 |
711 | return list(duplicate_nodes_map.values())
712 |
```
--------------------------------------------------------------------------------
/graphiti_core/nodes.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | from abc import ABC, abstractmethod
20 | from datetime import datetime
21 | from enum import Enum
22 | from time import time
23 | from typing import Any
24 | from uuid import uuid4
25 |
26 | from pydantic import BaseModel, Field
27 | from typing_extensions import LiteralString
28 |
29 | from graphiti_core.driver.driver import (
30 | GraphDriver,
31 | GraphProvider,
32 | )
33 | from graphiti_core.embedder import EmbedderClient
34 | from graphiti_core.errors import NodeNotFoundError
35 | from graphiti_core.helpers import parse_db_date
36 | from graphiti_core.models.nodes.node_db_queries import (
37 | COMMUNITY_NODE_RETURN,
38 | COMMUNITY_NODE_RETURN_NEPTUNE,
39 | EPISODIC_NODE_RETURN,
40 | EPISODIC_NODE_RETURN_NEPTUNE,
41 | get_community_node_save_query,
42 | get_entity_node_return_query,
43 | get_entity_node_save_query,
44 | get_episode_node_save_query,
45 | )
46 | from graphiti_core.utils.datetime_utils import utc_now
47 |
48 | logger = logging.getLogger(__name__)
49 |
50 |
51 | class EpisodeType(Enum):
52 | """
53 | Enumeration of different types of episodes that can be processed.
54 |
55 | This enum defines the various sources or formats of episodes that the system
56 | can handle. It's used to categorize and potentially handle different types
57 | of input data differently.
58 |
59 | Attributes:
60 | -----------
61 | message : str
62 | Represents a standard message-type episode. The content for this type
63 | should be formatted as "actor: content". For example, "user: Hello, how are you?"
64 | or "assistant: I'm doing well, thank you for asking."
65 | json : str
66 | Represents an episode containing a JSON string object with structured data.
67 | text : str
68 | Represents a plain text episode.
69 | """
70 |
71 | message = 'message'
72 | json = 'json'
73 | text = 'text'
74 |
75 | @staticmethod
76 | def from_str(episode_type: str):
77 | if episode_type == 'message':
78 | return EpisodeType.message
79 | if episode_type == 'json':
80 | return EpisodeType.json
81 | if episode_type == 'text':
82 | return EpisodeType.text
83 | logger.error(f'Episode type: {episode_type} not implemented')
84 | raise NotImplementedError
85 |
86 |
87 | class Node(BaseModel, ABC):
88 | uuid: str = Field(default_factory=lambda: str(uuid4()))
89 | name: str = Field(description='name of the node')
90 | group_id: str = Field(description='partition of the graph')
91 | labels: list[str] = Field(default_factory=list)
92 | created_at: datetime = Field(default_factory=lambda: utc_now())
93 |
94 | @abstractmethod
95 | async def save(self, driver: GraphDriver): ...
96 |
97 | async def delete(self, driver: GraphDriver):
98 | if driver.graph_operations_interface:
99 | return await driver.graph_operations_interface.node_delete(self, driver)
100 |
101 | match driver.provider:
102 | case GraphProvider.NEO4J:
103 | records, _, _ = await driver.execute_query(
104 | """
105 | MATCH (n {uuid: $uuid})
106 | WHERE n:Entity OR n:Episodic OR n:Community
107 | OPTIONAL MATCH (n)-[r]-()
108 | WITH collect(r.uuid) AS edge_uuids, n
109 | DETACH DELETE n
110 | RETURN edge_uuids
111 | """,
112 | uuid=self.uuid,
113 | )
114 |
115 | case GraphProvider.KUZU:
116 | for label in ['Episodic', 'Community']:
117 | await driver.execute_query(
118 | f"""
119 | MATCH (n:{label} {{uuid: $uuid}})
120 | DETACH DELETE n
121 | """,
122 | uuid=self.uuid,
123 | )
124 | # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
125 | # Explicitly delete the "edge" nodes first, then the entity node.
126 | await driver.execute_query(
127 | """
128 | MATCH (n:Entity {uuid: $uuid})-[:RELATES_TO]->(e:RelatesToNode_)
129 | DETACH DELETE e
130 | """,
131 | uuid=self.uuid,
132 | )
133 | await driver.execute_query(
134 | """
135 | MATCH (n:Entity {uuid: $uuid})
136 | DETACH DELETE n
137 | """,
138 | uuid=self.uuid,
139 | )
140 | case _: # FalkorDB, Neptune
141 | for label in ['Entity', 'Episodic', 'Community']:
142 | await driver.execute_query(
143 | f"""
144 | MATCH (n:{label} {{uuid: $uuid}})
145 | DETACH DELETE n
146 | """,
147 | uuid=self.uuid,
148 | )
149 |
150 | logger.debug(f'Deleted Node: {self.uuid}')
151 |
152 | def __hash__(self):
153 | return hash(self.uuid)
154 |
155 | def __eq__(self, other):
156 | if isinstance(other, Node):
157 | return self.uuid == other.uuid
158 | return False
159 |
160 | @classmethod
161 | async def delete_by_group_id(cls, driver: GraphDriver, group_id: str, batch_size: int = 100):
162 | if driver.graph_operations_interface:
163 | return await driver.graph_operations_interface.node_delete_by_group_id(
164 | cls, driver, group_id, batch_size
165 | )
166 |
167 | match driver.provider:
168 | case GraphProvider.NEO4J:
169 | async with driver.session() as session:
170 | await session.run(
171 | """
172 | MATCH (n:Entity|Episodic|Community {group_id: $group_id})
173 | CALL (n) {
174 | DETACH DELETE n
175 | } IN TRANSACTIONS OF $batch_size ROWS
176 | """,
177 | group_id=group_id,
178 | batch_size=batch_size,
179 | )
180 |
181 | case GraphProvider.KUZU:
182 | for label in ['Episodic', 'Community']:
183 | await driver.execute_query(
184 | f"""
185 | MATCH (n:{label} {{group_id: $group_id}})
186 | DETACH DELETE n
187 | """,
188 | group_id=group_id,
189 | )
190 | # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
191 | # Explicitly delete the "edge" nodes first, then the entity node.
192 | await driver.execute_query(
193 | """
194 | MATCH (n:Entity {group_id: $group_id})-[:RELATES_TO]->(e:RelatesToNode_)
195 | DETACH DELETE e
196 | """,
197 | group_id=group_id,
198 | )
199 | await driver.execute_query(
200 | """
201 | MATCH (n:Entity {group_id: $group_id})
202 | DETACH DELETE n
203 | """,
204 | group_id=group_id,
205 | )
206 | case _: # FalkorDB, Neptune
207 | for label in ['Entity', 'Episodic', 'Community']:
208 | await driver.execute_query(
209 | f"""
210 | MATCH (n:{label} {{group_id: $group_id}})
211 | DETACH DELETE n
212 | """,
213 | group_id=group_id,
214 | )
215 |
216 | @classmethod
217 | async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str], batch_size: int = 100):
218 | if driver.graph_operations_interface:
219 | return await driver.graph_operations_interface.node_delete_by_uuids(
220 | cls, driver, uuids, group_id=None, batch_size=batch_size
221 | )
222 |
223 | match driver.provider:
224 | case GraphProvider.FALKORDB:
225 | for label in ['Entity', 'Episodic', 'Community']:
226 | await driver.execute_query(
227 | f"""
228 | MATCH (n:{label})
229 | WHERE n.uuid IN $uuids
230 | DETACH DELETE n
231 | """,
232 | uuids=uuids,
233 | )
234 | case GraphProvider.KUZU:
235 | for label in ['Episodic', 'Community']:
236 | await driver.execute_query(
237 | f"""
238 | MATCH (n:{label})
239 | WHERE n.uuid IN $uuids
240 | DETACH DELETE n
241 | """,
242 | uuids=uuids,
243 | )
244 | # Entity edges are actually nodes in Kuzu, so simple `DETACH DELETE` will not work.
245 | # Explicitly delete the "edge" nodes first, then the entity node.
246 | await driver.execute_query(
247 | """
248 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)
249 | WHERE n.uuid IN $uuids
250 | DETACH DELETE e
251 | """,
252 | uuids=uuids,
253 | )
254 | await driver.execute_query(
255 | """
256 | MATCH (n:Entity)
257 | WHERE n.uuid IN $uuids
258 | DETACH DELETE n
259 | """,
260 | uuids=uuids,
261 | )
262 | case _: # Neo4J, Neptune
263 | async with driver.session() as session:
264 | # Collect all edge UUIDs before deleting nodes
265 | await session.run(
266 | """
267 | MATCH (n:Entity|Episodic|Community)
268 | WHERE n.uuid IN $uuids
269 | MATCH (n)-[r]-()
270 | RETURN collect(r.uuid) AS edge_uuids
271 | """,
272 | uuids=uuids,
273 | )
274 |
275 | # Now delete the nodes in batches
276 | await session.run(
277 | """
278 | MATCH (n:Entity|Episodic|Community)
279 | WHERE n.uuid IN $uuids
280 | CALL (n) {
281 | DETACH DELETE n
282 | } IN TRANSACTIONS OF $batch_size ROWS
283 | """,
284 | uuids=uuids,
285 | batch_size=batch_size,
286 | )
287 |
288 | @classmethod
289 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
290 |
291 | @classmethod
292 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]): ...
293 |
294 |
295 | class EpisodicNode(Node):
296 | source: EpisodeType = Field(description='source type')
297 | source_description: str = Field(description='description of the data source')
298 | content: str = Field(description='raw episode data')
299 | valid_at: datetime = Field(
300 | description='datetime of when the original document was created',
301 | )
302 | entity_edges: list[str] = Field(
303 | description='list of entity edges referenced in this episode',
304 | default_factory=list,
305 | )
306 |
307 | async def save(self, driver: GraphDriver):
308 | if driver.graph_operations_interface:
309 | return await driver.graph_operations_interface.episodic_node_save(self, driver)
310 |
311 | episode_args = {
312 | 'uuid': self.uuid,
313 | 'name': self.name,
314 | 'group_id': self.group_id,
315 | 'source_description': self.source_description,
316 | 'content': self.content,
317 | 'entity_edges': self.entity_edges,
318 | 'created_at': self.created_at,
319 | 'valid_at': self.valid_at,
320 | 'source': self.source.value,
321 | }
322 |
323 | result = await driver.execute_query(
324 | get_episode_node_save_query(driver.provider), **episode_args
325 | )
326 |
327 | logger.debug(f'Saved Node to Graph: {self.uuid}')
328 |
329 | return result
330 |
331 | @classmethod
332 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
333 | records, _, _ = await driver.execute_query(
334 | """
335 | MATCH (e:Episodic {uuid: $uuid})
336 | RETURN
337 | """
338 | + (
339 | EPISODIC_NODE_RETURN_NEPTUNE
340 | if driver.provider == GraphProvider.NEPTUNE
341 | else EPISODIC_NODE_RETURN
342 | ),
343 | uuid=uuid,
344 | routing_='r',
345 | )
346 |
347 | episodes = [get_episodic_node_from_record(record) for record in records]
348 |
349 | if len(episodes) == 0:
350 | raise NodeNotFoundError(uuid)
351 |
352 | return episodes[0]
353 |
354 | @classmethod
355 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
356 | records, _, _ = await driver.execute_query(
357 | """
358 | MATCH (e:Episodic)
359 | WHERE e.uuid IN $uuids
360 | RETURN DISTINCT
361 | """
362 | + (
363 | EPISODIC_NODE_RETURN_NEPTUNE
364 | if driver.provider == GraphProvider.NEPTUNE
365 | else EPISODIC_NODE_RETURN
366 | ),
367 | uuids=uuids,
368 | routing_='r',
369 | )
370 |
371 | episodes = [get_episodic_node_from_record(record) for record in records]
372 |
373 | return episodes
374 |
375 | @classmethod
376 | async def get_by_group_ids(
377 | cls,
378 | driver: GraphDriver,
379 | group_ids: list[str],
380 | limit: int | None = None,
381 | uuid_cursor: str | None = None,
382 | ):
383 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
384 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
385 |
386 | records, _, _ = await driver.execute_query(
387 | """
388 | MATCH (e:Episodic)
389 | WHERE e.group_id IN $group_ids
390 | """
391 | + cursor_query
392 | + """
393 | RETURN DISTINCT
394 | """
395 | + (
396 | EPISODIC_NODE_RETURN_NEPTUNE
397 | if driver.provider == GraphProvider.NEPTUNE
398 | else EPISODIC_NODE_RETURN
399 | )
400 | + """
401 | ORDER BY uuid DESC
402 | """
403 | + limit_query,
404 | group_ids=group_ids,
405 | uuid=uuid_cursor,
406 | limit=limit,
407 | routing_='r',
408 | )
409 |
410 | episodes = [get_episodic_node_from_record(record) for record in records]
411 |
412 | return episodes
413 |
414 | @classmethod
415 | async def get_by_entity_node_uuid(cls, driver: GraphDriver, entity_node_uuid: str):
416 | records, _, _ = await driver.execute_query(
417 | """
418 | MATCH (e:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $entity_node_uuid})
419 | RETURN DISTINCT
420 | """
421 | + (
422 | EPISODIC_NODE_RETURN_NEPTUNE
423 | if driver.provider == GraphProvider.NEPTUNE
424 | else EPISODIC_NODE_RETURN
425 | ),
426 | entity_node_uuid=entity_node_uuid,
427 | routing_='r',
428 | )
429 |
430 | episodes = [get_episodic_node_from_record(record) for record in records]
431 |
432 | return episodes
433 |
434 |
435 | class EntityNode(Node):
436 | name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
437 | summary: str = Field(description='regional summary of surrounding edges', default_factory=str)
438 | attributes: dict[str, Any] = Field(
439 | default={}, description='Additional attributes of the node. Dependent on node labels'
440 | )
441 |
442 | async def generate_name_embedding(self, embedder: EmbedderClient):
443 | start = time()
444 | text = self.name.replace('\n', ' ')
445 | self.name_embedding = await embedder.create(input_data=[text])
446 | end = time()
447 | logger.debug(f'embedded {text} in {end - start} ms')
448 |
449 | return self.name_embedding
450 |
451 | async def load_name_embedding(self, driver: GraphDriver):
452 | if driver.graph_operations_interface:
453 | return await driver.graph_operations_interface.node_load_embeddings(self, driver)
454 |
455 | if driver.provider == GraphProvider.NEPTUNE:
456 | query: LiteralString = """
457 | MATCH (n:Entity {uuid: $uuid})
458 | RETURN [x IN split(n.name_embedding, ",") | toFloat(x)] as name_embedding
459 | """
460 |
461 | else:
462 | query: LiteralString = """
463 | MATCH (n:Entity {uuid: $uuid})
464 | RETURN n.name_embedding AS name_embedding
465 | """
466 | records, _, _ = await driver.execute_query(
467 | query,
468 | uuid=self.uuid,
469 | routing_='r',
470 | )
471 |
472 | if len(records) == 0:
473 | raise NodeNotFoundError(self.uuid)
474 |
475 | self.name_embedding = records[0]['name_embedding']
476 |
477 | async def save(self, driver: GraphDriver):
478 | if driver.graph_operations_interface:
479 | return await driver.graph_operations_interface.node_save(self, driver)
480 |
481 | entity_data: dict[str, Any] = {
482 | 'uuid': self.uuid,
483 | 'name': self.name,
484 | 'name_embedding': self.name_embedding,
485 | 'group_id': self.group_id,
486 | 'summary': self.summary,
487 | 'created_at': self.created_at,
488 | }
489 |
490 | if driver.provider == GraphProvider.KUZU:
491 | entity_data['attributes'] = json.dumps(self.attributes)
492 | entity_data['labels'] = list(set(self.labels + ['Entity']))
493 | result = await driver.execute_query(
494 | get_entity_node_save_query(driver.provider, labels=''),
495 | **entity_data,
496 | )
497 | else:
498 | entity_data.update(self.attributes or {})
499 | labels = ':'.join(self.labels + ['Entity'])
500 |
501 | result = await driver.execute_query(
502 | get_entity_node_save_query(driver.provider, labels),
503 | entity_data=entity_data,
504 | )
505 |
506 | logger.debug(f'Saved Node to Graph: {self.uuid}')
507 |
508 | return result
509 |
510 | @classmethod
511 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
512 | records, _, _ = await driver.execute_query(
513 | """
514 | MATCH (n:Entity {uuid: $uuid})
515 | RETURN
516 | """
517 | + get_entity_node_return_query(driver.provider),
518 | uuid=uuid,
519 | routing_='r',
520 | )
521 |
522 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
523 |
524 | if len(nodes) == 0:
525 | raise NodeNotFoundError(uuid)
526 |
527 | return nodes[0]
528 |
529 | @classmethod
530 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
531 | records, _, _ = await driver.execute_query(
532 | """
533 | MATCH (n:Entity)
534 | WHERE n.uuid IN $uuids
535 | RETURN
536 | """
537 | + get_entity_node_return_query(driver.provider),
538 | uuids=uuids,
539 | routing_='r',
540 | )
541 |
542 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
543 |
544 | return nodes
545 |
546 | @classmethod
547 | async def get_by_group_ids(
548 | cls,
549 | driver: GraphDriver,
550 | group_ids: list[str],
551 | limit: int | None = None,
552 | uuid_cursor: str | None = None,
553 | with_embeddings: bool = False,
554 | ):
555 | cursor_query: LiteralString = 'AND n.uuid < $uuid' if uuid_cursor else ''
556 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
557 | with_embeddings_query: LiteralString = (
558 | """,
559 | n.name_embedding AS name_embedding
560 | """
561 | if with_embeddings
562 | else ''
563 | )
564 |
565 | records, _, _ = await driver.execute_query(
566 | """
567 | MATCH (n:Entity)
568 | WHERE n.group_id IN $group_ids
569 | """
570 | + cursor_query
571 | + """
572 | RETURN
573 | """
574 | + get_entity_node_return_query(driver.provider)
575 | + with_embeddings_query
576 | + """
577 | ORDER BY n.uuid DESC
578 | """
579 | + limit_query,
580 | group_ids=group_ids,
581 | uuid=uuid_cursor,
582 | limit=limit,
583 | routing_='r',
584 | )
585 |
586 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
587 |
588 | return nodes
589 |
590 |
591 | class CommunityNode(Node):
592 | name_embedding: list[float] | None = Field(default=None, description='embedding of the name')
593 | summary: str = Field(description='region summary of member nodes', default_factory=str)
594 |
595 | async def save(self, driver: GraphDriver):
596 | if driver.provider == GraphProvider.NEPTUNE:
597 | await driver.save_to_aoss( # pyright: ignore reportAttributeAccessIssue
598 | 'communities',
599 | [{'name': self.name, 'uuid': self.uuid, 'group_id': self.group_id}],
600 | )
601 | result = await driver.execute_query(
602 | get_community_node_save_query(driver.provider), # type: ignore
603 | uuid=self.uuid,
604 | name=self.name,
605 | group_id=self.group_id,
606 | summary=self.summary,
607 | name_embedding=self.name_embedding,
608 | created_at=self.created_at,
609 | )
610 |
611 | logger.debug(f'Saved Node to Graph: {self.uuid}')
612 |
613 | return result
614 |
615 | async def generate_name_embedding(self, embedder: EmbedderClient):
616 | start = time()
617 | text = self.name.replace('\n', ' ')
618 | self.name_embedding = await embedder.create(input_data=[text])
619 | end = time()
620 | logger.debug(f'embedded {text} in {end - start} ms')
621 |
622 | return self.name_embedding
623 |
624 | async def load_name_embedding(self, driver: GraphDriver):
625 | if driver.provider == GraphProvider.NEPTUNE:
626 | query: LiteralString = """
627 | MATCH (c:Community {uuid: $uuid})
628 | RETURN [x IN split(c.name_embedding, ",") | toFloat(x)] as name_embedding
629 | """
630 | else:
631 | query: LiteralString = """
632 | MATCH (c:Community {uuid: $uuid})
633 | RETURN c.name_embedding AS name_embedding
634 | """
635 |
636 | records, _, _ = await driver.execute_query(
637 | query,
638 | uuid=self.uuid,
639 | routing_='r',
640 | )
641 |
642 | if len(records) == 0:
643 | raise NodeNotFoundError(self.uuid)
644 |
645 | self.name_embedding = records[0]['name_embedding']
646 |
647 | @classmethod
648 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
649 | records, _, _ = await driver.execute_query(
650 | """
651 | MATCH (c:Community {uuid: $uuid})
652 | RETURN
653 | """
654 | + (
655 | COMMUNITY_NODE_RETURN_NEPTUNE
656 | if driver.provider == GraphProvider.NEPTUNE
657 | else COMMUNITY_NODE_RETURN
658 | ),
659 | uuid=uuid,
660 | routing_='r',
661 | )
662 |
663 | nodes = [get_community_node_from_record(record) for record in records]
664 |
665 | if len(nodes) == 0:
666 | raise NodeNotFoundError(uuid)
667 |
668 | return nodes[0]
669 |
670 | @classmethod
671 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
672 | records, _, _ = await driver.execute_query(
673 | """
674 | MATCH (c:Community)
675 | WHERE c.uuid IN $uuids
676 | RETURN
677 | """
678 | + (
679 | COMMUNITY_NODE_RETURN_NEPTUNE
680 | if driver.provider == GraphProvider.NEPTUNE
681 | else COMMUNITY_NODE_RETURN
682 | ),
683 | uuids=uuids,
684 | routing_='r',
685 | )
686 |
687 | communities = [get_community_node_from_record(record) for record in records]
688 |
689 | return communities
690 |
691 | @classmethod
692 | async def get_by_group_ids(
693 | cls,
694 | driver: GraphDriver,
695 | group_ids: list[str],
696 | limit: int | None = None,
697 | uuid_cursor: str | None = None,
698 | ):
699 | cursor_query: LiteralString = 'AND c.uuid < $uuid' if uuid_cursor else ''
700 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
701 |
702 | records, _, _ = await driver.execute_query(
703 | """
704 | MATCH (c:Community)
705 | WHERE c.group_id IN $group_ids
706 | """
707 | + cursor_query
708 | + """
709 | RETURN
710 | """
711 | + (
712 | COMMUNITY_NODE_RETURN_NEPTUNE
713 | if driver.provider == GraphProvider.NEPTUNE
714 | else COMMUNITY_NODE_RETURN
715 | )
716 | + """
717 | ORDER BY c.uuid DESC
718 | """
719 | + limit_query,
720 | group_ids=group_ids,
721 | uuid=uuid_cursor,
722 | limit=limit,
723 | routing_='r',
724 | )
725 |
726 | communities = [get_community_node_from_record(record) for record in records]
727 |
728 | return communities
729 |
730 |
731 | # Node helpers
732 | def get_episodic_node_from_record(record: Any) -> EpisodicNode:
733 | created_at = parse_db_date(record['created_at'])
734 | valid_at = parse_db_date(record['valid_at'])
735 |
736 | if created_at is None:
737 | raise ValueError(f'created_at cannot be None for episode {record.get("uuid", "unknown")}')
738 | if valid_at is None:
739 | raise ValueError(f'valid_at cannot be None for episode {record.get("uuid", "unknown")}')
740 |
741 | return EpisodicNode(
742 | content=record['content'],
743 | created_at=created_at,
744 | valid_at=valid_at,
745 | uuid=record['uuid'],
746 | group_id=record['group_id'],
747 | source=EpisodeType.from_str(record['source']),
748 | name=record['name'],
749 | source_description=record['source_description'],
750 | entity_edges=record['entity_edges'],
751 | )
752 |
753 |
754 | def get_entity_node_from_record(record: Any, provider: GraphProvider) -> EntityNode:
755 | if provider == GraphProvider.KUZU:
756 | attributes = json.loads(record['attributes']) if record['attributes'] else {}
757 | else:
758 | attributes = record['attributes']
759 | attributes.pop('uuid', None)
760 | attributes.pop('name', None)
761 | attributes.pop('group_id', None)
762 | attributes.pop('name_embedding', None)
763 | attributes.pop('summary', None)
764 | attributes.pop('created_at', None)
765 | attributes.pop('labels', None)
766 |
767 | labels = record.get('labels', [])
768 | group_id = record.get('group_id')
769 | if 'Entity_' + group_id.replace('-', '') in labels:
770 | labels.remove('Entity_' + group_id.replace('-', ''))
771 |
772 | entity_node = EntityNode(
773 | uuid=record['uuid'],
774 | name=record['name'],
775 | name_embedding=record.get('name_embedding'),
776 | group_id=group_id,
777 | labels=labels,
778 | created_at=parse_db_date(record['created_at']), # type: ignore
779 | summary=record['summary'],
780 | attributes=attributes,
781 | )
782 |
783 | return entity_node
784 |
785 |
786 | def get_community_node_from_record(record: Any) -> CommunityNode:
787 | return CommunityNode(
788 | uuid=record['uuid'],
789 | name=record['name'],
790 | group_id=record['group_id'],
791 | name_embedding=record['name_embedding'],
792 | created_at=parse_db_date(record['created_at']), # type: ignore
793 | summary=record['summary'],
794 | )
795 |
796 |
797 | async def create_entity_node_embeddings(embedder: EmbedderClient, nodes: list[EntityNode]):
798 | # filter out falsey values from nodes
799 | filtered_nodes = [node for node in nodes if node.name]
800 |
801 | if not filtered_nodes:
802 | return
803 |
804 | name_embeddings = await embedder.create_batch([node.name for node in filtered_nodes])
805 | for node, name_embedding in zip(filtered_nodes, name_embeddings, strict=True):
806 | node.name_embedding = name_embedding
807 |
```