#
tokens: 44447/50000 8/234 files (page 7/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 7 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/mcp_server/tests/test_async_operations.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Asynchronous operation tests for Graphiti MCP Server.
  4 | Tests concurrent operations, queue management, and async patterns.
  5 | """
  6 | 
  7 | import asyncio
  8 | import contextlib
  9 | import json
 10 | import time
 11 | 
 12 | import pytest
 13 | from test_fixtures import (
 14 |     TestDataGenerator,
 15 |     graphiti_test_client,
 16 | )
 17 | 
 18 | 
 19 | class TestAsyncQueueManagement:
 20 |     """Test asynchronous queue operations and episode processing."""
 21 | 
 22 |     @pytest.mark.asyncio
 23 |     async def test_sequential_queue_processing(self):
 24 |         """Verify episodes are processed sequentially within a group."""
 25 |         async with graphiti_test_client() as (session, group_id):
 26 |             # Add multiple episodes quickly
 27 |             episodes = []
 28 |             for i in range(5):
 29 |                 result = await session.call_tool(
 30 |                     'add_memory',
 31 |                     {
 32 |                         'name': f'Sequential Test {i}',
 33 |                         'episode_body': f'Episode {i} with timestamp {time.time()}',
 34 |                         'source': 'text',
 35 |                         'source_description': 'sequential test',
 36 |                         'group_id': group_id,
 37 |                         'reference_id': f'seq_{i}',  # Add reference for tracking
 38 |                     },
 39 |                 )
 40 |                 episodes.append(result)
 41 | 
 42 |             # Wait for processing
 43 |             await asyncio.sleep(10)  # Allow time for sequential processing
 44 | 
 45 |             # Retrieve episodes and verify order
 46 |             result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})
 47 | 
 48 |             processed_episodes = json.loads(result.content[0].text)['episodes']
 49 | 
 50 |             # Verify all episodes were processed
 51 |             assert len(processed_episodes) >= 5, (
 52 |                 f'Expected at least 5 episodes, got {len(processed_episodes)}'
 53 |             )
 54 | 
 55 |             # Verify sequential processing (timestamps should be ordered)
 56 |             timestamps = [ep.get('created_at') for ep in processed_episodes]
 57 |             assert timestamps == sorted(timestamps), 'Episodes not processed in order'
 58 | 
 59 |     @pytest.mark.asyncio
 60 |     async def test_concurrent_group_processing(self):
 61 |         """Test that different groups can process concurrently."""
 62 |         async with graphiti_test_client() as (session, _):
 63 |             groups = [f'group_{i}_{time.time()}' for i in range(3)]
 64 |             tasks = []
 65 | 
 66 |             # Create tasks for different groups
 67 |             for group_id in groups:
 68 |                 for j in range(2):
 69 |                     task = session.call_tool(
 70 |                         'add_memory',
 71 |                         {
 72 |                             'name': f'Group {group_id} Episode {j}',
 73 |                             'episode_body': f'Content for {group_id}',
 74 |                             'source': 'text',
 75 |                             'source_description': 'concurrent test',
 76 |                             'group_id': group_id,
 77 |                         },
 78 |                     )
 79 |                     tasks.append(task)
 80 | 
 81 |             # Execute all tasks concurrently
 82 |             start_time = time.time()
 83 |             results = await asyncio.gather(*tasks, return_exceptions=True)
 84 |             execution_time = time.time() - start_time
 85 | 
 86 |             # Verify all succeeded
 87 |             failures = [r for r in results if isinstance(r, Exception)]
 88 |             assert not failures, f'Concurrent operations failed: {failures}'
 89 | 
 90 |             # Check that execution was actually concurrent (should be faster than sequential)
 91 |             # Sequential would take at least 6 * processing_time
 92 |             assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'
 93 | 
 94 |     @pytest.mark.asyncio
 95 |     async def test_queue_overflow_handling(self):
 96 |         """Test behavior when queue reaches capacity."""
 97 |         async with graphiti_test_client() as (session, group_id):
 98 |             # Attempt to add many episodes rapidly
 99 |             tasks = []
100 |             for i in range(100):  # Large number to potentially overflow
101 |                 task = session.call_tool(
102 |                     'add_memory',
103 |                     {
104 |                         'name': f'Overflow Test {i}',
105 |                         'episode_body': f'Episode {i}',
106 |                         'source': 'text',
107 |                         'source_description': 'overflow test',
108 |                         'group_id': group_id,
109 |                     },
110 |                 )
111 |                 tasks.append(task)
112 | 
113 |             # Execute with gathering to catch any failures
114 |             results = await asyncio.gather(*tasks, return_exceptions=True)
115 | 
116 |             # Count successful queuing
117 |             successful = sum(1 for r in results if not isinstance(r, Exception))
118 | 
119 |             # Should handle overflow gracefully
120 |             assert successful > 0, 'No episodes were queued successfully'
121 | 
122 |             # Log overflow behavior
123 |             if successful < 100:
124 |                 print(f'Queue overflow: {successful}/100 episodes queued')
125 | 
126 | 
127 | class TestConcurrentOperations:
128 |     """Test concurrent tool calls and operations."""
129 | 
130 |     @pytest.mark.asyncio
131 |     async def test_concurrent_search_operations(self):
132 |         """Test multiple concurrent search operations."""
133 |         async with graphiti_test_client() as (session, group_id):
134 |             # First, add some test data
135 |             data_gen = TestDataGenerator()
136 | 
137 |             add_tasks = []
138 |             for _ in range(5):
139 |                 task = session.call_tool(
140 |                     'add_memory',
141 |                     {
142 |                         'name': 'Search Test Data',
143 |                         'episode_body': data_gen.generate_technical_document(),
144 |                         'source': 'text',
145 |                         'source_description': 'search test',
146 |                         'group_id': group_id,
147 |                     },
148 |                 )
149 |                 add_tasks.append(task)
150 | 
151 |             await asyncio.gather(*add_tasks)
152 |             await asyncio.sleep(15)  # Wait for processing
153 | 
154 |             # Now perform concurrent searches
155 |             search_queries = [
156 |                 'architecture',
157 |                 'performance',
158 |                 'implementation',
159 |                 'dependencies',
160 |                 'latency',
161 |             ]
162 | 
163 |             search_tasks = []
164 |             for query in search_queries:
165 |                 task = session.call_tool(
166 |                     'search_memory_nodes',
167 |                     {
168 |                         'query': query,
169 |                         'group_id': group_id,
170 |                         'limit': 10,
171 |                     },
172 |                 )
173 |                 search_tasks.append(task)
174 | 
175 |             start_time = time.time()
176 |             results = await asyncio.gather(*search_tasks, return_exceptions=True)
177 |             search_time = time.time() - start_time
178 | 
179 |             # Verify all searches completed
180 |             failures = [r for r in results if isinstance(r, Exception)]
181 |             assert not failures, f'Search operations failed: {failures}'
182 | 
183 |             # Verify concurrent execution efficiency
184 |             assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'
185 | 
186 |     @pytest.mark.asyncio
187 |     async def test_mixed_operation_concurrency(self):
188 |         """Test different types of operations running concurrently."""
189 |         async with graphiti_test_client() as (session, group_id):
190 |             operations = []
191 | 
192 |             # Add memory operation
193 |             operations.append(
194 |                 session.call_tool(
195 |                     'add_memory',
196 |                     {
197 |                         'name': 'Mixed Op Test',
198 |                         'episode_body': 'Testing mixed operations',
199 |                         'source': 'text',
200 |                         'source_description': 'test',
201 |                         'group_id': group_id,
202 |                     },
203 |                 )
204 |             )
205 | 
206 |             # Search operation
207 |             operations.append(
208 |                 session.call_tool(
209 |                     'search_memory_nodes',
210 |                     {
211 |                         'query': 'test',
212 |                         'group_id': group_id,
213 |                         'limit': 5,
214 |                     },
215 |                 )
216 |             )
217 | 
218 |             # Get episodes operation
219 |             operations.append(
220 |                 session.call_tool(
221 |                     'get_episodes',
222 |                     {
223 |                         'group_id': group_id,
224 |                         'last_n': 10,
225 |                     },
226 |                 )
227 |             )
228 | 
229 |             # Get status operation
230 |             operations.append(session.call_tool('get_status', {}))
231 | 
232 |             # Execute all concurrently
233 |             results = await asyncio.gather(*operations, return_exceptions=True)
234 | 
235 |             # Check results
236 |             for i, result in enumerate(results):
237 |                 assert not isinstance(result, Exception), f'Operation {i} failed: {result}'
238 | 
239 | 
240 | class TestAsyncErrorHandling:
241 |     """Test async error handling and recovery."""
242 | 
243 |     @pytest.mark.asyncio
244 |     async def test_timeout_recovery(self):
245 |         """Test recovery from operation timeouts."""
246 |         async with graphiti_test_client() as (session, group_id):
247 |             # Create a very large episode that might time out
248 |             large_content = 'x' * 1000000  # 1MB of data
249 | 
250 |             with contextlib.suppress(asyncio.TimeoutError):
251 |                 await asyncio.wait_for(
252 |                     session.call_tool(
253 |                         'add_memory',
254 |                         {
255 |                             'name': 'Timeout Test',
256 |                             'episode_body': large_content,
257 |                             'source': 'text',
258 |                             'source_description': 'timeout test',
259 |                             'group_id': group_id,
260 |                         },
261 |                     ),
262 |                     timeout=2.0,  # Short timeout - expected to timeout
263 |                 )
264 | 
265 |             # Verify server is still responsive after timeout
266 |             status_result = await session.call_tool('get_status', {})
267 |             assert status_result is not None, 'Server unresponsive after timeout'
268 | 
269 |     @pytest.mark.asyncio
270 |     async def test_cancellation_handling(self):
271 |         """Test proper handling of cancelled operations."""
272 |         async with graphiti_test_client() as (session, group_id):
273 |             # Start a long-running operation
274 |             task = asyncio.create_task(
275 |                 session.call_tool(
276 |                     'add_memory',
277 |                     {
278 |                         'name': 'Cancellation Test',
279 |                         'episode_body': TestDataGenerator.generate_technical_document(),
280 |                         'source': 'text',
281 |                         'source_description': 'cancel test',
282 |                         'group_id': group_id,
283 |                     },
284 |                 )
285 |             )
286 | 
287 |             # Cancel after a short delay
288 |             await asyncio.sleep(0.1)
289 |             task.cancel()
290 | 
291 |             # Verify cancellation was handled
292 |             with pytest.raises(asyncio.CancelledError):
293 |                 await task
294 | 
295 |             # Server should still be operational
296 |             result = await session.call_tool('get_status', {})
297 |             assert result is not None
298 | 
299 |     @pytest.mark.asyncio
300 |     async def test_exception_propagation(self):
301 |         """Test that exceptions are properly propagated in async context."""
302 |         async with graphiti_test_client() as (session, group_id):
303 |             # Call with invalid arguments
304 |             with pytest.raises(ValueError):
305 |                 await session.call_tool(
306 |                     'add_memory',
307 |                     {
308 |                         # Missing required fields
309 |                         'group_id': group_id,
310 |                     },
311 |                 )
312 | 
313 |             # Server should remain operational
314 |             status = await session.call_tool('get_status', {})
315 |             assert status is not None
316 | 
317 | 
318 | class TestAsyncPerformance:
319 |     """Performance tests for async operations."""
320 | 
321 |     @pytest.mark.asyncio
322 |     async def test_async_throughput(self, performance_benchmark):
323 |         """Measure throughput of async operations."""
324 |         async with graphiti_test_client() as (session, group_id):
325 |             num_operations = 50
326 |             start_time = time.time()
327 | 
328 |             # Create many concurrent operations
329 |             tasks = []
330 |             for i in range(num_operations):
331 |                 task = session.call_tool(
332 |                     'add_memory',
333 |                     {
334 |                         'name': f'Throughput Test {i}',
335 |                         'episode_body': f'Content {i}',
336 |                         'source': 'text',
337 |                         'source_description': 'throughput test',
338 |                         'group_id': group_id,
339 |                     },
340 |                 )
341 |                 tasks.append(task)
342 | 
343 |             # Execute all
344 |             results = await asyncio.gather(*tasks, return_exceptions=True)
345 |             total_time = time.time() - start_time
346 | 
347 |             # Calculate metrics
348 |             successful = sum(1 for r in results if not isinstance(r, Exception))
349 |             throughput = successful / total_time
350 | 
351 |             performance_benchmark.record('async_throughput', throughput)
352 | 
353 |             # Log results
354 |             print('\nAsync Throughput Test:')
355 |             print(f'  Operations: {num_operations}')
356 |             print(f'  Successful: {successful}')
357 |             print(f'  Total time: {total_time:.2f}s')
358 |             print(f'  Throughput: {throughput:.2f} ops/s')
359 | 
360 |             # Assert minimum throughput
361 |             assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'
362 | 
363 |     @pytest.mark.asyncio
364 |     async def test_latency_under_load(self, performance_benchmark):
365 |         """Test operation latency under concurrent load."""
366 |         async with graphiti_test_client() as (session, group_id):
367 |             # Create background load
368 |             background_tasks = []
369 |             for i in range(10):
370 |                 task = asyncio.create_task(
371 |                     session.call_tool(
372 |                         'add_memory',
373 |                         {
374 |                             'name': f'Background {i}',
375 |                             'episode_body': TestDataGenerator.generate_technical_document(),
376 |                             'source': 'text',
377 |                             'source_description': 'background',
378 |                             'group_id': f'background_{group_id}',
379 |                         },
380 |                     )
381 |                 )
382 |                 background_tasks.append(task)
383 | 
384 |             # Measure latency of operations under load
385 |             latencies = []
386 |             for _ in range(5):
387 |                 start = time.time()
388 |                 await session.call_tool('get_status', {})
389 |                 latency = time.time() - start
390 |                 latencies.append(latency)
391 |                 performance_benchmark.record('latency_under_load', latency)
392 | 
393 |             # Clean up background tasks
394 |             for task in background_tasks:
395 |                 task.cancel()
396 | 
397 |             # Analyze latencies
398 |             avg_latency = sum(latencies) / len(latencies)
399 |             max_latency = max(latencies)
400 | 
401 |             print('\nLatency Under Load:')
402 |             print(f'  Average: {avg_latency:.3f}s')
403 |             print(f'  Max: {max_latency:.3f}s')
404 | 
405 |             # Assert acceptable latency
406 |             assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
407 |             assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'
408 | 
409 | 
410 | class TestAsyncStreamHandling:
411 |     """Test handling of streaming responses and data."""
412 | 
413 |     @pytest.mark.asyncio
414 |     async def test_large_response_streaming(self):
415 |         """Test handling of large streamed responses."""
416 |         async with graphiti_test_client() as (session, group_id):
417 |             # Add many episodes
418 |             for i in range(20):
419 |                 await session.call_tool(
420 |                     'add_memory',
421 |                     {
422 |                         'name': f'Stream Test {i}',
423 |                         'episode_body': f'Episode content {i}',
424 |                         'source': 'text',
425 |                         'source_description': 'stream test',
426 |                         'group_id': group_id,
427 |                     },
428 |                 )
429 | 
430 |             # Wait for processing
431 |             await asyncio.sleep(30)
432 | 
433 |             # Request large result set
434 |             result = await session.call_tool(
435 |                 'get_episodes',
436 |                 {
437 |                     'group_id': group_id,
438 |                     'last_n': 100,  # Request all
439 |                 },
440 |             )
441 | 
442 |             # Verify response handling
443 |             episodes = json.loads(result.content[0].text)['episodes']
444 |             assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'
445 | 
446 |     @pytest.mark.asyncio
447 |     async def test_incremental_processing(self):
448 |         """Test incremental processing of results."""
449 |         async with graphiti_test_client() as (session, group_id):
450 |             # Add episodes incrementally
451 |             for batch in range(3):
452 |                 batch_tasks = []
453 |                 for i in range(5):
454 |                     task = session.call_tool(
455 |                         'add_memory',
456 |                         {
457 |                             'name': f'Batch {batch} Item {i}',
458 |                             'episode_body': f'Content for batch {batch}',
459 |                             'source': 'text',
460 |                             'source_description': 'incremental test',
461 |                             'group_id': group_id,
462 |                         },
463 |                     )
464 |                     batch_tasks.append(task)
465 | 
466 |                 # Process batch
467 |                 await asyncio.gather(*batch_tasks)
468 | 
469 |                 # Wait for this batch to process
470 |                 await asyncio.sleep(10)
471 | 
472 |                 # Verify incremental results
473 |                 result = await session.call_tool(
474 |                     'get_episodes',
475 |                     {
476 |                         'group_id': group_id,
477 |                         'last_n': 100,
478 |                     },
479 |                 )
480 | 
481 |                 episodes = json.loads(result.content[0].text)['episodes']
482 |                 expected_min = (batch + 1) * 5
483 |                 assert len(episodes) >= expected_min, (
484 |                     f'Batch {batch}: Expected at least {expected_min} episodes'
485 |                 )
486 | 
487 | 
488 | if __name__ == '__main__':
489 |     pytest.main([__file__, '-v', '--asyncio-mode=auto'])
490 | 
```

--------------------------------------------------------------------------------
/graphiti_core/search/search.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | from collections import defaultdict
 19 | from time import time
 20 | 
 21 | from graphiti_core.cross_encoder.client import CrossEncoderClient
 22 | from graphiti_core.driver.driver import GraphDriver
 23 | from graphiti_core.edges import EntityEdge
 24 | from graphiti_core.embedder.client import EMBEDDING_DIM
 25 | from graphiti_core.errors import SearchRerankerError
 26 | from graphiti_core.graphiti_types import GraphitiClients
 27 | from graphiti_core.helpers import semaphore_gather
 28 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
 29 | from graphiti_core.search.search_config import (
 30 |     DEFAULT_SEARCH_LIMIT,
 31 |     CommunityReranker,
 32 |     CommunitySearchConfig,
 33 |     CommunitySearchMethod,
 34 |     EdgeReranker,
 35 |     EdgeSearchConfig,
 36 |     EdgeSearchMethod,
 37 |     EpisodeReranker,
 38 |     EpisodeSearchConfig,
 39 |     NodeReranker,
 40 |     NodeSearchConfig,
 41 |     NodeSearchMethod,
 42 |     SearchConfig,
 43 |     SearchResults,
 44 | )
 45 | from graphiti_core.search.search_filters import SearchFilters
 46 | from graphiti_core.search.search_utils import (
 47 |     community_fulltext_search,
 48 |     community_similarity_search,
 49 |     edge_bfs_search,
 50 |     edge_fulltext_search,
 51 |     edge_similarity_search,
 52 |     episode_fulltext_search,
 53 |     episode_mentions_reranker,
 54 |     get_embeddings_for_communities,
 55 |     get_embeddings_for_edges,
 56 |     get_embeddings_for_nodes,
 57 |     maximal_marginal_relevance,
 58 |     node_bfs_search,
 59 |     node_distance_reranker,
 60 |     node_fulltext_search,
 61 |     node_similarity_search,
 62 |     rrf,
 63 | )
 64 | 
 65 | logger = logging.getLogger(__name__)
 66 | 
 67 | 
 68 | async def search(
 69 |     clients: GraphitiClients,
 70 |     query: str,
 71 |     group_ids: list[str] | None,
 72 |     config: SearchConfig,
 73 |     search_filter: SearchFilters,
 74 |     center_node_uuid: str | None = None,
 75 |     bfs_origin_node_uuids: list[str] | None = None,
 76 |     query_vector: list[float] | None = None,
 77 |     driver: GraphDriver | None = None,
 78 | ) -> SearchResults:
 79 |     start = time()
 80 | 
 81 |     driver = driver or clients.driver
 82 |     embedder = clients.embedder
 83 |     cross_encoder = clients.cross_encoder
 84 | 
 85 |     if query.strip() == '':
 86 |         return SearchResults()
 87 | 
 88 |     if (
 89 |         config.edge_config
 90 |         and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
 91 |         or config.edge_config
 92 |         and EdgeReranker.mmr == config.edge_config.reranker
 93 |         or config.node_config
 94 |         and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
 95 |         or config.node_config
 96 |         and NodeReranker.mmr == config.node_config.reranker
 97 |         or (
 98 |             config.community_config
 99 |             and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
100 |         )
101 |         or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
102 |     ):
103 |         search_vector = (
104 |             query_vector
105 |             if query_vector is not None
106 |             else await embedder.create(input_data=[query.replace('\n', ' ')])
107 |         )
108 |     else:
109 |         search_vector = [0.0] * EMBEDDING_DIM
110 | 
111 |     # if group_ids is empty, set it to None
112 |     group_ids = group_ids if group_ids and group_ids != [''] else None
113 |     (
114 |         (edges, edge_reranker_scores),
115 |         (nodes, node_reranker_scores),
116 |         (episodes, episode_reranker_scores),
117 |         (communities, community_reranker_scores),
118 |     ) = await semaphore_gather(
119 |         edge_search(
120 |             driver,
121 |             cross_encoder,
122 |             query,
123 |             search_vector,
124 |             group_ids,
125 |             config.edge_config,
126 |             search_filter,
127 |             center_node_uuid,
128 |             bfs_origin_node_uuids,
129 |             config.limit,
130 |             config.reranker_min_score,
131 |         ),
132 |         node_search(
133 |             driver,
134 |             cross_encoder,
135 |             query,
136 |             search_vector,
137 |             group_ids,
138 |             config.node_config,
139 |             search_filter,
140 |             center_node_uuid,
141 |             bfs_origin_node_uuids,
142 |             config.limit,
143 |             config.reranker_min_score,
144 |         ),
145 |         episode_search(
146 |             driver,
147 |             cross_encoder,
148 |             query,
149 |             search_vector,
150 |             group_ids,
151 |             config.episode_config,
152 |             search_filter,
153 |             config.limit,
154 |             config.reranker_min_score,
155 |         ),
156 |         community_search(
157 |             driver,
158 |             cross_encoder,
159 |             query,
160 |             search_vector,
161 |             group_ids,
162 |             config.community_config,
163 |             config.limit,
164 |             config.reranker_min_score,
165 |         ),
166 |     )
167 | 
168 |     results = SearchResults(
169 |         edges=edges,
170 |         edge_reranker_scores=edge_reranker_scores,
171 |         nodes=nodes,
172 |         node_reranker_scores=node_reranker_scores,
173 |         episodes=episodes,
174 |         episode_reranker_scores=episode_reranker_scores,
175 |         communities=communities,
176 |         community_reranker_scores=community_reranker_scores,
177 |     )
178 | 
179 |     latency = (time() - start) * 1000
180 | 
181 |     logger.debug(f'search returned context for query {query} in {latency} ms')
182 | 
183 |     return results
184 | 
185 | 
186 | async def edge_search(
187 |     driver: GraphDriver,
188 |     cross_encoder: CrossEncoderClient,
189 |     query: str,
190 |     query_vector: list[float],
191 |     group_ids: list[str] | None,
192 |     config: EdgeSearchConfig | None,
193 |     search_filter: SearchFilters,
194 |     center_node_uuid: str | None = None,
195 |     bfs_origin_node_uuids: list[str] | None = None,
196 |     limit=DEFAULT_SEARCH_LIMIT,
197 |     reranker_min_score: float = 0,
198 | ) -> tuple[list[EntityEdge], list[float]]:
199 |     if config is None:
200 |         return [], []
201 | 
202 |     # Build search tasks based on configured search methods
203 |     search_tasks = []
204 |     if EdgeSearchMethod.bm25 in config.search_methods:
205 |         search_tasks.append(
206 |             edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
207 |         )
208 |     if EdgeSearchMethod.cosine_similarity in config.search_methods:
209 |         search_tasks.append(
210 |             edge_similarity_search(
211 |                 driver,
212 |                 query_vector,
213 |                 None,
214 |                 None,
215 |                 search_filter,
216 |                 group_ids,
217 |                 2 * limit,
218 |                 config.sim_min_score,
219 |             )
220 |         )
221 |     if EdgeSearchMethod.bfs in config.search_methods:
222 |         search_tasks.append(
223 |             edge_bfs_search(
224 |                 driver,
225 |                 bfs_origin_node_uuids,
226 |                 config.bfs_max_depth,
227 |                 search_filter,
228 |                 group_ids,
229 |                 2 * limit,
230 |             )
231 |         )
232 | 
233 |     # Execute only the configured search methods
234 |     search_results: list[list[EntityEdge]] = []
235 |     if search_tasks:
236 |         search_results = list(await semaphore_gather(*search_tasks))
237 | 
238 |     if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
239 |         source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
240 |         search_results.append(
241 |             await edge_bfs_search(
242 |                 driver,
243 |                 source_node_uuids,
244 |                 config.bfs_max_depth,
245 |                 search_filter,
246 |                 group_ids,
247 |                 2 * limit,
248 |             )
249 |         )
250 | 
251 |     edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
252 | 
253 |     reranked_uuids: list[str] = []
254 |     edge_scores: list[float] = []
255 |     if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
256 |         search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
257 | 
258 |         reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
259 |     elif config.reranker == EdgeReranker.mmr:
260 |         search_result_uuids_and_vectors = await get_embeddings_for_edges(
261 |             driver, list(edge_uuid_map.values())
262 |         )
263 |         reranked_uuids, edge_scores = maximal_marginal_relevance(
264 |             query_vector,
265 |             search_result_uuids_and_vectors,
266 |             config.mmr_lambda,
267 |             reranker_min_score,
268 |         )
269 |     elif config.reranker == EdgeReranker.cross_encoder:
270 |         fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
271 |         reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
272 |         reranked_uuids = [
273 |             fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
274 |         ]
275 |         edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
276 |     elif config.reranker == EdgeReranker.node_distance:
277 |         if center_node_uuid is None:
278 |             raise SearchRerankerError('No center node provided for Node Distance reranker')
279 | 
280 |         # use rrf as a preliminary sort
281 |         sorted_result_uuids, node_scores = rrf(
282 |             [[edge.uuid for edge in result] for result in search_results],
283 |             min_score=reranker_min_score,
284 |         )
285 |         sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
286 | 
287 |         # node distance reranking
288 |         source_to_edge_uuid_map = defaultdict(list)
289 |         for edge in sorted_results:
290 |             source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
291 | 
292 |         source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
293 | 
294 |         reranked_node_uuids, edge_scores = await node_distance_reranker(
295 |             driver, source_uuids, center_node_uuid, min_score=reranker_min_score
296 |         )
297 | 
298 |         for node_uuid in reranked_node_uuids:
299 |             reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
300 | 
301 |     reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
302 | 
303 |     if config.reranker == EdgeReranker.episode_mentions:
304 |         reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
305 | 
306 |     return reranked_edges[:limit], edge_scores[:limit]
307 | 
308 | 
309 | async def node_search(
310 |     driver: GraphDriver,
311 |     cross_encoder: CrossEncoderClient,
312 |     query: str,
313 |     query_vector: list[float],
314 |     group_ids: list[str] | None,
315 |     config: NodeSearchConfig | None,
316 |     search_filter: SearchFilters,
317 |     center_node_uuid: str | None = None,
318 |     bfs_origin_node_uuids: list[str] | None = None,
319 |     limit=DEFAULT_SEARCH_LIMIT,
320 |     reranker_min_score: float = 0,
321 | ) -> tuple[list[EntityNode], list[float]]:
322 |     if config is None:
323 |         return [], []
324 | 
325 |     # Build search tasks based on configured search methods
326 |     search_tasks = []
327 |     if NodeSearchMethod.bm25 in config.search_methods:
328 |         search_tasks.append(
329 |             node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
330 |         )
331 |     if NodeSearchMethod.cosine_similarity in config.search_methods:
332 |         search_tasks.append(
333 |             node_similarity_search(
334 |                 driver,
335 |                 query_vector,
336 |                 search_filter,
337 |                 group_ids,
338 |                 2 * limit,
339 |                 config.sim_min_score,
340 |             )
341 |         )
342 |     if NodeSearchMethod.bfs in config.search_methods:
343 |         search_tasks.append(
344 |             node_bfs_search(
345 |                 driver,
346 |                 bfs_origin_node_uuids,
347 |                 search_filter,
348 |                 config.bfs_max_depth,
349 |                 group_ids,
350 |                 2 * limit,
351 |             )
352 |         )
353 | 
354 |     # Execute only the configured search methods
355 |     search_results: list[list[EntityNode]] = []
356 |     if search_tasks:
357 |         search_results = list(await semaphore_gather(*search_tasks))
358 | 
359 |     if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
360 |         origin_node_uuids = [node.uuid for result in search_results for node in result]
361 |         search_results.append(
362 |             await node_bfs_search(
363 |                 driver,
364 |                 origin_node_uuids,
365 |                 search_filter,
366 |                 config.bfs_max_depth,
367 |                 group_ids,
368 |                 2 * limit,
369 |             )
370 |         )
371 | 
372 |     search_result_uuids = [[node.uuid for node in result] for result in search_results]
373 |     node_uuid_map = {node.uuid: node for result in search_results for node in result}
374 | 
375 |     reranked_uuids: list[str] = []
376 |     node_scores: list[float] = []
377 |     if config.reranker == NodeReranker.rrf:
378 |         reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
379 |     elif config.reranker == NodeReranker.mmr:
380 |         search_result_uuids_and_vectors = await get_embeddings_for_nodes(
381 |             driver, list(node_uuid_map.values())
382 |         )
383 | 
384 |         reranked_uuids, node_scores = maximal_marginal_relevance(
385 |             query_vector,
386 |             search_result_uuids_and_vectors,
387 |             config.mmr_lambda,
388 |             reranker_min_score,
389 |         )
390 |     elif config.reranker == NodeReranker.cross_encoder:
391 |         name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
392 | 
393 |         reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
394 |         reranked_uuids = [
395 |             name_to_uuid_map[name]
396 |             for name, score in reranked_node_names
397 |             if score >= reranker_min_score
398 |         ]
399 |         node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
400 |     elif config.reranker == NodeReranker.episode_mentions:
401 |         reranked_uuids, node_scores = await episode_mentions_reranker(
402 |             driver, search_result_uuids, min_score=reranker_min_score
403 |         )
404 |     elif config.reranker == NodeReranker.node_distance:
405 |         if center_node_uuid is None:
406 |             raise SearchRerankerError('No center node provided for Node Distance reranker')
407 |         reranked_uuids, node_scores = await node_distance_reranker(
408 |             driver,
409 |             rrf(search_result_uuids, min_score=reranker_min_score)[0],
410 |             center_node_uuid,
411 |             min_score=reranker_min_score,
412 |         )
413 | 
414 |     reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
415 | 
416 |     return reranked_nodes[:limit], node_scores[:limit]
417 | 
418 | 
419 | async def episode_search(
420 |     driver: GraphDriver,
421 |     cross_encoder: CrossEncoderClient,
422 |     query: str,
423 |     _query_vector: list[float],
424 |     group_ids: list[str] | None,
425 |     config: EpisodeSearchConfig | None,
426 |     search_filter: SearchFilters,
427 |     limit=DEFAULT_SEARCH_LIMIT,
428 |     reranker_min_score: float = 0,
429 | ) -> tuple[list[EpisodicNode], list[float]]:
430 |     if config is None:
431 |         return [], []
432 |     search_results: list[list[EpisodicNode]] = list(
433 |         await semaphore_gather(
434 |             *[
435 |                 episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
436 |             ]
437 |         )
438 |     )
439 | 
440 |     search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
441 |     episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
442 | 
443 |     reranked_uuids: list[str] = []
444 |     episode_scores: list[float] = []
445 |     if config.reranker == EpisodeReranker.rrf:
446 |         reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
447 | 
448 |     elif config.reranker == EpisodeReranker.cross_encoder:
449 |         # use rrf as a preliminary reranker
450 |         rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
451 |         rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
452 | 
453 |         content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
454 | 
455 |         reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
456 |         reranked_uuids = [
457 |             content_to_uuid_map[content]
458 |             for content, score in reranked_contents
459 |             if score >= reranker_min_score
460 |         ]
461 |         episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
462 | 
463 |     reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
464 | 
465 |     return reranked_episodes[:limit], episode_scores[:limit]
466 | 
467 | 
468 | async def community_search(
469 |     driver: GraphDriver,
470 |     cross_encoder: CrossEncoderClient,
471 |     query: str,
472 |     query_vector: list[float],
473 |     group_ids: list[str] | None,
474 |     config: CommunitySearchConfig | None,
475 |     limit=DEFAULT_SEARCH_LIMIT,
476 |     reranker_min_score: float = 0,
477 | ) -> tuple[list[CommunityNode], list[float]]:
478 |     if config is None:
479 |         return [], []
480 | 
481 |     search_results: list[list[CommunityNode]] = list(
482 |         await semaphore_gather(
483 |             *[
484 |                 community_fulltext_search(driver, query, group_ids, 2 * limit),
485 |                 community_similarity_search(
486 |                     driver, query_vector, group_ids, 2 * limit, config.sim_min_score
487 |                 ),
488 |             ]
489 |         )
490 |     )
491 | 
492 |     search_result_uuids = [[community.uuid for community in result] for result in search_results]
493 |     community_uuid_map = {
494 |         community.uuid: community for result in search_results for community in result
495 |     }
496 | 
497 |     reranked_uuids: list[str] = []
498 |     community_scores: list[float] = []
499 |     if config.reranker == CommunityReranker.rrf:
500 |         reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
501 |     elif config.reranker == CommunityReranker.mmr:
502 |         search_result_uuids_and_vectors = await get_embeddings_for_communities(
503 |             driver, list(community_uuid_map.values())
504 |         )
505 | 
506 |         reranked_uuids, community_scores = maximal_marginal_relevance(
507 |             query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
508 |         )
509 |     elif config.reranker == CommunityReranker.cross_encoder:
510 |         name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
511 |         reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
512 |         reranked_uuids = [
513 |             name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
514 |         ]
515 |         community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
516 | 
517 |     reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
518 | 
519 |     return reranked_communities[:limit], community_scores[:limit]
520 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_integration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
  4 | Tests all major MCP tools and handles episode processing latency.
  5 | """
  6 | 
  7 | import asyncio
  8 | import json
  9 | import os
 10 | import time
 11 | from typing import Any
 12 | 
 13 | from mcp import ClientSession, StdioServerParameters
 14 | from mcp.client.stdio import stdio_client
 15 | 
 16 | 
 17 | class GraphitiMCPIntegrationTest:
 18 |     """Integration test client for Graphiti MCP Server using official MCP SDK."""
 19 | 
 20 |     def __init__(self):
 21 |         self.test_group_id = f'test_group_{int(time.time())}'
 22 |         self.session = None
 23 | 
 24 |     async def __aenter__(self):
 25 |         """Start the MCP client session."""
 26 |         # Configure server parameters to run our refactored server
 27 |         server_params = StdioServerParameters(
 28 |             command='uv',
 29 |             args=['run', 'main.py', '--transport', 'stdio'],
 30 |             env={
 31 |                 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
 32 |                 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
 33 |                 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
 34 |                 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
 35 |             },
 36 |         )
 37 | 
 38 |         print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
 39 | 
 40 |         # Use the async context manager properly
 41 |         self.client_context = stdio_client(server_params)
 42 |         read, write = await self.client_context.__aenter__()
 43 |         self.session = ClientSession(read, write)
 44 |         await self.session.initialize()
 45 | 
 46 |         return self
 47 | 
 48 |     async def __aexit__(self, exc_type, exc_val, exc_tb):
 49 |         """Close the MCP client session."""
 50 |         if self.session:
 51 |             await self.session.close()
 52 |         if hasattr(self, 'client_context'):
 53 |             await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
 54 | 
 55 |     async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
 56 |         """Call an MCP tool and return the result."""
 57 |         try:
 58 |             result = await self.session.call_tool(tool_name, arguments)
 59 |             return result.content[0].text if result.content else {'error': 'No content returned'}
 60 |         except Exception as e:
 61 |             return {'error': str(e)}
 62 | 
 63 |     async def test_server_initialization(self) -> bool:
 64 |         """Test that the server initializes properly."""
 65 |         print('🔍 Testing server initialization...')
 66 | 
 67 |         try:
 68 |             # List available tools to verify server is responding
 69 |             tools_result = await self.session.list_tools()
 70 |             tools = [tool.name for tool in tools_result.tools]
 71 | 
 72 |             expected_tools = [
 73 |                 'add_memory',
 74 |                 'search_memory_nodes',
 75 |                 'search_memory_facts',
 76 |                 'get_episodes',
 77 |                 'delete_episode',
 78 |                 'delete_entity_edge',
 79 |                 'get_entity_edge',
 80 |                 'clear_graph',
 81 |             ]
 82 | 
 83 |             available_tools = len([tool for tool in expected_tools if tool in tools])
 84 |             print(
 85 |                 f'   ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
 86 |             )
 87 |             print(f'   Available tools: {", ".join(sorted(tools))}')
 88 | 
 89 |             return available_tools >= len(expected_tools) * 0.8  # 80% of expected tools
 90 | 
 91 |         except Exception as e:
 92 |             print(f'   ❌ Server initialization failed: {e}')
 93 |             return False
 94 | 
 95 |     async def test_add_memory_operations(self) -> dict[str, bool]:
 96 |         """Test adding various types of memory episodes."""
 97 |         print('📝 Testing add_memory operations...')
 98 | 
 99 |         results = {}
100 | 
101 |         # Test 1: Add text episode
102 |         print('   Testing text episode...')
103 |         try:
104 |             result = await self.call_tool(
105 |                 'add_memory',
106 |                 {
107 |                     'name': 'Test Company News',
108 |                     'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
109 |                     'source': 'text',
110 |                     'source_description': 'news article',
111 |                     'group_id': self.test_group_id,
112 |                 },
113 |             )
114 | 
115 |             if isinstance(result, str) and 'queued' in result.lower():
116 |                 print(f'   ✅ Text episode: {result}')
117 |                 results['text'] = True
118 |             else:
119 |                 print(f'   ❌ Text episode failed: {result}')
120 |                 results['text'] = False
121 |         except Exception as e:
122 |             print(f'   ❌ Text episode error: {e}')
123 |             results['text'] = False
124 | 
125 |         # Test 2: Add JSON episode
126 |         print('   Testing JSON episode...')
127 |         try:
128 |             json_data = {
129 |                 'company': {'name': 'TechCorp', 'founded': 2010},
130 |                 'products': [
131 |                     {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
132 |                     {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
133 |                 ],
134 |                 'employees': 150,
135 |             }
136 | 
137 |             result = await self.call_tool(
138 |                 'add_memory',
139 |                 {
140 |                     'name': 'Company Profile',
141 |                     'episode_body': json.dumps(json_data),
142 |                     'source': 'json',
143 |                     'source_description': 'CRM data',
144 |                     'group_id': self.test_group_id,
145 |                 },
146 |             )
147 | 
148 |             if isinstance(result, str) and 'queued' in result.lower():
149 |                 print(f'   ✅ JSON episode: {result}')
150 |                 results['json'] = True
151 |             else:
152 |                 print(f'   ❌ JSON episode failed: {result}')
153 |                 results['json'] = False
154 |         except Exception as e:
155 |             print(f'   ❌ JSON episode error: {e}')
156 |             results['json'] = False
157 | 
158 |         # Test 3: Add message episode
159 |         print('   Testing message episode...')
160 |         try:
161 |             result = await self.call_tool(
162 |                 'add_memory',
163 |                 {
164 |                     'name': 'Customer Support Chat',
165 |                     'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
166 |                     'source': 'message',
167 |                     'source_description': 'support chat log',
168 |                     'group_id': self.test_group_id,
169 |                 },
170 |             )
171 | 
172 |             if isinstance(result, str) and 'queued' in result.lower():
173 |                 print(f'   ✅ Message episode: {result}')
174 |                 results['message'] = True
175 |             else:
176 |                 print(f'   ❌ Message episode failed: {result}')
177 |                 results['message'] = False
178 |         except Exception as e:
179 |             print(f'   ❌ Message episode error: {e}')
180 |             results['message'] = False
181 | 
182 |         return results
183 | 
184 |     async def wait_for_processing(self, max_wait: int = 45) -> bool:
185 |         """Wait for episode processing to complete."""
186 |         print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
187 | 
188 |         for i in range(max_wait):
189 |             await asyncio.sleep(1)
190 | 
191 |             try:
192 |                 # Check if we have any episodes
193 |                 result = await self.call_tool(
194 |                     'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
195 |                 )
196 | 
197 |                 # Parse the JSON result if it's a string
198 |                 if isinstance(result, str):
199 |                     try:
200 |                         parsed_result = json.loads(result)
201 |                         if isinstance(parsed_result, list) and len(parsed_result) > 0:
202 |                             print(
203 |                                 f'   ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
204 |                             )
205 |                             return True
206 |                     except json.JSONDecodeError:
207 |                         if 'episodes' in result.lower():
208 |                             print(f'   ✅ Episodes detected after {i + 1} seconds')
209 |                             return True
210 | 
211 |             except Exception as e:
212 |                 if i == 0:  # Only log first error to avoid spam
213 |                     print(f'   ⚠️  Waiting for processing... ({e})')
214 |                 continue
215 | 
216 |         print(f'   ⚠️  Still waiting after {max_wait} seconds...')
217 |         return False
218 | 
219 |     async def test_search_operations(self) -> dict[str, bool]:
220 |         """Test search functionality."""
221 |         print('🔍 Testing search operations...')
222 | 
223 |         results = {}
224 | 
225 |         # Test search_memory_nodes
226 |         print('   Testing search_memory_nodes...')
227 |         try:
228 |             result = await self.call_tool(
229 |                 'search_memory_nodes',
230 |                 {
231 |                     'query': 'Acme Corp product launch AI',
232 |                     'group_ids': [self.test_group_id],
233 |                     'max_nodes': 5,
234 |                 },
235 |             )
236 | 
237 |             success = False
238 |             if isinstance(result, str):
239 |                 try:
240 |                     parsed = json.loads(result)
241 |                     nodes = parsed.get('nodes', [])
242 |                     success = isinstance(nodes, list)
243 |                     print(f'   ✅ Node search returned {len(nodes)} nodes')
244 |                 except json.JSONDecodeError:
245 |                     success = 'nodes' in result.lower() and 'successfully' in result.lower()
246 |                     if success:
247 |                         print('   ✅ Node search completed successfully')
248 | 
249 |             results['nodes'] = success
250 |             if not success:
251 |                 print(f'   ❌ Node search failed: {result}')
252 | 
253 |         except Exception as e:
254 |             print(f'   ❌ Node search error: {e}')
255 |             results['nodes'] = False
256 | 
257 |         # Test search_memory_facts
258 |         print('   Testing search_memory_facts...')
259 |         try:
260 |             result = await self.call_tool(
261 |                 'search_memory_facts',
262 |                 {
263 |                     'query': 'company products software TechCorp',
264 |                     'group_ids': [self.test_group_id],
265 |                     'max_facts': 5,
266 |                 },
267 |             )
268 | 
269 |             success = False
270 |             if isinstance(result, str):
271 |                 try:
272 |                     parsed = json.loads(result)
273 |                     facts = parsed.get('facts', [])
274 |                     success = isinstance(facts, list)
275 |                     print(f'   ✅ Fact search returned {len(facts)} facts')
276 |                 except json.JSONDecodeError:
277 |                     success = 'facts' in result.lower() and 'successfully' in result.lower()
278 |                     if success:
279 |                         print('   ✅ Fact search completed successfully')
280 | 
281 |             results['facts'] = success
282 |             if not success:
283 |                 print(f'   ❌ Fact search failed: {result}')
284 | 
285 |         except Exception as e:
286 |             print(f'   ❌ Fact search error: {e}')
287 |             results['facts'] = False
288 | 
289 |         return results
290 | 
291 |     async def test_episode_retrieval(self) -> bool:
292 |         """Test episode retrieval."""
293 |         print('📚 Testing episode retrieval...')
294 | 
295 |         try:
296 |             result = await self.call_tool(
297 |                 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
298 |             )
299 | 
300 |             if isinstance(result, str):
301 |                 try:
302 |                     parsed = json.loads(result)
303 |                     if isinstance(parsed, list):
304 |                         print(f'   ✅ Retrieved {len(parsed)} episodes')
305 | 
306 |                         # Show episode details
307 |                         for i, episode in enumerate(parsed[:3]):
308 |                             name = episode.get('name', 'Unknown')
309 |                             source = episode.get('source', 'unknown')
310 |                             print(f'     Episode {i + 1}: {name} (source: {source})')
311 | 
312 |                         return len(parsed) > 0
313 |                 except json.JSONDecodeError:
314 |                     # Check if response indicates success
315 |                     if 'episode' in result.lower():
316 |                         print('   ✅ Episode retrieval completed')
317 |                         return True
318 | 
319 |             print(f'   ❌ Unexpected result format: {result}')
320 |             return False
321 | 
322 |         except Exception as e:
323 |             print(f'   ❌ Episode retrieval failed: {e}')
324 |             return False
325 | 
326 |     async def test_error_handling(self) -> dict[str, bool]:
327 |         """Test error handling and edge cases."""
328 |         print('🧪 Testing error handling...')
329 | 
330 |         results = {}
331 | 
332 |         # Test with nonexistent group
333 |         print('   Testing nonexistent group handling...')
334 |         try:
335 |             result = await self.call_tool(
336 |                 'search_memory_nodes',
337 |                 {
338 |                     'query': 'nonexistent data',
339 |                     'group_ids': ['nonexistent_group_12345'],
340 |                     'max_nodes': 5,
341 |                 },
342 |             )
343 | 
344 |             # Should handle gracefully, not crash
345 |             success = (
346 |                 'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
347 |             )
348 |             if success:
349 |                 print('   ✅ Nonexistent group handled gracefully')
350 |             else:
351 |                 print(f'   ❌ Nonexistent group caused issues: {result}')
352 | 
353 |             results['nonexistent_group'] = success
354 | 
355 |         except Exception as e:
356 |             print(f'   ❌ Nonexistent group test failed: {e}')
357 |             results['nonexistent_group'] = False
358 | 
359 |         # Test empty query
360 |         print('   Testing empty query handling...')
361 |         try:
362 |             result = await self.call_tool(
363 |                 'search_memory_nodes',
364 |                 {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
365 |             )
366 | 
367 |             # Should handle gracefully
368 |             success = (
369 |                 'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
370 |             )
371 |             if success:
372 |                 print('   ✅ Empty query handled gracefully')
373 |             else:
374 |                 print(f'   ❌ Empty query caused issues: {result}')
375 | 
376 |             results['empty_query'] = success
377 | 
378 |         except Exception as e:
379 |             print(f'   ❌ Empty query test failed: {e}')
380 |             results['empty_query'] = False
381 | 
382 |         return results
383 | 
384 |     async def run_comprehensive_test(self) -> dict[str, Any]:
385 |         """Run the complete integration test suite."""
386 |         print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
387 |         print(f'   Test group ID: {self.test_group_id}')
388 |         print('=' * 70)
389 | 
390 |         results = {
391 |             'server_init': False,
392 |             'add_memory': {},
393 |             'processing_wait': False,
394 |             'search': {},
395 |             'episodes': False,
396 |             'error_handling': {},
397 |             'overall_success': False,
398 |         }
399 | 
400 |         # Test 1: Server Initialization
401 |         results['server_init'] = await self.test_server_initialization()
402 |         if not results['server_init']:
403 |             print('❌ Server initialization failed, aborting remaining tests')
404 |             return results
405 | 
406 |         print()
407 | 
408 |         # Test 2: Add Memory Operations
409 |         results['add_memory'] = await self.test_add_memory_operations()
410 |         print()
411 | 
412 |         # Test 3: Wait for Processing
413 |         results['processing_wait'] = await self.wait_for_processing()
414 |         print()
415 | 
416 |         # Test 4: Search Operations
417 |         results['search'] = await self.test_search_operations()
418 |         print()
419 | 
420 |         # Test 5: Episode Retrieval
421 |         results['episodes'] = await self.test_episode_retrieval()
422 |         print()
423 | 
424 |         # Test 6: Error Handling
425 |         results['error_handling'] = await self.test_error_handling()
426 |         print()
427 | 
428 |         # Calculate overall success
429 |         memory_success = any(results['add_memory'].values())
430 |         search_success = any(results['search'].values()) if results['search'] else False
431 |         error_success = (
432 |             any(results['error_handling'].values()) if results['error_handling'] else True
433 |         )
434 | 
435 |         results['overall_success'] = (
436 |             results['server_init']
437 |             and memory_success
438 |             and (results['episodes'] or results['processing_wait'])
439 |             and error_success
440 |         )
441 | 
442 |         # Print comprehensive summary
443 |         print('=' * 70)
444 |         print('📊 COMPREHENSIVE TEST SUMMARY')
445 |         print('-' * 35)
446 |         print(f'Server Initialization:    {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
447 | 
448 |         memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
449 |         print(
450 |             f'Memory Operations:        {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
451 |         )
452 | 
453 |         print(f'Processing Pipeline:      {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
454 | 
455 |         search_stats = (
456 |             f'({sum(results["search"].values())}/{len(results["search"])} types)'
457 |             if results['search']
458 |             else '(0/0 types)'
459 |         )
460 |         print(
461 |             f'Search Operations:        {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
462 |         )
463 | 
464 |         print(f'Episode Retrieval:        {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
465 | 
466 |         error_stats = (
467 |             f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
468 |             if results['error_handling']
469 |             else '(0/0 cases)'
470 |         )
471 |         print(
472 |             f'Error Handling:           {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
473 |         )
474 | 
475 |         print('-' * 35)
476 |         print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
477 | 
478 |         if results['overall_success']:
479 |             print('\n🎉 The refactored Graphiti MCP server is working correctly!')
480 |             print('   All core functionality has been successfully tested.')
481 |         else:
482 |             print('\n⚠️  Some issues were detected. Review the test results above.')
483 |             print('   The refactoring may need additional attention.')
484 | 
485 |         return results
486 | 
487 | 
488 | async def main():
489 |     """Run the integration test."""
490 |     try:
491 |         async with GraphitiMCPIntegrationTest() as test:
492 |             results = await test.run_comprehensive_test()
493 | 
494 |             # Exit with appropriate code
495 |             exit_code = 0 if results['overall_success'] else 1
496 |             exit(exit_code)
497 |     except Exception as e:
498 |         print(f'❌ Test setup failed: {e}')
499 |         exit(1)
500 | 
501 | 
502 | if __name__ == '__main__':
503 |     asyncio.run(main())
504 | 
```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/gemini_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import re
 20 | import typing
 21 | from typing import TYPE_CHECKING, ClassVar
 22 | 
 23 | from pydantic import BaseModel
 24 | 
 25 | from ..prompts.models import Message
 26 | from .client import LLMClient, get_extraction_language_instruction
 27 | from .config import LLMConfig, ModelSize
 28 | from .errors import RateLimitError
 29 | 
 30 | if TYPE_CHECKING:
 31 |     from google import genai
 32 |     from google.genai import types
 33 | else:
 34 |     try:
 35 |         from google import genai
 36 |         from google.genai import types
 37 |     except ImportError:
 38 |         # If gemini client is not installed, raise an ImportError
 39 |         raise ImportError(
 40 |             'google-genai is required for GeminiClient. '
 41 |             'Install it with: pip install graphiti-core[google-genai]'
 42 |         ) from None
 43 | 
 44 | 
 45 | logger = logging.getLogger(__name__)
 46 | 
 47 | DEFAULT_MODEL = 'gemini-2.5-flash'
 48 | DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'
 49 | 
 50 | # Maximum output tokens for different Gemini models
 51 | GEMINI_MODEL_MAX_TOKENS = {
 52 |     # Gemini 2.5 models
 53 |     'gemini-2.5-pro': 65536,
 54 |     'gemini-2.5-flash': 65536,
 55 |     'gemini-2.5-flash-lite': 64000,
 56 |     # Gemini 2.0 models
 57 |     'gemini-2.0-flash': 8192,
 58 |     'gemini-2.0-flash-lite': 8192,
 59 |     # Gemini 1.5 models
 60 |     'gemini-1.5-pro': 8192,
 61 |     'gemini-1.5-flash': 8192,
 62 |     'gemini-1.5-flash-8b': 8192,
 63 | }
 64 | 
 65 | # Default max tokens for models not in the mapping
 66 | DEFAULT_GEMINI_MAX_TOKENS = 8192
 67 | 
 68 | 
 69 | class GeminiClient(LLMClient):
 70 |     """
 71 |     GeminiClient is a client class for interacting with Google's Gemini language models.
 72 | 
 73 |     This class extends the LLMClient and provides methods to initialize the client
 74 |     and generate responses from the Gemini language model.
 75 | 
 76 |     Attributes:
 77 |         model (str): The model name to use for generating responses.
 78 |         temperature (float): The temperature to use for generating responses.
 79 |         max_tokens (int): The maximum number of tokens to generate in a response.
 80 |         thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
 81 |     Methods:
 82 |         __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
 83 |             Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
 84 | 
 85 |         _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
 86 |             Generates a response from the language model based on the provided messages.
 87 |     """
 88 | 
 89 |     # Class-level constants
 90 |     MAX_RETRIES: ClassVar[int] = 2
 91 | 
 92 |     def __init__(
 93 |         self,
 94 |         config: LLMConfig | None = None,
 95 |         cache: bool = False,
 96 |         max_tokens: int | None = None,
 97 |         thinking_config: types.ThinkingConfig | None = None,
 98 |         client: 'genai.Client | None' = None,
 99 |     ):
100 |         """
101 |         Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
102 | 
103 |         Args:
104 |             config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
105 |             cache (bool): Whether to use caching for responses. Defaults to False.
106 |             thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
107 |                 Only use with models that support thinking (gemini-2.5+). Defaults to None.
108 |             client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
109 |         """
110 |         if config is None:
111 |             config = LLMConfig()
112 | 
113 |         super().__init__(config, cache)
114 | 
115 |         self.model = config.model
116 | 
117 |         if client is None:
118 |             self.client = genai.Client(api_key=config.api_key)
119 |         else:
120 |             self.client = client
121 | 
122 |         self.max_tokens = max_tokens
123 |         self.thinking_config = thinking_config
124 | 
125 |     def _check_safety_blocks(self, response) -> None:
126 |         """Check if response was blocked for safety reasons and raise appropriate exceptions."""
127 |         # Check if the response was blocked for safety reasons
128 |         if not (hasattr(response, 'candidates') and response.candidates):
129 |             return
130 | 
131 |         candidate = response.candidates[0]
132 |         if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
133 |             return
134 | 
135 |         # Content was blocked for safety reasons - collect safety details
136 |         safety_info = []
137 |         safety_ratings = getattr(candidate, 'safety_ratings', None)
138 | 
139 |         if safety_ratings:
140 |             for rating in safety_ratings:
141 |                 if getattr(rating, 'blocked', False):
142 |                     category = getattr(rating, 'category', 'Unknown')
143 |                     probability = getattr(rating, 'probability', 'Unknown')
144 |                     safety_info.append(f'{category}: {probability}')
145 | 
146 |         safety_details = (
147 |             ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
148 |         )
149 |         raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
150 | 
151 |     def _check_prompt_blocks(self, response) -> None:
152 |         """Check if prompt was blocked and raise appropriate exceptions."""
153 |         prompt_feedback = getattr(response, 'prompt_feedback', None)
154 |         if not prompt_feedback:
155 |             return
156 | 
157 |         block_reason = getattr(prompt_feedback, 'block_reason', None)
158 |         if block_reason:
159 |             raise Exception(f'Prompt blocked by Gemini: {block_reason}')
160 | 
161 |     def _get_model_for_size(self, model_size: ModelSize) -> str:
162 |         """Get the appropriate model name based on the requested size."""
163 |         if model_size == ModelSize.small:
164 |             return self.small_model or DEFAULT_SMALL_MODEL
165 |         else:
166 |             return self.model or DEFAULT_MODEL
167 | 
168 |     def _get_max_tokens_for_model(self, model: str) -> int:
169 |         """Get the maximum output tokens for a specific Gemini model."""
170 |         return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
171 | 
172 |     def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
173 |         """
174 |         Resolve the maximum output tokens to use based on precedence rules.
175 | 
176 |         Precedence order (highest to lowest):
177 |         1. Explicit max_tokens parameter passed to generate_response()
178 |         2. Instance max_tokens set during client initialization
179 |         3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
180 |         4. DEFAULT_MAX_TOKENS as final fallback
181 | 
182 |         Args:
183 |             requested_max_tokens: The max_tokens parameter passed to generate_response()
184 |             model: The model name to look up model-specific limits
185 | 
186 |         Returns:
187 |             int: The resolved maximum tokens to use
188 |         """
189 |         # 1. Use explicit parameter if provided
190 |         if requested_max_tokens is not None:
191 |             return requested_max_tokens
192 | 
193 |         # 2. Use instance max_tokens if set during initialization
194 |         if self.max_tokens is not None:
195 |             return self.max_tokens
196 | 
197 |         # 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
198 |         return self._get_max_tokens_for_model(model)
199 | 
200 |     def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
201 |         """
202 |         Attempt to salvage a JSON object if the raw output is truncated.
203 | 
204 |         This is accomplished by looking for the last closing bracket for an array or object.
205 |         If found, it will try to load the JSON object from the raw output.
206 |         If the JSON object is not valid, it will return None.
207 | 
208 |         Args:
209 |             raw_output (str): The raw output from the LLM.
210 | 
211 |         Returns:
212 |             dict[str, typing.Any]: The salvaged JSON object.
213 |             None: If no salvage is possible.
214 |         """
215 |         if not raw_output:
216 |             return None
217 |         # Try to salvage a JSON array
218 |         array_match = re.search(r'\]\s*$', raw_output)
219 |         if array_match:
220 |             try:
221 |                 return json.loads(raw_output[: array_match.end()])
222 |             except Exception:
223 |                 pass
224 |         # Try to salvage a JSON object
225 |         obj_match = re.search(r'\}\s*$', raw_output)
226 |         if obj_match:
227 |             try:
228 |                 return json.loads(raw_output[: obj_match.end()])
229 |             except Exception:
230 |                 pass
231 |         return None
232 | 
233 |     async def _generate_response(
234 |         self,
235 |         messages: list[Message],
236 |         response_model: type[BaseModel] | None = None,
237 |         max_tokens: int | None = None,
238 |         model_size: ModelSize = ModelSize.medium,
239 |     ) -> dict[str, typing.Any]:
240 |         """
241 |         Generate a response from the Gemini language model.
242 | 
243 |         Args:
244 |             messages (list[Message]): A list of messages to send to the language model.
245 |             response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
246 |             max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
247 |             model_size (ModelSize): The size of the model to use (small or medium).
248 | 
249 |         Returns:
250 |             dict[str, typing.Any]: The response from the language model.
251 | 
252 |         Raises:
253 |             RateLimitError: If the API rate limit is exceeded.
254 |             Exception: If there is an error generating the response or content is blocked.
255 |         """
256 |         try:
257 |             gemini_messages: typing.Any = []
258 |             # If a response model is provided, add schema for structured output
259 |             system_prompt = ''
260 |             if response_model is not None:
261 |                 # Get the schema from the Pydantic model
262 |                 pydantic_schema = response_model.model_json_schema()
263 | 
264 |                 # Create instruction to output in the desired JSON format
265 |                 system_prompt += (
266 |                     f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
267 |                     'Do not include any explanatory text before or after the JSON.\n\n'
268 |                 )
269 | 
270 |             # Add messages content
271 |             # First check for a system message
272 |             if messages and messages[0].role == 'system':
273 |                 system_prompt = f'{messages[0].content}\n\n {system_prompt}'
274 |                 messages = messages[1:]
275 | 
276 |             # Add the rest of the messages
277 |             for m in messages:
278 |                 m.content = self._clean_input(m.content)
279 |                 gemini_messages.append(
280 |                     types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
281 |                 )
282 | 
283 |             # Get the appropriate model for the requested size
284 |             model = self._get_model_for_size(model_size)
285 | 
286 |             # Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
287 |             resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
288 | 
289 |             # Create generation config
290 |             generation_config = types.GenerateContentConfig(
291 |                 temperature=self.temperature,
292 |                 max_output_tokens=resolved_max_tokens,
293 |                 response_mime_type='application/json' if response_model else None,
294 |                 response_schema=response_model if response_model else None,
295 |                 system_instruction=system_prompt,
296 |                 thinking_config=self.thinking_config,
297 |             )
298 | 
299 |             # Generate content using the simple string approach
300 |             response = await self.client.aio.models.generate_content(
301 |                 model=model,
302 |                 contents=gemini_messages,
303 |                 config=generation_config,
304 |             )
305 | 
306 |             # Always capture the raw output for debugging
307 |             raw_output = getattr(response, 'text', None)
308 | 
309 |             # Check for safety and prompt blocks
310 |             self._check_safety_blocks(response)
311 |             self._check_prompt_blocks(response)
312 | 
313 |             # If this was a structured output request, parse the response into the Pydantic model
314 |             if response_model is not None:
315 |                 try:
316 |                     if not raw_output:
317 |                         raise ValueError('No response text')
318 | 
319 |                     validated_model = response_model.model_validate(json.loads(raw_output))
320 | 
321 |                     # Return as a dictionary for API consistency
322 |                     return validated_model.model_dump()
323 |                 except Exception as e:
324 |                     if raw_output:
325 |                         logger.error(
326 |                             '🦀 LLM generation failed parsing as JSON, will try to salvage.'
327 |                         )
328 |                         logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
329 |                         # Try to salvage
330 |                         salvaged = self.salvage_json(raw_output)
331 |                         if salvaged is not None:
332 |                             logger.warning('Salvaged partial JSON from truncated/malformed output.')
333 |                             return salvaged
334 |                     raise Exception(f'Failed to parse structured response: {e}') from e
335 | 
336 |             # Otherwise, return the response text as a dictionary
337 |             return {'content': raw_output}
338 | 
339 |         except Exception as e:
340 |             # Check if it's a rate limit error based on Gemini API error codes
341 |             error_message = str(e).lower()
342 |             if (
343 |                 'rate limit' in error_message
344 |                 or 'quota' in error_message
345 |                 or 'resource_exhausted' in error_message
346 |                 or '429' in str(e)
347 |             ):
348 |                 raise RateLimitError from e
349 | 
350 |             logger.error(f'Error in generating LLM response: {e}')
351 |             raise Exception from e
352 | 
353 |     async def generate_response(
354 |         self,
355 |         messages: list[Message],
356 |         response_model: type[BaseModel] | None = None,
357 |         max_tokens: int | None = None,
358 |         model_size: ModelSize = ModelSize.medium,
359 |         group_id: str | None = None,
360 |         prompt_name: str | None = None,
361 |     ) -> dict[str, typing.Any]:
362 |         """
363 |         Generate a response from the Gemini language model with retry logic and error handling.
364 |         This method overrides the parent class method to provide a direct implementation with advanced retry logic.
365 | 
366 |         Args:
367 |             messages (list[Message]): A list of messages to send to the language model.
368 |             response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
369 |             max_tokens (int | None): The maximum number of tokens to generate in the response.
370 |             model_size (ModelSize): The size of the model to use (small or medium).
371 |             group_id (str | None): Optional partition identifier for the graph.
372 |             prompt_name (str | None): Optional name of the prompt for tracing.
373 | 
374 |         Returns:
375 |             dict[str, typing.Any]: The response from the language model.
376 |         """
377 |         # Add multilingual extraction instructions
378 |         messages[0].content += get_extraction_language_instruction(group_id)
379 | 
380 |         # Wrap entire operation in tracing span
381 |         with self.tracer.start_span('llm.generate') as span:
382 |             attributes = {
383 |                 'llm.provider': 'gemini',
384 |                 'model.size': model_size.value,
385 |                 'max_tokens': max_tokens or self.max_tokens,
386 |             }
387 |             if prompt_name:
388 |                 attributes['prompt.name'] = prompt_name
389 |             span.add_attributes(attributes)
390 | 
391 |             retry_count = 0
392 |             last_error = None
393 |             last_output = None
394 | 
395 |             while retry_count < self.MAX_RETRIES:
396 |                 try:
397 |                     response = await self._generate_response(
398 |                         messages=messages,
399 |                         response_model=response_model,
400 |                         max_tokens=max_tokens,
401 |                         model_size=model_size,
402 |                     )
403 |                     last_output = (
404 |                         response.get('content')
405 |                         if isinstance(response, dict) and 'content' in response
406 |                         else None
407 |                     )
408 |                     return response
409 |                 except RateLimitError as e:
410 |                     # Rate limit errors should not trigger retries (fail fast)
411 |                     span.set_status('error', str(e))
412 |                     raise e
413 |                 except Exception as e:
414 |                     last_error = e
415 | 
416 |                     # Check if this is a safety block - these typically shouldn't be retried
417 |                     error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
418 |                     if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
419 |                         logger.warning(f'Content blocked by safety filters: {e}')
420 |                         span.set_status('error', str(e))
421 |                         raise Exception(f'Content blocked by safety filters: {e}') from e
422 | 
423 |                     retry_count += 1
424 | 
425 |                     # Construct a detailed error message for the LLM
426 |                     error_context = (
427 |                         f'The previous response attempt was invalid. '
428 |                         f'Error type: {e.__class__.__name__}. '
429 |                         f'Error details: {str(e)}. '
430 |                         f'Please try again with a valid response, ensuring the output matches '
431 |                         f'the expected format and constraints.'
432 |                     )
433 | 
434 |                     error_message = Message(role='user', content=error_context)
435 |                     messages.append(error_message)
436 |                     logger.warning(
437 |                         f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
438 |                     )
439 | 
440 |             # If we exit the loop without returning, all retries are exhausted
441 |             logger.error('🦀 LLM generation failed and retries are exhausted.')
442 |             logger.error(self._get_failed_generation_log(messages, last_output))
443 |             logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
444 |             span.set_status('error', str(last_error))
445 |             span.record_exception(last_error) if last_error else None
446 |             raise last_error or Exception('Max retries exceeded')
447 | 
```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_edge_operations.py:
--------------------------------------------------------------------------------

```python
  1 | from datetime import datetime, timedelta, timezone
  2 | from types import SimpleNamespace
  3 | from unittest.mock import AsyncMock, MagicMock
  4 | 
  5 | import pytest
  6 | from pydantic import BaseModel
  7 | 
  8 | from graphiti_core.edges import EntityEdge
  9 | from graphiti_core.nodes import EntityNode, EpisodicNode
 10 | from graphiti_core.search.search_config import SearchResults
 11 | from graphiti_core.utils.maintenance.edge_operations import (
 12 |     DEFAULT_EDGE_NAME,
 13 |     resolve_extracted_edge,
 14 |     resolve_extracted_edges,
 15 | )
 16 | 
 17 | 
 18 | @pytest.fixture
 19 | def mock_llm_client():
 20 |     client = MagicMock()
 21 |     client.generate_response = AsyncMock()
 22 |     return client
 23 | 
 24 | 
 25 | @pytest.fixture
 26 | def mock_extracted_edge():
 27 |     return EntityEdge(
 28 |         source_node_uuid='source_uuid',
 29 |         target_node_uuid='target_uuid',
 30 |         name='test_edge',
 31 |         group_id='group_1',
 32 |         fact='Test fact',
 33 |         episodes=['episode_1'],
 34 |         created_at=datetime.now(timezone.utc),
 35 |         valid_at=None,
 36 |         invalid_at=None,
 37 |     )
 38 | 
 39 | 
 40 | @pytest.fixture
 41 | def mock_related_edges():
 42 |     return [
 43 |         EntityEdge(
 44 |             source_node_uuid='source_uuid_2',
 45 |             target_node_uuid='target_uuid_2',
 46 |             name='related_edge',
 47 |             group_id='group_1',
 48 |             fact='Related fact',
 49 |             episodes=['episode_2'],
 50 |             created_at=datetime.now(timezone.utc) - timedelta(days=1),
 51 |             valid_at=datetime.now(timezone.utc) - timedelta(days=1),
 52 |             invalid_at=None,
 53 |         )
 54 |     ]
 55 | 
 56 | 
 57 | @pytest.fixture
 58 | def mock_existing_edges():
 59 |     return [
 60 |         EntityEdge(
 61 |             source_node_uuid='source_uuid_3',
 62 |             target_node_uuid='target_uuid_3',
 63 |             name='existing_edge',
 64 |             group_id='group_1',
 65 |             fact='Existing fact',
 66 |             episodes=['episode_3'],
 67 |             created_at=datetime.now(timezone.utc) - timedelta(days=2),
 68 |             valid_at=datetime.now(timezone.utc) - timedelta(days=2),
 69 |             invalid_at=None,
 70 |         )
 71 |     ]
 72 | 
 73 | 
 74 | @pytest.fixture
 75 | def mock_current_episode():
 76 |     return EpisodicNode(
 77 |         uuid='episode_1',
 78 |         content='Current episode content',
 79 |         valid_at=datetime.now(timezone.utc),
 80 |         name='Current Episode',
 81 |         group_id='group_1',
 82 |         source='message',
 83 |         source_description='Test source description',
 84 |     )
 85 | 
 86 | 
 87 | @pytest.fixture
 88 | def mock_previous_episodes():
 89 |     return [
 90 |         EpisodicNode(
 91 |             uuid='episode_2',
 92 |             content='Previous episode content',
 93 |             valid_at=datetime.now(timezone.utc) - timedelta(days=1),
 94 |             name='Previous Episode',
 95 |             group_id='group_1',
 96 |             source='message',
 97 |             source_description='Test source description',
 98 |         )
 99 |     ]
100 | 
101 | 
102 | # Run the tests
103 | if __name__ == '__main__':
104 |     pytest.main([__file__])
105 | 
106 | 
107 | @pytest.mark.asyncio
108 | async def test_resolve_extracted_edge_exact_fact_short_circuit(
109 |     mock_llm_client,
110 |     mock_existing_edges,
111 |     mock_current_episode,
112 | ):
113 |     extracted = EntityEdge(
114 |         source_node_uuid='source_uuid',
115 |         target_node_uuid='target_uuid',
116 |         name='test_edge',
117 |         group_id='group_1',
118 |         fact='Related fact',
119 |         episodes=['episode_1'],
120 |         created_at=datetime.now(timezone.utc),
121 |         valid_at=None,
122 |         invalid_at=None,
123 |     )
124 | 
125 |     related_edges = [
126 |         EntityEdge(
127 |             source_node_uuid='source_uuid',
128 |             target_node_uuid='target_uuid',
129 |             name='related_edge',
130 |             group_id='group_1',
131 |             fact=' related FACT  ',
132 |             episodes=['episode_2'],
133 |             created_at=datetime.now(timezone.utc) - timedelta(days=1),
134 |             valid_at=None,
135 |             invalid_at=None,
136 |         )
137 |     ]
138 | 
139 |     resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
140 |         mock_llm_client,
141 |         extracted,
142 |         related_edges,
143 |         mock_existing_edges,
144 |         mock_current_episode,
145 |         edge_type_candidates=None,
146 |     )
147 | 
148 |     assert resolved_edge is related_edges[0]
149 |     assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
150 |     assert duplicate_edges == []
151 |     assert invalidated == []
152 |     mock_llm_client.generate_response.assert_not_called()
153 | 
154 | 
155 | class OccurredAtEdge(BaseModel):
156 |     """Edge model stub for OCCURRED_AT."""
157 | 
158 | 
159 | @pytest.mark.asyncio
160 | async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
161 |     from graphiti_core.utils.maintenance import edge_operations as edge_ops
162 | 
163 |     monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
164 |     monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
165 | 
166 |     async def immediate_gather(*aws, max_coroutines=None):
167 |         return [await aw for aw in aws]
168 | 
169 |     monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
170 |     monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
171 | 
172 |     llm_client = MagicMock()
173 |     llm_client.generate_response = AsyncMock(
174 |         return_value={
175 |             'duplicate_facts': [],
176 |             'contradicted_facts': [],
177 |             'fact_type': 'DEFAULT',
178 |         }
179 |     )
180 | 
181 |     clients = SimpleNamespace(
182 |         driver=MagicMock(),
183 |         llm_client=llm_client,
184 |         embedder=MagicMock(),
185 |         cross_encoder=MagicMock(),
186 |     )
187 | 
188 |     source_node = EntityNode(
189 |         uuid='source_uuid',
190 |         name='Document Node',
191 |         group_id='group_1',
192 |         labels=['Document'],
193 |     )
194 |     target_node = EntityNode(
195 |         uuid='target_uuid',
196 |         name='Topic Node',
197 |         group_id='group_1',
198 |         labels=['Topic'],
199 |     )
200 | 
201 |     extracted_edge = EntityEdge(
202 |         source_node_uuid=source_node.uuid,
203 |         target_node_uuid=target_node.uuid,
204 |         name='OCCURRED_AT',
205 |         group_id='group_1',
206 |         fact='Document occurred at somewhere',
207 |         episodes=[],
208 |         created_at=datetime.now(timezone.utc),
209 |         valid_at=None,
210 |         invalid_at=None,
211 |     )
212 | 
213 |     episode = EpisodicNode(
214 |         uuid='episode_uuid',
215 |         name='Episode',
216 |         group_id='group_1',
217 |         source='message',
218 |         source_description='desc',
219 |         content='Episode content',
220 |         valid_at=datetime.now(timezone.utc),
221 |     )
222 | 
223 |     edge_types = {'OCCURRED_AT': OccurredAtEdge}
224 |     edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
225 | 
226 |     resolved_edges, invalidated_edges = await resolve_extracted_edges(
227 |         clients,
228 |         [extracted_edge],
229 |         episode,
230 |         [source_node, target_node],
231 |         edge_types,
232 |         edge_type_map,
233 |     )
234 | 
235 |     assert resolved_edges[0].name == DEFAULT_EDGE_NAME
236 |     assert invalidated_edges == []
237 | 
238 | 
239 | @pytest.mark.asyncio
240 | async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
241 |     from graphiti_core.utils.maintenance import edge_operations as edge_ops
242 | 
243 |     monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
244 |     monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
245 | 
246 |     async def immediate_gather(*aws, max_coroutines=None):
247 |         return [await aw for aw in aws]
248 | 
249 |     monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
250 |     monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
251 | 
252 |     llm_client = MagicMock()
253 |     llm_client.generate_response = AsyncMock(
254 |         return_value={
255 |             'duplicate_facts': [],
256 |             'contradicted_facts': [],
257 |             'fact_type': 'DEFAULT',
258 |         }
259 |     )
260 | 
261 |     clients = SimpleNamespace(
262 |         driver=MagicMock(),
263 |         llm_client=llm_client,
264 |         embedder=MagicMock(),
265 |         cross_encoder=MagicMock(),
266 |     )
267 | 
268 |     source_node = EntityNode(
269 |         uuid='source_uuid',
270 |         name='User Node',
271 |         group_id='group_1',
272 |         labels=['User'],
273 |     )
274 |     target_node = EntityNode(
275 |         uuid='target_uuid',
276 |         name='Topic Node',
277 |         group_id='group_1',
278 |         labels=['Topic'],
279 |     )
280 | 
281 |     extracted_edge = EntityEdge(
282 |         source_node_uuid=source_node.uuid,
283 |         target_node_uuid=target_node.uuid,
284 |         name='INTERACTED_WITH',
285 |         group_id='group_1',
286 |         fact='User interacted with topic',
287 |         episodes=[],
288 |         created_at=datetime.now(timezone.utc),
289 |         valid_at=None,
290 |         invalid_at=None,
291 |     )
292 | 
293 |     episode = EpisodicNode(
294 |         uuid='episode_uuid',
295 |         name='Episode',
296 |         group_id='group_1',
297 |         source='message',
298 |         source_description='desc',
299 |         content='Episode content',
300 |         valid_at=datetime.now(timezone.utc),
301 |     )
302 | 
303 |     edge_types = {'OCCURRED_AT': OccurredAtEdge}
304 |     edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
305 | 
306 |     resolved_edges, invalidated_edges = await resolve_extracted_edges(
307 |         clients,
308 |         [extracted_edge],
309 |         episode,
310 |         [source_node, target_node],
311 |         edge_types,
312 |         edge_type_map,
313 |     )
314 | 
315 |     assert resolved_edges[0].name == 'INTERACTED_WITH'
316 |     assert invalidated_edges == []
317 | 
318 | 
319 | @pytest.mark.asyncio
320 | async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
321 |     mock_llm_client.generate_response.return_value = {
322 |         'duplicate_facts': [],
323 |         'contradicted_facts': [],
324 |         'fact_type': 'OCCURRED_AT',
325 |     }
326 | 
327 |     extracted_edge = EntityEdge(
328 |         source_node_uuid='source_uuid',
329 |         target_node_uuid='target_uuid',
330 |         name='OCCURRED_AT',
331 |         group_id='group_1',
332 |         fact='Document occurred at somewhere',
333 |         episodes=[],
334 |         created_at=datetime.now(timezone.utc),
335 |         valid_at=None,
336 |         invalid_at=None,
337 |     )
338 | 
339 |     episode = EpisodicNode(
340 |         uuid='episode_uuid',
341 |         name='Episode',
342 |         group_id='group_1',
343 |         source='message',
344 |         source_description='desc',
345 |         content='Episode content',
346 |         valid_at=datetime.now(timezone.utc),
347 |     )
348 | 
349 |     related_edge = EntityEdge(
350 |         source_node_uuid='alt_source',
351 |         target_node_uuid='alt_target',
352 |         name='OTHER',
353 |         group_id='group_1',
354 |         fact='Different fact',
355 |         episodes=[],
356 |         created_at=datetime.now(timezone.utc),
357 |         valid_at=None,
358 |         invalid_at=None,
359 |     )
360 | 
361 |     resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
362 |         mock_llm_client,
363 |         extracted_edge,
364 |         [related_edge],
365 |         [],
366 |         episode,
367 |         edge_type_candidates={},
368 |         custom_edge_type_names={'OCCURRED_AT'},
369 |     )
370 | 
371 |     assert resolved_edge.name == DEFAULT_EDGE_NAME
372 |     assert duplicates == []
373 |     assert invalidated == []
374 | 
375 | 
376 | @pytest.mark.asyncio
377 | async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
378 |     mock_llm_client.generate_response.return_value = {
379 |         'duplicate_facts': [],
380 |         'contradicted_facts': [],
381 |         'fact_type': 'INTERACTED_WITH',
382 |     }
383 | 
384 |     extracted_edge = EntityEdge(
385 |         source_node_uuid='source_uuid',
386 |         target_node_uuid='target_uuid',
387 |         name='DEFAULT',
388 |         group_id='group_1',
389 |         fact='User interacted with topic',
390 |         episodes=[],
391 |         created_at=datetime.now(timezone.utc),
392 |         valid_at=None,
393 |         invalid_at=None,
394 |     )
395 | 
396 |     episode = EpisodicNode(
397 |         uuid='episode_uuid',
398 |         name='Episode',
399 |         group_id='group_1',
400 |         source='message',
401 |         source_description='desc',
402 |         content='Episode content',
403 |         valid_at=datetime.now(timezone.utc),
404 |     )
405 | 
406 |     related_edge = EntityEdge(
407 |         source_node_uuid='source_uuid',
408 |         target_node_uuid='target_uuid',
409 |         name='DEFAULT',
410 |         group_id='group_1',
411 |         fact='User mentioned a topic',
412 |         episodes=[],
413 |         created_at=datetime.now(timezone.utc),
414 |         valid_at=None,
415 |         invalid_at=None,
416 |     )
417 | 
418 |     resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
419 |         mock_llm_client,
420 |         extracted_edge,
421 |         [related_edge],
422 |         [],
423 |         episode,
424 |         edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
425 |         custom_edge_type_names={'OCCURRED_AT'},
426 |     )
427 | 
428 |     assert resolved_edge.name == 'INTERACTED_WITH'
429 |     assert resolved_edge.attributes == {}
430 |     assert duplicates == []
431 |     assert invalidated == []
432 | 
433 | 
434 | @pytest.mark.asyncio
435 | async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
436 |     """Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
437 |     # Mock LLM to return duplicate_facts with integer indices
438 |     mock_llm_client.generate_response.return_value = {
439 |         'duplicate_facts': [0, 1],  # LLM identifies first two related edges as duplicates
440 |         'contradicted_facts': [],
441 |         'fact_type': 'DEFAULT',
442 |     }
443 | 
444 |     extracted_edge = EntityEdge(
445 |         source_node_uuid='source_uuid',
446 |         target_node_uuid='target_uuid',
447 |         name='test_edge',
448 |         group_id='group_1',
449 |         fact='User likes yoga',
450 |         episodes=[],
451 |         created_at=datetime.now(timezone.utc),
452 |         valid_at=None,
453 |         invalid_at=None,
454 |     )
455 | 
456 |     episode = EpisodicNode(
457 |         uuid='episode_uuid',
458 |         name='Episode',
459 |         group_id='group_1',
460 |         source='message',
461 |         source_description='desc',
462 |         content='Episode content',
463 |         valid_at=datetime.now(timezone.utc),
464 |     )
465 | 
466 |     # Create multiple related edges - LLM should receive these with integer indices
467 |     related_edge_0 = EntityEdge(
468 |         source_node_uuid='source_uuid',
469 |         target_node_uuid='target_uuid',
470 |         name='test_edge',
471 |         group_id='group_1',
472 |         fact='User enjoys yoga',
473 |         episodes=['episode_1'],
474 |         created_at=datetime.now(timezone.utc) - timedelta(days=1),
475 |         valid_at=None,
476 |         invalid_at=None,
477 |     )
478 | 
479 |     related_edge_1 = EntityEdge(
480 |         source_node_uuid='source_uuid',
481 |         target_node_uuid='target_uuid',
482 |         name='test_edge',
483 |         group_id='group_1',
484 |         fact='User practices yoga',
485 |         episodes=['episode_2'],
486 |         created_at=datetime.now(timezone.utc) - timedelta(days=2),
487 |         valid_at=None,
488 |         invalid_at=None,
489 |     )
490 | 
491 |     related_edge_2 = EntityEdge(
492 |         source_node_uuid='source_uuid',
493 |         target_node_uuid='target_uuid',
494 |         name='test_edge',
495 |         group_id='group_1',
496 |         fact='User loves swimming',
497 |         episodes=['episode_3'],
498 |         created_at=datetime.now(timezone.utc) - timedelta(days=3),
499 |         valid_at=None,
500 |         invalid_at=None,
501 |     )
502 | 
503 |     related_edges = [related_edge_0, related_edge_1, related_edge_2]
504 | 
505 |     resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
506 |         mock_llm_client,
507 |         extracted_edge,
508 |         related_edges,
509 |         [],
510 |         episode,
511 |         edge_type_candidates=None,
512 |         custom_edge_type_names=set(),
513 |     )
514 | 
515 |     # Verify LLM was called
516 |     mock_llm_client.generate_response.assert_called_once()
517 | 
518 |     # Verify the system correctly identified duplicates using integer indices
519 |     # The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
520 |     assert len(duplicates) == 2
521 |     assert related_edge_0 in duplicates
522 |     assert related_edge_1 in duplicates
523 |     assert invalidated == []
524 | 
525 |     # Verify that the resolved edge is one of the duplicates (the first one found)
526 |     # Check UUID since the episode list gets modified
527 |     assert resolved_edge.uuid == related_edge_0.uuid
528 |     assert episode.uuid in resolved_edge.episodes
529 | 
530 | 
531 | @pytest.mark.asyncio
532 | async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
533 |     """Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
534 |     from graphiti_core.utils.maintenance import edge_operations as edge_ops
535 | 
536 |     monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
537 |     monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
538 | 
539 |     # Track how many times resolve_extracted_edge is called
540 |     resolve_call_count = 0
541 | 
542 |     async def mock_resolve_extracted_edge(
543 |         llm_client,
544 |         extracted_edge,
545 |         related_edges,
546 |         existing_edges,
547 |         episode,
548 |         edge_type_candidates=None,
549 |         custom_edge_type_names=None,
550 |     ):
551 |         nonlocal resolve_call_count
552 |         resolve_call_count += 1
553 |         return extracted_edge, [], []
554 | 
555 |     # Mock semaphore_gather to execute awaitable immediately
556 |     async def immediate_gather(*aws, max_coroutines=None):
557 |         results = []
558 |         for aw in aws:
559 |             results.append(await aw)
560 |         return results
561 | 
562 |     monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
563 |     monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
564 |     monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
565 | 
566 |     llm_client = MagicMock()
567 |     clients = SimpleNamespace(
568 |         driver=MagicMock(),
569 |         llm_client=llm_client,
570 |         embedder=MagicMock(),
571 |         cross_encoder=MagicMock(),
572 |     )
573 | 
574 |     source_node = EntityNode(
575 |         uuid='source_uuid',
576 |         name='Assistant',
577 |         group_id='group_1',
578 |         labels=['Entity'],
579 |     )
580 |     target_node = EntityNode(
581 |         uuid='target_uuid',
582 |         name='User',
583 |         group_id='group_1',
584 |         labels=['Entity'],
585 |     )
586 | 
587 |     # Create 3 identical edges
588 |     edge1 = EntityEdge(
589 |         source_node_uuid=source_node.uuid,
590 |         target_node_uuid=target_node.uuid,
591 |         name='recommends',
592 |         group_id='group_1',
593 |         fact='assistant recommends yoga poses',
594 |         episodes=[],
595 |         created_at=datetime.now(timezone.utc),
596 |         valid_at=None,
597 |         invalid_at=None,
598 |     )
599 | 
600 |     edge2 = EntityEdge(
601 |         source_node_uuid=source_node.uuid,
602 |         target_node_uuid=target_node.uuid,
603 |         name='recommends',
604 |         group_id='group_1',
605 |         fact='  Assistant Recommends YOGA Poses  ',  # Different whitespace/case
606 |         episodes=[],
607 |         created_at=datetime.now(timezone.utc),
608 |         valid_at=None,
609 |         invalid_at=None,
610 |     )
611 | 
612 |     edge3 = EntityEdge(
613 |         source_node_uuid=source_node.uuid,
614 |         target_node_uuid=target_node.uuid,
615 |         name='recommends',
616 |         group_id='group_1',
617 |         fact='assistant recommends yoga poses',
618 |         episodes=[],
619 |         created_at=datetime.now(timezone.utc),
620 |         valid_at=None,
621 |         invalid_at=None,
622 |     )
623 | 
624 |     episode = EpisodicNode(
625 |         uuid='episode_uuid',
626 |         name='Episode',
627 |         group_id='group_1',
628 |         source='message',
629 |         source_description='desc',
630 |         content='Episode content',
631 |         valid_at=datetime.now(timezone.utc),
632 |     )
633 | 
634 |     resolved_edges, invalidated_edges = await resolve_extracted_edges(
635 |         clients,
636 |         [edge1, edge2, edge3],
637 |         episode,
638 |         [source_node, target_node],
639 |         {},
640 |         {},
641 |     )
642 | 
643 |     # Fast path should have deduplicated the 3 identical edges to 1
644 |     # So resolve_extracted_edge should only be called once
645 |     assert resolve_call_count == 1
646 |     assert len(resolved_edges) == 1
647 |     assert invalidated_edges == []
648 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import typing
 20 | from datetime import datetime
 21 | 
 22 | import numpy as np
 23 | from pydantic import BaseModel, Field
 24 | from typing_extensions import Any
 25 | 
 26 | from graphiti_core.driver.driver import (
 27 |     GraphDriver,
 28 |     GraphDriverSession,
 29 |     GraphProvider,
 30 | )
 31 | from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
 32 | from graphiti_core.embedder import EmbedderClient
 33 | from graphiti_core.graphiti_types import GraphitiClients
 34 | from graphiti_core.helpers import normalize_l2, semaphore_gather
 35 | from graphiti_core.models.edges.edge_db_queries import (
 36 |     get_entity_edge_save_bulk_query,
 37 |     get_episodic_edge_save_bulk_query,
 38 | )
 39 | from graphiti_core.models.nodes.node_db_queries import (
 40 |     get_entity_node_save_bulk_query,
 41 |     get_episode_node_save_bulk_query,
 42 | )
 43 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
 44 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
 45 | from graphiti_core.utils.maintenance.dedup_helpers import (
 46 |     DedupResolutionState,
 47 |     _build_candidate_indexes,
 48 |     _normalize_string_exact,
 49 |     _resolve_with_similarity,
 50 | )
 51 | from graphiti_core.utils.maintenance.edge_operations import (
 52 |     extract_edges,
 53 |     resolve_extracted_edge,
 54 | )
 55 | from graphiti_core.utils.maintenance.graph_data_operations import (
 56 |     EPISODE_WINDOW_LEN,
 57 |     retrieve_episodes,
 58 | )
 59 | from graphiti_core.utils.maintenance.node_operations import (
 60 |     extract_nodes,
 61 |     resolve_extracted_nodes,
 62 | )
 63 | 
 64 | logger = logging.getLogger(__name__)
 65 | 
 66 | CHUNK_SIZE = 10
 67 | 
 68 | 
 69 | def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
 70 |     """Collapse alias -> canonical chains while preserving direction.
 71 | 
 72 |     The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
 73 |     union-find with iterative path compression to ensure every source UUID resolves to its ultimate
 74 |     canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
 75 |     """
 76 | 
 77 |     parent: dict[str, str] = {}
 78 | 
 79 |     def find(uuid: str) -> str:
 80 |         """Directed union-find lookup using iterative path compression."""
 81 |         parent.setdefault(uuid, uuid)
 82 |         root = uuid
 83 |         while parent[root] != root:
 84 |             root = parent[root]
 85 | 
 86 |         while parent[uuid] != root:
 87 |             next_uuid = parent[uuid]
 88 |             parent[uuid] = root
 89 |             uuid = next_uuid
 90 | 
 91 |         return root
 92 | 
 93 |     for source_uuid, target_uuid in pairs:
 94 |         parent.setdefault(source_uuid, source_uuid)
 95 |         parent.setdefault(target_uuid, target_uuid)
 96 |         parent[find(source_uuid)] = find(target_uuid)
 97 | 
 98 |     return {uuid: find(uuid) for uuid in parent}
 99 | 
100 | 
101 | class RawEpisode(BaseModel):
102 |     name: str
103 |     uuid: str | None = Field(default=None)
104 |     content: str
105 |     source_description: str
106 |     source: EpisodeType
107 |     reference_time: datetime
108 | 
109 | 
110 | async def retrieve_previous_episodes_bulk(
111 |     driver: GraphDriver, episodes: list[EpisodicNode]
112 | ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
113 |     previous_episodes_list = await semaphore_gather(
114 |         *[
115 |             retrieve_episodes(
116 |                 driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
117 |             )
118 |             for episode in episodes
119 |         ]
120 |     )
121 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
122 |         (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
123 |     ]
124 | 
125 |     return episode_tuples
126 | 
127 | 
128 | async def add_nodes_and_edges_bulk(
129 |     driver: GraphDriver,
130 |     episodic_nodes: list[EpisodicNode],
131 |     episodic_edges: list[EpisodicEdge],
132 |     entity_nodes: list[EntityNode],
133 |     entity_edges: list[EntityEdge],
134 |     embedder: EmbedderClient,
135 | ):
136 |     session = driver.session()
137 |     try:
138 |         await session.execute_write(
139 |             add_nodes_and_edges_bulk_tx,
140 |             episodic_nodes,
141 |             episodic_edges,
142 |             entity_nodes,
143 |             entity_edges,
144 |             embedder,
145 |             driver=driver,
146 |         )
147 |     finally:
148 |         await session.close()
149 | 
150 | 
151 | async def add_nodes_and_edges_bulk_tx(
152 |     tx: GraphDriverSession,
153 |     episodic_nodes: list[EpisodicNode],
154 |     episodic_edges: list[EpisodicEdge],
155 |     entity_nodes: list[EntityNode],
156 |     entity_edges: list[EntityEdge],
157 |     embedder: EmbedderClient,
158 |     driver: GraphDriver,
159 | ):
160 |     episodes = [dict(episode) for episode in episodic_nodes]
161 |     for episode in episodes:
162 |         episode['source'] = str(episode['source'].value)
163 |         episode.pop('labels', None)
164 | 
165 |     nodes = []
166 | 
167 |     for node in entity_nodes:
168 |         if node.name_embedding is None:
169 |             await node.generate_name_embedding(embedder)
170 | 
171 |         entity_data: dict[str, Any] = {
172 |             'uuid': node.uuid,
173 |             'name': node.name,
174 |             'group_id': node.group_id,
175 |             'summary': node.summary,
176 |             'created_at': node.created_at,
177 |             'name_embedding': node.name_embedding,
178 |             'labels': list(set(node.labels + ['Entity'])),
179 |         }
180 | 
181 |         if driver.provider == GraphProvider.KUZU:
182 |             attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
183 |             entity_data['attributes'] = json.dumps(attributes)
184 |         else:
185 |             entity_data.update(node.attributes or {})
186 | 
187 |         nodes.append(entity_data)
188 | 
189 |     edges = []
190 |     for edge in entity_edges:
191 |         if edge.fact_embedding is None:
192 |             await edge.generate_embedding(embedder)
193 |         edge_data: dict[str, Any] = {
194 |             'uuid': edge.uuid,
195 |             'source_node_uuid': edge.source_node_uuid,
196 |             'target_node_uuid': edge.target_node_uuid,
197 |             'name': edge.name,
198 |             'fact': edge.fact,
199 |             'group_id': edge.group_id,
200 |             'episodes': edge.episodes,
201 |             'created_at': edge.created_at,
202 |             'expired_at': edge.expired_at,
203 |             'valid_at': edge.valid_at,
204 |             'invalid_at': edge.invalid_at,
205 |             'fact_embedding': edge.fact_embedding,
206 |         }
207 | 
208 |         if driver.provider == GraphProvider.KUZU:
209 |             attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
210 |             edge_data['attributes'] = json.dumps(attributes)
211 |         else:
212 |             edge_data.update(edge.attributes or {})
213 | 
214 |         edges.append(edge_data)
215 | 
216 |     if driver.graph_operations_interface:
217 |         await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
218 |         await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
219 |         await driver.graph_operations_interface.episodic_edge_save_bulk(
220 |             None, driver, tx, [edge.model_dump() for edge in episodic_edges]
221 |         )
222 |         await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
223 | 
224 |     elif driver.provider == GraphProvider.KUZU:
225 |         # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
226 |         episode_query = get_episode_node_save_bulk_query(driver.provider)
227 |         for episode in episodes:
228 |             await tx.run(episode_query, **episode)
229 |         entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
230 |         for node in nodes:
231 |             await tx.run(entity_node_query, **node)
232 |         entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
233 |         for edge in edges:
234 |             await tx.run(entity_edge_query, **edge)
235 |         episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
236 |         for edge in episodic_edges:
237 |             await tx.run(episodic_edge_query, **edge.model_dump())
238 |     else:
239 |         await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
240 |         await tx.run(
241 |             get_entity_node_save_bulk_query(driver.provider, nodes),
242 |             nodes=nodes,
243 |         )
244 |         await tx.run(
245 |             get_episodic_edge_save_bulk_query(driver.provider),
246 |             episodic_edges=[edge.model_dump() for edge in episodic_edges],
247 |         )
248 |         await tx.run(
249 |             get_entity_edge_save_bulk_query(driver.provider),
250 |             entity_edges=edges,
251 |         )
252 | 
253 | 
254 | async def extract_nodes_and_edges_bulk(
255 |     clients: GraphitiClients,
256 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
257 |     edge_type_map: dict[tuple[str, str], list[str]],
258 |     entity_types: dict[str, type[BaseModel]] | None = None,
259 |     excluded_entity_types: list[str] | None = None,
260 |     edge_types: dict[str, type[BaseModel]] | None = None,
261 | ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
262 |     extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
263 |         *[
264 |             extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
265 |             for episode, previous_episodes in episode_tuples
266 |         ]
267 |     )
268 | 
269 |     extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
270 |         *[
271 |             extract_edges(
272 |                 clients,
273 |                 episode,
274 |                 extracted_nodes_bulk[i],
275 |                 previous_episodes,
276 |                 edge_type_map=edge_type_map,
277 |                 group_id=episode.group_id,
278 |                 edge_types=edge_types,
279 |             )
280 |             for i, (episode, previous_episodes) in enumerate(episode_tuples)
281 |         ]
282 |     )
283 | 
284 |     return extracted_nodes_bulk, extracted_edges_bulk
285 | 
286 | 
287 | async def dedupe_nodes_bulk(
288 |     clients: GraphitiClients,
289 |     extracted_nodes: list[list[EntityNode]],
290 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
291 |     entity_types: dict[str, type[BaseModel]] | None = None,
292 | ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
293 |     """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
294 | 
295 |     1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
296 |        reconciled against the live graph just like the non-batch flow.
297 |     2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
298 |        duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
299 |        can apply to edges and persistence.
300 |     """
301 | 
302 |     first_pass_results = await semaphore_gather(
303 |         *[
304 |             resolve_extracted_nodes(
305 |                 clients,
306 |                 nodes,
307 |                 episode_tuples[i][0],
308 |                 episode_tuples[i][1],
309 |                 entity_types,
310 |             )
311 |             for i, nodes in enumerate(extracted_nodes)
312 |         ]
313 |     )
314 | 
315 |     episode_resolutions: list[tuple[str, list[EntityNode]]] = []
316 |     per_episode_uuid_maps: list[dict[str, str]] = []
317 |     duplicate_pairs: list[tuple[str, str]] = []
318 | 
319 |     for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
320 |         first_pass_results, episode_tuples, strict=True
321 |     ):
322 |         episode_resolutions.append((episode.uuid, resolved_nodes))
323 |         per_episode_uuid_maps.append(uuid_map)
324 |         duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
325 | 
326 |     canonical_nodes: dict[str, EntityNode] = {}
327 |     for _, resolved_nodes in episode_resolutions:
328 |         for node in resolved_nodes:
329 |             # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
330 |             # the MinHash index for the accumulated canonical pool each time. The LRU-backed
331 |             # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
332 |             # but if batches grow significantly we should switch to an incremental index or chunked
333 |             # processing.
334 |             if not canonical_nodes:
335 |                 canonical_nodes[node.uuid] = node
336 |                 continue
337 | 
338 |             existing_candidates = list(canonical_nodes.values())
339 |             normalized = _normalize_string_exact(node.name)
340 |             exact_match = next(
341 |                 (
342 |                     candidate
343 |                     for candidate in existing_candidates
344 |                     if _normalize_string_exact(candidate.name) == normalized
345 |                 ),
346 |                 None,
347 |             )
348 |             if exact_match is not None:
349 |                 if exact_match.uuid != node.uuid:
350 |                     duplicate_pairs.append((node.uuid, exact_match.uuid))
351 |                 continue
352 | 
353 |             indexes = _build_candidate_indexes(existing_candidates)
354 |             state = DedupResolutionState(
355 |                 resolved_nodes=[None],
356 |                 uuid_map={},
357 |                 unresolved_indices=[],
358 |             )
359 |             _resolve_with_similarity([node], indexes, state)
360 | 
361 |             resolved = state.resolved_nodes[0]
362 |             if resolved is None:
363 |                 canonical_nodes[node.uuid] = node
364 |                 continue
365 | 
366 |             canonical_uuid = resolved.uuid
367 |             canonical_nodes.setdefault(canonical_uuid, resolved)
368 |             if canonical_uuid != node.uuid:
369 |                 duplicate_pairs.append((node.uuid, canonical_uuid))
370 | 
371 |     union_pairs: list[tuple[str, str]] = []
372 |     for uuid_map in per_episode_uuid_maps:
373 |         union_pairs.extend(uuid_map.items())
374 |     union_pairs.extend(duplicate_pairs)
375 | 
376 |     compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
377 | 
378 |     nodes_by_episode: dict[str, list[EntityNode]] = {}
379 |     for episode_uuid, resolved_nodes in episode_resolutions:
380 |         deduped_nodes: list[EntityNode] = []
381 |         seen: set[str] = set()
382 |         for node in resolved_nodes:
383 |             canonical_uuid = compressed_map.get(node.uuid, node.uuid)
384 |             if canonical_uuid in seen:
385 |                 continue
386 |             seen.add(canonical_uuid)
387 |             canonical_node = canonical_nodes.get(canonical_uuid)
388 |             if canonical_node is None:
389 |                 logger.error(
390 |                     'Canonical node %s missing during batch dedupe; falling back to %s',
391 |                     canonical_uuid,
392 |                     node.uuid,
393 |                 )
394 |                 canonical_node = node
395 |             deduped_nodes.append(canonical_node)
396 | 
397 |         nodes_by_episode[episode_uuid] = deduped_nodes
398 | 
399 |     return nodes_by_episode, compressed_map
400 | 
401 | 
402 | async def dedupe_edges_bulk(
403 |     clients: GraphitiClients,
404 |     extracted_edges: list[list[EntityEdge]],
405 |     episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
406 |     _entities: list[EntityNode],
407 |     edge_types: dict[str, type[BaseModel]],
408 |     _edge_type_map: dict[tuple[str, str], list[str]],
409 | ) -> dict[str, list[EntityEdge]]:
410 |     embedder = clients.embedder
411 |     min_score = 0.6
412 | 
413 |     # generate embeddings
414 |     await semaphore_gather(
415 |         *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
416 |     )
417 | 
418 |     # Find similar results
419 |     dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
420 |     for i, edges_i in enumerate(extracted_edges):
421 |         existing_edges: list[EntityEdge] = []
422 |         for edges_j in extracted_edges:
423 |             existing_edges += edges_j
424 | 
425 |         for edge in edges_i:
426 |             candidates: list[EntityEdge] = []
427 |             for existing_edge in existing_edges:
428 |                 # Skip self-comparison
429 |                 if edge.uuid == existing_edge.uuid:
430 |                     continue
431 |                 # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
432 |                 # This approach will cast a wider net than BM25, which is ideal for this use case
433 |                 if (
434 |                     edge.source_node_uuid != existing_edge.source_node_uuid
435 |                     or edge.target_node_uuid != existing_edge.target_node_uuid
436 |                 ):
437 |                     continue
438 | 
439 |                 edge_words = set(edge.fact.lower().split())
440 |                 existing_edge_words = set(existing_edge.fact.lower().split())
441 |                 has_overlap = not edge_words.isdisjoint(existing_edge_words)
442 |                 if has_overlap:
443 |                     candidates.append(existing_edge)
444 |                     continue
445 | 
446 |                 # Check for semantic similarity even if there is no overlap
447 |                 similarity = np.dot(
448 |                     normalize_l2(edge.fact_embedding or []),
449 |                     normalize_l2(existing_edge.fact_embedding or []),
450 |                 )
451 |                 if similarity >= min_score:
452 |                     candidates.append(existing_edge)
453 | 
454 |             dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
455 | 
456 |     bulk_edge_resolutions: list[
457 |         tuple[EntityEdge, EntityEdge, list[EntityEdge]]
458 |     ] = await semaphore_gather(
459 |         *[
460 |             resolve_extracted_edge(
461 |                 clients.llm_client,
462 |                 edge,
463 |                 candidates,
464 |                 candidates,
465 |                 episode,
466 |                 edge_types,
467 |                 set(edge_types),
468 |             )
469 |             for episode, edge, candidates in dedupe_tuples
470 |         ]
471 |     )
472 | 
473 |     # For now we won't track edge invalidation
474 |     duplicate_pairs: list[tuple[str, str]] = []
475 |     for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
476 |         episode, edge, candidates = dedupe_tuples[i]
477 |         for duplicate in duplicates:
478 |             duplicate_pairs.append((edge.uuid, duplicate.uuid))
479 | 
480 |     # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
481 |     compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
482 | 
483 |     edge_uuid_map: dict[str, EntityEdge] = {
484 |         edge.uuid: edge for edges in extracted_edges for edge in edges
485 |     }
486 | 
487 |     edges_by_episode: dict[str, list[EntityEdge]] = {}
488 |     for i, edges in enumerate(extracted_edges):
489 |         episode = episode_tuples[i][0]
490 | 
491 |         edges_by_episode[episode.uuid] = [
492 |             edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
493 |         ]
494 | 
495 |     return edges_by_episode
496 | 
497 | 
498 | class UnionFind:
499 |     def __init__(self, elements):
500 |         # start each element in its own set
501 |         self.parent = {e: e for e in elements}
502 | 
503 |     def find(self, x):
504 |         # path‐compression
505 |         if self.parent[x] != x:
506 |             self.parent[x] = self.find(self.parent[x])
507 |         return self.parent[x]
508 | 
509 |     def union(self, a, b):
510 |         ra, rb = self.find(a), self.find(b)
511 |         if ra == rb:
512 |             return
513 |         # attach the lexicographically larger root under the smaller
514 |         if ra < rb:
515 |             self.parent[rb] = ra
516 |         else:
517 |             self.parent[ra] = rb
518 | 
519 | 
520 | def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
521 |     """
522 |     all_ids: iterable of all entity IDs (strings)
523 |     duplicate_pairs: iterable of (id1, id2) pairs
524 |     returns: dict mapping each id -> lexicographically smallest id in its duplicate set
525 |     """
526 |     all_uuids = set()
527 |     for pair in duplicate_pairs:
528 |         all_uuids.add(pair[0])
529 |         all_uuids.add(pair[1])
530 | 
531 |     uf = UnionFind(all_uuids)
532 |     for a, b in duplicate_pairs:
533 |         uf.union(a, b)
534 |     # ensure full path‐compression before mapping
535 |     return {uuid: uf.find(uuid) for uuid in all_uuids}
536 | 
537 | 
538 | E = typing.TypeVar('E', bound=Edge)
539 | 
540 | 
541 | def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
542 |     for edge in edges:
543 |         source_uuid = edge.source_node_uuid
544 |         target_uuid = edge.target_node_uuid
545 |         edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
546 |         edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
547 | 
548 |     return edges
549 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_stress_load.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Stress and load testing for Graphiti MCP Server.
  4 | Tests system behavior under high load, resource constraints, and edge conditions.
  5 | """
  6 | 
  7 | import asyncio
  8 | import gc
  9 | import random
 10 | import time
 11 | from dataclasses import dataclass
 12 | 
 13 | import psutil
 14 | import pytest
 15 | from test_fixtures import TestDataGenerator, graphiti_test_client
 16 | 
 17 | 
 18 | @dataclass
 19 | class LoadTestConfig:
 20 |     """Configuration for load testing scenarios."""
 21 | 
 22 |     num_clients: int = 10
 23 |     operations_per_client: int = 100
 24 |     ramp_up_time: float = 5.0  # seconds
 25 |     test_duration: float = 60.0  # seconds
 26 |     target_throughput: float | None = None  # ops/sec
 27 |     think_time: float = 0.1  # seconds between ops
 28 | 
 29 | 
 30 | @dataclass
 31 | class LoadTestResult:
 32 |     """Results from a load test run."""
 33 | 
 34 |     total_operations: int
 35 |     successful_operations: int
 36 |     failed_operations: int
 37 |     duration: float
 38 |     throughput: float
 39 |     average_latency: float
 40 |     p50_latency: float
 41 |     p95_latency: float
 42 |     p99_latency: float
 43 |     max_latency: float
 44 |     errors: dict[str, int]
 45 |     resource_usage: dict[str, float]
 46 | 
 47 | 
 48 | class LoadTester:
 49 |     """Orchestrate load testing scenarios."""
 50 | 
 51 |     def __init__(self, config: LoadTestConfig):
 52 |         self.config = config
 53 |         self.metrics: list[tuple[float, float, bool]] = []  # (start, duration, success)
 54 |         self.errors: dict[str, int] = {}
 55 |         self.start_time: float | None = None
 56 | 
 57 |     async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
 58 |         """Run workload for a single simulated client."""
 59 |         stats = {'success': 0, 'failure': 0}
 60 |         data_gen = TestDataGenerator()
 61 | 
 62 |         # Ramp-up delay
 63 |         ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
 64 |         await asyncio.sleep(ramp_delay)
 65 | 
 66 |         for op_num in range(self.config.operations_per_client):
 67 |             operation_start = time.time()
 68 | 
 69 |             try:
 70 |                 # Randomly select operation type
 71 |                 operation = random.choice(
 72 |                     [
 73 |                         'add_memory',
 74 |                         'search_memory_nodes',
 75 |                         'get_episodes',
 76 |                     ]
 77 |                 )
 78 | 
 79 |                 if operation == 'add_memory':
 80 |                     args = {
 81 |                         'name': f'Load Test {client_id}-{op_num}',
 82 |                         'episode_body': data_gen.generate_technical_document(),
 83 |                         'source': 'text',
 84 |                         'source_description': 'load test',
 85 |                         'group_id': group_id,
 86 |                     }
 87 |                 elif operation == 'search_memory_nodes':
 88 |                     args = {
 89 |                         'query': random.choice(['performance', 'architecture', 'test', 'data']),
 90 |                         'group_id': group_id,
 91 |                         'limit': 10,
 92 |                     }
 93 |                 else:  # get_episodes
 94 |                     args = {
 95 |                         'group_id': group_id,
 96 |                         'last_n': 10,
 97 |                     }
 98 | 
 99 |                 # Execute operation with timeout
100 |                 await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)
101 | 
102 |                 duration = time.time() - operation_start
103 |                 self.metrics.append((operation_start, duration, True))
104 |                 stats['success'] += 1
105 | 
106 |             except asyncio.TimeoutError:
107 |                 duration = time.time() - operation_start
108 |                 self.metrics.append((operation_start, duration, False))
109 |                 self.errors['timeout'] = self.errors.get('timeout', 0) + 1
110 |                 stats['failure'] += 1
111 | 
112 |             except Exception as e:
113 |                 duration = time.time() - operation_start
114 |                 self.metrics.append((operation_start, duration, False))
115 |                 error_type = type(e).__name__
116 |                 self.errors[error_type] = self.errors.get(error_type, 0) + 1
117 |                 stats['failure'] += 1
118 | 
119 |             # Think time between operations
120 |             await asyncio.sleep(self.config.think_time)
121 | 
122 |             # Stop if we've exceeded test duration
123 |             if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
124 |                 break
125 | 
126 |         return stats
127 | 
128 |     def calculate_results(self) -> LoadTestResult:
129 |         """Calculate load test results from metrics."""
130 |         if not self.metrics:
131 |             return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})
132 | 
133 |         successful = [m for m in self.metrics if m[2]]
134 |         failed = [m for m in self.metrics if not m[2]]
135 | 
136 |         latencies = sorted([m[1] for m in self.metrics])
137 |         duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
138 | 
139 |         # Calculate percentiles
140 |         def percentile(data: list[float], p: float) -> float:
141 |             if not data:
142 |                 return 0.0
143 |             idx = int(len(data) * p / 100)
144 |             return data[min(idx, len(data) - 1)]
145 | 
146 |         # Get resource usage
147 |         process = psutil.Process()
148 |         resource_usage = {
149 |             'cpu_percent': process.cpu_percent(),
150 |             'memory_mb': process.memory_info().rss / 1024 / 1024,
151 |             'num_threads': process.num_threads(),
152 |         }
153 | 
154 |         return LoadTestResult(
155 |             total_operations=len(self.metrics),
156 |             successful_operations=len(successful),
157 |             failed_operations=len(failed),
158 |             duration=duration,
159 |             throughput=len(self.metrics) / duration if duration > 0 else 0,
160 |             average_latency=sum(latencies) / len(latencies) if latencies else 0,
161 |             p50_latency=percentile(latencies, 50),
162 |             p95_latency=percentile(latencies, 95),
163 |             p99_latency=percentile(latencies, 99),
164 |             max_latency=max(latencies) if latencies else 0,
165 |             errors=self.errors,
166 |             resource_usage=resource_usage,
167 |         )
168 | 
169 | 
170 | class TestLoadScenarios:
171 |     """Various load testing scenarios."""
172 | 
173 |     @pytest.mark.asyncio
174 |     @pytest.mark.slow
175 |     async def test_sustained_load(self):
176 |         """Test system under sustained moderate load."""
177 |         config = LoadTestConfig(
178 |             num_clients=5,
179 |             operations_per_client=20,
180 |             ramp_up_time=2.0,
181 |             test_duration=30.0,
182 |             think_time=0.5,
183 |         )
184 | 
185 |         async with graphiti_test_client() as (session, group_id):
186 |             tester = LoadTester(config)
187 |             tester.start_time = time.time()
188 | 
189 |             # Run client workloads
190 |             client_tasks = []
191 |             for client_id in range(config.num_clients):
192 |                 task = tester.run_client_workload(client_id, session, group_id)
193 |                 client_tasks.append(task)
194 | 
195 |             # Execute all clients
196 |             await asyncio.gather(*client_tasks)
197 | 
198 |             # Calculate results
199 |             results = tester.calculate_results()
200 | 
201 |             # Assertions
202 |             assert results.successful_operations > results.failed_operations
203 |             assert results.average_latency < 5.0, (
204 |                 f'Average latency too high: {results.average_latency:.2f}s'
205 |             )
206 |             assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'
207 | 
208 |             # Report results
209 |             print('\nSustained Load Test Results:')
210 |             print(f'  Total operations: {results.total_operations}')
211 |             print(
212 |                 f'  Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
213 |             )
214 |             print(f'  Throughput: {results.throughput:.2f} ops/s')
215 |             print(f'  Avg latency: {results.average_latency:.2f}s')
216 |             print(f'  P95 latency: {results.p95_latency:.2f}s')
217 | 
218 |     @pytest.mark.asyncio
219 |     @pytest.mark.slow
220 |     async def test_spike_load(self):
221 |         """Test system response to sudden load spikes."""
222 |         async with graphiti_test_client() as (session, group_id):
223 |             # Normal load phase
224 |             normal_tasks = []
225 |             for i in range(3):
226 |                 task = session.call_tool(
227 |                     'add_memory',
228 |                     {
229 |                         'name': f'Normal Load {i}',
230 |                         'episode_body': 'Normal operation',
231 |                         'source': 'text',
232 |                         'source_description': 'normal',
233 |                         'group_id': group_id,
234 |                     },
235 |                 )
236 |                 normal_tasks.append(task)
237 |                 await asyncio.sleep(0.5)
238 | 
239 |             await asyncio.gather(*normal_tasks)
240 | 
241 |             # Spike phase - sudden burst of requests
242 |             spike_start = time.time()
243 |             spike_tasks = []
244 |             for i in range(50):
245 |                 task = session.call_tool(
246 |                     'add_memory',
247 |                     {
248 |                         'name': f'Spike Load {i}',
249 |                         'episode_body': TestDataGenerator.generate_technical_document(),
250 |                         'source': 'text',
251 |                         'source_description': 'spike',
252 |                         'group_id': group_id,
253 |                     },
254 |                 )
255 |                 spike_tasks.append(task)
256 | 
257 |             # Execute spike
258 |             spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
259 |             spike_duration = time.time() - spike_start
260 | 
261 |             # Analyze spike handling
262 |             spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
263 |             spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
264 | 
265 |             print('\nSpike Load Test Results:')
266 |             print(f'  Spike size: {len(spike_tasks)} operations')
267 |             print(f'  Duration: {spike_duration:.2f}s')
268 |             print(f'  Success rate: {spike_success_rate * 100:.1f}%')
269 |             print(f'  Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')
270 | 
271 |             # System should handle at least 80% of spike
272 |             assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'
273 | 
274 |     @pytest.mark.asyncio
275 |     @pytest.mark.slow
276 |     async def test_memory_leak_detection(self):
277 |         """Test for memory leaks during extended operation."""
278 |         async with graphiti_test_client() as (session, group_id):
279 |             process = psutil.Process()
280 |             gc.collect()  # Force garbage collection
281 |             initial_memory = process.memory_info().rss / 1024 / 1024  # MB
282 | 
283 |             # Perform many operations
284 |             for batch in range(10):
285 |                 batch_tasks = []
286 |                 for i in range(10):
287 |                     task = session.call_tool(
288 |                         'add_memory',
289 |                         {
290 |                             'name': f'Memory Test {batch}-{i}',
291 |                             'episode_body': TestDataGenerator.generate_technical_document(),
292 |                             'source': 'text',
293 |                             'source_description': 'memory test',
294 |                             'group_id': group_id,
295 |                         },
296 |                     )
297 |                     batch_tasks.append(task)
298 | 
299 |                 await asyncio.gather(*batch_tasks)
300 | 
301 |                 # Force garbage collection between batches
302 |                 gc.collect()
303 |                 await asyncio.sleep(1)
304 | 
305 |             # Check memory after operations
306 |             gc.collect()
307 |             final_memory = process.memory_info().rss / 1024 / 1024  # MB
308 |             memory_growth = final_memory - initial_memory
309 | 
310 |             print('\nMemory Leak Test:')
311 |             print(f'  Initial memory: {initial_memory:.1f} MB')
312 |             print(f'  Final memory: {final_memory:.1f} MB')
313 |             print(f'  Growth: {memory_growth:.1f} MB')
314 | 
315 |             # Allow for some memory growth but flag potential leaks
316 |             # This is a soft check - actual threshold depends on system
317 |             if memory_growth > 100:  # More than 100MB growth
318 |                 print(f'  ⚠️  Potential memory leak detected: {memory_growth:.1f} MB growth')
319 | 
320 |     @pytest.mark.asyncio
321 |     @pytest.mark.slow
322 |     async def test_connection_pool_exhaustion(self):
323 |         """Test behavior when connection pools are exhausted."""
324 |         async with graphiti_test_client() as (session, group_id):
325 |             # Create many concurrent long-running operations
326 |             long_tasks = []
327 |             for i in range(100):  # Many more than typical pool size
328 |                 task = session.call_tool(
329 |                     'search_memory_nodes',
330 |                     {
331 |                         'query': f'complex query {i} '
332 |                         + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
333 |                         'group_id': group_id,
334 |                         'limit': 100,
335 |                     },
336 |                 )
337 |                 long_tasks.append(task)
338 | 
339 |             # Execute with timeout
340 |             try:
341 |                 results = await asyncio.wait_for(
342 |                     asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
343 |                 )
344 | 
345 |                 # Count connection-related errors
346 |                 connection_errors = sum(
347 |                     1
348 |                     for r in results
349 |                     if isinstance(r, Exception) and 'connection' in str(r).lower()
350 |                 )
351 | 
352 |                 print('\nConnection Pool Test:')
353 |                 print(f'  Total requests: {len(long_tasks)}')
354 |                 print(f'  Connection errors: {connection_errors}')
355 | 
356 |             except asyncio.TimeoutError:
357 |                 print('  Test timed out - possible deadlock or exhaustion')
358 | 
359 |     @pytest.mark.asyncio
360 |     @pytest.mark.slow
361 |     async def test_gradual_degradation(self):
362 |         """Test system degradation under increasing load."""
363 |         async with graphiti_test_client() as (session, group_id):
364 |             load_levels = [5, 10, 20, 40, 80]  # Increasing concurrent operations
365 |             results_by_level = {}
366 | 
367 |             for level in load_levels:
368 |                 level_start = time.time()
369 |                 tasks = []
370 | 
371 |                 for i in range(level):
372 |                     task = session.call_tool(
373 |                         'add_memory',
374 |                         {
375 |                             'name': f'Load Level {level} Op {i}',
376 |                             'episode_body': f'Testing at load level {level}',
377 |                             'source': 'text',
378 |                             'source_description': 'degradation test',
379 |                             'group_id': group_id,
380 |                         },
381 |                     )
382 |                     tasks.append(task)
383 | 
384 |                 # Execute level
385 |                 level_results = await asyncio.gather(*tasks, return_exceptions=True)
386 |                 level_duration = time.time() - level_start
387 | 
388 |                 # Calculate metrics
389 |                 failures = sum(1 for r in level_results if isinstance(r, Exception))
390 |                 success_rate = (level - failures) / level * 100
391 |                 throughput = level / level_duration
392 | 
393 |                 results_by_level[level] = {
394 |                     'success_rate': success_rate,
395 |                     'throughput': throughput,
396 |                     'duration': level_duration,
397 |                 }
398 | 
399 |                 print(f'\nLoad Level {level}:')
400 |                 print(f'  Success rate: {success_rate:.1f}%')
401 |                 print(f'  Throughput: {throughput:.2f} ops/s')
402 |                 print(f'  Duration: {level_duration:.2f}s')
403 | 
404 |                 # Brief pause between levels
405 |                 await asyncio.sleep(2)
406 | 
407 |             # Verify graceful degradation
408 |             # Success rate should not drop below 50% even at high load
409 |             for level, metrics in results_by_level.items():
410 |                 assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'
411 | 
412 | 
413 | class TestResourceLimits:
414 |     """Test behavior at resource limits."""
415 | 
416 |     @pytest.mark.asyncio
417 |     async def test_large_payload_handling(self):
418 |         """Test handling of very large payloads."""
419 |         async with graphiti_test_client() as (session, group_id):
420 |             payload_sizes = [
421 |                 (1_000, '1KB'),
422 |                 (10_000, '10KB'),
423 |                 (100_000, '100KB'),
424 |                 (1_000_000, '1MB'),
425 |             ]
426 | 
427 |             for size, label in payload_sizes:
428 |                 content = 'x' * size
429 | 
430 |                 start_time = time.time()
431 |                 try:
432 |                     await asyncio.wait_for(
433 |                         session.call_tool(
434 |                             'add_memory',
435 |                             {
436 |                                 'name': f'Large Payload {label}',
437 |                                 'episode_body': content,
438 |                                 'source': 'text',
439 |                                 'source_description': 'payload test',
440 |                                 'group_id': group_id,
441 |                             },
442 |                         ),
443 |                         timeout=30.0,
444 |                     )
445 |                     duration = time.time() - start_time
446 |                     status = '✅ Success'
447 | 
448 |                 except asyncio.TimeoutError:
449 |                     duration = 30.0
450 |                     status = '⏱️  Timeout'
451 | 
452 |                 except Exception as e:
453 |                     duration = time.time() - start_time
454 |                     status = f'❌ Error: {type(e).__name__}'
455 | 
456 |                 print(f'Payload {label}: {status} ({duration:.2f}s)')
457 | 
458 |     @pytest.mark.asyncio
459 |     async def test_rate_limit_handling(self):
460 |         """Test handling of rate limits."""
461 |         async with graphiti_test_client() as (session, group_id):
462 |             # Rapid fire requests to trigger rate limits
463 |             rapid_tasks = []
464 |             for i in range(100):
465 |                 task = session.call_tool(
466 |                     'add_memory',
467 |                     {
468 |                         'name': f'Rate Limit Test {i}',
469 |                         'episode_body': f'Testing rate limit {i}',
470 |                         'source': 'text',
471 |                         'source_description': 'rate test',
472 |                         'group_id': group_id,
473 |                     },
474 |                 )
475 |                 rapid_tasks.append(task)
476 | 
477 |             # Execute without delays
478 |             results = await asyncio.gather(*rapid_tasks, return_exceptions=True)
479 | 
480 |             # Count rate limit errors
481 |             rate_limit_errors = sum(
482 |                 1
483 |                 for r in results
484 |                 if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
485 |             )
486 | 
487 |             print('\nRate Limit Test:')
488 |             print(f'  Total requests: {len(rapid_tasks)}')
489 |             print(f'  Rate limit errors: {rate_limit_errors}')
490 |             print(
491 |                 f'  Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
492 |             )
493 | 
494 | 
495 | def generate_load_test_report(results: list[LoadTestResult]) -> str:
496 |     """Generate comprehensive load test report."""
497 |     report = []
498 |     report.append('\n' + '=' * 60)
499 |     report.append('LOAD TEST REPORT')
500 |     report.append('=' * 60)
501 | 
502 |     for i, result in enumerate(results):
503 |         report.append(f'\nTest Run {i + 1}:')
504 |         report.append(f'  Total Operations: {result.total_operations}')
505 |         report.append(
506 |             f'  Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
507 |         )
508 |         report.append(f'  Throughput: {result.throughput:.2f} ops/s')
509 |         report.append(
510 |             f'  Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
511 |         )
512 | 
513 |         if result.errors:
514 |             report.append('  Errors:')
515 |             for error_type, count in result.errors.items():
516 |                 report.append(f'    {error_type}: {count}')
517 | 
518 |         report.append('  Resource Usage:')
519 |         for metric, value in result.resource_usage.items():
520 |             report.append(f'    {metric}: {value:.2f}')
521 | 
522 |     report.append('=' * 60)
523 |     return '\n'.join(report)
524 | 
525 | 
526 | if __name__ == '__main__':
527 |     pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])
528 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | from collections.abc import Awaitable, Callable
 19 | from time import time
 20 | from typing import Any
 21 | 
 22 | from pydantic import BaseModel
 23 | 
 24 | from graphiti_core.graphiti_types import GraphitiClients
 25 | from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
 26 | from graphiti_core.llm_client import LLMClient
 27 | from graphiti_core.llm_client.config import ModelSize
 28 | from graphiti_core.nodes import (
 29 |     EntityNode,
 30 |     EpisodeType,
 31 |     EpisodicNode,
 32 |     create_entity_node_embeddings,
 33 | )
 34 | from graphiti_core.prompts import prompt_library
 35 | from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
 36 | from graphiti_core.prompts.extract_nodes import (
 37 |     EntitySummary,
 38 |     ExtractedEntities,
 39 |     ExtractedEntity,
 40 |     MissedEntities,
 41 | )
 42 | from graphiti_core.search.search import search
 43 | from graphiti_core.search.search_config import SearchResults
 44 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
 45 | from graphiti_core.search.search_filters import SearchFilters
 46 | from graphiti_core.utils.datetime_utils import utc_now
 47 | from graphiti_core.utils.maintenance.dedup_helpers import (
 48 |     DedupCandidateIndexes,
 49 |     DedupResolutionState,
 50 |     _build_candidate_indexes,
 51 |     _resolve_with_similarity,
 52 | )
 53 | from graphiti_core.utils.maintenance.edge_operations import (
 54 |     filter_existing_duplicate_of_edges,
 55 | )
 56 | from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
 57 | 
 58 | logger = logging.getLogger(__name__)
 59 | 
 60 | NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
 61 | 
 62 | 
 63 | async def extract_nodes_reflexion(
 64 |     llm_client: LLMClient,
 65 |     episode: EpisodicNode,
 66 |     previous_episodes: list[EpisodicNode],
 67 |     node_names: list[str],
 68 |     group_id: str | None = None,
 69 | ) -> list[str]:
 70 |     # Prepare context for LLM
 71 |     context = {
 72 |         'episode_content': episode.content,
 73 |         'previous_episodes': [ep.content for ep in previous_episodes],
 74 |         'extracted_entities': node_names,
 75 |     }
 76 | 
 77 |     llm_response = await llm_client.generate_response(
 78 |         prompt_library.extract_nodes.reflexion(context),
 79 |         MissedEntities,
 80 |         group_id=group_id,
 81 |         prompt_name='extract_nodes.reflexion',
 82 |     )
 83 |     missed_entities = llm_response.get('missed_entities', [])
 84 | 
 85 |     return missed_entities
 86 | 
 87 | 
 88 | async def extract_nodes(
 89 |     clients: GraphitiClients,
 90 |     episode: EpisodicNode,
 91 |     previous_episodes: list[EpisodicNode],
 92 |     entity_types: dict[str, type[BaseModel]] | None = None,
 93 |     excluded_entity_types: list[str] | None = None,
 94 | ) -> list[EntityNode]:
 95 |     start = time()
 96 |     llm_client = clients.llm_client
 97 |     llm_response = {}
 98 |     custom_prompt = ''
 99 |     entities_missed = True
100 |     reflexion_iterations = 0
101 | 
102 |     entity_types_context = [
103 |         {
104 |             'entity_type_id': 0,
105 |             'entity_type_name': 'Entity',
106 |             'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.',
107 |         }
108 |     ]
109 | 
110 |     entity_types_context += (
111 |         [
112 |             {
113 |                 'entity_type_id': i + 1,
114 |                 'entity_type_name': type_name,
115 |                 'entity_type_description': type_model.__doc__,
116 |             }
117 |             for i, (type_name, type_model) in enumerate(entity_types.items())
118 |         ]
119 |         if entity_types is not None
120 |         else []
121 |     )
122 | 
123 |     context = {
124 |         'episode_content': episode.content,
125 |         'episode_timestamp': episode.valid_at.isoformat(),
126 |         'previous_episodes': [ep.content for ep in previous_episodes],
127 |         'custom_prompt': custom_prompt,
128 |         'entity_types': entity_types_context,
129 |         'source_description': episode.source_description,
130 |     }
131 | 
132 |     while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
133 |         if episode.source == EpisodeType.message:
134 |             llm_response = await llm_client.generate_response(
135 |                 prompt_library.extract_nodes.extract_message(context),
136 |                 response_model=ExtractedEntities,
137 |                 group_id=episode.group_id,
138 |                 prompt_name='extract_nodes.extract_message',
139 |             )
140 |         elif episode.source == EpisodeType.text:
141 |             llm_response = await llm_client.generate_response(
142 |                 prompt_library.extract_nodes.extract_text(context),
143 |                 response_model=ExtractedEntities,
144 |                 group_id=episode.group_id,
145 |                 prompt_name='extract_nodes.extract_text',
146 |             )
147 |         elif episode.source == EpisodeType.json:
148 |             llm_response = await llm_client.generate_response(
149 |                 prompt_library.extract_nodes.extract_json(context),
150 |                 response_model=ExtractedEntities,
151 |                 group_id=episode.group_id,
152 |                 prompt_name='extract_nodes.extract_json',
153 |             )
154 | 
155 |         response_object = ExtractedEntities(**llm_response)
156 | 
157 |         extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
158 | 
159 |         reflexion_iterations += 1
160 |         if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
161 |             missing_entities = await extract_nodes_reflexion(
162 |                 llm_client,
163 |                 episode,
164 |                 previous_episodes,
165 |                 [entity.name for entity in extracted_entities],
166 |                 episode.group_id,
167 |             )
168 | 
169 |             entities_missed = len(missing_entities) != 0
170 | 
171 |             custom_prompt = 'Make sure that the following entities are extracted: '
172 |             for entity in missing_entities:
173 |                 custom_prompt += f'\n{entity},'
174 | 
175 |     filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()]
176 |     end = time()
177 |     logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms')
178 |     # Convert the extracted data into EntityNode objects
179 |     extracted_nodes = []
180 |     for extracted_entity in filtered_extracted_entities:
181 |         type_id = extracted_entity.entity_type_id
182 |         if 0 <= type_id < len(entity_types_context):
183 |             entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
184 |                 'entity_type_name'
185 |             )
186 |         else:
187 |             entity_type_name = 'Entity'
188 | 
189 |         # Check if this entity type should be excluded
190 |         if excluded_entity_types and entity_type_name in excluded_entity_types:
191 |             logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
192 |             continue
193 | 
194 |         labels: list[str] = list({'Entity', str(entity_type_name)})
195 | 
196 |         new_node = EntityNode(
197 |             name=extracted_entity.name,
198 |             group_id=episode.group_id,
199 |             labels=labels,
200 |             summary='',
201 |             created_at=utc_now(),
202 |         )
203 |         extracted_nodes.append(new_node)
204 |         logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
205 | 
206 |     logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
207 | 
208 |     return extracted_nodes
209 | 
210 | 
211 | async def _collect_candidate_nodes(
212 |     clients: GraphitiClients,
213 |     extracted_nodes: list[EntityNode],
214 |     existing_nodes_override: list[EntityNode] | None,
215 | ) -> list[EntityNode]:
216 |     """Search per extracted name and return unique candidates with overrides honored in order."""
217 |     search_results: list[SearchResults] = await semaphore_gather(
218 |         *[
219 |             search(
220 |                 clients=clients,
221 |                 query=node.name,
222 |                 group_ids=[node.group_id],
223 |                 search_filter=SearchFilters(),
224 |                 config=NODE_HYBRID_SEARCH_RRF,
225 |             )
226 |             for node in extracted_nodes
227 |         ]
228 |     )
229 | 
230 |     candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
231 | 
232 |     if existing_nodes_override is not None:
233 |         candidate_nodes.extend(existing_nodes_override)
234 | 
235 |     seen_candidate_uuids: set[str] = set()
236 |     ordered_candidates: list[EntityNode] = []
237 |     for candidate in candidate_nodes:
238 |         if candidate.uuid in seen_candidate_uuids:
239 |             continue
240 |         seen_candidate_uuids.add(candidate.uuid)
241 |         ordered_candidates.append(candidate)
242 | 
243 |     return ordered_candidates
244 | 
245 | 
246 | async def _resolve_with_llm(
247 |     llm_client: LLMClient,
248 |     extracted_nodes: list[EntityNode],
249 |     indexes: DedupCandidateIndexes,
250 |     state: DedupResolutionState,
251 |     episode: EpisodicNode | None,
252 |     previous_episodes: list[EpisodicNode] | None,
253 |     entity_types: dict[str, type[BaseModel]] | None,
254 | ) -> None:
255 |     """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
256 | 
257 |     The guardrails below defensively ignore malformed or duplicate LLM responses so the
258 |     ingestion workflow remains deterministic even when the model misbehaves.
259 |     """
260 |     if not state.unresolved_indices:
261 |         return
262 | 
263 |     entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
264 | 
265 |     llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
266 | 
267 |     extracted_nodes_context = [
268 |         {
269 |             'id': i,
270 |             'name': node.name,
271 |             'entity_type': node.labels,
272 |             'entity_type_description': entity_types_dict.get(
273 |                 next((item for item in node.labels if item != 'Entity'), '')
274 |             ).__doc__
275 |             or 'Default Entity Type',
276 |         }
277 |         for i, node in enumerate(llm_extracted_nodes)
278 |     ]
279 | 
280 |     sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
281 |     logger.debug(
282 |         'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
283 |         len(llm_extracted_nodes),
284 |         len(llm_extracted_nodes) - 1,
285 |         sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
286 |     )
287 |     if llm_extracted_nodes:
288 |         sample_size = min(3, len(extracted_nodes_context))
289 |         logger.debug(
290 |             'First %d entities: %s',
291 |             sample_size,
292 |             [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
293 |         )
294 |         if len(extracted_nodes_context) > 3:
295 |             logger.debug(
296 |                 'Last %d entities: %s',
297 |                 sample_size,
298 |                 [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
299 |             )
300 | 
301 |     existing_nodes_context = [
302 |         {
303 |             **{
304 |                 'idx': i,
305 |                 'name': candidate.name,
306 |                 'entity_types': candidate.labels,
307 |             },
308 |             **candidate.attributes,
309 |         }
310 |         for i, candidate in enumerate(indexes.existing_nodes)
311 |     ]
312 | 
313 |     context = {
314 |         'extracted_nodes': extracted_nodes_context,
315 |         'existing_nodes': existing_nodes_context,
316 |         'episode_content': episode.content if episode is not None else '',
317 |         'previous_episodes': (
318 |             [ep.content for ep in previous_episodes] if previous_episodes is not None else []
319 |         ),
320 |     }
321 | 
322 |     llm_response = await llm_client.generate_response(
323 |         prompt_library.dedupe_nodes.nodes(context),
324 |         response_model=NodeResolutions,
325 |         prompt_name='dedupe_nodes.nodes',
326 |     )
327 | 
328 |     node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
329 | 
330 |     valid_relative_range = range(len(state.unresolved_indices))
331 |     processed_relative_ids: set[int] = set()
332 | 
333 |     received_ids = {r.id for r in node_resolutions}
334 |     expected_ids = set(valid_relative_range)
335 |     missing_ids = expected_ids - received_ids
336 |     extra_ids = received_ids - expected_ids
337 | 
338 |     logger.debug(
339 |         'Received %d resolutions for %d entities',
340 |         len(node_resolutions),
341 |         len(state.unresolved_indices),
342 |     )
343 | 
344 |     if missing_ids:
345 |         logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
346 | 
347 |     if extra_ids:
348 |         logger.warning(
349 |             'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
350 |             len(state.unresolved_indices) - 1,
351 |             sorted(extra_ids),
352 |             sorted(received_ids),
353 |         )
354 | 
355 |     for resolution in node_resolutions:
356 |         relative_id: int = resolution.id
357 |         duplicate_idx: int = resolution.duplicate_idx
358 | 
359 |         if relative_id not in valid_relative_range:
360 |             logger.warning(
361 |                 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
362 |                 relative_id,
363 |                 len(state.unresolved_indices) - 1,
364 |                 len(node_resolutions),
365 |             )
366 |             continue
367 | 
368 |         if relative_id in processed_relative_ids:
369 |             logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
370 |             continue
371 |         processed_relative_ids.add(relative_id)
372 | 
373 |         original_index = state.unresolved_indices[relative_id]
374 |         extracted_node = extracted_nodes[original_index]
375 | 
376 |         resolved_node: EntityNode
377 |         if duplicate_idx == -1:
378 |             resolved_node = extracted_node
379 |         elif 0 <= duplicate_idx < len(indexes.existing_nodes):
380 |             resolved_node = indexes.existing_nodes[duplicate_idx]
381 |         else:
382 |             logger.warning(
383 |                 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
384 |                 duplicate_idx,
385 |                 extracted_node.uuid,
386 |             )
387 |             resolved_node = extracted_node
388 | 
389 |         state.resolved_nodes[original_index] = resolved_node
390 |         state.uuid_map[extracted_node.uuid] = resolved_node.uuid
391 |         if resolved_node.uuid != extracted_node.uuid:
392 |             state.duplicate_pairs.append((extracted_node, resolved_node))
393 | 
394 | 
395 | async def resolve_extracted_nodes(
396 |     clients: GraphitiClients,
397 |     extracted_nodes: list[EntityNode],
398 |     episode: EpisodicNode | None = None,
399 |     previous_episodes: list[EpisodicNode] | None = None,
400 |     entity_types: dict[str, type[BaseModel]] | None = None,
401 |     existing_nodes_override: list[EntityNode] | None = None,
402 | ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
403 |     """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
404 |     llm_client = clients.llm_client
405 |     driver = clients.driver
406 |     existing_nodes = await _collect_candidate_nodes(
407 |         clients,
408 |         extracted_nodes,
409 |         existing_nodes_override,
410 |     )
411 | 
412 |     indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
413 | 
414 |     state = DedupResolutionState(
415 |         resolved_nodes=[None] * len(extracted_nodes),
416 |         uuid_map={},
417 |         unresolved_indices=[],
418 |     )
419 | 
420 |     _resolve_with_similarity(extracted_nodes, indexes, state)
421 | 
422 |     await _resolve_with_llm(
423 |         llm_client,
424 |         extracted_nodes,
425 |         indexes,
426 |         state,
427 |         episode,
428 |         previous_episodes,
429 |         entity_types,
430 |     )
431 | 
432 |     for idx, node in enumerate(extracted_nodes):
433 |         if state.resolved_nodes[idx] is None:
434 |             state.resolved_nodes[idx] = node
435 |             state.uuid_map[node.uuid] = node.uuid
436 | 
437 |     logger.debug(
438 |         'Resolved nodes: %s',
439 |         [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
440 |     )
441 | 
442 |     new_node_duplicates: list[
443 |         tuple[EntityNode, EntityNode]
444 |     ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
445 | 
446 |     return (
447 |         [node for node in state.resolved_nodes if node is not None],
448 |         state.uuid_map,
449 |         new_node_duplicates,
450 |     )
451 | 
452 | 
453 | async def extract_attributes_from_nodes(
454 |     clients: GraphitiClients,
455 |     nodes: list[EntityNode],
456 |     episode: EpisodicNode | None = None,
457 |     previous_episodes: list[EpisodicNode] | None = None,
458 |     entity_types: dict[str, type[BaseModel]] | None = None,
459 |     should_summarize_node: NodeSummaryFilter | None = None,
460 | ) -> list[EntityNode]:
461 |     llm_client = clients.llm_client
462 |     embedder = clients.embedder
463 |     updated_nodes: list[EntityNode] = await semaphore_gather(
464 |         *[
465 |             extract_attributes_from_node(
466 |                 llm_client,
467 |                 node,
468 |                 episode,
469 |                 previous_episodes,
470 |                 (
471 |                     entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
472 |                     if entity_types is not None
473 |                     else None
474 |                 ),
475 |                 should_summarize_node,
476 |             )
477 |             for node in nodes
478 |         ]
479 |     )
480 | 
481 |     await create_entity_node_embeddings(embedder, updated_nodes)
482 | 
483 |     return updated_nodes
484 | 
485 | 
486 | async def extract_attributes_from_node(
487 |     llm_client: LLMClient,
488 |     node: EntityNode,
489 |     episode: EpisodicNode | None = None,
490 |     previous_episodes: list[EpisodicNode] | None = None,
491 |     entity_type: type[BaseModel] | None = None,
492 |     should_summarize_node: NodeSummaryFilter | None = None,
493 | ) -> EntityNode:
494 |     # Extract attributes if entity type is defined and has attributes
495 |     llm_response = await _extract_entity_attributes(
496 |         llm_client, node, episode, previous_episodes, entity_type
497 |     )
498 | 
499 |     # Extract summary if needed
500 |     await _extract_entity_summary(
501 |         llm_client, node, episode, previous_episodes, should_summarize_node
502 |     )
503 | 
504 |     node.attributes.update(llm_response)
505 | 
506 |     return node
507 | 
508 | 
509 | async def _extract_entity_attributes(
510 |     llm_client: LLMClient,
511 |     node: EntityNode,
512 |     episode: EpisodicNode | None,
513 |     previous_episodes: list[EpisodicNode] | None,
514 |     entity_type: type[BaseModel] | None,
515 | ) -> dict[str, Any]:
516 |     if entity_type is None or len(entity_type.model_fields) == 0:
517 |         return {}
518 | 
519 |     attributes_context = _build_episode_context(
520 |         # should not include summary
521 |         node_data={
522 |             'name': node.name,
523 |             'entity_types': node.labels,
524 |             'attributes': node.attributes,
525 |         },
526 |         episode=episode,
527 |         previous_episodes=previous_episodes,
528 |     )
529 | 
530 |     llm_response = await llm_client.generate_response(
531 |         prompt_library.extract_nodes.extract_attributes(attributes_context),
532 |         response_model=entity_type,
533 |         model_size=ModelSize.small,
534 |         group_id=node.group_id,
535 |         prompt_name='extract_nodes.extract_attributes',
536 |     )
537 | 
538 |     # validate response
539 |     entity_type(**llm_response)
540 | 
541 |     return llm_response
542 | 
543 | 
544 | async def _extract_entity_summary(
545 |     llm_client: LLMClient,
546 |     node: EntityNode,
547 |     episode: EpisodicNode | None,
548 |     previous_episodes: list[EpisodicNode] | None,
549 |     should_summarize_node: NodeSummaryFilter | None,
550 | ) -> None:
551 |     if should_summarize_node is not None and not await should_summarize_node(node):
552 |         return
553 | 
554 |     summary_context = _build_episode_context(
555 |         node_data={
556 |             'name': node.name,
557 |             'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
558 |             'entity_types': node.labels,
559 |             'attributes': node.attributes,
560 |         },
561 |         episode=episode,
562 |         previous_episodes=previous_episodes,
563 |     )
564 | 
565 |     summary_response = await llm_client.generate_response(
566 |         prompt_library.extract_nodes.extract_summary(summary_context),
567 |         response_model=EntitySummary,
568 |         model_size=ModelSize.small,
569 |         group_id=node.group_id,
570 |         prompt_name='extract_nodes.extract_summary',
571 |     )
572 | 
573 |     node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
574 | 
575 | 
576 | def _build_episode_context(
577 |     node_data: dict[str, Any],
578 |     episode: EpisodicNode | None,
579 |     previous_episodes: list[EpisodicNode] | None,
580 | ) -> dict[str, Any]:
581 |     return {
582 |         'node': node_data,
583 |         'episode_content': episode.content if episode is not None else '',
584 |         'previous_episodes': (
585 |             [ep.content for ep in previous_episodes] if previous_episodes is not None else []
586 |         ),
587 |     }
588 | 
```
Page 7/12FirstPrevNextLast