This is page 8 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── dense_vs_normal_ingestion.py
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── content_chunking.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_entity_extraction.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ ├── search
│ │ └── search_utils_test.py
│ └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import typing
20 | from datetime import datetime
21 |
22 | import numpy as np
23 | from pydantic import BaseModel, Field
24 | from typing_extensions import Any
25 |
26 | from graphiti_core.driver.driver import (
27 | GraphDriver,
28 | GraphDriverSession,
29 | GraphProvider,
30 | )
31 | from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
32 | from graphiti_core.embedder import EmbedderClient
33 | from graphiti_core.graphiti_types import GraphitiClients
34 | from graphiti_core.helpers import normalize_l2, semaphore_gather
35 | from graphiti_core.models.edges.edge_db_queries import (
36 | get_entity_edge_save_bulk_query,
37 | get_episodic_edge_save_bulk_query,
38 | )
39 | from graphiti_core.models.nodes.node_db_queries import (
40 | get_entity_node_save_bulk_query,
41 | get_episode_node_save_bulk_query,
42 | )
43 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
44 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
45 | from graphiti_core.utils.maintenance.dedup_helpers import (
46 | DedupResolutionState,
47 | _build_candidate_indexes,
48 | _normalize_string_exact,
49 | _resolve_with_similarity,
50 | )
51 | from graphiti_core.utils.maintenance.edge_operations import (
52 | extract_edges,
53 | resolve_extracted_edge,
54 | )
55 | from graphiti_core.utils.maintenance.graph_data_operations import (
56 | EPISODE_WINDOW_LEN,
57 | retrieve_episodes,
58 | )
59 | from graphiti_core.utils.maintenance.node_operations import (
60 | extract_nodes,
61 | resolve_extracted_nodes,
62 | )
63 |
64 | logger = logging.getLogger(__name__)
65 |
66 | CHUNK_SIZE = 10
67 |
68 |
69 | def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
70 | """Collapse alias -> canonical chains while preserving direction.
71 |
72 | The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
73 | union-find with iterative path compression to ensure every source UUID resolves to its ultimate
74 | canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
75 | """
76 |
77 | parent: dict[str, str] = {}
78 |
79 | def find(uuid: str) -> str:
80 | """Directed union-find lookup using iterative path compression."""
81 | parent.setdefault(uuid, uuid)
82 | root = uuid
83 | while parent[root] != root:
84 | root = parent[root]
85 |
86 | while parent[uuid] != root:
87 | next_uuid = parent[uuid]
88 | parent[uuid] = root
89 | uuid = next_uuid
90 |
91 | return root
92 |
93 | for source_uuid, target_uuid in pairs:
94 | parent.setdefault(source_uuid, source_uuid)
95 | parent.setdefault(target_uuid, target_uuid)
96 | parent[find(source_uuid)] = find(target_uuid)
97 |
98 | return {uuid: find(uuid) for uuid in parent}
99 |
100 |
101 | class RawEpisode(BaseModel):
102 | name: str
103 | uuid: str | None = Field(default=None)
104 | content: str
105 | source_description: str
106 | source: EpisodeType
107 | reference_time: datetime
108 |
109 |
110 | async def retrieve_previous_episodes_bulk(
111 | driver: GraphDriver, episodes: list[EpisodicNode]
112 | ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
113 | previous_episodes_list = await semaphore_gather(
114 | *[
115 | retrieve_episodes(
116 | driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
117 | )
118 | for episode in episodes
119 | ]
120 | )
121 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
122 | (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
123 | ]
124 |
125 | return episode_tuples
126 |
127 |
128 | async def add_nodes_and_edges_bulk(
129 | driver: GraphDriver,
130 | episodic_nodes: list[EpisodicNode],
131 | episodic_edges: list[EpisodicEdge],
132 | entity_nodes: list[EntityNode],
133 | entity_edges: list[EntityEdge],
134 | embedder: EmbedderClient,
135 | ):
136 | session = driver.session()
137 | try:
138 | await session.execute_write(
139 | add_nodes_and_edges_bulk_tx,
140 | episodic_nodes,
141 | episodic_edges,
142 | entity_nodes,
143 | entity_edges,
144 | embedder,
145 | driver=driver,
146 | )
147 | finally:
148 | await session.close()
149 |
150 |
151 | async def add_nodes_and_edges_bulk_tx(
152 | tx: GraphDriverSession,
153 | episodic_nodes: list[EpisodicNode],
154 | episodic_edges: list[EpisodicEdge],
155 | entity_nodes: list[EntityNode],
156 | entity_edges: list[EntityEdge],
157 | embedder: EmbedderClient,
158 | driver: GraphDriver,
159 | ):
160 | episodes = [dict(episode) for episode in episodic_nodes]
161 | for episode in episodes:
162 | episode['source'] = str(episode['source'].value)
163 | episode.pop('labels', None)
164 |
165 | nodes = []
166 |
167 | for node in entity_nodes:
168 | if node.name_embedding is None:
169 | await node.generate_name_embedding(embedder)
170 |
171 | entity_data: dict[str, Any] = {
172 | 'uuid': node.uuid,
173 | 'name': node.name,
174 | 'group_id': node.group_id,
175 | 'summary': node.summary,
176 | 'created_at': node.created_at,
177 | 'name_embedding': node.name_embedding,
178 | 'labels': list(set(node.labels + ['Entity'])),
179 | }
180 |
181 | if driver.provider == GraphProvider.KUZU:
182 | attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
183 | entity_data['attributes'] = json.dumps(attributes)
184 | else:
185 | entity_data.update(node.attributes or {})
186 |
187 | nodes.append(entity_data)
188 |
189 | edges = []
190 | for edge in entity_edges:
191 | if edge.fact_embedding is None:
192 | await edge.generate_embedding(embedder)
193 | edge_data: dict[str, Any] = {
194 | 'uuid': edge.uuid,
195 | 'source_node_uuid': edge.source_node_uuid,
196 | 'target_node_uuid': edge.target_node_uuid,
197 | 'name': edge.name,
198 | 'fact': edge.fact,
199 | 'group_id': edge.group_id,
200 | 'episodes': edge.episodes,
201 | 'created_at': edge.created_at,
202 | 'expired_at': edge.expired_at,
203 | 'valid_at': edge.valid_at,
204 | 'invalid_at': edge.invalid_at,
205 | 'fact_embedding': edge.fact_embedding,
206 | }
207 |
208 | if driver.provider == GraphProvider.KUZU:
209 | attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
210 | edge_data['attributes'] = json.dumps(attributes)
211 | else:
212 | edge_data.update(edge.attributes or {})
213 |
214 | edges.append(edge_data)
215 |
216 | if driver.graph_operations_interface:
217 | await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
218 | await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
219 | await driver.graph_operations_interface.episodic_edge_save_bulk(
220 | None, driver, tx, [edge.model_dump() for edge in episodic_edges]
221 | )
222 | await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
223 |
224 | elif driver.provider == GraphProvider.KUZU:
225 | # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
226 | episode_query = get_episode_node_save_bulk_query(driver.provider)
227 | for episode in episodes:
228 | await tx.run(episode_query, **episode)
229 | entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
230 | for node in nodes:
231 | await tx.run(entity_node_query, **node)
232 | entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
233 | for edge in edges:
234 | await tx.run(entity_edge_query, **edge)
235 | episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
236 | for edge in episodic_edges:
237 | await tx.run(episodic_edge_query, **edge.model_dump())
238 | else:
239 | await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
240 | await tx.run(
241 | get_entity_node_save_bulk_query(driver.provider, nodes),
242 | nodes=nodes,
243 | )
244 | await tx.run(
245 | get_episodic_edge_save_bulk_query(driver.provider),
246 | episodic_edges=[edge.model_dump() for edge in episodic_edges],
247 | )
248 | await tx.run(
249 | get_entity_edge_save_bulk_query(driver.provider),
250 | entity_edges=edges,
251 | )
252 |
253 |
254 | async def extract_nodes_and_edges_bulk(
255 | clients: GraphitiClients,
256 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
257 | edge_type_map: dict[tuple[str, str], list[str]],
258 | entity_types: dict[str, type[BaseModel]] | None = None,
259 | excluded_entity_types: list[str] | None = None,
260 | edge_types: dict[str, type[BaseModel]] | None = None,
261 | custom_extraction_instructions: str | None = None,
262 | ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
263 | extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
264 | *[
265 | extract_nodes(
266 | clients,
267 | episode,
268 | previous_episodes,
269 | entity_types=entity_types,
270 | excluded_entity_types=excluded_entity_types,
271 | custom_extraction_instructions=custom_extraction_instructions,
272 | )
273 | for episode, previous_episodes in episode_tuples
274 | ]
275 | )
276 |
277 | extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
278 | *[
279 | extract_edges(
280 | clients,
281 | episode,
282 | extracted_nodes_bulk[i],
283 | previous_episodes,
284 | edge_type_map=edge_type_map,
285 | group_id=episode.group_id,
286 | edge_types=edge_types,
287 | custom_extraction_instructions=custom_extraction_instructions,
288 | )
289 | for i, (episode, previous_episodes) in enumerate(episode_tuples)
290 | ]
291 | )
292 |
293 | return extracted_nodes_bulk, extracted_edges_bulk
294 |
295 |
296 | async def dedupe_nodes_bulk(
297 | clients: GraphitiClients,
298 | extracted_nodes: list[list[EntityNode]],
299 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
300 | entity_types: dict[str, type[BaseModel]] | None = None,
301 | ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
302 | """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
303 |
304 | 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
305 | reconciled against the live graph just like the non-batch flow.
306 | 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
307 | duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
308 | can apply to edges and persistence.
309 | """
310 |
311 | first_pass_results = await semaphore_gather(
312 | *[
313 | resolve_extracted_nodes(
314 | clients,
315 | nodes,
316 | episode_tuples[i][0],
317 | episode_tuples[i][1],
318 | entity_types,
319 | )
320 | for i, nodes in enumerate(extracted_nodes)
321 | ]
322 | )
323 |
324 | episode_resolutions: list[tuple[str, list[EntityNode]]] = []
325 | per_episode_uuid_maps: list[dict[str, str]] = []
326 | duplicate_pairs: list[tuple[str, str]] = []
327 |
328 | for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
329 | first_pass_results, episode_tuples, strict=True
330 | ):
331 | episode_resolutions.append((episode.uuid, resolved_nodes))
332 | per_episode_uuid_maps.append(uuid_map)
333 | duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
334 |
335 | canonical_nodes: dict[str, EntityNode] = {}
336 | for _, resolved_nodes in episode_resolutions:
337 | for node in resolved_nodes:
338 | # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
339 | # the MinHash index for the accumulated canonical pool each time. The LRU-backed
340 | # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
341 | # but if batches grow significantly we should switch to an incremental index or chunked
342 | # processing.
343 | if not canonical_nodes:
344 | canonical_nodes[node.uuid] = node
345 | continue
346 |
347 | existing_candidates = list(canonical_nodes.values())
348 | normalized = _normalize_string_exact(node.name)
349 | exact_match = next(
350 | (
351 | candidate
352 | for candidate in existing_candidates
353 | if _normalize_string_exact(candidate.name) == normalized
354 | ),
355 | None,
356 | )
357 | if exact_match is not None:
358 | if exact_match.uuid != node.uuid:
359 | duplicate_pairs.append((node.uuid, exact_match.uuid))
360 | continue
361 |
362 | indexes = _build_candidate_indexes(existing_candidates)
363 | state = DedupResolutionState(
364 | resolved_nodes=[None],
365 | uuid_map={},
366 | unresolved_indices=[],
367 | )
368 | _resolve_with_similarity([node], indexes, state)
369 |
370 | resolved = state.resolved_nodes[0]
371 | if resolved is None:
372 | canonical_nodes[node.uuid] = node
373 | continue
374 |
375 | canonical_uuid = resolved.uuid
376 | canonical_nodes.setdefault(canonical_uuid, resolved)
377 | if canonical_uuid != node.uuid:
378 | duplicate_pairs.append((node.uuid, canonical_uuid))
379 |
380 | union_pairs: list[tuple[str, str]] = []
381 | for uuid_map in per_episode_uuid_maps:
382 | union_pairs.extend(uuid_map.items())
383 | union_pairs.extend(duplicate_pairs)
384 |
385 | compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
386 |
387 | nodes_by_episode: dict[str, list[EntityNode]] = {}
388 | for episode_uuid, resolved_nodes in episode_resolutions:
389 | deduped_nodes: list[EntityNode] = []
390 | seen: set[str] = set()
391 | for node in resolved_nodes:
392 | canonical_uuid = compressed_map.get(node.uuid, node.uuid)
393 | if canonical_uuid in seen:
394 | continue
395 | seen.add(canonical_uuid)
396 | canonical_node = canonical_nodes.get(canonical_uuid)
397 | if canonical_node is None:
398 | logger.error(
399 | 'Canonical node %s missing during batch dedupe; falling back to %s',
400 | canonical_uuid,
401 | node.uuid,
402 | )
403 | canonical_node = node
404 | deduped_nodes.append(canonical_node)
405 |
406 | nodes_by_episode[episode_uuid] = deduped_nodes
407 |
408 | return nodes_by_episode, compressed_map
409 |
410 |
411 | async def dedupe_edges_bulk(
412 | clients: GraphitiClients,
413 | extracted_edges: list[list[EntityEdge]],
414 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
415 | _entities: list[EntityNode],
416 | edge_types: dict[str, type[BaseModel]],
417 | _edge_type_map: dict[tuple[str, str], list[str]],
418 | ) -> dict[str, list[EntityEdge]]:
419 | embedder = clients.embedder
420 | min_score = 0.6
421 |
422 | # generate embeddings
423 | await semaphore_gather(
424 | *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
425 | )
426 |
427 | # Find similar results
428 | dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
429 | for i, edges_i in enumerate(extracted_edges):
430 | existing_edges: list[EntityEdge] = []
431 | for edges_j in extracted_edges:
432 | existing_edges += edges_j
433 |
434 | for edge in edges_i:
435 | candidates: list[EntityEdge] = []
436 | for existing_edge in existing_edges:
437 | # Skip self-comparison
438 | if edge.uuid == existing_edge.uuid:
439 | continue
440 | # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
441 | # This approach will cast a wider net than BM25, which is ideal for this use case
442 | if (
443 | edge.source_node_uuid != existing_edge.source_node_uuid
444 | or edge.target_node_uuid != existing_edge.target_node_uuid
445 | ):
446 | continue
447 |
448 | edge_words = set(edge.fact.lower().split())
449 | existing_edge_words = set(existing_edge.fact.lower().split())
450 | has_overlap = not edge_words.isdisjoint(existing_edge_words)
451 | if has_overlap:
452 | candidates.append(existing_edge)
453 | continue
454 |
455 | # Check for semantic similarity even if there is no overlap
456 | similarity = np.dot(
457 | normalize_l2(edge.fact_embedding or []),
458 | normalize_l2(existing_edge.fact_embedding or []),
459 | )
460 | if similarity >= min_score:
461 | candidates.append(existing_edge)
462 |
463 | dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
464 |
465 | bulk_edge_resolutions: list[
466 | tuple[EntityEdge, EntityEdge, list[EntityEdge]]
467 | ] = await semaphore_gather(
468 | *[
469 | resolve_extracted_edge(
470 | clients.llm_client,
471 | edge,
472 | candidates,
473 | candidates,
474 | episode,
475 | edge_types,
476 | set(edge_types),
477 | )
478 | for episode, edge, candidates in dedupe_tuples
479 | ]
480 | )
481 |
482 | # For now we won't track edge invalidation
483 | duplicate_pairs: list[tuple[str, str]] = []
484 | for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
485 | episode, edge, candidates = dedupe_tuples[i]
486 | for duplicate in duplicates:
487 | duplicate_pairs.append((edge.uuid, duplicate.uuid))
488 |
489 | # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
490 | compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
491 |
492 | edge_uuid_map: dict[str, EntityEdge] = {
493 | edge.uuid: edge for edges in extracted_edges for edge in edges
494 | }
495 |
496 | edges_by_episode: dict[str, list[EntityEdge]] = {}
497 | for i, edges in enumerate(extracted_edges):
498 | episode = episode_tuples[i][0]
499 |
500 | edges_by_episode[episode.uuid] = [
501 | edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
502 | ]
503 |
504 | return edges_by_episode
505 |
506 |
507 | class UnionFind:
508 | def __init__(self, elements):
509 | # start each element in its own set
510 | self.parent = {e: e for e in elements}
511 |
512 | def find(self, x):
513 | # path‐compression
514 | if self.parent[x] != x:
515 | self.parent[x] = self.find(self.parent[x])
516 | return self.parent[x]
517 |
518 | def union(self, a, b):
519 | ra, rb = self.find(a), self.find(b)
520 | if ra == rb:
521 | return
522 | # attach the lexicographically larger root under the smaller
523 | if ra < rb:
524 | self.parent[rb] = ra
525 | else:
526 | self.parent[ra] = rb
527 |
528 |
529 | def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
530 | """
531 | all_ids: iterable of all entity IDs (strings)
532 | duplicate_pairs: iterable of (id1, id2) pairs
533 | returns: dict mapping each id -> lexicographically smallest id in its duplicate set
534 | """
535 | all_uuids = set()
536 | for pair in duplicate_pairs:
537 | all_uuids.add(pair[0])
538 | all_uuids.add(pair[1])
539 |
540 | uf = UnionFind(all_uuids)
541 | for a, b in duplicate_pairs:
542 | uf.union(a, b)
543 | # ensure full path‐compression before mapping
544 | return {uuid: uf.find(uuid) for uuid in all_uuids}
545 |
546 |
547 | E = typing.TypeVar('E', bound=Edge)
548 |
549 |
550 | def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
551 | for edge in edges:
552 | source_uuid = edge.source_node_uuid
553 | target_uuid = edge.target_node_uuid
554 | edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
555 | edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
556 |
557 | return edges
558 |
```
--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | from abc import ABC, abstractmethod
20 | from datetime import datetime
21 | from time import time
22 | from typing import Any
23 | from uuid import uuid4
24 |
25 | from pydantic import BaseModel, Field
26 | from typing_extensions import LiteralString
27 |
28 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
29 | from graphiti_core.embedder import EmbedderClient
30 | from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
31 | from graphiti_core.helpers import parse_db_date
32 | from graphiti_core.models.edges.edge_db_queries import (
33 | COMMUNITY_EDGE_RETURN,
34 | EPISODIC_EDGE_RETURN,
35 | EPISODIC_EDGE_SAVE,
36 | get_community_edge_save_query,
37 | get_entity_edge_return_query,
38 | get_entity_edge_save_query,
39 | )
40 | from graphiti_core.nodes import Node
41 |
42 | logger = logging.getLogger(__name__)
43 |
44 |
45 | class Edge(BaseModel, ABC):
46 | uuid: str = Field(default_factory=lambda: str(uuid4()))
47 | group_id: str = Field(description='partition of the graph')
48 | source_node_uuid: str
49 | target_node_uuid: str
50 | created_at: datetime
51 |
52 | @abstractmethod
53 | async def save(self, driver: GraphDriver): ...
54 |
55 | async def delete(self, driver: GraphDriver):
56 | if driver.graph_operations_interface:
57 | return await driver.graph_operations_interface.edge_delete(self, driver)
58 |
59 | if driver.provider == GraphProvider.KUZU:
60 | await driver.execute_query(
61 | """
62 | MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
63 | DELETE e
64 | """,
65 | uuid=self.uuid,
66 | )
67 | await driver.execute_query(
68 | """
69 | MATCH (e:RelatesToNode_ {uuid: $uuid})
70 | DETACH DELETE e
71 | """,
72 | uuid=self.uuid,
73 | )
74 | else:
75 | await driver.execute_query(
76 | """
77 | MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
78 | DELETE e
79 | """,
80 | uuid=self.uuid,
81 | )
82 |
83 | logger.debug(f'Deleted Edge: {self.uuid}')
84 |
85 | @classmethod
86 | async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
87 | if driver.graph_operations_interface:
88 | return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
89 |
90 | if driver.provider == GraphProvider.KUZU:
91 | await driver.execute_query(
92 | """
93 | MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
94 | WHERE e.uuid IN $uuids
95 | DELETE e
96 | """,
97 | uuids=uuids,
98 | )
99 | await driver.execute_query(
100 | """
101 | MATCH (e:RelatesToNode_)
102 | WHERE e.uuid IN $uuids
103 | DETACH DELETE e
104 | """,
105 | uuids=uuids,
106 | )
107 | else:
108 | await driver.execute_query(
109 | """
110 | MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
111 | WHERE e.uuid IN $uuids
112 | DELETE e
113 | """,
114 | uuids=uuids,
115 | )
116 |
117 | logger.debug(f'Deleted Edges: {uuids}')
118 |
119 | def __hash__(self):
120 | return hash(self.uuid)
121 |
122 | def __eq__(self, other):
123 | if isinstance(other, Node):
124 | return self.uuid == other.uuid
125 | return False
126 |
127 | @classmethod
128 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
129 |
130 |
131 | class EpisodicEdge(Edge):
132 | async def save(self, driver: GraphDriver):
133 | result = await driver.execute_query(
134 | EPISODIC_EDGE_SAVE,
135 | episode_uuid=self.source_node_uuid,
136 | entity_uuid=self.target_node_uuid,
137 | uuid=self.uuid,
138 | group_id=self.group_id,
139 | created_at=self.created_at,
140 | )
141 |
142 | logger.debug(f'Saved edge to Graph: {self.uuid}')
143 |
144 | return result
145 |
146 | @classmethod
147 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
148 | records, _, _ = await driver.execute_query(
149 | """
150 | MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
151 | RETURN
152 | """
153 | + EPISODIC_EDGE_RETURN,
154 | uuid=uuid,
155 | routing_='r',
156 | )
157 |
158 | edges = [get_episodic_edge_from_record(record) for record in records]
159 |
160 | if len(edges) == 0:
161 | raise EdgeNotFoundError(uuid)
162 | return edges[0]
163 |
164 | @classmethod
165 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
166 | records, _, _ = await driver.execute_query(
167 | """
168 | MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
169 | WHERE e.uuid IN $uuids
170 | RETURN
171 | """
172 | + EPISODIC_EDGE_RETURN,
173 | uuids=uuids,
174 | routing_='r',
175 | )
176 |
177 | edges = [get_episodic_edge_from_record(record) for record in records]
178 |
179 | if len(edges) == 0:
180 | raise EdgeNotFoundError(uuids[0])
181 | return edges
182 |
183 | @classmethod
184 | async def get_by_group_ids(
185 | cls,
186 | driver: GraphDriver,
187 | group_ids: list[str],
188 | limit: int | None = None,
189 | uuid_cursor: str | None = None,
190 | ):
191 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
192 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
193 |
194 | records, _, _ = await driver.execute_query(
195 | """
196 | MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
197 | WHERE e.group_id IN $group_ids
198 | """
199 | + cursor_query
200 | + """
201 | RETURN
202 | """
203 | + EPISODIC_EDGE_RETURN
204 | + """
205 | ORDER BY e.uuid DESC
206 | """
207 | + limit_query,
208 | group_ids=group_ids,
209 | uuid=uuid_cursor,
210 | limit=limit,
211 | routing_='r',
212 | )
213 |
214 | edges = [get_episodic_edge_from_record(record) for record in records]
215 |
216 | if len(edges) == 0:
217 | raise GroupsEdgesNotFoundError(group_ids)
218 | return edges
219 |
220 |
221 | class EntityEdge(Edge):
222 | name: str = Field(description='name of the edge, relation name')
223 | fact: str = Field(description='fact representing the edge and nodes that it connects')
224 | fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
225 | episodes: list[str] = Field(
226 | default=[],
227 | description='list of episode ids that reference these entity edges',
228 | )
229 | expired_at: datetime | None = Field(
230 | default=None, description='datetime of when the node was invalidated'
231 | )
232 | valid_at: datetime | None = Field(
233 | default=None, description='datetime of when the fact became true'
234 | )
235 | invalid_at: datetime | None = Field(
236 | default=None, description='datetime of when the fact stopped being true'
237 | )
238 | attributes: dict[str, Any] = Field(
239 | default={}, description='Additional attributes of the edge. Dependent on edge name'
240 | )
241 |
242 | async def generate_embedding(self, embedder: EmbedderClient):
243 | start = time()
244 |
245 | text = self.fact.replace('\n', ' ')
246 | self.fact_embedding = await embedder.create(input_data=[text])
247 |
248 | end = time()
249 | logger.debug(f'embedded {text} in {end - start} ms')
250 |
251 | return self.fact_embedding
252 |
253 | async def load_fact_embedding(self, driver: GraphDriver):
254 | if driver.graph_operations_interface:
255 | return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256 |
257 | query = """
258 | MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
259 | RETURN e.fact_embedding AS fact_embedding
260 | """
261 |
262 | if driver.provider == GraphProvider.NEPTUNE:
263 | query = """
264 | MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
265 | RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
266 | """
267 |
268 | if driver.provider == GraphProvider.KUZU:
269 | query = """
270 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
271 | RETURN e.fact_embedding AS fact_embedding
272 | """
273 |
274 | records, _, _ = await driver.execute_query(
275 | query,
276 | uuid=self.uuid,
277 | routing_='r',
278 | )
279 |
280 | if len(records) == 0:
281 | raise EdgeNotFoundError(self.uuid)
282 |
283 | self.fact_embedding = records[0]['fact_embedding']
284 |
285 | async def save(self, driver: GraphDriver):
286 | edge_data: dict[str, Any] = {
287 | 'source_uuid': self.source_node_uuid,
288 | 'target_uuid': self.target_node_uuid,
289 | 'uuid': self.uuid,
290 | 'name': self.name,
291 | 'group_id': self.group_id,
292 | 'fact': self.fact,
293 | 'fact_embedding': self.fact_embedding,
294 | 'episodes': self.episodes,
295 | 'created_at': self.created_at,
296 | 'expired_at': self.expired_at,
297 | 'valid_at': self.valid_at,
298 | 'invalid_at': self.invalid_at,
299 | }
300 |
301 | if driver.provider == GraphProvider.KUZU:
302 | edge_data['attributes'] = json.dumps(self.attributes)
303 | result = await driver.execute_query(
304 | get_entity_edge_save_query(driver.provider),
305 | **edge_data,
306 | )
307 | else:
308 | edge_data.update(self.attributes or {})
309 | result = await driver.execute_query(
310 | get_entity_edge_save_query(driver.provider),
311 | edge_data=edge_data,
312 | )
313 |
314 | logger.debug(f'Saved edge to Graph: {self.uuid}')
315 |
316 | return result
317 |
318 | @classmethod
319 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
320 | match_query = """
321 | MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
322 | """
323 | if driver.provider == GraphProvider.KUZU:
324 | match_query = """
325 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
326 | """
327 |
328 | records, _, _ = await driver.execute_query(
329 | match_query
330 | + """
331 | RETURN
332 | """
333 | + get_entity_edge_return_query(driver.provider),
334 | uuid=uuid,
335 | routing_='r',
336 | )
337 |
338 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
339 |
340 | if len(edges) == 0:
341 | raise EdgeNotFoundError(uuid)
342 | return edges[0]
343 |
344 | @classmethod
345 | async def get_between_nodes(
346 | cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
347 | ):
348 | match_query = """
349 | MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
350 | """
351 | if driver.provider == GraphProvider.KUZU:
352 | match_query = """
353 | MATCH (n:Entity {uuid: $source_node_uuid})
354 | -[:RELATES_TO]->(e:RelatesToNode_)
355 | -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
356 | """
357 |
358 | records, _, _ = await driver.execute_query(
359 | match_query
360 | + """
361 | RETURN
362 | """
363 | + get_entity_edge_return_query(driver.provider),
364 | source_node_uuid=source_node_uuid,
365 | target_node_uuid=target_node_uuid,
366 | routing_='r',
367 | )
368 |
369 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
370 |
371 | return edges
372 |
373 | @classmethod
374 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
375 | if len(uuids) == 0:
376 | return []
377 |
378 | match_query = """
379 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
380 | """
381 | if driver.provider == GraphProvider.KUZU:
382 | match_query = """
383 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
384 | """
385 |
386 | records, _, _ = await driver.execute_query(
387 | match_query
388 | + """
389 | WHERE e.uuid IN $uuids
390 | RETURN
391 | """
392 | + get_entity_edge_return_query(driver.provider),
393 | uuids=uuids,
394 | routing_='r',
395 | )
396 |
397 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
398 |
399 | return edges
400 |
401 | @classmethod
402 | async def get_by_group_ids(
403 | cls,
404 | driver: GraphDriver,
405 | group_ids: list[str],
406 | limit: int | None = None,
407 | uuid_cursor: str | None = None,
408 | with_embeddings: bool = False,
409 | ):
410 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
411 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
412 | with_embeddings_query: LiteralString = (
413 | """,
414 | e.fact_embedding AS fact_embedding
415 | """
416 | if with_embeddings
417 | else ''
418 | )
419 |
420 | match_query = """
421 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
422 | """
423 | if driver.provider == GraphProvider.KUZU:
424 | match_query = """
425 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
426 | """
427 |
428 | records, _, _ = await driver.execute_query(
429 | match_query
430 | + """
431 | WHERE e.group_id IN $group_ids
432 | """
433 | + cursor_query
434 | + """
435 | RETURN
436 | """
437 | + get_entity_edge_return_query(driver.provider)
438 | + with_embeddings_query
439 | + """
440 | ORDER BY e.uuid DESC
441 | """
442 | + limit_query,
443 | group_ids=group_ids,
444 | uuid=uuid_cursor,
445 | limit=limit,
446 | routing_='r',
447 | )
448 |
449 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
450 |
451 | if len(edges) == 0:
452 | raise GroupsEdgesNotFoundError(group_ids)
453 | return edges
454 |
455 | @classmethod
456 | async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
457 | match_query = """
458 | MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
459 | """
460 | if driver.provider == GraphProvider.KUZU:
461 | match_query = """
462 | MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
463 | """
464 |
465 | records, _, _ = await driver.execute_query(
466 | match_query
467 | + """
468 | RETURN
469 | """
470 | + get_entity_edge_return_query(driver.provider),
471 | node_uuid=node_uuid,
472 | routing_='r',
473 | )
474 |
475 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
476 |
477 | return edges
478 |
479 |
480 | class CommunityEdge(Edge):
481 | async def save(self, driver: GraphDriver):
482 | result = await driver.execute_query(
483 | get_community_edge_save_query(driver.provider),
484 | community_uuid=self.source_node_uuid,
485 | entity_uuid=self.target_node_uuid,
486 | uuid=self.uuid,
487 | group_id=self.group_id,
488 | created_at=self.created_at,
489 | )
490 |
491 | logger.debug(f'Saved edge to Graph: {self.uuid}')
492 |
493 | return result
494 |
495 | @classmethod
496 | async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
497 | records, _, _ = await driver.execute_query(
498 | """
499 | MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
500 | RETURN
501 | """
502 | + COMMUNITY_EDGE_RETURN,
503 | uuid=uuid,
504 | routing_='r',
505 | )
506 |
507 | edges = [get_community_edge_from_record(record) for record in records]
508 |
509 | return edges[0]
510 |
511 | @classmethod
512 | async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
513 | records, _, _ = await driver.execute_query(
514 | """
515 | MATCH (n:Community)-[e:HAS_MEMBER]->(m)
516 | WHERE e.uuid IN $uuids
517 | RETURN
518 | """
519 | + COMMUNITY_EDGE_RETURN,
520 | uuids=uuids,
521 | routing_='r',
522 | )
523 |
524 | edges = [get_community_edge_from_record(record) for record in records]
525 |
526 | return edges
527 |
528 | @classmethod
529 | async def get_by_group_ids(
530 | cls,
531 | driver: GraphDriver,
532 | group_ids: list[str],
533 | limit: int | None = None,
534 | uuid_cursor: str | None = None,
535 | ):
536 | cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
537 | limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
538 |
539 | records, _, _ = await driver.execute_query(
540 | """
541 | MATCH (n:Community)-[e:HAS_MEMBER]->(m)
542 | WHERE e.group_id IN $group_ids
543 | """
544 | + cursor_query
545 | + """
546 | RETURN
547 | """
548 | + COMMUNITY_EDGE_RETURN
549 | + """
550 | ORDER BY e.uuid DESC
551 | """
552 | + limit_query,
553 | group_ids=group_ids,
554 | uuid=uuid_cursor,
555 | limit=limit,
556 | routing_='r',
557 | )
558 |
559 | edges = [get_community_edge_from_record(record) for record in records]
560 |
561 | return edges
562 |
563 |
564 | # Edge helpers
565 | def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
566 | return EpisodicEdge(
567 | uuid=record['uuid'],
568 | group_id=record['group_id'],
569 | source_node_uuid=record['source_node_uuid'],
570 | target_node_uuid=record['target_node_uuid'],
571 | created_at=parse_db_date(record['created_at']), # type: ignore
572 | )
573 |
574 |
575 | def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
576 | episodes = record['episodes']
577 | if provider == GraphProvider.KUZU:
578 | attributes = json.loads(record['attributes']) if record['attributes'] else {}
579 | else:
580 | attributes = record['attributes']
581 | attributes.pop('uuid', None)
582 | attributes.pop('source_node_uuid', None)
583 | attributes.pop('target_node_uuid', None)
584 | attributes.pop('fact', None)
585 | attributes.pop('fact_embedding', None)
586 | attributes.pop('name', None)
587 | attributes.pop('group_id', None)
588 | attributes.pop('episodes', None)
589 | attributes.pop('created_at', None)
590 | attributes.pop('expired_at', None)
591 | attributes.pop('valid_at', None)
592 | attributes.pop('invalid_at', None)
593 |
594 | edge = EntityEdge(
595 | uuid=record['uuid'],
596 | source_node_uuid=record['source_node_uuid'],
597 | target_node_uuid=record['target_node_uuid'],
598 | fact=record['fact'],
599 | fact_embedding=record.get('fact_embedding'),
600 | name=record['name'],
601 | group_id=record['group_id'],
602 | episodes=episodes,
603 | created_at=parse_db_date(record['created_at']), # type: ignore
604 | expired_at=parse_db_date(record['expired_at']),
605 | valid_at=parse_db_date(record['valid_at']),
606 | invalid_at=parse_db_date(record['invalid_at']),
607 | attributes=attributes,
608 | )
609 |
610 | return edge
611 |
612 |
613 | def get_community_edge_from_record(record: Any):
614 | return CommunityEdge(
615 | uuid=record['uuid'],
616 | group_id=record['group_id'],
617 | source_node_uuid=record['source_node_uuid'],
618 | target_node_uuid=record['target_node_uuid'],
619 | created_at=parse_db_date(record['created_at']), # type: ignore
620 | )
621 |
622 |
623 | async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
624 | # filter out falsey values from edges
625 | filtered_edges = [edge for edge in edges if edge.fact]
626 |
627 | if len(filtered_edges) == 0:
628 | return
629 | fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
630 | for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
631 | edge.fact_embedding = fact_embedding
632 |
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------
```python
1 | import logging
2 | from collections import defaultdict
3 | from unittest.mock import AsyncMock, MagicMock
4 |
5 | import pytest
6 |
7 | from graphiti_core.graphiti_types import GraphitiClients
8 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
9 | from graphiti_core.search.search_config import SearchResults
10 | from graphiti_core.utils.datetime_utils import utc_now
11 | from graphiti_core.utils.maintenance.dedup_helpers import (
12 | DedupCandidateIndexes,
13 | DedupResolutionState,
14 | _build_candidate_indexes,
15 | _cached_shingles,
16 | _has_high_entropy,
17 | _hash_shingle,
18 | _jaccard_similarity,
19 | _lsh_bands,
20 | _minhash_signature,
21 | _name_entropy,
22 | _normalize_name_for_fuzzy,
23 | _normalize_string_exact,
24 | _resolve_with_similarity,
25 | _shingles,
26 | )
27 | from graphiti_core.utils.maintenance.node_operations import (
28 | _collect_candidate_nodes,
29 | _resolve_with_llm,
30 | extract_attributes_from_node,
31 | extract_attributes_from_nodes,
32 | resolve_extracted_nodes,
33 | )
34 |
35 |
36 | def _make_clients():
37 | driver = MagicMock()
38 | embedder = MagicMock()
39 | cross_encoder = MagicMock()
40 | llm_client = MagicMock()
41 | llm_generate = AsyncMock()
42 | llm_client.generate_response = llm_generate
43 |
44 | clients = GraphitiClients.model_construct( # bypass validation to allow test doubles
45 | driver=driver,
46 | embedder=embedder,
47 | cross_encoder=cross_encoder,
48 | llm_client=llm_client,
49 | )
50 |
51 | return clients, llm_generate
52 |
53 |
54 | def _make_episode(group_id: str = 'group'):
55 | return EpisodicNode(
56 | name='episode',
57 | group_id=group_id,
58 | source=EpisodeType.message,
59 | source_description='test',
60 | content='content',
61 | valid_at=utc_now(),
62 | )
63 |
64 |
65 | @pytest.mark.asyncio
66 | async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
67 | clients, llm_generate = _make_clients()
68 |
69 | candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
70 | extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
71 |
72 | async def fake_search(*_, **__):
73 | return SearchResults(nodes=[candidate])
74 |
75 | monkeypatch.setattr(
76 | 'graphiti_core.utils.maintenance.node_operations.search',
77 | fake_search,
78 | )
79 | monkeypatch.setattr(
80 | 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
81 | AsyncMock(return_value=[]),
82 | )
83 |
84 | resolved, uuid_map, _ = await resolve_extracted_nodes(
85 | clients,
86 | [extracted],
87 | episode=_make_episode(),
88 | previous_episodes=[],
89 | )
90 |
91 | assert resolved[0].uuid == candidate.uuid
92 | assert uuid_map[extracted.uuid] == candidate.uuid
93 | llm_generate.assert_not_awaited()
94 |
95 |
96 | @pytest.mark.asyncio
97 | async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
98 | clients, llm_generate = _make_clients()
99 | llm_generate.return_value = {
100 | 'entity_resolutions': [
101 | {
102 | 'id': 0,
103 | 'duplicate_idx': -1,
104 | 'name': 'Joe',
105 | 'duplicates': [],
106 | }
107 | ]
108 | }
109 |
110 | extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])
111 |
112 | async def fake_search(*_, **__):
113 | return SearchResults(nodes=[])
114 |
115 | monkeypatch.setattr(
116 | 'graphiti_core.utils.maintenance.node_operations.search',
117 | fake_search,
118 | )
119 | monkeypatch.setattr(
120 | 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
121 | AsyncMock(return_value=[]),
122 | )
123 |
124 | resolved, uuid_map, _ = await resolve_extracted_nodes(
125 | clients,
126 | [extracted],
127 | episode=_make_episode(),
128 | previous_episodes=[],
129 | )
130 |
131 | assert resolved[0].uuid == extracted.uuid
132 | assert uuid_map[extracted.uuid] == extracted.uuid
133 | llm_generate.assert_awaited()
134 |
135 |
136 | @pytest.mark.asyncio
137 | async def test_resolve_nodes_fuzzy_match(monkeypatch):
138 | clients, llm_generate = _make_clients()
139 |
140 | candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
141 | extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
142 |
143 | async def fake_search(*_, **__):
144 | return SearchResults(nodes=[candidate])
145 |
146 | monkeypatch.setattr(
147 | 'graphiti_core.utils.maintenance.node_operations.search',
148 | fake_search,
149 | )
150 | monkeypatch.setattr(
151 | 'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
152 | AsyncMock(return_value=[]),
153 | )
154 |
155 | resolved, uuid_map, _ = await resolve_extracted_nodes(
156 | clients,
157 | [extracted],
158 | episode=_make_episode(),
159 | previous_episodes=[],
160 | )
161 |
162 | assert resolved[0].uuid == candidate.uuid
163 | assert uuid_map[extracted.uuid] == candidate.uuid
164 | llm_generate.assert_not_awaited()
165 |
166 |
167 | @pytest.mark.asyncio
168 | async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
169 | clients, _ = _make_clients()
170 |
171 | candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
172 | override_duplicate = EntityNode(
173 | uuid=candidate.uuid,
174 | name='Alice Alt',
175 | group_id='group',
176 | labels=['Entity'],
177 | )
178 | extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])
179 |
180 | search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
181 | monkeypatch.setattr(
182 | 'graphiti_core.utils.maintenance.node_operations.search',
183 | search_mock,
184 | )
185 |
186 | result = await _collect_candidate_nodes(
187 | clients,
188 | [extracted],
189 | existing_nodes_override=[override_duplicate],
190 | )
191 |
192 | assert len(result) == 1
193 | assert result[0].uuid == candidate.uuid
194 | search_mock.assert_awaited()
195 |
196 |
197 | def test_build_candidate_indexes_populates_structures():
198 | candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])
199 |
200 | indexes = _build_candidate_indexes([candidate])
201 |
202 | normalized_key = candidate.name.lower()
203 | assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
204 | assert indexes.nodes_by_uuid[candidate.uuid] is candidate
205 | assert candidate.uuid in indexes.shingles_by_candidate
206 | assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())
207 |
208 |
209 | def test_normalize_helpers():
210 | assert _normalize_string_exact(' Alice Smith ') == 'alice smith'
211 | assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'
212 |
213 |
214 | def test_name_entropy_variants():
215 | assert _name_entropy('alice') > _name_entropy('aaaaa')
216 | assert _name_entropy('') == 0.0
217 |
218 |
219 | def test_has_high_entropy_rules():
220 | assert _has_high_entropy('meaningful name') is True
221 | assert _has_high_entropy('aa') is False
222 |
223 |
224 | def test_shingles_and_cache():
225 | raw = 'alice'
226 | shingle_set = _shingles(raw)
227 | assert shingle_set == {'ali', 'lic', 'ice'}
228 | assert _cached_shingles(raw) == shingle_set
229 | assert _cached_shingles(raw) is _cached_shingles(raw)
230 |
231 |
232 | def test_hash_minhash_and_lsh():
233 | shingles = {'abc', 'bcd', 'cde'}
234 | signature = _minhash_signature(shingles)
235 | assert len(signature) == 32
236 | bands = _lsh_bands(signature)
237 | assert all(len(band) == 4 for band in bands)
238 | hashed = {_hash_shingle(s, 0) for s in shingles}
239 | assert len(hashed) == len(shingles)
240 |
241 |
242 | def test_jaccard_similarity_edges():
243 | a = {'a', 'b'}
244 | b = {'a', 'c'}
245 | assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
246 | assert _jaccard_similarity(set(), set()) == 1.0
247 | assert _jaccard_similarity(a, set()) == 0.0
248 |
249 |
250 | def test_resolve_with_similarity_exact_match_updates_state():
251 | candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
252 | extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
253 |
254 | indexes = _build_candidate_indexes([candidate])
255 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
256 |
257 | _resolve_with_similarity([extracted], indexes, state)
258 |
259 | assert state.resolved_nodes[0].uuid == candidate.uuid
260 | assert state.uuid_map[extracted.uuid] == candidate.uuid
261 | assert state.unresolved_indices == []
262 | assert state.duplicate_pairs == [(extracted, candidate)]
263 |
264 |
265 | def test_resolve_with_similarity_low_entropy_defers_resolution():
266 | extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
267 | indexes = DedupCandidateIndexes(
268 | existing_nodes=[],
269 | nodes_by_uuid={},
270 | normalized_existing=defaultdict(list),
271 | shingles_by_candidate={},
272 | lsh_buckets=defaultdict(list),
273 | )
274 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
275 |
276 | _resolve_with_similarity([extracted], indexes, state)
277 |
278 | assert state.resolved_nodes[0] is None
279 | assert state.unresolved_indices == [0]
280 | assert state.duplicate_pairs == []
281 |
282 |
283 | def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
284 | candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
285 | candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
286 | extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
287 |
288 | indexes = _build_candidate_indexes([candidate1, candidate2])
289 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
290 |
291 | _resolve_with_similarity([extracted], indexes, state)
292 |
293 | assert state.resolved_nodes[0] is None
294 | assert state.unresolved_indices == [0]
295 | assert state.duplicate_pairs == []
296 |
297 |
298 | @pytest.mark.asyncio
299 | async def test_resolve_with_llm_updates_unresolved(monkeypatch):
300 | extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
301 | candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
302 |
303 | indexes = _build_candidate_indexes([candidate])
304 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
305 |
306 | captured_context = {}
307 |
308 | def fake_prompt_nodes(context):
309 | captured_context.update(context)
310 | return ['prompt']
311 |
312 | monkeypatch.setattr(
313 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
314 | fake_prompt_nodes,
315 | )
316 |
317 | async def fake_generate_response(*_, **__):
318 | return {
319 | 'entity_resolutions': [
320 | {
321 | 'id': 0,
322 | 'duplicate_idx': 0,
323 | 'name': 'Dizzy Gillespie',
324 | 'duplicates': [0],
325 | }
326 | ]
327 | }
328 |
329 | llm_client = MagicMock()
330 | llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)
331 |
332 | await _resolve_with_llm(
333 | llm_client,
334 | [extracted],
335 | indexes,
336 | state,
337 | episode=_make_episode(),
338 | previous_episodes=[],
339 | entity_types=None,
340 | )
341 |
342 | assert state.resolved_nodes[0].uuid == candidate.uuid
343 | assert state.uuid_map[extracted.uuid] == candidate.uuid
344 | assert captured_context['existing_nodes'][0]['idx'] == 0
345 | assert isinstance(captured_context['existing_nodes'], list)
346 | assert state.duplicate_pairs == [(extracted, candidate)]
347 |
348 |
349 | @pytest.mark.asyncio
350 | async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
351 | extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
352 |
353 | indexes = _build_candidate_indexes([])
354 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
355 |
356 | monkeypatch.setattr(
357 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
358 | lambda context: ['prompt'],
359 | )
360 |
361 | llm_client = MagicMock()
362 | llm_client.generate_response = AsyncMock(
363 | return_value={
364 | 'entity_resolutions': [
365 | {
366 | 'id': 5,
367 | 'duplicate_idx': -1,
368 | 'name': 'Dexter',
369 | 'duplicates': [],
370 | }
371 | ]
372 | }
373 | )
374 |
375 | with caplog.at_level(logging.WARNING):
376 | await _resolve_with_llm(
377 | llm_client,
378 | [extracted],
379 | indexes,
380 | state,
381 | episode=_make_episode(),
382 | previous_episodes=[],
383 | entity_types=None,
384 | )
385 |
386 | assert state.resolved_nodes[0] is None
387 | assert 'Skipping invalid LLM dedupe id 5' in caplog.text
388 |
389 |
390 | @pytest.mark.asyncio
391 | async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
392 | extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
393 | candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
394 |
395 | indexes = _build_candidate_indexes([candidate])
396 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
397 |
398 | monkeypatch.setattr(
399 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
400 | lambda context: ['prompt'],
401 | )
402 |
403 | llm_client = MagicMock()
404 | llm_client.generate_response = AsyncMock(
405 | return_value={
406 | 'entity_resolutions': [
407 | {
408 | 'id': 0,
409 | 'duplicate_idx': 0,
410 | 'name': 'Dizzy Gillespie',
411 | 'duplicates': [0],
412 | },
413 | {
414 | 'id': 0,
415 | 'duplicate_idx': -1,
416 | 'name': 'Dizzy',
417 | 'duplicates': [],
418 | },
419 | ]
420 | }
421 | )
422 |
423 | await _resolve_with_llm(
424 | llm_client,
425 | [extracted],
426 | indexes,
427 | state,
428 | episode=_make_episode(),
429 | previous_episodes=[],
430 | entity_types=None,
431 | )
432 |
433 | assert state.resolved_nodes[0].uuid == candidate.uuid
434 | assert state.uuid_map[extracted.uuid] == candidate.uuid
435 | assert state.duplicate_pairs == [(extracted, candidate)]
436 |
437 |
438 | @pytest.mark.asyncio
439 | async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
440 | extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
441 |
442 | indexes = _build_candidate_indexes([])
443 | state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
444 |
445 | monkeypatch.setattr(
446 | 'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
447 | lambda context: ['prompt'],
448 | )
449 |
450 | llm_client = MagicMock()
451 | llm_client.generate_response = AsyncMock(
452 | return_value={
453 | 'entity_resolutions': [
454 | {
455 | 'id': 0,
456 | 'duplicate_idx': 10,
457 | 'name': 'Dexter',
458 | 'duplicates': [],
459 | }
460 | ]
461 | }
462 | )
463 |
464 | await _resolve_with_llm(
465 | llm_client,
466 | [extracted],
467 | indexes,
468 | state,
469 | episode=_make_episode(),
470 | previous_episodes=[],
471 | entity_types=None,
472 | )
473 |
474 | assert state.resolved_nodes[0] == extracted
475 | assert state.uuid_map[extracted.uuid] == extracted.uuid
476 | assert state.duplicate_pairs == []
477 |
478 |
479 | @pytest.mark.asyncio
480 | async def test_extract_attributes_without_callback_generates_summary():
481 | """Test that summary is generated when no callback is provided (default behavior)."""
482 | llm_client = MagicMock()
483 | llm_client.generate_response = AsyncMock(
484 | return_value={'summary': 'Generated summary', 'attributes': {}}
485 | )
486 |
487 | node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
488 | episode = _make_episode()
489 |
490 | result = await extract_attributes_from_node(
491 | llm_client,
492 | node,
493 | episode=episode,
494 | previous_episodes=[],
495 | entity_type=None,
496 | should_summarize_node=None, # No callback provided
497 | )
498 |
499 | # Summary should be generated
500 | assert result.summary == 'Generated summary'
501 | # LLM should have been called for summary
502 | assert llm_client.generate_response.call_count == 1
503 |
504 |
505 | @pytest.mark.asyncio
506 | async def test_extract_attributes_with_callback_skip_summary():
507 | """Test that summary is NOT regenerated when callback returns False."""
508 | llm_client = MagicMock()
509 | llm_client.generate_response = AsyncMock(
510 | return_value={'summary': 'This should not be used', 'attributes': {}}
511 | )
512 |
513 | node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
514 | episode = _make_episode()
515 |
516 | # Callback that always returns False (skip summary generation)
517 | async def skip_summary_filter(node: EntityNode) -> bool:
518 | return False
519 |
520 | result = await extract_attributes_from_node(
521 | llm_client,
522 | node,
523 | episode=episode,
524 | previous_episodes=[],
525 | entity_type=None,
526 | should_summarize_node=skip_summary_filter,
527 | )
528 |
529 | # Summary should remain unchanged
530 | assert result.summary == 'Old summary'
531 | # LLM should NOT have been called for summary
532 | assert llm_client.generate_response.call_count == 0
533 |
534 |
535 | @pytest.mark.asyncio
536 | async def test_extract_attributes_with_callback_generate_summary():
537 | """Test that summary is regenerated when callback returns True."""
538 | llm_client = MagicMock()
539 | llm_client.generate_response = AsyncMock(
540 | return_value={'summary': 'New generated summary', 'attributes': {}}
541 | )
542 |
543 | node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
544 | episode = _make_episode()
545 |
546 | # Callback that always returns True (generate summary)
547 | async def generate_summary_filter(node: EntityNode) -> bool:
548 | return True
549 |
550 | result = await extract_attributes_from_node(
551 | llm_client,
552 | node,
553 | episode=episode,
554 | previous_episodes=[],
555 | entity_type=None,
556 | should_summarize_node=generate_summary_filter,
557 | )
558 |
559 | # Summary should be updated
560 | assert result.summary == 'New generated summary'
561 | # LLM should have been called for summary
562 | assert llm_client.generate_response.call_count == 1
563 |
564 |
565 | @pytest.mark.asyncio
566 | async def test_extract_attributes_with_selective_callback():
567 | """Test callback that selectively skips summaries based on node properties."""
568 | llm_client = MagicMock()
569 | llm_client.generate_response = AsyncMock(
570 | return_value={'summary': 'Generated summary', 'attributes': {}}
571 | )
572 |
573 | user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
574 | topic_node = EntityNode(
575 | name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
576 | )
577 |
578 | episode = _make_episode()
579 |
580 | # Callback that skips User nodes but generates for others
581 | async def selective_filter(node: EntityNode) -> bool:
582 | return 'User' not in node.labels
583 |
584 | result_user = await extract_attributes_from_node(
585 | llm_client,
586 | user_node,
587 | episode=episode,
588 | previous_episodes=[],
589 | entity_type=None,
590 | should_summarize_node=selective_filter,
591 | )
592 |
593 | result_topic = await extract_attributes_from_node(
594 | llm_client,
595 | topic_node,
596 | episode=episode,
597 | previous_episodes=[],
598 | entity_type=None,
599 | should_summarize_node=selective_filter,
600 | )
601 |
602 | # User summary should remain unchanged
603 | assert result_user.summary == 'Old'
604 | # Topic summary should be generated
605 | assert result_topic.summary == 'Generated summary'
606 | # LLM should have been called only once (for topic)
607 | assert llm_client.generate_response.call_count == 1
608 |
609 |
610 | @pytest.mark.asyncio
611 | async def test_extract_attributes_from_nodes_with_callback():
612 | """Test that callback is properly passed through extract_attributes_from_nodes."""
613 | clients, _ = _make_clients()
614 | clients.llm_client.generate_response = AsyncMock(
615 | return_value={'summary': 'New summary', 'attributes': {}}
616 | )
617 | clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
618 | clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
619 |
620 | node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
621 | node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
622 |
623 | episode = _make_episode()
624 |
625 | call_tracker = []
626 |
627 | # Callback that tracks which nodes it's called with
628 | async def tracking_filter(node: EntityNode) -> bool:
629 | call_tracker.append(node.name)
630 | return 'User' not in node.labels
631 |
632 | results = await extract_attributes_from_nodes(
633 | clients,
634 | [node1, node2],
635 | episode=episode,
636 | previous_episodes=[],
637 | entity_types=None,
638 | should_summarize_node=tracking_filter,
639 | )
640 |
641 | # Callback should have been called for both nodes
642 | assert len(call_tracker) == 2
643 | assert 'Node1' in call_tracker
644 | assert 'Node2' in call_tracker
645 |
646 | # Node1 (User) should keep old summary, Node2 (Topic) should get new summary
647 | node1_result = next(n for n in results if n.name == 'Node1')
648 | node2_result = next(n for n in results if n.name == 'Node2')
649 |
650 | assert node1_result.summary == 'Old1'
651 | assert node2_result.summary == 'New summary'
652 |
```
--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
18 |
19 | from unittest.mock import AsyncMock, MagicMock, patch
20 |
21 | import pytest
22 | from pydantic import BaseModel
23 |
24 | from graphiti_core.llm_client.config import LLMConfig, ModelSize
25 | from graphiti_core.llm_client.errors import RateLimitError
26 | from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
27 | from graphiti_core.prompts.models import Message
28 |
29 |
30 | # Test model for response testing
31 | class ResponseModel(BaseModel):
32 | """Test model for response testing."""
33 |
34 | test_field: str
35 | optional_field: int = 0
36 |
37 |
38 | @pytest.fixture
39 | def mock_gemini_client():
40 | """Fixture to mock the Google Gemini client."""
41 | with patch('google.genai.Client') as mock_client:
42 | # Setup mock instance and its methods
43 | mock_instance = mock_client.return_value
44 | mock_instance.aio = MagicMock()
45 | mock_instance.aio.models = MagicMock()
46 | mock_instance.aio.models.generate_content = AsyncMock()
47 | yield mock_instance
48 |
49 |
50 | @pytest.fixture
51 | def gemini_client(mock_gemini_client):
52 | """Fixture to create a GeminiClient with a mocked client."""
53 | config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
54 | client = GeminiClient(config=config, cache=False)
55 | # Replace the client's client with our mock to ensure we're using the mock
56 | client.client = mock_gemini_client
57 | return client
58 |
59 |
60 | class TestGeminiClientInitialization:
61 | """Tests for GeminiClient initialization."""
62 |
63 | @patch('google.genai.Client')
64 | def test_init_with_config(self, mock_client):
65 | """Test initialization with a config object."""
66 | config = LLMConfig(
67 | api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
68 | )
69 | client = GeminiClient(config=config, cache=False, max_tokens=1000)
70 |
71 | assert client.config == config
72 | assert client.model == 'test-model'
73 | assert client.temperature == 0.5
74 | assert client.max_tokens == 1000
75 |
76 | @patch('google.genai.Client')
77 | def test_init_with_default_model(self, mock_client):
78 | """Test initialization with default model when none is provided."""
79 | config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
80 | client = GeminiClient(config=config, cache=False)
81 |
82 | assert client.model == DEFAULT_MODEL
83 |
84 | @patch('google.genai.Client')
85 | def test_init_without_config(self, mock_client):
86 | """Test initialization without a config uses defaults."""
87 | client = GeminiClient(cache=False)
88 |
89 | assert client.config is not None
90 | # When no config.model is set, it will be None, not DEFAULT_MODEL
91 | assert client.model is None
92 |
93 | @patch('google.genai.Client')
94 | def test_init_with_thinking_config(self, mock_client):
95 | """Test initialization with thinking config."""
96 | with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
97 | thinking_config = mock_thinking_config.return_value
98 | client = GeminiClient(thinking_config=thinking_config)
99 | assert client.thinking_config == thinking_config
100 |
101 |
102 | class TestGeminiClientGenerateResponse:
103 | """Tests for GeminiClient generate_response method."""
104 |
105 | @pytest.mark.asyncio
106 | async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
107 | """Test successful response generation with simple text."""
108 | # Setup mock response
109 | mock_response = MagicMock()
110 | mock_response.text = 'Test response text'
111 | mock_response.candidates = []
112 | mock_response.prompt_feedback = None
113 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
114 |
115 | # Call method
116 | messages = [Message(role='user', content='Test message')]
117 | result = await gemini_client.generate_response(messages)
118 |
119 | # Assertions
120 | assert isinstance(result, dict)
121 | assert result['content'] == 'Test response text'
122 | mock_gemini_client.aio.models.generate_content.assert_called_once()
123 |
124 | @pytest.mark.asyncio
125 | async def test_generate_response_with_structured_output(
126 | self, gemini_client, mock_gemini_client
127 | ):
128 | """Test response generation with structured output."""
129 | # Setup mock response
130 | mock_response = MagicMock()
131 | mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
132 | mock_response.candidates = []
133 | mock_response.prompt_feedback = None
134 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
135 |
136 | # Call method
137 | messages = [
138 | Message(role='system', content='System message'),
139 | Message(role='user', content='User message'),
140 | ]
141 | result = await gemini_client.generate_response(
142 | messages=messages, response_model=ResponseModel
143 | )
144 |
145 | # Assertions
146 | assert isinstance(result, dict)
147 | assert result['test_field'] == 'test_value'
148 | assert result['optional_field'] == 42
149 | mock_gemini_client.aio.models.generate_content.assert_called_once()
150 |
151 | @pytest.mark.asyncio
152 | async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
153 | """Test response generation with system message handling."""
154 | # Setup mock response
155 | mock_response = MagicMock()
156 | mock_response.text = 'Response with system context'
157 | mock_response.candidates = []
158 | mock_response.prompt_feedback = None
159 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
160 |
161 | # Call method
162 | messages = [
163 | Message(role='system', content='System message'),
164 | Message(role='user', content='User message'),
165 | ]
166 | await gemini_client.generate_response(messages)
167 |
168 | # Verify system message is processed correctly
169 | call_args = mock_gemini_client.aio.models.generate_content.call_args
170 | config = call_args[1]['config']
171 | assert 'System message' in config.system_instruction
172 |
173 | @pytest.mark.asyncio
174 | async def test_get_model_for_size(self, gemini_client):
175 | """Test model selection based on size."""
176 | # Test small model
177 | small_model = gemini_client._get_model_for_size(ModelSize.small)
178 | assert small_model == DEFAULT_SMALL_MODEL
179 |
180 | # Test medium/large model
181 | medium_model = gemini_client._get_model_for_size(ModelSize.medium)
182 | assert medium_model == gemini_client.model
183 |
184 | @pytest.mark.asyncio
185 | async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
186 | """Test handling of rate limit errors."""
187 | # Setup mock to raise rate limit error
188 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
189 | 'Rate limit exceeded'
190 | )
191 |
192 | # Call method and check exception
193 | messages = [Message(role='user', content='Test message')]
194 | with pytest.raises(RateLimitError):
195 | await gemini_client.generate_response(messages)
196 |
197 | @pytest.mark.asyncio
198 | async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
199 | """Test handling of quota errors."""
200 | # Setup mock to raise quota error
201 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
202 | 'Quota exceeded for requests'
203 | )
204 |
205 | # Call method and check exception
206 | messages = [Message(role='user', content='Test message')]
207 | with pytest.raises(RateLimitError):
208 | await gemini_client.generate_response(messages)
209 |
210 | @pytest.mark.asyncio
211 | async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
212 | """Test handling of resource exhausted errors."""
213 | # Setup mock to raise resource exhausted error
214 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
215 | 'resource_exhausted: Request limit exceeded'
216 | )
217 |
218 | # Call method and check exception
219 | messages = [Message(role='user', content='Test message')]
220 | with pytest.raises(RateLimitError):
221 | await gemini_client.generate_response(messages)
222 |
223 | @pytest.mark.asyncio
224 | async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
225 | """Test handling of safety blocks."""
226 | # Setup mock response with safety block
227 | mock_candidate = MagicMock()
228 | mock_candidate.finish_reason = 'SAFETY'
229 | mock_candidate.safety_ratings = [
230 | MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
231 | ]
232 |
233 | mock_response = MagicMock()
234 | mock_response.candidates = [mock_candidate]
235 | mock_response.prompt_feedback = None
236 | mock_response.text = ''
237 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
238 |
239 | # Call method and check exception
240 | messages = [Message(role='user', content='Test message')]
241 | with pytest.raises(Exception, match='Content blocked by safety filters'):
242 | await gemini_client.generate_response(messages)
243 |
244 | @pytest.mark.asyncio
245 | async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
246 | """Test handling of prompt blocks."""
247 | # Setup mock response with prompt block
248 | mock_prompt_feedback = MagicMock()
249 | mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
250 |
251 | mock_response = MagicMock()
252 | mock_response.candidates = []
253 | mock_response.prompt_feedback = mock_prompt_feedback
254 | mock_response.text = ''
255 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
256 |
257 | # Call method and check exception
258 | messages = [Message(role='user', content='Test message')]
259 | with pytest.raises(Exception, match='Content blocked by safety filters'):
260 | await gemini_client.generate_response(messages)
261 |
262 | @pytest.mark.asyncio
263 | async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
264 | """Test handling of structured output parsing errors."""
265 | # Setup mock response with invalid JSON that will exhaust retries
266 | mock_response = MagicMock()
267 | mock_response.text = 'Invalid JSON that cannot be parsed'
268 | mock_response.candidates = []
269 | mock_response.prompt_feedback = None
270 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
271 |
272 | # Call method and check exception - should exhaust retries
273 | messages = [Message(role='user', content='Test message')]
274 | with pytest.raises(Exception): # noqa: B017
275 | await gemini_client.generate_response(messages, response_model=ResponseModel)
276 |
277 | # Should have called generate_content MAX_RETRIES times (2 attempts total)
278 | assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
279 |
280 | @pytest.mark.asyncio
281 | async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
282 | """Test that safety blocks are not retried."""
283 | # Setup mock response with safety block
284 | mock_candidate = MagicMock()
285 | mock_candidate.finish_reason = 'SAFETY'
286 | mock_candidate.safety_ratings = [
287 | MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
288 | ]
289 |
290 | mock_response = MagicMock()
291 | mock_response.candidates = [mock_candidate]
292 | mock_response.prompt_feedback = None
293 | mock_response.text = ''
294 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
295 |
296 | # Call method and check that it doesn't retry
297 | messages = [Message(role='user', content='Test message')]
298 | with pytest.raises(Exception, match='Content blocked by safety filters'):
299 | await gemini_client.generate_response(messages)
300 |
301 | # Should only be called once (no retries for safety blocks)
302 | assert mock_gemini_client.aio.models.generate_content.call_count == 1
303 |
304 | @pytest.mark.asyncio
305 | async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
306 | """Test retry behavior on validation error."""
307 | # First call returns invalid JSON, second call returns valid data
308 | mock_response1 = MagicMock()
309 | mock_response1.text = 'Invalid JSON that cannot be parsed'
310 | mock_response1.candidates = []
311 | mock_response1.prompt_feedback = None
312 |
313 | mock_response2 = MagicMock()
314 | mock_response2.text = '{"test_field": "correct_value"}'
315 | mock_response2.candidates = []
316 | mock_response2.prompt_feedback = None
317 |
318 | mock_gemini_client.aio.models.generate_content.side_effect = [
319 | mock_response1,
320 | mock_response2,
321 | ]
322 |
323 | # Call method
324 | messages = [Message(role='user', content='Test message')]
325 | result = await gemini_client.generate_response(messages, response_model=ResponseModel)
326 |
327 | # Should have called generate_content twice due to retry
328 | assert mock_gemini_client.aio.models.generate_content.call_count == 2
329 | assert result['test_field'] == 'correct_value'
330 |
331 | @pytest.mark.asyncio
332 | async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
333 | """Test behavior when max retries are exceeded."""
334 | # Setup mock to always return invalid JSON
335 | mock_response = MagicMock()
336 | mock_response.text = 'Invalid JSON that cannot be parsed'
337 | mock_response.candidates = []
338 | mock_response.prompt_feedback = None
339 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 |
341 | # Call method and check exception
342 | messages = [Message(role='user', content='Test message')]
343 | with pytest.raises(Exception): # noqa: B017
344 | await gemini_client.generate_response(messages, response_model=ResponseModel)
345 |
346 | # Should have called generate_content MAX_RETRIES times (2 attempts total)
347 | assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
348 |
349 | @pytest.mark.asyncio
350 | async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
351 | """Test handling of empty responses."""
352 | # Setup mock response with no text
353 | mock_response = MagicMock()
354 | mock_response.text = ''
355 | mock_response.candidates = []
356 | mock_response.prompt_feedback = None
357 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
358 |
359 | # Call method with structured output and check exception
360 | messages = [Message(role='user', content='Test message')]
361 | with pytest.raises(Exception): # noqa: B017
362 | await gemini_client.generate_response(messages, response_model=ResponseModel)
363 |
364 | # Should have exhausted retries due to empty response (2 attempts total)
365 | assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
366 |
367 | @pytest.mark.asyncio
368 | async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
369 | """Test that explicit max_tokens parameter takes precedence over all other values."""
370 | # Setup mock response
371 | mock_response = MagicMock()
372 | mock_response.text = 'Test response'
373 | mock_response.candidates = []
374 | mock_response.prompt_feedback = None
375 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
376 |
377 | # Call method with custom max tokens (should take precedence)
378 | messages = [Message(role='user', content='Test message')]
379 | await gemini_client.generate_response(messages, max_tokens=500)
380 |
381 | # Verify explicit max_tokens parameter takes precedence
382 | call_args = mock_gemini_client.aio.models.generate_content.call_args
383 | config = call_args[1]['config']
384 | # Explicit parameter should override everything else
385 | assert config.max_output_tokens == 500
386 |
387 | @pytest.mark.asyncio
388 | async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
389 | """Test max_tokens precedence when no explicit parameter is provided."""
390 | # Setup mock response
391 | mock_response = MagicMock()
392 | mock_response.text = 'Test response'
393 | mock_response.candidates = []
394 | mock_response.prompt_feedback = None
395 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
396 |
397 | # Test case 1: No explicit max_tokens, has instance max_tokens
398 | config = LLMConfig(
399 | api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
400 | )
401 | client = GeminiClient(
402 | config=config, cache=False, max_tokens=2000, client=mock_gemini_client
403 | )
404 |
405 | messages = [Message(role='user', content='Test message')]
406 | await client.generate_response(messages)
407 |
408 | call_args = mock_gemini_client.aio.models.generate_content.call_args
409 | config = call_args[1]['config']
410 | # Instance max_tokens should be used
411 | assert config.max_output_tokens == 2000
412 |
413 | # Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
414 | config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
415 | client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
416 |
417 | messages = [Message(role='user', content='Test message')]
418 | await client.generate_response(messages)
419 |
420 | call_args = mock_gemini_client.aio.models.generate_content.call_args
421 | config = call_args[1]['config']
422 | # Model mapping should be used
423 | assert config.max_output_tokens == 65536
424 |
425 | @pytest.mark.asyncio
426 | async def test_model_size_selection(self, gemini_client, mock_gemini_client):
427 | """Test that the correct model is selected based on model size."""
428 | # Setup mock response
429 | mock_response = MagicMock()
430 | mock_response.text = 'Test response'
431 | mock_response.candidates = []
432 | mock_response.prompt_feedback = None
433 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
434 |
435 | # Call method with small model size
436 | messages = [Message(role='user', content='Test message')]
437 | await gemini_client.generate_response(messages, model_size=ModelSize.small)
438 |
439 | # Verify correct model is used
440 | call_args = mock_gemini_client.aio.models.generate_content.call_args
441 | assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
442 |
443 | @pytest.mark.asyncio
444 | async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
445 | """Test that different Gemini models use their correct max tokens."""
446 | # Setup mock response
447 | mock_response = MagicMock()
448 | mock_response.text = 'Test response'
449 | mock_response.candidates = []
450 | mock_response.prompt_feedback = None
451 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
452 |
453 | # Test data: (model_name, expected_max_tokens)
454 | test_cases = [
455 | ('gemini-2.5-flash', 65536),
456 | ('gemini-2.5-pro', 65536),
457 | ('gemini-2.5-flash-lite', 64000),
458 | ('gemini-2.0-flash', 8192),
459 | ('gemini-1.5-pro', 8192),
460 | ('gemini-1.5-flash', 8192),
461 | ('unknown-model', 8192), # Fallback case
462 | ]
463 |
464 | for model_name, expected_max_tokens in test_cases:
465 | # Create client with specific model, no explicit max_tokens to test mapping
466 | config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
467 | client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
468 |
469 | # Call method without explicit max_tokens to test model mapping fallback
470 | messages = [Message(role='user', content='Test message')]
471 | await client.generate_response(messages)
472 |
473 | # Verify correct max tokens is used from model mapping
474 | call_args = mock_gemini_client.aio.models.generate_content.call_args
475 | config = call_args[1]['config']
476 | assert config.max_output_tokens == expected_max_tokens, (
477 | f'Model {model_name} should use {expected_max_tokens} tokens'
478 | )
479 |
480 |
481 | if __name__ == '__main__':
482 | pytest.main(['-v', 'test_gemini_client.py'])
483 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/content_chunking.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import re
20 |
21 | from graphiti_core.helpers import (
22 | CHUNK_DENSITY_THRESHOLD,
23 | CHUNK_MIN_TOKENS,
24 | CHUNK_OVERLAP_TOKENS,
25 | CHUNK_TOKEN_SIZE,
26 | )
27 | from graphiti_core.nodes import EpisodeType
28 |
29 | logger = logging.getLogger(__name__)
30 |
31 | # Approximate characters per token (conservative estimate)
32 | CHARS_PER_TOKEN = 4
33 |
34 |
35 | def estimate_tokens(text: str) -> int:
36 | """Estimate token count using character-based heuristic.
37 |
38 | Uses ~4 characters per token as a conservative estimate.
39 | This is faster than actual tokenization and works across all LLM providers.
40 |
41 | Args:
42 | text: The text to estimate tokens for
43 |
44 | Returns:
45 | Estimated token count
46 | """
47 | return len(text) // CHARS_PER_TOKEN
48 |
49 |
50 | def _tokens_to_chars(tokens: int) -> int:
51 | """Convert token count to approximate character count."""
52 | return tokens * CHARS_PER_TOKEN
53 |
54 |
55 | def should_chunk(content: str, episode_type: EpisodeType) -> bool:
56 | """Determine whether content should be chunked based on size and entity density.
57 |
58 | Only chunks content that is both:
59 | 1. Large enough to potentially cause LLM issues (>= CHUNK_MIN_TOKENS)
60 | 2. High entity density (many entities per token)
61 |
62 | Short content processes fine regardless of density. This targets the specific
63 | failure case of large entity-dense inputs while preserving context for
64 | prose/narrative content and avoiding unnecessary chunking of small inputs.
65 |
66 | Args:
67 | content: The content to evaluate
68 | episode_type: Type of episode (json, message, text)
69 |
70 | Returns:
71 | True if content is large and has high entity density
72 | """
73 | tokens = estimate_tokens(content)
74 |
75 | # Short content always processes fine - no need to chunk
76 | if tokens < CHUNK_MIN_TOKENS:
77 | return False
78 |
79 | return _estimate_high_density(content, episode_type, tokens)
80 |
81 |
82 | def _estimate_high_density(content: str, episode_type: EpisodeType, tokens: int) -> bool:
83 | """Estimate whether content has high entity density.
84 |
85 | High-density content (many entities per token) benefits from chunking.
86 | Low-density content (prose, narratives) loses context when chunked.
87 |
88 | Args:
89 | content: The content to analyze
90 | episode_type: Type of episode
91 | tokens: Pre-computed token count
92 |
93 | Returns:
94 | True if content appears to have high entity density
95 | """
96 | if episode_type == EpisodeType.json:
97 | return _json_likely_dense(content, tokens)
98 | else:
99 | return _text_likely_dense(content, tokens)
100 |
101 |
102 | def _json_likely_dense(content: str, tokens: int) -> bool:
103 | """Estimate entity density for JSON content.
104 |
105 | JSON is considered dense if it has many array elements or object keys,
106 | as each typically represents a distinct entity or data point.
107 |
108 | Heuristics:
109 | - Array: Count elements, estimate entities per 1000 tokens
110 | - Object: Count top-level keys
111 |
112 | Args:
113 | content: JSON string content
114 | tokens: Token count
115 |
116 | Returns:
117 | True if JSON appears to have high entity density
118 | """
119 | try:
120 | data = json.loads(content)
121 | except json.JSONDecodeError:
122 | # Invalid JSON, fall back to text heuristics
123 | return _text_likely_dense(content, tokens)
124 |
125 | if isinstance(data, list):
126 | # For arrays, each element likely contains entities
127 | element_count = len(data)
128 | # Estimate density: elements per 1000 tokens
129 | density = (element_count / tokens) * 1000 if tokens > 0 else 0
130 | return density > CHUNK_DENSITY_THRESHOLD * 1000 # Scale threshold
131 | elif isinstance(data, dict):
132 | # For objects, count keys recursively (shallow)
133 | key_count = _count_json_keys(data, max_depth=2)
134 | density = (key_count / tokens) * 1000 if tokens > 0 else 0
135 | return density > CHUNK_DENSITY_THRESHOLD * 1000
136 | else:
137 | # Scalar value, no need to chunk
138 | return False
139 |
140 |
141 | def _count_json_keys(data: dict, max_depth: int = 2, current_depth: int = 0) -> int:
142 | """Count keys in a JSON object up to a certain depth.
143 |
144 | Args:
145 | data: Dictionary to count keys in
146 | max_depth: Maximum depth to traverse
147 | current_depth: Current recursion depth
148 |
149 | Returns:
150 | Count of keys
151 | """
152 | if current_depth >= max_depth:
153 | return 0
154 |
155 | count = len(data)
156 | for value in data.values():
157 | if isinstance(value, dict):
158 | count += _count_json_keys(value, max_depth, current_depth + 1)
159 | elif isinstance(value, list):
160 | for item in value:
161 | if isinstance(item, dict):
162 | count += _count_json_keys(item, max_depth, current_depth + 1)
163 | return count
164 |
165 |
166 | def _text_likely_dense(content: str, tokens: int) -> bool:
167 | """Estimate entity density for text content.
168 |
169 | Uses capitalized words as a proxy for named entities (people, places,
170 | organizations, products). High ratio of capitalized words suggests
171 | high entity density.
172 |
173 | Args:
174 | content: Text content
175 | tokens: Token count
176 |
177 | Returns:
178 | True if text appears to have high entity density
179 | """
180 | if tokens == 0:
181 | return False
182 |
183 | # Split into words
184 | words = content.split()
185 | if not words:
186 | return False
187 |
188 | # Count capitalized words (excluding sentence starters)
189 | # A word is "capitalized" if it starts with uppercase and isn't all caps
190 | capitalized_count = 0
191 | for i, word in enumerate(words):
192 | # Skip if it's likely a sentence starter (after . ! ? or first word)
193 | if i == 0:
194 | continue
195 | if i > 0 and words[i - 1].rstrip()[-1:] in '.!?':
196 | continue
197 |
198 | # Check if capitalized (first char upper, not all caps)
199 | cleaned = word.strip('.,!?;:\'"()[]{}')
200 | if cleaned and cleaned[0].isupper() and not cleaned.isupper():
201 | capitalized_count += 1
202 |
203 | # Calculate density: capitalized words per 1000 tokens
204 | density = (capitalized_count / tokens) * 1000 if tokens > 0 else 0
205 |
206 | # Text density threshold is typically lower than JSON
207 | # A well-written article might have 5-10% named entities
208 | return density > CHUNK_DENSITY_THRESHOLD * 500 # Half the JSON threshold
209 |
210 |
211 | def chunk_json_content(
212 | content: str,
213 | chunk_size_tokens: int | None = None,
214 | overlap_tokens: int | None = None,
215 | ) -> list[str]:
216 | """Split JSON content into chunks while preserving structure.
217 |
218 | For arrays: splits at element boundaries, keeping complete objects.
219 | For objects: splits at top-level key boundaries.
220 |
221 | Args:
222 | content: JSON string to chunk
223 | chunk_size_tokens: Target size per chunk in tokens (default from env)
224 | overlap_tokens: Overlap between chunks in tokens (default from env)
225 |
226 | Returns:
227 | List of JSON string chunks
228 | """
229 | chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
230 | overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
231 |
232 | chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
233 | overlap_chars = _tokens_to_chars(overlap_tokens)
234 |
235 | try:
236 | data = json.loads(content)
237 | except json.JSONDecodeError:
238 | logger.warning('Failed to parse JSON, falling back to text chunking')
239 | return chunk_text_content(content, chunk_size_tokens, overlap_tokens)
240 |
241 | if isinstance(data, list):
242 | return _chunk_json_array(data, chunk_size_chars, overlap_chars)
243 | elif isinstance(data, dict):
244 | return _chunk_json_object(data, chunk_size_chars, overlap_chars)
245 | else:
246 | # Scalar value, return as-is
247 | return [content]
248 |
249 |
250 | def _chunk_json_array(
251 | data: list,
252 | chunk_size_chars: int,
253 | overlap_chars: int,
254 | ) -> list[str]:
255 | """Chunk a JSON array by splitting at element boundaries."""
256 | if not data:
257 | return ['[]']
258 |
259 | chunks: list[str] = []
260 | current_elements: list = []
261 | current_size = 2 # Account for '[]'
262 |
263 | for element in data:
264 | element_json = json.dumps(element)
265 | element_size = len(element_json) + 2 # Account for comma and space
266 |
267 | # Check if adding this element would exceed chunk size
268 | if current_elements and current_size + element_size > chunk_size_chars:
269 | # Save current chunk
270 | chunks.append(json.dumps(current_elements))
271 |
272 | # Start new chunk with overlap (include last few elements)
273 | overlap_elements = _get_overlap_elements(current_elements, overlap_chars)
274 | current_elements = overlap_elements
275 | current_size = len(json.dumps(current_elements)) if current_elements else 2
276 |
277 | current_elements.append(element)
278 | current_size += element_size
279 |
280 | # Don't forget the last chunk
281 | if current_elements:
282 | chunks.append(json.dumps(current_elements))
283 |
284 | return chunks if chunks else ['[]']
285 |
286 |
287 | def _get_overlap_elements(elements: list, overlap_chars: int) -> list:
288 | """Get elements from the end of a list that fit within overlap_chars."""
289 | if not elements:
290 | return []
291 |
292 | overlap_elements: list = []
293 | current_size = 2 # Account for '[]'
294 |
295 | for element in reversed(elements):
296 | element_json = json.dumps(element)
297 | element_size = len(element_json) + 2
298 |
299 | if current_size + element_size > overlap_chars:
300 | break
301 |
302 | overlap_elements.insert(0, element)
303 | current_size += element_size
304 |
305 | return overlap_elements
306 |
307 |
308 | def _chunk_json_object(
309 | data: dict,
310 | chunk_size_chars: int,
311 | overlap_chars: int,
312 | ) -> list[str]:
313 | """Chunk a JSON object by splitting at top-level key boundaries."""
314 | if not data:
315 | return ['{}']
316 |
317 | chunks: list[str] = []
318 | current_keys: list[str] = []
319 | current_dict: dict = {}
320 | current_size = 2 # Account for '{}'
321 |
322 | for key, value in data.items():
323 | entry_json = json.dumps({key: value})
324 | entry_size = len(entry_json)
325 |
326 | # Check if adding this entry would exceed chunk size
327 | if current_dict and current_size + entry_size > chunk_size_chars:
328 | # Save current chunk
329 | chunks.append(json.dumps(current_dict))
330 |
331 | # Start new chunk with overlap (include last few keys)
332 | overlap_dict = _get_overlap_dict(current_dict, current_keys, overlap_chars)
333 | current_dict = overlap_dict
334 | current_keys = list(overlap_dict.keys())
335 | current_size = len(json.dumps(current_dict)) if current_dict else 2
336 |
337 | current_dict[key] = value
338 | current_keys.append(key)
339 | current_size += entry_size
340 |
341 | # Don't forget the last chunk
342 | if current_dict:
343 | chunks.append(json.dumps(current_dict))
344 |
345 | return chunks if chunks else ['{}']
346 |
347 |
348 | def _get_overlap_dict(data: dict, keys: list[str], overlap_chars: int) -> dict:
349 | """Get key-value pairs from the end of a dict that fit within overlap_chars."""
350 | if not data or not keys:
351 | return {}
352 |
353 | overlap_dict: dict = {}
354 | current_size = 2 # Account for '{}'
355 |
356 | for key in reversed(keys):
357 | if key not in data:
358 | continue
359 | entry_json = json.dumps({key: data[key]})
360 | entry_size = len(entry_json)
361 |
362 | if current_size + entry_size > overlap_chars:
363 | break
364 |
365 | overlap_dict[key] = data[key]
366 | current_size += entry_size
367 |
368 | # Reverse to maintain original order
369 | return dict(reversed(list(overlap_dict.items())))
370 |
371 |
372 | def chunk_text_content(
373 | content: str,
374 | chunk_size_tokens: int | None = None,
375 | overlap_tokens: int | None = None,
376 | ) -> list[str]:
377 | """Split text content at natural boundaries (paragraphs, sentences).
378 |
379 | Includes overlap to capture entities at chunk boundaries.
380 |
381 | Args:
382 | content: Text to chunk
383 | chunk_size_tokens: Target size per chunk in tokens (default from env)
384 | overlap_tokens: Overlap between chunks in tokens (default from env)
385 |
386 | Returns:
387 | List of text chunks
388 | """
389 | chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
390 | overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
391 |
392 | chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
393 | overlap_chars = _tokens_to_chars(overlap_tokens)
394 |
395 | if len(content) <= chunk_size_chars:
396 | return [content]
397 |
398 | # Split into paragraphs first
399 | paragraphs = re.split(r'\n\s*\n', content)
400 |
401 | chunks: list[str] = []
402 | current_chunk: list[str] = []
403 | current_size = 0
404 |
405 | for paragraph in paragraphs:
406 | paragraph = paragraph.strip()
407 | if not paragraph:
408 | continue
409 |
410 | para_size = len(paragraph)
411 |
412 | # If a single paragraph is too large, split it by sentences
413 | if para_size > chunk_size_chars:
414 | # First, save current chunk if any
415 | if current_chunk:
416 | chunks.append('\n\n'.join(current_chunk))
417 | current_chunk = []
418 | current_size = 0
419 |
420 | # Split large paragraph by sentences
421 | sentence_chunks = _chunk_by_sentences(paragraph, chunk_size_chars, overlap_chars)
422 | chunks.extend(sentence_chunks)
423 | continue
424 |
425 | # Check if adding this paragraph would exceed chunk size
426 | if current_chunk and current_size + para_size + 2 > chunk_size_chars:
427 | # Save current chunk
428 | chunks.append('\n\n'.join(current_chunk))
429 |
430 | # Start new chunk with overlap
431 | overlap_text = _get_overlap_text('\n\n'.join(current_chunk), overlap_chars)
432 | if overlap_text:
433 | current_chunk = [overlap_text]
434 | current_size = len(overlap_text)
435 | else:
436 | current_chunk = []
437 | current_size = 0
438 |
439 | current_chunk.append(paragraph)
440 | current_size += para_size + 2 # Account for '\n\n'
441 |
442 | # Don't forget the last chunk
443 | if current_chunk:
444 | chunks.append('\n\n'.join(current_chunk))
445 |
446 | return chunks if chunks else [content]
447 |
448 |
449 | def _chunk_by_sentences(
450 | text: str,
451 | chunk_size_chars: int,
452 | overlap_chars: int,
453 | ) -> list[str]:
454 | """Split text by sentence boundaries."""
455 | # Split on sentence-ending punctuation followed by whitespace
456 | sentence_pattern = r'(?<=[.!?])\s+'
457 | sentences = re.split(sentence_pattern, text)
458 |
459 | chunks: list[str] = []
460 | current_chunk: list[str] = []
461 | current_size = 0
462 |
463 | for sentence in sentences:
464 | sentence = sentence.strip()
465 | if not sentence:
466 | continue
467 |
468 | sent_size = len(sentence)
469 |
470 | # If a single sentence is too large, split it by fixed size
471 | if sent_size > chunk_size_chars:
472 | if current_chunk:
473 | chunks.append(' '.join(current_chunk))
474 | current_chunk = []
475 | current_size = 0
476 |
477 | # Split by fixed size as last resort
478 | fixed_chunks = _chunk_by_size(sentence, chunk_size_chars, overlap_chars)
479 | chunks.extend(fixed_chunks)
480 | continue
481 |
482 | # Check if adding this sentence would exceed chunk size
483 | if current_chunk and current_size + sent_size + 1 > chunk_size_chars:
484 | chunks.append(' '.join(current_chunk))
485 |
486 | # Start new chunk with overlap
487 | overlap_text = _get_overlap_text(' '.join(current_chunk), overlap_chars)
488 | if overlap_text:
489 | current_chunk = [overlap_text]
490 | current_size = len(overlap_text)
491 | else:
492 | current_chunk = []
493 | current_size = 0
494 |
495 | current_chunk.append(sentence)
496 | current_size += sent_size + 1
497 |
498 | if current_chunk:
499 | chunks.append(' '.join(current_chunk))
500 |
501 | return chunks
502 |
503 |
504 | def _chunk_by_size(
505 | text: str,
506 | chunk_size_chars: int,
507 | overlap_chars: int,
508 | ) -> list[str]:
509 | """Split text by fixed character size (last resort)."""
510 | chunks: list[str] = []
511 | start = 0
512 |
513 | while start < len(text):
514 | end = min(start + chunk_size_chars, len(text))
515 |
516 | # Try to break at word boundary
517 | if end < len(text):
518 | space_idx = text.rfind(' ', start, end)
519 | if space_idx > start:
520 | end = space_idx
521 |
522 | chunks.append(text[start:end].strip())
523 |
524 | # Move start forward, ensuring progress even if overlap >= chunk_size
525 | # Always advance by at least (chunk_size - overlap) or 1 char minimum
526 | min_progress = max(1, chunk_size_chars - overlap_chars)
527 | start = max(start + min_progress, end - overlap_chars)
528 |
529 | return chunks
530 |
531 |
532 | def _get_overlap_text(text: str, overlap_chars: int) -> str:
533 | """Get the last overlap_chars characters of text, breaking at word boundary."""
534 | if len(text) <= overlap_chars:
535 | return text
536 |
537 | overlap_start = len(text) - overlap_chars
538 | # Find the next word boundary after overlap_start
539 | space_idx = text.find(' ', overlap_start)
540 | if space_idx != -1:
541 | return text[space_idx + 1 :]
542 | return text[overlap_start:]
543 |
544 |
545 | def chunk_message_content(
546 | content: str,
547 | chunk_size_tokens: int | None = None,
548 | overlap_tokens: int | None = None,
549 | ) -> list[str]:
550 | """Split conversation content preserving message boundaries.
551 |
552 | Never splits mid-message. Messages are identified by patterns like:
553 | - "Speaker: message"
554 | - JSON message arrays
555 | - Newline-separated messages
556 |
557 | Args:
558 | content: Conversation content to chunk
559 | chunk_size_tokens: Target size per chunk in tokens (default from env)
560 | overlap_tokens: Overlap between chunks in tokens (default from env)
561 |
562 | Returns:
563 | List of conversation chunks
564 | """
565 | chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
566 | overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
567 |
568 | chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
569 | overlap_chars = _tokens_to_chars(overlap_tokens)
570 |
571 | if len(content) <= chunk_size_chars:
572 | return [content]
573 |
574 | # Try to detect message format
575 | # Check if it's JSON (array of message objects)
576 | try:
577 | data = json.loads(content)
578 | if isinstance(data, list):
579 | return _chunk_message_array(data, chunk_size_chars, overlap_chars)
580 | except json.JSONDecodeError:
581 | pass
582 |
583 | # Try speaker pattern (e.g., "Alice: Hello")
584 | speaker_pattern = r'^([A-Za-z_][A-Za-z0-9_\s]*):(.+?)(?=^[A-Za-z_][A-Za-z0-9_\s]*:|$)'
585 | if re.search(speaker_pattern, content, re.MULTILINE | re.DOTALL):
586 | return _chunk_speaker_messages(content, chunk_size_chars, overlap_chars)
587 |
588 | # Fallback to line-based chunking
589 | return _chunk_by_lines(content, chunk_size_chars, overlap_chars)
590 |
591 |
592 | def _chunk_message_array(
593 | messages: list,
594 | chunk_size_chars: int,
595 | overlap_chars: int,
596 | ) -> list[str]:
597 | """Chunk a JSON array of message objects."""
598 | # Delegate to JSON array chunking
599 | chunks = _chunk_json_array(messages, chunk_size_chars, overlap_chars)
600 | return chunks
601 |
602 |
603 | def _chunk_speaker_messages(
604 | content: str,
605 | chunk_size_chars: int,
606 | overlap_chars: int,
607 | ) -> list[str]:
608 | """Chunk messages in 'Speaker: message' format."""
609 | # Split on speaker patterns
610 | pattern = r'(?=^[A-Za-z_][A-Za-z0-9_\s]*:)'
611 | messages = re.split(pattern, content, flags=re.MULTILINE)
612 | messages = [m.strip() for m in messages if m.strip()]
613 |
614 | if not messages:
615 | return [content]
616 |
617 | chunks: list[str] = []
618 | current_messages: list[str] = []
619 | current_size = 0
620 |
621 | for message in messages:
622 | msg_size = len(message)
623 |
624 | # If a single message is too large, include it as its own chunk
625 | if msg_size > chunk_size_chars:
626 | if current_messages:
627 | chunks.append('\n'.join(current_messages))
628 | current_messages = []
629 | current_size = 0
630 | chunks.append(message)
631 | continue
632 |
633 | if current_messages and current_size + msg_size + 1 > chunk_size_chars:
634 | chunks.append('\n'.join(current_messages))
635 |
636 | # Get overlap (last message(s) that fit)
637 | overlap_messages = _get_overlap_messages(current_messages, overlap_chars)
638 | current_messages = overlap_messages
639 | current_size = sum(len(m) for m in current_messages) + len(current_messages) - 1
640 |
641 | current_messages.append(message)
642 | current_size += msg_size + 1
643 |
644 | if current_messages:
645 | chunks.append('\n'.join(current_messages))
646 |
647 | return chunks if chunks else [content]
648 |
649 |
650 | def _get_overlap_messages(messages: list[str], overlap_chars: int) -> list[str]:
651 | """Get messages from the end that fit within overlap_chars."""
652 | if not messages:
653 | return []
654 |
655 | overlap: list[str] = []
656 | current_size = 0
657 |
658 | for msg in reversed(messages):
659 | msg_size = len(msg) + 1
660 | if current_size + msg_size > overlap_chars:
661 | break
662 | overlap.insert(0, msg)
663 | current_size += msg_size
664 |
665 | return overlap
666 |
667 |
668 | def _chunk_by_lines(
669 | content: str,
670 | chunk_size_chars: int,
671 | overlap_chars: int,
672 | ) -> list[str]:
673 | """Chunk content by line boundaries."""
674 | lines = content.split('\n')
675 |
676 | chunks: list[str] = []
677 | current_lines: list[str] = []
678 | current_size = 0
679 |
680 | for line in lines:
681 | line_size = len(line) + 1
682 |
683 | if current_lines and current_size + line_size > chunk_size_chars:
684 | chunks.append('\n'.join(current_lines))
685 |
686 | # Get overlap lines
687 | overlap_text = '\n'.join(current_lines)
688 | overlap = _get_overlap_text(overlap_text, overlap_chars)
689 | if overlap:
690 | current_lines = overlap.split('\n')
691 | current_size = len(overlap)
692 | else:
693 | current_lines = []
694 | current_size = 0
695 |
696 | current_lines.append(line)
697 | current_size += line_size
698 |
699 | if current_lines:
700 | chunks.append('\n'.join(current_lines))
701 |
702 | return chunks if chunks else [content]
703 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from collections.abc import Awaitable, Callable
19 | from time import time
20 | from typing import Any
21 |
22 | from pydantic import BaseModel
23 |
24 | from graphiti_core.graphiti_types import GraphitiClients
25 | from graphiti_core.helpers import semaphore_gather
26 | from graphiti_core.llm_client import LLMClient
27 | from graphiti_core.llm_client.config import ModelSize
28 | from graphiti_core.nodes import (
29 | EntityNode,
30 | EpisodeType,
31 | EpisodicNode,
32 | create_entity_node_embeddings,
33 | )
34 | from graphiti_core.prompts import prompt_library
35 | from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
36 | from graphiti_core.prompts.extract_nodes import (
37 | EntitySummary,
38 | ExtractedEntities,
39 | ExtractedEntity,
40 | MissedEntities,
41 | )
42 | from graphiti_core.search.search import search
43 | from graphiti_core.search.search_config import SearchResults
44 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
45 | from graphiti_core.search.search_filters import SearchFilters
46 | from graphiti_core.utils.content_chunking import (
47 | chunk_json_content,
48 | chunk_message_content,
49 | chunk_text_content,
50 | should_chunk,
51 | )
52 | from graphiti_core.utils.datetime_utils import utc_now
53 | from graphiti_core.utils.maintenance.dedup_helpers import (
54 | DedupCandidateIndexes,
55 | DedupResolutionState,
56 | _build_candidate_indexes,
57 | _resolve_with_similarity,
58 | )
59 | from graphiti_core.utils.maintenance.edge_operations import (
60 | filter_existing_duplicate_of_edges,
61 | )
62 | from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
63 |
64 | logger = logging.getLogger(__name__)
65 |
66 | NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
67 |
68 |
69 | async def extract_nodes_reflexion(
70 | llm_client: LLMClient,
71 | episode: EpisodicNode,
72 | previous_episodes: list[EpisodicNode],
73 | node_names: list[str],
74 | group_id: str | None = None,
75 | ) -> list[str]:
76 | # Prepare context for LLM
77 | context = {
78 | 'episode_content': episode.content,
79 | 'previous_episodes': [ep.content for ep in previous_episodes],
80 | 'extracted_entities': node_names,
81 | }
82 |
83 | llm_response = await llm_client.generate_response(
84 | prompt_library.extract_nodes.reflexion(context),
85 | MissedEntities,
86 | group_id=group_id,
87 | prompt_name='extract_nodes.reflexion',
88 | )
89 | missed_entities = llm_response.get('missed_entities', [])
90 |
91 | return missed_entities
92 |
93 |
94 | async def extract_nodes(
95 | clients: GraphitiClients,
96 | episode: EpisodicNode,
97 | previous_episodes: list[EpisodicNode],
98 | entity_types: dict[str, type[BaseModel]] | None = None,
99 | excluded_entity_types: list[str] | None = None,
100 | custom_extraction_instructions: str | None = None,
101 | ) -> list[EntityNode]:
102 | """Extract entity nodes from an episode with adaptive chunking.
103 |
104 | For high-density content (many entities per token), the content is chunked
105 | and processed in parallel to avoid LLM timeouts and truncation issues.
106 | """
107 | start = time()
108 | llm_client = clients.llm_client
109 |
110 | # Build entity types context
111 | entity_types_context = _build_entity_types_context(entity_types)
112 |
113 | # Build base context
114 | context = {
115 | 'episode_content': episode.content,
116 | 'episode_timestamp': episode.valid_at.isoformat(),
117 | 'previous_episodes': [ep.content for ep in previous_episodes],
118 | 'custom_extraction_instructions': custom_extraction_instructions or '',
119 | 'entity_types': entity_types_context,
120 | 'source_description': episode.source_description,
121 | }
122 |
123 | # Check if chunking is needed (based on entity density)
124 | if should_chunk(episode.content, episode.source):
125 | extracted_entities = await _extract_nodes_chunked(llm_client, episode, context)
126 | else:
127 | extracted_entities = await _extract_nodes_single(llm_client, episode, context)
128 |
129 | # Filter empty names
130 | filtered_entities = [e for e in extracted_entities if e.name.strip()]
131 |
132 | end = time()
133 | logger.debug(f'Extracted {len(filtered_entities)} entities in {(end - start) * 1000:.0f} ms')
134 |
135 | # Convert to EntityNode objects
136 | extracted_nodes = _create_entity_nodes(
137 | filtered_entities, entity_types_context, excluded_entity_types, episode
138 | )
139 |
140 | logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
141 | return extracted_nodes
142 |
143 |
144 | def _build_entity_types_context(
145 | entity_types: dict[str, type[BaseModel]] | None,
146 | ) -> list[dict]:
147 | """Build entity types context with ID mappings."""
148 | entity_types_context = [
149 | {
150 | 'entity_type_id': 0,
151 | 'entity_type_name': 'Entity',
152 | 'entity_type_description': (
153 | 'Default entity classification. Use this entity type '
154 | 'if the entity is not one of the other listed types.'
155 | ),
156 | }
157 | ]
158 |
159 | if entity_types is not None:
160 | entity_types_context += [
161 | {
162 | 'entity_type_id': i + 1,
163 | 'entity_type_name': type_name,
164 | 'entity_type_description': type_model.__doc__,
165 | }
166 | for i, (type_name, type_model) in enumerate(entity_types.items())
167 | ]
168 |
169 | return entity_types_context
170 |
171 |
172 | async def _extract_nodes_single(
173 | llm_client: LLMClient,
174 | episode: EpisodicNode,
175 | context: dict,
176 | ) -> list[ExtractedEntity]:
177 | """Extract entities using a single LLM call."""
178 | llm_response = await _call_extraction_llm(llm_client, episode, context)
179 | response_object = ExtractedEntities(**llm_response)
180 | return response_object.extracted_entities
181 |
182 |
183 | async def _extract_nodes_chunked(
184 | llm_client: LLMClient,
185 | episode: EpisodicNode,
186 | context: dict,
187 | ) -> list[ExtractedEntity]:
188 | """Extract entities from large content using chunking."""
189 | # Chunk the content based on episode type
190 | if episode.source == EpisodeType.json:
191 | chunks = chunk_json_content(episode.content)
192 | elif episode.source == EpisodeType.message:
193 | chunks = chunk_message_content(episode.content)
194 | else:
195 | chunks = chunk_text_content(episode.content)
196 |
197 | logger.debug(f'Chunked content into {len(chunks)} chunks for entity extraction')
198 |
199 | # Extract entities from each chunk in parallel
200 | chunk_results = await semaphore_gather(
201 | *[_extract_from_chunk(llm_client, chunk, context, episode) for chunk in chunks]
202 | )
203 |
204 | # Merge and deduplicate entities across chunks
205 | merged_entities = _merge_extracted_entities(chunk_results)
206 | logger.debug(
207 | f'Merged {sum(len(r) for r in chunk_results)} entities into {len(merged_entities)} unique'
208 | )
209 |
210 | return merged_entities
211 |
212 |
213 | async def _extract_from_chunk(
214 | llm_client: LLMClient,
215 | chunk: str,
216 | base_context: dict,
217 | episode: EpisodicNode,
218 | ) -> list[ExtractedEntity]:
219 | """Extract entities from a single chunk."""
220 | chunk_context = {**base_context, 'episode_content': chunk}
221 | llm_response = await _call_extraction_llm(llm_client, episode, chunk_context)
222 | return ExtractedEntities(**llm_response).extracted_entities
223 |
224 |
225 | async def _call_extraction_llm(
226 | llm_client: LLMClient,
227 | episode: EpisodicNode,
228 | context: dict,
229 | ) -> dict:
230 | """Call the appropriate extraction prompt based on episode type."""
231 | if episode.source == EpisodeType.message:
232 | prompt = prompt_library.extract_nodes.extract_message(context)
233 | prompt_name = 'extract_nodes.extract_message'
234 | elif episode.source == EpisodeType.text:
235 | prompt = prompt_library.extract_nodes.extract_text(context)
236 | prompt_name = 'extract_nodes.extract_text'
237 | elif episode.source == EpisodeType.json:
238 | prompt = prompt_library.extract_nodes.extract_json(context)
239 | prompt_name = 'extract_nodes.extract_json'
240 | else:
241 | # Fallback to text extraction
242 | prompt = prompt_library.extract_nodes.extract_text(context)
243 | prompt_name = 'extract_nodes.extract_text'
244 |
245 | return await llm_client.generate_response(
246 | prompt,
247 | response_model=ExtractedEntities,
248 | group_id=episode.group_id,
249 | prompt_name=prompt_name,
250 | )
251 |
252 |
253 | def _merge_extracted_entities(
254 | chunk_results: list[list[ExtractedEntity]],
255 | ) -> list[ExtractedEntity]:
256 | """Merge entities from multiple chunks, deduplicating by normalized name.
257 |
258 | When duplicates occur, prefer the first occurrence (maintains ordering).
259 | """
260 | seen_names: set[str] = set()
261 | merged: list[ExtractedEntity] = []
262 |
263 | for entities in chunk_results:
264 | for entity in entities:
265 | normalized = entity.name.strip().lower()
266 | if normalized and normalized not in seen_names:
267 | seen_names.add(normalized)
268 | merged.append(entity)
269 |
270 | return merged
271 |
272 |
273 | def _create_entity_nodes(
274 | extracted_entities: list[ExtractedEntity],
275 | entity_types_context: list[dict],
276 | excluded_entity_types: list[str] | None,
277 | episode: EpisodicNode,
278 | ) -> list[EntityNode]:
279 | """Convert ExtractedEntity objects to EntityNode objects."""
280 | extracted_nodes = []
281 |
282 | for extracted_entity in extracted_entities:
283 | type_id = extracted_entity.entity_type_id
284 | if 0 <= type_id < len(entity_types_context):
285 | entity_type_name = entity_types_context[type_id].get('entity_type_name')
286 | else:
287 | entity_type_name = 'Entity'
288 |
289 | # Check if this entity type should be excluded
290 | if excluded_entity_types and entity_type_name in excluded_entity_types:
291 | logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
292 | continue
293 |
294 | labels: list[str] = list({'Entity', str(entity_type_name)})
295 |
296 | new_node = EntityNode(
297 | name=extracted_entity.name,
298 | group_id=episode.group_id,
299 | labels=labels,
300 | summary='',
301 | created_at=utc_now(),
302 | )
303 | extracted_nodes.append(new_node)
304 | logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
305 |
306 | return extracted_nodes
307 |
308 |
309 | async def _collect_candidate_nodes(
310 | clients: GraphitiClients,
311 | extracted_nodes: list[EntityNode],
312 | existing_nodes_override: list[EntityNode] | None,
313 | ) -> list[EntityNode]:
314 | """Search per extracted name and return unique candidates with overrides honored in order."""
315 | search_results: list[SearchResults] = await semaphore_gather(
316 | *[
317 | search(
318 | clients=clients,
319 | query=node.name,
320 | group_ids=[node.group_id],
321 | search_filter=SearchFilters(),
322 | config=NODE_HYBRID_SEARCH_RRF,
323 | )
324 | for node in extracted_nodes
325 | ]
326 | )
327 |
328 | candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
329 |
330 | if existing_nodes_override is not None:
331 | candidate_nodes.extend(existing_nodes_override)
332 |
333 | seen_candidate_uuids: set[str] = set()
334 | ordered_candidates: list[EntityNode] = []
335 | for candidate in candidate_nodes:
336 | if candidate.uuid in seen_candidate_uuids:
337 | continue
338 | seen_candidate_uuids.add(candidate.uuid)
339 | ordered_candidates.append(candidate)
340 |
341 | return ordered_candidates
342 |
343 |
344 | async def _resolve_with_llm(
345 | llm_client: LLMClient,
346 | extracted_nodes: list[EntityNode],
347 | indexes: DedupCandidateIndexes,
348 | state: DedupResolutionState,
349 | episode: EpisodicNode | None,
350 | previous_episodes: list[EpisodicNode] | None,
351 | entity_types: dict[str, type[BaseModel]] | None,
352 | ) -> None:
353 | """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
354 |
355 | The guardrails below defensively ignore malformed or duplicate LLM responses so the
356 | ingestion workflow remains deterministic even when the model misbehaves.
357 | """
358 | if not state.unresolved_indices:
359 | return
360 |
361 | entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
362 |
363 | llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
364 |
365 | extracted_nodes_context = [
366 | {
367 | 'id': i,
368 | 'name': node.name,
369 | 'entity_type': node.labels,
370 | 'entity_type_description': entity_types_dict.get(
371 | next((item for item in node.labels if item != 'Entity'), '')
372 | ).__doc__
373 | or 'Default Entity Type',
374 | }
375 | for i, node in enumerate(llm_extracted_nodes)
376 | ]
377 |
378 | sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
379 | logger.debug(
380 | 'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
381 | len(llm_extracted_nodes),
382 | len(llm_extracted_nodes) - 1,
383 | sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
384 | )
385 | if llm_extracted_nodes:
386 | sample_size = min(3, len(extracted_nodes_context))
387 | logger.debug(
388 | 'First %d entities: %s',
389 | sample_size,
390 | [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
391 | )
392 | if len(extracted_nodes_context) > 3:
393 | logger.debug(
394 | 'Last %d entities: %s',
395 | sample_size,
396 | [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
397 | )
398 |
399 | existing_nodes_context = [
400 | {
401 | **{
402 | 'idx': i,
403 | 'name': candidate.name,
404 | 'entity_types': candidate.labels,
405 | },
406 | **candidate.attributes,
407 | }
408 | for i, candidate in enumerate(indexes.existing_nodes)
409 | ]
410 |
411 | context = {
412 | 'extracted_nodes': extracted_nodes_context,
413 | 'existing_nodes': existing_nodes_context,
414 | 'episode_content': episode.content if episode is not None else '',
415 | 'previous_episodes': (
416 | [ep.content for ep in previous_episodes] if previous_episodes is not None else []
417 | ),
418 | }
419 |
420 | llm_response = await llm_client.generate_response(
421 | prompt_library.dedupe_nodes.nodes(context),
422 | response_model=NodeResolutions,
423 | prompt_name='dedupe_nodes.nodes',
424 | )
425 |
426 | node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
427 |
428 | valid_relative_range = range(len(state.unresolved_indices))
429 | processed_relative_ids: set[int] = set()
430 |
431 | received_ids = {r.id for r in node_resolutions}
432 | expected_ids = set(valid_relative_range)
433 | missing_ids = expected_ids - received_ids
434 | extra_ids = received_ids - expected_ids
435 |
436 | logger.debug(
437 | 'Received %d resolutions for %d entities',
438 | len(node_resolutions),
439 | len(state.unresolved_indices),
440 | )
441 |
442 | if missing_ids:
443 | logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
444 |
445 | if extra_ids:
446 | logger.warning(
447 | 'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
448 | len(state.unresolved_indices) - 1,
449 | sorted(extra_ids),
450 | sorted(received_ids),
451 | )
452 |
453 | for resolution in node_resolutions:
454 | relative_id: int = resolution.id
455 | duplicate_idx: int = resolution.duplicate_idx
456 |
457 | if relative_id not in valid_relative_range:
458 | logger.warning(
459 | 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
460 | relative_id,
461 | len(state.unresolved_indices) - 1,
462 | len(node_resolutions),
463 | )
464 | continue
465 |
466 | if relative_id in processed_relative_ids:
467 | logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
468 | continue
469 | processed_relative_ids.add(relative_id)
470 |
471 | original_index = state.unresolved_indices[relative_id]
472 | extracted_node = extracted_nodes[original_index]
473 |
474 | resolved_node: EntityNode
475 | if duplicate_idx == -1:
476 | resolved_node = extracted_node
477 | elif 0 <= duplicate_idx < len(indexes.existing_nodes):
478 | resolved_node = indexes.existing_nodes[duplicate_idx]
479 | else:
480 | logger.warning(
481 | 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
482 | duplicate_idx,
483 | extracted_node.uuid,
484 | )
485 | resolved_node = extracted_node
486 |
487 | state.resolved_nodes[original_index] = resolved_node
488 | state.uuid_map[extracted_node.uuid] = resolved_node.uuid
489 | if resolved_node.uuid != extracted_node.uuid:
490 | state.duplicate_pairs.append((extracted_node, resolved_node))
491 |
492 |
493 | async def resolve_extracted_nodes(
494 | clients: GraphitiClients,
495 | extracted_nodes: list[EntityNode],
496 | episode: EpisodicNode | None = None,
497 | previous_episodes: list[EpisodicNode] | None = None,
498 | entity_types: dict[str, type[BaseModel]] | None = None,
499 | existing_nodes_override: list[EntityNode] | None = None,
500 | ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
501 | """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
502 | llm_client = clients.llm_client
503 | driver = clients.driver
504 | existing_nodes = await _collect_candidate_nodes(
505 | clients,
506 | extracted_nodes,
507 | existing_nodes_override,
508 | )
509 |
510 | indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
511 |
512 | state = DedupResolutionState(
513 | resolved_nodes=[None] * len(extracted_nodes),
514 | uuid_map={},
515 | unresolved_indices=[],
516 | )
517 |
518 | _resolve_with_similarity(extracted_nodes, indexes, state)
519 |
520 | await _resolve_with_llm(
521 | llm_client,
522 | extracted_nodes,
523 | indexes,
524 | state,
525 | episode,
526 | previous_episodes,
527 | entity_types,
528 | )
529 |
530 | for idx, node in enumerate(extracted_nodes):
531 | if state.resolved_nodes[idx] is None:
532 | state.resolved_nodes[idx] = node
533 | state.uuid_map[node.uuid] = node.uuid
534 |
535 | logger.debug(
536 | 'Resolved nodes: %s',
537 | [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
538 | )
539 |
540 | new_node_duplicates: list[
541 | tuple[EntityNode, EntityNode]
542 | ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
543 |
544 | return (
545 | [node for node in state.resolved_nodes if node is not None],
546 | state.uuid_map,
547 | new_node_duplicates,
548 | )
549 |
550 |
551 | async def extract_attributes_from_nodes(
552 | clients: GraphitiClients,
553 | nodes: list[EntityNode],
554 | episode: EpisodicNode | None = None,
555 | previous_episodes: list[EpisodicNode] | None = None,
556 | entity_types: dict[str, type[BaseModel]] | None = None,
557 | should_summarize_node: NodeSummaryFilter | None = None,
558 | ) -> list[EntityNode]:
559 | llm_client = clients.llm_client
560 | embedder = clients.embedder
561 | updated_nodes: list[EntityNode] = await semaphore_gather(
562 | *[
563 | extract_attributes_from_node(
564 | llm_client,
565 | node,
566 | episode,
567 | previous_episodes,
568 | (
569 | entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
570 | if entity_types is not None
571 | else None
572 | ),
573 | should_summarize_node,
574 | )
575 | for node in nodes
576 | ]
577 | )
578 |
579 | await create_entity_node_embeddings(embedder, updated_nodes)
580 |
581 | return updated_nodes
582 |
583 |
584 | async def extract_attributes_from_node(
585 | llm_client: LLMClient,
586 | node: EntityNode,
587 | episode: EpisodicNode | None = None,
588 | previous_episodes: list[EpisodicNode] | None = None,
589 | entity_type: type[BaseModel] | None = None,
590 | should_summarize_node: NodeSummaryFilter | None = None,
591 | ) -> EntityNode:
592 | # Extract attributes if entity type is defined and has attributes
593 | llm_response = await _extract_entity_attributes(
594 | llm_client, node, episode, previous_episodes, entity_type
595 | )
596 |
597 | # Extract summary if needed
598 | await _extract_entity_summary(
599 | llm_client, node, episode, previous_episodes, should_summarize_node
600 | )
601 |
602 | node.attributes.update(llm_response)
603 |
604 | return node
605 |
606 |
607 | async def _extract_entity_attributes(
608 | llm_client: LLMClient,
609 | node: EntityNode,
610 | episode: EpisodicNode | None,
611 | previous_episodes: list[EpisodicNode] | None,
612 | entity_type: type[BaseModel] | None,
613 | ) -> dict[str, Any]:
614 | if entity_type is None or len(entity_type.model_fields) == 0:
615 | return {}
616 |
617 | attributes_context = _build_episode_context(
618 | # should not include summary
619 | node_data={
620 | 'name': node.name,
621 | 'entity_types': node.labels,
622 | 'attributes': node.attributes,
623 | },
624 | episode=episode,
625 | previous_episodes=previous_episodes,
626 | )
627 |
628 | llm_response = await llm_client.generate_response(
629 | prompt_library.extract_nodes.extract_attributes(attributes_context),
630 | response_model=entity_type,
631 | model_size=ModelSize.small,
632 | group_id=node.group_id,
633 | prompt_name='extract_nodes.extract_attributes',
634 | )
635 |
636 | # validate response
637 | entity_type(**llm_response)
638 |
639 | return llm_response
640 |
641 |
642 | async def _extract_entity_summary(
643 | llm_client: LLMClient,
644 | node: EntityNode,
645 | episode: EpisodicNode | None,
646 | previous_episodes: list[EpisodicNode] | None,
647 | should_summarize_node: NodeSummaryFilter | None,
648 | ) -> None:
649 | if should_summarize_node is not None and not await should_summarize_node(node):
650 | return
651 |
652 | summary_context = _build_episode_context(
653 | node_data={
654 | 'name': node.name,
655 | 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
656 | 'entity_types': node.labels,
657 | 'attributes': node.attributes,
658 | },
659 | episode=episode,
660 | previous_episodes=previous_episodes,
661 | )
662 |
663 | summary_response = await llm_client.generate_response(
664 | prompt_library.extract_nodes.extract_summary(summary_context),
665 | response_model=EntitySummary,
666 | model_size=ModelSize.small,
667 | group_id=node.group_id,
668 | prompt_name='extract_nodes.extract_summary',
669 | )
670 |
671 | node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
672 |
673 |
674 | def _build_episode_context(
675 | node_data: dict[str, Any],
676 | episode: EpisodicNode | None,
677 | previous_episodes: list[EpisodicNode] | None,
678 | ) -> dict[str, Any]:
679 | return {
680 | 'node': node_data,
681 | 'episode_content': episode.content if episode is not None else '',
682 | 'previous_episodes': (
683 | [ep.content for ep in previous_episodes] if previous_episodes is not None else []
684 | ),
685 | }
686 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Comprehensive integration test suite for Graphiti MCP Server.
4 | Covers all MCP tools with consideration for LLM inference latency.
5 | """
6 |
7 | import asyncio
8 | import json
9 | import os
10 | import time
11 | from dataclasses import dataclass
12 | from typing import Any
13 |
14 | import pytest
15 | from mcp import ClientSession, StdioServerParameters
16 | from mcp.client.stdio import stdio_client
17 |
18 |
19 | @dataclass
20 | class TestMetrics:
21 | """Track test performance metrics."""
22 |
23 | operation: str
24 | start_time: float
25 | end_time: float
26 | success: bool
27 | details: dict[str, Any]
28 |
29 | @property
30 | def duration(self) -> float:
31 | """Calculate operation duration in seconds."""
32 | return self.end_time - self.start_time
33 |
34 |
35 | class GraphitiTestClient:
36 | """Enhanced test client for comprehensive Graphiti MCP testing."""
37 |
38 | def __init__(self, test_group_id: str | None = None):
39 | self.test_group_id = test_group_id or f'test_{int(time.time())}'
40 | self.session = None
41 | self.metrics: list[TestMetrics] = []
42 | self.default_timeout = 30 # seconds
43 |
44 | async def __aenter__(self):
45 | """Initialize MCP client session."""
46 | server_params = StdioServerParameters(
47 | command='uv',
48 | args=['run', '../main.py', '--transport', 'stdio'],
49 | env={
50 | 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
51 | 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
52 | 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
53 | 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
54 | 'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
55 | },
56 | )
57 |
58 | self.client_context = stdio_client(server_params)
59 | read, write = await self.client_context.__aenter__()
60 | self.session = ClientSession(read, write)
61 | await self.session.initialize()
62 |
63 | # Wait for server to be fully ready
64 | await asyncio.sleep(2)
65 |
66 | return self
67 |
68 | async def __aexit__(self, exc_type, exc_val, exc_tb):
69 | """Clean up client session."""
70 | if self.session:
71 | await self.session.close()
72 | if hasattr(self, 'client_context'):
73 | await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
74 |
75 | async def call_tool_with_metrics(
76 | self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
77 | ) -> tuple[Any, TestMetrics]:
78 | """Call a tool and capture performance metrics."""
79 | start_time = time.time()
80 | timeout = timeout or self.default_timeout
81 |
82 | try:
83 | result = await asyncio.wait_for(
84 | self.session.call_tool(tool_name, arguments), timeout=timeout
85 | )
86 |
87 | content = result.content[0].text if result.content else None
88 | success = True
89 | details = {'result': content, 'tool': tool_name}
90 |
91 | except asyncio.TimeoutError:
92 | content = None
93 | success = False
94 | details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
95 |
96 | except Exception as e:
97 | content = None
98 | success = False
99 | details = {'error': str(e), 'tool': tool_name}
100 |
101 | end_time = time.time()
102 | metric = TestMetrics(
103 | operation=f'call_{tool_name}',
104 | start_time=start_time,
105 | end_time=end_time,
106 | success=success,
107 | details=details,
108 | )
109 | self.metrics.append(metric)
110 |
111 | return content, metric
112 |
113 | async def wait_for_episode_processing(
114 | self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
115 | ) -> bool:
116 | """
117 | Wait for episodes to be processed with intelligent polling.
118 |
119 | Args:
120 | expected_count: Number of episodes expected to be processed
121 | max_wait: Maximum seconds to wait
122 | poll_interval: Seconds between status checks
123 |
124 | Returns:
125 | True if episodes were processed successfully
126 | """
127 | start_time = time.time()
128 |
129 | while (time.time() - start_time) < max_wait:
130 | result, _ = await self.call_tool_with_metrics(
131 | 'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
132 | )
133 |
134 | if result:
135 | try:
136 | episodes = json.loads(result) if isinstance(result, str) else result
137 | if len(episodes.get('episodes', [])) >= expected_count:
138 | return True
139 | except (json.JSONDecodeError, AttributeError):
140 | pass
141 |
142 | await asyncio.sleep(poll_interval)
143 |
144 | return False
145 |
146 |
147 | class TestCoreOperations:
148 | """Test core Graphiti operations."""
149 |
150 | @pytest.mark.asyncio
151 | async def test_server_initialization(self):
152 | """Verify server initializes with all required tools."""
153 | async with GraphitiTestClient() as client:
154 | tools_result = await client.session.list_tools()
155 | tools = {tool.name for tool in tools_result.tools}
156 |
157 | required_tools = {
158 | 'add_memory',
159 | 'search_memory_nodes',
160 | 'search_memory_facts',
161 | 'get_episodes',
162 | 'delete_episode',
163 | 'delete_entity_edge',
164 | 'get_entity_edge',
165 | 'clear_graph',
166 | 'get_status',
167 | }
168 |
169 | missing_tools = required_tools - tools
170 | assert not missing_tools, f'Missing required tools: {missing_tools}'
171 |
172 | @pytest.mark.asyncio
173 | async def test_add_text_memory(self):
174 | """Test adding text-based memories."""
175 | async with GraphitiTestClient() as client:
176 | # Add memory
177 | result, metric = await client.call_tool_with_metrics(
178 | 'add_memory',
179 | {
180 | 'name': 'Tech Conference Notes',
181 | 'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
182 | 'source': 'text',
183 | 'source_description': 'conference notes',
184 | 'group_id': client.test_group_id,
185 | },
186 | )
187 |
188 | assert metric.success, f'Failed to add memory: {metric.details}'
189 | assert 'queued' in str(result).lower()
190 |
191 | # Wait for processing
192 | processed = await client.wait_for_episode_processing(expected_count=1)
193 | assert processed, 'Episode was not processed within timeout'
194 |
195 | @pytest.mark.asyncio
196 | async def test_add_json_memory(self):
197 | """Test adding structured JSON memories."""
198 | async with GraphitiTestClient() as client:
199 | json_data = {
200 | 'project': {
201 | 'name': 'GraphitiDB',
202 | 'version': '2.0.0',
203 | 'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
204 | },
205 | 'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
206 | }
207 |
208 | result, metric = await client.call_tool_with_metrics(
209 | 'add_memory',
210 | {
211 | 'name': 'Project Data',
212 | 'episode_body': json.dumps(json_data),
213 | 'source': 'json',
214 | 'source_description': 'project database',
215 | 'group_id': client.test_group_id,
216 | },
217 | )
218 |
219 | assert metric.success
220 | assert 'queued' in str(result).lower()
221 |
222 | @pytest.mark.asyncio
223 | async def test_add_message_memory(self):
224 | """Test adding conversation/message memories."""
225 | async with GraphitiTestClient() as client:
226 | conversation = """
227 | user: What are the key features of Graphiti?
228 | assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
229 | user: How does it handle entity resolution?
230 | assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
231 | """
232 |
233 | result, metric = await client.call_tool_with_metrics(
234 | 'add_memory',
235 | {
236 | 'name': 'Feature Discussion',
237 | 'episode_body': conversation,
238 | 'source': 'message',
239 | 'source_description': 'support chat',
240 | 'group_id': client.test_group_id,
241 | },
242 | )
243 |
244 | assert metric.success
245 | assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
246 |
247 |
248 | class TestSearchOperations:
249 | """Test search and retrieval operations."""
250 |
251 | @pytest.mark.asyncio
252 | async def test_search_nodes_semantic(self):
253 | """Test semantic search for nodes."""
254 | async with GraphitiTestClient() as client:
255 | # First add some test data
256 | await client.call_tool_with_metrics(
257 | 'add_memory',
258 | {
259 | 'name': 'Product Launch',
260 | 'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
261 | 'source': 'text',
262 | 'source_description': 'product roadmap',
263 | 'group_id': client.test_group_id,
264 | },
265 | )
266 |
267 | # Wait for processing
268 | await client.wait_for_episode_processing()
269 |
270 | # Search for nodes
271 | result, metric = await client.call_tool_with_metrics(
272 | 'search_memory_nodes',
273 | {'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
274 | )
275 |
276 | assert metric.success
277 | assert result is not None
278 |
279 | @pytest.mark.asyncio
280 | async def test_search_facts_with_filters(self):
281 | """Test fact search with various filters."""
282 | async with GraphitiTestClient() as client:
283 | # Add test data
284 | await client.call_tool_with_metrics(
285 | 'add_memory',
286 | {
287 | 'name': 'Company Facts',
288 | 'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
289 | 'source': 'text',
290 | 'source_description': 'company profile',
291 | 'group_id': client.test_group_id,
292 | },
293 | )
294 |
295 | await client.wait_for_episode_processing()
296 |
297 | # Search with date filter
298 | result, metric = await client.call_tool_with_metrics(
299 | 'search_memory_facts',
300 | {
301 | 'query': 'company information',
302 | 'group_id': client.test_group_id,
303 | 'created_after': '2020-01-01T00:00:00Z',
304 | 'limit': 20,
305 | },
306 | )
307 |
308 | assert metric.success
309 |
310 | @pytest.mark.asyncio
311 | async def test_hybrid_search(self):
312 | """Test hybrid search combining semantic and keyword search."""
313 | async with GraphitiTestClient() as client:
314 | # Add diverse test data
315 | test_memories = [
316 | {
317 | 'name': 'Technical Doc',
318 | 'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
319 | 'source': 'text',
320 | },
321 | {
322 | 'name': 'Architecture',
323 | 'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
324 | 'source': 'text',
325 | },
326 | ]
327 |
328 | for memory in test_memories:
329 | memory['group_id'] = client.test_group_id
330 | memory['source_description'] = 'documentation'
331 | await client.call_tool_with_metrics('add_memory', memory)
332 |
333 | await client.wait_for_episode_processing(expected_count=2)
334 |
335 | # Test semantic + keyword search
336 | result, metric = await client.call_tool_with_metrics(
337 | 'search_memory_nodes',
338 | {'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
339 | )
340 |
341 | assert metric.success
342 |
343 |
344 | class TestEpisodeManagement:
345 | """Test episode lifecycle operations."""
346 |
347 | @pytest.mark.asyncio
348 | async def test_get_episodes_pagination(self):
349 | """Test retrieving episodes with pagination."""
350 | async with GraphitiTestClient() as client:
351 | # Add multiple episodes
352 | for i in range(5):
353 | await client.call_tool_with_metrics(
354 | 'add_memory',
355 | {
356 | 'name': f'Episode {i}',
357 | 'episode_body': f'This is test episode number {i}',
358 | 'source': 'text',
359 | 'source_description': 'test',
360 | 'group_id': client.test_group_id,
361 | },
362 | )
363 |
364 | await client.wait_for_episode_processing(expected_count=5)
365 |
366 | # Test pagination
367 | result, metric = await client.call_tool_with_metrics(
368 | 'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
369 | )
370 |
371 | assert metric.success
372 | episodes = json.loads(result) if isinstance(result, str) else result
373 | assert len(episodes.get('episodes', [])) <= 3
374 |
375 | @pytest.mark.asyncio
376 | async def test_delete_episode(self):
377 | """Test deleting specific episodes."""
378 | async with GraphitiTestClient() as client:
379 | # Add an episode
380 | await client.call_tool_with_metrics(
381 | 'add_memory',
382 | {
383 | 'name': 'To Delete',
384 | 'episode_body': 'This episode will be deleted',
385 | 'source': 'text',
386 | 'source_description': 'test',
387 | 'group_id': client.test_group_id,
388 | },
389 | )
390 |
391 | await client.wait_for_episode_processing()
392 |
393 | # Get episode UUID
394 | result, _ = await client.call_tool_with_metrics(
395 | 'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
396 | )
397 |
398 | episodes = json.loads(result) if isinstance(result, str) else result
399 | episode_uuid = episodes['episodes'][0]['uuid']
400 |
401 | # Delete the episode
402 | result, metric = await client.call_tool_with_metrics(
403 | 'delete_episode', {'episode_uuid': episode_uuid}
404 | )
405 |
406 | assert metric.success
407 | assert 'deleted' in str(result).lower()
408 |
409 |
410 | class TestEntityAndEdgeOperations:
411 | """Test entity and edge management."""
412 |
413 | @pytest.mark.asyncio
414 | async def test_get_entity_edge(self):
415 | """Test retrieving entity edges."""
416 | async with GraphitiTestClient() as client:
417 | # Add data to create entities and edges
418 | await client.call_tool_with_metrics(
419 | 'add_memory',
420 | {
421 | 'name': 'Relationship Data',
422 | 'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
423 | 'source': 'text',
424 | 'source_description': 'org chart',
425 | 'group_id': client.test_group_id,
426 | },
427 | )
428 |
429 | await client.wait_for_episode_processing()
430 |
431 | # Search for nodes to get UUIDs
432 | result, _ = await client.call_tool_with_metrics(
433 | 'search_memory_nodes',
434 | {'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
435 | )
436 |
437 | # Note: This test assumes edges are created between entities
438 | # Actual edge retrieval would require valid edge UUIDs
439 |
440 | @pytest.mark.asyncio
441 | async def test_delete_entity_edge(self):
442 | """Test deleting entity edges."""
443 | # Similar structure to get_entity_edge but with deletion
444 | pass # Implement based on actual edge creation patterns
445 |
446 |
447 | class TestErrorHandling:
448 | """Test error conditions and edge cases."""
449 |
450 | @pytest.mark.asyncio
451 | async def test_invalid_tool_arguments(self):
452 | """Test handling of invalid tool arguments."""
453 | async with GraphitiTestClient() as client:
454 | # Missing required arguments
455 | result, metric = await client.call_tool_with_metrics(
456 | 'add_memory',
457 | {'name': 'Incomplete'}, # Missing required fields
458 | )
459 |
460 | assert not metric.success
461 | assert 'error' in str(metric.details).lower()
462 |
463 | @pytest.mark.asyncio
464 | async def test_timeout_handling(self):
465 | """Test timeout handling for long operations."""
466 | async with GraphitiTestClient() as client:
467 | # Simulate a very large episode that might time out
468 | large_text = 'Large document content. ' * 10000
469 |
470 | result, metric = await client.call_tool_with_metrics(
471 | 'add_memory',
472 | {
473 | 'name': 'Large Document',
474 | 'episode_body': large_text,
475 | 'source': 'text',
476 | 'source_description': 'large file',
477 | 'group_id': client.test_group_id,
478 | },
479 | timeout=5, # Short timeout
480 | )
481 |
482 | # Check if timeout was handled gracefully
483 | if not metric.success:
484 | assert 'timeout' in str(metric.details).lower()
485 |
486 | @pytest.mark.asyncio
487 | async def test_concurrent_operations(self):
488 | """Test handling of concurrent operations."""
489 | async with GraphitiTestClient() as client:
490 | # Launch multiple operations concurrently
491 | tasks = []
492 | for i in range(5):
493 | task = client.call_tool_with_metrics(
494 | 'add_memory',
495 | {
496 | 'name': f'Concurrent {i}',
497 | 'episode_body': f'Concurrent operation {i}',
498 | 'source': 'text',
499 | 'source_description': 'concurrent test',
500 | 'group_id': client.test_group_id,
501 | },
502 | )
503 | tasks.append(task)
504 |
505 | results = await asyncio.gather(*tasks, return_exceptions=True)
506 |
507 | # Check that operations were queued successfully
508 | successful = sum(1 for r, m in results if m.success)
509 | assert successful >= 3 # At least 60% should succeed
510 |
511 |
512 | class TestPerformance:
513 | """Test performance characteristics and optimization."""
514 |
515 | @pytest.mark.asyncio
516 | async def test_latency_metrics(self):
517 | """Measure and validate operation latencies."""
518 | async with GraphitiTestClient() as client:
519 | operations = [
520 | (
521 | 'add_memory',
522 | {
523 | 'name': 'Perf Test',
524 | 'episode_body': 'Simple text',
525 | 'source': 'text',
526 | 'source_description': 'test',
527 | 'group_id': client.test_group_id,
528 | },
529 | ),
530 | (
531 | 'search_memory_nodes',
532 | {'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
533 | ),
534 | ('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
535 | ]
536 |
537 | for tool_name, args in operations:
538 | _, metric = await client.call_tool_with_metrics(tool_name, args)
539 |
540 | # Log performance metrics
541 | print(f'{tool_name}: {metric.duration:.2f}s')
542 |
543 | # Basic latency assertions
544 | if tool_name == 'get_episodes':
545 | assert metric.duration < 2, f'{tool_name} too slow'
546 | elif tool_name == 'search_memory_nodes':
547 | assert metric.duration < 10, f'{tool_name} too slow'
548 |
549 | @pytest.mark.asyncio
550 | async def test_batch_processing_efficiency(self):
551 | """Test efficiency of batch operations."""
552 | async with GraphitiTestClient() as client:
553 | batch_size = 10
554 | start_time = time.time()
555 |
556 | # Batch add memories
557 | for i in range(batch_size):
558 | await client.call_tool_with_metrics(
559 | 'add_memory',
560 | {
561 | 'name': f'Batch {i}',
562 | 'episode_body': f'Batch content {i}',
563 | 'source': 'text',
564 | 'source_description': 'batch test',
565 | 'group_id': client.test_group_id,
566 | },
567 | )
568 |
569 | # Wait for all to process
570 | processed = await client.wait_for_episode_processing(
571 | expected_count=batch_size,
572 | max_wait=120, # Allow more time for batch
573 | )
574 |
575 | total_time = time.time() - start_time
576 | avg_time_per_item = total_time / batch_size
577 |
578 | assert processed, f'Failed to process {batch_size} items'
579 | assert avg_time_per_item < 15, (
580 | f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
581 | )
582 |
583 | # Generate performance report
584 | print('\nBatch Performance Report:')
585 | print(f' Total items: {batch_size}')
586 | print(f' Total time: {total_time:.2f}s')
587 | print(f' Avg per item: {avg_time_per_item:.2f}s')
588 |
589 |
590 | class TestDatabaseBackends:
591 | """Test different database backend configurations."""
592 |
593 | @pytest.mark.asyncio
594 | @pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
595 | async def test_database_operations(self, database):
596 | """Test operations with different database backends."""
597 | env_vars = {
598 | 'DATABASE_PROVIDER': database,
599 | 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
600 | }
601 |
602 | if database == 'neo4j':
603 | env_vars.update(
604 | {
605 | 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
606 | 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
607 | 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
608 | }
609 | )
610 | elif database == 'falkordb':
611 | env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
612 |
613 | # This test would require setting up server with specific database
614 | # Implementation depends on database availability
615 | pass # Placeholder for database-specific tests
616 |
617 |
618 | def generate_test_report(client: GraphitiTestClient) -> str:
619 | """Generate a comprehensive test report from metrics."""
620 | if not client.metrics:
621 | return 'No metrics collected'
622 |
623 | report = []
624 | report.append('\n' + '=' * 60)
625 | report.append('GRAPHITI MCP TEST REPORT')
626 | report.append('=' * 60)
627 |
628 | # Summary statistics
629 | total_ops = len(client.metrics)
630 | successful_ops = sum(1 for m in client.metrics if m.success)
631 | avg_duration = sum(m.duration for m in client.metrics) / total_ops
632 |
633 | report.append(f'\nTotal Operations: {total_ops}')
634 | report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
635 | report.append(f'Average Duration: {avg_duration:.2f}s')
636 |
637 | # Operation breakdown
638 | report.append('\nOperation Breakdown:')
639 | operation_stats = {}
640 | for metric in client.metrics:
641 | if metric.operation not in operation_stats:
642 | operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
643 | stats = operation_stats[metric.operation]
644 | stats['count'] += 1
645 | stats['success'] += 1 if metric.success else 0
646 | stats['total_duration'] += metric.duration
647 |
648 | for op, stats in sorted(operation_stats.items()):
649 | avg_dur = stats['total_duration'] / stats['count']
650 | success_rate = stats['success'] / stats['count'] * 100
651 | report.append(
652 | f' {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
653 | )
654 |
655 | # Slowest operations
656 | slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
657 | report.append('\nSlowest Operations:')
658 | for metric in slowest:
659 | report.append(f' {metric.operation}: {metric.duration:.2f}s')
660 |
661 | report.append('=' * 60)
662 | return '\n'.join(report)
663 |
664 |
665 | if __name__ == '__main__':
666 | # Run tests with pytest
667 | pytest.main([__file__, '-v', '--asyncio-mode=auto'])
668 |
```