#
tokens: 45475/50000 7/236 files (page 8/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 8 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── dense_vs_normal_ingestion.py
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── content_chunking.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_entity_extraction.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       ├── search
│       │   └── search_utils_test.py
│       └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import typing
 20 | from datetime import datetime
 21 | 
 22 | import numpy as np
 23 | from pydantic import BaseModel, Field
 24 | from typing_extensions import Any
 25 | 
 26 | from graphiti_core.driver.driver import (
 27 |     GraphDriver,
 28 |     GraphDriverSession,
 29 |     GraphProvider,
 30 | )
 31 | from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
 32 | from graphiti_core.embedder import EmbedderClient
 33 | from graphiti_core.graphiti_types import GraphitiClients
 34 | from graphiti_core.helpers import normalize_l2, semaphore_gather
 35 | from graphiti_core.models.edges.edge_db_queries import (
 36 |     get_entity_edge_save_bulk_query,
 37 |     get_episodic_edge_save_bulk_query,
 38 | )
 39 | from graphiti_core.models.nodes.node_db_queries import (
 40 |     get_entity_node_save_bulk_query,
 41 |     get_episode_node_save_bulk_query,
 42 | )
 43 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
 44 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
 45 | from graphiti_core.utils.maintenance.dedup_helpers import (
 46 |     DedupResolutionState,
 47 |     _build_candidate_indexes,
 48 |     _normalize_string_exact,
 49 |     _resolve_with_similarity,
 50 | )
 51 | from graphiti_core.utils.maintenance.edge_operations import (
 52 |     extract_edges,
 53 |     resolve_extracted_edge,
 54 | )
 55 | from graphiti_core.utils.maintenance.graph_data_operations import (
 56 |     EPISODE_WINDOW_LEN,
 57 |     retrieve_episodes,
 58 | )
 59 | from graphiti_core.utils.maintenance.node_operations import (
 60 |     extract_nodes,
 61 |     resolve_extracted_nodes,
 62 | )
 63 | 
 64 | logger = logging.getLogger(__name__)
 65 | 
 66 | CHUNK_SIZE = 10
 67 | 
 68 | 
 69 | def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
 70 |     """Collapse alias -> canonical chains while preserving direction.
 71 | 
 72 |     The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
 73 |     union-find with iterative path compression to ensure every source UUID resolves to its ultimate
 74 |     canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
 75 |     """
 76 | 
 77 |     parent: dict[str, str] = {}
 78 | 
 79 |     def find(uuid: str) -> str:
 80 |         """Directed union-find lookup using iterative path compression."""
 81 |         parent.setdefault(uuid, uuid)
 82 |         root = uuid
 83 |         while parent[root] != root:
 84 |             root = parent[root]
 85 | 
 86 |         while parent[uuid] != root:
 87 |             next_uuid = parent[uuid]
 88 |             parent[uuid] = root
 89 |             uuid = next_uuid
 90 | 
 91 |         return root
 92 | 
 93 |     for source_uuid, target_uuid in pairs:
 94 |         parent.setdefault(source_uuid, source_uuid)
 95 |         parent.setdefault(target_uuid, target_uuid)
 96 |         parent[find(source_uuid)] = find(target_uuid)
 97 | 
 98 |     return {uuid: find(uuid) for uuid in parent}
 99 | 
100 | 
101 | class RawEpisode(BaseModel):
102 |     name: str
103 |     uuid: str | None = Field(default=None)
104 |     content: str
105 |     source_description: str
106 |     source: EpisodeType
107 |     reference_time: datetime
108 | 
109 | 
110 | async def retrieve_previous_episodes_bulk(
111 |     driver: GraphDriver, episodes: list[EpisodicNode]
112 | ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
113 |     previous_episodes_list = await semaphore_gather(
114 |         *[
115 |             retrieve_episodes(
116 |                 driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
117 |             )
118 |             for episode in episodes
119 |         ]
120 |     )
121 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
122 |         (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
123 |     ]
124 | 
125 |     return episode_tuples
126 | 
127 | 
128 | async def add_nodes_and_edges_bulk(
129 |     driver: GraphDriver,
130 |     episodic_nodes: list[EpisodicNode],
131 |     episodic_edges: list[EpisodicEdge],
132 |     entity_nodes: list[EntityNode],
133 |     entity_edges: list[EntityEdge],
134 |     embedder: EmbedderClient,
135 | ):
136 |     session = driver.session()
137 |     try:
138 |         await session.execute_write(
139 |             add_nodes_and_edges_bulk_tx,
140 |             episodic_nodes,
141 |             episodic_edges,
142 |             entity_nodes,
143 |             entity_edges,
144 |             embedder,
145 |             driver=driver,
146 |         )
147 |     finally:
148 |         await session.close()
149 | 
150 | 
151 | async def add_nodes_and_edges_bulk_tx(
152 |     tx: GraphDriverSession,
153 |     episodic_nodes: list[EpisodicNode],
154 |     episodic_edges: list[EpisodicEdge],
155 |     entity_nodes: list[EntityNode],
156 |     entity_edges: list[EntityEdge],
157 |     embedder: EmbedderClient,
158 |     driver: GraphDriver,
159 | ):
160 |     episodes = [dict(episode) for episode in episodic_nodes]
161 |     for episode in episodes:
162 |         episode['source'] = str(episode['source'].value)
163 |         episode.pop('labels', None)
164 | 
165 |     nodes = []
166 | 
167 |     for node in entity_nodes:
168 |         if node.name_embedding is None:
169 |             await node.generate_name_embedding(embedder)
170 | 
171 |         entity_data: dict[str, Any] = {
172 |             'uuid': node.uuid,
173 |             'name': node.name,
174 |             'group_id': node.group_id,
175 |             'summary': node.summary,
176 |             'created_at': node.created_at,
177 |             'name_embedding': node.name_embedding,
178 |             'labels': list(set(node.labels + ['Entity'])),
179 |         }
180 | 
181 |         if driver.provider == GraphProvider.KUZU:
182 |             attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
183 |             entity_data['attributes'] = json.dumps(attributes)
184 |         else:
185 |             entity_data.update(node.attributes or {})
186 | 
187 |         nodes.append(entity_data)
188 | 
189 |     edges = []
190 |     for edge in entity_edges:
191 |         if edge.fact_embedding is None:
192 |             await edge.generate_embedding(embedder)
193 |         edge_data: dict[str, Any] = {
194 |             'uuid': edge.uuid,
195 |             'source_node_uuid': edge.source_node_uuid,
196 |             'target_node_uuid': edge.target_node_uuid,
197 |             'name': edge.name,
198 |             'fact': edge.fact,
199 |             'group_id': edge.group_id,
200 |             'episodes': edge.episodes,
201 |             'created_at': edge.created_at,
202 |             'expired_at': edge.expired_at,
203 |             'valid_at': edge.valid_at,
204 |             'invalid_at': edge.invalid_at,
205 |             'fact_embedding': edge.fact_embedding,
206 |         }
207 | 
208 |         if driver.provider == GraphProvider.KUZU:
209 |             attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
210 |             edge_data['attributes'] = json.dumps(attributes)
211 |         else:
212 |             edge_data.update(edge.attributes or {})
213 | 
214 |         edges.append(edge_data)
215 | 
216 |     if driver.graph_operations_interface:
217 |         await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
218 |         await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
219 |         await driver.graph_operations_interface.episodic_edge_save_bulk(
220 |             None, driver, tx, [edge.model_dump() for edge in episodic_edges]
221 |         )
222 |         await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
223 | 
224 |     elif driver.provider == GraphProvider.KUZU:
225 |         # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
226 |         episode_query = get_episode_node_save_bulk_query(driver.provider)
227 |         for episode in episodes:
228 |             await tx.run(episode_query, **episode)
229 |         entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
230 |         for node in nodes:
231 |             await tx.run(entity_node_query, **node)
232 |         entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
233 |         for edge in edges:
234 |             await tx.run(entity_edge_query, **edge)
235 |         episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
236 |         for edge in episodic_edges:
237 |             await tx.run(episodic_edge_query, **edge.model_dump())
238 |     else:
239 |         await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
240 |         await tx.run(
241 |             get_entity_node_save_bulk_query(driver.provider, nodes),
242 |             nodes=nodes,
243 |         )
244 |         await tx.run(
245 |             get_episodic_edge_save_bulk_query(driver.provider),
246 |             episodic_edges=[edge.model_dump() for edge in episodic_edges],
247 |         )
248 |         await tx.run(
249 |             get_entity_edge_save_bulk_query(driver.provider),
250 |             entity_edges=edges,
251 |         )
252 | 
253 | 
254 | async def extract_nodes_and_edges_bulk(
255 |     clients: GraphitiClients,
256 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
257 |     edge_type_map: dict[tuple[str, str], list[str]],
258 |     entity_types: dict[str, type[BaseModel]] | None = None,
259 |     excluded_entity_types: list[str] | None = None,
260 |     edge_types: dict[str, type[BaseModel]] | None = None,
261 |     custom_extraction_instructions: str | None = None,
262 | ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
263 |     extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
264 |         *[
265 |             extract_nodes(
266 |                 clients,
267 |                 episode,
268 |                 previous_episodes,
269 |                 entity_types=entity_types,
270 |                 excluded_entity_types=excluded_entity_types,
271 |                 custom_extraction_instructions=custom_extraction_instructions,
272 |             )
273 |             for episode, previous_episodes in episode_tuples
274 |         ]
275 |     )
276 | 
277 |     extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
278 |         *[
279 |             extract_edges(
280 |                 clients,
281 |                 episode,
282 |                 extracted_nodes_bulk[i],
283 |                 previous_episodes,
284 |                 edge_type_map=edge_type_map,
285 |                 group_id=episode.group_id,
286 |                 edge_types=edge_types,
287 |                 custom_extraction_instructions=custom_extraction_instructions,
288 |             )
289 |             for i, (episode, previous_episodes) in enumerate(episode_tuples)
290 |         ]
291 |     )
292 | 
293 |     return extracted_nodes_bulk, extracted_edges_bulk
294 | 
295 | 
296 | async def dedupe_nodes_bulk(
297 |     clients: GraphitiClients,
298 |     extracted_nodes: list[list[EntityNode]],
299 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
300 |     entity_types: dict[str, type[BaseModel]] | None = None,
301 | ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
302 |     """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
303 | 
304 |     1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
305 |        reconciled against the live graph just like the non-batch flow.
306 |     2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
307 |        duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
308 |        can apply to edges and persistence.
309 |     """
310 | 
311 |     first_pass_results = await semaphore_gather(
312 |         *[
313 |             resolve_extracted_nodes(
314 |                 clients,
315 |                 nodes,
316 |                 episode_tuples[i][0],
317 |                 episode_tuples[i][1],
318 |                 entity_types,
319 |             )
320 |             for i, nodes in enumerate(extracted_nodes)
321 |         ]
322 |     )
323 | 
324 |     episode_resolutions: list[tuple[str, list[EntityNode]]] = []
325 |     per_episode_uuid_maps: list[dict[str, str]] = []
326 |     duplicate_pairs: list[tuple[str, str]] = []
327 | 
328 |     for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
329 |         first_pass_results, episode_tuples, strict=True
330 |     ):
331 |         episode_resolutions.append((episode.uuid, resolved_nodes))
332 |         per_episode_uuid_maps.append(uuid_map)
333 |         duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
334 | 
335 |     canonical_nodes: dict[str, EntityNode] = {}
336 |     for _, resolved_nodes in episode_resolutions:
337 |         for node in resolved_nodes:
338 |             # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
339 |             # the MinHash index for the accumulated canonical pool each time. The LRU-backed
340 |             # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
341 |             # but if batches grow significantly we should switch to an incremental index or chunked
342 |             # processing.
343 |             if not canonical_nodes:
344 |                 canonical_nodes[node.uuid] = node
345 |                 continue
346 | 
347 |             existing_candidates = list(canonical_nodes.values())
348 |             normalized = _normalize_string_exact(node.name)
349 |             exact_match = next(
350 |                 (
351 |                     candidate
352 |                     for candidate in existing_candidates
353 |                     if _normalize_string_exact(candidate.name) == normalized
354 |                 ),
355 |                 None,
356 |             )
357 |             if exact_match is not None:
358 |                 if exact_match.uuid != node.uuid:
359 |                     duplicate_pairs.append((node.uuid, exact_match.uuid))
360 |                 continue
361 | 
362 |             indexes = _build_candidate_indexes(existing_candidates)
363 |             state = DedupResolutionState(
364 |                 resolved_nodes=[None],
365 |                 uuid_map={},
366 |                 unresolved_indices=[],
367 |             )
368 |             _resolve_with_similarity([node], indexes, state)
369 | 
370 |             resolved = state.resolved_nodes[0]
371 |             if resolved is None:
372 |                 canonical_nodes[node.uuid] = node
373 |                 continue
374 | 
375 |             canonical_uuid = resolved.uuid
376 |             canonical_nodes.setdefault(canonical_uuid, resolved)
377 |             if canonical_uuid != node.uuid:
378 |                 duplicate_pairs.append((node.uuid, canonical_uuid))
379 | 
380 |     union_pairs: list[tuple[str, str]] = []
381 |     for uuid_map in per_episode_uuid_maps:
382 |         union_pairs.extend(uuid_map.items())
383 |     union_pairs.extend(duplicate_pairs)
384 | 
385 |     compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
386 | 
387 |     nodes_by_episode: dict[str, list[EntityNode]] = {}
388 |     for episode_uuid, resolved_nodes in episode_resolutions:
389 |         deduped_nodes: list[EntityNode] = []
390 |         seen: set[str] = set()
391 |         for node in resolved_nodes:
392 |             canonical_uuid = compressed_map.get(node.uuid, node.uuid)
393 |             if canonical_uuid in seen:
394 |                 continue
395 |             seen.add(canonical_uuid)
396 |             canonical_node = canonical_nodes.get(canonical_uuid)
397 |             if canonical_node is None:
398 |                 logger.error(
399 |                     'Canonical node %s missing during batch dedupe; falling back to %s',
400 |                     canonical_uuid,
401 |                     node.uuid,
402 |                 )
403 |                 canonical_node = node
404 |             deduped_nodes.append(canonical_node)
405 | 
406 |         nodes_by_episode[episode_uuid] = deduped_nodes
407 | 
408 |     return nodes_by_episode, compressed_map
409 | 
410 | 
411 | async def dedupe_edges_bulk(
412 |     clients: GraphitiClients,
413 |     extracted_edges: list[list[EntityEdge]],
414 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
415 |     _entities: list[EntityNode],
416 |     edge_types: dict[str, type[BaseModel]],
417 |     _edge_type_map: dict[tuple[str, str], list[str]],
418 | ) -> dict[str, list[EntityEdge]]:
419 |     embedder = clients.embedder
420 |     min_score = 0.6
421 | 
422 |     # generate embeddings
423 |     await semaphore_gather(
424 |         *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
425 |     )
426 | 
427 |     # Find similar results
428 |     dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
429 |     for i, edges_i in enumerate(extracted_edges):
430 |         existing_edges: list[EntityEdge] = []
431 |         for edges_j in extracted_edges:
432 |             existing_edges += edges_j
433 | 
434 |         for edge in edges_i:
435 |             candidates: list[EntityEdge] = []
436 |             for existing_edge in existing_edges:
437 |                 # Skip self-comparison
438 |                 if edge.uuid == existing_edge.uuid:
439 |                     continue
440 |                 # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
441 |                 # This approach will cast a wider net than BM25, which is ideal for this use case
442 |                 if (
443 |                     edge.source_node_uuid != existing_edge.source_node_uuid
444 |                     or edge.target_node_uuid != existing_edge.target_node_uuid
445 |                 ):
446 |                     continue
447 | 
448 |                 edge_words = set(edge.fact.lower().split())
449 |                 existing_edge_words = set(existing_edge.fact.lower().split())
450 |                 has_overlap = not edge_words.isdisjoint(existing_edge_words)
451 |                 if has_overlap:
452 |                     candidates.append(existing_edge)
453 |                     continue
454 | 
455 |                 # Check for semantic similarity even if there is no overlap
456 |                 similarity = np.dot(
457 |                     normalize_l2(edge.fact_embedding or []),
458 |                     normalize_l2(existing_edge.fact_embedding or []),
459 |                 )
460 |                 if similarity >= min_score:
461 |                     candidates.append(existing_edge)
462 | 
463 |             dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
464 | 
465 |     bulk_edge_resolutions: list[
466 |         tuple[EntityEdge, EntityEdge, list[EntityEdge]]
467 |     ] = await semaphore_gather(
468 |         *[
469 |             resolve_extracted_edge(
470 |                 clients.llm_client,
471 |                 edge,
472 |                 candidates,
473 |                 candidates,
474 |                 episode,
475 |                 edge_types,
476 |                 set(edge_types),
477 |             )
478 |             for episode, edge, candidates in dedupe_tuples
479 |         ]
480 |     )
481 | 
482 |     # For now we won't track edge invalidation
483 |     duplicate_pairs: list[tuple[str, str]] = []
484 |     for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
485 |         episode, edge, candidates = dedupe_tuples[i]
486 |         for duplicate in duplicates:
487 |             duplicate_pairs.append((edge.uuid, duplicate.uuid))
488 | 
489 |     # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
490 |     compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
491 | 
492 |     edge_uuid_map: dict[str, EntityEdge] = {
493 |         edge.uuid: edge for edges in extracted_edges for edge in edges
494 |     }
495 | 
496 |     edges_by_episode: dict[str, list[EntityEdge]] = {}
497 |     for i, edges in enumerate(extracted_edges):
498 |         episode = episode_tuples[i][0]
499 | 
500 |         edges_by_episode[episode.uuid] = [
501 |             edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
502 |         ]
503 | 
504 |     return edges_by_episode
505 | 
506 | 
507 | class UnionFind:
508 |     def __init__(self, elements):
509 |         # start each element in its own set
510 |         self.parent = {e: e for e in elements}
511 | 
512 |     def find(self, x):
513 |         # path‐compression
514 |         if self.parent[x] != x:
515 |             self.parent[x] = self.find(self.parent[x])
516 |         return self.parent[x]
517 | 
518 |     def union(self, a, b):
519 |         ra, rb = self.find(a), self.find(b)
520 |         if ra == rb:
521 |             return
522 |         # attach the lexicographically larger root under the smaller
523 |         if ra < rb:
524 |             self.parent[rb] = ra
525 |         else:
526 |             self.parent[ra] = rb
527 | 
528 | 
529 | def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
530 |     """
531 |     all_ids: iterable of all entity IDs (strings)
532 |     duplicate_pairs: iterable of (id1, id2) pairs
533 |     returns: dict mapping each id -> lexicographically smallest id in its duplicate set
534 |     """
535 |     all_uuids = set()
536 |     for pair in duplicate_pairs:
537 |         all_uuids.add(pair[0])
538 |         all_uuids.add(pair[1])
539 | 
540 |     uf = UnionFind(all_uuids)
541 |     for a, b in duplicate_pairs:
542 |         uf.union(a, b)
543 |     # ensure full path‐compression before mapping
544 |     return {uuid: uf.find(uuid) for uuid in all_uuids}
545 | 
546 | 
547 | E = typing.TypeVar('E', bound=Edge)
548 | 
549 | 
550 | def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
551 |     for edge in edges:
552 |         source_uuid = edge.source_node_uuid
553 |         target_uuid = edge.target_node_uuid
554 |         edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
555 |         edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
556 | 
557 |     return edges
558 | 
```

--------------------------------------------------------------------------------
/graphiti_core/edges.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | from abc import ABC, abstractmethod
 20 | from datetime import datetime
 21 | from time import time
 22 | from typing import Any
 23 | from uuid import uuid4
 24 | 
 25 | from pydantic import BaseModel, Field
 26 | from typing_extensions import LiteralString
 27 | 
 28 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
 29 | from graphiti_core.embedder import EmbedderClient
 30 | from graphiti_core.errors import EdgeNotFoundError, GroupsEdgesNotFoundError
 31 | from graphiti_core.helpers import parse_db_date
 32 | from graphiti_core.models.edges.edge_db_queries import (
 33 |     COMMUNITY_EDGE_RETURN,
 34 |     EPISODIC_EDGE_RETURN,
 35 |     EPISODIC_EDGE_SAVE,
 36 |     get_community_edge_save_query,
 37 |     get_entity_edge_return_query,
 38 |     get_entity_edge_save_query,
 39 | )
 40 | from graphiti_core.nodes import Node
 41 | 
 42 | logger = logging.getLogger(__name__)
 43 | 
 44 | 
 45 | class Edge(BaseModel, ABC):
 46 |     uuid: str = Field(default_factory=lambda: str(uuid4()))
 47 |     group_id: str = Field(description='partition of the graph')
 48 |     source_node_uuid: str
 49 |     target_node_uuid: str
 50 |     created_at: datetime
 51 | 
 52 |     @abstractmethod
 53 |     async def save(self, driver: GraphDriver): ...
 54 | 
 55 |     async def delete(self, driver: GraphDriver):
 56 |         if driver.graph_operations_interface:
 57 |             return await driver.graph_operations_interface.edge_delete(self, driver)
 58 | 
 59 |         if driver.provider == GraphProvider.KUZU:
 60 |             await driver.execute_query(
 61 |                 """
 62 |                 MATCH (n)-[e:MENTIONS|HAS_MEMBER {uuid: $uuid}]->(m)
 63 |                 DELETE e
 64 |                 """,
 65 |                 uuid=self.uuid,
 66 |             )
 67 |             await driver.execute_query(
 68 |                 """
 69 |                 MATCH (e:RelatesToNode_ {uuid: $uuid})
 70 |                 DETACH DELETE e
 71 |                 """,
 72 |                 uuid=self.uuid,
 73 |             )
 74 |         else:
 75 |             await driver.execute_query(
 76 |                 """
 77 |                 MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER {uuid: $uuid}]->(m)
 78 |                 DELETE e
 79 |                 """,
 80 |                 uuid=self.uuid,
 81 |             )
 82 | 
 83 |         logger.debug(f'Deleted Edge: {self.uuid}')
 84 | 
 85 |     @classmethod
 86 |     async def delete_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
 87 |         if driver.graph_operations_interface:
 88 |             return await driver.graph_operations_interface.edge_delete_by_uuids(cls, driver, uuids)
 89 | 
 90 |         if driver.provider == GraphProvider.KUZU:
 91 |             await driver.execute_query(
 92 |                 """
 93 |                 MATCH (n)-[e:MENTIONS|HAS_MEMBER]->(m)
 94 |                 WHERE e.uuid IN $uuids
 95 |                 DELETE e
 96 |                 """,
 97 |                 uuids=uuids,
 98 |             )
 99 |             await driver.execute_query(
100 |                 """
101 |                 MATCH (e:RelatesToNode_)
102 |                 WHERE e.uuid IN $uuids
103 |                 DETACH DELETE e
104 |                 """,
105 |                 uuids=uuids,
106 |             )
107 |         else:
108 |             await driver.execute_query(
109 |                 """
110 |                 MATCH (n)-[e:MENTIONS|RELATES_TO|HAS_MEMBER]->(m)
111 |                 WHERE e.uuid IN $uuids
112 |                 DELETE e
113 |                 """,
114 |                 uuids=uuids,
115 |             )
116 | 
117 |         logger.debug(f'Deleted Edges: {uuids}')
118 | 
119 |     def __hash__(self):
120 |         return hash(self.uuid)
121 | 
122 |     def __eq__(self, other):
123 |         if isinstance(other, Node):
124 |             return self.uuid == other.uuid
125 |         return False
126 | 
127 |     @classmethod
128 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str): ...
129 | 
130 | 
131 | class EpisodicEdge(Edge):
132 |     async def save(self, driver: GraphDriver):
133 |         result = await driver.execute_query(
134 |             EPISODIC_EDGE_SAVE,
135 |             episode_uuid=self.source_node_uuid,
136 |             entity_uuid=self.target_node_uuid,
137 |             uuid=self.uuid,
138 |             group_id=self.group_id,
139 |             created_at=self.created_at,
140 |         )
141 | 
142 |         logger.debug(f'Saved edge to Graph: {self.uuid}')
143 | 
144 |         return result
145 | 
146 |     @classmethod
147 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
148 |         records, _, _ = await driver.execute_query(
149 |             """
150 |             MATCH (n:Episodic)-[e:MENTIONS {uuid: $uuid}]->(m:Entity)
151 |             RETURN
152 |             """
153 |             + EPISODIC_EDGE_RETURN,
154 |             uuid=uuid,
155 |             routing_='r',
156 |         )
157 | 
158 |         edges = [get_episodic_edge_from_record(record) for record in records]
159 | 
160 |         if len(edges) == 0:
161 |             raise EdgeNotFoundError(uuid)
162 |         return edges[0]
163 | 
164 |     @classmethod
165 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
166 |         records, _, _ = await driver.execute_query(
167 |             """
168 |             MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
169 |             WHERE e.uuid IN $uuids
170 |             RETURN
171 |             """
172 |             + EPISODIC_EDGE_RETURN,
173 |             uuids=uuids,
174 |             routing_='r',
175 |         )
176 | 
177 |         edges = [get_episodic_edge_from_record(record) for record in records]
178 | 
179 |         if len(edges) == 0:
180 |             raise EdgeNotFoundError(uuids[0])
181 |         return edges
182 | 
183 |     @classmethod
184 |     async def get_by_group_ids(
185 |         cls,
186 |         driver: GraphDriver,
187 |         group_ids: list[str],
188 |         limit: int | None = None,
189 |         uuid_cursor: str | None = None,
190 |     ):
191 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
192 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
193 | 
194 |         records, _, _ = await driver.execute_query(
195 |             """
196 |             MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity)
197 |             WHERE e.group_id IN $group_ids
198 |             """
199 |             + cursor_query
200 |             + """
201 |             RETURN
202 |             """
203 |             + EPISODIC_EDGE_RETURN
204 |             + """
205 |             ORDER BY e.uuid DESC
206 |             """
207 |             + limit_query,
208 |             group_ids=group_ids,
209 |             uuid=uuid_cursor,
210 |             limit=limit,
211 |             routing_='r',
212 |         )
213 | 
214 |         edges = [get_episodic_edge_from_record(record) for record in records]
215 | 
216 |         if len(edges) == 0:
217 |             raise GroupsEdgesNotFoundError(group_ids)
218 |         return edges
219 | 
220 | 
221 | class EntityEdge(Edge):
222 |     name: str = Field(description='name of the edge, relation name')
223 |     fact: str = Field(description='fact representing the edge and nodes that it connects')
224 |     fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact')
225 |     episodes: list[str] = Field(
226 |         default=[],
227 |         description='list of episode ids that reference these entity edges',
228 |     )
229 |     expired_at: datetime | None = Field(
230 |         default=None, description='datetime of when the node was invalidated'
231 |     )
232 |     valid_at: datetime | None = Field(
233 |         default=None, description='datetime of when the fact became true'
234 |     )
235 |     invalid_at: datetime | None = Field(
236 |         default=None, description='datetime of when the fact stopped being true'
237 |     )
238 |     attributes: dict[str, Any] = Field(
239 |         default={}, description='Additional attributes of the edge. Dependent on edge name'
240 |     )
241 | 
242 |     async def generate_embedding(self, embedder: EmbedderClient):
243 |         start = time()
244 | 
245 |         text = self.fact.replace('\n', ' ')
246 |         self.fact_embedding = await embedder.create(input_data=[text])
247 | 
248 |         end = time()
249 |         logger.debug(f'embedded {text} in {end - start} ms')
250 | 
251 |         return self.fact_embedding
252 | 
253 |     async def load_fact_embedding(self, driver: GraphDriver):
254 |         if driver.graph_operations_interface:
255 |             return await driver.graph_operations_interface.edge_load_embeddings(self, driver)
256 | 
257 |         query = """
258 |             MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
259 |             RETURN e.fact_embedding AS fact_embedding
260 |         """
261 | 
262 |         if driver.provider == GraphProvider.NEPTUNE:
263 |             query = """
264 |                 MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
265 |                 RETURN [x IN split(e.fact_embedding, ",") | toFloat(x)] as fact_embedding
266 |             """
267 | 
268 |         if driver.provider == GraphProvider.KUZU:
269 |             query = """
270 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
271 |                 RETURN e.fact_embedding AS fact_embedding
272 |             """
273 | 
274 |         records, _, _ = await driver.execute_query(
275 |             query,
276 |             uuid=self.uuid,
277 |             routing_='r',
278 |         )
279 | 
280 |         if len(records) == 0:
281 |             raise EdgeNotFoundError(self.uuid)
282 | 
283 |         self.fact_embedding = records[0]['fact_embedding']
284 | 
285 |     async def save(self, driver: GraphDriver):
286 |         edge_data: dict[str, Any] = {
287 |             'source_uuid': self.source_node_uuid,
288 |             'target_uuid': self.target_node_uuid,
289 |             'uuid': self.uuid,
290 |             'name': self.name,
291 |             'group_id': self.group_id,
292 |             'fact': self.fact,
293 |             'fact_embedding': self.fact_embedding,
294 |             'episodes': self.episodes,
295 |             'created_at': self.created_at,
296 |             'expired_at': self.expired_at,
297 |             'valid_at': self.valid_at,
298 |             'invalid_at': self.invalid_at,
299 |         }
300 | 
301 |         if driver.provider == GraphProvider.KUZU:
302 |             edge_data['attributes'] = json.dumps(self.attributes)
303 |             result = await driver.execute_query(
304 |                 get_entity_edge_save_query(driver.provider),
305 |                 **edge_data,
306 |             )
307 |         else:
308 |             edge_data.update(self.attributes or {})
309 |             result = await driver.execute_query(
310 |                 get_entity_edge_save_query(driver.provider),
311 |                 edge_data=edge_data,
312 |             )
313 | 
314 |         logger.debug(f'Saved edge to Graph: {self.uuid}')
315 | 
316 |         return result
317 | 
318 |     @classmethod
319 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
320 |         match_query = """
321 |             MATCH (n:Entity)-[e:RELATES_TO {uuid: $uuid}]->(m:Entity)
322 |         """
323 |         if driver.provider == GraphProvider.KUZU:
324 |             match_query = """
325 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(m:Entity)
326 |             """
327 | 
328 |         records, _, _ = await driver.execute_query(
329 |             match_query
330 |             + """
331 |             RETURN
332 |             """
333 |             + get_entity_edge_return_query(driver.provider),
334 |             uuid=uuid,
335 |             routing_='r',
336 |         )
337 | 
338 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
339 | 
340 |         if len(edges) == 0:
341 |             raise EdgeNotFoundError(uuid)
342 |         return edges[0]
343 | 
344 |     @classmethod
345 |     async def get_between_nodes(
346 |         cls, driver: GraphDriver, source_node_uuid: str, target_node_uuid: str
347 |     ):
348 |         match_query = """
349 |             MATCH (n:Entity {uuid: $source_node_uuid})-[e:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
350 |         """
351 |         if driver.provider == GraphProvider.KUZU:
352 |             match_query = """
353 |                 MATCH (n:Entity {uuid: $source_node_uuid})
354 |                       -[:RELATES_TO]->(e:RelatesToNode_)
355 |                       -[:RELATES_TO]->(m:Entity {uuid: $target_node_uuid})
356 |             """
357 | 
358 |         records, _, _ = await driver.execute_query(
359 |             match_query
360 |             + """
361 |             RETURN
362 |             """
363 |             + get_entity_edge_return_query(driver.provider),
364 |             source_node_uuid=source_node_uuid,
365 |             target_node_uuid=target_node_uuid,
366 |             routing_='r',
367 |         )
368 | 
369 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
370 | 
371 |         return edges
372 | 
373 |     @classmethod
374 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
375 |         if len(uuids) == 0:
376 |             return []
377 | 
378 |         match_query = """
379 |             MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
380 |         """
381 |         if driver.provider == GraphProvider.KUZU:
382 |             match_query = """
383 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
384 |             """
385 | 
386 |         records, _, _ = await driver.execute_query(
387 |             match_query
388 |             + """
389 |             WHERE e.uuid IN $uuids
390 |             RETURN
391 |             """
392 |             + get_entity_edge_return_query(driver.provider),
393 |             uuids=uuids,
394 |             routing_='r',
395 |         )
396 | 
397 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
398 | 
399 |         return edges
400 | 
401 |     @classmethod
402 |     async def get_by_group_ids(
403 |         cls,
404 |         driver: GraphDriver,
405 |         group_ids: list[str],
406 |         limit: int | None = None,
407 |         uuid_cursor: str | None = None,
408 |         with_embeddings: bool = False,
409 |     ):
410 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
411 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
412 |         with_embeddings_query: LiteralString = (
413 |             """,
414 |                 e.fact_embedding AS fact_embedding
415 |                 """
416 |             if with_embeddings
417 |             else ''
418 |         )
419 | 
420 |         match_query = """
421 |             MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
422 |         """
423 |         if driver.provider == GraphProvider.KUZU:
424 |             match_query = """
425 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
426 |             """
427 | 
428 |         records, _, _ = await driver.execute_query(
429 |             match_query
430 |             + """
431 |             WHERE e.group_id IN $group_ids
432 |             """
433 |             + cursor_query
434 |             + """
435 |             RETURN
436 |             """
437 |             + get_entity_edge_return_query(driver.provider)
438 |             + with_embeddings_query
439 |             + """
440 |             ORDER BY e.uuid DESC
441 |             """
442 |             + limit_query,
443 |             group_ids=group_ids,
444 |             uuid=uuid_cursor,
445 |             limit=limit,
446 |             routing_='r',
447 |         )
448 | 
449 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
450 | 
451 |         if len(edges) == 0:
452 |             raise GroupsEdgesNotFoundError(group_ids)
453 |         return edges
454 | 
455 |     @classmethod
456 |     async def get_by_node_uuid(cls, driver: GraphDriver, node_uuid: str):
457 |         match_query = """
458 |             MATCH (n:Entity {uuid: $node_uuid})-[e:RELATES_TO]-(m:Entity)
459 |         """
460 |         if driver.provider == GraphProvider.KUZU:
461 |             match_query = """
462 |                 MATCH (n:Entity {uuid: $node_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
463 |             """
464 | 
465 |         records, _, _ = await driver.execute_query(
466 |             match_query
467 |             + """
468 |             RETURN
469 |             """
470 |             + get_entity_edge_return_query(driver.provider),
471 |             node_uuid=node_uuid,
472 |             routing_='r',
473 |         )
474 | 
475 |         edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
476 | 
477 |         return edges
478 | 
479 | 
480 | class CommunityEdge(Edge):
481 |     async def save(self, driver: GraphDriver):
482 |         result = await driver.execute_query(
483 |             get_community_edge_save_query(driver.provider),
484 |             community_uuid=self.source_node_uuid,
485 |             entity_uuid=self.target_node_uuid,
486 |             uuid=self.uuid,
487 |             group_id=self.group_id,
488 |             created_at=self.created_at,
489 |         )
490 | 
491 |         logger.debug(f'Saved edge to Graph: {self.uuid}')
492 | 
493 |         return result
494 | 
495 |     @classmethod
496 |     async def get_by_uuid(cls, driver: GraphDriver, uuid: str):
497 |         records, _, _ = await driver.execute_query(
498 |             """
499 |             MATCH (n:Community)-[e:HAS_MEMBER {uuid: $uuid}]->(m)
500 |             RETURN
501 |             """
502 |             + COMMUNITY_EDGE_RETURN,
503 |             uuid=uuid,
504 |             routing_='r',
505 |         )
506 | 
507 |         edges = [get_community_edge_from_record(record) for record in records]
508 | 
509 |         return edges[0]
510 | 
511 |     @classmethod
512 |     async def get_by_uuids(cls, driver: GraphDriver, uuids: list[str]):
513 |         records, _, _ = await driver.execute_query(
514 |             """
515 |             MATCH (n:Community)-[e:HAS_MEMBER]->(m)
516 |             WHERE e.uuid IN $uuids
517 |             RETURN
518 |             """
519 |             + COMMUNITY_EDGE_RETURN,
520 |             uuids=uuids,
521 |             routing_='r',
522 |         )
523 | 
524 |         edges = [get_community_edge_from_record(record) for record in records]
525 | 
526 |         return edges
527 | 
528 |     @classmethod
529 |     async def get_by_group_ids(
530 |         cls,
531 |         driver: GraphDriver,
532 |         group_ids: list[str],
533 |         limit: int | None = None,
534 |         uuid_cursor: str | None = None,
535 |     ):
536 |         cursor_query: LiteralString = 'AND e.uuid < $uuid' if uuid_cursor else ''
537 |         limit_query: LiteralString = 'LIMIT $limit' if limit is not None else ''
538 | 
539 |         records, _, _ = await driver.execute_query(
540 |             """
541 |             MATCH (n:Community)-[e:HAS_MEMBER]->(m)
542 |             WHERE e.group_id IN $group_ids
543 |             """
544 |             + cursor_query
545 |             + """
546 |             RETURN
547 |             """
548 |             + COMMUNITY_EDGE_RETURN
549 |             + """
550 |             ORDER BY e.uuid DESC
551 |             """
552 |             + limit_query,
553 |             group_ids=group_ids,
554 |             uuid=uuid_cursor,
555 |             limit=limit,
556 |             routing_='r',
557 |         )
558 | 
559 |         edges = [get_community_edge_from_record(record) for record in records]
560 | 
561 |         return edges
562 | 
563 | 
564 | # Edge helpers
565 | def get_episodic_edge_from_record(record: Any) -> EpisodicEdge:
566 |     return EpisodicEdge(
567 |         uuid=record['uuid'],
568 |         group_id=record['group_id'],
569 |         source_node_uuid=record['source_node_uuid'],
570 |         target_node_uuid=record['target_node_uuid'],
571 |         created_at=parse_db_date(record['created_at']),  # type: ignore
572 |     )
573 | 
574 | 
575 | def get_entity_edge_from_record(record: Any, provider: GraphProvider) -> EntityEdge:
576 |     episodes = record['episodes']
577 |     if provider == GraphProvider.KUZU:
578 |         attributes = json.loads(record['attributes']) if record['attributes'] else {}
579 |     else:
580 |         attributes = record['attributes']
581 |         attributes.pop('uuid', None)
582 |         attributes.pop('source_node_uuid', None)
583 |         attributes.pop('target_node_uuid', None)
584 |         attributes.pop('fact', None)
585 |         attributes.pop('fact_embedding', None)
586 |         attributes.pop('name', None)
587 |         attributes.pop('group_id', None)
588 |         attributes.pop('episodes', None)
589 |         attributes.pop('created_at', None)
590 |         attributes.pop('expired_at', None)
591 |         attributes.pop('valid_at', None)
592 |         attributes.pop('invalid_at', None)
593 | 
594 |     edge = EntityEdge(
595 |         uuid=record['uuid'],
596 |         source_node_uuid=record['source_node_uuid'],
597 |         target_node_uuid=record['target_node_uuid'],
598 |         fact=record['fact'],
599 |         fact_embedding=record.get('fact_embedding'),
600 |         name=record['name'],
601 |         group_id=record['group_id'],
602 |         episodes=episodes,
603 |         created_at=parse_db_date(record['created_at']),  # type: ignore
604 |         expired_at=parse_db_date(record['expired_at']),
605 |         valid_at=parse_db_date(record['valid_at']),
606 |         invalid_at=parse_db_date(record['invalid_at']),
607 |         attributes=attributes,
608 |     )
609 | 
610 |     return edge
611 | 
612 | 
613 | def get_community_edge_from_record(record: Any):
614 |     return CommunityEdge(
615 |         uuid=record['uuid'],
616 |         group_id=record['group_id'],
617 |         source_node_uuid=record['source_node_uuid'],
618 |         target_node_uuid=record['target_node_uuid'],
619 |         created_at=parse_db_date(record['created_at']),  # type: ignore
620 |     )
621 | 
622 | 
623 | async def create_entity_edge_embeddings(embedder: EmbedderClient, edges: list[EntityEdge]):
624 |     # filter out falsey values from edges
625 |     filtered_edges = [edge for edge in edges if edge.fact]
626 | 
627 |     if len(filtered_edges) == 0:
628 |         return
629 |     fact_embeddings = await embedder.create_batch([edge.fact for edge in filtered_edges])
630 |     for edge, fact_embedding in zip(filtered_edges, fact_embeddings, strict=True):
631 |         edge.fact_embedding = fact_embedding
632 | 
```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_node_operations.py:
--------------------------------------------------------------------------------

```python
  1 | import logging
  2 | from collections import defaultdict
  3 | from unittest.mock import AsyncMock, MagicMock
  4 | 
  5 | import pytest
  6 | 
  7 | from graphiti_core.graphiti_types import GraphitiClients
  8 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
  9 | from graphiti_core.search.search_config import SearchResults
 10 | from graphiti_core.utils.datetime_utils import utc_now
 11 | from graphiti_core.utils.maintenance.dedup_helpers import (
 12 |     DedupCandidateIndexes,
 13 |     DedupResolutionState,
 14 |     _build_candidate_indexes,
 15 |     _cached_shingles,
 16 |     _has_high_entropy,
 17 |     _hash_shingle,
 18 |     _jaccard_similarity,
 19 |     _lsh_bands,
 20 |     _minhash_signature,
 21 |     _name_entropy,
 22 |     _normalize_name_for_fuzzy,
 23 |     _normalize_string_exact,
 24 |     _resolve_with_similarity,
 25 |     _shingles,
 26 | )
 27 | from graphiti_core.utils.maintenance.node_operations import (
 28 |     _collect_candidate_nodes,
 29 |     _resolve_with_llm,
 30 |     extract_attributes_from_node,
 31 |     extract_attributes_from_nodes,
 32 |     resolve_extracted_nodes,
 33 | )
 34 | 
 35 | 
 36 | def _make_clients():
 37 |     driver = MagicMock()
 38 |     embedder = MagicMock()
 39 |     cross_encoder = MagicMock()
 40 |     llm_client = MagicMock()
 41 |     llm_generate = AsyncMock()
 42 |     llm_client.generate_response = llm_generate
 43 | 
 44 |     clients = GraphitiClients.model_construct(  # bypass validation to allow test doubles
 45 |         driver=driver,
 46 |         embedder=embedder,
 47 |         cross_encoder=cross_encoder,
 48 |         llm_client=llm_client,
 49 |     )
 50 | 
 51 |     return clients, llm_generate
 52 | 
 53 | 
 54 | def _make_episode(group_id: str = 'group'):
 55 |     return EpisodicNode(
 56 |         name='episode',
 57 |         group_id=group_id,
 58 |         source=EpisodeType.message,
 59 |         source_description='test',
 60 |         content='content',
 61 |         valid_at=utc_now(),
 62 |     )
 63 | 
 64 | 
 65 | @pytest.mark.asyncio
 66 | async def test_resolve_nodes_exact_match_skips_llm(monkeypatch):
 67 |     clients, llm_generate = _make_clients()
 68 | 
 69 |     candidate = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
 70 |     extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
 71 | 
 72 |     async def fake_search(*_, **__):
 73 |         return SearchResults(nodes=[candidate])
 74 | 
 75 |     monkeypatch.setattr(
 76 |         'graphiti_core.utils.maintenance.node_operations.search',
 77 |         fake_search,
 78 |     )
 79 |     monkeypatch.setattr(
 80 |         'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
 81 |         AsyncMock(return_value=[]),
 82 |     )
 83 | 
 84 |     resolved, uuid_map, _ = await resolve_extracted_nodes(
 85 |         clients,
 86 |         [extracted],
 87 |         episode=_make_episode(),
 88 |         previous_episodes=[],
 89 |     )
 90 | 
 91 |     assert resolved[0].uuid == candidate.uuid
 92 |     assert uuid_map[extracted.uuid] == candidate.uuid
 93 |     llm_generate.assert_not_awaited()
 94 | 
 95 | 
 96 | @pytest.mark.asyncio
 97 | async def test_resolve_nodes_low_entropy_uses_llm(monkeypatch):
 98 |     clients, llm_generate = _make_clients()
 99 |     llm_generate.return_value = {
100 |         'entity_resolutions': [
101 |             {
102 |                 'id': 0,
103 |                 'duplicate_idx': -1,
104 |                 'name': 'Joe',
105 |                 'duplicates': [],
106 |             }
107 |         ]
108 |     }
109 | 
110 |     extracted = EntityNode(name='Joe', group_id='group', labels=['Entity'])
111 | 
112 |     async def fake_search(*_, **__):
113 |         return SearchResults(nodes=[])
114 | 
115 |     monkeypatch.setattr(
116 |         'graphiti_core.utils.maintenance.node_operations.search',
117 |         fake_search,
118 |     )
119 |     monkeypatch.setattr(
120 |         'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
121 |         AsyncMock(return_value=[]),
122 |     )
123 | 
124 |     resolved, uuid_map, _ = await resolve_extracted_nodes(
125 |         clients,
126 |         [extracted],
127 |         episode=_make_episode(),
128 |         previous_episodes=[],
129 |     )
130 | 
131 |     assert resolved[0].uuid == extracted.uuid
132 |     assert uuid_map[extracted.uuid] == extracted.uuid
133 |     llm_generate.assert_awaited()
134 | 
135 | 
136 | @pytest.mark.asyncio
137 | async def test_resolve_nodes_fuzzy_match(monkeypatch):
138 |     clients, llm_generate = _make_clients()
139 | 
140 |     candidate = EntityNode(name='Joe-Michaels', group_id='group', labels=['Entity'])
141 |     extracted = EntityNode(name='Joe Michaels', group_id='group', labels=['Entity'])
142 | 
143 |     async def fake_search(*_, **__):
144 |         return SearchResults(nodes=[candidate])
145 | 
146 |     monkeypatch.setattr(
147 |         'graphiti_core.utils.maintenance.node_operations.search',
148 |         fake_search,
149 |     )
150 |     monkeypatch.setattr(
151 |         'graphiti_core.utils.maintenance.node_operations.filter_existing_duplicate_of_edges',
152 |         AsyncMock(return_value=[]),
153 |     )
154 | 
155 |     resolved, uuid_map, _ = await resolve_extracted_nodes(
156 |         clients,
157 |         [extracted],
158 |         episode=_make_episode(),
159 |         previous_episodes=[],
160 |     )
161 | 
162 |     assert resolved[0].uuid == candidate.uuid
163 |     assert uuid_map[extracted.uuid] == candidate.uuid
164 |     llm_generate.assert_not_awaited()
165 | 
166 | 
167 | @pytest.mark.asyncio
168 | async def test_collect_candidate_nodes_dedupes_and_merges_override(monkeypatch):
169 |     clients, _ = _make_clients()
170 | 
171 |     candidate = EntityNode(name='Alice', group_id='group', labels=['Entity'])
172 |     override_duplicate = EntityNode(
173 |         uuid=candidate.uuid,
174 |         name='Alice Alt',
175 |         group_id='group',
176 |         labels=['Entity'],
177 |     )
178 |     extracted = EntityNode(name='Alice', group_id='group', labels=['Entity'])
179 | 
180 |     search_mock = AsyncMock(return_value=SearchResults(nodes=[candidate]))
181 |     monkeypatch.setattr(
182 |         'graphiti_core.utils.maintenance.node_operations.search',
183 |         search_mock,
184 |     )
185 | 
186 |     result = await _collect_candidate_nodes(
187 |         clients,
188 |         [extracted],
189 |         existing_nodes_override=[override_duplicate],
190 |     )
191 | 
192 |     assert len(result) == 1
193 |     assert result[0].uuid == candidate.uuid
194 |     search_mock.assert_awaited()
195 | 
196 | 
197 | def test_build_candidate_indexes_populates_structures():
198 |     candidate = EntityNode(name='Bob Dylan', group_id='group', labels=['Entity'])
199 | 
200 |     indexes = _build_candidate_indexes([candidate])
201 | 
202 |     normalized_key = candidate.name.lower()
203 |     assert indexes.normalized_existing[normalized_key][0].uuid == candidate.uuid
204 |     assert indexes.nodes_by_uuid[candidate.uuid] is candidate
205 |     assert candidate.uuid in indexes.shingles_by_candidate
206 |     assert any(candidate.uuid in bucket for bucket in indexes.lsh_buckets.values())
207 | 
208 | 
209 | def test_normalize_helpers():
210 |     assert _normalize_string_exact('  Alice   Smith ') == 'alice smith'
211 |     assert _normalize_name_for_fuzzy('Alice-Smith!') == 'alice smith'
212 | 
213 | 
214 | def test_name_entropy_variants():
215 |     assert _name_entropy('alice') > _name_entropy('aaaaa')
216 |     assert _name_entropy('') == 0.0
217 | 
218 | 
219 | def test_has_high_entropy_rules():
220 |     assert _has_high_entropy('meaningful name') is True
221 |     assert _has_high_entropy('aa') is False
222 | 
223 | 
224 | def test_shingles_and_cache():
225 |     raw = 'alice'
226 |     shingle_set = _shingles(raw)
227 |     assert shingle_set == {'ali', 'lic', 'ice'}
228 |     assert _cached_shingles(raw) == shingle_set
229 |     assert _cached_shingles(raw) is _cached_shingles(raw)
230 | 
231 | 
232 | def test_hash_minhash_and_lsh():
233 |     shingles = {'abc', 'bcd', 'cde'}
234 |     signature = _minhash_signature(shingles)
235 |     assert len(signature) == 32
236 |     bands = _lsh_bands(signature)
237 |     assert all(len(band) == 4 for band in bands)
238 |     hashed = {_hash_shingle(s, 0) for s in shingles}
239 |     assert len(hashed) == len(shingles)
240 | 
241 | 
242 | def test_jaccard_similarity_edges():
243 |     a = {'a', 'b'}
244 |     b = {'a', 'c'}
245 |     assert _jaccard_similarity(a, b) == pytest.approx(1 / 3)
246 |     assert _jaccard_similarity(set(), set()) == 1.0
247 |     assert _jaccard_similarity(a, set()) == 0.0
248 | 
249 | 
250 | def test_resolve_with_similarity_exact_match_updates_state():
251 |     candidate = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
252 |     extracted = EntityNode(name='Charlie Parker', group_id='group', labels=['Entity'])
253 | 
254 |     indexes = _build_candidate_indexes([candidate])
255 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
256 | 
257 |     _resolve_with_similarity([extracted], indexes, state)
258 | 
259 |     assert state.resolved_nodes[0].uuid == candidate.uuid
260 |     assert state.uuid_map[extracted.uuid] == candidate.uuid
261 |     assert state.unresolved_indices == []
262 |     assert state.duplicate_pairs == [(extracted, candidate)]
263 | 
264 | 
265 | def test_resolve_with_similarity_low_entropy_defers_resolution():
266 |     extracted = EntityNode(name='Bob', group_id='group', labels=['Entity'])
267 |     indexes = DedupCandidateIndexes(
268 |         existing_nodes=[],
269 |         nodes_by_uuid={},
270 |         normalized_existing=defaultdict(list),
271 |         shingles_by_candidate={},
272 |         lsh_buckets=defaultdict(list),
273 |     )
274 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
275 | 
276 |     _resolve_with_similarity([extracted], indexes, state)
277 | 
278 |     assert state.resolved_nodes[0] is None
279 |     assert state.unresolved_indices == [0]
280 |     assert state.duplicate_pairs == []
281 | 
282 | 
283 | def test_resolve_with_similarity_multiple_exact_matches_defers_to_llm():
284 |     candidate1 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
285 |     candidate2 = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
286 |     extracted = EntityNode(name='Johnny Appleseed', group_id='group', labels=['Entity'])
287 | 
288 |     indexes = _build_candidate_indexes([candidate1, candidate2])
289 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[])
290 | 
291 |     _resolve_with_similarity([extracted], indexes, state)
292 | 
293 |     assert state.resolved_nodes[0] is None
294 |     assert state.unresolved_indices == [0]
295 |     assert state.duplicate_pairs == []
296 | 
297 | 
298 | @pytest.mark.asyncio
299 | async def test_resolve_with_llm_updates_unresolved(monkeypatch):
300 |     extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
301 |     candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
302 | 
303 |     indexes = _build_candidate_indexes([candidate])
304 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
305 | 
306 |     captured_context = {}
307 | 
308 |     def fake_prompt_nodes(context):
309 |         captured_context.update(context)
310 |         return ['prompt']
311 | 
312 |     monkeypatch.setattr(
313 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
314 |         fake_prompt_nodes,
315 |     )
316 | 
317 |     async def fake_generate_response(*_, **__):
318 |         return {
319 |             'entity_resolutions': [
320 |                 {
321 |                     'id': 0,
322 |                     'duplicate_idx': 0,
323 |                     'name': 'Dizzy Gillespie',
324 |                     'duplicates': [0],
325 |                 }
326 |             ]
327 |         }
328 | 
329 |     llm_client = MagicMock()
330 |     llm_client.generate_response = AsyncMock(side_effect=fake_generate_response)
331 | 
332 |     await _resolve_with_llm(
333 |         llm_client,
334 |         [extracted],
335 |         indexes,
336 |         state,
337 |         episode=_make_episode(),
338 |         previous_episodes=[],
339 |         entity_types=None,
340 |     )
341 | 
342 |     assert state.resolved_nodes[0].uuid == candidate.uuid
343 |     assert state.uuid_map[extracted.uuid] == candidate.uuid
344 |     assert captured_context['existing_nodes'][0]['idx'] == 0
345 |     assert isinstance(captured_context['existing_nodes'], list)
346 |     assert state.duplicate_pairs == [(extracted, candidate)]
347 | 
348 | 
349 | @pytest.mark.asyncio
350 | async def test_resolve_with_llm_ignores_out_of_range_relative_ids(monkeypatch, caplog):
351 |     extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
352 | 
353 |     indexes = _build_candidate_indexes([])
354 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
355 | 
356 |     monkeypatch.setattr(
357 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
358 |         lambda context: ['prompt'],
359 |     )
360 | 
361 |     llm_client = MagicMock()
362 |     llm_client.generate_response = AsyncMock(
363 |         return_value={
364 |             'entity_resolutions': [
365 |                 {
366 |                     'id': 5,
367 |                     'duplicate_idx': -1,
368 |                     'name': 'Dexter',
369 |                     'duplicates': [],
370 |                 }
371 |             ]
372 |         }
373 |     )
374 | 
375 |     with caplog.at_level(logging.WARNING):
376 |         await _resolve_with_llm(
377 |             llm_client,
378 |             [extracted],
379 |             indexes,
380 |             state,
381 |             episode=_make_episode(),
382 |             previous_episodes=[],
383 |             entity_types=None,
384 |         )
385 | 
386 |     assert state.resolved_nodes[0] is None
387 |     assert 'Skipping invalid LLM dedupe id 5' in caplog.text
388 | 
389 | 
390 | @pytest.mark.asyncio
391 | async def test_resolve_with_llm_ignores_duplicate_relative_ids(monkeypatch):
392 |     extracted = EntityNode(name='Dizzy', group_id='group', labels=['Entity'])
393 |     candidate = EntityNode(name='Dizzy Gillespie', group_id='group', labels=['Entity'])
394 | 
395 |     indexes = _build_candidate_indexes([candidate])
396 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
397 | 
398 |     monkeypatch.setattr(
399 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
400 |         lambda context: ['prompt'],
401 |     )
402 | 
403 |     llm_client = MagicMock()
404 |     llm_client.generate_response = AsyncMock(
405 |         return_value={
406 |             'entity_resolutions': [
407 |                 {
408 |                     'id': 0,
409 |                     'duplicate_idx': 0,
410 |                     'name': 'Dizzy Gillespie',
411 |                     'duplicates': [0],
412 |                 },
413 |                 {
414 |                     'id': 0,
415 |                     'duplicate_idx': -1,
416 |                     'name': 'Dizzy',
417 |                     'duplicates': [],
418 |                 },
419 |             ]
420 |         }
421 |     )
422 | 
423 |     await _resolve_with_llm(
424 |         llm_client,
425 |         [extracted],
426 |         indexes,
427 |         state,
428 |         episode=_make_episode(),
429 |         previous_episodes=[],
430 |         entity_types=None,
431 |     )
432 | 
433 |     assert state.resolved_nodes[0].uuid == candidate.uuid
434 |     assert state.uuid_map[extracted.uuid] == candidate.uuid
435 |     assert state.duplicate_pairs == [(extracted, candidate)]
436 | 
437 | 
438 | @pytest.mark.asyncio
439 | async def test_resolve_with_llm_invalid_duplicate_idx_defaults_to_extracted(monkeypatch):
440 |     extracted = EntityNode(name='Dexter', group_id='group', labels=['Entity'])
441 | 
442 |     indexes = _build_candidate_indexes([])
443 |     state = DedupResolutionState(resolved_nodes=[None], uuid_map={}, unresolved_indices=[0])
444 | 
445 |     monkeypatch.setattr(
446 |         'graphiti_core.utils.maintenance.node_operations.prompt_library.dedupe_nodes.nodes',
447 |         lambda context: ['prompt'],
448 |     )
449 | 
450 |     llm_client = MagicMock()
451 |     llm_client.generate_response = AsyncMock(
452 |         return_value={
453 |             'entity_resolutions': [
454 |                 {
455 |                     'id': 0,
456 |                     'duplicate_idx': 10,
457 |                     'name': 'Dexter',
458 |                     'duplicates': [],
459 |                 }
460 |             ]
461 |         }
462 |     )
463 | 
464 |     await _resolve_with_llm(
465 |         llm_client,
466 |         [extracted],
467 |         indexes,
468 |         state,
469 |         episode=_make_episode(),
470 |         previous_episodes=[],
471 |         entity_types=None,
472 |     )
473 | 
474 |     assert state.resolved_nodes[0] == extracted
475 |     assert state.uuid_map[extracted.uuid] == extracted.uuid
476 |     assert state.duplicate_pairs == []
477 | 
478 | 
479 | @pytest.mark.asyncio
480 | async def test_extract_attributes_without_callback_generates_summary():
481 |     """Test that summary is generated when no callback is provided (default behavior)."""
482 |     llm_client = MagicMock()
483 |     llm_client.generate_response = AsyncMock(
484 |         return_value={'summary': 'Generated summary', 'attributes': {}}
485 |     )
486 | 
487 |     node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
488 |     episode = _make_episode()
489 | 
490 |     result = await extract_attributes_from_node(
491 |         llm_client,
492 |         node,
493 |         episode=episode,
494 |         previous_episodes=[],
495 |         entity_type=None,
496 |         should_summarize_node=None,  # No callback provided
497 |     )
498 | 
499 |     # Summary should be generated
500 |     assert result.summary == 'Generated summary'
501 |     # LLM should have been called for summary
502 |     assert llm_client.generate_response.call_count == 1
503 | 
504 | 
505 | @pytest.mark.asyncio
506 | async def test_extract_attributes_with_callback_skip_summary():
507 |     """Test that summary is NOT regenerated when callback returns False."""
508 |     llm_client = MagicMock()
509 |     llm_client.generate_response = AsyncMock(
510 |         return_value={'summary': 'This should not be used', 'attributes': {}}
511 |     )
512 | 
513 |     node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
514 |     episode = _make_episode()
515 | 
516 |     # Callback that always returns False (skip summary generation)
517 |     async def skip_summary_filter(node: EntityNode) -> bool:
518 |         return False
519 | 
520 |     result = await extract_attributes_from_node(
521 |         llm_client,
522 |         node,
523 |         episode=episode,
524 |         previous_episodes=[],
525 |         entity_type=None,
526 |         should_summarize_node=skip_summary_filter,
527 |     )
528 | 
529 |     # Summary should remain unchanged
530 |     assert result.summary == 'Old summary'
531 |     # LLM should NOT have been called for summary
532 |     assert llm_client.generate_response.call_count == 0
533 | 
534 | 
535 | @pytest.mark.asyncio
536 | async def test_extract_attributes_with_callback_generate_summary():
537 |     """Test that summary is regenerated when callback returns True."""
538 |     llm_client = MagicMock()
539 |     llm_client.generate_response = AsyncMock(
540 |         return_value={'summary': 'New generated summary', 'attributes': {}}
541 |     )
542 | 
543 |     node = EntityNode(name='Test Node', group_id='group', labels=['Entity'], summary='Old summary')
544 |     episode = _make_episode()
545 | 
546 |     # Callback that always returns True (generate summary)
547 |     async def generate_summary_filter(node: EntityNode) -> bool:
548 |         return True
549 | 
550 |     result = await extract_attributes_from_node(
551 |         llm_client,
552 |         node,
553 |         episode=episode,
554 |         previous_episodes=[],
555 |         entity_type=None,
556 |         should_summarize_node=generate_summary_filter,
557 |     )
558 | 
559 |     # Summary should be updated
560 |     assert result.summary == 'New generated summary'
561 |     # LLM should have been called for summary
562 |     assert llm_client.generate_response.call_count == 1
563 | 
564 | 
565 | @pytest.mark.asyncio
566 | async def test_extract_attributes_with_selective_callback():
567 |     """Test callback that selectively skips summaries based on node properties."""
568 |     llm_client = MagicMock()
569 |     llm_client.generate_response = AsyncMock(
570 |         return_value={'summary': 'Generated summary', 'attributes': {}}
571 |     )
572 | 
573 |     user_node = EntityNode(name='User', group_id='group', labels=['Entity', 'User'], summary='Old')
574 |     topic_node = EntityNode(
575 |         name='Topic', group_id='group', labels=['Entity', 'Topic'], summary='Old'
576 |     )
577 | 
578 |     episode = _make_episode()
579 | 
580 |     # Callback that skips User nodes but generates for others
581 |     async def selective_filter(node: EntityNode) -> bool:
582 |         return 'User' not in node.labels
583 | 
584 |     result_user = await extract_attributes_from_node(
585 |         llm_client,
586 |         user_node,
587 |         episode=episode,
588 |         previous_episodes=[],
589 |         entity_type=None,
590 |         should_summarize_node=selective_filter,
591 |     )
592 | 
593 |     result_topic = await extract_attributes_from_node(
594 |         llm_client,
595 |         topic_node,
596 |         episode=episode,
597 |         previous_episodes=[],
598 |         entity_type=None,
599 |         should_summarize_node=selective_filter,
600 |     )
601 | 
602 |     # User summary should remain unchanged
603 |     assert result_user.summary == 'Old'
604 |     # Topic summary should be generated
605 |     assert result_topic.summary == 'Generated summary'
606 |     # LLM should have been called only once (for topic)
607 |     assert llm_client.generate_response.call_count == 1
608 | 
609 | 
610 | @pytest.mark.asyncio
611 | async def test_extract_attributes_from_nodes_with_callback():
612 |     """Test that callback is properly passed through extract_attributes_from_nodes."""
613 |     clients, _ = _make_clients()
614 |     clients.llm_client.generate_response = AsyncMock(
615 |         return_value={'summary': 'New summary', 'attributes': {}}
616 |     )
617 |     clients.embedder.create = AsyncMock(return_value=[0.1, 0.2, 0.3])
618 |     clients.embedder.create_batch = AsyncMock(return_value=[[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])
619 | 
620 |     node1 = EntityNode(name='Node1', group_id='group', labels=['Entity', 'User'], summary='Old1')
621 |     node2 = EntityNode(name='Node2', group_id='group', labels=['Entity', 'Topic'], summary='Old2')
622 | 
623 |     episode = _make_episode()
624 | 
625 |     call_tracker = []
626 | 
627 |     # Callback that tracks which nodes it's called with
628 |     async def tracking_filter(node: EntityNode) -> bool:
629 |         call_tracker.append(node.name)
630 |         return 'User' not in node.labels
631 | 
632 |     results = await extract_attributes_from_nodes(
633 |         clients,
634 |         [node1, node2],
635 |         episode=episode,
636 |         previous_episodes=[],
637 |         entity_types=None,
638 |         should_summarize_node=tracking_filter,
639 |     )
640 | 
641 |     # Callback should have been called for both nodes
642 |     assert len(call_tracker) == 2
643 |     assert 'Node1' in call_tracker
644 |     assert 'Node2' in call_tracker
645 | 
646 |     # Node1 (User) should keep old summary, Node2 (Topic) should get new summary
647 |     node1_result = next(n for n in results if n.name == 'Node1')
648 |     node2_result = next(n for n in results if n.name == 'Node2')
649 | 
650 |     assert node1_result.summary == 'Old1'
651 |     assert node2_result.summary == 'New summary'
652 | 
```

--------------------------------------------------------------------------------
/tests/llm_client/test_gemini_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/llm_client/test_gemini_client.py
 18 | 
 19 | from unittest.mock import AsyncMock, MagicMock, patch
 20 | 
 21 | import pytest
 22 | from pydantic import BaseModel
 23 | 
 24 | from graphiti_core.llm_client.config import LLMConfig, ModelSize
 25 | from graphiti_core.llm_client.errors import RateLimitError
 26 | from graphiti_core.llm_client.gemini_client import DEFAULT_MODEL, DEFAULT_SMALL_MODEL, GeminiClient
 27 | from graphiti_core.prompts.models import Message
 28 | 
 29 | 
 30 | # Test model for response testing
 31 | class ResponseModel(BaseModel):
 32 |     """Test model for response testing."""
 33 | 
 34 |     test_field: str
 35 |     optional_field: int = 0
 36 | 
 37 | 
 38 | @pytest.fixture
 39 | def mock_gemini_client():
 40 |     """Fixture to mock the Google Gemini client."""
 41 |     with patch('google.genai.Client') as mock_client:
 42 |         # Setup mock instance and its methods
 43 |         mock_instance = mock_client.return_value
 44 |         mock_instance.aio = MagicMock()
 45 |         mock_instance.aio.models = MagicMock()
 46 |         mock_instance.aio.models.generate_content = AsyncMock()
 47 |         yield mock_instance
 48 | 
 49 | 
 50 | @pytest.fixture
 51 | def gemini_client(mock_gemini_client):
 52 |     """Fixture to create a GeminiClient with a mocked client."""
 53 |     config = LLMConfig(api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000)
 54 |     client = GeminiClient(config=config, cache=False)
 55 |     # Replace the client's client with our mock to ensure we're using the mock
 56 |     client.client = mock_gemini_client
 57 |     return client
 58 | 
 59 | 
 60 | class TestGeminiClientInitialization:
 61 |     """Tests for GeminiClient initialization."""
 62 | 
 63 |     @patch('google.genai.Client')
 64 |     def test_init_with_config(self, mock_client):
 65 |         """Test initialization with a config object."""
 66 |         config = LLMConfig(
 67 |             api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
 68 |         )
 69 |         client = GeminiClient(config=config, cache=False, max_tokens=1000)
 70 | 
 71 |         assert client.config == config
 72 |         assert client.model == 'test-model'
 73 |         assert client.temperature == 0.5
 74 |         assert client.max_tokens == 1000
 75 | 
 76 |     @patch('google.genai.Client')
 77 |     def test_init_with_default_model(self, mock_client):
 78 |         """Test initialization with default model when none is provided."""
 79 |         config = LLMConfig(api_key='test_api_key', model=DEFAULT_MODEL)
 80 |         client = GeminiClient(config=config, cache=False)
 81 | 
 82 |         assert client.model == DEFAULT_MODEL
 83 | 
 84 |     @patch('google.genai.Client')
 85 |     def test_init_without_config(self, mock_client):
 86 |         """Test initialization without a config uses defaults."""
 87 |         client = GeminiClient(cache=False)
 88 | 
 89 |         assert client.config is not None
 90 |         # When no config.model is set, it will be None, not DEFAULT_MODEL
 91 |         assert client.model is None
 92 | 
 93 |     @patch('google.genai.Client')
 94 |     def test_init_with_thinking_config(self, mock_client):
 95 |         """Test initialization with thinking config."""
 96 |         with patch('google.genai.types.ThinkingConfig') as mock_thinking_config:
 97 |             thinking_config = mock_thinking_config.return_value
 98 |             client = GeminiClient(thinking_config=thinking_config)
 99 |             assert client.thinking_config == thinking_config
100 | 
101 | 
102 | class TestGeminiClientGenerateResponse:
103 |     """Tests for GeminiClient generate_response method."""
104 | 
105 |     @pytest.mark.asyncio
106 |     async def test_generate_response_simple_text(self, gemini_client, mock_gemini_client):
107 |         """Test successful response generation with simple text."""
108 |         # Setup mock response
109 |         mock_response = MagicMock()
110 |         mock_response.text = 'Test response text'
111 |         mock_response.candidates = []
112 |         mock_response.prompt_feedback = None
113 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
114 | 
115 |         # Call method
116 |         messages = [Message(role='user', content='Test message')]
117 |         result = await gemini_client.generate_response(messages)
118 | 
119 |         # Assertions
120 |         assert isinstance(result, dict)
121 |         assert result['content'] == 'Test response text'
122 |         mock_gemini_client.aio.models.generate_content.assert_called_once()
123 | 
124 |     @pytest.mark.asyncio
125 |     async def test_generate_response_with_structured_output(
126 |         self, gemini_client, mock_gemini_client
127 |     ):
128 |         """Test response generation with structured output."""
129 |         # Setup mock response
130 |         mock_response = MagicMock()
131 |         mock_response.text = '{"test_field": "test_value", "optional_field": 42}'
132 |         mock_response.candidates = []
133 |         mock_response.prompt_feedback = None
134 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
135 | 
136 |         # Call method
137 |         messages = [
138 |             Message(role='system', content='System message'),
139 |             Message(role='user', content='User message'),
140 |         ]
141 |         result = await gemini_client.generate_response(
142 |             messages=messages, response_model=ResponseModel
143 |         )
144 | 
145 |         # Assertions
146 |         assert isinstance(result, dict)
147 |         assert result['test_field'] == 'test_value'
148 |         assert result['optional_field'] == 42
149 |         mock_gemini_client.aio.models.generate_content.assert_called_once()
150 | 
151 |     @pytest.mark.asyncio
152 |     async def test_generate_response_with_system_message(self, gemini_client, mock_gemini_client):
153 |         """Test response generation with system message handling."""
154 |         # Setup mock response
155 |         mock_response = MagicMock()
156 |         mock_response.text = 'Response with system context'
157 |         mock_response.candidates = []
158 |         mock_response.prompt_feedback = None
159 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
160 | 
161 |         # Call method
162 |         messages = [
163 |             Message(role='system', content='System message'),
164 |             Message(role='user', content='User message'),
165 |         ]
166 |         await gemini_client.generate_response(messages)
167 | 
168 |         # Verify system message is processed correctly
169 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
170 |         config = call_args[1]['config']
171 |         assert 'System message' in config.system_instruction
172 | 
173 |     @pytest.mark.asyncio
174 |     async def test_get_model_for_size(self, gemini_client):
175 |         """Test model selection based on size."""
176 |         # Test small model
177 |         small_model = gemini_client._get_model_for_size(ModelSize.small)
178 |         assert small_model == DEFAULT_SMALL_MODEL
179 | 
180 |         # Test medium/large model
181 |         medium_model = gemini_client._get_model_for_size(ModelSize.medium)
182 |         assert medium_model == gemini_client.model
183 | 
184 |     @pytest.mark.asyncio
185 |     async def test_rate_limit_error_handling(self, gemini_client, mock_gemini_client):
186 |         """Test handling of rate limit errors."""
187 |         # Setup mock to raise rate limit error
188 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
189 |             'Rate limit exceeded'
190 |         )
191 | 
192 |         # Call method and check exception
193 |         messages = [Message(role='user', content='Test message')]
194 |         with pytest.raises(RateLimitError):
195 |             await gemini_client.generate_response(messages)
196 | 
197 |     @pytest.mark.asyncio
198 |     async def test_quota_error_handling(self, gemini_client, mock_gemini_client):
199 |         """Test handling of quota errors."""
200 |         # Setup mock to raise quota error
201 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
202 |             'Quota exceeded for requests'
203 |         )
204 | 
205 |         # Call method and check exception
206 |         messages = [Message(role='user', content='Test message')]
207 |         with pytest.raises(RateLimitError):
208 |             await gemini_client.generate_response(messages)
209 | 
210 |     @pytest.mark.asyncio
211 |     async def test_resource_exhausted_error_handling(self, gemini_client, mock_gemini_client):
212 |         """Test handling of resource exhausted errors."""
213 |         # Setup mock to raise resource exhausted error
214 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
215 |             'resource_exhausted: Request limit exceeded'
216 |         )
217 | 
218 |         # Call method and check exception
219 |         messages = [Message(role='user', content='Test message')]
220 |         with pytest.raises(RateLimitError):
221 |             await gemini_client.generate_response(messages)
222 | 
223 |     @pytest.mark.asyncio
224 |     async def test_safety_block_handling(self, gemini_client, mock_gemini_client):
225 |         """Test handling of safety blocks."""
226 |         # Setup mock response with safety block
227 |         mock_candidate = MagicMock()
228 |         mock_candidate.finish_reason = 'SAFETY'
229 |         mock_candidate.safety_ratings = [
230 |             MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
231 |         ]
232 | 
233 |         mock_response = MagicMock()
234 |         mock_response.candidates = [mock_candidate]
235 |         mock_response.prompt_feedback = None
236 |         mock_response.text = ''
237 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
238 | 
239 |         # Call method and check exception
240 |         messages = [Message(role='user', content='Test message')]
241 |         with pytest.raises(Exception, match='Content blocked by safety filters'):
242 |             await gemini_client.generate_response(messages)
243 | 
244 |     @pytest.mark.asyncio
245 |     async def test_prompt_block_handling(self, gemini_client, mock_gemini_client):
246 |         """Test handling of prompt blocks."""
247 |         # Setup mock response with prompt block
248 |         mock_prompt_feedback = MagicMock()
249 |         mock_prompt_feedback.block_reason = 'BLOCKED_REASON_OTHER'
250 | 
251 |         mock_response = MagicMock()
252 |         mock_response.candidates = []
253 |         mock_response.prompt_feedback = mock_prompt_feedback
254 |         mock_response.text = ''
255 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
256 | 
257 |         # Call method and check exception
258 |         messages = [Message(role='user', content='Test message')]
259 |         with pytest.raises(Exception, match='Content blocked by safety filters'):
260 |             await gemini_client.generate_response(messages)
261 | 
262 |     @pytest.mark.asyncio
263 |     async def test_structured_output_parsing_error(self, gemini_client, mock_gemini_client):
264 |         """Test handling of structured output parsing errors."""
265 |         # Setup mock response with invalid JSON that will exhaust retries
266 |         mock_response = MagicMock()
267 |         mock_response.text = 'Invalid JSON that cannot be parsed'
268 |         mock_response.candidates = []
269 |         mock_response.prompt_feedback = None
270 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
271 | 
272 |         # Call method and check exception - should exhaust retries
273 |         messages = [Message(role='user', content='Test message')]
274 |         with pytest.raises(Exception):  # noqa: B017
275 |             await gemini_client.generate_response(messages, response_model=ResponseModel)
276 | 
277 |         # Should have called generate_content MAX_RETRIES times (2 attempts total)
278 |         assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
279 | 
280 |     @pytest.mark.asyncio
281 |     async def test_retry_logic_with_safety_block(self, gemini_client, mock_gemini_client):
282 |         """Test that safety blocks are not retried."""
283 |         # Setup mock response with safety block
284 |         mock_candidate = MagicMock()
285 |         mock_candidate.finish_reason = 'SAFETY'
286 |         mock_candidate.safety_ratings = [
287 |             MagicMock(blocked=True, category='HARM_CATEGORY_HARASSMENT', probability='HIGH')
288 |         ]
289 | 
290 |         mock_response = MagicMock()
291 |         mock_response.candidates = [mock_candidate]
292 |         mock_response.prompt_feedback = None
293 |         mock_response.text = ''
294 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
295 | 
296 |         # Call method and check that it doesn't retry
297 |         messages = [Message(role='user', content='Test message')]
298 |         with pytest.raises(Exception, match='Content blocked by safety filters'):
299 |             await gemini_client.generate_response(messages)
300 | 
301 |         # Should only be called once (no retries for safety blocks)
302 |         assert mock_gemini_client.aio.models.generate_content.call_count == 1
303 | 
304 |     @pytest.mark.asyncio
305 |     async def test_retry_logic_with_validation_error(self, gemini_client, mock_gemini_client):
306 |         """Test retry behavior on validation error."""
307 |         # First call returns invalid JSON, second call returns valid data
308 |         mock_response1 = MagicMock()
309 |         mock_response1.text = 'Invalid JSON that cannot be parsed'
310 |         mock_response1.candidates = []
311 |         mock_response1.prompt_feedback = None
312 | 
313 |         mock_response2 = MagicMock()
314 |         mock_response2.text = '{"test_field": "correct_value"}'
315 |         mock_response2.candidates = []
316 |         mock_response2.prompt_feedback = None
317 | 
318 |         mock_gemini_client.aio.models.generate_content.side_effect = [
319 |             mock_response1,
320 |             mock_response2,
321 |         ]
322 | 
323 |         # Call method
324 |         messages = [Message(role='user', content='Test message')]
325 |         result = await gemini_client.generate_response(messages, response_model=ResponseModel)
326 | 
327 |         # Should have called generate_content twice due to retry
328 |         assert mock_gemini_client.aio.models.generate_content.call_count == 2
329 |         assert result['test_field'] == 'correct_value'
330 | 
331 |     @pytest.mark.asyncio
332 |     async def test_max_retries_exceeded(self, gemini_client, mock_gemini_client):
333 |         """Test behavior when max retries are exceeded."""
334 |         # Setup mock to always return invalid JSON
335 |         mock_response = MagicMock()
336 |         mock_response.text = 'Invalid JSON that cannot be parsed'
337 |         mock_response.candidates = []
338 |         mock_response.prompt_feedback = None
339 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 | 
341 |         # Call method and check exception
342 |         messages = [Message(role='user', content='Test message')]
343 |         with pytest.raises(Exception):  # noqa: B017
344 |             await gemini_client.generate_response(messages, response_model=ResponseModel)
345 | 
346 |         # Should have called generate_content MAX_RETRIES times (2 attempts total)
347 |         assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
348 | 
349 |     @pytest.mark.asyncio
350 |     async def test_empty_response_handling(self, gemini_client, mock_gemini_client):
351 |         """Test handling of empty responses."""
352 |         # Setup mock response with no text
353 |         mock_response = MagicMock()
354 |         mock_response.text = ''
355 |         mock_response.candidates = []
356 |         mock_response.prompt_feedback = None
357 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
358 | 
359 |         # Call method with structured output and check exception
360 |         messages = [Message(role='user', content='Test message')]
361 |         with pytest.raises(Exception):  # noqa: B017
362 |             await gemini_client.generate_response(messages, response_model=ResponseModel)
363 | 
364 |         # Should have exhausted retries due to empty response (2 attempts total)
365 |         assert mock_gemini_client.aio.models.generate_content.call_count == GeminiClient.MAX_RETRIES
366 | 
367 |     @pytest.mark.asyncio
368 |     async def test_custom_max_tokens(self, gemini_client, mock_gemini_client):
369 |         """Test that explicit max_tokens parameter takes precedence over all other values."""
370 |         # Setup mock response
371 |         mock_response = MagicMock()
372 |         mock_response.text = 'Test response'
373 |         mock_response.candidates = []
374 |         mock_response.prompt_feedback = None
375 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
376 | 
377 |         # Call method with custom max tokens (should take precedence)
378 |         messages = [Message(role='user', content='Test message')]
379 |         await gemini_client.generate_response(messages, max_tokens=500)
380 | 
381 |         # Verify explicit max_tokens parameter takes precedence
382 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
383 |         config = call_args[1]['config']
384 |         # Explicit parameter should override everything else
385 |         assert config.max_output_tokens == 500
386 | 
387 |     @pytest.mark.asyncio
388 |     async def test_max_tokens_precedence_fallback(self, mock_gemini_client):
389 |         """Test max_tokens precedence when no explicit parameter is provided."""
390 |         # Setup mock response
391 |         mock_response = MagicMock()
392 |         mock_response.text = 'Test response'
393 |         mock_response.candidates = []
394 |         mock_response.prompt_feedback = None
395 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
396 | 
397 |         # Test case 1: No explicit max_tokens, has instance max_tokens
398 |         config = LLMConfig(
399 |             api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
400 |         )
401 |         client = GeminiClient(
402 |             config=config, cache=False, max_tokens=2000, client=mock_gemini_client
403 |         )
404 | 
405 |         messages = [Message(role='user', content='Test message')]
406 |         await client.generate_response(messages)
407 | 
408 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
409 |         config = call_args[1]['config']
410 |         # Instance max_tokens should be used
411 |         assert config.max_output_tokens == 2000
412 | 
413 |         # Test case 2: No explicit max_tokens, no instance max_tokens, uses model mapping
414 |         config = LLMConfig(api_key='test_api_key', model='gemini-2.5-flash', temperature=0.5)
415 |         client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
416 | 
417 |         messages = [Message(role='user', content='Test message')]
418 |         await client.generate_response(messages)
419 | 
420 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
421 |         config = call_args[1]['config']
422 |         # Model mapping should be used
423 |         assert config.max_output_tokens == 65536
424 | 
425 |     @pytest.mark.asyncio
426 |     async def test_model_size_selection(self, gemini_client, mock_gemini_client):
427 |         """Test that the correct model is selected based on model size."""
428 |         # Setup mock response
429 |         mock_response = MagicMock()
430 |         mock_response.text = 'Test response'
431 |         mock_response.candidates = []
432 |         mock_response.prompt_feedback = None
433 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
434 | 
435 |         # Call method with small model size
436 |         messages = [Message(role='user', content='Test message')]
437 |         await gemini_client.generate_response(messages, model_size=ModelSize.small)
438 | 
439 |         # Verify correct model is used
440 |         call_args = mock_gemini_client.aio.models.generate_content.call_args
441 |         assert call_args[1]['model'] == DEFAULT_SMALL_MODEL
442 | 
443 |     @pytest.mark.asyncio
444 |     async def test_gemini_model_max_tokens_mapping(self, mock_gemini_client):
445 |         """Test that different Gemini models use their correct max tokens."""
446 |         # Setup mock response
447 |         mock_response = MagicMock()
448 |         mock_response.text = 'Test response'
449 |         mock_response.candidates = []
450 |         mock_response.prompt_feedback = None
451 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
452 | 
453 |         # Test data: (model_name, expected_max_tokens)
454 |         test_cases = [
455 |             ('gemini-2.5-flash', 65536),
456 |             ('gemini-2.5-pro', 65536),
457 |             ('gemini-2.5-flash-lite', 64000),
458 |             ('gemini-2.0-flash', 8192),
459 |             ('gemini-1.5-pro', 8192),
460 |             ('gemini-1.5-flash', 8192),
461 |             ('unknown-model', 8192),  # Fallback case
462 |         ]
463 | 
464 |         for model_name, expected_max_tokens in test_cases:
465 |             # Create client with specific model, no explicit max_tokens to test mapping
466 |             config = LLMConfig(api_key='test_api_key', model=model_name, temperature=0.5)
467 |             client = GeminiClient(config=config, cache=False, client=mock_gemini_client)
468 | 
469 |             # Call method without explicit max_tokens to test model mapping fallback
470 |             messages = [Message(role='user', content='Test message')]
471 |             await client.generate_response(messages)
472 | 
473 |             # Verify correct max tokens is used from model mapping
474 |             call_args = mock_gemini_client.aio.models.generate_content.call_args
475 |             config = call_args[1]['config']
476 |             assert config.max_output_tokens == expected_max_tokens, (
477 |                 f'Model {model_name} should use {expected_max_tokens} tokens'
478 |             )
479 | 
480 | 
481 | if __name__ == '__main__':
482 |     pytest.main(['-v', 'test_gemini_client.py'])
483 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/content_chunking.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import re
 20 | 
 21 | from graphiti_core.helpers import (
 22 |     CHUNK_DENSITY_THRESHOLD,
 23 |     CHUNK_MIN_TOKENS,
 24 |     CHUNK_OVERLAP_TOKENS,
 25 |     CHUNK_TOKEN_SIZE,
 26 | )
 27 | from graphiti_core.nodes import EpisodeType
 28 | 
 29 | logger = logging.getLogger(__name__)
 30 | 
 31 | # Approximate characters per token (conservative estimate)
 32 | CHARS_PER_TOKEN = 4
 33 | 
 34 | 
 35 | def estimate_tokens(text: str) -> int:
 36 |     """Estimate token count using character-based heuristic.
 37 | 
 38 |     Uses ~4 characters per token as a conservative estimate.
 39 |     This is faster than actual tokenization and works across all LLM providers.
 40 | 
 41 |     Args:
 42 |         text: The text to estimate tokens for
 43 | 
 44 |     Returns:
 45 |         Estimated token count
 46 |     """
 47 |     return len(text) // CHARS_PER_TOKEN
 48 | 
 49 | 
 50 | def _tokens_to_chars(tokens: int) -> int:
 51 |     """Convert token count to approximate character count."""
 52 |     return tokens * CHARS_PER_TOKEN
 53 | 
 54 | 
 55 | def should_chunk(content: str, episode_type: EpisodeType) -> bool:
 56 |     """Determine whether content should be chunked based on size and entity density.
 57 | 
 58 |     Only chunks content that is both:
 59 |     1. Large enough to potentially cause LLM issues (>= CHUNK_MIN_TOKENS)
 60 |     2. High entity density (many entities per token)
 61 | 
 62 |     Short content processes fine regardless of density. This targets the specific
 63 |     failure case of large entity-dense inputs while preserving context for
 64 |     prose/narrative content and avoiding unnecessary chunking of small inputs.
 65 | 
 66 |     Args:
 67 |         content: The content to evaluate
 68 |         episode_type: Type of episode (json, message, text)
 69 | 
 70 |     Returns:
 71 |         True if content is large and has high entity density
 72 |     """
 73 |     tokens = estimate_tokens(content)
 74 | 
 75 |     # Short content always processes fine - no need to chunk
 76 |     if tokens < CHUNK_MIN_TOKENS:
 77 |         return False
 78 | 
 79 |     return _estimate_high_density(content, episode_type, tokens)
 80 | 
 81 | 
 82 | def _estimate_high_density(content: str, episode_type: EpisodeType, tokens: int) -> bool:
 83 |     """Estimate whether content has high entity density.
 84 | 
 85 |     High-density content (many entities per token) benefits from chunking.
 86 |     Low-density content (prose, narratives) loses context when chunked.
 87 | 
 88 |     Args:
 89 |         content: The content to analyze
 90 |         episode_type: Type of episode
 91 |         tokens: Pre-computed token count
 92 | 
 93 |     Returns:
 94 |         True if content appears to have high entity density
 95 |     """
 96 |     if episode_type == EpisodeType.json:
 97 |         return _json_likely_dense(content, tokens)
 98 |     else:
 99 |         return _text_likely_dense(content, tokens)
