This is page 7 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/mcp_server/tests/test_async_operations.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Asynchronous operation tests for Graphiti MCP Server.
4 | Tests concurrent operations, queue management, and async patterns.
5 | """
6 |
7 | import asyncio
8 | import contextlib
9 | import json
10 | import time
11 |
12 | import pytest
13 | from test_fixtures import (
14 | TestDataGenerator,
15 | graphiti_test_client,
16 | )
17 |
18 |
19 | class TestAsyncQueueManagement:
20 | """Test asynchronous queue operations and episode processing."""
21 |
22 | @pytest.mark.asyncio
23 | async def test_sequential_queue_processing(self):
24 | """Verify episodes are processed sequentially within a group."""
25 | async with graphiti_test_client() as (session, group_id):
26 | # Add multiple episodes quickly
27 | episodes = []
28 | for i in range(5):
29 | result = await session.call_tool(
30 | 'add_memory',
31 | {
32 | 'name': f'Sequential Test {i}',
33 | 'episode_body': f'Episode {i} with timestamp {time.time()}',
34 | 'source': 'text',
35 | 'source_description': 'sequential test',
36 | 'group_id': group_id,
37 | 'reference_id': f'seq_{i}', # Add reference for tracking
38 | },
39 | )
40 | episodes.append(result)
41 |
42 | # Wait for processing
43 | await asyncio.sleep(10) # Allow time for sequential processing
44 |
45 | # Retrieve episodes and verify order
46 | result = await session.call_tool('get_episodes', {'group_id': group_id, 'last_n': 10})
47 |
48 | processed_episodes = json.loads(result.content[0].text)['episodes']
49 |
50 | # Verify all episodes were processed
51 | assert len(processed_episodes) >= 5, (
52 | f'Expected at least 5 episodes, got {len(processed_episodes)}'
53 | )
54 |
55 | # Verify sequential processing (timestamps should be ordered)
56 | timestamps = [ep.get('created_at') for ep in processed_episodes]
57 | assert timestamps == sorted(timestamps), 'Episodes not processed in order'
58 |
59 | @pytest.mark.asyncio
60 | async def test_concurrent_group_processing(self):
61 | """Test that different groups can process concurrently."""
62 | async with graphiti_test_client() as (session, _):
63 | groups = [f'group_{i}_{time.time()}' for i in range(3)]
64 | tasks = []
65 |
66 | # Create tasks for different groups
67 | for group_id in groups:
68 | for j in range(2):
69 | task = session.call_tool(
70 | 'add_memory',
71 | {
72 | 'name': f'Group {group_id} Episode {j}',
73 | 'episode_body': f'Content for {group_id}',
74 | 'source': 'text',
75 | 'source_description': 'concurrent test',
76 | 'group_id': group_id,
77 | },
78 | )
79 | tasks.append(task)
80 |
81 | # Execute all tasks concurrently
82 | start_time = time.time()
83 | results = await asyncio.gather(*tasks, return_exceptions=True)
84 | execution_time = time.time() - start_time
85 |
86 | # Verify all succeeded
87 | failures = [r for r in results if isinstance(r, Exception)]
88 | assert not failures, f'Concurrent operations failed: {failures}'
89 |
90 | # Check that execution was actually concurrent (should be faster than sequential)
91 | # Sequential would take at least 6 * processing_time
92 | assert execution_time < 30, f'Concurrent execution too slow: {execution_time}s'
93 |
94 | @pytest.mark.asyncio
95 | async def test_queue_overflow_handling(self):
96 | """Test behavior when queue reaches capacity."""
97 | async with graphiti_test_client() as (session, group_id):
98 | # Attempt to add many episodes rapidly
99 | tasks = []
100 | for i in range(100): # Large number to potentially overflow
101 | task = session.call_tool(
102 | 'add_memory',
103 | {
104 | 'name': f'Overflow Test {i}',
105 | 'episode_body': f'Episode {i}',
106 | 'source': 'text',
107 | 'source_description': 'overflow test',
108 | 'group_id': group_id,
109 | },
110 | )
111 | tasks.append(task)
112 |
113 | # Execute with gathering to catch any failures
114 | results = await asyncio.gather(*tasks, return_exceptions=True)
115 |
116 | # Count successful queuing
117 | successful = sum(1 for r in results if not isinstance(r, Exception))
118 |
119 | # Should handle overflow gracefully
120 | assert successful > 0, 'No episodes were queued successfully'
121 |
122 | # Log overflow behavior
123 | if successful < 100:
124 | print(f'Queue overflow: {successful}/100 episodes queued')
125 |
126 |
127 | class TestConcurrentOperations:
128 | """Test concurrent tool calls and operations."""
129 |
130 | @pytest.mark.asyncio
131 | async def test_concurrent_search_operations(self):
132 | """Test multiple concurrent search operations."""
133 | async with graphiti_test_client() as (session, group_id):
134 | # First, add some test data
135 | data_gen = TestDataGenerator()
136 |
137 | add_tasks = []
138 | for _ in range(5):
139 | task = session.call_tool(
140 | 'add_memory',
141 | {
142 | 'name': 'Search Test Data',
143 | 'episode_body': data_gen.generate_technical_document(),
144 | 'source': 'text',
145 | 'source_description': 'search test',
146 | 'group_id': group_id,
147 | },
148 | )
149 | add_tasks.append(task)
150 |
151 | await asyncio.gather(*add_tasks)
152 | await asyncio.sleep(15) # Wait for processing
153 |
154 | # Now perform concurrent searches
155 | search_queries = [
156 | 'architecture',
157 | 'performance',
158 | 'implementation',
159 | 'dependencies',
160 | 'latency',
161 | ]
162 |
163 | search_tasks = []
164 | for query in search_queries:
165 | task = session.call_tool(
166 | 'search_memory_nodes',
167 | {
168 | 'query': query,
169 | 'group_id': group_id,
170 | 'limit': 10,
171 | },
172 | )
173 | search_tasks.append(task)
174 |
175 | start_time = time.time()
176 | results = await asyncio.gather(*search_tasks, return_exceptions=True)
177 | search_time = time.time() - start_time
178 |
179 | # Verify all searches completed
180 | failures = [r for r in results if isinstance(r, Exception)]
181 | assert not failures, f'Search operations failed: {failures}'
182 |
183 | # Verify concurrent execution efficiency
184 | assert search_time < len(search_queries) * 2, 'Searches not executing concurrently'
185 |
186 | @pytest.mark.asyncio
187 | async def test_mixed_operation_concurrency(self):
188 | """Test different types of operations running concurrently."""
189 | async with graphiti_test_client() as (session, group_id):
190 | operations = []
191 |
192 | # Add memory operation
193 | operations.append(
194 | session.call_tool(
195 | 'add_memory',
196 | {
197 | 'name': 'Mixed Op Test',
198 | 'episode_body': 'Testing mixed operations',
199 | 'source': 'text',
200 | 'source_description': 'test',
201 | 'group_id': group_id,
202 | },
203 | )
204 | )
205 |
206 | # Search operation
207 | operations.append(
208 | session.call_tool(
209 | 'search_memory_nodes',
210 | {
211 | 'query': 'test',
212 | 'group_id': group_id,
213 | 'limit': 5,
214 | },
215 | )
216 | )
217 |
218 | # Get episodes operation
219 | operations.append(
220 | session.call_tool(
221 | 'get_episodes',
222 | {
223 | 'group_id': group_id,
224 | 'last_n': 10,
225 | },
226 | )
227 | )
228 |
229 | # Get status operation
230 | operations.append(session.call_tool('get_status', {}))
231 |
232 | # Execute all concurrently
233 | results = await asyncio.gather(*operations, return_exceptions=True)
234 |
235 | # Check results
236 | for i, result in enumerate(results):
237 | assert not isinstance(result, Exception), f'Operation {i} failed: {result}'
238 |
239 |
240 | class TestAsyncErrorHandling:
241 | """Test async error handling and recovery."""
242 |
243 | @pytest.mark.asyncio
244 | async def test_timeout_recovery(self):
245 | """Test recovery from operation timeouts."""
246 | async with graphiti_test_client() as (session, group_id):
247 | # Create a very large episode that might time out
248 | large_content = 'x' * 1000000 # 1MB of data
249 |
250 | with contextlib.suppress(asyncio.TimeoutError):
251 | await asyncio.wait_for(
252 | session.call_tool(
253 | 'add_memory',
254 | {
255 | 'name': 'Timeout Test',
256 | 'episode_body': large_content,
257 | 'source': 'text',
258 | 'source_description': 'timeout test',
259 | 'group_id': group_id,
260 | },
261 | ),
262 | timeout=2.0, # Short timeout - expected to timeout
263 | )
264 |
265 | # Verify server is still responsive after timeout
266 | status_result = await session.call_tool('get_status', {})
267 | assert status_result is not None, 'Server unresponsive after timeout'
268 |
269 | @pytest.mark.asyncio
270 | async def test_cancellation_handling(self):
271 | """Test proper handling of cancelled operations."""
272 | async with graphiti_test_client() as (session, group_id):
273 | # Start a long-running operation
274 | task = asyncio.create_task(
275 | session.call_tool(
276 | 'add_memory',
277 | {
278 | 'name': 'Cancellation Test',
279 | 'episode_body': TestDataGenerator.generate_technical_document(),
280 | 'source': 'text',
281 | 'source_description': 'cancel test',
282 | 'group_id': group_id,
283 | },
284 | )
285 | )
286 |
287 | # Cancel after a short delay
288 | await asyncio.sleep(0.1)
289 | task.cancel()
290 |
291 | # Verify cancellation was handled
292 | with pytest.raises(asyncio.CancelledError):
293 | await task
294 |
295 | # Server should still be operational
296 | result = await session.call_tool('get_status', {})
297 | assert result is not None
298 |
299 | @pytest.mark.asyncio
300 | async def test_exception_propagation(self):
301 | """Test that exceptions are properly propagated in async context."""
302 | async with graphiti_test_client() as (session, group_id):
303 | # Call with invalid arguments
304 | with pytest.raises(ValueError):
305 | await session.call_tool(
306 | 'add_memory',
307 | {
308 | # Missing required fields
309 | 'group_id': group_id,
310 | },
311 | )
312 |
313 | # Server should remain operational
314 | status = await session.call_tool('get_status', {})
315 | assert status is not None
316 |
317 |
318 | class TestAsyncPerformance:
319 | """Performance tests for async operations."""
320 |
321 | @pytest.mark.asyncio
322 | async def test_async_throughput(self, performance_benchmark):
323 | """Measure throughput of async operations."""
324 | async with graphiti_test_client() as (session, group_id):
325 | num_operations = 50
326 | start_time = time.time()
327 |
328 | # Create many concurrent operations
329 | tasks = []
330 | for i in range(num_operations):
331 | task = session.call_tool(
332 | 'add_memory',
333 | {
334 | 'name': f'Throughput Test {i}',
335 | 'episode_body': f'Content {i}',
336 | 'source': 'text',
337 | 'source_description': 'throughput test',
338 | 'group_id': group_id,
339 | },
340 | )
341 | tasks.append(task)
342 |
343 | # Execute all
344 | results = await asyncio.gather(*tasks, return_exceptions=True)
345 | total_time = time.time() - start_time
346 |
347 | # Calculate metrics
348 | successful = sum(1 for r in results if not isinstance(r, Exception))
349 | throughput = successful / total_time
350 |
351 | performance_benchmark.record('async_throughput', throughput)
352 |
353 | # Log results
354 | print('\nAsync Throughput Test:')
355 | print(f' Operations: {num_operations}')
356 | print(f' Successful: {successful}')
357 | print(f' Total time: {total_time:.2f}s')
358 | print(f' Throughput: {throughput:.2f} ops/s')
359 |
360 | # Assert minimum throughput
361 | assert throughput > 1.0, f'Throughput too low: {throughput:.2f} ops/s'
362 |
363 | @pytest.mark.asyncio
364 | async def test_latency_under_load(self, performance_benchmark):
365 | """Test operation latency under concurrent load."""
366 | async with graphiti_test_client() as (session, group_id):
367 | # Create background load
368 | background_tasks = []
369 | for i in range(10):
370 | task = asyncio.create_task(
371 | session.call_tool(
372 | 'add_memory',
373 | {
374 | 'name': f'Background {i}',
375 | 'episode_body': TestDataGenerator.generate_technical_document(),
376 | 'source': 'text',
377 | 'source_description': 'background',
378 | 'group_id': f'background_{group_id}',
379 | },
380 | )
381 | )
382 | background_tasks.append(task)
383 |
384 | # Measure latency of operations under load
385 | latencies = []
386 | for _ in range(5):
387 | start = time.time()
388 | await session.call_tool('get_status', {})
389 | latency = time.time() - start
390 | latencies.append(latency)
391 | performance_benchmark.record('latency_under_load', latency)
392 |
393 | # Clean up background tasks
394 | for task in background_tasks:
395 | task.cancel()
396 |
397 | # Analyze latencies
398 | avg_latency = sum(latencies) / len(latencies)
399 | max_latency = max(latencies)
400 |
401 | print('\nLatency Under Load:')
402 | print(f' Average: {avg_latency:.3f}s')
403 | print(f' Max: {max_latency:.3f}s')
404 |
405 | # Assert acceptable latency
406 | assert avg_latency < 2.0, f'Average latency too high: {avg_latency:.3f}s'
407 | assert max_latency < 5.0, f'Max latency too high: {max_latency:.3f}s'
408 |
409 |
410 | class TestAsyncStreamHandling:
411 | """Test handling of streaming responses and data."""
412 |
413 | @pytest.mark.asyncio
414 | async def test_large_response_streaming(self):
415 | """Test handling of large streamed responses."""
416 | async with graphiti_test_client() as (session, group_id):
417 | # Add many episodes
418 | for i in range(20):
419 | await session.call_tool(
420 | 'add_memory',
421 | {
422 | 'name': f'Stream Test {i}',
423 | 'episode_body': f'Episode content {i}',
424 | 'source': 'text',
425 | 'source_description': 'stream test',
426 | 'group_id': group_id,
427 | },
428 | )
429 |
430 | # Wait for processing
431 | await asyncio.sleep(30)
432 |
433 | # Request large result set
434 | result = await session.call_tool(
435 | 'get_episodes',
436 | {
437 | 'group_id': group_id,
438 | 'last_n': 100, # Request all
439 | },
440 | )
441 |
442 | # Verify response handling
443 | episodes = json.loads(result.content[0].text)['episodes']
444 | assert len(episodes) >= 20, f'Expected at least 20 episodes, got {len(episodes)}'
445 |
446 | @pytest.mark.asyncio
447 | async def test_incremental_processing(self):
448 | """Test incremental processing of results."""
449 | async with graphiti_test_client() as (session, group_id):
450 | # Add episodes incrementally
451 | for batch in range(3):
452 | batch_tasks = []
453 | for i in range(5):
454 | task = session.call_tool(
455 | 'add_memory',
456 | {
457 | 'name': f'Batch {batch} Item {i}',
458 | 'episode_body': f'Content for batch {batch}',
459 | 'source': 'text',
460 | 'source_description': 'incremental test',
461 | 'group_id': group_id,
462 | },
463 | )
464 | batch_tasks.append(task)
465 |
466 | # Process batch
467 | await asyncio.gather(*batch_tasks)
468 |
469 | # Wait for this batch to process
470 | await asyncio.sleep(10)
471 |
472 | # Verify incremental results
473 | result = await session.call_tool(
474 | 'get_episodes',
475 | {
476 | 'group_id': group_id,
477 | 'last_n': 100,
478 | },
479 | )
480 |
481 | episodes = json.loads(result.content[0].text)['episodes']
482 | expected_min = (batch + 1) * 5
483 | assert len(episodes) >= expected_min, (
484 | f'Batch {batch}: Expected at least {expected_min} episodes'
485 | )
486 |
487 |
488 | if __name__ == '__main__':
489 | pytest.main([__file__, '-v', '--asyncio-mode=auto'])
490 |
```
--------------------------------------------------------------------------------
/graphiti_core/search/search.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from collections import defaultdict
19 | from time import time
20 |
21 | from graphiti_core.cross_encoder.client import CrossEncoderClient
22 | from graphiti_core.driver.driver import GraphDriver
23 | from graphiti_core.edges import EntityEdge
24 | from graphiti_core.embedder.client import EMBEDDING_DIM
25 | from graphiti_core.errors import SearchRerankerError
26 | from graphiti_core.graphiti_types import GraphitiClients
27 | from graphiti_core.helpers import semaphore_gather
28 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
29 | from graphiti_core.search.search_config import (
30 | DEFAULT_SEARCH_LIMIT,
31 | CommunityReranker,
32 | CommunitySearchConfig,
33 | CommunitySearchMethod,
34 | EdgeReranker,
35 | EdgeSearchConfig,
36 | EdgeSearchMethod,
37 | EpisodeReranker,
38 | EpisodeSearchConfig,
39 | NodeReranker,
40 | NodeSearchConfig,
41 | NodeSearchMethod,
42 | SearchConfig,
43 | SearchResults,
44 | )
45 | from graphiti_core.search.search_filters import SearchFilters
46 | from graphiti_core.search.search_utils import (
47 | community_fulltext_search,
48 | community_similarity_search,
49 | edge_bfs_search,
50 | edge_fulltext_search,
51 | edge_similarity_search,
52 | episode_fulltext_search,
53 | episode_mentions_reranker,
54 | get_embeddings_for_communities,
55 | get_embeddings_for_edges,
56 | get_embeddings_for_nodes,
57 | maximal_marginal_relevance,
58 | node_bfs_search,
59 | node_distance_reranker,
60 | node_fulltext_search,
61 | node_similarity_search,
62 | rrf,
63 | )
64 |
65 | logger = logging.getLogger(__name__)
66 |
67 |
68 | async def search(
69 | clients: GraphitiClients,
70 | query: str,
71 | group_ids: list[str] | None,
72 | config: SearchConfig,
73 | search_filter: SearchFilters,
74 | center_node_uuid: str | None = None,
75 | bfs_origin_node_uuids: list[str] | None = None,
76 | query_vector: list[float] | None = None,
77 | driver: GraphDriver | None = None,
78 | ) -> SearchResults:
79 | start = time()
80 |
81 | driver = driver or clients.driver
82 | embedder = clients.embedder
83 | cross_encoder = clients.cross_encoder
84 |
85 | if query.strip() == '':
86 | return SearchResults()
87 |
88 | if (
89 | config.edge_config
90 | and EdgeSearchMethod.cosine_similarity in config.edge_config.search_methods
91 | or config.edge_config
92 | and EdgeReranker.mmr == config.edge_config.reranker
93 | or config.node_config
94 | and NodeSearchMethod.cosine_similarity in config.node_config.search_methods
95 | or config.node_config
96 | and NodeReranker.mmr == config.node_config.reranker
97 | or (
98 | config.community_config
99 | and CommunitySearchMethod.cosine_similarity in config.community_config.search_methods
100 | )
101 | or (config.community_config and CommunityReranker.mmr == config.community_config.reranker)
102 | ):
103 | search_vector = (
104 | query_vector
105 | if query_vector is not None
106 | else await embedder.create(input_data=[query.replace('\n', ' ')])
107 | )
108 | else:
109 | search_vector = [0.0] * EMBEDDING_DIM
110 |
111 | # if group_ids is empty, set it to None
112 | group_ids = group_ids if group_ids and group_ids != [''] else None
113 | (
114 | (edges, edge_reranker_scores),
115 | (nodes, node_reranker_scores),
116 | (episodes, episode_reranker_scores),
117 | (communities, community_reranker_scores),
118 | ) = await semaphore_gather(
119 | edge_search(
120 | driver,
121 | cross_encoder,
122 | query,
123 | search_vector,
124 | group_ids,
125 | config.edge_config,
126 | search_filter,
127 | center_node_uuid,
128 | bfs_origin_node_uuids,
129 | config.limit,
130 | config.reranker_min_score,
131 | ),
132 | node_search(
133 | driver,
134 | cross_encoder,
135 | query,
136 | search_vector,
137 | group_ids,
138 | config.node_config,
139 | search_filter,
140 | center_node_uuid,
141 | bfs_origin_node_uuids,
142 | config.limit,
143 | config.reranker_min_score,
144 | ),
145 | episode_search(
146 | driver,
147 | cross_encoder,
148 | query,
149 | search_vector,
150 | group_ids,
151 | config.episode_config,
152 | search_filter,
153 | config.limit,
154 | config.reranker_min_score,
155 | ),
156 | community_search(
157 | driver,
158 | cross_encoder,
159 | query,
160 | search_vector,
161 | group_ids,
162 | config.community_config,
163 | config.limit,
164 | config.reranker_min_score,
165 | ),
166 | )
167 |
168 | results = SearchResults(
169 | edges=edges,
170 | edge_reranker_scores=edge_reranker_scores,
171 | nodes=nodes,
172 | node_reranker_scores=node_reranker_scores,
173 | episodes=episodes,
174 | episode_reranker_scores=episode_reranker_scores,
175 | communities=communities,
176 | community_reranker_scores=community_reranker_scores,
177 | )
178 |
179 | latency = (time() - start) * 1000
180 |
181 | logger.debug(f'search returned context for query {query} in {latency} ms')
182 |
183 | return results
184 |
185 |
186 | async def edge_search(
187 | driver: GraphDriver,
188 | cross_encoder: CrossEncoderClient,
189 | query: str,
190 | query_vector: list[float],
191 | group_ids: list[str] | None,
192 | config: EdgeSearchConfig | None,
193 | search_filter: SearchFilters,
194 | center_node_uuid: str | None = None,
195 | bfs_origin_node_uuids: list[str] | None = None,
196 | limit=DEFAULT_SEARCH_LIMIT,
197 | reranker_min_score: float = 0,
198 | ) -> tuple[list[EntityEdge], list[float]]:
199 | if config is None:
200 | return [], []
201 |
202 | # Build search tasks based on configured search methods
203 | search_tasks = []
204 | if EdgeSearchMethod.bm25 in config.search_methods:
205 | search_tasks.append(
206 | edge_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
207 | )
208 | if EdgeSearchMethod.cosine_similarity in config.search_methods:
209 | search_tasks.append(
210 | edge_similarity_search(
211 | driver,
212 | query_vector,
213 | None,
214 | None,
215 | search_filter,
216 | group_ids,
217 | 2 * limit,
218 | config.sim_min_score,
219 | )
220 | )
221 | if EdgeSearchMethod.bfs in config.search_methods:
222 | search_tasks.append(
223 | edge_bfs_search(
224 | driver,
225 | bfs_origin_node_uuids,
226 | config.bfs_max_depth,
227 | search_filter,
228 | group_ids,
229 | 2 * limit,
230 | )
231 | )
232 |
233 | # Execute only the configured search methods
234 | search_results: list[list[EntityEdge]] = []
235 | if search_tasks:
236 | search_results = list(await semaphore_gather(*search_tasks))
237 |
238 | if EdgeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
239 | source_node_uuids = [edge.source_node_uuid for result in search_results for edge in result]
240 | search_results.append(
241 | await edge_bfs_search(
242 | driver,
243 | source_node_uuids,
244 | config.bfs_max_depth,
245 | search_filter,
246 | group_ids,
247 | 2 * limit,
248 | )
249 | )
250 |
251 | edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result}
252 |
253 | reranked_uuids: list[str] = []
254 | edge_scores: list[float] = []
255 | if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions:
256 | search_result_uuids = [[edge.uuid for edge in result] for result in search_results]
257 |
258 | reranked_uuids, edge_scores = rrf(search_result_uuids, min_score=reranker_min_score)
259 | elif config.reranker == EdgeReranker.mmr:
260 | search_result_uuids_and_vectors = await get_embeddings_for_edges(
261 | driver, list(edge_uuid_map.values())
262 | )
263 | reranked_uuids, edge_scores = maximal_marginal_relevance(
264 | query_vector,
265 | search_result_uuids_and_vectors,
266 | config.mmr_lambda,
267 | reranker_min_score,
268 | )
269 | elif config.reranker == EdgeReranker.cross_encoder:
270 | fact_to_uuid_map = {edge.fact: edge.uuid for edge in list(edge_uuid_map.values())[:limit]}
271 | reranked_facts = await cross_encoder.rank(query, list(fact_to_uuid_map.keys()))
272 | reranked_uuids = [
273 | fact_to_uuid_map[fact] for fact, score in reranked_facts if score >= reranker_min_score
274 | ]
275 | edge_scores = [score for _, score in reranked_facts if score >= reranker_min_score]
276 | elif config.reranker == EdgeReranker.node_distance:
277 | if center_node_uuid is None:
278 | raise SearchRerankerError('No center node provided for Node Distance reranker')
279 |
280 | # use rrf as a preliminary sort
281 | sorted_result_uuids, node_scores = rrf(
282 | [[edge.uuid for edge in result] for result in search_results],
283 | min_score=reranker_min_score,
284 | )
285 | sorted_results = [edge_uuid_map[uuid] for uuid in sorted_result_uuids]
286 |
287 | # node distance reranking
288 | source_to_edge_uuid_map = defaultdict(list)
289 | for edge in sorted_results:
290 | source_to_edge_uuid_map[edge.source_node_uuid].append(edge.uuid)
291 |
292 | source_uuids = [source_node_uuid for source_node_uuid in source_to_edge_uuid_map]
293 |
294 | reranked_node_uuids, edge_scores = await node_distance_reranker(
295 | driver, source_uuids, center_node_uuid, min_score=reranker_min_score
296 | )
297 |
298 | for node_uuid in reranked_node_uuids:
299 | reranked_uuids.extend(source_to_edge_uuid_map[node_uuid])
300 |
301 | reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
302 |
303 | if config.reranker == EdgeReranker.episode_mentions:
304 | reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes))
305 |
306 | return reranked_edges[:limit], edge_scores[:limit]
307 |
308 |
309 | async def node_search(
310 | driver: GraphDriver,
311 | cross_encoder: CrossEncoderClient,
312 | query: str,
313 | query_vector: list[float],
314 | group_ids: list[str] | None,
315 | config: NodeSearchConfig | None,
316 | search_filter: SearchFilters,
317 | center_node_uuid: str | None = None,
318 | bfs_origin_node_uuids: list[str] | None = None,
319 | limit=DEFAULT_SEARCH_LIMIT,
320 | reranker_min_score: float = 0,
321 | ) -> tuple[list[EntityNode], list[float]]:
322 | if config is None:
323 | return [], []
324 |
325 | # Build search tasks based on configured search methods
326 | search_tasks = []
327 | if NodeSearchMethod.bm25 in config.search_methods:
328 | search_tasks.append(
329 | node_fulltext_search(driver, query, search_filter, group_ids, 2 * limit)
330 | )
331 | if NodeSearchMethod.cosine_similarity in config.search_methods:
332 | search_tasks.append(
333 | node_similarity_search(
334 | driver,
335 | query_vector,
336 | search_filter,
337 | group_ids,
338 | 2 * limit,
339 | config.sim_min_score,
340 | )
341 | )
342 | if NodeSearchMethod.bfs in config.search_methods:
343 | search_tasks.append(
344 | node_bfs_search(
345 | driver,
346 | bfs_origin_node_uuids,
347 | search_filter,
348 | config.bfs_max_depth,
349 | group_ids,
350 | 2 * limit,
351 | )
352 | )
353 |
354 | # Execute only the configured search methods
355 | search_results: list[list[EntityNode]] = []
356 | if search_tasks:
357 | search_results = list(await semaphore_gather(*search_tasks))
358 |
359 | if NodeSearchMethod.bfs in config.search_methods and bfs_origin_node_uuids is None:
360 | origin_node_uuids = [node.uuid for result in search_results for node in result]
361 | search_results.append(
362 | await node_bfs_search(
363 | driver,
364 | origin_node_uuids,
365 | search_filter,
366 | config.bfs_max_depth,
367 | group_ids,
368 | 2 * limit,
369 | )
370 | )
371 |
372 | search_result_uuids = [[node.uuid for node in result] for result in search_results]
373 | node_uuid_map = {node.uuid: node for result in search_results for node in result}
374 |
375 | reranked_uuids: list[str] = []
376 | node_scores: list[float] = []
377 | if config.reranker == NodeReranker.rrf:
378 | reranked_uuids, node_scores = rrf(search_result_uuids, min_score=reranker_min_score)
379 | elif config.reranker == NodeReranker.mmr:
380 | search_result_uuids_and_vectors = await get_embeddings_for_nodes(
381 | driver, list(node_uuid_map.values())
382 | )
383 |
384 | reranked_uuids, node_scores = maximal_marginal_relevance(
385 | query_vector,
386 | search_result_uuids_and_vectors,
387 | config.mmr_lambda,
388 | reranker_min_score,
389 | )
390 | elif config.reranker == NodeReranker.cross_encoder:
391 | name_to_uuid_map = {node.name: node.uuid for node in list(node_uuid_map.values())}
392 |
393 | reranked_node_names = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
394 | reranked_uuids = [
395 | name_to_uuid_map[name]
396 | for name, score in reranked_node_names
397 | if score >= reranker_min_score
398 | ]
399 | node_scores = [score for _, score in reranked_node_names if score >= reranker_min_score]
400 | elif config.reranker == NodeReranker.episode_mentions:
401 | reranked_uuids, node_scores = await episode_mentions_reranker(
402 | driver, search_result_uuids, min_score=reranker_min_score
403 | )
404 | elif config.reranker == NodeReranker.node_distance:
405 | if center_node_uuid is None:
406 | raise SearchRerankerError('No center node provided for Node Distance reranker')
407 | reranked_uuids, node_scores = await node_distance_reranker(
408 | driver,
409 | rrf(search_result_uuids, min_score=reranker_min_score)[0],
410 | center_node_uuid,
411 | min_score=reranker_min_score,
412 | )
413 |
414 | reranked_nodes = [node_uuid_map[uuid] for uuid in reranked_uuids]
415 |
416 | return reranked_nodes[:limit], node_scores[:limit]
417 |
418 |
419 | async def episode_search(
420 | driver: GraphDriver,
421 | cross_encoder: CrossEncoderClient,
422 | query: str,
423 | _query_vector: list[float],
424 | group_ids: list[str] | None,
425 | config: EpisodeSearchConfig | None,
426 | search_filter: SearchFilters,
427 | limit=DEFAULT_SEARCH_LIMIT,
428 | reranker_min_score: float = 0,
429 | ) -> tuple[list[EpisodicNode], list[float]]:
430 | if config is None:
431 | return [], []
432 | search_results: list[list[EpisodicNode]] = list(
433 | await semaphore_gather(
434 | *[
435 | episode_fulltext_search(driver, query, search_filter, group_ids, 2 * limit),
436 | ]
437 | )
438 | )
439 |
440 | search_result_uuids = [[episode.uuid for episode in result] for result in search_results]
441 | episode_uuid_map = {episode.uuid: episode for result in search_results for episode in result}
442 |
443 | reranked_uuids: list[str] = []
444 | episode_scores: list[float] = []
445 | if config.reranker == EpisodeReranker.rrf:
446 | reranked_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
447 |
448 | elif config.reranker == EpisodeReranker.cross_encoder:
449 | # use rrf as a preliminary reranker
450 | rrf_result_uuids, episode_scores = rrf(search_result_uuids, min_score=reranker_min_score)
451 | rrf_results = [episode_uuid_map[uuid] for uuid in rrf_result_uuids][:limit]
452 |
453 | content_to_uuid_map = {episode.content: episode.uuid for episode in rrf_results}
454 |
455 | reranked_contents = await cross_encoder.rank(query, list(content_to_uuid_map.keys()))
456 | reranked_uuids = [
457 | content_to_uuid_map[content]
458 | for content, score in reranked_contents
459 | if score >= reranker_min_score
460 | ]
461 | episode_scores = [score for _, score in reranked_contents if score >= reranker_min_score]
462 |
463 | reranked_episodes = [episode_uuid_map[uuid] for uuid in reranked_uuids]
464 |
465 | return reranked_episodes[:limit], episode_scores[:limit]
466 |
467 |
468 | async def community_search(
469 | driver: GraphDriver,
470 | cross_encoder: CrossEncoderClient,
471 | query: str,
472 | query_vector: list[float],
473 | group_ids: list[str] | None,
474 | config: CommunitySearchConfig | None,
475 | limit=DEFAULT_SEARCH_LIMIT,
476 | reranker_min_score: float = 0,
477 | ) -> tuple[list[CommunityNode], list[float]]:
478 | if config is None:
479 | return [], []
480 |
481 | search_results: list[list[CommunityNode]] = list(
482 | await semaphore_gather(
483 | *[
484 | community_fulltext_search(driver, query, group_ids, 2 * limit),
485 | community_similarity_search(
486 | driver, query_vector, group_ids, 2 * limit, config.sim_min_score
487 | ),
488 | ]
489 | )
490 | )
491 |
492 | search_result_uuids = [[community.uuid for community in result] for result in search_results]
493 | community_uuid_map = {
494 | community.uuid: community for result in search_results for community in result
495 | }
496 |
497 | reranked_uuids: list[str] = []
498 | community_scores: list[float] = []
499 | if config.reranker == CommunityReranker.rrf:
500 | reranked_uuids, community_scores = rrf(search_result_uuids, min_score=reranker_min_score)
501 | elif config.reranker == CommunityReranker.mmr:
502 | search_result_uuids_and_vectors = await get_embeddings_for_communities(
503 | driver, list(community_uuid_map.values())
504 | )
505 |
506 | reranked_uuids, community_scores = maximal_marginal_relevance(
507 | query_vector, search_result_uuids_and_vectors, config.mmr_lambda, reranker_min_score
508 | )
509 | elif config.reranker == CommunityReranker.cross_encoder:
510 | name_to_uuid_map = {node.name: node.uuid for result in search_results for node in result}
511 | reranked_nodes = await cross_encoder.rank(query, list(name_to_uuid_map.keys()))
512 | reranked_uuids = [
513 | name_to_uuid_map[name] for name, score in reranked_nodes if score >= reranker_min_score
514 | ]
515 | community_scores = [score for _, score in reranked_nodes if score >= reranker_min_score]
516 |
517 | reranked_communities = [community_uuid_map[uuid] for uuid in reranked_uuids]
518 |
519 | return reranked_communities[:limit], community_scores[:limit]
520 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Integration test for the refactored Graphiti MCP Server using the official MCP Python SDK.
4 | Tests all major MCP tools and handles episode processing latency.
5 | """
6 |
7 | import asyncio
8 | import json
9 | import os
10 | import time
11 | from typing import Any
12 |
13 | from mcp import ClientSession, StdioServerParameters
14 | from mcp.client.stdio import stdio_client
15 |
16 |
17 | class GraphitiMCPIntegrationTest:
18 | """Integration test client for Graphiti MCP Server using official MCP SDK."""
19 |
20 | def __init__(self):
21 | self.test_group_id = f'test_group_{int(time.time())}'
22 | self.session = None
23 |
24 | async def __aenter__(self):
25 | """Start the MCP client session."""
26 | # Configure server parameters to run our refactored server
27 | server_params = StdioServerParameters(
28 | command='uv',
29 | args=['run', 'main.py', '--transport', 'stdio'],
30 | env={
31 | 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
32 | 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
33 | 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
34 | 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'dummy_key_for_testing'),
35 | },
36 | )
37 |
38 | print(f'🚀 Starting MCP client session with test group: {self.test_group_id}')
39 |
40 | # Use the async context manager properly
41 | self.client_context = stdio_client(server_params)
42 | read, write = await self.client_context.__aenter__()
43 | self.session = ClientSession(read, write)
44 | await self.session.initialize()
45 |
46 | return self
47 |
48 | async def __aexit__(self, exc_type, exc_val, exc_tb):
49 | """Close the MCP client session."""
50 | if self.session:
51 | await self.session.close()
52 | if hasattr(self, 'client_context'):
53 | await self.client_context.__aexit__(exc_type, exc_val, exc_tb)
54 |
55 | async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
56 | """Call an MCP tool and return the result."""
57 | try:
58 | result = await self.session.call_tool(tool_name, arguments)
59 | return result.content[0].text if result.content else {'error': 'No content returned'}
60 | except Exception as e:
61 | return {'error': str(e)}
62 |
63 | async def test_server_initialization(self) -> bool:
64 | """Test that the server initializes properly."""
65 | print('🔍 Testing server initialization...')
66 |
67 | try:
68 | # List available tools to verify server is responding
69 | tools_result = await self.session.list_tools()
70 | tools = [tool.name for tool in tools_result.tools]
71 |
72 | expected_tools = [
73 | 'add_memory',
74 | 'search_memory_nodes',
75 | 'search_memory_facts',
76 | 'get_episodes',
77 | 'delete_episode',
78 | 'delete_entity_edge',
79 | 'get_entity_edge',
80 | 'clear_graph',
81 | ]
82 |
83 | available_tools = len([tool for tool in expected_tools if tool in tools])
84 | print(
85 | f' ✅ Server responding with {len(tools)} tools ({available_tools}/{len(expected_tools)} expected)'
86 | )
87 | print(f' Available tools: {", ".join(sorted(tools))}')
88 |
89 | return available_tools >= len(expected_tools) * 0.8 # 80% of expected tools
90 |
91 | except Exception as e:
92 | print(f' ❌ Server initialization failed: {e}')
93 | return False
94 |
95 | async def test_add_memory_operations(self) -> dict[str, bool]:
96 | """Test adding various types of memory episodes."""
97 | print('📝 Testing add_memory operations...')
98 |
99 | results = {}
100 |
101 | # Test 1: Add text episode
102 | print(' Testing text episode...')
103 | try:
104 | result = await self.call_tool(
105 | 'add_memory',
106 | {
107 | 'name': 'Test Company News',
108 | 'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
109 | 'source': 'text',
110 | 'source_description': 'news article',
111 | 'group_id': self.test_group_id,
112 | },
113 | )
114 |
115 | if isinstance(result, str) and 'queued' in result.lower():
116 | print(f' ✅ Text episode: {result}')
117 | results['text'] = True
118 | else:
119 | print(f' ❌ Text episode failed: {result}')
120 | results['text'] = False
121 | except Exception as e:
122 | print(f' ❌ Text episode error: {e}')
123 | results['text'] = False
124 |
125 | # Test 2: Add JSON episode
126 | print(' Testing JSON episode...')
127 | try:
128 | json_data = {
129 | 'company': {'name': 'TechCorp', 'founded': 2010},
130 | 'products': [
131 | {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
132 | {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
133 | ],
134 | 'employees': 150,
135 | }
136 |
137 | result = await self.call_tool(
138 | 'add_memory',
139 | {
140 | 'name': 'Company Profile',
141 | 'episode_body': json.dumps(json_data),
142 | 'source': 'json',
143 | 'source_description': 'CRM data',
144 | 'group_id': self.test_group_id,
145 | },
146 | )
147 |
148 | if isinstance(result, str) and 'queued' in result.lower():
149 | print(f' ✅ JSON episode: {result}')
150 | results['json'] = True
151 | else:
152 | print(f' ❌ JSON episode failed: {result}')
153 | results['json'] = False
154 | except Exception as e:
155 | print(f' ❌ JSON episode error: {e}')
156 | results['json'] = False
157 |
158 | # Test 3: Add message episode
159 | print(' Testing message episode...')
160 | try:
161 | result = await self.call_tool(
162 | 'add_memory',
163 | {
164 | 'name': 'Customer Support Chat',
165 | 'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
166 | 'source': 'message',
167 | 'source_description': 'support chat log',
168 | 'group_id': self.test_group_id,
169 | },
170 | )
171 |
172 | if isinstance(result, str) and 'queued' in result.lower():
173 | print(f' ✅ Message episode: {result}')
174 | results['message'] = True
175 | else:
176 | print(f' ❌ Message episode failed: {result}')
177 | results['message'] = False
178 | except Exception as e:
179 | print(f' ❌ Message episode error: {e}')
180 | results['message'] = False
181 |
182 | return results
183 |
184 | async def wait_for_processing(self, max_wait: int = 45) -> bool:
185 | """Wait for episode processing to complete."""
186 | print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
187 |
188 | for i in range(max_wait):
189 | await asyncio.sleep(1)
190 |
191 | try:
192 | # Check if we have any episodes
193 | result = await self.call_tool(
194 | 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
195 | )
196 |
197 | # Parse the JSON result if it's a string
198 | if isinstance(result, str):
199 | try:
200 | parsed_result = json.loads(result)
201 | if isinstance(parsed_result, list) and len(parsed_result) > 0:
202 | print(
203 | f' ✅ Found {len(parsed_result)} processed episodes after {i + 1} seconds'
204 | )
205 | return True
206 | except json.JSONDecodeError:
207 | if 'episodes' in result.lower():
208 | print(f' ✅ Episodes detected after {i + 1} seconds')
209 | return True
210 |
211 | except Exception as e:
212 | if i == 0: # Only log first error to avoid spam
213 | print(f' ⚠️ Waiting for processing... ({e})')
214 | continue
215 |
216 | print(f' ⚠️ Still waiting after {max_wait} seconds...')
217 | return False
218 |
219 | async def test_search_operations(self) -> dict[str, bool]:
220 | """Test search functionality."""
221 | print('🔍 Testing search operations...')
222 |
223 | results = {}
224 |
225 | # Test search_memory_nodes
226 | print(' Testing search_memory_nodes...')
227 | try:
228 | result = await self.call_tool(
229 | 'search_memory_nodes',
230 | {
231 | 'query': 'Acme Corp product launch AI',
232 | 'group_ids': [self.test_group_id],
233 | 'max_nodes': 5,
234 | },
235 | )
236 |
237 | success = False
238 | if isinstance(result, str):
239 | try:
240 | parsed = json.loads(result)
241 | nodes = parsed.get('nodes', [])
242 | success = isinstance(nodes, list)
243 | print(f' ✅ Node search returned {len(nodes)} nodes')
244 | except json.JSONDecodeError:
245 | success = 'nodes' in result.lower() and 'successfully' in result.lower()
246 | if success:
247 | print(' ✅ Node search completed successfully')
248 |
249 | results['nodes'] = success
250 | if not success:
251 | print(f' ❌ Node search failed: {result}')
252 |
253 | except Exception as e:
254 | print(f' ❌ Node search error: {e}')
255 | results['nodes'] = False
256 |
257 | # Test search_memory_facts
258 | print(' Testing search_memory_facts...')
259 | try:
260 | result = await self.call_tool(
261 | 'search_memory_facts',
262 | {
263 | 'query': 'company products software TechCorp',
264 | 'group_ids': [self.test_group_id],
265 | 'max_facts': 5,
266 | },
267 | )
268 |
269 | success = False
270 | if isinstance(result, str):
271 | try:
272 | parsed = json.loads(result)
273 | facts = parsed.get('facts', [])
274 | success = isinstance(facts, list)
275 | print(f' ✅ Fact search returned {len(facts)} facts')
276 | except json.JSONDecodeError:
277 | success = 'facts' in result.lower() and 'successfully' in result.lower()
278 | if success:
279 | print(' ✅ Fact search completed successfully')
280 |
281 | results['facts'] = success
282 | if not success:
283 | print(f' ❌ Fact search failed: {result}')
284 |
285 | except Exception as e:
286 | print(f' ❌ Fact search error: {e}')
287 | results['facts'] = False
288 |
289 | return results
290 |
291 | async def test_episode_retrieval(self) -> bool:
292 | """Test episode retrieval."""
293 | print('📚 Testing episode retrieval...')
294 |
295 | try:
296 | result = await self.call_tool(
297 | 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
298 | )
299 |
300 | if isinstance(result, str):
301 | try:
302 | parsed = json.loads(result)
303 | if isinstance(parsed, list):
304 | print(f' ✅ Retrieved {len(parsed)} episodes')
305 |
306 | # Show episode details
307 | for i, episode in enumerate(parsed[:3]):
308 | name = episode.get('name', 'Unknown')
309 | source = episode.get('source', 'unknown')
310 | print(f' Episode {i + 1}: {name} (source: {source})')
311 |
312 | return len(parsed) > 0
313 | except json.JSONDecodeError:
314 | # Check if response indicates success
315 | if 'episode' in result.lower():
316 | print(' ✅ Episode retrieval completed')
317 | return True
318 |
319 | print(f' ❌ Unexpected result format: {result}')
320 | return False
321 |
322 | except Exception as e:
323 | print(f' ❌ Episode retrieval failed: {e}')
324 | return False
325 |
326 | async def test_error_handling(self) -> dict[str, bool]:
327 | """Test error handling and edge cases."""
328 | print('🧪 Testing error handling...')
329 |
330 | results = {}
331 |
332 | # Test with nonexistent group
333 | print(' Testing nonexistent group handling...')
334 | try:
335 | result = await self.call_tool(
336 | 'search_memory_nodes',
337 | {
338 | 'query': 'nonexistent data',
339 | 'group_ids': ['nonexistent_group_12345'],
340 | 'max_nodes': 5,
341 | },
342 | )
343 |
344 | # Should handle gracefully, not crash
345 | success = (
346 | 'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
347 | )
348 | if success:
349 | print(' ✅ Nonexistent group handled gracefully')
350 | else:
351 | print(f' ❌ Nonexistent group caused issues: {result}')
352 |
353 | results['nonexistent_group'] = success
354 |
355 | except Exception as e:
356 | print(f' ❌ Nonexistent group test failed: {e}')
357 | results['nonexistent_group'] = False
358 |
359 | # Test empty query
360 | print(' Testing empty query handling...')
361 | try:
362 | result = await self.call_tool(
363 | 'search_memory_nodes',
364 | {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5},
365 | )
366 |
367 | # Should handle gracefully
368 | success = (
369 | 'error' not in str(result).lower() or 'not initialized' not in str(result).lower()
370 | )
371 | if success:
372 | print(' ✅ Empty query handled gracefully')
373 | else:
374 | print(f' ❌ Empty query caused issues: {result}')
375 |
376 | results['empty_query'] = success
377 |
378 | except Exception as e:
379 | print(f' ❌ Empty query test failed: {e}')
380 | results['empty_query'] = False
381 |
382 | return results
383 |
384 | async def run_comprehensive_test(self) -> dict[str, Any]:
385 | """Run the complete integration test suite."""
386 | print('🚀 Starting Comprehensive Graphiti MCP Server Integration Test')
387 | print(f' Test group ID: {self.test_group_id}')
388 | print('=' * 70)
389 |
390 | results = {
391 | 'server_init': False,
392 | 'add_memory': {},
393 | 'processing_wait': False,
394 | 'search': {},
395 | 'episodes': False,
396 | 'error_handling': {},
397 | 'overall_success': False,
398 | }
399 |
400 | # Test 1: Server Initialization
401 | results['server_init'] = await self.test_server_initialization()
402 | if not results['server_init']:
403 | print('❌ Server initialization failed, aborting remaining tests')
404 | return results
405 |
406 | print()
407 |
408 | # Test 2: Add Memory Operations
409 | results['add_memory'] = await self.test_add_memory_operations()
410 | print()
411 |
412 | # Test 3: Wait for Processing
413 | results['processing_wait'] = await self.wait_for_processing()
414 | print()
415 |
416 | # Test 4: Search Operations
417 | results['search'] = await self.test_search_operations()
418 | print()
419 |
420 | # Test 5: Episode Retrieval
421 | results['episodes'] = await self.test_episode_retrieval()
422 | print()
423 |
424 | # Test 6: Error Handling
425 | results['error_handling'] = await self.test_error_handling()
426 | print()
427 |
428 | # Calculate overall success
429 | memory_success = any(results['add_memory'].values())
430 | search_success = any(results['search'].values()) if results['search'] else False
431 | error_success = (
432 | any(results['error_handling'].values()) if results['error_handling'] else True
433 | )
434 |
435 | results['overall_success'] = (
436 | results['server_init']
437 | and memory_success
438 | and (results['episodes'] or results['processing_wait'])
439 | and error_success
440 | )
441 |
442 | # Print comprehensive summary
443 | print('=' * 70)
444 | print('📊 COMPREHENSIVE TEST SUMMARY')
445 | print('-' * 35)
446 | print(f'Server Initialization: {"✅ PASS" if results["server_init"] else "❌ FAIL"}')
447 |
448 | memory_stats = f'({sum(results["add_memory"].values())}/{len(results["add_memory"])} types)'
449 | print(
450 | f'Memory Operations: {"✅ PASS" if memory_success else "❌ FAIL"} {memory_stats}'
451 | )
452 |
453 | print(f'Processing Pipeline: {"✅ PASS" if results["processing_wait"] else "❌ FAIL"}')
454 |
455 | search_stats = (
456 | f'({sum(results["search"].values())}/{len(results["search"])} types)'
457 | if results['search']
458 | else '(0/0 types)'
459 | )
460 | print(
461 | f'Search Operations: {"✅ PASS" if search_success else "❌ FAIL"} {search_stats}'
462 | )
463 |
464 | print(f'Episode Retrieval: {"✅ PASS" if results["episodes"] else "❌ FAIL"}')
465 |
466 | error_stats = (
467 | f'({sum(results["error_handling"].values())}/{len(results["error_handling"])} cases)'
468 | if results['error_handling']
469 | else '(0/0 cases)'
470 | )
471 | print(
472 | f'Error Handling: {"✅ PASS" if error_success else "❌ FAIL"} {error_stats}'
473 | )
474 |
475 | print('-' * 35)
476 | print(f'🎯 OVERALL RESULT: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
477 |
478 | if results['overall_success']:
479 | print('\n🎉 The refactored Graphiti MCP server is working correctly!')
480 | print(' All core functionality has been successfully tested.')
481 | else:
482 | print('\n⚠️ Some issues were detected. Review the test results above.')
483 | print(' The refactoring may need additional attention.')
484 |
485 | return results
486 |
487 |
488 | async def main():
489 | """Run the integration test."""
490 | try:
491 | async with GraphitiMCPIntegrationTest() as test:
492 | results = await test.run_comprehensive_test()
493 |
494 | # Exit with appropriate code
495 | exit_code = 0 if results['overall_success'] else 1
496 | exit(exit_code)
497 | except Exception as e:
498 | print(f'❌ Test setup failed: {e}')
499 | exit(1)
500 |
501 |
502 | if __name__ == '__main__':
503 | asyncio.run(main())
504 |
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/gemini_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import re
20 | import typing
21 | from typing import TYPE_CHECKING, ClassVar
22 |
23 | from pydantic import BaseModel
24 |
25 | from ..prompts.models import Message
26 | from .client import LLMClient, get_extraction_language_instruction
27 | from .config import LLMConfig, ModelSize
28 | from .errors import RateLimitError
29 |
30 | if TYPE_CHECKING:
31 | from google import genai
32 | from google.genai import types
33 | else:
34 | try:
35 | from google import genai
36 | from google.genai import types
37 | except ImportError:
38 | # If gemini client is not installed, raise an ImportError
39 | raise ImportError(
40 | 'google-genai is required for GeminiClient. '
41 | 'Install it with: pip install graphiti-core[google-genai]'
42 | ) from None
43 |
44 |
45 | logger = logging.getLogger(__name__)
46 |
47 | DEFAULT_MODEL = 'gemini-2.5-flash'
48 | DEFAULT_SMALL_MODEL = 'gemini-2.5-flash-lite'
49 |
50 | # Maximum output tokens for different Gemini models
51 | GEMINI_MODEL_MAX_TOKENS = {
52 | # Gemini 2.5 models
53 | 'gemini-2.5-pro': 65536,
54 | 'gemini-2.5-flash': 65536,
55 | 'gemini-2.5-flash-lite': 64000,
56 | # Gemini 2.0 models
57 | 'gemini-2.0-flash': 8192,
58 | 'gemini-2.0-flash-lite': 8192,
59 | # Gemini 1.5 models
60 | 'gemini-1.5-pro': 8192,
61 | 'gemini-1.5-flash': 8192,
62 | 'gemini-1.5-flash-8b': 8192,
63 | }
64 |
65 | # Default max tokens for models not in the mapping
66 | DEFAULT_GEMINI_MAX_TOKENS = 8192
67 |
68 |
69 | class GeminiClient(LLMClient):
70 | """
71 | GeminiClient is a client class for interacting with Google's Gemini language models.
72 |
73 | This class extends the LLMClient and provides methods to initialize the client
74 | and generate responses from the Gemini language model.
75 |
76 | Attributes:
77 | model (str): The model name to use for generating responses.
78 | temperature (float): The temperature to use for generating responses.
79 | max_tokens (int): The maximum number of tokens to generate in a response.
80 | thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
81 | Methods:
82 | __init__(config: LLMConfig | None = None, cache: bool = False, thinking_config: types.ThinkingConfig | None = None):
83 | Initializes the GeminiClient with the provided configuration, cache setting, and optional thinking config.
84 |
85 | _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
86 | Generates a response from the language model based on the provided messages.
87 | """
88 |
89 | # Class-level constants
90 | MAX_RETRIES: ClassVar[int] = 2
91 |
92 | def __init__(
93 | self,
94 | config: LLMConfig | None = None,
95 | cache: bool = False,
96 | max_tokens: int | None = None,
97 | thinking_config: types.ThinkingConfig | None = None,
98 | client: 'genai.Client | None' = None,
99 | ):
100 | """
101 | Initialize the GeminiClient with the provided configuration, cache setting, and optional thinking config.
102 |
103 | Args:
104 | config (LLMConfig | None): The configuration for the LLM client, including API key, model, temperature, and max tokens.
105 | cache (bool): Whether to use caching for responses. Defaults to False.
106 | thinking_config (types.ThinkingConfig | None): Optional thinking configuration for models that support it.
107 | Only use with models that support thinking (gemini-2.5+). Defaults to None.
108 | client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
109 | """
110 | if config is None:
111 | config = LLMConfig()
112 |
113 | super().__init__(config, cache)
114 |
115 | self.model = config.model
116 |
117 | if client is None:
118 | self.client = genai.Client(api_key=config.api_key)
119 | else:
120 | self.client = client
121 |
122 | self.max_tokens = max_tokens
123 | self.thinking_config = thinking_config
124 |
125 | def _check_safety_blocks(self, response) -> None:
126 | """Check if response was blocked for safety reasons and raise appropriate exceptions."""
127 | # Check if the response was blocked for safety reasons
128 | if not (hasattr(response, 'candidates') and response.candidates):
129 | return
130 |
131 | candidate = response.candidates[0]
132 | if not (hasattr(candidate, 'finish_reason') and candidate.finish_reason == 'SAFETY'):
133 | return
134 |
135 | # Content was blocked for safety reasons - collect safety details
136 | safety_info = []
137 | safety_ratings = getattr(candidate, 'safety_ratings', None)
138 |
139 | if safety_ratings:
140 | for rating in safety_ratings:
141 | if getattr(rating, 'blocked', False):
142 | category = getattr(rating, 'category', 'Unknown')
143 | probability = getattr(rating, 'probability', 'Unknown')
144 | safety_info.append(f'{category}: {probability}')
145 |
146 | safety_details = (
147 | ', '.join(safety_info) if safety_info else 'Content blocked for safety reasons'
148 | )
149 | raise Exception(f'Response blocked by Gemini safety filters: {safety_details}')
150 |
151 | def _check_prompt_blocks(self, response) -> None:
152 | """Check if prompt was blocked and raise appropriate exceptions."""
153 | prompt_feedback = getattr(response, 'prompt_feedback', None)
154 | if not prompt_feedback:
155 | return
156 |
157 | block_reason = getattr(prompt_feedback, 'block_reason', None)
158 | if block_reason:
159 | raise Exception(f'Prompt blocked by Gemini: {block_reason}')
160 |
161 | def _get_model_for_size(self, model_size: ModelSize) -> str:
162 | """Get the appropriate model name based on the requested size."""
163 | if model_size == ModelSize.small:
164 | return self.small_model or DEFAULT_SMALL_MODEL
165 | else:
166 | return self.model or DEFAULT_MODEL
167 |
168 | def _get_max_tokens_for_model(self, model: str) -> int:
169 | """Get the maximum output tokens for a specific Gemini model."""
170 | return GEMINI_MODEL_MAX_TOKENS.get(model, DEFAULT_GEMINI_MAX_TOKENS)
171 |
172 | def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
173 | """
174 | Resolve the maximum output tokens to use based on precedence rules.
175 |
176 | Precedence order (highest to lowest):
177 | 1. Explicit max_tokens parameter passed to generate_response()
178 | 2. Instance max_tokens set during client initialization
179 | 3. Model-specific maximum tokens from GEMINI_MODEL_MAX_TOKENS mapping
180 | 4. DEFAULT_MAX_TOKENS as final fallback
181 |
182 | Args:
183 | requested_max_tokens: The max_tokens parameter passed to generate_response()
184 | model: The model name to look up model-specific limits
185 |
186 | Returns:
187 | int: The resolved maximum tokens to use
188 | """
189 | # 1. Use explicit parameter if provided
190 | if requested_max_tokens is not None:
191 | return requested_max_tokens
192 |
193 | # 2. Use instance max_tokens if set during initialization
194 | if self.max_tokens is not None:
195 | return self.max_tokens
196 |
197 | # 3. Use model-specific maximum or return DEFAULT_GEMINI_MAX_TOKENS
198 | return self._get_max_tokens_for_model(model)
199 |
200 | def salvage_json(self, raw_output: str) -> dict[str, typing.Any] | None:
201 | """
202 | Attempt to salvage a JSON object if the raw output is truncated.
203 |
204 | This is accomplished by looking for the last closing bracket for an array or object.
205 | If found, it will try to load the JSON object from the raw output.
206 | If the JSON object is not valid, it will return None.
207 |
208 | Args:
209 | raw_output (str): The raw output from the LLM.
210 |
211 | Returns:
212 | dict[str, typing.Any]: The salvaged JSON object.
213 | None: If no salvage is possible.
214 | """
215 | if not raw_output:
216 | return None
217 | # Try to salvage a JSON array
218 | array_match = re.search(r'\]\s*$', raw_output)
219 | if array_match:
220 | try:
221 | return json.loads(raw_output[: array_match.end()])
222 | except Exception:
223 | pass
224 | # Try to salvage a JSON object
225 | obj_match = re.search(r'\}\s*$', raw_output)
226 | if obj_match:
227 | try:
228 | return json.loads(raw_output[: obj_match.end()])
229 | except Exception:
230 | pass
231 | return None
232 |
233 | async def _generate_response(
234 | self,
235 | messages: list[Message],
236 | response_model: type[BaseModel] | None = None,
237 | max_tokens: int | None = None,
238 | model_size: ModelSize = ModelSize.medium,
239 | ) -> dict[str, typing.Any]:
240 | """
241 | Generate a response from the Gemini language model.
242 |
243 | Args:
244 | messages (list[Message]): A list of messages to send to the language model.
245 | response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
246 | max_tokens (int | None): The maximum number of tokens to generate in the response. If None, uses precedence rules.
247 | model_size (ModelSize): The size of the model to use (small or medium).
248 |
249 | Returns:
250 | dict[str, typing.Any]: The response from the language model.
251 |
252 | Raises:
253 | RateLimitError: If the API rate limit is exceeded.
254 | Exception: If there is an error generating the response or content is blocked.
255 | """
256 | try:
257 | gemini_messages: typing.Any = []
258 | # If a response model is provided, add schema for structured output
259 | system_prompt = ''
260 | if response_model is not None:
261 | # Get the schema from the Pydantic model
262 | pydantic_schema = response_model.model_json_schema()
263 |
264 | # Create instruction to output in the desired JSON format
265 | system_prompt += (
266 | f'Output ONLY valid JSON matching this schema: {json.dumps(pydantic_schema)}.\n'
267 | 'Do not include any explanatory text before or after the JSON.\n\n'
268 | )
269 |
270 | # Add messages content
271 | # First check for a system message
272 | if messages and messages[0].role == 'system':
273 | system_prompt = f'{messages[0].content}\n\n {system_prompt}'
274 | messages = messages[1:]
275 |
276 | # Add the rest of the messages
277 | for m in messages:
278 | m.content = self._clean_input(m.content)
279 | gemini_messages.append(
280 | types.Content(role=m.role, parts=[types.Part.from_text(text=m.content)])
281 | )
282 |
283 | # Get the appropriate model for the requested size
284 | model = self._get_model_for_size(model_size)
285 |
286 | # Resolve max_tokens using precedence rules (see _resolve_max_tokens for details)
287 | resolved_max_tokens = self._resolve_max_tokens(max_tokens, model)
288 |
289 | # Create generation config
290 | generation_config = types.GenerateContentConfig(
291 | temperature=self.temperature,
292 | max_output_tokens=resolved_max_tokens,
293 | response_mime_type='application/json' if response_model else None,
294 | response_schema=response_model if response_model else None,
295 | system_instruction=system_prompt,
296 | thinking_config=self.thinking_config,
297 | )
298 |
299 | # Generate content using the simple string approach
300 | response = await self.client.aio.models.generate_content(
301 | model=model,
302 | contents=gemini_messages,
303 | config=generation_config,
304 | )
305 |
306 | # Always capture the raw output for debugging
307 | raw_output = getattr(response, 'text', None)
308 |
309 | # Check for safety and prompt blocks
310 | self._check_safety_blocks(response)
311 | self._check_prompt_blocks(response)
312 |
313 | # If this was a structured output request, parse the response into the Pydantic model
314 | if response_model is not None:
315 | try:
316 | if not raw_output:
317 | raise ValueError('No response text')
318 |
319 | validated_model = response_model.model_validate(json.loads(raw_output))
320 |
321 | # Return as a dictionary for API consistency
322 | return validated_model.model_dump()
323 | except Exception as e:
324 | if raw_output:
325 | logger.error(
326 | '🦀 LLM generation failed parsing as JSON, will try to salvage.'
327 | )
328 | logger.error(self._get_failed_generation_log(gemini_messages, raw_output))
329 | # Try to salvage
330 | salvaged = self.salvage_json(raw_output)
331 | if salvaged is not None:
332 | logger.warning('Salvaged partial JSON from truncated/malformed output.')
333 | return salvaged
334 | raise Exception(f'Failed to parse structured response: {e}') from e
335 |
336 | # Otherwise, return the response text as a dictionary
337 | return {'content': raw_output}
338 |
339 | except Exception as e:
340 | # Check if it's a rate limit error based on Gemini API error codes
341 | error_message = str(e).lower()
342 | if (
343 | 'rate limit' in error_message
344 | or 'quota' in error_message
345 | or 'resource_exhausted' in error_message
346 | or '429' in str(e)
347 | ):
348 | raise RateLimitError from e
349 |
350 | logger.error(f'Error in generating LLM response: {e}')
351 | raise Exception from e
352 |
353 | async def generate_response(
354 | self,
355 | messages: list[Message],
356 | response_model: type[BaseModel] | None = None,
357 | max_tokens: int | None = None,
358 | model_size: ModelSize = ModelSize.medium,
359 | group_id: str | None = None,
360 | prompt_name: str | None = None,
361 | ) -> dict[str, typing.Any]:
362 | """
363 | Generate a response from the Gemini language model with retry logic and error handling.
364 | This method overrides the parent class method to provide a direct implementation with advanced retry logic.
365 |
366 | Args:
367 | messages (list[Message]): A list of messages to send to the language model.
368 | response_model (type[BaseModel] | None): An optional Pydantic model to parse the response into.
369 | max_tokens (int | None): The maximum number of tokens to generate in the response.
370 | model_size (ModelSize): The size of the model to use (small or medium).
371 | group_id (str | None): Optional partition identifier for the graph.
372 | prompt_name (str | None): Optional name of the prompt for tracing.
373 |
374 | Returns:
375 | dict[str, typing.Any]: The response from the language model.
376 | """
377 | # Add multilingual extraction instructions
378 | messages[0].content += get_extraction_language_instruction(group_id)
379 |
380 | # Wrap entire operation in tracing span
381 | with self.tracer.start_span('llm.generate') as span:
382 | attributes = {
383 | 'llm.provider': 'gemini',
384 | 'model.size': model_size.value,
385 | 'max_tokens': max_tokens or self.max_tokens,
386 | }
387 | if prompt_name:
388 | attributes['prompt.name'] = prompt_name
389 | span.add_attributes(attributes)
390 |
391 | retry_count = 0
392 | last_error = None
393 | last_output = None
394 |
395 | while retry_count < self.MAX_RETRIES:
396 | try:
397 | response = await self._generate_response(
398 | messages=messages,
399 | response_model=response_model,
400 | max_tokens=max_tokens,
401 | model_size=model_size,
402 | )
403 | last_output = (
404 | response.get('content')
405 | if isinstance(response, dict) and 'content' in response
406 | else None
407 | )
408 | return response
409 | except RateLimitError as e:
410 | # Rate limit errors should not trigger retries (fail fast)
411 | span.set_status('error', str(e))
412 | raise e
413 | except Exception as e:
414 | last_error = e
415 |
416 | # Check if this is a safety block - these typically shouldn't be retried
417 | error_text = str(e) or (str(e.__cause__) if e.__cause__ else '')
418 | if 'safety' in error_text.lower() or 'blocked' in error_text.lower():
419 | logger.warning(f'Content blocked by safety filters: {e}')
420 | span.set_status('error', str(e))
421 | raise Exception(f'Content blocked by safety filters: {e}') from e
422 |
423 | retry_count += 1
424 |
425 | # Construct a detailed error message for the LLM
426 | error_context = (
427 | f'The previous response attempt was invalid. '
428 | f'Error type: {e.__class__.__name__}. '
429 | f'Error details: {str(e)}. '
430 | f'Please try again with a valid response, ensuring the output matches '
431 | f'the expected format and constraints.'
432 | )
433 |
434 | error_message = Message(role='user', content=error_context)
435 | messages.append(error_message)
436 | logger.warning(
437 | f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
438 | )
439 |
440 | # If we exit the loop without returning, all retries are exhausted
441 | logger.error('🦀 LLM generation failed and retries are exhausted.')
442 | logger.error(self._get_failed_generation_log(messages, last_output))
443 | logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {last_error}')
444 | span.set_status('error', str(last_error))
445 | span.record_exception(last_error) if last_error else None
446 | raise last_error or Exception('Max retries exceeded')
447 |
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_edge_operations.py:
--------------------------------------------------------------------------------
```python
1 | from datetime import datetime, timedelta, timezone
2 | from types import SimpleNamespace
3 | from unittest.mock import AsyncMock, MagicMock
4 |
5 | import pytest
6 | from pydantic import BaseModel
7 |
8 | from graphiti_core.edges import EntityEdge
9 | from graphiti_core.nodes import EntityNode, EpisodicNode
10 | from graphiti_core.search.search_config import SearchResults
11 | from graphiti_core.utils.maintenance.edge_operations import (
12 | DEFAULT_EDGE_NAME,
13 | resolve_extracted_edge,
14 | resolve_extracted_edges,
15 | )
16 |
17 |
18 | @pytest.fixture
19 | def mock_llm_client():
20 | client = MagicMock()
21 | client.generate_response = AsyncMock()
22 | return client
23 |
24 |
25 | @pytest.fixture
26 | def mock_extracted_edge():
27 | return EntityEdge(
28 | source_node_uuid='source_uuid',
29 | target_node_uuid='target_uuid',
30 | name='test_edge',
31 | group_id='group_1',
32 | fact='Test fact',
33 | episodes=['episode_1'],
34 | created_at=datetime.now(timezone.utc),
35 | valid_at=None,
36 | invalid_at=None,
37 | )
38 |
39 |
40 | @pytest.fixture
41 | def mock_related_edges():
42 | return [
43 | EntityEdge(
44 | source_node_uuid='source_uuid_2',
45 | target_node_uuid='target_uuid_2',
46 | name='related_edge',
47 | group_id='group_1',
48 | fact='Related fact',
49 | episodes=['episode_2'],
50 | created_at=datetime.now(timezone.utc) - timedelta(days=1),
51 | valid_at=datetime.now(timezone.utc) - timedelta(days=1),
52 | invalid_at=None,
53 | )
54 | ]
55 |
56 |
57 | @pytest.fixture
58 | def mock_existing_edges():
59 | return [
60 | EntityEdge(
61 | source_node_uuid='source_uuid_3',
62 | target_node_uuid='target_uuid_3',
63 | name='existing_edge',
64 | group_id='group_1',
65 | fact='Existing fact',
66 | episodes=['episode_3'],
67 | created_at=datetime.now(timezone.utc) - timedelta(days=2),
68 | valid_at=datetime.now(timezone.utc) - timedelta(days=2),
69 | invalid_at=None,
70 | )
71 | ]
72 |
73 |
74 | @pytest.fixture
75 | def mock_current_episode():
76 | return EpisodicNode(
77 | uuid='episode_1',
78 | content='Current episode content',
79 | valid_at=datetime.now(timezone.utc),
80 | name='Current Episode',
81 | group_id='group_1',
82 | source='message',
83 | source_description='Test source description',
84 | )
85 |
86 |
87 | @pytest.fixture
88 | def mock_previous_episodes():
89 | return [
90 | EpisodicNode(
91 | uuid='episode_2',
92 | content='Previous episode content',
93 | valid_at=datetime.now(timezone.utc) - timedelta(days=1),
94 | name='Previous Episode',
95 | group_id='group_1',
96 | source='message',
97 | source_description='Test source description',
98 | )
99 | ]
100 |
101 |
102 | # Run the tests
103 | if __name__ == '__main__':
104 | pytest.main([__file__])
105 |
106 |
107 | @pytest.mark.asyncio
108 | async def test_resolve_extracted_edge_exact_fact_short_circuit(
109 | mock_llm_client,
110 | mock_existing_edges,
111 | mock_current_episode,
112 | ):
113 | extracted = EntityEdge(
114 | source_node_uuid='source_uuid',
115 | target_node_uuid='target_uuid',
116 | name='test_edge',
117 | group_id='group_1',
118 | fact='Related fact',
119 | episodes=['episode_1'],
120 | created_at=datetime.now(timezone.utc),
121 | valid_at=None,
122 | invalid_at=None,
123 | )
124 |
125 | related_edges = [
126 | EntityEdge(
127 | source_node_uuid='source_uuid',
128 | target_node_uuid='target_uuid',
129 | name='related_edge',
130 | group_id='group_1',
131 | fact=' related FACT ',
132 | episodes=['episode_2'],
133 | created_at=datetime.now(timezone.utc) - timedelta(days=1),
134 | valid_at=None,
135 | invalid_at=None,
136 | )
137 | ]
138 |
139 | resolved_edge, duplicate_edges, invalidated = await resolve_extracted_edge(
140 | mock_llm_client,
141 | extracted,
142 | related_edges,
143 | mock_existing_edges,
144 | mock_current_episode,
145 | edge_type_candidates=None,
146 | )
147 |
148 | assert resolved_edge is related_edges[0]
149 | assert resolved_edge.episodes.count(mock_current_episode.uuid) == 1
150 | assert duplicate_edges == []
151 | assert invalidated == []
152 | mock_llm_client.generate_response.assert_not_called()
153 |
154 |
155 | class OccurredAtEdge(BaseModel):
156 | """Edge model stub for OCCURRED_AT."""
157 |
158 |
159 | @pytest.mark.asyncio
160 | async def test_resolve_extracted_edges_resets_unmapped_names(monkeypatch):
161 | from graphiti_core.utils.maintenance import edge_operations as edge_ops
162 |
163 | monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
164 | monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
165 |
166 | async def immediate_gather(*aws, max_coroutines=None):
167 | return [await aw for aw in aws]
168 |
169 | monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
170 | monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
171 |
172 | llm_client = MagicMock()
173 | llm_client.generate_response = AsyncMock(
174 | return_value={
175 | 'duplicate_facts': [],
176 | 'contradicted_facts': [],
177 | 'fact_type': 'DEFAULT',
178 | }
179 | )
180 |
181 | clients = SimpleNamespace(
182 | driver=MagicMock(),
183 | llm_client=llm_client,
184 | embedder=MagicMock(),
185 | cross_encoder=MagicMock(),
186 | )
187 |
188 | source_node = EntityNode(
189 | uuid='source_uuid',
190 | name='Document Node',
191 | group_id='group_1',
192 | labels=['Document'],
193 | )
194 | target_node = EntityNode(
195 | uuid='target_uuid',
196 | name='Topic Node',
197 | group_id='group_1',
198 | labels=['Topic'],
199 | )
200 |
201 | extracted_edge = EntityEdge(
202 | source_node_uuid=source_node.uuid,
203 | target_node_uuid=target_node.uuid,
204 | name='OCCURRED_AT',
205 | group_id='group_1',
206 | fact='Document occurred at somewhere',
207 | episodes=[],
208 | created_at=datetime.now(timezone.utc),
209 | valid_at=None,
210 | invalid_at=None,
211 | )
212 |
213 | episode = EpisodicNode(
214 | uuid='episode_uuid',
215 | name='Episode',
216 | group_id='group_1',
217 | source='message',
218 | source_description='desc',
219 | content='Episode content',
220 | valid_at=datetime.now(timezone.utc),
221 | )
222 |
223 | edge_types = {'OCCURRED_AT': OccurredAtEdge}
224 | edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
225 |
226 | resolved_edges, invalidated_edges = await resolve_extracted_edges(
227 | clients,
228 | [extracted_edge],
229 | episode,
230 | [source_node, target_node],
231 | edge_types,
232 | edge_type_map,
233 | )
234 |
235 | assert resolved_edges[0].name == DEFAULT_EDGE_NAME
236 | assert invalidated_edges == []
237 |
238 |
239 | @pytest.mark.asyncio
240 | async def test_resolve_extracted_edges_keeps_unknown_names(monkeypatch):
241 | from graphiti_core.utils.maintenance import edge_operations as edge_ops
242 |
243 | monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
244 | monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
245 |
246 | async def immediate_gather(*aws, max_coroutines=None):
247 | return [await aw for aw in aws]
248 |
249 | monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
250 | monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
251 |
252 | llm_client = MagicMock()
253 | llm_client.generate_response = AsyncMock(
254 | return_value={
255 | 'duplicate_facts': [],
256 | 'contradicted_facts': [],
257 | 'fact_type': 'DEFAULT',
258 | }
259 | )
260 |
261 | clients = SimpleNamespace(
262 | driver=MagicMock(),
263 | llm_client=llm_client,
264 | embedder=MagicMock(),
265 | cross_encoder=MagicMock(),
266 | )
267 |
268 | source_node = EntityNode(
269 | uuid='source_uuid',
270 | name='User Node',
271 | group_id='group_1',
272 | labels=['User'],
273 | )
274 | target_node = EntityNode(
275 | uuid='target_uuid',
276 | name='Topic Node',
277 | group_id='group_1',
278 | labels=['Topic'],
279 | )
280 |
281 | extracted_edge = EntityEdge(
282 | source_node_uuid=source_node.uuid,
283 | target_node_uuid=target_node.uuid,
284 | name='INTERACTED_WITH',
285 | group_id='group_1',
286 | fact='User interacted with topic',
287 | episodes=[],
288 | created_at=datetime.now(timezone.utc),
289 | valid_at=None,
290 | invalid_at=None,
291 | )
292 |
293 | episode = EpisodicNode(
294 | uuid='episode_uuid',
295 | name='Episode',
296 | group_id='group_1',
297 | source='message',
298 | source_description='desc',
299 | content='Episode content',
300 | valid_at=datetime.now(timezone.utc),
301 | )
302 |
303 | edge_types = {'OCCURRED_AT': OccurredAtEdge}
304 | edge_type_map = {('Event', 'Entity'): ['OCCURRED_AT']}
305 |
306 | resolved_edges, invalidated_edges = await resolve_extracted_edges(
307 | clients,
308 | [extracted_edge],
309 | episode,
310 | [source_node, target_node],
311 | edge_types,
312 | edge_type_map,
313 | )
314 |
315 | assert resolved_edges[0].name == 'INTERACTED_WITH'
316 | assert invalidated_edges == []
317 |
318 |
319 | @pytest.mark.asyncio
320 | async def test_resolve_extracted_edge_rejects_unmapped_fact_type(mock_llm_client):
321 | mock_llm_client.generate_response.return_value = {
322 | 'duplicate_facts': [],
323 | 'contradicted_facts': [],
324 | 'fact_type': 'OCCURRED_AT',
325 | }
326 |
327 | extracted_edge = EntityEdge(
328 | source_node_uuid='source_uuid',
329 | target_node_uuid='target_uuid',
330 | name='OCCURRED_AT',
331 | group_id='group_1',
332 | fact='Document occurred at somewhere',
333 | episodes=[],
334 | created_at=datetime.now(timezone.utc),
335 | valid_at=None,
336 | invalid_at=None,
337 | )
338 |
339 | episode = EpisodicNode(
340 | uuid='episode_uuid',
341 | name='Episode',
342 | group_id='group_1',
343 | source='message',
344 | source_description='desc',
345 | content='Episode content',
346 | valid_at=datetime.now(timezone.utc),
347 | )
348 |
349 | related_edge = EntityEdge(
350 | source_node_uuid='alt_source',
351 | target_node_uuid='alt_target',
352 | name='OTHER',
353 | group_id='group_1',
354 | fact='Different fact',
355 | episodes=[],
356 | created_at=datetime.now(timezone.utc),
357 | valid_at=None,
358 | invalid_at=None,
359 | )
360 |
361 | resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
362 | mock_llm_client,
363 | extracted_edge,
364 | [related_edge],
365 | [],
366 | episode,
367 | edge_type_candidates={},
368 | custom_edge_type_names={'OCCURRED_AT'},
369 | )
370 |
371 | assert resolved_edge.name == DEFAULT_EDGE_NAME
372 | assert duplicates == []
373 | assert invalidated == []
374 |
375 |
376 | @pytest.mark.asyncio
377 | async def test_resolve_extracted_edge_accepts_unknown_fact_type(mock_llm_client):
378 | mock_llm_client.generate_response.return_value = {
379 | 'duplicate_facts': [],
380 | 'contradicted_facts': [],
381 | 'fact_type': 'INTERACTED_WITH',
382 | }
383 |
384 | extracted_edge = EntityEdge(
385 | source_node_uuid='source_uuid',
386 | target_node_uuid='target_uuid',
387 | name='DEFAULT',
388 | group_id='group_1',
389 | fact='User interacted with topic',
390 | episodes=[],
391 | created_at=datetime.now(timezone.utc),
392 | valid_at=None,
393 | invalid_at=None,
394 | )
395 |
396 | episode = EpisodicNode(
397 | uuid='episode_uuid',
398 | name='Episode',
399 | group_id='group_1',
400 | source='message',
401 | source_description='desc',
402 | content='Episode content',
403 | valid_at=datetime.now(timezone.utc),
404 | )
405 |
406 | related_edge = EntityEdge(
407 | source_node_uuid='source_uuid',
408 | target_node_uuid='target_uuid',
409 | name='DEFAULT',
410 | group_id='group_1',
411 | fact='User mentioned a topic',
412 | episodes=[],
413 | created_at=datetime.now(timezone.utc),
414 | valid_at=None,
415 | invalid_at=None,
416 | )
417 |
418 | resolved_edge, duplicates, invalidated = await resolve_extracted_edge(
419 | mock_llm_client,
420 | extracted_edge,
421 | [related_edge],
422 | [],
423 | episode,
424 | edge_type_candidates={'OCCURRED_AT': OccurredAtEdge},
425 | custom_edge_type_names={'OCCURRED_AT'},
426 | )
427 |
428 | assert resolved_edge.name == 'INTERACTED_WITH'
429 | assert resolved_edge.attributes == {}
430 | assert duplicates == []
431 | assert invalidated == []
432 |
433 |
434 | @pytest.mark.asyncio
435 | async def test_resolve_extracted_edge_uses_integer_indices_for_duplicates(mock_llm_client):
436 | """Test that resolve_extracted_edge correctly uses integer indices for LLM duplicate detection."""
437 | # Mock LLM to return duplicate_facts with integer indices
438 | mock_llm_client.generate_response.return_value = {
439 | 'duplicate_facts': [0, 1], # LLM identifies first two related edges as duplicates
440 | 'contradicted_facts': [],
441 | 'fact_type': 'DEFAULT',
442 | }
443 |
444 | extracted_edge = EntityEdge(
445 | source_node_uuid='source_uuid',
446 | target_node_uuid='target_uuid',
447 | name='test_edge',
448 | group_id='group_1',
449 | fact='User likes yoga',
450 | episodes=[],
451 | created_at=datetime.now(timezone.utc),
452 | valid_at=None,
453 | invalid_at=None,
454 | )
455 |
456 | episode = EpisodicNode(
457 | uuid='episode_uuid',
458 | name='Episode',
459 | group_id='group_1',
460 | source='message',
461 | source_description='desc',
462 | content='Episode content',
463 | valid_at=datetime.now(timezone.utc),
464 | )
465 |
466 | # Create multiple related edges - LLM should receive these with integer indices
467 | related_edge_0 = EntityEdge(
468 | source_node_uuid='source_uuid',
469 | target_node_uuid='target_uuid',
470 | name='test_edge',
471 | group_id='group_1',
472 | fact='User enjoys yoga',
473 | episodes=['episode_1'],
474 | created_at=datetime.now(timezone.utc) - timedelta(days=1),
475 | valid_at=None,
476 | invalid_at=None,
477 | )
478 |
479 | related_edge_1 = EntityEdge(
480 | source_node_uuid='source_uuid',
481 | target_node_uuid='target_uuid',
482 | name='test_edge',
483 | group_id='group_1',
484 | fact='User practices yoga',
485 | episodes=['episode_2'],
486 | created_at=datetime.now(timezone.utc) - timedelta(days=2),
487 | valid_at=None,
488 | invalid_at=None,
489 | )
490 |
491 | related_edge_2 = EntityEdge(
492 | source_node_uuid='source_uuid',
493 | target_node_uuid='target_uuid',
494 | name='test_edge',
495 | group_id='group_1',
496 | fact='User loves swimming',
497 | episodes=['episode_3'],
498 | created_at=datetime.now(timezone.utc) - timedelta(days=3),
499 | valid_at=None,
500 | invalid_at=None,
501 | )
502 |
503 | related_edges = [related_edge_0, related_edge_1, related_edge_2]
504 |
505 | resolved_edge, invalidated, duplicates = await resolve_extracted_edge(
506 | mock_llm_client,
507 | extracted_edge,
508 | related_edges,
509 | [],
510 | episode,
511 | edge_type_candidates=None,
512 | custom_edge_type_names=set(),
513 | )
514 |
515 | # Verify LLM was called
516 | mock_llm_client.generate_response.assert_called_once()
517 |
518 | # Verify the system correctly identified duplicates using integer indices
519 | # The LLM returned [0, 1], so related_edge_0 and related_edge_1 should be marked as duplicates
520 | assert len(duplicates) == 2
521 | assert related_edge_0 in duplicates
522 | assert related_edge_1 in duplicates
523 | assert invalidated == []
524 |
525 | # Verify that the resolved edge is one of the duplicates (the first one found)
526 | # Check UUID since the episode list gets modified
527 | assert resolved_edge.uuid == related_edge_0.uuid
528 | assert episode.uuid in resolved_edge.episodes
529 |
530 |
531 | @pytest.mark.asyncio
532 | async def test_resolve_extracted_edges_fast_path_deduplication(monkeypatch):
533 | """Test that resolve_extracted_edges deduplicates exact matches before parallel processing."""
534 | from graphiti_core.utils.maintenance import edge_operations as edge_ops
535 |
536 | monkeypatch.setattr(edge_ops, 'create_entity_edge_embeddings', AsyncMock(return_value=None))
537 | monkeypatch.setattr(EntityEdge, 'get_between_nodes', AsyncMock(return_value=[]))
538 |
539 | # Track how many times resolve_extracted_edge is called
540 | resolve_call_count = 0
541 |
542 | async def mock_resolve_extracted_edge(
543 | llm_client,
544 | extracted_edge,
545 | related_edges,
546 | existing_edges,
547 | episode,
548 | edge_type_candidates=None,
549 | custom_edge_type_names=None,
550 | ):
551 | nonlocal resolve_call_count
552 | resolve_call_count += 1
553 | return extracted_edge, [], []
554 |
555 | # Mock semaphore_gather to execute awaitable immediately
556 | async def immediate_gather(*aws, max_coroutines=None):
557 | results = []
558 | for aw in aws:
559 | results.append(await aw)
560 | return results
561 |
562 | monkeypatch.setattr(edge_ops, 'semaphore_gather', immediate_gather)
563 | monkeypatch.setattr(edge_ops, 'search', AsyncMock(return_value=SearchResults()))
564 | monkeypatch.setattr(edge_ops, 'resolve_extracted_edge', mock_resolve_extracted_edge)
565 |
566 | llm_client = MagicMock()
567 | clients = SimpleNamespace(
568 | driver=MagicMock(),
569 | llm_client=llm_client,
570 | embedder=MagicMock(),
571 | cross_encoder=MagicMock(),
572 | )
573 |
574 | source_node = EntityNode(
575 | uuid='source_uuid',
576 | name='Assistant',
577 | group_id='group_1',
578 | labels=['Entity'],
579 | )
580 | target_node = EntityNode(
581 | uuid='target_uuid',
582 | name='User',
583 | group_id='group_1',
584 | labels=['Entity'],
585 | )
586 |
587 | # Create 3 identical edges
588 | edge1 = EntityEdge(
589 | source_node_uuid=source_node.uuid,
590 | target_node_uuid=target_node.uuid,
591 | name='recommends',
592 | group_id='group_1',
593 | fact='assistant recommends yoga poses',
594 | episodes=[],
595 | created_at=datetime.now(timezone.utc),
596 | valid_at=None,
597 | invalid_at=None,
598 | )
599 |
600 | edge2 = EntityEdge(
601 | source_node_uuid=source_node.uuid,
602 | target_node_uuid=target_node.uuid,
603 | name='recommends',
604 | group_id='group_1',
605 | fact=' Assistant Recommends YOGA Poses ', # Different whitespace/case
606 | episodes=[],
607 | created_at=datetime.now(timezone.utc),
608 | valid_at=None,
609 | invalid_at=None,
610 | )
611 |
612 | edge3 = EntityEdge(
613 | source_node_uuid=source_node.uuid,
614 | target_node_uuid=target_node.uuid,
615 | name='recommends',
616 | group_id='group_1',
617 | fact='assistant recommends yoga poses',
618 | episodes=[],
619 | created_at=datetime.now(timezone.utc),
620 | valid_at=None,
621 | invalid_at=None,
622 | )
623 |
624 | episode = EpisodicNode(
625 | uuid='episode_uuid',
626 | name='Episode',
627 | group_id='group_1',
628 | source='message',
629 | source_description='desc',
630 | content='Episode content',
631 | valid_at=datetime.now(timezone.utc),
632 | )
633 |
634 | resolved_edges, invalidated_edges = await resolve_extracted_edges(
635 | clients,
636 | [edge1, edge2, edge3],
637 | episode,
638 | [source_node, target_node],
639 | {},
640 | {},
641 | )
642 |
643 | # Fast path should have deduplicated the 3 identical edges to 1
644 | # So resolve_extracted_edge should only be called once
645 | assert resolve_call_count == 1
646 | assert len(resolved_edges) == 1
647 | assert invalidated_edges == []
648 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/bulk_utils.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import typing
20 | from datetime import datetime
21 |
22 | import numpy as np
23 | from pydantic import BaseModel, Field
24 | from typing_extensions import Any
25 |
26 | from graphiti_core.driver.driver import (
27 | GraphDriver,
28 | GraphDriverSession,
29 | GraphProvider,
30 | )
31 | from graphiti_core.edges import Edge, EntityEdge, EpisodicEdge, create_entity_edge_embeddings
32 | from graphiti_core.embedder import EmbedderClient
33 | from graphiti_core.graphiti_types import GraphitiClients
34 | from graphiti_core.helpers import normalize_l2, semaphore_gather
35 | from graphiti_core.models.edges.edge_db_queries import (
36 | get_entity_edge_save_bulk_query,
37 | get_episodic_edge_save_bulk_query,
38 | )
39 | from graphiti_core.models.nodes.node_db_queries import (
40 | get_entity_node_save_bulk_query,
41 | get_episode_node_save_bulk_query,
42 | )
43 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
44 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
45 | from graphiti_core.utils.maintenance.dedup_helpers import (
46 | DedupResolutionState,
47 | _build_candidate_indexes,
48 | _normalize_string_exact,
49 | _resolve_with_similarity,
50 | )
51 | from graphiti_core.utils.maintenance.edge_operations import (
52 | extract_edges,
53 | resolve_extracted_edge,
54 | )
55 | from graphiti_core.utils.maintenance.graph_data_operations import (
56 | EPISODE_WINDOW_LEN,
57 | retrieve_episodes,
58 | )
59 | from graphiti_core.utils.maintenance.node_operations import (
60 | extract_nodes,
61 | resolve_extracted_nodes,
62 | )
63 |
64 | logger = logging.getLogger(__name__)
65 |
66 | CHUNK_SIZE = 10
67 |
68 |
69 | def _build_directed_uuid_map(pairs: list[tuple[str, str]]) -> dict[str, str]:
70 | """Collapse alias -> canonical chains while preserving direction.
71 |
72 | The incoming pairs represent directed mappings discovered during node dedupe. We use a simple
73 | union-find with iterative path compression to ensure every source UUID resolves to its ultimate
74 | canonical target, even if aliases appear lexicographically smaller than the canonical UUID.
75 | """
76 |
77 | parent: dict[str, str] = {}
78 |
79 | def find(uuid: str) -> str:
80 | """Directed union-find lookup using iterative path compression."""
81 | parent.setdefault(uuid, uuid)
82 | root = uuid
83 | while parent[root] != root:
84 | root = parent[root]
85 |
86 | while parent[uuid] != root:
87 | next_uuid = parent[uuid]
88 | parent[uuid] = root
89 | uuid = next_uuid
90 |
91 | return root
92 |
93 | for source_uuid, target_uuid in pairs:
94 | parent.setdefault(source_uuid, source_uuid)
95 | parent.setdefault(target_uuid, target_uuid)
96 | parent[find(source_uuid)] = find(target_uuid)
97 |
98 | return {uuid: find(uuid) for uuid in parent}
99 |
100 |
101 | class RawEpisode(BaseModel):
102 | name: str
103 | uuid: str | None = Field(default=None)
104 | content: str
105 | source_description: str
106 | source: EpisodeType
107 | reference_time: datetime
108 |
109 |
110 | async def retrieve_previous_episodes_bulk(
111 | driver: GraphDriver, episodes: list[EpisodicNode]
112 | ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]:
113 | previous_episodes_list = await semaphore_gather(
114 | *[
115 | retrieve_episodes(
116 | driver, episode.valid_at, last_n=EPISODE_WINDOW_LEN, group_ids=[episode.group_id]
117 | )
118 | for episode in episodes
119 | ]
120 | )
121 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] = [
122 | (episode, previous_episodes_list[i]) for i, episode in enumerate(episodes)
123 | ]
124 |
125 | return episode_tuples
126 |
127 |
128 | async def add_nodes_and_edges_bulk(
129 | driver: GraphDriver,
130 | episodic_nodes: list[EpisodicNode],
131 | episodic_edges: list[EpisodicEdge],
132 | entity_nodes: list[EntityNode],
133 | entity_edges: list[EntityEdge],
134 | embedder: EmbedderClient,
135 | ):
136 | session = driver.session()
137 | try:
138 | await session.execute_write(
139 | add_nodes_and_edges_bulk_tx,
140 | episodic_nodes,
141 | episodic_edges,
142 | entity_nodes,
143 | entity_edges,
144 | embedder,
145 | driver=driver,
146 | )
147 | finally:
148 | await session.close()
149 |
150 |
151 | async def add_nodes_and_edges_bulk_tx(
152 | tx: GraphDriverSession,
153 | episodic_nodes: list[EpisodicNode],
154 | episodic_edges: list[EpisodicEdge],
155 | entity_nodes: list[EntityNode],
156 | entity_edges: list[EntityEdge],
157 | embedder: EmbedderClient,
158 | driver: GraphDriver,
159 | ):
160 | episodes = [dict(episode) for episode in episodic_nodes]
161 | for episode in episodes:
162 | episode['source'] = str(episode['source'].value)
163 | episode.pop('labels', None)
164 |
165 | nodes = []
166 |
167 | for node in entity_nodes:
168 | if node.name_embedding is None:
169 | await node.generate_name_embedding(embedder)
170 |
171 | entity_data: dict[str, Any] = {
172 | 'uuid': node.uuid,
173 | 'name': node.name,
174 | 'group_id': node.group_id,
175 | 'summary': node.summary,
176 | 'created_at': node.created_at,
177 | 'name_embedding': node.name_embedding,
178 | 'labels': list(set(node.labels + ['Entity'])),
179 | }
180 |
181 | if driver.provider == GraphProvider.KUZU:
182 | attributes = convert_datetimes_to_strings(node.attributes) if node.attributes else {}
183 | entity_data['attributes'] = json.dumps(attributes)
184 | else:
185 | entity_data.update(node.attributes or {})
186 |
187 | nodes.append(entity_data)
188 |
189 | edges = []
190 | for edge in entity_edges:
191 | if edge.fact_embedding is None:
192 | await edge.generate_embedding(embedder)
193 | edge_data: dict[str, Any] = {
194 | 'uuid': edge.uuid,
195 | 'source_node_uuid': edge.source_node_uuid,
196 | 'target_node_uuid': edge.target_node_uuid,
197 | 'name': edge.name,
198 | 'fact': edge.fact,
199 | 'group_id': edge.group_id,
200 | 'episodes': edge.episodes,
201 | 'created_at': edge.created_at,
202 | 'expired_at': edge.expired_at,
203 | 'valid_at': edge.valid_at,
204 | 'invalid_at': edge.invalid_at,
205 | 'fact_embedding': edge.fact_embedding,
206 | }
207 |
208 | if driver.provider == GraphProvider.KUZU:
209 | attributes = convert_datetimes_to_strings(edge.attributes) if edge.attributes else {}
210 | edge_data['attributes'] = json.dumps(attributes)
211 | else:
212 | edge_data.update(edge.attributes or {})
213 |
214 | edges.append(edge_data)
215 |
216 | if driver.graph_operations_interface:
217 | await driver.graph_operations_interface.episodic_node_save_bulk(None, driver, tx, episodes)
218 | await driver.graph_operations_interface.node_save_bulk(None, driver, tx, nodes)
219 | await driver.graph_operations_interface.episodic_edge_save_bulk(
220 | None, driver, tx, [edge.model_dump() for edge in episodic_edges]
221 | )
222 | await driver.graph_operations_interface.edge_save_bulk(None, driver, tx, edges)
223 |
224 | elif driver.provider == GraphProvider.KUZU:
225 | # FIXME: Kuzu's UNWIND does not currently support STRUCT[] type properly, so we insert the data one by one instead for now.
226 | episode_query = get_episode_node_save_bulk_query(driver.provider)
227 | for episode in episodes:
228 | await tx.run(episode_query, **episode)
229 | entity_node_query = get_entity_node_save_bulk_query(driver.provider, nodes)
230 | for node in nodes:
231 | await tx.run(entity_node_query, **node)
232 | entity_edge_query = get_entity_edge_save_bulk_query(driver.provider)
233 | for edge in edges:
234 | await tx.run(entity_edge_query, **edge)
235 | episodic_edge_query = get_episodic_edge_save_bulk_query(driver.provider)
236 | for edge in episodic_edges:
237 | await tx.run(episodic_edge_query, **edge.model_dump())
238 | else:
239 | await tx.run(get_episode_node_save_bulk_query(driver.provider), episodes=episodes)
240 | await tx.run(
241 | get_entity_node_save_bulk_query(driver.provider, nodes),
242 | nodes=nodes,
243 | )
244 | await tx.run(
245 | get_episodic_edge_save_bulk_query(driver.provider),
246 | episodic_edges=[edge.model_dump() for edge in episodic_edges],
247 | )
248 | await tx.run(
249 | get_entity_edge_save_bulk_query(driver.provider),
250 | entity_edges=edges,
251 | )
252 |
253 |
254 | async def extract_nodes_and_edges_bulk(
255 | clients: GraphitiClients,
256 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
257 | edge_type_map: dict[tuple[str, str], list[str]],
258 | entity_types: dict[str, type[BaseModel]] | None = None,
259 | excluded_entity_types: list[str] | None = None,
260 | edge_types: dict[str, type[BaseModel]] | None = None,
261 | ) -> tuple[list[list[EntityNode]], list[list[EntityEdge]]]:
262 | extracted_nodes_bulk: list[list[EntityNode]] = await semaphore_gather(
263 | *[
264 | extract_nodes(clients, episode, previous_episodes, entity_types, excluded_entity_types)
265 | for episode, previous_episodes in episode_tuples
266 | ]
267 | )
268 |
269 | extracted_edges_bulk: list[list[EntityEdge]] = await semaphore_gather(
270 | *[
271 | extract_edges(
272 | clients,
273 | episode,
274 | extracted_nodes_bulk[i],
275 | previous_episodes,
276 | edge_type_map=edge_type_map,
277 | group_id=episode.group_id,
278 | edge_types=edge_types,
279 | )
280 | for i, (episode, previous_episodes) in enumerate(episode_tuples)
281 | ]
282 | )
283 |
284 | return extracted_nodes_bulk, extracted_edges_bulk
285 |
286 |
287 | async def dedupe_nodes_bulk(
288 | clients: GraphitiClients,
289 | extracted_nodes: list[list[EntityNode]],
290 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
291 | entity_types: dict[str, type[BaseModel]] | None = None,
292 | ) -> tuple[dict[str, list[EntityNode]], dict[str, str]]:
293 | """Resolve entity duplicates across an in-memory batch using a two-pass strategy.
294 |
295 | 1. Run :func:`resolve_extracted_nodes` for every episode in parallel so each batch item is
296 | reconciled against the live graph just like the non-batch flow.
297 | 2. Re-run the deterministic similarity heuristics across the union of resolved nodes to catch
298 | duplicates that only co-occur inside this batch, emitting a canonical UUID map that callers
299 | can apply to edges and persistence.
300 | """
301 |
302 | first_pass_results = await semaphore_gather(
303 | *[
304 | resolve_extracted_nodes(
305 | clients,
306 | nodes,
307 | episode_tuples[i][0],
308 | episode_tuples[i][1],
309 | entity_types,
310 | )
311 | for i, nodes in enumerate(extracted_nodes)
312 | ]
313 | )
314 |
315 | episode_resolutions: list[tuple[str, list[EntityNode]]] = []
316 | per_episode_uuid_maps: list[dict[str, str]] = []
317 | duplicate_pairs: list[tuple[str, str]] = []
318 |
319 | for (resolved_nodes, uuid_map, duplicates), (episode, _) in zip(
320 | first_pass_results, episode_tuples, strict=True
321 | ):
322 | episode_resolutions.append((episode.uuid, resolved_nodes))
323 | per_episode_uuid_maps.append(uuid_map)
324 | duplicate_pairs.extend((source.uuid, target.uuid) for source, target in duplicates)
325 |
326 | canonical_nodes: dict[str, EntityNode] = {}
327 | for _, resolved_nodes in episode_resolutions:
328 | for node in resolved_nodes:
329 | # NOTE: this loop is O(n^2) in the number of nodes inside the batch because we rebuild
330 | # the MinHash index for the accumulated canonical pool each time. The LRU-backed
331 | # shingle cache keeps the constant factors low for typical batch sizes (≤ CHUNK_SIZE),
332 | # but if batches grow significantly we should switch to an incremental index or chunked
333 | # processing.
334 | if not canonical_nodes:
335 | canonical_nodes[node.uuid] = node
336 | continue
337 |
338 | existing_candidates = list(canonical_nodes.values())
339 | normalized = _normalize_string_exact(node.name)
340 | exact_match = next(
341 | (
342 | candidate
343 | for candidate in existing_candidates
344 | if _normalize_string_exact(candidate.name) == normalized
345 | ),
346 | None,
347 | )
348 | if exact_match is not None:
349 | if exact_match.uuid != node.uuid:
350 | duplicate_pairs.append((node.uuid, exact_match.uuid))
351 | continue
352 |
353 | indexes = _build_candidate_indexes(existing_candidates)
354 | state = DedupResolutionState(
355 | resolved_nodes=[None],
356 | uuid_map={},
357 | unresolved_indices=[],
358 | )
359 | _resolve_with_similarity([node], indexes, state)
360 |
361 | resolved = state.resolved_nodes[0]
362 | if resolved is None:
363 | canonical_nodes[node.uuid] = node
364 | continue
365 |
366 | canonical_uuid = resolved.uuid
367 | canonical_nodes.setdefault(canonical_uuid, resolved)
368 | if canonical_uuid != node.uuid:
369 | duplicate_pairs.append((node.uuid, canonical_uuid))
370 |
371 | union_pairs: list[tuple[str, str]] = []
372 | for uuid_map in per_episode_uuid_maps:
373 | union_pairs.extend(uuid_map.items())
374 | union_pairs.extend(duplicate_pairs)
375 |
376 | compressed_map: dict[str, str] = _build_directed_uuid_map(union_pairs)
377 |
378 | nodes_by_episode: dict[str, list[EntityNode]] = {}
379 | for episode_uuid, resolved_nodes in episode_resolutions:
380 | deduped_nodes: list[EntityNode] = []
381 | seen: set[str] = set()
382 | for node in resolved_nodes:
383 | canonical_uuid = compressed_map.get(node.uuid, node.uuid)
384 | if canonical_uuid in seen:
385 | continue
386 | seen.add(canonical_uuid)
387 | canonical_node = canonical_nodes.get(canonical_uuid)
388 | if canonical_node is None:
389 | logger.error(
390 | 'Canonical node %s missing during batch dedupe; falling back to %s',
391 | canonical_uuid,
392 | node.uuid,
393 | )
394 | canonical_node = node
395 | deduped_nodes.append(canonical_node)
396 |
397 | nodes_by_episode[episode_uuid] = deduped_nodes
398 |
399 | return nodes_by_episode, compressed_map
400 |
401 |
402 | async def dedupe_edges_bulk(
403 | clients: GraphitiClients,
404 | extracted_edges: list[list[EntityEdge]],
405 | episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]],
406 | _entities: list[EntityNode],
407 | edge_types: dict[str, type[BaseModel]],
408 | _edge_type_map: dict[tuple[str, str], list[str]],
409 | ) -> dict[str, list[EntityEdge]]:
410 | embedder = clients.embedder
411 | min_score = 0.6
412 |
413 | # generate embeddings
414 | await semaphore_gather(
415 | *[create_entity_edge_embeddings(embedder, edges) for edges in extracted_edges]
416 | )
417 |
418 | # Find similar results
419 | dedupe_tuples: list[tuple[EpisodicNode, EntityEdge, list[EntityEdge]]] = []
420 | for i, edges_i in enumerate(extracted_edges):
421 | existing_edges: list[EntityEdge] = []
422 | for edges_j in extracted_edges:
423 | existing_edges += edges_j
424 |
425 | for edge in edges_i:
426 | candidates: list[EntityEdge] = []
427 | for existing_edge in existing_edges:
428 | # Skip self-comparison
429 | if edge.uuid == existing_edge.uuid:
430 | continue
431 | # Approximate BM25 by checking for word overlaps (this is faster than creating many in-memory indices)
432 | # This approach will cast a wider net than BM25, which is ideal for this use case
433 | if (
434 | edge.source_node_uuid != existing_edge.source_node_uuid
435 | or edge.target_node_uuid != existing_edge.target_node_uuid
436 | ):
437 | continue
438 |
439 | edge_words = set(edge.fact.lower().split())
440 | existing_edge_words = set(existing_edge.fact.lower().split())
441 | has_overlap = not edge_words.isdisjoint(existing_edge_words)
442 | if has_overlap:
443 | candidates.append(existing_edge)
444 | continue
445 |
446 | # Check for semantic similarity even if there is no overlap
447 | similarity = np.dot(
448 | normalize_l2(edge.fact_embedding or []),
449 | normalize_l2(existing_edge.fact_embedding or []),
450 | )
451 | if similarity >= min_score:
452 | candidates.append(existing_edge)
453 |
454 | dedupe_tuples.append((episode_tuples[i][0], edge, candidates))
455 |
456 | bulk_edge_resolutions: list[
457 | tuple[EntityEdge, EntityEdge, list[EntityEdge]]
458 | ] = await semaphore_gather(
459 | *[
460 | resolve_extracted_edge(
461 | clients.llm_client,
462 | edge,
463 | candidates,
464 | candidates,
465 | episode,
466 | edge_types,
467 | set(edge_types),
468 | )
469 | for episode, edge, candidates in dedupe_tuples
470 | ]
471 | )
472 |
473 | # For now we won't track edge invalidation
474 | duplicate_pairs: list[tuple[str, str]] = []
475 | for i, (_, _, duplicates) in enumerate(bulk_edge_resolutions):
476 | episode, edge, candidates = dedupe_tuples[i]
477 | for duplicate in duplicates:
478 | duplicate_pairs.append((edge.uuid, duplicate.uuid))
479 |
480 | # Now we compress the duplicate_map, so that 3 -> 2 and 2 -> becomes 3 -> 1 (sorted by uuid)
481 | compressed_map: dict[str, str] = compress_uuid_map(duplicate_pairs)
482 |
483 | edge_uuid_map: dict[str, EntityEdge] = {
484 | edge.uuid: edge for edges in extracted_edges for edge in edges
485 | }
486 |
487 | edges_by_episode: dict[str, list[EntityEdge]] = {}
488 | for i, edges in enumerate(extracted_edges):
489 | episode = episode_tuples[i][0]
490 |
491 | edges_by_episode[episode.uuid] = [
492 | edge_uuid_map[compressed_map.get(edge.uuid, edge.uuid)] for edge in edges
493 | ]
494 |
495 | return edges_by_episode
496 |
497 |
498 | class UnionFind:
499 | def __init__(self, elements):
500 | # start each element in its own set
501 | self.parent = {e: e for e in elements}
502 |
503 | def find(self, x):
504 | # path‐compression
505 | if self.parent[x] != x:
506 | self.parent[x] = self.find(self.parent[x])
507 | return self.parent[x]
508 |
509 | def union(self, a, b):
510 | ra, rb = self.find(a), self.find(b)
511 | if ra == rb:
512 | return
513 | # attach the lexicographically larger root under the smaller
514 | if ra < rb:
515 | self.parent[rb] = ra
516 | else:
517 | self.parent[ra] = rb
518 |
519 |
520 | def compress_uuid_map(duplicate_pairs: list[tuple[str, str]]) -> dict[str, str]:
521 | """
522 | all_ids: iterable of all entity IDs (strings)
523 | duplicate_pairs: iterable of (id1, id2) pairs
524 | returns: dict mapping each id -> lexicographically smallest id in its duplicate set
525 | """
526 | all_uuids = set()
527 | for pair in duplicate_pairs:
528 | all_uuids.add(pair[0])
529 | all_uuids.add(pair[1])
530 |
531 | uf = UnionFind(all_uuids)
532 | for a, b in duplicate_pairs:
533 | uf.union(a, b)
534 | # ensure full path‐compression before mapping
535 | return {uuid: uf.find(uuid) for uuid in all_uuids}
536 |
537 |
538 | E = typing.TypeVar('E', bound=Edge)
539 |
540 |
541 | def resolve_edge_pointers(edges: list[E], uuid_map: dict[str, str]):
542 | for edge in edges:
543 | source_uuid = edge.source_node_uuid
544 | target_uuid = edge.target_node_uuid
545 | edge.source_node_uuid = uuid_map.get(source_uuid, source_uuid)
546 | edge.target_node_uuid = uuid_map.get(target_uuid, target_uuid)
547 |
548 | return edges
549 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_stress_load.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Stress and load testing for Graphiti MCP Server.
4 | Tests system behavior under high load, resource constraints, and edge conditions.
5 | """
6 |
7 | import asyncio
8 | import gc
9 | import random
10 | import time
11 | from dataclasses import dataclass
12 |
13 | import psutil
14 | import pytest
15 | from test_fixtures import TestDataGenerator, graphiti_test_client
16 |
17 |
18 | @dataclass
19 | class LoadTestConfig:
20 | """Configuration for load testing scenarios."""
21 |
22 | num_clients: int = 10
23 | operations_per_client: int = 100
24 | ramp_up_time: float = 5.0 # seconds
25 | test_duration: float = 60.0 # seconds
26 | target_throughput: float | None = None # ops/sec
27 | think_time: float = 0.1 # seconds between ops
28 |
29 |
30 | @dataclass
31 | class LoadTestResult:
32 | """Results from a load test run."""
33 |
34 | total_operations: int
35 | successful_operations: int
36 | failed_operations: int
37 | duration: float
38 | throughput: float
39 | average_latency: float
40 | p50_latency: float
41 | p95_latency: float
42 | p99_latency: float
43 | max_latency: float
44 | errors: dict[str, int]
45 | resource_usage: dict[str, float]
46 |
47 |
48 | class LoadTester:
49 | """Orchestrate load testing scenarios."""
50 |
51 | def __init__(self, config: LoadTestConfig):
52 | self.config = config
53 | self.metrics: list[tuple[float, float, bool]] = [] # (start, duration, success)
54 | self.errors: dict[str, int] = {}
55 | self.start_time: float | None = None
56 |
57 | async def run_client_workload(self, client_id: int, session, group_id: str) -> dict[str, int]:
58 | """Run workload for a single simulated client."""
59 | stats = {'success': 0, 'failure': 0}
60 | data_gen = TestDataGenerator()
61 |
62 | # Ramp-up delay
63 | ramp_delay = (client_id / self.config.num_clients) * self.config.ramp_up_time
64 | await asyncio.sleep(ramp_delay)
65 |
66 | for op_num in range(self.config.operations_per_client):
67 | operation_start = time.time()
68 |
69 | try:
70 | # Randomly select operation type
71 | operation = random.choice(
72 | [
73 | 'add_memory',
74 | 'search_memory_nodes',
75 | 'get_episodes',
76 | ]
77 | )
78 |
79 | if operation == 'add_memory':
80 | args = {
81 | 'name': f'Load Test {client_id}-{op_num}',
82 | 'episode_body': data_gen.generate_technical_document(),
83 | 'source': 'text',
84 | 'source_description': 'load test',
85 | 'group_id': group_id,
86 | }
87 | elif operation == 'search_memory_nodes':
88 | args = {
89 | 'query': random.choice(['performance', 'architecture', 'test', 'data']),
90 | 'group_id': group_id,
91 | 'limit': 10,
92 | }
93 | else: # get_episodes
94 | args = {
95 | 'group_id': group_id,
96 | 'last_n': 10,
97 | }
98 |
99 | # Execute operation with timeout
100 | await asyncio.wait_for(session.call_tool(operation, args), timeout=30.0)
101 |
102 | duration = time.time() - operation_start
103 | self.metrics.append((operation_start, duration, True))
104 | stats['success'] += 1
105 |
106 | except asyncio.TimeoutError:
107 | duration = time.time() - operation_start
108 | self.metrics.append((operation_start, duration, False))
109 | self.errors['timeout'] = self.errors.get('timeout', 0) + 1
110 | stats['failure'] += 1
111 |
112 | except Exception as e:
113 | duration = time.time() - operation_start
114 | self.metrics.append((operation_start, duration, False))
115 | error_type = type(e).__name__
116 | self.errors[error_type] = self.errors.get(error_type, 0) + 1
117 | stats['failure'] += 1
118 |
119 | # Think time between operations
120 | await asyncio.sleep(self.config.think_time)
121 |
122 | # Stop if we've exceeded test duration
123 | if self.start_time and (time.time() - self.start_time) > self.config.test_duration:
124 | break
125 |
126 | return stats
127 |
128 | def calculate_results(self) -> LoadTestResult:
129 | """Calculate load test results from metrics."""
130 | if not self.metrics:
131 | return LoadTestResult(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, {}, {})
132 |
133 | successful = [m for m in self.metrics if m[2]]
134 | failed = [m for m in self.metrics if not m[2]]
135 |
136 | latencies = sorted([m[1] for m in self.metrics])
137 | duration = max([m[0] + m[1] for m in self.metrics]) - min([m[0] for m in self.metrics])
138 |
139 | # Calculate percentiles
140 | def percentile(data: list[float], p: float) -> float:
141 | if not data:
142 | return 0.0
143 | idx = int(len(data) * p / 100)
144 | return data[min(idx, len(data) - 1)]
145 |
146 | # Get resource usage
147 | process = psutil.Process()
148 | resource_usage = {
149 | 'cpu_percent': process.cpu_percent(),
150 | 'memory_mb': process.memory_info().rss / 1024 / 1024,
151 | 'num_threads': process.num_threads(),
152 | }
153 |
154 | return LoadTestResult(
155 | total_operations=len(self.metrics),
156 | successful_operations=len(successful),
157 | failed_operations=len(failed),
158 | duration=duration,
159 | throughput=len(self.metrics) / duration if duration > 0 else 0,
160 | average_latency=sum(latencies) / len(latencies) if latencies else 0,
161 | p50_latency=percentile(latencies, 50),
162 | p95_latency=percentile(latencies, 95),
163 | p99_latency=percentile(latencies, 99),
164 | max_latency=max(latencies) if latencies else 0,
165 | errors=self.errors,
166 | resource_usage=resource_usage,
167 | )
168 |
169 |
170 | class TestLoadScenarios:
171 | """Various load testing scenarios."""
172 |
173 | @pytest.mark.asyncio
174 | @pytest.mark.slow
175 | async def test_sustained_load(self):
176 | """Test system under sustained moderate load."""
177 | config = LoadTestConfig(
178 | num_clients=5,
179 | operations_per_client=20,
180 | ramp_up_time=2.0,
181 | test_duration=30.0,
182 | think_time=0.5,
183 | )
184 |
185 | async with graphiti_test_client() as (session, group_id):
186 | tester = LoadTester(config)
187 | tester.start_time = time.time()
188 |
189 | # Run client workloads
190 | client_tasks = []
191 | for client_id in range(config.num_clients):
192 | task = tester.run_client_workload(client_id, session, group_id)
193 | client_tasks.append(task)
194 |
195 | # Execute all clients
196 | await asyncio.gather(*client_tasks)
197 |
198 | # Calculate results
199 | results = tester.calculate_results()
200 |
201 | # Assertions
202 | assert results.successful_operations > results.failed_operations
203 | assert results.average_latency < 5.0, (
204 | f'Average latency too high: {results.average_latency:.2f}s'
205 | )
206 | assert results.p95_latency < 10.0, f'P95 latency too high: {results.p95_latency:.2f}s'
207 |
208 | # Report results
209 | print('\nSustained Load Test Results:')
210 | print(f' Total operations: {results.total_operations}')
211 | print(
212 | f' Success rate: {results.successful_operations / results.total_operations * 100:.1f}%'
213 | )
214 | print(f' Throughput: {results.throughput:.2f} ops/s')
215 | print(f' Avg latency: {results.average_latency:.2f}s')
216 | print(f' P95 latency: {results.p95_latency:.2f}s')
217 |
218 | @pytest.mark.asyncio
219 | @pytest.mark.slow
220 | async def test_spike_load(self):
221 | """Test system response to sudden load spikes."""
222 | async with graphiti_test_client() as (session, group_id):
223 | # Normal load phase
224 | normal_tasks = []
225 | for i in range(3):
226 | task = session.call_tool(
227 | 'add_memory',
228 | {
229 | 'name': f'Normal Load {i}',
230 | 'episode_body': 'Normal operation',
231 | 'source': 'text',
232 | 'source_description': 'normal',
233 | 'group_id': group_id,
234 | },
235 | )
236 | normal_tasks.append(task)
237 | await asyncio.sleep(0.5)
238 |
239 | await asyncio.gather(*normal_tasks)
240 |
241 | # Spike phase - sudden burst of requests
242 | spike_start = time.time()
243 | spike_tasks = []
244 | for i in range(50):
245 | task = session.call_tool(
246 | 'add_memory',
247 | {
248 | 'name': f'Spike Load {i}',
249 | 'episode_body': TestDataGenerator.generate_technical_document(),
250 | 'source': 'text',
251 | 'source_description': 'spike',
252 | 'group_id': group_id,
253 | },
254 | )
255 | spike_tasks.append(task)
256 |
257 | # Execute spike
258 | spike_results = await asyncio.gather(*spike_tasks, return_exceptions=True)
259 | spike_duration = time.time() - spike_start
260 |
261 | # Analyze spike handling
262 | spike_failures = sum(1 for r in spike_results if isinstance(r, Exception))
263 | spike_success_rate = (len(spike_results) - spike_failures) / len(spike_results)
264 |
265 | print('\nSpike Load Test Results:')
266 | print(f' Spike size: {len(spike_tasks)} operations')
267 | print(f' Duration: {spike_duration:.2f}s')
268 | print(f' Success rate: {spike_success_rate * 100:.1f}%')
269 | print(f' Throughput: {len(spike_tasks) / spike_duration:.2f} ops/s')
270 |
271 | # System should handle at least 80% of spike
272 | assert spike_success_rate > 0.8, f'Too many failures during spike: {spike_failures}'
273 |
274 | @pytest.mark.asyncio
275 | @pytest.mark.slow
276 | async def test_memory_leak_detection(self):
277 | """Test for memory leaks during extended operation."""
278 | async with graphiti_test_client() as (session, group_id):
279 | process = psutil.Process()
280 | gc.collect() # Force garbage collection
281 | initial_memory = process.memory_info().rss / 1024 / 1024 # MB
282 |
283 | # Perform many operations
284 | for batch in range(10):
285 | batch_tasks = []
286 | for i in range(10):
287 | task = session.call_tool(
288 | 'add_memory',
289 | {
290 | 'name': f'Memory Test {batch}-{i}',
291 | 'episode_body': TestDataGenerator.generate_technical_document(),
292 | 'source': 'text',
293 | 'source_description': 'memory test',
294 | 'group_id': group_id,
295 | },
296 | )
297 | batch_tasks.append(task)
298 |
299 | await asyncio.gather(*batch_tasks)
300 |
301 | # Force garbage collection between batches
302 | gc.collect()
303 | await asyncio.sleep(1)
304 |
305 | # Check memory after operations
306 | gc.collect()
307 | final_memory = process.memory_info().rss / 1024 / 1024 # MB
308 | memory_growth = final_memory - initial_memory
309 |
310 | print('\nMemory Leak Test:')
311 | print(f' Initial memory: {initial_memory:.1f} MB')
312 | print(f' Final memory: {final_memory:.1f} MB')
313 | print(f' Growth: {memory_growth:.1f} MB')
314 |
315 | # Allow for some memory growth but flag potential leaks
316 | # This is a soft check - actual threshold depends on system
317 | if memory_growth > 100: # More than 100MB growth
318 | print(f' ⚠️ Potential memory leak detected: {memory_growth:.1f} MB growth')
319 |
320 | @pytest.mark.asyncio
321 | @pytest.mark.slow
322 | async def test_connection_pool_exhaustion(self):
323 | """Test behavior when connection pools are exhausted."""
324 | async with graphiti_test_client() as (session, group_id):
325 | # Create many concurrent long-running operations
326 | long_tasks = []
327 | for i in range(100): # Many more than typical pool size
328 | task = session.call_tool(
329 | 'search_memory_nodes',
330 | {
331 | 'query': f'complex query {i} '
332 | + ' '.join([TestDataGenerator.fake.word() for _ in range(10)]),
333 | 'group_id': group_id,
334 | 'limit': 100,
335 | },
336 | )
337 | long_tasks.append(task)
338 |
339 | # Execute with timeout
340 | try:
341 | results = await asyncio.wait_for(
342 | asyncio.gather(*long_tasks, return_exceptions=True), timeout=60.0
343 | )
344 |
345 | # Count connection-related errors
346 | connection_errors = sum(
347 | 1
348 | for r in results
349 | if isinstance(r, Exception) and 'connection' in str(r).lower()
350 | )
351 |
352 | print('\nConnection Pool Test:')
353 | print(f' Total requests: {len(long_tasks)}')
354 | print(f' Connection errors: {connection_errors}')
355 |
356 | except asyncio.TimeoutError:
357 | print(' Test timed out - possible deadlock or exhaustion')
358 |
359 | @pytest.mark.asyncio
360 | @pytest.mark.slow
361 | async def test_gradual_degradation(self):
362 | """Test system degradation under increasing load."""
363 | async with graphiti_test_client() as (session, group_id):
364 | load_levels = [5, 10, 20, 40, 80] # Increasing concurrent operations
365 | results_by_level = {}
366 |
367 | for level in load_levels:
368 | level_start = time.time()
369 | tasks = []
370 |
371 | for i in range(level):
372 | task = session.call_tool(
373 | 'add_memory',
374 | {
375 | 'name': f'Load Level {level} Op {i}',
376 | 'episode_body': f'Testing at load level {level}',
377 | 'source': 'text',
378 | 'source_description': 'degradation test',
379 | 'group_id': group_id,
380 | },
381 | )
382 | tasks.append(task)
383 |
384 | # Execute level
385 | level_results = await asyncio.gather(*tasks, return_exceptions=True)
386 | level_duration = time.time() - level_start
387 |
388 | # Calculate metrics
389 | failures = sum(1 for r in level_results if isinstance(r, Exception))
390 | success_rate = (level - failures) / level * 100
391 | throughput = level / level_duration
392 |
393 | results_by_level[level] = {
394 | 'success_rate': success_rate,
395 | 'throughput': throughput,
396 | 'duration': level_duration,
397 | }
398 |
399 | print(f'\nLoad Level {level}:')
400 | print(f' Success rate: {success_rate:.1f}%')
401 | print(f' Throughput: {throughput:.2f} ops/s')
402 | print(f' Duration: {level_duration:.2f}s')
403 |
404 | # Brief pause between levels
405 | await asyncio.sleep(2)
406 |
407 | # Verify graceful degradation
408 | # Success rate should not drop below 50% even at high load
409 | for level, metrics in results_by_level.items():
410 | assert metrics['success_rate'] > 50, f'Poor performance at load level {level}'
411 |
412 |
413 | class TestResourceLimits:
414 | """Test behavior at resource limits."""
415 |
416 | @pytest.mark.asyncio
417 | async def test_large_payload_handling(self):
418 | """Test handling of very large payloads."""
419 | async with graphiti_test_client() as (session, group_id):
420 | payload_sizes = [
421 | (1_000, '1KB'),
422 | (10_000, '10KB'),
423 | (100_000, '100KB'),
424 | (1_000_000, '1MB'),
425 | ]
426 |
427 | for size, label in payload_sizes:
428 | content = 'x' * size
429 |
430 | start_time = time.time()
431 | try:
432 | await asyncio.wait_for(
433 | session.call_tool(
434 | 'add_memory',
435 | {
436 | 'name': f'Large Payload {label}',
437 | 'episode_body': content,
438 | 'source': 'text',
439 | 'source_description': 'payload test',
440 | 'group_id': group_id,
441 | },
442 | ),
443 | timeout=30.0,
444 | )
445 | duration = time.time() - start_time
446 | status = '✅ Success'
447 |
448 | except asyncio.TimeoutError:
449 | duration = 30.0
450 | status = '⏱️ Timeout'
451 |
452 | except Exception as e:
453 | duration = time.time() - start_time
454 | status = f'❌ Error: {type(e).__name__}'
455 |
456 | print(f'Payload {label}: {status} ({duration:.2f}s)')
457 |
458 | @pytest.mark.asyncio
459 | async def test_rate_limit_handling(self):
460 | """Test handling of rate limits."""
461 | async with graphiti_test_client() as (session, group_id):
462 | # Rapid fire requests to trigger rate limits
463 | rapid_tasks = []
464 | for i in range(100):
465 | task = session.call_tool(
466 | 'add_memory',
467 | {
468 | 'name': f'Rate Limit Test {i}',
469 | 'episode_body': f'Testing rate limit {i}',
470 | 'source': 'text',
471 | 'source_description': 'rate test',
472 | 'group_id': group_id,
473 | },
474 | )
475 | rapid_tasks.append(task)
476 |
477 | # Execute without delays
478 | results = await asyncio.gather(*rapid_tasks, return_exceptions=True)
479 |
480 | # Count rate limit errors
481 | rate_limit_errors = sum(
482 | 1
483 | for r in results
484 | if isinstance(r, Exception) and ('rate' in str(r).lower() or '429' in str(r))
485 | )
486 |
487 | print('\nRate Limit Test:')
488 | print(f' Total requests: {len(rapid_tasks)}')
489 | print(f' Rate limit errors: {rate_limit_errors}')
490 | print(
491 | f' Success rate: {(len(rapid_tasks) - rate_limit_errors) / len(rapid_tasks) * 100:.1f}%'
492 | )
493 |
494 |
495 | def generate_load_test_report(results: list[LoadTestResult]) -> str:
496 | """Generate comprehensive load test report."""
497 | report = []
498 | report.append('\n' + '=' * 60)
499 | report.append('LOAD TEST REPORT')
500 | report.append('=' * 60)
501 |
502 | for i, result in enumerate(results):
503 | report.append(f'\nTest Run {i + 1}:')
504 | report.append(f' Total Operations: {result.total_operations}')
505 | report.append(
506 | f' Success Rate: {result.successful_operations / result.total_operations * 100:.1f}%'
507 | )
508 | report.append(f' Throughput: {result.throughput:.2f} ops/s')
509 | report.append(
510 | f' Latency (avg/p50/p95/p99/max): {result.average_latency:.2f}/{result.p50_latency:.2f}/{result.p95_latency:.2f}/{result.p99_latency:.2f}/{result.max_latency:.2f}s'
511 | )
512 |
513 | if result.errors:
514 | report.append(' Errors:')
515 | for error_type, count in result.errors.items():
516 | report.append(f' {error_type}: {count}')
517 |
518 | report.append(' Resource Usage:')
519 | for metric, value in result.resource_usage.items():
520 | report.append(f' {metric}: {value:.2f}')
521 |
522 | report.append('=' * 60)
523 | return '\n'.join(report)
524 |
525 |
526 | if __name__ == '__main__':
527 | pytest.main([__file__, '-v', '--asyncio-mode=auto', '-m', 'slow'])
528 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/node_operations.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from collections.abc import Awaitable, Callable
19 | from time import time
20 | from typing import Any
21 |
22 | from pydantic import BaseModel
23 |
24 | from graphiti_core.graphiti_types import GraphitiClients
25 | from graphiti_core.helpers import MAX_REFLEXION_ITERATIONS, semaphore_gather
26 | from graphiti_core.llm_client import LLMClient
27 | from graphiti_core.llm_client.config import ModelSize
28 | from graphiti_core.nodes import (
29 | EntityNode,
30 | EpisodeType,
31 | EpisodicNode,
32 | create_entity_node_embeddings,
33 | )
34 | from graphiti_core.prompts import prompt_library
35 | from graphiti_core.prompts.dedupe_nodes import NodeDuplicate, NodeResolutions
36 | from graphiti_core.prompts.extract_nodes import (
37 | EntitySummary,
38 | ExtractedEntities,
39 | ExtractedEntity,
40 | MissedEntities,
41 | )
42 | from graphiti_core.search.search import search
43 | from graphiti_core.search.search_config import SearchResults
44 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
45 | from graphiti_core.search.search_filters import SearchFilters
46 | from graphiti_core.utils.datetime_utils import utc_now
47 | from graphiti_core.utils.maintenance.dedup_helpers import (
48 | DedupCandidateIndexes,
49 | DedupResolutionState,
50 | _build_candidate_indexes,
51 | _resolve_with_similarity,
52 | )
53 | from graphiti_core.utils.maintenance.edge_operations import (
54 | filter_existing_duplicate_of_edges,
55 | )
56 | from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS, truncate_at_sentence
57 |
58 | logger = logging.getLogger(__name__)
59 |
60 | NodeSummaryFilter = Callable[[EntityNode], Awaitable[bool]]
61 |
62 |
63 | async def extract_nodes_reflexion(
64 | llm_client: LLMClient,
65 | episode: EpisodicNode,
66 | previous_episodes: list[EpisodicNode],
67 | node_names: list[str],
68 | group_id: str | None = None,
69 | ) -> list[str]:
70 | # Prepare context for LLM
71 | context = {
72 | 'episode_content': episode.content,
73 | 'previous_episodes': [ep.content for ep in previous_episodes],
74 | 'extracted_entities': node_names,
75 | }
76 |
77 | llm_response = await llm_client.generate_response(
78 | prompt_library.extract_nodes.reflexion(context),
79 | MissedEntities,
80 | group_id=group_id,
81 | prompt_name='extract_nodes.reflexion',
82 | )
83 | missed_entities = llm_response.get('missed_entities', [])
84 |
85 | return missed_entities
86 |
87 |
88 | async def extract_nodes(
89 | clients: GraphitiClients,
90 | episode: EpisodicNode,
91 | previous_episodes: list[EpisodicNode],
92 | entity_types: dict[str, type[BaseModel]] | None = None,
93 | excluded_entity_types: list[str] | None = None,
94 | ) -> list[EntityNode]:
95 | start = time()
96 | llm_client = clients.llm_client
97 | llm_response = {}
98 | custom_prompt = ''
99 | entities_missed = True
100 | reflexion_iterations = 0
101 |
102 | entity_types_context = [
103 | {
104 | 'entity_type_id': 0,
105 | 'entity_type_name': 'Entity',
106 | 'entity_type_description': 'Default entity classification. Use this entity type if the entity is not one of the other listed types.',
107 | }
108 | ]
109 |
110 | entity_types_context += (
111 | [
112 | {
113 | 'entity_type_id': i + 1,
114 | 'entity_type_name': type_name,
115 | 'entity_type_description': type_model.__doc__,
116 | }
117 | for i, (type_name, type_model) in enumerate(entity_types.items())
118 | ]
119 | if entity_types is not None
120 | else []
121 | )
122 |
123 | context = {
124 | 'episode_content': episode.content,
125 | 'episode_timestamp': episode.valid_at.isoformat(),
126 | 'previous_episodes': [ep.content for ep in previous_episodes],
127 | 'custom_prompt': custom_prompt,
128 | 'entity_types': entity_types_context,
129 | 'source_description': episode.source_description,
130 | }
131 |
132 | while entities_missed and reflexion_iterations <= MAX_REFLEXION_ITERATIONS:
133 | if episode.source == EpisodeType.message:
134 | llm_response = await llm_client.generate_response(
135 | prompt_library.extract_nodes.extract_message(context),
136 | response_model=ExtractedEntities,
137 | group_id=episode.group_id,
138 | prompt_name='extract_nodes.extract_message',
139 | )
140 | elif episode.source == EpisodeType.text:
141 | llm_response = await llm_client.generate_response(
142 | prompt_library.extract_nodes.extract_text(context),
143 | response_model=ExtractedEntities,
144 | group_id=episode.group_id,
145 | prompt_name='extract_nodes.extract_text',
146 | )
147 | elif episode.source == EpisodeType.json:
148 | llm_response = await llm_client.generate_response(
149 | prompt_library.extract_nodes.extract_json(context),
150 | response_model=ExtractedEntities,
151 | group_id=episode.group_id,
152 | prompt_name='extract_nodes.extract_json',
153 | )
154 |
155 | response_object = ExtractedEntities(**llm_response)
156 |
157 | extracted_entities: list[ExtractedEntity] = response_object.extracted_entities
158 |
159 | reflexion_iterations += 1
160 | if reflexion_iterations < MAX_REFLEXION_ITERATIONS:
161 | missing_entities = await extract_nodes_reflexion(
162 | llm_client,
163 | episode,
164 | previous_episodes,
165 | [entity.name for entity in extracted_entities],
166 | episode.group_id,
167 | )
168 |
169 | entities_missed = len(missing_entities) != 0
170 |
171 | custom_prompt = 'Make sure that the following entities are extracted: '
172 | for entity in missing_entities:
173 | custom_prompt += f'\n{entity},'
174 |
175 | filtered_extracted_entities = [entity for entity in extracted_entities if entity.name.strip()]
176 | end = time()
177 | logger.debug(f'Extracted new nodes: {filtered_extracted_entities} in {(end - start) * 1000} ms')
178 | # Convert the extracted data into EntityNode objects
179 | extracted_nodes = []
180 | for extracted_entity in filtered_extracted_entities:
181 | type_id = extracted_entity.entity_type_id
182 | if 0 <= type_id < len(entity_types_context):
183 | entity_type_name = entity_types_context[extracted_entity.entity_type_id].get(
184 | 'entity_type_name'
185 | )
186 | else:
187 | entity_type_name = 'Entity'
188 |
189 | # Check if this entity type should be excluded
190 | if excluded_entity_types and entity_type_name in excluded_entity_types:
191 | logger.debug(f'Excluding entity "{extracted_entity.name}" of type "{entity_type_name}"')
192 | continue
193 |
194 | labels: list[str] = list({'Entity', str(entity_type_name)})
195 |
196 | new_node = EntityNode(
197 | name=extracted_entity.name,
198 | group_id=episode.group_id,
199 | labels=labels,
200 | summary='',
201 | created_at=utc_now(),
202 | )
203 | extracted_nodes.append(new_node)
204 | logger.debug(f'Created new node: {new_node.name} (UUID: {new_node.uuid})')
205 |
206 | logger.debug(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
207 |
208 | return extracted_nodes
209 |
210 |
211 | async def _collect_candidate_nodes(
212 | clients: GraphitiClients,
213 | extracted_nodes: list[EntityNode],
214 | existing_nodes_override: list[EntityNode] | None,
215 | ) -> list[EntityNode]:
216 | """Search per extracted name and return unique candidates with overrides honored in order."""
217 | search_results: list[SearchResults] = await semaphore_gather(
218 | *[
219 | search(
220 | clients=clients,
221 | query=node.name,
222 | group_ids=[node.group_id],
223 | search_filter=SearchFilters(),
224 | config=NODE_HYBRID_SEARCH_RRF,
225 | )
226 | for node in extracted_nodes
227 | ]
228 | )
229 |
230 | candidate_nodes: list[EntityNode] = [node for result in search_results for node in result.nodes]
231 |
232 | if existing_nodes_override is not None:
233 | candidate_nodes.extend(existing_nodes_override)
234 |
235 | seen_candidate_uuids: set[str] = set()
236 | ordered_candidates: list[EntityNode] = []
237 | for candidate in candidate_nodes:
238 | if candidate.uuid in seen_candidate_uuids:
239 | continue
240 | seen_candidate_uuids.add(candidate.uuid)
241 | ordered_candidates.append(candidate)
242 |
243 | return ordered_candidates
244 |
245 |
246 | async def _resolve_with_llm(
247 | llm_client: LLMClient,
248 | extracted_nodes: list[EntityNode],
249 | indexes: DedupCandidateIndexes,
250 | state: DedupResolutionState,
251 | episode: EpisodicNode | None,
252 | previous_episodes: list[EpisodicNode] | None,
253 | entity_types: dict[str, type[BaseModel]] | None,
254 | ) -> None:
255 | """Escalate unresolved nodes to the dedupe prompt so the LLM can select or reject duplicates.
256 |
257 | The guardrails below defensively ignore malformed or duplicate LLM responses so the
258 | ingestion workflow remains deterministic even when the model misbehaves.
259 | """
260 | if not state.unresolved_indices:
261 | return
262 |
263 | entity_types_dict: dict[str, type[BaseModel]] = entity_types if entity_types is not None else {}
264 |
265 | llm_extracted_nodes = [extracted_nodes[i] for i in state.unresolved_indices]
266 |
267 | extracted_nodes_context = [
268 | {
269 | 'id': i,
270 | 'name': node.name,
271 | 'entity_type': node.labels,
272 | 'entity_type_description': entity_types_dict.get(
273 | next((item for item in node.labels if item != 'Entity'), '')
274 | ).__doc__
275 | or 'Default Entity Type',
276 | }
277 | for i, node in enumerate(llm_extracted_nodes)
278 | ]
279 |
280 | sent_ids = [ctx['id'] for ctx in extracted_nodes_context]
281 | logger.debug(
282 | 'Sending %d entities to LLM for deduplication with IDs 0-%d (actual IDs sent: %s)',
283 | len(llm_extracted_nodes),
284 | len(llm_extracted_nodes) - 1,
285 | sent_ids if len(sent_ids) < 20 else f'{sent_ids[:10]}...{sent_ids[-10:]}',
286 | )
287 | if llm_extracted_nodes:
288 | sample_size = min(3, len(extracted_nodes_context))
289 | logger.debug(
290 | 'First %d entities: %s',
291 | sample_size,
292 | [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[:sample_size]],
293 | )
294 | if len(extracted_nodes_context) > 3:
295 | logger.debug(
296 | 'Last %d entities: %s',
297 | sample_size,
298 | [(ctx['id'], ctx['name']) for ctx in extracted_nodes_context[-sample_size:]],
299 | )
300 |
301 | existing_nodes_context = [
302 | {
303 | **{
304 | 'idx': i,
305 | 'name': candidate.name,
306 | 'entity_types': candidate.labels,
307 | },
308 | **candidate.attributes,
309 | }
310 | for i, candidate in enumerate(indexes.existing_nodes)
311 | ]
312 |
313 | context = {
314 | 'extracted_nodes': extracted_nodes_context,
315 | 'existing_nodes': existing_nodes_context,
316 | 'episode_content': episode.content if episode is not None else '',
317 | 'previous_episodes': (
318 | [ep.content for ep in previous_episodes] if previous_episodes is not None else []
319 | ),
320 | }
321 |
322 | llm_response = await llm_client.generate_response(
323 | prompt_library.dedupe_nodes.nodes(context),
324 | response_model=NodeResolutions,
325 | prompt_name='dedupe_nodes.nodes',
326 | )
327 |
328 | node_resolutions: list[NodeDuplicate] = NodeResolutions(**llm_response).entity_resolutions
329 |
330 | valid_relative_range = range(len(state.unresolved_indices))
331 | processed_relative_ids: set[int] = set()
332 |
333 | received_ids = {r.id for r in node_resolutions}
334 | expected_ids = set(valid_relative_range)
335 | missing_ids = expected_ids - received_ids
336 | extra_ids = received_ids - expected_ids
337 |
338 | logger.debug(
339 | 'Received %d resolutions for %d entities',
340 | len(node_resolutions),
341 | len(state.unresolved_indices),
342 | )
343 |
344 | if missing_ids:
345 | logger.warning('LLM did not return resolutions for IDs: %s', sorted(missing_ids))
346 |
347 | if extra_ids:
348 | logger.warning(
349 | 'LLM returned invalid IDs outside valid range 0-%d: %s (all returned IDs: %s)',
350 | len(state.unresolved_indices) - 1,
351 | sorted(extra_ids),
352 | sorted(received_ids),
353 | )
354 |
355 | for resolution in node_resolutions:
356 | relative_id: int = resolution.id
357 | duplicate_idx: int = resolution.duplicate_idx
358 |
359 | if relative_id not in valid_relative_range:
360 | logger.warning(
361 | 'Skipping invalid LLM dedupe id %d (valid range: 0-%d, received %d resolutions)',
362 | relative_id,
363 | len(state.unresolved_indices) - 1,
364 | len(node_resolutions),
365 | )
366 | continue
367 |
368 | if relative_id in processed_relative_ids:
369 | logger.warning('Duplicate LLM dedupe id %s received; ignoring.', relative_id)
370 | continue
371 | processed_relative_ids.add(relative_id)
372 |
373 | original_index = state.unresolved_indices[relative_id]
374 | extracted_node = extracted_nodes[original_index]
375 |
376 | resolved_node: EntityNode
377 | if duplicate_idx == -1:
378 | resolved_node = extracted_node
379 | elif 0 <= duplicate_idx < len(indexes.existing_nodes):
380 | resolved_node = indexes.existing_nodes[duplicate_idx]
381 | else:
382 | logger.warning(
383 | 'Invalid duplicate_idx %s for extracted node %s; treating as no duplicate.',
384 | duplicate_idx,
385 | extracted_node.uuid,
386 | )
387 | resolved_node = extracted_node
388 |
389 | state.resolved_nodes[original_index] = resolved_node
390 | state.uuid_map[extracted_node.uuid] = resolved_node.uuid
391 | if resolved_node.uuid != extracted_node.uuid:
392 | state.duplicate_pairs.append((extracted_node, resolved_node))
393 |
394 |
395 | async def resolve_extracted_nodes(
396 | clients: GraphitiClients,
397 | extracted_nodes: list[EntityNode],
398 | episode: EpisodicNode | None = None,
399 | previous_episodes: list[EpisodicNode] | None = None,
400 | entity_types: dict[str, type[BaseModel]] | None = None,
401 | existing_nodes_override: list[EntityNode] | None = None,
402 | ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
403 | """Search for existing nodes, resolve deterministic matches, then escalate holdouts to the LLM dedupe prompt."""
404 | llm_client = clients.llm_client
405 | driver = clients.driver
406 | existing_nodes = await _collect_candidate_nodes(
407 | clients,
408 | extracted_nodes,
409 | existing_nodes_override,
410 | )
411 |
412 | indexes: DedupCandidateIndexes = _build_candidate_indexes(existing_nodes)
413 |
414 | state = DedupResolutionState(
415 | resolved_nodes=[None] * len(extracted_nodes),
416 | uuid_map={},
417 | unresolved_indices=[],
418 | )
419 |
420 | _resolve_with_similarity(extracted_nodes, indexes, state)
421 |
422 | await _resolve_with_llm(
423 | llm_client,
424 | extracted_nodes,
425 | indexes,
426 | state,
427 | episode,
428 | previous_episodes,
429 | entity_types,
430 | )
431 |
432 | for idx, node in enumerate(extracted_nodes):
433 | if state.resolved_nodes[idx] is None:
434 | state.resolved_nodes[idx] = node
435 | state.uuid_map[node.uuid] = node.uuid
436 |
437 | logger.debug(
438 | 'Resolved nodes: %s',
439 | [(node.name, node.uuid) for node in state.resolved_nodes if node is not None],
440 | )
441 |
442 | new_node_duplicates: list[
443 | tuple[EntityNode, EntityNode]
444 | ] = await filter_existing_duplicate_of_edges(driver, state.duplicate_pairs)
445 |
446 | return (
447 | [node for node in state.resolved_nodes if node is not None],
448 | state.uuid_map,
449 | new_node_duplicates,
450 | )
451 |
452 |
453 | async def extract_attributes_from_nodes(
454 | clients: GraphitiClients,
455 | nodes: list[EntityNode],
456 | episode: EpisodicNode | None = None,
457 | previous_episodes: list[EpisodicNode] | None = None,
458 | entity_types: dict[str, type[BaseModel]] | None = None,
459 | should_summarize_node: NodeSummaryFilter | None = None,
460 | ) -> list[EntityNode]:
461 | llm_client = clients.llm_client
462 | embedder = clients.embedder
463 | updated_nodes: list[EntityNode] = await semaphore_gather(
464 | *[
465 | extract_attributes_from_node(
466 | llm_client,
467 | node,
468 | episode,
469 | previous_episodes,
470 | (
471 | entity_types.get(next((item for item in node.labels if item != 'Entity'), ''))
472 | if entity_types is not None
473 | else None
474 | ),
475 | should_summarize_node,
476 | )
477 | for node in nodes
478 | ]
479 | )
480 |
481 | await create_entity_node_embeddings(embedder, updated_nodes)
482 |
483 | return updated_nodes
484 |
485 |
486 | async def extract_attributes_from_node(
487 | llm_client: LLMClient,
488 | node: EntityNode,
489 | episode: EpisodicNode | None = None,
490 | previous_episodes: list[EpisodicNode] | None = None,
491 | entity_type: type[BaseModel] | None = None,
492 | should_summarize_node: NodeSummaryFilter | None = None,
493 | ) -> EntityNode:
494 | # Extract attributes if entity type is defined and has attributes
495 | llm_response = await _extract_entity_attributes(
496 | llm_client, node, episode, previous_episodes, entity_type
497 | )
498 |
499 | # Extract summary if needed
500 | await _extract_entity_summary(
501 | llm_client, node, episode, previous_episodes, should_summarize_node
502 | )
503 |
504 | node.attributes.update(llm_response)
505 |
506 | return node
507 |
508 |
509 | async def _extract_entity_attributes(
510 | llm_client: LLMClient,
511 | node: EntityNode,
512 | episode: EpisodicNode | None,
513 | previous_episodes: list[EpisodicNode] | None,
514 | entity_type: type[BaseModel] | None,
515 | ) -> dict[str, Any]:
516 | if entity_type is None or len(entity_type.model_fields) == 0:
517 | return {}
518 |
519 | attributes_context = _build_episode_context(
520 | # should not include summary
521 | node_data={
522 | 'name': node.name,
523 | 'entity_types': node.labels,
524 | 'attributes': node.attributes,
525 | },
526 | episode=episode,
527 | previous_episodes=previous_episodes,
528 | )
529 |
530 | llm_response = await llm_client.generate_response(
531 | prompt_library.extract_nodes.extract_attributes(attributes_context),
532 | response_model=entity_type,
533 | model_size=ModelSize.small,
534 | group_id=node.group_id,
535 | prompt_name='extract_nodes.extract_attributes',
536 | )
537 |
538 | # validate response
539 | entity_type(**llm_response)
540 |
541 | return llm_response
542 |
543 |
544 | async def _extract_entity_summary(
545 | llm_client: LLMClient,
546 | node: EntityNode,
547 | episode: EpisodicNode | None,
548 | previous_episodes: list[EpisodicNode] | None,
549 | should_summarize_node: NodeSummaryFilter | None,
550 | ) -> None:
551 | if should_summarize_node is not None and not await should_summarize_node(node):
552 | return
553 |
554 | summary_context = _build_episode_context(
555 | node_data={
556 | 'name': node.name,
557 | 'summary': truncate_at_sentence(node.summary, MAX_SUMMARY_CHARS),
558 | 'entity_types': node.labels,
559 | 'attributes': node.attributes,
560 | },
561 | episode=episode,
562 | previous_episodes=previous_episodes,
563 | )
564 |
565 | summary_response = await llm_client.generate_response(
566 | prompt_library.extract_nodes.extract_summary(summary_context),
567 | response_model=EntitySummary,
568 | model_size=ModelSize.small,
569 | group_id=node.group_id,
570 | prompt_name='extract_nodes.extract_summary',
571 | )
572 |
573 | node.summary = truncate_at_sentence(summary_response.get('summary', ''), MAX_SUMMARY_CHARS)
574 |
575 |
576 | def _build_episode_context(
577 | node_data: dict[str, Any],
578 | episode: EpisodicNode | None,
579 | previous_episodes: list[EpisodicNode] | None,
580 | ) -> dict[str, Any]:
581 | return {
582 | 'node': node_data,
583 | 'episode_content': episode.content if episode is not None else '',
584 | 'previous_episodes': (
585 | [ep.content for ep in previous_episodes] if previous_episodes is not None else []
586 | ),
587 | }
588 |
```