100 | 
101 | 
102 | def _json_likely_dense(content: str, tokens: int) -> bool:
103 |     """Estimate entity density for JSON content.
104 | 
105 |     JSON is considered dense if it has many array elements or object keys,
106 |     as each typically represents a distinct entity or data point.
107 | 
108 |     Heuristics:
109 |     - Array: Count elements, estimate entities per 1000 tokens
110 |     - Object: Count top-level keys
111 | 
112 |     Args:
113 |         content: JSON string content
114 |         tokens: Token count
115 | 
116 |     Returns:
117 |         True if JSON appears to have high entity density
118 |     """
119 |     try:
120 |         data = json.loads(content)
121 |     except json.JSONDecodeError:
122 |         # Invalid JSON, fall back to text heuristics
123 |         return _text_likely_dense(content, tokens)
124 | 
125 |     if isinstance(data, list):
126 |         # For arrays, each element likely contains entities
127 |         element_count = len(data)
128 |         # Estimate density: elements per 1000 tokens
129 |         density = (element_count / tokens) * 1000 if tokens > 0 else 0
130 |         return density > CHUNK_DENSITY_THRESHOLD * 1000  # Scale threshold
131 |     elif isinstance(data, dict):
132 |         # For objects, count keys recursively (shallow)
133 |         key_count = _count_json_keys(data, max_depth=2)
134 |         density = (key_count / tokens) * 1000 if tokens > 0 else 0
135 |         return density > CHUNK_DENSITY_THRESHOLD * 1000
136 |     else:
137 |         # Scalar value, no need to chunk
138 |         return False
139 | 
140 | 
141 | def _count_json_keys(data: dict, max_depth: int = 2, current_depth: int = 0) -> int:
142 |     """Count keys in a JSON object up to a certain depth.
143 | 
144 |     Args:
145 |         data: Dictionary to count keys in
146 |         max_depth: Maximum depth to traverse
147 |         current_depth: Current recursion depth
148 | 
149 |     Returns:
150 |         Count of keys
151 |     """
152 |     if current_depth >= max_depth:
153 |         return 0
154 | 
155 |     count = len(data)
156 |     for value in data.values():
157 |         if isinstance(value, dict):
158 |             count += _count_json_keys(value, max_depth, current_depth + 1)
159 |         elif isinstance(value, list):
160 |             for item in value:
161 |                 if isinstance(item, dict):
162 |                     count += _count_json_keys(item, max_depth, current_depth + 1)
163 |     return count
164 | 
165 | 
166 | def _text_likely_dense(content: str, tokens: int) -> bool:
167 |     """Estimate entity density for text content.
168 | 
169 |     Uses capitalized words as a proxy for named entities (people, places,
170 |     organizations, products). High ratio of capitalized words suggests
171 |     high entity density.
172 | 
173 |     Args:
174 |         content: Text content
175 |         tokens: Token count
176 | 
177 |     Returns:
178 |         True if text appears to have high entity density
179 |     """
180 |     if tokens == 0:
181 |         return False
182 | 
183 |     # Split into words
184 |     words = content.split()
185 |     if not words:
186 |         return False
187 | 
188 |     # Count capitalized words (excluding sentence starters)
189 |     # A word is "capitalized" if it starts with uppercase and isn't all caps
190 |     capitalized_count = 0
191 |     for i, word in enumerate(words):
192 |         # Skip if it's likely a sentence starter (after . ! ? or first word)
193 |         if i == 0:
194 |             continue
195 |         if i > 0 and words[i - 1].rstrip()[-1:] in '.!?':
196 |             continue
197 | 
198 |         # Check if capitalized (first char upper, not all caps)
199 |         cleaned = word.strip('.,!?;:\'"()[]{}')
200 |         if cleaned and cleaned[0].isupper() and not cleaned.isupper():
201 |             capitalized_count += 1
202 | 
203 |     # Calculate density: capitalized words per 1000 tokens
204 |     density = (capitalized_count / tokens) * 1000 if tokens > 0 else 0
205 | 
206 |     # Text density threshold is typically lower than JSON
207 |     # A well-written article might have 5-10% named entities
208 |     return density > CHUNK_DENSITY_THRESHOLD * 500  # Half the JSON threshold
209 | 
210 | 
211 | def chunk_json_content(
212 |     content: str,
213 |     chunk_size_tokens: int | None = None,
214 |     overlap_tokens: int | None = None,
215 | ) -> list[str]:
216 |     """Split JSON content into chunks while preserving structure.
217 | 
218 |     For arrays: splits at element boundaries, keeping complete objects.
219 |     For objects: splits at top-level key boundaries.
220 | 
221 |     Args:
222 |         content: JSON string to chunk
223 |         chunk_size_tokens: Target size per chunk in tokens (default from env)
224 |         overlap_tokens: Overlap between chunks in tokens (default from env)
225 | 
226 |     Returns:
227 |         List of JSON string chunks
228 |     """
229 |     chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
230 |     overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
231 | 
232 |     chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
233 |     overlap_chars = _tokens_to_chars(overlap_tokens)
234 | 
235 |     try:
236 |         data = json.loads(content)
237 |     except json.JSONDecodeError:
238 |         logger.warning('Failed to parse JSON, falling back to text chunking')
239 |         return chunk_text_content(content, chunk_size_tokens, overlap_tokens)
240 | 
241 |     if isinstance(data, list):
242 |         return _chunk_json_array(data, chunk_size_chars, overlap_chars)
243 |     elif isinstance(data, dict):
244 |         return _chunk_json_object(data, chunk_size_chars, overlap_chars)
245 |     else:
246 |         # Scalar value, return as-is
247 |         return [content]
248 | 
249 | 
250 | def _chunk_json_array(
251 |     data: list,
252 |     chunk_size_chars: int,
253 |     overlap_chars: int,
254 | ) -> list[str]:
255 |     """Chunk a JSON array by splitting at element boundaries."""
256 |     if not data:
257 |         return ['[]']
258 | 
259 |     chunks: list[str] = []
260 |     current_elements: list = []
261 |     current_size = 2  # Account for '[]'
262 | 
263 |     for element in data:
264 |         element_json = json.dumps(element)
265 |         element_size = len(element_json) + 2  # Account for comma and space
266 | 
267 |         # Check if adding this element would exceed chunk size
268 |         if current_elements and current_size + element_size > chunk_size_chars:
269 |             # Save current chunk
270 |             chunks.append(json.dumps(current_elements))
271 | 
272 |             # Start new chunk with overlap (include last few elements)
273 |             overlap_elements = _get_overlap_elements(current_elements, overlap_chars)
274 |             current_elements = overlap_elements
275 |             current_size = len(json.dumps(current_elements)) if current_elements else 2
276 | 
277 |         current_elements.append(element)
278 |         current_size += element_size
279 | 
280 |     # Don't forget the last chunk
281 |     if current_elements:
282 |         chunks.append(json.dumps(current_elements))
283 | 
284 |     return chunks if chunks else ['[]']
285 | 
286 | 
287 | def _get_overlap_elements(elements: list, overlap_chars: int) -> list:
288 |     """Get elements from the end of a list that fit within overlap_chars."""
289 |     if not elements:
290 |         return []
291 | 
292 |     overlap_elements: list = []
293 |     current_size = 2  # Account for '[]'
294 | 
295 |     for element in reversed(elements):
296 |         element_json = json.dumps(element)
297 |         element_size = len(element_json) + 2
298 | 
299 |         if current_size + element_size > overlap_chars:
300 |             break
301 | 
302 |         overlap_elements.insert(0, element)
303 |         current_size += element_size
304 | 
305 |     return overlap_elements
306 | 
307 | 
308 | def _chunk_json_object(
309 |     data: dict,
310 |     chunk_size_chars: int,
311 |     overlap_chars: int,
312 | ) -> list[str]:
313 |     """Chunk a JSON object by splitting at top-level key boundaries."""
314 |     if not data:
315 |         return ['{}']
316 | 
317 |     chunks: list[str] = []
318 |     current_keys: list[str] = []
319 |     current_dict: dict = {}
320 |     current_size = 2  # Account for '{}'
321 | 
322 |     for key, value in data.items():
323 |         entry_json = json.dumps({key: value})
324 |         entry_size = len(entry_json)
325 | 
326 |         # Check if adding this entry would exceed chunk size
327 |         if current_dict and current_size + entry_size > chunk_size_chars:
328 |             # Save current chunk
329 |             chunks.append(json.dumps(current_dict))
330 | 
331 |             # Start new chunk with overlap (include last few keys)
332 |             overlap_dict = _get_overlap_dict(current_dict, current_keys, overlap_chars)
333 |             current_dict = overlap_dict
334 |             current_keys = list(overlap_dict.keys())
335 |             current_size = len(json.dumps(current_dict)) if current_dict else 2
336 | 
337 |         current_dict[key] = value
338 |         current_keys.append(key)
339 |         current_size += entry_size
340 | 
341 |     # Don't forget the last chunk
342 |     if current_dict:
343 |         chunks.append(json.dumps(current_dict))
344 | 
345 |     return chunks if chunks else ['{}']
346 | 
347 | 
348 | def _get_overlap_dict(data: dict, keys: list[str], overlap_chars: int) -> dict:
349 |     """Get key-value pairs from the end of a dict that fit within overlap_chars."""
350 |     if not data or not keys:
351 |         return {}
352 | 
353 |     overlap_dict: dict = {}
354 |     current_size = 2  # Account for '{}'
355 | 
356 |     for key in reversed(keys):
357 |         if key not in data:
358 |             continue
359 |         entry_json = json.dumps({key: data[key]})
360 |         entry_size = len(entry_json)
361 | 
362 |         if current_size + entry_size > overlap_chars:
363 |             break
364 | 
365 |         overlap_dict[key] = data[key]
366 |         current_size += entry_size
367 | 
368 |     # Reverse to maintain original order
369 |     return dict(reversed(list(overlap_dict.items())))
370 | 
371 | 
372 | def chunk_text_content(
373 |     content: str,
374 |     chunk_size_tokens: int | None = None,
375 |     overlap_tokens: int | None = None,
376 | ) -> list[str]:
377 |     """Split text content at natural boundaries (paragraphs, sentences).
378 | 
379 |     Includes overlap to capture entities at chunk boundaries.
380 | 
381 |     Args:
382 |         content: Text to chunk
383 |         chunk_size_tokens: Target size per chunk in tokens (default from env)
384 |         overlap_tokens: Overlap between chunks in tokens (default from env)
385 | 
386 |     Returns:
387 |         List of text chunks
388 |     """
389 |     chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
390 |     overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
391 | 
392 |     chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
393 |     overlap_chars = _tokens_to_chars(overlap_tokens)
394 | 
395 |     if len(content) <= chunk_size_chars:
396 |         return [content]
397 | 
398 |     # Split into paragraphs first
399 |     paragraphs = re.split(r'\n\s*\n', content)
400 | 
401 |     chunks: list[str] = []
402 |     current_chunk: list[str] = []
403 |     current_size = 0
404 | 
405 |     for paragraph in paragraphs:
406 |         paragraph = paragraph.strip()
407 |         if not paragraph:
408 |             continue
409 | 
410 |         para_size = len(paragraph)
411 | 
412 |         # If a single paragraph is too large, split it by sentences
413 |         if para_size > chunk_size_chars:
414 |             # First, save current chunk if any
415 |             if current_chunk:
416 |                 chunks.append('\n\n'.join(current_chunk))
417 |                 current_chunk = []
418 |                 current_size = 0
419 | 
420 |             # Split large paragraph by sentences
421 |             sentence_chunks = _chunk_by_sentences(paragraph, chunk_size_chars, overlap_chars)
422 |             chunks.extend(sentence_chunks)
423 |             continue
424 | 
425 |         # Check if adding this paragraph would exceed chunk size
426 |         if current_chunk and current_size + para_size + 2 > chunk_size_chars:
427 |             # Save current chunk
428 |             chunks.append('\n\n'.join(current_chunk))
429 | 
430 |             # Start new chunk with overlap
431 |             overlap_text = _get_overlap_text('\n\n'.join(current_chunk), overlap_chars)
432 |             if overlap_text:
433 |                 current_chunk = [overlap_text]
434 |                 current_size = len(overlap_text)
435 |             else:
436 |                 current_chunk = []
437 |                 current_size = 0
438 | 
439 |         current_chunk.append(paragraph)
440 |         current_size += para_size + 2  # Account for '\n\n'
441 | 
442 |     # Don't forget the last chunk
443 |     if current_chunk:
444 |         chunks.append('\n\n'.join(current_chunk))
445 | 
446 |     return chunks if chunks else [content]
447 | 
448 | 
449 | def _chunk_by_sentences(
450 |     text: str,
451 |     chunk_size_chars: int,
452 |     overlap_chars: int,
453 | ) -> list[str]:
454 |     """Split text by sentence boundaries."""
455 |     # Split on sentence-ending punctuation followed by whitespace
456 |     sentence_pattern = r'(?<=[.!?])\s+'
457 |     sentences = re.split(sentence_pattern, text)
458 | 
459 |     chunks: list[str] = []
460 |     current_chunk: list[str] = []
461 |     current_size = 0
462 | 
463 |     for sentence in sentences:
464 |         sentence = sentence.strip()
465 |         if not sentence:
466 |             continue
467 | 
468 |         sent_size = len(sentence)
469 | 
470 |         # If a single sentence is too large, split it by fixed size
471 |         if sent_size > chunk_size_chars:
472 |             if current_chunk:
473 |                 chunks.append(' '.join(current_chunk))
474 |                 current_chunk = []
475 |                 current_size = 0
476 | 
477 |             # Split by fixed size as last resort
478 |             fixed_chunks = _chunk_by_size(sentence, chunk_size_chars, overlap_chars)
479 |             chunks.extend(fixed_chunks)
480 |             continue
481 | 
482 |         # Check if adding this sentence would exceed chunk size
483 |         if current_chunk and current_size + sent_size + 1 > chunk_size_chars:
484 |             chunks.append(' '.join(current_chunk))
485 | 
486 |             # Start new chunk with overlap
487 |             overlap_text = _get_overlap_text(' '.join(current_chunk), overlap_chars)
488 |             if overlap_text:
489 |                 current_chunk = [overlap_text]
490 |                 current_size = len(overlap_text)
491 |             else:
492 |                 current_chunk = []
493 |                 current_size = 0
494 | 
495 |         current_chunk.append(sentence)
496 |         current_size += sent_size + 1
497 | 
498 |     if current_chunk:
499 |         chunks.append(' '.join(current_chunk))
500 | 
501 |     return chunks
502 | 
503 | 
504 | def _chunk_by_size(
505 |     text: str,
506 |     chunk_size_chars: int,
507 |     overlap_chars: int,
508 | ) -> list[str]:
509 |     """Split text by fixed character size (last resort)."""
510 |     chunks: list[str] = []
511 |     start = 0
512 | 
513 |     while start < len(text):
514 |         end = min(start + chunk_size_chars, len(text))
515 | 
516 |         # Try to break at word boundary
517 |         if end < len(text):
518 |             space_idx = text.rfind(' ', start, end)
519 |             if space_idx > start:
520 |                 end = space_idx
521 | 
522 |         chunks.append(text[start:end].strip())
523 | 
524 |         # Move start forward, ensuring progress even if overlap >= chunk_size
525 |         # Always advance by at least (chunk_size - overlap) or 1 char minimum
526 |         min_progress = max(1, chunk_size_chars - overlap_chars)
527 |         start = max(start + min_progress, end - overlap_chars)
528 | 
529 |     return chunks
530 | 
531 | 
532 | def _get_overlap_text(text: str, overlap_chars: int) -> str:
533 |     """Get the last overlap_chars characters of text, breaking at word boundary."""
534 |     if len(text) <= overlap_chars:
535 |         return text
536 | 
537 |     overlap_start = len(text) - overlap_chars
538 |     # Find the next word boundary after overlap_start
539 |     space_idx = text.find(' ', overlap_start)
540 |     if space_idx != -1:
541 |         return text[space_idx + 1 :]
542 |     return text[overlap_start:]
543 | 
544 | 
545 | def chunk_message_content(
546 |     content: str,
547 |     chunk_size_tokens: int | None = None,
548 |     overlap_tokens: int | None = None,
549 | ) -> list[str]:
550 |     """Split conversation content preserving message boundaries.
551 | 
552 |     Never splits mid-message. Messages are identified by patterns like:
553 |     - "Speaker: message"
554 |     - JSON message arrays
555 |     - Newline-separated messages
556 | 
557 |     Args:
558 |         content: Conversation content to chunk
559 |         chunk_size_tokens: Target size per chunk in tokens (default from env)
560 |         overlap_tokens: Overlap between chunks in tokens (default from env)
561 | 
562 |     Returns:
563 |         List of conversation chunks
564 |     """
565 |     chunk_size_tokens = chunk_size_tokens or CHUNK_TOKEN_SIZE
566 |     overlap_tokens = overlap_tokens or CHUNK_OVERLAP_TOKENS
567 | 
568 |     chunk_size_chars = _tokens_to_chars(chunk_size_tokens)
569 |     overlap_chars = _tokens_to_chars(overlap_tokens)
570 | 
571 |     if len(content) <= chunk_size_chars:
572 |         return [content]
573 | 
574 |     # Try to detect message format
575 |     # Check if it's JSON (array of message objects)
576 |     try:
577 |         data = json.loads(content)
578 |         if isinstance(data, list):
579 |             return _chunk_message_array(data, chunk_size_chars, overlap_chars)
580 |     except json.JSONDecodeError:
581 |         pass
582 | 
583 |     # Try speaker pattern (e.g., "Alice: Hello")
584 |     speaker_pattern = r'^([A-Za-z_][A-Za-z0-9_\s]*):(.+?)(?=^[A-Za-z_][A-Za-z0-9_\s]*:|$)'
585 |     if re.search(speaker_pattern, content, re.MULTILINE | re.DOTALL):
586 |         return _chunk_speaker_messages(content, chunk_size_chars, overlap_chars)
587 | 
588 |     # Fallback to line-based chunking
589 |     return _chunk_by_lines(content, chunk_size_chars, overlap_chars)
590 | 
591 | 
592 | def _chunk_message_array(
593 |     messages: list,
594 |     chunk_size_chars: int,
595 |     overlap_chars: int,
596 | ) -> list[str]:
597 |     """Chunk a JSON array of message objects."""
598 |     # Delegate to JSON array chunking
599 |     chunks = _chunk_json_array(messages, chunk_size_chars, overlap_chars)
600 |     return chunks
601 | 
602 | 
603 | def _chunk_speaker_messages(
604 |     content: str,
605 |     chunk_size_chars: int,
606 |     overlap_chars: int,
607 | ) -> list[str]:
608 |     """Chunk messages in 'Speaker: message' format."""
609 |     # Split on speaker patterns
610 |     pattern = r'(?=^[A-Za-z_][A-Za-z0-9_\s]*:)'
611 |     messages = re.split(pattern, content, flags=re.MULTILINE)
612 |     messages = [m.strip() for m in messages if m.strip()]
613 | 
614 |     if not messages:
615 |         return [content]
616 | 
617 |     chunks: list[str] = []
618 |     current_messages: list[str] = []
619 |     current_size = 0
620 | 
621 |     for message in messages:
622 |         msg_size = len(message)
623 | 
624 |         # If a single message is too large, include it as its own chunk
625 |         if msg_size > chunk_size_chars:
626 |             if current_messages:
627 |                 chunks.append('\n'.join(current_messages))
628 |                 current_messages = []
629 |                 current_size = 0
630 |             chunks.append(message)
631 |             continue
632 | 
633 |         if current_messages and current_size + msg_size + 1 > chunk_size_chars:
634 |             chunks.append('\n'.join(current_messages))
635 | 
636 |             # Get overlap (last message(s) that fit)
637 |             overlap_messages = _get_overlap_messages(current_messages, overlap_chars)
638 |             current_messages = overlap_messages
639 |             current_size = sum(len(m) for m in current_messages) + len(current_messages) - 1
640 | 
641 |         current_messages.append(message)
642 |         current_size += msg_size + 1
643 | 
644 |     if current_messages:
645 |         chunks.append('\n'.join(current_messages))
646 | 
647 |     return chunks if chunks else [content]
648 | 
649 | 
650 | def _get_overlap_messages(messages: list[str], overlap_chars: int) -> list[str]:
651 |     """Get messages from the end that fit within overlap_chars."""
652 |     if not messages:
653 |         return []
654 | 
655 |     overlap: list[str] = []
656 |     current_size = 0
657 | 
658 |     for msg in reversed(messages):
659 |         msg_size = len(msg) + 1
660 |         if current_size + msg_size > overlap_chars:
661 |             break
662 |         overlap.insert(0, msg)
663 |         current_size += msg_size
664 | 
665 |     return overlap
666 | 
667 | 
668 | def _chunk_by_lines(
669 |     content: str,
670 |     chunk_size_chars: int,
671 |     overlap_chars: int,
672 | ) -> list[str]:
673 |     """Chunk content by line boundaries."""
674 |     lines = content.split('\n')
675 | 
676 |     chunks: list[str] = []
677 |     current_lines: list[str] = []
678 |     current_size = 0
679 | 
680 |     for line in lines:
681 |         line_size = len(line) + 1
682 | 
683 |         if current_lines and current_size + line_size > chunk_size_chars:
684 |             chunks.append('\n'.join(current_lines))
685 | 
686 |             # Get overlap lines
687 |             overlap_text = '\n'.join(current_lines)
688 |             overlap = _get_overlap_text(overlap_text, overlap_chars)
689 |             if overlap:
690 |                 current_lines = overlap.split('\n')
691 |                 current_size = len(overlap)
692 |             else:
693 |                 current_lines = []
694 |                 current_size = 0
695 | 
696 |         current_lines.append(line)
697 |         current_size += line_size
698 | 
699 |     if current_lines:
700 |         chunks.append('\n'.join(current_lines))
701 | 
702 |     return chunks if chunks else [content]
703 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | from collections.abc import Awaitable, Callable
 19 | from time import time
 20 | from typing import Any
 21 | 
 22 | from pydantic import BaseModel
 23 | 
 24 | from graphiti_core.graphiti_types import GraphitiClients
 25 | from graphiti_core.helpers import semaphore_gather
 26 | from graphiti_core.llm_client import LLMClient
 27 | from graphiti_core.llm_client.config import ModelSize
 28 | from graphiti_core.nodes import (
 29 |     EntityNode,
 30 |     EpisodeType,
 31 |     EpisodicNode,
 32 |     create_entity_node_embeddings,
 33 | )
 34 | from graphiti_core.prompts import prompt_library
 35 | from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
 36 | from graphiti_core.prompts.extract_nodes import (
 37 |     EntitySummary,
 38 |     ExtractedEntities,
 39 |     ExtractedEntity,
 40 |     MissedEntities,
 41 | )
 42 | from graphiti_core.search.search import search
 43 | from graphiti_core.search.search_config import SearchResults
 44 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
 45 | from graphiti_core.search.search_filters import SearchFilters
 46 | from graphiti_core.utils.content_chunking import (
 47 |     chunk_json_content,
 48 |     chunk_message_content,
 49 |     chunk_text_content,
 50 |     should_chunk,
 51 | )
 52 | from graphiti_core.utils.datetime_utils import utc_now
 53 | from graphiti_core.utils.maintenance.dedup_helpers import (
 54 |     DedupCandidateIndexes,
 55 |     DedupResolutionState,
 56 |     _build_candidate_indexes,
 57 |     _resolve_with_similarity,
 58 | )
 59 | from graphiti_core.utils.maintenance.edge_operations import (
 60 |     filter_existing_duplicate_of_edges,
 61 | )
 62 | from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
 63 | 
 64 | logger = logging.getLogger(__name__)
 65 | 
 66 | NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
 67 | 
 68 | 
 69 | async def extract_nodes_reflexion(
 70 |     llm_client: LLMClient,
 71 |     episode: EpisodicNode,
 72 |     previous_episodes: list[EpisodicNode],
 73 |     node_names: list[str],
 74 |     group_id: str | None = None,
 75 | ) -> list[str]:
 76 |     # Prepare context for LLM
 77 |     context = {
 78 |         'episode_content': episode.content,
 79 |         'previous_episodes': [ep.content for ep in previous_episodes],
 80 |         'extracted_entities': node_names,
 81 |     }
 82 | 
 83 |     llm_response = await llm_client.generate_response(
 84 |         prompt_library.extract_nodes.reflexion(context),
 85 |         MissedEntities,
 86 |         group_id=group_id,
 87 |         prompt_name='extract_nodes.reflexion',
 88 |     )
 89 |     missed_entities = llm_response.get('missed_entities', [])
 90 | 
 91 |     return missed_entities
 92 | 
 93 | 
 94 | async def extract_nodes(
 95 |     clients: GraphitiClients,
 96 |     episode: EpisodicNode,
 97 |     previous_episodes: list[EpisodicNode],
 98 |     entity_types: dict[str, type[BaseModel]] | None = None,
 99 |     excluded_entity_types: list[str] | None = None,
100 |     custom_extraction_instructions: str | None = None,
101 | ) -> list[EntityNode]:
102 |     """Extract entity nodes from an episode with adaptive chunking.
103 | 
104 |     For high-density content (many entities per token), the content is chunked
105 |     and processed in parallel to avoid LLM timeouts and truncation issues.
106 |     """
107 |     start = time()
108 |     llm_client = clients.llm_client
109 | 
110 |     # Build entity types context
111 |     entity_types_context = _build_entity_types_context(entity_types)
112 | 
113 |     # Build base context
114 |     context = {
115 |         'episode_content': episode.content,
116 |         'episode_timestamp': episode.valid_at.isoformat(),
117 |         'previous_episodes': [ep.content for ep in previous_episodes],
118 |         'custom_extraction_instructions': custom_extraction_instructions or '',
119 |         'entity_types': entity_types_context,
120 |         'source_description': episode.source_description,
121 |     }
122 | 
123 |     # Check if chunking is needed (based on entity density)
124 |     if should_chunk(episode.content, episode.source):
125 |         extracted_entities = await _extract_nodes_chunked(llm_client, episode, context)
126 |     else:
127 |         extracted_entities = await _extract_nodes_single(llm_client, episode, context)
128 | 
129 |     # Filter empty names
130 |     filtered_entities = [e for e in extracted_entities if e.name.strip()]
131 | 
132 |     end = time()
133 |     logger.debug(f'Extracted {len(filtered_entities)} entities in {(end - start) * 1000:.0f} ms')
134 | 
135 |     # Convert to EntityNode objects
136 |     extracted_nodes = _create_entity_nodes(
137 |         filtered_entities, entity_types_context, excluded_entity_types, episode
138 |     )
139 | 
140 |     logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
141 |     return extracted_nodes
142 | 
143 | 
144 | def _build_entity_types_context(
145 |     entity_types: dict[str, type[BaseModel]] | None,
146 | ) -> list[dict]:
147 |     """Build entity types context with ID mappings."""
148 |     entity_types_context = [
149 |         {
150 |             'entity_type_id': 0,
151 |             'entity_type_name': 'Entity',
152 |             'entity_type_description': (
153 |                 'Default entity classification. Use this entity type '
154 |                 'if the entity is not one of the other listed types.'
155 |             ),
156 |         }
157 |     ]
158 | 
159 |     if entity_types is not None:
160 |         entity_types_context += [
161 |             {
162 |                 'entity_type_id': i + 1,
163 |                 'entity_type_name': type_name,
164 |                 'entity_type_description': type_model.__doc__,
165 |             }
166 |             for i, (type_name, type_model) in enumerate(entity_types.items())
167 |         ]
168 | 
169 |     return entity_types_context
170 | 
171 | 
172 | async def _extract_nodes_single(
173 |     llm_client: LLMClient,
174 |     episode: EpisodicNode,
175 |     context: dict,
176 | ) -> list[ExtractedEntity]:
177 |     """Extract entities using a single LLM call."""
178 |     llm_response = await _call_extraction_llm(llm_client, episode, context)
179 |     response_object = ExtractedEntities(**llm_response)
180 |     return response_object.extracted_entities
181 | 
182 | 
183 | async def _extract_nodes_chunked(
184 |     llm_client: LLMClient,
185 |     episode: EpisodicNode,
186 |     context: dict,
187 | ) -> list[ExtractedEntity]:
188 |     """Extract entities from large content using chunking."""
189 |     # Chunk the content based on episode type
190 |     if episode.source == EpisodeType.json:
191 |         chunks = chunk_json_content(episode.content)
192 |     elif episode.source == EpisodeType.message:
193 |         chunks = chunk_message_content(episode.content)
194 |     else:
195 |         chunks = chunk_text_content(episode.content)
196 | 
197 |     logger.debug(f'Chunked content into {len(chunks)} chunks for entity extraction')
198 | 
199 |     # Extract entities from each chunk in parallel
200 |     chunk_results = await semaphore_gather(
201 |         *[_extract_from_chunk(llm_client, chunk, context, episode) for chunk in chunks]
202 |     )
203 | 
204 |     # Merge and deduplicate entities across chunks
205 |     merged_entities = _merge_extracted_entities(chunk_results)
206 |     logger.debug(
207 |         f'Merged {sum(len(r) for r in chunk_results)} entities into {len(merged_entities)} unique'
208 |     )
209 | 
210 |     return merged_entities
211 | 
212 | 
213 | async def _extract_from_chunk(
214 |     llm_client: LLMClient,
215 |     chunk: str,
216 |     base_context: dict,
217 |     episode: EpisodicNode,
218 | ) -> list[ExtractedEntity]:
219 |     """Extract entities from a single chunk."""
220 |     chunk_context = {**base_context, 'episode_content': chunk}
221 |     llm_response = await _call_extraction_llm(llm_client, episode, chunk_context)
222 |     return ExtractedEntities(**llm_response).extracted_entities
223 | 
224 | 
225 | async def _call_extraction_llm(
226 |     llm_client: LLMClient,
227 |     episode: EpisodicNode,
228 |     context: dict,
229 | ) -> dict:
230 |     """Call the appropriate extraction prompt based on episode type."""
231 |     if episode.source == EpisodeType.message:
232 |         prompt = prompt_library.extract_nodes.extract_message(context)
233 |         prompt_name = 'extract_nodes.extract_message'
234 |     elif episode.source == EpisodeType.text:
235 |         prompt = prompt_library.extract_nodes.extract_text(context)
236 |         prompt_name = 'extract_nodes.extract_text'
237 |     elif episode.source == EpisodeType.json:
238 |         prompt = prompt_library.extract_nodes.extract_json(context)
239 |         prompt_name = 'extract_nodes.extract_json'
240 |     else:
241 |         # Fallback to text extraction
242 |         prompt = prompt_library.extract_nodes.extract_text(context)
243 |         prompt_name = 'extract_nodes.extract_text'
244 | 
245 |     return await llm_client.generate_response(
246 |         prompt,
247 |         response_model=ExtractedEntities,
248 |         group_id=episode.group_id,
249 |         prompt_name=prompt_name,
250 |     )
251 | 
252 | 
253 | def _merge_extracted_entities(
254 |     chunk_results: list[list[ExtractedEntity]],
255 | ) -> list[ExtractedEntity]:
256 |     """Merge entities from multiple chunks, deduplicating by normalized name.
257 | 
258 |     When duplicates occur, prefer the first occurrence (maintains ordering).
259 |     """
260 |     seen_names: set[str] = set()
261 |     merged: list[ExtractedEntity] = []
262 | 
263 |     for entities in chunk_results:
264 |         for entity in entities:
265 |             normalized = entity.name.strip().lower()
266 |             if normalized and normalized not in seen_names:
267 |                 seen_names.add(normalized)
268 |                 merged.append(entity)
269 | 
270 |     return merged
271 | 
272 | 
273 | def _create_entity_nodes(
274 |     extracted_entities: list[ExtractedEntity],
275 |     entity_types_context: list[dict],
276 |     excluded_entity_types: list[str] | None,
277 |     episode: EpisodicNode,
278 | ) -> list[EntityNode]:
279 |     """Convert ExtractedEntity objects to EntityNode objects."""
280 |     extracted_nodes = []
281 | 
282 |     for extracted_entity in extracted_entities:
283 |         type_id = extracted_entity.entity_type_id
284 |         if 0 <= type_id < len(entity_types_context):
285 |             entity_type_name = entity_types_context[type_id].get('entity_type_name')
286 |         else:
287 |             entity_type_name = 'Entity'
288 | 
289 |         # Check if this entity type should be excluded
290 |         if excluded_entity_types and entity_type_name in excluded_entity_types:
291 |             logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
292 |             continue
293 | 
294 |         labels: list[str] = list({'Entity', str(entity_type_name)})
295 | 
296 |         new_node = EntityNode(
297 |             name=extracted_entity.name,
298 |             group_id=episode.group_id,
299 |             labels=labels,
300 |             summary='',
301 |             created_at=utc_now(),
302 |         )
303 |         extracted_nodes.append(new_node)
304 |         logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
305 | 
306 |     return extracted_nodes
307 | 
308 | 
309 | async def _collect_candidate_nodes(
310 |     clients: GraphitiClients,
311 |     extracted_nodes: list[EntityNode],
312 |     existing_nodes_override: list[EntityNode] | None,
313 | ) -> list[EntityNode]:
314 |     """Search per extracted name and return unique candidates with overrides honored in order."""
315 |     search_results: list[SearchResults] = await semaphore_gather(
316 |         *[
317 |             search(
318 |                 clients=clients,
319 |                 query=node.name,
320 |                 group_ids=[node.group_id],
321 |                 search_filter=SearchFilters(),
322 |                 config=NODE_HYBRID_SEARCH_RRF,
323 |             )
324 |             for node in extracted_nodes
325 |         ]
326 |     )
327 | 
328 |     candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
329 | 
330 |     if existing_nodes_override is not None:
331 |         candidate_nodes.extend(existing_nodes_override)
332 | 
333 |     seen_candidate_uuids: set[str] = set()
334 |     ordered_candidates: list[EntityNode] = []
335 |     for candidate in candidate_nodes:
336 |         if candidate.uuid in seen_candidate_uuids:
337 |             continue
338 |         seen_candidate_uuids.add(candidate.uuid)
339 |         ordered_candidates.append(candidate)
340 | 
341 |     return ordered_candidates
342 | 
343 | 
344 | async def _resolve_with_llm(
345 |     llm_client: LLMClient,
346 |     extracted_nodes: list[EntityNode],
347 |     indexes: DedupCandidateIndexes,
348 |     state: DedupResolutionState,
349 |     episode: EpisodicNode | None,
350 |     previous_episodes: list[EpisodicNode] | None,
351 |     entity_types: dict[str, type[BaseModel]] | None,
352 | ) -> None:
353 |     """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
354 | 
355 |     The guardrails below defensively ignore malformed or duplicate LLM responses so the
356 |     ingestion workflow remains deterministic even when the model misbehaves.
357 |     """
358 |     if not state.unresolved_indices:
359 |         return
360 | 
361 |     entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
362 | 
363 |     llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
364 | 
365 |     extracted_nodes_context = [
366 |         {
367 |             'id': i,
368 |             'name': node.name,
369 |             'entity_type': node.labels,
370 |             'entity_type_description': entity_types_dict.get(
371 |                 next((item for item in node.labels if item != 'Entity'), '')
372 |             ).__doc__
373 |             or 'Default Entity Type',
374 |         }
375 |         for i, node in enumerate(llm_extracted_nodes)
376 |     ]
377 | 
378 |     sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
379 |     logger.debug(
380 |         'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
381 |         len(llm_extracted_nodes),
382 |         len(llm_extracted_nodes) - 1,
383 |         sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
384 |     )
385 |     if llm_extracted_nodes:
386 |         sample_size = min(3, len(extracted_nodes_context))
387 |         logger.debug(
388 |             'First %d entities: %s',
389 |             sample_size,
390 |             [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
391 |         )
392 |         if len(extracted_nodes_context) > 3:
393 |             logger.debug(
394 |                 'Last %d entities: %s',
395 |                 sample_size,
396 |                 [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
397 |             )
398 | 
399 |     existing_nodes_context = [
400 |         {
401 |             **{
402 |                 'idx': i,
403 |                 'name': candidate.name,
404 |                 'entity_types': candidate.labels,
405 |             },
406 |             **candidate.attributes,
407 |         }
408 |         for i, candidate in enumerate(indexes.existing_nodes)
409 |     ]
410 | 
411 |     context = {
412 |         'extracted_nodes': extracted_nodes_context,
413 |         'existing_nodes': existing_nodes_context,
414 |         'episode_content': episode.content if episode is not None else '',
415 |         'previous_episodes': (
416 |             [ep.content for ep in previous_episodes] if previous_episodes is not None else []
417 |         ),
418 |     }
419 | 
420 |     llm_response = await llm_client.generate_response(
421 |         prompt_library.dedupe_nodes.nodes(context),
422 |         response_model=NodeResolutions,
423 |         prompt_name='dedupe_nodes.nodes',
424 |     )
425 | 
426 |     node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
427 | 
428 |     valid_relative_range = range(len(state.unresolved_indices))
429 |     processed_relative_ids: set[int] = set()
430 | 
431 |     received_ids = {r.id for r in node_resolutions}
432 |     expected_ids = set(valid_relative_range)
433 |     missing_ids = expected_ids - received_ids
434 |     extra_ids = received_ids - expected_ids
435 | 
436 |     logger.debug(
437 |         'Received %d resolutions for %d entities',
438 |         len(node_resolutions),
439 |         len(state.unresolved_indices),
440 |     )
441 | 
442 |     if missing_ids:
443 |         logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
444 | 
445 |     if extra_ids:
446 |         logger.warning(
447 |             'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
448 |             len(state.unresolved_indices) - 1,
449 |             sorted(extra_ids),
450 |             sorted(received_ids),
451 |         )
452 | 
453 |     for resolution in node_resolutions:
454 |         relative_id: int = resolution.id
455 |         duplicate_idx: int = resolution.duplicate_idx
456 | 
457 |         if relative_id not in valid_relative_range:
458 |             logger.warning(
459 |                 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
460 |                 relative_id,
461 |                 len(state.unresolved_indices) - 1,
462 |                 len(node_resolutions),
463 |             )
464 |             continue
465 | 
466 |         if relative_id in processed_relative_ids:
467 |             logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
468 |             continue
469 |         processed_relative_ids.add(relative_id)
470 | 
471 |         original_index = state.unresolved_indices[relative_id]
472 |         extracted_node = extracted_nodes[original_index]
473 | 
474 |         resolved_node: EntityNode
475 |         if duplicate_idx == -1:
476 |             resolved_node = extracted_node
477 |         elif 0 <= duplicate_idx < len(indexes.existing_nodes):
478 |             resolved_node = indexes.existing_nodes[duplicate_idx]
479 |         else:
480 |             logger.warning(
481 |                 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
482 |                 duplicate_idx,
483 |                 extracted_node.uuid,
484 |             )
485 |             resolved_node = extracted_node
486 | 
487 |         state.resolved_nodes[original_index] = resolved_node
488 |         state.uuid_map[extracted_node.uuid] = resolved_node.uuid
489 |         if resolved_node.uuid != extracted_node.uuid:
490 |             state.duplicate_pairs.append((extracted_node, resolved_node))
491 | 
492 | 
493 | async def resolve_extracted_nodes(
494 |     clients: GraphitiClients,
495 |     extracted_nodes: list[EntityNode],
496 |     episode: EpisodicNode | None = None,
497 |     previous_episodes: list[EpisodicNode] | None = None,
498 |     entity_types: dict[str, type[BaseModel]] | None = None,
499 |     existing_nodes_override: list[EntityNode] | None = None,
500 | ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
501 |     """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
502 |     llm_client = clients.llm_client
503 |     driver = clients.driver
504 |     existing_nodes = await _collect_candidate_nodes(
505 |         clients,
506 |         extracted_nodes,
507 |         existing_nodes_override,
508 |     )
509 | 
510 |     indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
511 | 
512 |     state = DedupResolutionState(
513 |         resolved_nodes=[None] * len(extracted_nodes),
514 |         uuid_map={},
515 |         unresolved_indices=[],
516 |     )
517 | 
518 |     _resolve_with_similarity(extracted_nodes, indexes, state)
519 | 
520 |     await _resolve_with_llm(
521 |         llm_client,
522 |         extracted_nodes,
523 |         indexes,
524 |         state,
525 |         episode,
526 |         previous_episodes,
527 |         entity_types,
528 |     )
529 | 
530 |     for idx, node in enumerate(extracted_nodes):
531 |         if state.resolved_nodes[idx] is None:
532 |             state.resolved_nodes[idx] = node
533 |             state.uuid_map[node.uuid] = node.uuid
534 | 
535 |     logger.debug(
536 |         'Resolved nodes: %s',
537 |         [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
538 |     )
539 | 
540 |     new_node_duplicates: list[
541 |         tuple[EntityNode, EntityNode]
542 |     ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
543 | 
544 |     return (
545 |         [node for node in state.resolved_nodes if node is not None],
546 |         state.uuid_map,
547 |         new_node_duplicates,
548 |     )
549 | 
550 | 
551 | async def extract_attributes_from_nodes(
552 |     clients: GraphitiClients,
553 |     nodes: list[EntityNode],
554 |     episode: EpisodicNode | None = None,
555 |     previous_episodes: list[EpisodicNode] | None = None,
556 |     entity_types: dict[str, type[BaseModel]] | None = None,
557 |     should_summarize_node: NodeSummaryFilter | None = None,
558 | ) -> list[EntityNode]:
559 |     llm_client = clients.llm_client
560 |     embedder = clients.embedder
561 |     updated_nodes: list[EntityNode] = await semaphore_gather(
562 |         *[
563 |             extract_attributes_from_node(
564 |                 llm_client,
565 |                 node,
566 |                 episode,
567 |                 previous_episodes,
568 |                 (
569 |                     entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
570 |                     if entity_types is not None
571 |                     else None
572 |                 ),
573 |                 should_summarize_node,
574 |             )
575 |             for node in nodes
576 |         ]
577 |     )
578 | 
579 |     await create_entity_node_embeddings(embedder, updated_nodes)
580 | 
581 |     return updated_nodes
582 | 
583 | 
584 | async def extract_attributes_from_node(
585 |     llm_client: LLMClient,
586 |     node: EntityNode,
587 |     episode: EpisodicNode | None = None,
588 |     previous_episodes: list[EpisodicNode] | None = None,
589 |     entity_type: type[BaseModel] | None = None,
590 |     should_summarize_node: NodeSummaryFilter | None = None,
591 | ) -> EntityNode:
592 |     # Extract attributes if entity type is defined and has attributes
593 |     llm_response = await _extract_entity_attributes(
594 |         llm_client, node, episode, previous_episodes, entity_type
595 |     )
596 | 
597 |     # Extract summary if needed
598 |     await _extract_entity_summary(
599 |         llm_client, node, episode, previous_episodes, should_summarize_node
600 |     )
601 | 
602 |     node.attributes.update(llm_response)
603 | 
604 |     return node
605 | 
606 | 
607 | async def _extract_entity_attributes(
608 |     llm_client: LLMClient,
609 |     node: EntityNode,
610 |     episode: EpisodicNode | None,
611 |     previous_episodes: list[EpisodicNode] | None,
612 |     entity_type: type[BaseModel] | None,
613 | ) -> dict[str, Any]:
614 |     if entity_type is None or len(entity_type.model_fields) == 0:
615 |         return {}
616 | 
617 |     attributes_context = _build_episode_context(
618 |         # should not include summary
619 |         node_data={
620 |             'name': node.name,
621 |             'entity_types': node.labels,
622 |             'attributes': node.attributes,
623 |         },
624 |         episode=episode,
625 |         previous_episodes=previous_episodes,
626 |     )
627 | 
628 |     llm_response = await llm_client.generate_response(
629 |         prompt_library.extract_nodes.extract_attributes(attributes_context),
630 |         response_model=entity_type,
631 |         model_size=ModelSize.small,
632 |         group_id=node.group_id,
633 |         prompt_name='extract_nodes.extract_attributes',
634 |     )
635 | 
636 |     # validate response
637 |     entity_type(**llm_response)
638 | 
639 |     return llm_response
640 | 
641 | 
642 | async def _extract_entity_summary(
643 |     llm_client: LLMClient,
644 |     node: EntityNode,
645 |     episode: EpisodicNode | None,
646 |     previous_episodes: list[EpisodicNode] | None,
647 |     should_summarize_node: NodeSummaryFilter | None,
648 | ) -> None:
649 |     if should_summarize_node is not None and not await should_summarize_node(node):
650 |         return
651 | 
652 |     summary_context = _build_episode_context(
653 |         node_data={
654 |             'name': node.name,
655 |             'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
656 |             'entity_types': node.labels,
657 |             'attributes': node.attributes,
658 |         },
659 |         episode=episode,
660 |         previous_episodes=previous_episodes,
661 |     )
662 | 
663 |     summary_response = await llm_client.generate_response(
664 |         prompt_library.extract_nodes.extract_summary(summary_context),
665 |         response_model=EntitySummary,
666 |         model_size=ModelSize.small,
667 |         group_id=node.group_id,
668 |         prompt_name='extract_nodes.extract_summary',
669 |     )
670 | 
671 |     node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
672 | 
673 | 
674 | def _build_episode_context(
675 |     node_data: dict[str, Any],
676 |     episode: EpisodicNode | None,
677 |     previous_episodes: list[EpisodicNode] | None,
678 | ) -> dict[str, Any]:
679 |     return {
680 |         'node': node_data,
681 |         'episode_content': episode.content if episode is not None else '',
682 |         'previous_episodes': (
683 |             [ep.content for ep in previous_episodes] if previous_episodes is not None else []
684 |         ),
685 |     }
686 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_comprehensive_integration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Comprehensive integration test suite for Graphiti MCP Server.
  4 | Covers all MCP tools with consideration for LLM inference latency.
  5 | """
  6 | 
  7 | import asyncio
  8 | import json
  9 | import os
 10 | import time
 11 | from dataclasses import dataclass
 12 | from typing import Any
 13 | 
 14 | import pytest
 15 | from mcp import ClientSession, StdioServerParameters
 16 | from mcp.client.stdio import stdio_client
 17 | 
 18 | 
 19 | @dataclass
 20 | class TestMetrics:
 21 |     """Track test performance metrics."""
 22 | 
 23 |     operation: str
 24 |     start_time: float
 25 |     end_time: float
 26 |     success: bool
 27 |     details: dict[str, Any]
 28 | 
 29 |     @property
 30 |     def duration(self) -> float:
 31 |         """Calculate operation duration in seconds."""
 32 |         return self.end_time - self.start_time
 33 | 
 34 | 
 35 | class GraphitiTestClient:
 36 |     """Enhanced test client for comprehensive Graphiti MCP testing."""
 37 | 
 38 |     def __init__(self, test_group_id: str | None = None):
 39 |         self.test_group_id = test_group_id or f'test_{int(time.time())}'
 40 |         self.session = None
 41 |         self.metrics: list[TestMetrics] = []
 42 |         self.default_timeout = 30  # seconds
 43 | 
 44 |     async def __aenter__(self):
 45 |         """Initialize MCP client session."""
 46 |         server_params = StdioServerParameters(
 47 |             command='uv',
 48 |             args=['run', '../main.py', '--transport', 'stdio'],
 49 |             env={
 50 |                 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
 51 |                 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
 52 |                 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
 53 |                 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key_for_mock'),
 54 |                 'FALKORDB_URI': os.environ.get('FALKORDB_URI', 'redis://localhost:6379'),
 55 |             },
 56 |         )
 57 | 
 58 |         self.client_context = stdio_client(server_params)
 59 |         read, write = await self.client_context.__aenter__()
 60 |         self.session = ClientSession(read, write)
 61 |         await self.session.initialize()
 62 | 
 63 |         # Wait for server to be fully ready
 64 |         await asyncio.sleep(2)
 65 | 
 66 |         return self
 67 | 
 68 |     async def __aexit__(self, exc_type, exc_val, exc_tb):
 69 |         """Clean up client session."""
 70 |         if self.session:
 71 |             await self.session.close()
 72 |         if hasattr(self, 'client_context'):
 73 |             await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
 74 | 
 75 |     async def call_tool_with_metrics(
 76 |         self, tool_name: str, arguments: dict[str, Any], timeout: float | None = None
 77 |     ) -> tuple[Any, TestMetrics]:
 78 |         """Call a tool and capture performance metrics."""
 79 |         start_time = time.time()
 80 |         timeout = timeout or self.default_timeout
 81 | 
 82 |         try:
 83 |             result = await asyncio.wait_for(
 84 |                 self.session.call_tool(tool_name, arguments), timeout=timeout
 85 |             )
 86 | 
 87 |             content = result.content[0].text if result.content else None
 88 |             success = True
 89 |             details = {'result': content, 'tool': tool_name}
 90 | 
 91 |         except asyncio.TimeoutError:
 92 |             content = None
 93 |             success = False
 94 |             details = {'error': f'Timeout after {timeout}s', 'tool': tool_name}
 95 | 
 96 |         except Exception as e:
 97 |             content = None
 98 |             success = False
 99 |             details = {'error': str(e), 'tool': tool_name}
100 | 
101 |         end_time = time.time()
102 |         metric = TestMetrics(
103 |             operation=f'call_{tool_name}',
104 |             start_time=start_time,
105 |             end_time=end_time,
106 |             success=success,
107 |             details=details,
108 |         )
109 |         self.metrics.append(metric)
110 | 
111 |         return content, metric
112 | 
113 |     async def wait_for_episode_processing(
114 |         self, expected_count: int = 1, max_wait: int = 60, poll_interval: int = 2
115 |     ) -> bool:
116 |         """
117 |         Wait for episodes to be processed with intelligent polling.
118 | 
119 |         Args:
120 |             expected_count: Number of episodes expected to be processed
121 |             max_wait: Maximum seconds to wait
122 |             poll_interval: Seconds between status checks
123 | 
124 |         Returns:
125 |             True if episodes were processed successfully
126 |         """
127 |         start_time = time.time()
128 | 
129 |         while (time.time() - start_time) < max_wait:
130 |             result, _ = await self.call_tool_with_metrics(
131 |                 'get_episodes', {'group_id': self.test_group_id, 'last_n': 100}
132 |             )
133 | 
134 |             if result:
135 |                 try:
136 |                     episodes = json.loads(result) if isinstance(result, str) else result
137 |                     if len(episodes.get('episodes', [])) >= expected_count:
138 |                         return True
139 |                 except (json.JSONDecodeError, AttributeError):
140 |                     pass
141 | 
142 |             await asyncio.sleep(poll_interval)
143 | 
144 |         return False
145 | 
146 | 
147 | class TestCoreOperations:
148 |     """Test core Graphiti operations."""
149 | 
150 |     @pytest.mark.asyncio
151 |     async def test_server_initialization(self):
152 |         """Verify server initializes with all required tools."""
153 |         async with GraphitiTestClient() as client:
154 |             tools_result = await client.session.list_tools()
155 |             tools = {tool.name for tool in tools_result.tools}
156 | 
157 |             required_tools = {
158 |                 'add_memory',
159 |                 'search_memory_nodes',
160 |                 'search_memory_facts',
161 |                 'get_episodes',
162 |                 'delete_episode',
163 |                 'delete_entity_edge',
164 |                 'get_entity_edge',
165 |                 'clear_graph',
166 |                 'get_status',
167 |             }
168 | 
169 |             missing_tools = required_tools - tools
170 |             assert not missing_tools, f'Missing required tools: {missing_tools}'
171 | 
172 |     @pytest.mark.asyncio
173 |     async def test_add_text_memory(self):
174 |         """Test adding text-based memories."""
175 |         async with GraphitiTestClient() as client:
176 |             # Add memory
177 |             result, metric = await client.call_tool_with_metrics(
178 |                 'add_memory',
179 |                 {
180 |                     'name': 'Tech Conference Notes',
181 |                     'episode_body': 'The AI conference featured talks on LLMs, RAG systems, and knowledge graphs. Notable speakers included researchers from OpenAI and Anthropic.',
182 |                     'source': 'text',
183 |                     'source_description': 'conference notes',
184 |                     'group_id': client.test_group_id,
185 |                 },
186 |             )
187 | 
188 |             assert metric.success, f'Failed to add memory: {metric.details}'
189 |             assert 'queued' in str(result).lower()
190 | 
191 |             # Wait for processing
192 |             processed = await client.wait_for_episode_processing(expected_count=1)
193 |             assert processed, 'Episode was not processed within timeout'
194 | 
195 |     @pytest.mark.asyncio
196 |     async def test_add_json_memory(self):
197 |         """Test adding structured JSON memories."""
198 |         async with GraphitiTestClient() as client:
199 |             json_data = {
200 |                 'project': {
201 |                     'name': 'GraphitiDB',
202 |                     'version': '2.0.0',
203 |                     'features': ['temporal-awareness', 'hybrid-search', 'custom-entities'],
204 |                 },
205 |                 'team': {'size': 5, 'roles': ['engineering', 'product', 'research']},
206 |             }
207 | 
208 |             result, metric = await client.call_tool_with_metrics(
209 |                 'add_memory',
210 |                 {
211 |                     'name': 'Project Data',
212 |                     'episode_body': json.dumps(json_data),
213 |                     'source': 'json',
214 |                     'source_description': 'project database',
215 |                     'group_id': client.test_group_id,
216 |                 },
217 |             )
218 | 
219 |             assert metric.success
220 |             assert 'queued' in str(result).lower()
221 | 
222 |     @pytest.mark.asyncio
223 |     async def test_add_message_memory(self):
224 |         """Test adding conversation/message memories."""
225 |         async with GraphitiTestClient() as client:
226 |             conversation = """
227 |             user: What are the key features of Graphiti?
228 |             assistant: Graphiti offers temporal-aware knowledge graphs, hybrid retrieval, and real-time updates.
229 |             user: How does it handle entity resolution?
230 |             assistant: It uses LLM-based entity extraction and deduplication with semantic similarity matching.
231 |             """
232 | 
233 |             result, metric = await client.call_tool_with_metrics(
234 |                 'add_memory',
235 |                 {
236 |                     'name': 'Feature Discussion',
237 |                     'episode_body': conversation,
238 |                     'source': 'message',
239 |                     'source_description': 'support chat',
240 |                     'group_id': client.test_group_id,
241 |                 },
242 |             )
243 | 
244 |             assert metric.success
245 |             assert metric.duration < 5, f'Add memory took too long: {metric.duration}s'
246 | 
247 | 
248 | class TestSearchOperations:
249 |     """Test search and retrieval operations."""
250 | 
251 |     @pytest.mark.asyncio
252 |     async def test_search_nodes_semantic(self):
253 |         """Test semantic search for nodes."""
254 |         async with GraphitiTestClient() as client:
255 |             # First add some test data
256 |             await client.call_tool_with_metrics(
257 |                 'add_memory',
258 |                 {
259 |                     'name': 'Product Launch',
260 |                     'episode_body': 'Our new AI assistant product launches in Q2 2024 with advanced NLP capabilities.',
261 |                     'source': 'text',
262 |                     'source_description': 'product roadmap',
263 |                     'group_id': client.test_group_id,
264 |                 },
265 |             )
266 | 
267 |             # Wait for processing
268 |             await client.wait_for_episode_processing()
269 | 
270 |             # Search for nodes
271 |             result, metric = await client.call_tool_with_metrics(
272 |                 'search_memory_nodes',
273 |                 {'query': 'AI product features', 'group_id': client.test_group_id, 'limit': 10},
274 |             )
275 | 
276 |             assert metric.success
277 |             assert result is not None
278 | 
279 |     @pytest.mark.asyncio
280 |     async def test_search_facts_with_filters(self):
281 |         """Test fact search with various filters."""
282 |         async with GraphitiTestClient() as client:
283 |             # Add test data
284 |             await client.call_tool_with_metrics(
285 |                 'add_memory',
286 |                 {
287 |                     'name': 'Company Facts',
288 |                     'episode_body': 'Acme Corp was founded in 2020. They have 50 employees and $10M in revenue.',
289 |                     'source': 'text',
290 |                     'source_description': 'company profile',
291 |                     'group_id': client.test_group_id,
292 |                 },
293 |             )
294 | 
295 |             await client.wait_for_episode_processing()
296 | 
297 |             # Search with date filter
298 |             result, metric = await client.call_tool_with_metrics(
299 |                 'search_memory_facts',
300 |                 {
301 |                     'query': 'company information',
302 |                     'group_id': client.test_group_id,
303 |                     'created_after': '2020-01-01T00:00:00Z',
304 |                     'limit': 20,
305 |                 },
306 |             )
307 | 
308 |             assert metric.success
309 | 
310 |     @pytest.mark.asyncio
311 |     async def test_hybrid_search(self):
312 |         """Test hybrid search combining semantic and keyword search."""
313 |         async with GraphitiTestClient() as client:
314 |             # Add diverse test data
315 |             test_memories = [
316 |                 {
317 |                     'name': 'Technical Doc',
318 |                     'episode_body': 'GraphQL API endpoints support pagination, filtering, and real-time subscriptions.',
319 |                     'source': 'text',
320 |                 },
321 |                 {
322 |                     'name': 'Architecture',
323 |                     'episode_body': 'The system uses Neo4j for graph storage and OpenAI embeddings for semantic search.',
324 |                     'source': 'text',
325 |                 },
326 |             ]
327 | 
328 |             for memory in test_memories:
329 |                 memory['group_id'] = client.test_group_id
330 |                 memory['source_description'] = 'documentation'
331 |                 await client.call_tool_with_metrics('add_memory', memory)
332 | 
333 |             await client.wait_for_episode_processing(expected_count=2)
334 | 
335 |             # Test semantic + keyword search
336 |             result, metric = await client.call_tool_with_metrics(
337 |                 'search_memory_nodes',
338 |                 {'query': 'Neo4j graph database', 'group_id': client.test_group_id, 'limit': 10},
339 |             )
340 | 
341 |             assert metric.success
342 | 
343 | 
344 | class TestEpisodeManagement:
345 |     """Test episode lifecycle operations."""
346 | 
347 |     @pytest.mark.asyncio
348 |     async def test_get_episodes_pagination(self):
349 |         """Test retrieving episodes with pagination."""
350 |         async with GraphitiTestClient() as client:
351 |             # Add multiple episodes
352 |             for i in range(5):
353 |                 await client.call_tool_with_metrics(
354 |                     'add_memory',
355 |                     {
356 |                         'name': f'Episode {i}',
357 |                         'episode_body': f'This is test episode number {i}',
358 |                         'source': 'text',
359 |                         'source_description': 'test',
360 |                         'group_id': client.test_group_id,
361 |                     },
362 |                 )
363 | 
364 |             await client.wait_for_episode_processing(expected_count=5)
365 | 
366 |             # Test pagination
367 |             result, metric = await client.call_tool_with_metrics(
368 |                 'get_episodes', {'group_id': client.test_group_id, 'last_n': 3}
369 |             )
370 | 
371 |             assert metric.success
372 |             episodes = json.loads(result) if isinstance(result, str) else result
373 |             assert len(episodes.get('episodes', [])) <= 3
374 | 
375 |     @pytest.mark.asyncio
376 |     async def test_delete_episode(self):
377 |         """Test deleting specific episodes."""
378 |         async with GraphitiTestClient() as client:
379 |             # Add an episode
380 |             await client.call_tool_with_metrics(
381 |                 'add_memory',
382 |                 {
383 |                     'name': 'To Delete',
384 |                     'episode_body': 'This episode will be deleted',
385 |                     'source': 'text',
386 |                     'source_description': 'test',
387 |                     'group_id': client.test_group_id,
388 |                 },
389 |             )
390 | 
391 |             await client.wait_for_episode_processing()
392 | 
393 |             # Get episode UUID
394 |             result, _ = await client.call_tool_with_metrics(
395 |                 'get_episodes', {'group_id': client.test_group_id, 'last_n': 1}
396 |             )
397 | 
398 |             episodes = json.loads(result) if isinstance(result, str) else result
399 |             episode_uuid = episodes['episodes'][0]['uuid']
400 | 
401 |             # Delete the episode
402 |             result, metric = await client.call_tool_with_metrics(
403 |                 'delete_episode', {'episode_uuid': episode_uuid}
404 |             )
405 | 
406 |             assert metric.success
407 |             assert 'deleted' in str(result).lower()
408 | 
409 | 
410 | class TestEntityAndEdgeOperations:
411 |     """Test entity and edge management."""
412 | 
413 |     @pytest.mark.asyncio
414 |     async def test_get_entity_edge(self):
415 |         """Test retrieving entity edges."""
416 |         async with GraphitiTestClient() as client:
417 |             # Add data to create entities and edges
418 |             await client.call_tool_with_metrics(
419 |                 'add_memory',
420 |                 {
421 |                     'name': 'Relationship Data',
422 |                     'episode_body': 'Alice works at TechCorp. Bob is the CEO of TechCorp.',
423 |                     'source': 'text',
424 |                     'source_description': 'org chart',
425 |                     'group_id': client.test_group_id,
426 |                 },
427 |             )
428 | 
429 |             await client.wait_for_episode_processing()
430 | 
431 |             # Search for nodes to get UUIDs
432 |             result, _ = await client.call_tool_with_metrics(
433 |                 'search_memory_nodes',
434 |                 {'query': 'TechCorp', 'group_id': client.test_group_id, 'limit': 5},
435 |             )
436 | 
437 |             # Note: This test assumes edges are created between entities
438 |             # Actual edge retrieval would require valid edge UUIDs
439 | 
440 |     @pytest.mark.asyncio
441 |     async def test_delete_entity_edge(self):
442 |         """Test deleting entity edges."""
443 |         # Similar structure to get_entity_edge but with deletion
444 |         pass  # Implement based on actual edge creation patterns
445 | 
446 | 
447 | class TestErrorHandling:
448 |     """Test error conditions and edge cases."""
449 | 
450 |     @pytest.mark.asyncio
451 |     async def test_invalid_tool_arguments(self):
452 |         """Test handling of invalid tool arguments."""
453 |         async with GraphitiTestClient() as client:
454 |             # Missing required arguments
455 |             result, metric = await client.call_tool_with_metrics(
456 |                 'add_memory',
457 |                 {'name': 'Incomplete'},  # Missing required fields
458 |             )
459 | 
460 |             assert not metric.success
461 |             assert 'error' in str(metric.details).lower()
462 | 
463 |     @pytest.mark.asyncio
464 |     async def test_timeout_handling(self):
465 |         """Test timeout handling for long operations."""
466 |         async with GraphitiTestClient() as client:
467 |             # Simulate a very large episode that might time out
468 |             large_text = 'Large document content. ' * 10000
469 | 
470 |             result, metric = await client.call_tool_with_metrics(
471 |                 'add_memory',
472 |                 {
473 |                     'name': 'Large Document',
474 |                     'episode_body': large_text,
475 |                     'source': 'text',
476 |                     'source_description': 'large file',
477 |                     'group_id': client.test_group_id,
478 |                 },
479 |                 timeout=5,  # Short timeout
480 |             )
481 | 
482 |             # Check if timeout was handled gracefully
483 |             if not metric.success:
484 |                 assert 'timeout' in str(metric.details).lower()
485 | 
486 |     @pytest.mark.asyncio
487 |     async def test_concurrent_operations(self):
488 |         """Test handling of concurrent operations."""
489 |         async with GraphitiTestClient() as client:
490 |             # Launch multiple operations concurrently
491 |             tasks = []
492 |             for i in range(5):
493 |                 task = client.call_tool_with_metrics(
494 |                     'add_memory',
495 |                     {
496 |                         'name': f'Concurrent {i}',
497 |                         'episode_body': f'Concurrent operation {i}',
498 |                         'source': 'text',
499 |                         'source_description': 'concurrent test',
500 |                         'group_id': client.test_group_id,
501 |                     },
502 |                 )
503 |                 tasks.append(task)
504 | 
505 |             results = await asyncio.gather(*tasks, return_exceptions=True)
506 | 
507 |             # Check that operations were queued successfully
508 |             successful = sum(1 for r, m in results if m.success)
509 |             assert successful >= 3  # At least 60% should succeed
510 | 
511 | 
512 | class TestPerformance:
513 |     """Test performance characteristics and optimization."""
514 | 
515 |     @pytest.mark.asyncio
516 |     async def test_latency_metrics(self):
517 |         """Measure and validate operation latencies."""
518 |         async with GraphitiTestClient() as client:
519 |             operations = [
520 |                 (
521 |                     'add_memory',
522 |                     {
523 |                         'name': 'Perf Test',
524 |                         'episode_body': 'Simple text',
525 |                         'source': 'text',
526 |                         'source_description': 'test',
527 |                         'group_id': client.test_group_id,
528 |                     },
529 |                 ),
530 |                 (
531 |                     'search_memory_nodes',
532 |                     {'query': 'test', 'group_id': client.test_group_id, 'limit': 10},
533 |                 ),
534 |                 ('get_episodes', {'group_id': client.test_group_id, 'last_n': 10}),
535 |             ]
536 | 
537 |             for tool_name, args in operations:
538 |                 _, metric = await client.call_tool_with_metrics(tool_name, args)
539 | 
540 |                 # Log performance metrics
541 |                 print(f'{tool_name}: {metric.duration:.2f}s')
542 | 
543 |                 # Basic latency assertions
544 |                 if tool_name == 'get_episodes':
545 |                     assert metric.duration < 2, f'{tool_name} too slow'
546 |                 elif tool_name == 'search_memory_nodes':
547 |                     assert metric.duration < 10, f'{tool_name} too slow'
548 | 
549 |     @pytest.mark.asyncio
550 |     async def test_batch_processing_efficiency(self):
551 |         """Test efficiency of batch operations."""
552 |         async with GraphitiTestClient() as client:
553 |             batch_size = 10
554 |             start_time = time.time()
555 | 
556 |             # Batch add memories
557 |             for i in range(batch_size):
558 |                 await client.call_tool_with_metrics(
559 |                     'add_memory',
560 |                     {
561 |                         'name': f'Batch {i}',
562 |                         'episode_body': f'Batch content {i}',
563 |                         'source': 'text',
564 |                         'source_description': 'batch test',
565 |                         'group_id': client.test_group_id,
566 |                     },
567 |                 )
568 | 
569 |             # Wait for all to process
570 |             processed = await client.wait_for_episode_processing(
571 |                 expected_count=batch_size,
572 |                 max_wait=120,  # Allow more time for batch
573 |             )
574 | 
575 |             total_time = time.time() - start_time
576 |             avg_time_per_item = total_time / batch_size
577 | 
578 |             assert processed, f'Failed to process {batch_size} items'
579 |             assert avg_time_per_item < 15, (
580 |                 f'Batch processing too slow: {avg_time_per_item:.2f}s per item'
581 |             )
582 | 
583 |             # Generate performance report
584 |             print('\nBatch Performance Report:')
585 |             print(f'  Total items: {batch_size}')
586 |             print(f'  Total time: {total_time:.2f}s')
587 |             print(f'  Avg per item: {avg_time_per_item:.2f}s')
588 | 
589 | 
590 | class TestDatabaseBackends:
591 |     """Test different database backend configurations."""
592 | 
593 |     @pytest.mark.asyncio
594 |     @pytest.mark.parametrize('database', ['neo4j', 'falkordb'])
595 |     async def test_database_operations(self, database):
596 |         """Test operations with different database backends."""
597 |         env_vars = {
598 |             'DATABASE_PROVIDER': database,
599 |             'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY'),
600 |         }
601 | 
602 |         if database == 'neo4j':
603 |             env_vars.update(
604 |                 {
605 |                     'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
606 |                     'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
607 |                     'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
608 |                 }
609 |             )
610 |         elif database == 'falkordb':
611 |             env_vars['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
612 | 
613 |         # This test would require setting up server with specific database
614 |         # Implementation depends on database availability
615 |         pass  # Placeholder for database-specific tests
616 | 
617 | 
618 | def generate_test_report(client: GraphitiTestClient) -> str:
619 |     """Generate a comprehensive test report from metrics."""
620 |     if not client.metrics:
621 |         return 'No metrics collected'
622 | 
623 |     report = []
624 |     report.append('\n' + '=' * 60)
625 |     report.append('GRAPHITI MCP TEST REPORT')
626 |     report.append('=' * 60)
627 | 
628 |     # Summary statistics
629 |     total_ops = len(client.metrics)
630 |     successful_ops = sum(1 for m in client.metrics if m.success)
631 |     avg_duration = sum(m.duration for m in client.metrics) / total_ops
632 | 
633 |     report.append(f'\nTotal Operations: {total_ops}')
634 |     report.append(f'Successful: {successful_ops} ({successful_ops / total_ops * 100:.1f}%)')
635 |     report.append(f'Average Duration: {avg_duration:.2f}s')
636 | 
637 |     # Operation breakdown
638 |     report.append('\nOperation Breakdown:')
639 |     operation_stats = {}
640 |     for metric in client.metrics:
641 |         if metric.operation not in operation_stats:
642 |             operation_stats[metric.operation] = {'count': 0, 'success': 0, 'total_duration': 0}
643 |         stats = operation_stats[metric.operation]
644 |         stats['count'] += 1
645 |         stats['success'] += 1 if metric.success else 0
646 |         stats['total_duration'] += metric.duration
647 | 
648 |     for op, stats in sorted(operation_stats.items()):
649 |         avg_dur = stats['total_duration'] / stats['count']
650 |         success_rate = stats['success'] / stats['count'] * 100
651 |         report.append(
652 |             f'  {op}: {stats["count"]} calls, {success_rate:.0f}% success, {avg_dur:.2f}s avg'
653 |         )
654 | 
655 |     # Slowest operations
656 |     slowest = sorted(client.metrics, key=lambda m: m.duration, reverse=True)[:5]
657 |     report.append('\nSlowest Operations:')
658 |     for metric in slowest:
659 |         report.append(f'  {metric.operation}: {metric.duration:.2f}s')
660 | 
661 |     report.append('=' * 60)
662 |     return '\n'.join(report)
663 | 
664 | 
665 | if __name__ == '__main__':
666 |     # Run tests with pytest
667 |     pytest.main([__file__, '-v', '--asyncio-mode=auto'])
668 | 
```
Page 8/12FirstPrevNextLast