This is page 6 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/signatures/version1/cla.json:
--------------------------------------------------------------------------------
```json
1 | {
2 | "signedContributors": [
3 | {
4 | "name": "colombod",
5 | "id": 375556,
6 | "comment_id": 2761979440,
7 | "created_at": "2025-03-28T17:21:29Z",
8 | "repoId": 840056306,
9 | "pullRequestNo": 310
10 | },
11 | {
12 | "name": "evanmschultz",
13 | "id": 3806601,
14 | "comment_id": 2813673237,
15 | "created_at": "2025-04-17T17:56:24Z",
16 | "repoId": 840056306,
17 | "pullRequestNo": 372
18 | },
19 | {
20 | "name": "soichisumi",
21 | "id": 30210641,
22 | "comment_id": 2818469528,
23 | "created_at": "2025-04-21T14:02:11Z",
24 | "repoId": 840056306,
25 | "pullRequestNo": 382
26 | },
27 | {
28 | "name": "drumnation",
29 | "id": 18486434,
30 | "comment_id": 2822330188,
31 | "created_at": "2025-04-22T19:51:09Z",
32 | "repoId": 840056306,
33 | "pullRequestNo": 389
34 | },
35 | {
36 | "name": "jackaldenryan",
37 | "id": 61809814,
38 | "comment_id": 2845356793,
39 | "created_at": "2025-05-01T17:51:11Z",
40 | "repoId": 840056306,
41 | "pullRequestNo": 429
42 | },
43 | {
44 | "name": "t41372",
45 | "id": 36402030,
46 | "comment_id": 2849035400,
47 | "created_at": "2025-05-04T06:24:37Z",
48 | "repoId": 840056306,
49 | "pullRequestNo": 438
50 | },
51 | {
52 | "name": "markalosey",
53 | "id": 1949914,
54 | "comment_id": 2878173826,
55 | "created_at": "2025-05-13T23:27:16Z",
56 | "repoId": 840056306,
57 | "pullRequestNo": 486
58 | },
59 | {
60 | "name": "adamkatav",
61 | "id": 13109136,
62 | "comment_id": 2887184706,
63 | "created_at": "2025-05-16T16:29:22Z",
64 | "repoId": 840056306,
65 | "pullRequestNo": 493
66 | },
67 | {
68 | "name": "realugbun",
69 | "id": 74101927,
70 | "comment_id": 2899731784,
71 | "created_at": "2025-05-22T02:36:44Z",
72 | "repoId": 840056306,
73 | "pullRequestNo": 513
74 | },
75 | {
76 | "name": "dudizimber",
77 | "id": 16744955,
78 | "comment_id": 2912211548,
79 | "created_at": "2025-05-27T11:45:57Z",
80 | "repoId": 840056306,
81 | "pullRequestNo": 525
82 | },
83 | {
84 | "name": "galshubeli",
85 | "id": 124919062,
86 | "comment_id": 2912289100,
87 | "created_at": "2025-05-27T12:15:03Z",
88 | "repoId": 840056306,
89 | "pullRequestNo": 525
90 | },
91 | {
92 | "name": "TheEpTic",
93 | "id": 326774,
94 | "comment_id": 2917970901,
95 | "created_at": "2025-05-29T01:26:54Z",
96 | "repoId": 840056306,
97 | "pullRequestNo": 541
98 | },
99 | {
100 | "name": "PrettyWood",
101 | "id": 18406791,
102 | "comment_id": 2938495182,
103 | "created_at": "2025-06-04T04:44:59Z",
104 | "repoId": 840056306,
105 | "pullRequestNo": 558
106 | },
107 | {
108 | "name": "denyska",
109 | "id": 1242726,
110 | "comment_id": 2957480685,
111 | "created_at": "2025-06-10T02:08:05Z",
112 | "repoId": 840056306,
113 | "pullRequestNo": 574
114 | },
115 | {
116 | "name": "LongPML",
117 | "id": 59755436,
118 | "comment_id": 2965391879,
119 | "created_at": "2025-06-12T07:10:01Z",
120 | "repoId": 840056306,
121 | "pullRequestNo": 579
122 | },
123 | {
124 | "name": "karn09",
125 | "id": 3743119,
126 | "comment_id": 2973492225,
127 | "created_at": "2025-06-15T04:45:13Z",
128 | "repoId": 840056306,
129 | "pullRequestNo": 584
130 | },
131 | {
132 | "name": "abab-dev",
133 | "id": 146825408,
134 | "comment_id": 2975719469,
135 | "created_at": "2025-06-16T09:12:53Z",
136 | "repoId": 840056306,
137 | "pullRequestNo": 588
138 | },
139 | {
140 | "name": "thorchh",
141 | "id": 75025911,
142 | "comment_id": 2982990164,
143 | "created_at": "2025-06-18T07:19:38Z",
144 | "repoId": 840056306,
145 | "pullRequestNo": 601
146 | },
147 | {
148 | "name": "robrichardson13",
149 | "id": 9492530,
150 | "comment_id": 2989798338,
151 | "created_at": "2025-06-20T04:59:06Z",
152 | "repoId": 840056306,
153 | "pullRequestNo": 611
154 | },
155 | {
156 | "name": "gkorland",
157 | "id": 753206,
158 | "comment_id": 2993690025,
159 | "created_at": "2025-06-21T17:35:37Z",
160 | "repoId": 840056306,
161 | "pullRequestNo": 609
162 | },
163 | {
164 | "name": "urmzd",
165 | "id": 45431570,
166 | "comment_id": 3027098935,
167 | "created_at": "2025-07-02T09:16:46Z",
168 | "repoId": 840056306,
169 | "pullRequestNo": 661
170 | },
171 | {
172 | "name": "jawwadfirdousi",
173 | "id": 10913083,
174 | "comment_id": 3027808026,
175 | "created_at": "2025-07-02T13:02:22Z",
176 | "repoId": 840056306,
177 | "pullRequestNo": 663
178 | },
179 | {
180 | "name": "jamesindeed",
181 | "id": 60527576,
182 | "comment_id": 3028293328,
183 | "created_at": "2025-07-02T15:24:23Z",
184 | "repoId": 840056306,
185 | "pullRequestNo": 664
186 | },
187 | {
188 | "name": "dev-mirzabicer",
189 | "id": 90691873,
190 | "comment_id": 3035836506,
191 | "created_at": "2025-07-04T11:47:08Z",
192 | "repoId": 840056306,
193 | "pullRequestNo": 672
194 | },
195 | {
196 | "name": "zeroasterisk",
197 | "id": 23422,
198 | "comment_id": 3040716245,
199 | "created_at": "2025-07-06T03:41:19Z",
200 | "repoId": 840056306,
201 | "pullRequestNo": 679
202 | },
203 | {
204 | "name": "charlesmcchan",
205 | "id": 425857,
206 | "comment_id": 3066732289,
207 | "created_at": "2025-07-13T08:54:26Z",
208 | "repoId": 840056306,
209 | "pullRequestNo": 711
210 | },
211 | {
212 | "name": "soraxas",
213 | "id": 22362177,
214 | "comment_id": 3084093750,
215 | "created_at": "2025-07-17T13:33:25Z",
216 | "repoId": 840056306,
217 | "pullRequestNo": 741
218 | },
219 | {
220 | "name": "sdht0",
221 | "id": 867424,
222 | "comment_id": 3092540466,
223 | "created_at": "2025-07-19T19:52:21Z",
224 | "repoId": 840056306,
225 | "pullRequestNo": 748
226 | },
227 | {
228 | "name": "Naseem77",
229 | "id": 34807727,
230 | "comment_id": 3093746709,
231 | "created_at": "2025-07-20T07:07:33Z",
232 | "repoId": 840056306,
233 | "pullRequestNo": 742
234 | },
235 | {
236 | "name": "kavenGw",
237 | "id": 3193355,
238 | "comment_id": 3100620568,
239 | "created_at": "2025-07-22T02:58:50Z",
240 | "repoId": 840056306,
241 | "pullRequestNo": 750
242 | },
243 | {
244 | "name": "paveljakov",
245 | "id": 45147436,
246 | "comment_id": 3113955940,
247 | "created_at": "2025-07-24T15:39:36Z",
248 | "repoId": 840056306,
249 | "pullRequestNo": 764
250 | },
251 | {
252 | "name": "gifflet",
253 | "id": 33522742,
254 | "comment_id": 3133869379,
255 | "created_at": "2025-07-29T20:00:27Z",
256 | "repoId": 840056306,
257 | "pullRequestNo": 782
258 | },
259 | {
260 | "name": "bechbd",
261 | "id": 6898505,
262 | "comment_id": 3140501814,
263 | "created_at": "2025-07-31T15:58:08Z",
264 | "repoId": 840056306,
265 | "pullRequestNo": 793
266 | },
267 | {
268 | "name": "hugo-son",
269 | "id": 141999572,
270 | "comment_id": 3155009405,
271 | "created_at": "2025-08-05T12:27:09Z",
272 | "repoId": 840056306,
273 | "pullRequestNo": 805
274 | },
275 | {
276 | "name": "mvanders",
277 | "id": 758617,
278 | "comment_id": 3160523661,
279 | "created_at": "2025-08-06T14:56:21Z",
280 | "repoId": 840056306,
281 | "pullRequestNo": 808
282 | },
283 | {
284 | "name": "v-khanna",
285 | "id": 102773390,
286 | "comment_id": 3162200130,
287 | "created_at": "2025-08-07T02:23:09Z",
288 | "repoId": 840056306,
289 | "pullRequestNo": 812
290 | },
291 | {
292 | "name": "vjeeva",
293 | "id": 13189349,
294 | "comment_id": 3165600173,
295 | "created_at": "2025-08-07T20:24:08Z",
296 | "repoId": 840056306,
297 | "pullRequestNo": 814
298 | },
299 | {
300 | "name": "liebertar",
301 | "id": 99405438,
302 | "comment_id": 3166905812,
303 | "created_at": "2025-08-08T07:52:27Z",
304 | "repoId": 840056306,
305 | "pullRequestNo": 816
306 | },
307 | {
308 | "name": "CaroLe-prw",
309 | "id": 42695882,
310 | "comment_id": 3187949734,
311 | "created_at": "2025-08-14T10:29:25Z",
312 | "repoId": 840056306,
313 | "pullRequestNo": 833
314 | },
315 | {
316 | "name": "Wizmann",
317 | "id": 1270921,
318 | "comment_id": 3196208374,
319 | "created_at": "2025-08-18T11:09:35Z",
320 | "repoId": 840056306,
321 | "pullRequestNo": 842
322 | },
323 | {
324 | "name": "liangyuanpeng",
325 | "id": 28711504,
326 | "comment_id": 3205841804,
327 | "created_at": "2025-08-20T11:35:42Z",
328 | "repoId": 840056306,
329 | "pullRequestNo": 847
330 | },
331 | {
332 | "name": "aktek-yazge",
333 | "id": 218602044,
334 | "comment_id": 3078757968,
335 | "created_at": "2025-07-16T14:00:40Z",
336 | "repoId": 840056306,
337 | "pullRequestNo": 735
338 | },
339 | {
340 | "name": "Shelvak",
341 | "id": 873323,
342 | "comment_id": 3243330690,
343 | "created_at": "2025-09-01T22:26:32Z",
344 | "repoId": 840056306,
345 | "pullRequestNo": 885
346 | },
347 | {
348 | "name": "maskshell",
349 | "id": 5113279,
350 | "comment_id": 3244187860,
351 | "created_at": "2025-09-02T07:48:05Z",
352 | "repoId": 840056306,
353 | "pullRequestNo": 886
354 | },
355 | {
356 | "name": "jeanlucthumm",
357 | "id": 4934853,
358 | "comment_id": 3255120747,
359 | "created_at": "2025-09-04T18:49:57Z",
360 | "repoId": 840056306,
361 | "pullRequestNo": 892
362 | },
363 | {
364 | "name": "Bit-urd",
365 | "id": 43745133,
366 | "comment_id": 3264006888,
367 | "created_at": "2025-09-07T20:01:08Z",
368 | "repoId": 840056306,
369 | "pullRequestNo": 895
370 | },
371 | {
372 | "name": "DavIvek",
373 | "id": 88043717,
374 | "comment_id": 3269895491,
375 | "created_at": "2025-09-09T09:59:47Z",
376 | "repoId": 840056306,
377 | "pullRequestNo": 900
378 | },
379 | {
380 | "name": "gsw945",
381 | "id": 6281968,
382 | "comment_id": 3270396586,
383 | "created_at": "2025-09-09T12:05:27Z",
384 | "repoId": 840056306,
385 | "pullRequestNo": 901
386 | },
387 | {
388 | "name": "luan122",
389 | "id": 5606023,
390 | "comment_id": 3287095238,
391 | "created_at": "2025-09-12T23:14:21Z",
392 | "repoId": 840056306,
393 | "pullRequestNo": 908
394 | },
395 | {
396 | "name": "Brandtweary",
397 | "id": 7968557,
398 | "comment_id": 3314191937,
399 | "created_at": "2025-09-19T23:37:33Z",
400 | "repoId": 840056306,
401 | "pullRequestNo": 916
402 | },
403 | {
404 | "name": "clsferguson",
405 | "id": 48876201,
406 | "comment_id": 3368715688,
407 | "created_at": "2025-10-05T03:30:10Z",
408 | "repoId": 840056306,
409 | "pullRequestNo": 981
410 | },
411 | {
412 | "name": "ngaiyuc",
413 | "id": 69293565,
414 | "comment_id": 3407383300,
415 | "created_at": "2025-10-15T16:45:10Z",
416 | "repoId": 840056306,
417 | "pullRequestNo": 1005
418 | },
419 | {
420 | "name": "0fism",
421 | "id": 63762457,
422 | "comment_id": 3407328042,
423 | "created_at": "2025-10-15T16:29:33Z",
424 | "repoId": 840056306,
425 | "pullRequestNo": 1005
426 | },
427 | {
428 | "name": "dontang97",
429 | "id": 88384441,
430 | "comment_id": 3431443627,
431 | "created_at": "2025-10-22T09:52:01Z",
432 | "repoId": 840056306,
433 | "pullRequestNo": 1020
434 | },
435 | {
436 | "name": "didier-durand",
437 | "id": 2927957,
438 | "comment_id": 3460571645,
439 | "created_at": "2025-10-29T09:31:25Z",
440 | "repoId": 840056306,
441 | "pullRequestNo": 1028
442 | },
443 | {
444 | "name": "anubhavgirdhar1",
445 | "id": 85768253,
446 | "comment_id": 3468525446,
447 | "created_at": "2025-10-30T15:11:58Z",
448 | "repoId": 840056306,
449 | "pullRequestNo": 1035
450 | },
451 | {
452 | "name": "Galleons2029",
453 | "id": 88185941,
454 | "comment_id": 3495884964,
455 | "created_at": "2025-11-06T08:39:46Z",
456 | "repoId": 840056306,
457 | "pullRequestNo": 1053
458 | },
459 | {
460 | "name": "supmo668",
461 | "id": 28805779,
462 | "comment_id": 3550309664,
463 | "created_at": "2025-11-19T01:56:25Z",
464 | "repoId": 840056306,
465 | "pullRequestNo": 1072
466 | },
467 | {
468 | "name": "donbr",
469 | "id": 7340008,
470 | "comment_id": 3568970102,
471 | "created_at": "2025-11-24T05:19:42Z",
472 | "repoId": 840056306,
473 | "pullRequestNo": 1081
474 | },
475 | {
476 | "name": "apetti1920",
477 | "id": 4706645,
478 | "comment_id": 3572726648,
479 | "created_at": "2025-11-24T21:07:34Z",
480 | "repoId": 840056306,
481 | "pullRequestNo": 1084
482 | }
483 | ]
484 | }
```
--------------------------------------------------------------------------------
/tests/test_entity_exclusion_int.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from datetime import datetime, timezone
18 |
19 | import pytest
20 | from pydantic import BaseModel, Field
21 |
22 | from graphiti_core.graphiti import Graphiti
23 | from graphiti_core.helpers import validate_excluded_entity_types
24 | from tests.helpers_test import drivers, get_driver
25 |
26 | pytestmark = pytest.mark.integration
27 | pytest_plugins = ('pytest_asyncio',)
28 |
29 |
30 | # Test entity type definitions
31 | class Person(BaseModel):
32 | """A human person mentioned in the conversation."""
33 |
34 | first_name: str | None = Field(None, description='First name of the person')
35 | last_name: str | None = Field(None, description='Last name of the person')
36 | occupation: str | None = Field(None, description='Job or profession of the person')
37 |
38 |
39 | class Organization(BaseModel):
40 | """A company, institution, or organized group."""
41 |
42 | organization_type: str | None = Field(
43 | None, description='Type of organization (company, NGO, etc.)'
44 | )
45 | industry: str | None = Field(
46 | None, description='Industry or sector the organization operates in'
47 | )
48 |
49 |
50 | class Location(BaseModel):
51 | """A geographic location, place, or address."""
52 |
53 | location_type: str | None = Field(
54 | None, description='Type of location (city, country, building, etc.)'
55 | )
56 | coordinates: str | None = Field(None, description='Geographic coordinates if available')
57 |
58 |
59 | @pytest.mark.asyncio
60 | @pytest.mark.parametrize(
61 | 'driver',
62 | drivers,
63 | )
64 | async def test_exclude_default_entity_type(driver):
65 | """Test excluding the default 'Entity' type while keeping custom types."""
66 | graphiti = Graphiti(graph_driver=get_driver(driver))
67 |
68 | try:
69 | await graphiti.build_indices_and_constraints()
70 |
71 | # Define entity types but exclude the default 'Entity' type
72 | entity_types = {
73 | 'Person': Person,
74 | 'Organization': Organization,
75 | }
76 |
77 | # Add an episode that would normally create both Entity and custom type entities
78 | episode_content = (
79 | 'John Smith works at Acme Corporation in New York. The weather is nice today.'
80 | )
81 |
82 | result = await graphiti.add_episode(
83 | name='Business Meeting',
84 | episode_body=episode_content,
85 | source_description='Meeting notes',
86 | reference_time=datetime.now(timezone.utc),
87 | entity_types=entity_types,
88 | excluded_entity_types=['Entity'], # Exclude default type
89 | group_id='test_exclude_default',
90 | )
91 |
92 | # Verify that nodes were created (custom types should still work)
93 | assert result is not None
94 |
95 | # Search for nodes to verify only custom types were created
96 | search_results = await graphiti.search_(
97 | query='John Smith Acme Corporation', group_ids=['test_exclude_default']
98 | )
99 |
100 | # Check that entities were created but with specific types, not default 'Entity'
101 | found_nodes = search_results.nodes
102 | for node in found_nodes:
103 | assert 'Entity' in node.labels # All nodes should have Entity label
104 | # But they should also have specific type labels
105 | assert any(label in ['Person', 'Organization'] for label in node.labels), (
106 | f'Node {node.name} should have a specific type label, got: {node.labels}'
107 | )
108 |
109 | # Clean up
110 | await _cleanup_test_nodes(graphiti, 'test_exclude_default')
111 |
112 | finally:
113 | await graphiti.close()
114 |
115 |
116 | @pytest.mark.asyncio
117 | @pytest.mark.parametrize(
118 | 'driver',
119 | drivers,
120 | )
121 | async def test_exclude_specific_custom_types(driver):
122 | """Test excluding specific custom entity types while keeping others."""
123 | graphiti = Graphiti(graph_driver=get_driver(driver))
124 |
125 | try:
126 | await graphiti.build_indices_and_constraints()
127 |
128 | # Define multiple entity types
129 | entity_types = {
130 | 'Person': Person,
131 | 'Organization': Organization,
132 | 'Location': Location,
133 | }
134 |
135 | # Add an episode with content that would create all types
136 | episode_content = (
137 | 'Sarah Johnson from Google visited the San Francisco office to discuss the new project.'
138 | )
139 |
140 | result = await graphiti.add_episode(
141 | name='Office Visit',
142 | episode_body=episode_content,
143 | source_description='Visit report',
144 | reference_time=datetime.now(timezone.utc),
145 | entity_types=entity_types,
146 | excluded_entity_types=['Organization', 'Location'], # Exclude these types
147 | group_id='test_exclude_custom',
148 | )
149 |
150 | assert result is not None
151 |
152 | # Search for nodes to verify only Person and Entity types were created
153 | search_results = await graphiti.search_(
154 | query='Sarah Johnson Google San Francisco', group_ids=['test_exclude_custom']
155 | )
156 |
157 | found_nodes = search_results.nodes
158 |
159 | # Should have Person and Entity type nodes, but no Organization or Location
160 | for node in found_nodes:
161 | assert 'Entity' in node.labels
162 | # Should not have excluded types
163 | assert 'Organization' not in node.labels, (
164 | f'Found excluded Organization in node: {node.name}'
165 | )
166 | assert 'Location' not in node.labels, f'Found excluded Location in node: {node.name}'
167 |
168 | # Should find at least one Person entity (Sarah Johnson)
169 | person_nodes = [n for n in found_nodes if 'Person' in n.labels]
170 | assert len(person_nodes) > 0, 'Should have found at least one Person entity'
171 |
172 | # Clean up
173 | await _cleanup_test_nodes(graphiti, 'test_exclude_custom')
174 |
175 | finally:
176 | await graphiti.close()
177 |
178 |
179 | @pytest.mark.asyncio
180 | @pytest.mark.parametrize(
181 | 'driver',
182 | drivers,
183 | )
184 | async def test_exclude_all_types(driver):
185 | """Test excluding all entity types (edge case)."""
186 | graphiti = Graphiti(graph_driver=get_driver(driver))
187 |
188 | try:
189 | await graphiti.build_indices_and_constraints()
190 |
191 | entity_types = {
192 | 'Person': Person,
193 | 'Organization': Organization,
194 | }
195 |
196 | # Exclude all types
197 | result = await graphiti.add_episode(
198 | name='No Entities',
199 | episode_body='This text mentions John and Microsoft but no entities should be created.',
200 | source_description='Test content',
201 | reference_time=datetime.now(timezone.utc),
202 | entity_types=entity_types,
203 | excluded_entity_types=['Entity', 'Person', 'Organization'], # Exclude everything
204 | group_id='test_exclude_all',
205 | )
206 |
207 | assert result is not None
208 |
209 | # Search for nodes - should find very few or none from this episode
210 | search_results = await graphiti.search_(
211 | query='John Microsoft', group_ids=['test_exclude_all']
212 | )
213 |
214 | # There should be minimal to no entities created
215 | found_nodes = search_results.nodes
216 | assert len(found_nodes) == 0, (
217 | f'Expected no entities, but found: {[n.name for n in found_nodes]}'
218 | )
219 |
220 | # Clean up
221 | await _cleanup_test_nodes(graphiti, 'test_exclude_all')
222 |
223 | finally:
224 | await graphiti.close()
225 |
226 |
227 | @pytest.mark.asyncio
228 | @pytest.mark.parametrize(
229 | 'driver',
230 | drivers,
231 | )
232 | async def test_exclude_no_types(driver):
233 | """Test normal behavior when no types are excluded (baseline test)."""
234 | graphiti = Graphiti(graph_driver=get_driver(driver))
235 |
236 | try:
237 | await graphiti.build_indices_and_constraints()
238 |
239 | entity_types = {
240 | 'Person': Person,
241 | 'Organization': Organization,
242 | }
243 |
244 | # Don't exclude any types
245 | result = await graphiti.add_episode(
246 | name='Normal Behavior',
247 | episode_body='Alice Smith works at TechCorp.',
248 | source_description='Normal test',
249 | reference_time=datetime.now(timezone.utc),
250 | entity_types=entity_types,
251 | excluded_entity_types=None, # No exclusions
252 | group_id='test_exclude_none',
253 | )
254 |
255 | assert result is not None
256 |
257 | # Search for nodes - should find entities of all types
258 | search_results = await graphiti.search_(
259 | query='Alice Smith TechCorp', group_ids=['test_exclude_none']
260 | )
261 |
262 | found_nodes = search_results.nodes
263 | assert len(found_nodes) > 0, 'Should have found some entities'
264 |
265 | # Should have both Person and Organization entities
266 | person_nodes = [n for n in found_nodes if 'Person' in n.labels]
267 | org_nodes = [n for n in found_nodes if 'Organization' in n.labels]
268 |
269 | assert len(person_nodes) > 0, 'Should have found Person entities'
270 | assert len(org_nodes) > 0, 'Should have found Organization entities'
271 |
272 | # Clean up
273 | await _cleanup_test_nodes(graphiti, 'test_exclude_none')
274 |
275 | finally:
276 | await graphiti.close()
277 |
278 |
279 | def test_validation_valid_excluded_types():
280 | """Test validation function with valid excluded types."""
281 | entity_types = {
282 | 'Person': Person,
283 | 'Organization': Organization,
284 | }
285 |
286 | # Valid exclusions
287 | assert validate_excluded_entity_types(['Entity'], entity_types) is True
288 | assert validate_excluded_entity_types(['Person'], entity_types) is True
289 | assert validate_excluded_entity_types(['Entity', 'Person'], entity_types) is True
290 | assert validate_excluded_entity_types(None, entity_types) is True
291 | assert validate_excluded_entity_types([], entity_types) is True
292 |
293 |
294 | def test_validation_invalid_excluded_types():
295 | """Test validation function with invalid excluded types."""
296 | entity_types = {
297 | 'Person': Person,
298 | 'Organization': Organization,
299 | }
300 |
301 | # Invalid exclusions should raise ValueError
302 | with pytest.raises(ValueError, match='Invalid excluded entity types'):
303 | validate_excluded_entity_types(['InvalidType'], entity_types)
304 |
305 | with pytest.raises(ValueError, match='Invalid excluded entity types'):
306 | validate_excluded_entity_types(['Person', 'NonExistentType'], entity_types)
307 |
308 |
309 | @pytest.mark.asyncio
310 | @pytest.mark.parametrize(
311 | 'driver',
312 | drivers,
313 | )
314 | async def test_excluded_types_parameter_validation_in_add_episode(driver):
315 | """Test that add_episode validates excluded_entity_types parameter."""
316 | graphiti = Graphiti(graph_driver=get_driver(driver))
317 |
318 | try:
319 | entity_types = {
320 | 'Person': Person,
321 | }
322 |
323 | # Should raise ValueError for invalid excluded type
324 | with pytest.raises(ValueError, match='Invalid excluded entity types'):
325 | await graphiti.add_episode(
326 | name='Invalid Test',
327 | episode_body='Test content',
328 | source_description='Test',
329 | reference_time=datetime.now(timezone.utc),
330 | entity_types=entity_types,
331 | excluded_entity_types=['NonExistentType'],
332 | group_id='test_validation',
333 | )
334 |
335 | finally:
336 | await graphiti.close()
337 |
338 |
339 | async def _cleanup_test_nodes(graphiti: Graphiti, group_id: str):
340 | """Helper function to clean up test nodes."""
341 | try:
342 | # Get all nodes for this group
343 | search_results = await graphiti.search_(query='*', group_ids=[group_id])
344 |
345 | # Delete all found nodes
346 | for node in search_results.nodes:
347 | await node.delete(graphiti.driver)
348 |
349 | except Exception as e:
350 | # Log but don't fail the test if cleanup fails
351 | print(f'Warning: Failed to clean up test nodes for group {group_id}: {e}')
352 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | HTTP/SSE Integration test for the refactored Graphiti MCP Server.
4 | Tests server functionality when running in SSE (Server-Sent Events) mode over HTTP.
5 | Note: This test requires the server to be running with --transport sse.
6 | """
7 |
8 | import asyncio
9 | import json
10 | import time
11 | from typing import Any
12 |
13 | import httpx
14 |
15 |
16 | class MCPIntegrationTest:
17 | """Integration test client for Graphiti MCP Server."""
18 |
19 | def __init__(self, base_url: str = 'http://localhost:8000'):
20 | self.base_url = base_url
21 | self.client = httpx.AsyncClient(timeout=30.0)
22 | self.test_group_id = f'test_group_{int(time.time())}'
23 |
24 | async def __aenter__(self):
25 | return self
26 |
27 | async def __aexit__(self, exc_type, exc_val, exc_tb):
28 | await self.client.aclose()
29 |
30 | async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
31 | """Call an MCP tool via the SSE endpoint."""
32 | # MCP protocol message structure
33 | message = {
34 | 'jsonrpc': '2.0',
35 | 'id': int(time.time() * 1000),
36 | 'method': 'tools/call',
37 | 'params': {'name': tool_name, 'arguments': arguments},
38 | }
39 |
40 | try:
41 | response = await self.client.post(
42 | f'{self.base_url}/message',
43 | json=message,
44 | headers={'Content-Type': 'application/json'},
45 | )
46 |
47 | if response.status_code != 200:
48 | return {'error': f'HTTP {response.status_code}: {response.text}'}
49 |
50 | result = response.json()
51 | return result.get('result', result)
52 |
53 | except Exception as e:
54 | return {'error': str(e)}
55 |
56 | async def test_server_status(self) -> bool:
57 | """Test the get_status resource."""
58 | print('🔍 Testing server status...')
59 |
60 | try:
61 | response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
62 | if response.status_code == 200:
63 | status = response.json()
64 | print(f' ✅ Server status: {status.get("status", "unknown")}')
65 | return status.get('status') == 'ok'
66 | else:
67 | print(f' ❌ Status check failed: HTTP {response.status_code}')
68 | return False
69 | except Exception as e:
70 | print(f' ❌ Status check failed: {e}')
71 | return False
72 |
73 | async def test_add_memory(self) -> dict[str, str]:
74 | """Test adding various types of memory episodes."""
75 | print('📝 Testing add_memory functionality...')
76 |
77 | episode_results = {}
78 |
79 | # Test 1: Add text episode
80 | print(' Testing text episode...')
81 | result = await self.call_mcp_tool(
82 | 'add_memory',
83 | {
84 | 'name': 'Test Company News',
85 | 'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
86 | 'source': 'text',
87 | 'source_description': 'news article',
88 | 'group_id': self.test_group_id,
89 | },
90 | )
91 |
92 | if 'error' in result:
93 | print(f' ❌ Text episode failed: {result["error"]}')
94 | else:
95 | print(f' ✅ Text episode queued: {result.get("message", "Success")}')
96 | episode_results['text'] = 'success'
97 |
98 | # Test 2: Add JSON episode
99 | print(' Testing JSON episode...')
100 | json_data = {
101 | 'company': {'name': 'TechCorp', 'founded': 2010},
102 | 'products': [
103 | {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
104 | {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
105 | ],
106 | 'employees': 150,
107 | }
108 |
109 | result = await self.call_mcp_tool(
110 | 'add_memory',
111 | {
112 | 'name': 'Company Profile',
113 | 'episode_body': json.dumps(json_data),
114 | 'source': 'json',
115 | 'source_description': 'CRM data',
116 | 'group_id': self.test_group_id,
117 | },
118 | )
119 |
120 | if 'error' in result:
121 | print(f' ❌ JSON episode failed: {result["error"]}')
122 | else:
123 | print(f' ✅ JSON episode queued: {result.get("message", "Success")}')
124 | episode_results['json'] = 'success'
125 |
126 | # Test 3: Add message episode
127 | print(' Testing message episode...')
128 | result = await self.call_mcp_tool(
129 | 'add_memory',
130 | {
131 | 'name': 'Customer Support Chat',
132 | 'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
133 | 'source': 'message',
134 | 'source_description': 'support chat log',
135 | 'group_id': self.test_group_id,
136 | },
137 | )
138 |
139 | if 'error' in result:
140 | print(f' ❌ Message episode failed: {result["error"]}')
141 | else:
142 | print(f' ✅ Message episode queued: {result.get("message", "Success")}')
143 | episode_results['message'] = 'success'
144 |
145 | return episode_results
146 |
147 | async def wait_for_processing(self, max_wait: int = 30) -> None:
148 | """Wait for episode processing to complete."""
149 | print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
150 |
151 | for i in range(max_wait):
152 | await asyncio.sleep(1)
153 |
154 | # Check if we have any episodes
155 | result = await self.call_mcp_tool(
156 | 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
157 | )
158 |
159 | if not isinstance(result, dict) or 'error' in result:
160 | continue
161 |
162 | if isinstance(result, list) and len(result) > 0:
163 | print(f' ✅ Found {len(result)} processed episodes after {i + 1} seconds')
164 | return
165 |
166 | print(f' ⚠️ Still waiting after {max_wait} seconds...')
167 |
168 | async def test_search_functions(self) -> dict[str, bool]:
169 | """Test search functionality."""
170 | print('🔍 Testing search functions...')
171 |
172 | results = {}
173 |
174 | # Test search_memory_nodes
175 | print(' Testing search_memory_nodes...')
176 | result = await self.call_mcp_tool(
177 | 'search_memory_nodes',
178 | {
179 | 'query': 'Acme Corp product launch',
180 | 'group_ids': [self.test_group_id],
181 | 'max_nodes': 5,
182 | },
183 | )
184 |
185 | if 'error' in result:
186 | print(f' ❌ Node search failed: {result["error"]}')
187 | results['nodes'] = False
188 | else:
189 | nodes = result.get('nodes', [])
190 | print(f' ✅ Node search returned {len(nodes)} nodes')
191 | results['nodes'] = True
192 |
193 | # Test search_memory_facts
194 | print(' Testing search_memory_facts...')
195 | result = await self.call_mcp_tool(
196 | 'search_memory_facts',
197 | {
198 | 'query': 'company products software',
199 | 'group_ids': [self.test_group_id],
200 | 'max_facts': 5,
201 | },
202 | )
203 |
204 | if 'error' in result:
205 | print(f' ❌ Fact search failed: {result["error"]}')
206 | results['facts'] = False
207 | else:
208 | facts = result.get('facts', [])
209 | print(f' ✅ Fact search returned {len(facts)} facts')
210 | results['facts'] = True
211 |
212 | return results
213 |
214 | async def test_episode_retrieval(self) -> bool:
215 | """Test episode retrieval."""
216 | print('📚 Testing episode retrieval...')
217 |
218 | result = await self.call_mcp_tool(
219 | 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
220 | )
221 |
222 | if 'error' in result:
223 | print(f' ❌ Episode retrieval failed: {result["error"]}')
224 | return False
225 |
226 | if isinstance(result, list):
227 | print(f' ✅ Retrieved {len(result)} episodes')
228 |
229 | # Print episode details
230 | for i, episode in enumerate(result[:3]): # Show first 3
231 | name = episode.get('name', 'Unknown')
232 | source = episode.get('source', 'unknown')
233 | print(f' Episode {i + 1}: {name} (source: {source})')
234 |
235 | return len(result) > 0
236 | else:
237 | print(f' ❌ Unexpected result format: {type(result)}')
238 | return False
239 |
240 | async def test_edge_cases(self) -> dict[str, bool]:
241 | """Test edge cases and error handling."""
242 | print('🧪 Testing edge cases...')
243 |
244 | results = {}
245 |
246 | # Test with invalid group_id
247 | print(' Testing invalid group_id...')
248 | result = await self.call_mcp_tool(
249 | 'search_memory_nodes',
250 | {'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
251 | )
252 |
253 | # Should not error, just return empty results
254 | if 'error' not in result:
255 | nodes = result.get('nodes', [])
256 | print(f' ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
257 | results['invalid_group'] = True
258 | else:
259 | print(f' ❌ Invalid group_id caused error: {result["error"]}')
260 | results['invalid_group'] = False
261 |
262 | # Test empty query
263 | print(' Testing empty query...')
264 | result = await self.call_mcp_tool(
265 | 'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
266 | )
267 |
268 | if 'error' not in result:
269 | print(' ✅ Empty query handled gracefully')
270 | results['empty_query'] = True
271 | else:
272 | print(f' ❌ Empty query caused error: {result["error"]}')
273 | results['empty_query'] = False
274 |
275 | return results
276 |
277 | async def run_full_test_suite(self) -> dict[str, Any]:
278 | """Run the complete integration test suite."""
279 | print('🚀 Starting Graphiti MCP Server Integration Test')
280 | print(f' Test group ID: {self.test_group_id}')
281 | print('=' * 60)
282 |
283 | results = {
284 | 'server_status': False,
285 | 'add_memory': {},
286 | 'search': {},
287 | 'episodes': False,
288 | 'edge_cases': {},
289 | 'overall_success': False,
290 | }
291 |
292 | # Test 1: Server Status
293 | results['server_status'] = await self.test_server_status()
294 | if not results['server_status']:
295 | print('❌ Server not responding, aborting tests')
296 | return results
297 |
298 | print()
299 |
300 | # Test 2: Add Memory
301 | results['add_memory'] = await self.test_add_memory()
302 | print()
303 |
304 | # Test 3: Wait for processing
305 | await self.wait_for_processing()
306 | print()
307 |
308 | # Test 4: Search Functions
309 | results['search'] = await self.test_search_functions()
310 | print()
311 |
312 | # Test 5: Episode Retrieval
313 | results['episodes'] = await self.test_episode_retrieval()
314 | print()
315 |
316 | # Test 6: Edge Cases
317 | results['edge_cases'] = await self.test_edge_cases()
318 | print()
319 |
320 | # Calculate overall success
321 | memory_success = len(results['add_memory']) > 0
322 | search_success = any(results['search'].values())
323 | edge_case_success = any(results['edge_cases'].values())
324 |
325 | results['overall_success'] = (
326 | results['server_status']
327 | and memory_success
328 | and results['episodes']
329 | and (search_success or edge_case_success) # At least some functionality working
330 | )
331 |
332 | # Print summary
333 | print('=' * 60)
334 | print('📊 TEST SUMMARY')
335 | print(f' Server Status: {"✅" if results["server_status"] else "❌"}')
336 | print(
337 | f' Memory Operations: {"✅" if memory_success else "❌"} ({len(results["add_memory"])} types)'
338 | )
339 | print(f' Search Functions: {"✅" if search_success else "❌"}')
340 | print(f' Episode Retrieval: {"✅" if results["episodes"] else "❌"}')
341 | print(f' Edge Cases: {"✅" if edge_case_success else "❌"}')
342 | print()
343 | print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
344 |
345 | if results['overall_success']:
346 | print(' The refactored MCP server is working correctly!')
347 | else:
348 | print(' Some issues detected. Check individual test results above.')
349 |
350 | return results
351 |
352 |
353 | async def main():
354 | """Run the integration test."""
355 | async with MCPIntegrationTest() as test:
356 | results = await test.run_full_test_suite()
357 |
358 | # Exit with appropriate code
359 | exit_code = 0 if results['overall_success'] else 1
360 | exit(exit_code)
361 |
362 |
363 | if __name__ == '__main__':
364 | asyncio.run(main())
365 |
```
--------------------------------------------------------------------------------
/graphiti_core/driver/falkordb_driver.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import asyncio
18 | import datetime
19 | import logging
20 | from typing import TYPE_CHECKING, Any
21 |
22 | if TYPE_CHECKING:
23 | from falkordb import Graph as FalkorGraph
24 | from falkordb.asyncio import FalkorDB
25 | else:
26 | try:
27 | from falkordb import Graph as FalkorGraph
28 | from falkordb.asyncio import FalkorDB
29 | except ImportError:
30 | # If falkordb is not installed, raise an ImportError
31 | raise ImportError(
32 | 'falkordb is required for FalkorDriver. '
33 | 'Install it with: pip install graphiti-core[falkordb]'
34 | ) from None
35 |
36 | from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
37 | from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
38 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 | STOPWORDS = [
43 | 'a',
44 | 'is',
45 | 'the',
46 | 'an',
47 | 'and',
48 | 'are',
49 | 'as',
50 | 'at',
51 | 'be',
52 | 'but',
53 | 'by',
54 | 'for',
55 | 'if',
56 | 'in',
57 | 'into',
58 | 'it',
59 | 'no',
60 | 'not',
61 | 'of',
62 | 'on',
63 | 'or',
64 | 'such',
65 | 'that',
66 | 'their',
67 | 'then',
68 | 'there',
69 | 'these',
70 | 'they',
71 | 'this',
72 | 'to',
73 | 'was',
74 | 'will',
75 | 'with',
76 | ]
77 |
78 |
79 | class FalkorDriverSession(GraphDriverSession):
80 | provider = GraphProvider.FALKORDB
81 |
82 | def __init__(self, graph: FalkorGraph):
83 | self.graph = graph
84 |
85 | async def __aenter__(self):
86 | return self
87 |
88 | async def __aexit__(self, exc_type, exc, tb):
89 | # No cleanup needed for Falkor, but method must exist
90 | pass
91 |
92 | async def close(self):
93 | # No explicit close needed for FalkorDB, but method must exist
94 | pass
95 |
96 | async def execute_write(self, func, *args, **kwargs):
97 | # Directly await the provided async function with `self` as the transaction/session
98 | return await func(self, *args, **kwargs)
99 |
100 | async def run(self, query: str | list, **kwargs: Any) -> Any:
101 | # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
102 | if isinstance(query, list):
103 | for cypher, params in query:
104 | params = convert_datetimes_to_strings(params)
105 | await self.graph.query(str(cypher), params) # type: ignore[reportUnknownArgumentType]
106 | else:
107 | params = dict(kwargs)
108 | params = convert_datetimes_to_strings(params)
109 | await self.graph.query(str(query), params) # type: ignore[reportUnknownArgumentType]
110 | # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
111 | return None
112 |
113 |
114 | class FalkorDriver(GraphDriver):
115 | provider = GraphProvider.FALKORDB
116 | default_group_id: str = '\\_'
117 | fulltext_syntax: str = '@' # FalkorDB uses a redisearch-like syntax for fulltext queries
118 | aoss_client: None = None
119 |
120 | def __init__(
121 | self,
122 | host: str = 'localhost',
123 | port: int = 6379,
124 | username: str | None = None,
125 | password: str | None = None,
126 | falkor_db: FalkorDB | None = None,
127 | database: str = 'default_db',
128 | ):
129 | """
130 | Initialize the FalkorDB driver.
131 |
132 | FalkorDB is a multi-tenant graph database.
133 | To connect, provide the host and port.
134 | The default parameters assume a local (on-premises) FalkorDB instance.
135 |
136 | Args:
137 | host (str): The host where FalkorDB is running.
138 | port (int): The port on which FalkorDB is listening.
139 | username (str | None): The username for authentication (if required).
140 | password (str | None): The password for authentication (if required).
141 | falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
142 | database (str): The name of the database to connect to. Defaults to 'default_db'.
143 | """
144 | super().__init__()
145 | self._database = database
146 | if falkor_db is not None:
147 | # If a FalkorDB instance is provided, use it directly
148 | self.client = falkor_db
149 | else:
150 | self.client = FalkorDB(host=host, port=port, username=username, password=password)
151 |
152 | # Schedule the indices and constraints to be built
153 | try:
154 | # Try to get the current event loop
155 | loop = asyncio.get_running_loop()
156 | # Schedule the build_indices_and_constraints to run
157 | loop.create_task(self.build_indices_and_constraints())
158 | except RuntimeError:
159 | # No event loop running, this will be handled later
160 | pass
161 |
162 | def _get_graph(self, graph_name: str | None) -> FalkorGraph:
163 | # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
164 | if graph_name is None:
165 | graph_name = self._database
166 | return self.client.select_graph(graph_name)
167 |
168 | async def execute_query(self, cypher_query_, **kwargs: Any):
169 | graph = self._get_graph(self._database)
170 |
171 | # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
172 | params = convert_datetimes_to_strings(dict(kwargs))
173 |
174 | try:
175 | result = await graph.query(cypher_query_, params) # type: ignore[reportUnknownArgumentType]
176 | except Exception as e:
177 | if 'already indexed' in str(e):
178 | # check if index already exists
179 | logger.info(f'Index already exists: {e}')
180 | return None
181 | logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
182 | raise
183 |
184 | # Convert the result header to a list of strings
185 | header = [h[1] for h in result.header]
186 |
187 | # Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
188 | records = []
189 | for row in result.result_set:
190 | record = {}
191 | for i, field_name in enumerate(header):
192 | if i < len(row):
193 | record[field_name] = row[i]
194 | else:
195 | # If there are more fields in header than values in row, set to None
196 | record[field_name] = None
197 | records.append(record)
198 |
199 | return records, header, None
200 |
201 | def session(self, database: str | None = None) -> GraphDriverSession:
202 | return FalkorDriverSession(self._get_graph(database))
203 |
204 | async def close(self) -> None:
205 | """Close the driver connection."""
206 | if hasattr(self.client, 'aclose'):
207 | await self.client.aclose() # type: ignore[reportUnknownMemberType]
208 | elif hasattr(self.client.connection, 'aclose'):
209 | await self.client.connection.aclose()
210 | elif hasattr(self.client.connection, 'close'):
211 | await self.client.connection.close()
212 |
213 | async def delete_all_indexes(self) -> None:
214 | result = await self.execute_query('CALL db.indexes()')
215 | if not result:
216 | return
217 |
218 | records, _, _ = result
219 | drop_tasks = []
220 |
221 | for record in records:
222 | label = record['label']
223 | entity_type = record['entitytype']
224 |
225 | for field_name, index_type in record['types'].items():
226 | if 'RANGE' in index_type:
227 | drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
228 | elif 'FULLTEXT' in index_type:
229 | if entity_type == 'NODE':
230 | drop_tasks.append(
231 | self.execute_query(
232 | f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
233 | )
234 | )
235 | elif entity_type == 'RELATIONSHIP':
236 | drop_tasks.append(
237 | self.execute_query(
238 | f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
239 | )
240 | )
241 |
242 | if drop_tasks:
243 | await asyncio.gather(*drop_tasks)
244 |
245 | async def build_indices_and_constraints(self, delete_existing=False):
246 | if delete_existing:
247 | await self.delete_all_indexes()
248 | index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
249 | for query in index_queries:
250 | await self.execute_query(query)
251 |
252 | def clone(self, database: str) -> 'GraphDriver':
253 | """
254 | Returns a shallow copy of this driver with a different default database.
255 | Reuses the same connection (e.g. FalkorDB, Neo4j).
256 | """
257 | if database == self._database:
258 | cloned = self
259 | elif database == self.default_group_id:
260 | cloned = FalkorDriver(falkor_db=self.client)
261 | else:
262 | # Create a new instance of FalkorDriver with the same connection but a different database
263 | cloned = FalkorDriver(falkor_db=self.client, database=database)
264 |
265 | return cloned
266 |
267 | async def health_check(self) -> None:
268 | """Check FalkorDB connectivity by running a simple query."""
269 | try:
270 | await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
271 | return None
272 | except Exception as e:
273 | print(f'FalkorDB health check failed: {e}')
274 | raise
275 |
276 | @staticmethod
277 | def convert_datetimes_to_strings(obj):
278 | if isinstance(obj, dict):
279 | return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
280 | elif isinstance(obj, list):
281 | return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
282 | elif isinstance(obj, tuple):
283 | return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
284 | elif isinstance(obj, datetime):
285 | return obj.isoformat()
286 | else:
287 | return obj
288 |
289 | def sanitize(self, query: str) -> str:
290 | """
291 | Replace FalkorDB special characters with whitespace.
292 | Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
293 | """
294 | # FalkorDB separator characters that break text into tokens
295 | separator_map = str.maketrans(
296 | {
297 | ',': ' ',
298 | '.': ' ',
299 | '<': ' ',
300 | '>': ' ',
301 | '{': ' ',
302 | '}': ' ',
303 | '[': ' ',
304 | ']': ' ',
305 | '"': ' ',
306 | "'": ' ',
307 | ':': ' ',
308 | ';': ' ',
309 | '!': ' ',
310 | '@': ' ',
311 | '#': ' ',
312 | '$': ' ',
313 | '%': ' ',
314 | '^': ' ',
315 | '&': ' ',
316 | '*': ' ',
317 | '(': ' ',
318 | ')': ' ',
319 | '-': ' ',
320 | '+': ' ',
321 | '=': ' ',
322 | '~': ' ',
323 | '?': ' ',
324 | }
325 | )
326 | sanitized = query.translate(separator_map)
327 | # Clean up multiple spaces
328 | sanitized = ' '.join(sanitized.split())
329 | return sanitized
330 |
331 | def build_fulltext_query(
332 | self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
333 | ) -> str:
334 | """
335 | Build a fulltext query string for FalkorDB using RedisSearch syntax.
336 | FalkorDB uses RedisSearch-like syntax where:
337 | - Field queries use @ prefix: @field:value
338 | - Multiple values for same field: (@field:value1|value2)
339 | - Text search doesn't need @ prefix for content fields
340 | - AND is implicit with space: (@group_id:value) (text)
341 | - OR uses pipe within parentheses: (@group_id:value1|value2)
342 | """
343 | if group_ids is None or len(group_ids) == 0:
344 | group_filter = ''
345 | else:
346 | group_values = '|'.join(group_ids)
347 | group_filter = f'(@group_id:{group_values})'
348 |
349 | sanitized_query = self.sanitize(query)
350 |
351 | # Remove stopwords from the sanitized query
352 | query_words = sanitized_query.split()
353 | filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
354 | sanitized_query = ' | '.join(filtered_words)
355 |
356 | # If the query is too long return no query
357 | if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
358 | return ''
359 |
360 | full_query = group_filter + ' (' + sanitized_query + ')'
361 |
362 | return full_query
363 |
```
--------------------------------------------------------------------------------
/graphiti_core/models/nodes/node_db_queries.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import Any
18 |
19 | from graphiti_core.driver.driver import GraphProvider
20 |
21 |
22 | def get_episode_node_save_query(provider: GraphProvider) -> str:
23 | match provider:
24 | case GraphProvider.NEPTUNE:
25 | return """
26 | MERGE (n:Episodic {uuid: $uuid})
27 | SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
28 | entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
29 | RETURN n.uuid AS uuid
30 | """
31 | case GraphProvider.KUZU:
32 | return """
33 | MERGE (n:Episodic {uuid: $uuid})
34 | SET
35 | n.name = $name,
36 | n.group_id = $group_id,
37 | n.created_at = $created_at,
38 | n.source = $source,
39 | n.source_description = $source_description,
40 | n.content = $content,
41 | n.valid_at = $valid_at,
42 | n.entity_edges = $entity_edges
43 | RETURN n.uuid AS uuid
44 | """
45 | case GraphProvider.FALKORDB:
46 | return """
47 | MERGE (n:Episodic {uuid: $uuid})
48 | SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
49 | entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
50 | RETURN n.uuid AS uuid
51 | """
52 | case _: # Neo4j
53 | return """
54 | MERGE (n:Episodic {uuid: $uuid})
55 | SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
56 | entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
57 | RETURN n.uuid AS uuid
58 | """
59 |
60 |
61 | def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
62 | match provider:
63 | case GraphProvider.NEPTUNE:
64 | return """
65 | UNWIND $episodes AS episode
66 | MERGE (n:Episodic {uuid: episode.uuid})
67 | SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
68 | source: episode.source, content: episode.content,
69 | entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
70 | RETURN n.uuid AS uuid
71 | """
72 | case GraphProvider.KUZU:
73 | return """
74 | MERGE (n:Episodic {uuid: $uuid})
75 | SET
76 | n.name = $name,
77 | n.group_id = $group_id,
78 | n.created_at = $created_at,
79 | n.source = $source,
80 | n.source_description = $source_description,
81 | n.content = $content,
82 | n.valid_at = $valid_at,
83 | n.entity_edges = $entity_edges
84 | RETURN n.uuid AS uuid
85 | """
86 | case GraphProvider.FALKORDB:
87 | return """
88 | UNWIND $episodes AS episode
89 | MERGE (n:Episodic {uuid: episode.uuid})
90 | SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
91 | entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
92 | RETURN n.uuid AS uuid
93 | """
94 | case _: # Neo4j
95 | return """
96 | UNWIND $episodes AS episode
97 | MERGE (n:Episodic {uuid: episode.uuid})
98 | SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content,
99 | entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
100 | RETURN n.uuid AS uuid
101 | """
102 |
103 |
104 | EPISODIC_NODE_RETURN = """
105 | e.uuid AS uuid,
106 | e.name AS name,
107 | e.group_id AS group_id,
108 | e.created_at AS created_at,
109 | e.source AS source,
110 | e.source_description AS source_description,
111 | e.content AS content,
112 | e.valid_at AS valid_at,
113 | e.entity_edges AS entity_edges
114 | """
115 |
116 | EPISODIC_NODE_RETURN_NEPTUNE = """
117 | e.content AS content,
118 | e.created_at AS created_at,
119 | e.valid_at AS valid_at,
120 | e.uuid AS uuid,
121 | e.name AS name,
122 | e.group_id AS group_id,
123 | e.source_description AS source_description,
124 | e.source AS source,
125 | split(e.entity_edges, ",") AS entity_edges
126 | """
127 |
128 |
129 | def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
130 | match provider:
131 | case GraphProvider.FALKORDB:
132 | return f"""
133 | MERGE (n:Entity {{uuid: $entity_data.uuid}})
134 | SET n:{labels}
135 | SET n = $entity_data
136 | SET n.name_embedding = vecf32($entity_data.name_embedding)
137 | RETURN n.uuid AS uuid
138 | """
139 | case GraphProvider.KUZU:
140 | return """
141 | MERGE (n:Entity {uuid: $uuid})
142 | SET
143 | n.name = $name,
144 | n.group_id = $group_id,
145 | n.labels = $labels,
146 | n.created_at = $created_at,
147 | n.name_embedding = $name_embedding,
148 | n.summary = $summary,
149 | n.attributes = $attributes
150 | WITH n
151 | RETURN n.uuid AS uuid
152 | """
153 | case GraphProvider.NEPTUNE:
154 | label_subquery = ''
155 | for label in labels.split(':'):
156 | label_subquery += f' SET n:{label}\n'
157 | return f"""
158 | MERGE (n:Entity {{uuid: $entity_data.uuid}})
159 | {label_subquery}
160 | SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
161 | SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
162 | RETURN n.uuid AS uuid
163 | """
164 | case _:
165 | save_embedding_query = (
166 | 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
167 | if not has_aoss
168 | else ''
169 | )
170 | return (
171 | f"""
172 | MERGE (n:Entity {{uuid: $entity_data.uuid}})
173 | SET n:{labels}
174 | SET n = $entity_data
175 | """
176 | + save_embedding_query
177 | + """
178 | RETURN n.uuid AS uuid
179 | """
180 | )
181 |
182 |
183 | def get_entity_node_save_bulk_query(
184 | provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
185 | ) -> str | Any:
186 | match provider:
187 | case GraphProvider.FALKORDB:
188 | queries = []
189 | for node in nodes:
190 | for label in node['labels']:
191 | queries.append(
192 | (
193 | f"""
194 | UNWIND $nodes AS node
195 | MERGE (n:Entity {{uuid: node.uuid}})
196 | SET n:{label}
197 | SET n = node
198 | WITH n, node
199 | SET n.name_embedding = vecf32(node.name_embedding)
200 | RETURN n.uuid AS uuid
201 | """,
202 | {'nodes': [node]},
203 | )
204 | )
205 | return queries
206 | case GraphProvider.NEPTUNE:
207 | queries = []
208 | for node in nodes:
209 | labels = ''
210 | for label in node['labels']:
211 | labels += f' SET n:{label}\n'
212 | queries.append(
213 | f"""
214 | UNWIND $nodes AS node
215 | MERGE (n:Entity {{uuid: node.uuid}})
216 | {labels}
217 | SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
218 | SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
219 | RETURN n.uuid AS uuid
220 | """
221 | )
222 | return queries
223 | case GraphProvider.KUZU:
224 | return """
225 | MERGE (n:Entity {uuid: $uuid})
226 | SET
227 | n.name = $name,
228 | n.group_id = $group_id,
229 | n.labels = $labels,
230 | n.created_at = $created_at,
231 | n.name_embedding = $name_embedding,
232 | n.summary = $summary,
233 | n.attributes = $attributes
234 | RETURN n.uuid AS uuid
235 | """
236 | case _: # Neo4j
237 | save_embedding_query = (
238 | 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
239 | if not has_aoss
240 | else ''
241 | )
242 | return (
243 | """
244 | UNWIND $nodes AS node
245 | MERGE (n:Entity {uuid: node.uuid})
246 | SET n:$(node.labels)
247 | SET n = node
248 | """
249 | + save_embedding_query
250 | + """
251 | RETURN n.uuid AS uuid
252 | """
253 | )
254 |
255 |
256 | def get_entity_node_return_query(provider: GraphProvider) -> str:
257 | # `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
258 | if provider == GraphProvider.KUZU:
259 | return """
260 | n.uuid AS uuid,
261 | n.name AS name,
262 | n.group_id AS group_id,
263 | n.labels AS labels,
264 | n.created_at AS created_at,
265 | n.summary AS summary,
266 | n.attributes AS attributes
267 | """
268 |
269 | return """
270 | n.uuid AS uuid,
271 | n.name AS name,
272 | n.group_id AS group_id,
273 | n.created_at AS created_at,
274 | n.summary AS summary,
275 | labels(n) AS labels,
276 | properties(n) AS attributes
277 | """
278 |
279 |
280 | def get_community_node_save_query(provider: GraphProvider) -> str:
281 | match provider:
282 | case GraphProvider.FALKORDB:
283 | return """
284 | MERGE (n:Community {uuid: $uuid})
285 | SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
286 | RETURN n.uuid AS uuid
287 | """
288 | case GraphProvider.NEPTUNE:
289 | return """
290 | MERGE (n:Community {uuid: $uuid})
291 | SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
292 | SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
293 | RETURN n.uuid AS uuid
294 | """
295 | case GraphProvider.KUZU:
296 | return """
297 | MERGE (n:Community {uuid: $uuid})
298 | SET
299 | n.name = $name,
300 | n.group_id = $group_id,
301 | n.created_at = $created_at,
302 | n.name_embedding = $name_embedding,
303 | n.summary = $summary
304 | RETURN n.uuid AS uuid
305 | """
306 | case _: # Neo4j
307 | return """
308 | MERGE (n:Community {uuid: $uuid})
309 | SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
310 | WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
311 | RETURN n.uuid AS uuid
312 | """
313 |
314 |
315 | COMMUNITY_NODE_RETURN = """
316 | c.uuid AS uuid,
317 | c.name AS name,
318 | c.group_id AS group_id,
319 | c.created_at AS created_at,
320 | c.name_embedding AS name_embedding,
321 | c.summary AS summary
322 | """
323 |
324 | COMMUNITY_NODE_RETURN_NEPTUNE = """
325 | n.uuid AS uuid,
326 | n.name AS name,
327 | [x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
328 | n.group_id AS group_id,
329 | n.summary AS summary,
330 | n.created_at AS created_at
331 | """
332 |
```
--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
18 |
19 | from unittest.mock import AsyncMock, MagicMock, patch
20 |
21 | import pytest
22 |
23 | from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
24 | from graphiti_core.llm_client import LLMConfig, RateLimitError
25 |
26 |
27 | @pytest.fixture
28 | def mock_gemini_client():
29 | """Fixture to mock the Google Gemini client."""
30 | with patch('google.genai.Client') as mock_client:
31 | # Setup mock instance and its methods
32 | mock_instance = mock_client.return_value
33 | mock_instance.aio = MagicMock()
34 | mock_instance.aio.models = MagicMock()
35 | mock_instance.aio.models.generate_content = AsyncMock()
36 | yield mock_instance
37 |
38 |
39 | @pytest.fixture
40 | def gemini_reranker_client(mock_gemini_client):
41 | """Fixture to create a GeminiRerankerClient with a mocked client."""
42 | config = LLMConfig(api_key='test_api_key', model='test-model')
43 | client = GeminiRerankerClient(config=config)
44 | # Replace the client's client with our mock to ensure we're using the mock
45 | client.client = mock_gemini_client
46 | return client
47 |
48 |
49 | def create_mock_response(score_text: str) -> MagicMock:
50 | """Helper function to create a mock Gemini response."""
51 | mock_response = MagicMock()
52 | mock_response.text = score_text
53 | return mock_response
54 |
55 |
56 | class TestGeminiRerankerClientInitialization:
57 | """Tests for GeminiRerankerClient initialization."""
58 |
59 | def test_init_with_config(self):
60 | """Test initialization with a config object."""
61 | config = LLMConfig(api_key='test_api_key', model='test-model')
62 | client = GeminiRerankerClient(config=config)
63 |
64 | assert client.config == config
65 |
66 | @patch('google.genai.Client')
67 | def test_init_without_config(self, mock_client):
68 | """Test initialization without a config uses defaults."""
69 | client = GeminiRerankerClient()
70 |
71 | assert client.config is not None
72 |
73 | def test_init_with_custom_client(self):
74 | """Test initialization with a custom client."""
75 | mock_client = MagicMock()
76 | client = GeminiRerankerClient(client=mock_client)
77 |
78 | assert client.client == mock_client
79 |
80 |
81 | class TestGeminiRerankerClientRanking:
82 | """Tests for GeminiRerankerClient rank method."""
83 |
84 | @pytest.mark.asyncio
85 | async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
86 | """Test basic ranking functionality."""
87 | # Setup mock responses with different scores
88 | mock_responses = [
89 | create_mock_response('85'), # High relevance
90 | create_mock_response('45'), # Medium relevance
91 | create_mock_response('20'), # Low relevance
92 | ]
93 | mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
94 |
95 | # Test data
96 | query = 'What is the capital of France?'
97 | passages = [
98 | 'Paris is the capital and most populous city of France.',
99 | 'London is the capital city of England and the United Kingdom.',
100 | 'Berlin is the capital and largest city of Germany.',
101 | ]
102 |
103 | # Call method
104 | result = await gemini_reranker_client.rank(query, passages)
105 |
106 | # Assertions
107 | assert len(result) == 3
108 | assert all(isinstance(item, tuple) for item in result)
109 | assert all(
110 | isinstance(passage, str) and isinstance(score, float) for passage, score in result
111 | )
112 |
113 | # Check scores are normalized to [0, 1] and sorted in descending order
114 | scores = [score for _, score in result]
115 | assert all(0.0 <= score <= 1.0 for score in scores)
116 | assert scores == sorted(scores, reverse=True)
117 |
118 | # Check that the highest scoring passage is first
119 | assert result[0][1] == 0.85 # 85/100
120 | assert result[1][1] == 0.45 # 45/100
121 | assert result[2][1] == 0.20 # 20/100
122 |
123 | @pytest.mark.asyncio
124 | async def test_rank_empty_passages(self, gemini_reranker_client):
125 | """Test ranking with empty passages list."""
126 | query = 'Test query'
127 | passages = []
128 |
129 | result = await gemini_reranker_client.rank(query, passages)
130 |
131 | assert result == []
132 |
133 | @pytest.mark.asyncio
134 | async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
135 | """Test ranking with a single passage."""
136 | # Setup mock response
137 | mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
138 |
139 | query = 'Test query'
140 | passages = ['Single test passage']
141 |
142 | result = await gemini_reranker_client.rank(query, passages)
143 |
144 | assert len(result) == 1
145 | assert result[0][0] == 'Single test passage'
146 | assert result[0][1] == 1.0 # Single passage gets full score
147 |
148 | @pytest.mark.asyncio
149 | async def test_rank_score_extraction_with_regex(
150 | self, gemini_reranker_client, mock_gemini_client
151 | ):
152 | """Test score extraction from various response formats."""
153 | # Setup mock responses with different formats
154 | mock_responses = [
155 | create_mock_response('Score: 90'), # Contains text before number
156 | create_mock_response('The relevance is 65 out of 100'), # Contains text around number
157 | create_mock_response('8'), # Just the number
158 | ]
159 | mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
160 |
161 | query = 'Test query'
162 | passages = ['Passage 1', 'Passage 2', 'Passage 3']
163 |
164 | result = await gemini_reranker_client.rank(query, passages)
165 |
166 | # Check that scores were extracted correctly and normalized
167 | scores = [score for _, score in result]
168 | assert 0.90 in scores # 90/100
169 | assert 0.65 in scores # 65/100
170 | assert 0.08 in scores # 8/100
171 |
172 | @pytest.mark.asyncio
173 | async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
174 | """Test handling of invalid or non-numeric scores."""
175 | # Setup mock responses with invalid scores
176 | mock_responses = [
177 | create_mock_response('Not a number'), # Invalid response
178 | create_mock_response(''), # Empty response
179 | create_mock_response('95'), # Valid response
180 | ]
181 | mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
182 |
183 | query = 'Test query'
184 | passages = ['Passage 1', 'Passage 2', 'Passage 3']
185 |
186 | result = await gemini_reranker_client.rank(query, passages)
187 |
188 | # Check that invalid scores are handled gracefully (assigned 0.0)
189 | scores = [score for _, score in result]
190 | assert 0.95 in scores # Valid score
191 | assert scores.count(0.0) == 2 # Two invalid scores assigned 0.0
192 |
193 | @pytest.mark.asyncio
194 | async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
195 | """Test that scores are properly clamped to [0, 1] range."""
196 | # Setup mock responses with extreme scores
197 | # Note: regex only matches 1-3 digits, so negative numbers won't match
198 | mock_responses = [
199 | create_mock_response('999'), # Above 100 but within regex range
200 | create_mock_response('invalid'), # Invalid response becomes 0.0
201 | create_mock_response('50'), # Normal score
202 | ]
203 | mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
204 |
205 | query = 'Test query'
206 | passages = ['Passage 1', 'Passage 2', 'Passage 3']
207 |
208 | result = await gemini_reranker_client.rank(query, passages)
209 |
210 | # Check that scores are normalized and clamped
211 | scores = [score for _, score in result]
212 | assert all(0.0 <= score <= 1.0 for score in scores)
213 | # 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
214 | assert 1.0 in scores
215 | # Invalid response should be 0.0
216 | assert 0.0 in scores
217 | # Normal score should be normalized (50/100 = 0.5)
218 | assert 0.5 in scores
219 |
220 | @pytest.mark.asyncio
221 | async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
222 | """Test handling of rate limit errors."""
223 | # Setup mock to raise rate limit error
224 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
225 | 'Rate limit exceeded'
226 | )
227 |
228 | query = 'Test query'
229 | passages = ['Passage 1', 'Passage 2']
230 |
231 | with pytest.raises(RateLimitError):
232 | await gemini_reranker_client.rank(query, passages)
233 |
234 | @pytest.mark.asyncio
235 | async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
236 | """Test handling of quota errors."""
237 | # Setup mock to raise quota error
238 | mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
239 |
240 | query = 'Test query'
241 | passages = ['Passage 1', 'Passage 2']
242 |
243 | with pytest.raises(RateLimitError):
244 | await gemini_reranker_client.rank(query, passages)
245 |
246 | @pytest.mark.asyncio
247 | async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
248 | """Test handling of resource exhausted errors."""
249 | # Setup mock to raise resource exhausted error
250 | mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
251 |
252 | query = 'Test query'
253 | passages = ['Passage 1', 'Passage 2']
254 |
255 | with pytest.raises(RateLimitError):
256 | await gemini_reranker_client.rank(query, passages)
257 |
258 | @pytest.mark.asyncio
259 | async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
260 | """Test handling of HTTP 429 errors."""
261 | # Setup mock to raise 429 error
262 | mock_gemini_client.aio.models.generate_content.side_effect = Exception(
263 | 'HTTP 429 Too Many Requests'
264 | )
265 |
266 | query = 'Test query'
267 | passages = ['Passage 1', 'Passage 2']
268 |
269 | with pytest.raises(RateLimitError):
270 | await gemini_reranker_client.rank(query, passages)
271 |
272 | @pytest.mark.asyncio
273 | async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
274 | """Test handling of generic errors."""
275 | # Setup mock to raise generic error
276 | mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
277 |
278 | query = 'Test query'
279 | passages = ['Passage 1', 'Passage 2']
280 |
281 | with pytest.raises(Exception) as exc_info:
282 | await gemini_reranker_client.rank(query, passages)
283 |
284 | assert 'Generic error' in str(exc_info.value)
285 |
286 | @pytest.mark.asyncio
287 | async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
288 | """Test that multiple passages are scored concurrently."""
289 | # Setup mock responses
290 | mock_responses = [
291 | create_mock_response('80'),
292 | create_mock_response('60'),
293 | create_mock_response('40'),
294 | ]
295 | mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
296 |
297 | query = 'Test query'
298 | passages = ['Passage 1', 'Passage 2', 'Passage 3']
299 |
300 | await gemini_reranker_client.rank(query, passages)
301 |
302 | # Verify that generate_content was called for each passage
303 | assert mock_gemini_client.aio.models.generate_content.call_count == 3
304 |
305 | # Verify that all calls were made with correct parameters
306 | calls = mock_gemini_client.aio.models.generate_content.call_args_list
307 | for call in calls:
308 | args, kwargs = call
309 | assert kwargs['model'] == gemini_reranker_client.config.model
310 | assert kwargs['config'].temperature == 0.0
311 | assert kwargs['config'].max_output_tokens == 3
312 |
313 | @pytest.mark.asyncio
314 | async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
315 | """Test handling of response parsing errors."""
316 | # Setup mock responses that will trigger ValueError during parsing
317 | mock_responses = [
318 | create_mock_response('not a number at all'), # Will fail regex match
319 | create_mock_response('also invalid text'), # Will fail regex match
320 | ]
321 | mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
322 |
323 | query = 'Test query'
324 | # Use multiple passages to avoid the single passage special case
325 | passages = ['Passage 1', 'Passage 2']
326 |
327 | result = await gemini_reranker_client.rank(query, passages)
328 |
329 | # Should handle the error gracefully and assign 0.0 score to both
330 | assert len(result) == 2
331 | assert all(score == 0.0 for _, score in result)
332 |
333 | @pytest.mark.asyncio
334 | async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
335 | """Test handling of empty response text."""
336 | # Setup mock response with empty text
337 | mock_response = MagicMock()
338 | mock_response.text = '' # Empty string instead of None
339 | mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 |
341 | query = 'Test query'
342 | # Use multiple passages to avoid the single passage special case
343 | passages = ['Passage 1', 'Passage 2']
344 |
345 | result = await gemini_reranker_client.rank(query, passages)
346 |
347 | # Should handle empty text gracefully and assign 0.0 score to both
348 | assert len(result) == 2
349 | assert all(score == 0.0 for _, score in result)
350 |
351 |
352 | if __name__ == '__main__':
353 | pytest.main(['-v', 'test_gemini_reranker_client.py'])
354 |
```
--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | import sys
19 | from datetime import datetime
20 |
21 | import numpy as np
22 | import pytest
23 |
24 | from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
25 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
26 | from tests.helpers_test import get_edge_count, get_node_count, group_id
27 |
28 | pytest_plugins = ('pytest_asyncio',)
29 |
30 |
31 | def setup_logging():
32 | # Create a logger
33 | logger = logging.getLogger()
34 | logger.setLevel(logging.INFO) # Set the logging level to INFO
35 |
36 | # Create console handler and set level to INFO
37 | console_handler = logging.StreamHandler(sys.stdout)
38 | console_handler.setLevel(logging.INFO)
39 |
40 | # Create formatter
41 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
42 |
43 | # Add formatter to console handler
44 | console_handler.setFormatter(formatter)
45 |
46 | # Add console handler to logger
47 | logger.addHandler(console_handler)
48 |
49 | return logger
50 |
51 |
52 | @pytest.mark.asyncio
53 | async def test_episodic_edge(graph_driver, mock_embedder):
54 | now = datetime.now()
55 |
56 | # Create episodic node
57 | episode_node = EpisodicNode(
58 | name='test_episode',
59 | labels=[],
60 | created_at=now,
61 | valid_at=now,
62 | source=EpisodeType.message,
63 | source_description='conversation message',
64 | content='Alice likes Bob',
65 | entity_edges=[],
66 | group_id=group_id,
67 | )
68 | node_count = await get_node_count(graph_driver, [episode_node.uuid])
69 | assert node_count == 0
70 | await episode_node.save(graph_driver)
71 | node_count = await get_node_count(graph_driver, [episode_node.uuid])
72 | assert node_count == 1
73 |
74 | # Create entity node
75 | alice_node = EntityNode(
76 | name='Alice',
77 | labels=[],
78 | created_at=now,
79 | summary='Alice summary',
80 | group_id=group_id,
81 | )
82 | await alice_node.generate_name_embedding(mock_embedder)
83 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
84 | assert node_count == 0
85 | await alice_node.save(graph_driver)
86 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
87 | assert node_count == 1
88 |
89 | # Create episodic to entity edge
90 | episodic_edge = EpisodicEdge(
91 | source_node_uuid=episode_node.uuid,
92 | target_node_uuid=alice_node.uuid,
93 | created_at=now,
94 | group_id=group_id,
95 | )
96 | edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
97 | assert edge_count == 0
98 | await episodic_edge.save(graph_driver)
99 | edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
100 | assert edge_count == 1
101 |
102 | # Get edge by uuid
103 | retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
104 | assert retrieved.uuid == episodic_edge.uuid
105 | assert retrieved.source_node_uuid == episode_node.uuid
106 | assert retrieved.target_node_uuid == alice_node.uuid
107 | assert retrieved.created_at == now
108 | assert retrieved.group_id == group_id
109 |
110 | # Get edge by uuids
111 | retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
112 | assert len(retrieved) == 1
113 | assert retrieved[0].uuid == episodic_edge.uuid
114 | assert retrieved[0].source_node_uuid == episode_node.uuid
115 | assert retrieved[0].target_node_uuid == alice_node.uuid
116 | assert retrieved[0].created_at == now
117 | assert retrieved[0].group_id == group_id
118 |
119 | # Get edge by group ids
120 | retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
121 | assert len(retrieved) == 1
122 | assert retrieved[0].uuid == episodic_edge.uuid
123 | assert retrieved[0].source_node_uuid == episode_node.uuid
124 | assert retrieved[0].target_node_uuid == alice_node.uuid
125 | assert retrieved[0].created_at == now
126 | assert retrieved[0].group_id == group_id
127 |
128 | # Get episodic node by entity node uuid
129 | retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
130 | assert len(retrieved) == 1
131 | assert retrieved[0].uuid == episode_node.uuid
132 | assert retrieved[0].name == 'test_episode'
133 | assert retrieved[0].created_at == now
134 | assert retrieved[0].group_id == group_id
135 |
136 | # Delete edge by uuid
137 | await episodic_edge.delete(graph_driver)
138 | edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
139 | assert edge_count == 0
140 |
141 | # Delete edge by uuids
142 | await episodic_edge.save(graph_driver)
143 | await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
144 | edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
145 | assert edge_count == 0
146 |
147 | # Cleanup nodes
148 | await episode_node.delete(graph_driver)
149 | node_count = await get_node_count(graph_driver, [episode_node.uuid])
150 | assert node_count == 0
151 | await alice_node.delete(graph_driver)
152 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
153 | assert node_count == 0
154 |
155 | await graph_driver.close()
156 |
157 |
158 | @pytest.mark.asyncio
159 | async def test_entity_edge(graph_driver, mock_embedder):
160 | now = datetime.now()
161 |
162 | # Create entity node
163 | alice_node = EntityNode(
164 | name='Alice',
165 | labels=[],
166 | created_at=now,
167 | summary='Alice summary',
168 | group_id=group_id,
169 | )
170 | await alice_node.generate_name_embedding(mock_embedder)
171 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
172 | assert node_count == 0
173 | await alice_node.save(graph_driver)
174 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
175 | assert node_count == 1
176 |
177 | # Create entity node
178 | bob_node = EntityNode(
179 | name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
180 | )
181 | await bob_node.generate_name_embedding(mock_embedder)
182 | node_count = await get_node_count(graph_driver, [bob_node.uuid])
183 | assert node_count == 0
184 | await bob_node.save(graph_driver)
185 | node_count = await get_node_count(graph_driver, [bob_node.uuid])
186 | assert node_count == 1
187 |
188 | # Create entity to entity edge
189 | entity_edge = EntityEdge(
190 | source_node_uuid=alice_node.uuid,
191 | target_node_uuid=bob_node.uuid,
192 | created_at=now,
193 | name='likes',
194 | fact='Alice likes Bob',
195 | episodes=[],
196 | expired_at=now,
197 | valid_at=now,
198 | invalid_at=now,
199 | group_id=group_id,
200 | )
201 | edge_embedding = await entity_edge.generate_embedding(mock_embedder)
202 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
203 | assert edge_count == 0
204 | await entity_edge.save(graph_driver)
205 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
206 | assert edge_count == 1
207 |
208 | # Get edge by uuid
209 | retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
210 | assert retrieved.uuid == entity_edge.uuid
211 | assert retrieved.source_node_uuid == alice_node.uuid
212 | assert retrieved.target_node_uuid == bob_node.uuid
213 | assert retrieved.created_at == now
214 | assert retrieved.group_id == group_id
215 |
216 | # Get edge by uuids
217 | retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
218 | assert len(retrieved) == 1
219 | assert retrieved[0].uuid == entity_edge.uuid
220 | assert retrieved[0].source_node_uuid == alice_node.uuid
221 | assert retrieved[0].target_node_uuid == bob_node.uuid
222 | assert retrieved[0].created_at == now
223 | assert retrieved[0].group_id == group_id
224 |
225 | # Get edge by group ids
226 | retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
227 | assert len(retrieved) == 1
228 | assert retrieved[0].uuid == entity_edge.uuid
229 | assert retrieved[0].source_node_uuid == alice_node.uuid
230 | assert retrieved[0].target_node_uuid == bob_node.uuid
231 | assert retrieved[0].created_at == now
232 | assert retrieved[0].group_id == group_id
233 |
234 | # Get edge by node uuid
235 | retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
236 | assert len(retrieved) == 1
237 | assert retrieved[0].uuid == entity_edge.uuid
238 | assert retrieved[0].source_node_uuid == alice_node.uuid
239 | assert retrieved[0].target_node_uuid == bob_node.uuid
240 | assert retrieved[0].created_at == now
241 | assert retrieved[0].group_id == group_id
242 |
243 | # Get fact embedding
244 | await entity_edge.load_fact_embedding(graph_driver)
245 | assert np.allclose(entity_edge.fact_embedding, edge_embedding)
246 |
247 | # Delete edge by uuid
248 | await entity_edge.delete(graph_driver)
249 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
250 | assert edge_count == 0
251 |
252 | # Delete edge by uuids
253 | await entity_edge.save(graph_driver)
254 | await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
255 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
256 | assert edge_count == 0
257 |
258 | # Deleting node should delete the edge
259 | await entity_edge.save(graph_driver)
260 | await alice_node.delete(graph_driver)
261 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
262 | assert node_count == 0
263 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
264 | assert edge_count == 0
265 |
266 | # Deleting node by uuids should delete the edge
267 | await alice_node.save(graph_driver)
268 | await entity_edge.save(graph_driver)
269 | await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
270 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
271 | assert node_count == 0
272 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
273 | assert edge_count == 0
274 |
275 | # Deleting node by group id should delete the edge
276 | await alice_node.save(graph_driver)
277 | await entity_edge.save(graph_driver)
278 | await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
279 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
280 | assert node_count == 0
281 | edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
282 | assert edge_count == 0
283 |
284 | # Cleanup nodes
285 | await alice_node.delete(graph_driver)
286 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
287 | assert node_count == 0
288 | await bob_node.delete(graph_driver)
289 | node_count = await get_node_count(graph_driver, [bob_node.uuid])
290 | assert node_count == 0
291 |
292 | await graph_driver.close()
293 |
294 |
295 | @pytest.mark.asyncio
296 | async def test_community_edge(graph_driver, mock_embedder):
297 | now = datetime.now()
298 |
299 | # Create community node
300 | community_node_1 = CommunityNode(
301 | name='test_community_1',
302 | group_id=group_id,
303 | summary='Community A summary',
304 | )
305 | await community_node_1.generate_name_embedding(mock_embedder)
306 | node_count = await get_node_count(graph_driver, [community_node_1.uuid])
307 | assert node_count == 0
308 | await community_node_1.save(graph_driver)
309 | node_count = await get_node_count(graph_driver, [community_node_1.uuid])
310 | assert node_count == 1
311 |
312 | # Create community node
313 | community_node_2 = CommunityNode(
314 | name='test_community_2',
315 | group_id=group_id,
316 | summary='Community B summary',
317 | )
318 | await community_node_2.generate_name_embedding(mock_embedder)
319 | node_count = await get_node_count(graph_driver, [community_node_2.uuid])
320 | assert node_count == 0
321 | await community_node_2.save(graph_driver)
322 | node_count = await get_node_count(graph_driver, [community_node_2.uuid])
323 | assert node_count == 1
324 |
325 | # Create entity node
326 | alice_node = EntityNode(
327 | name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
328 | )
329 | await alice_node.generate_name_embedding(mock_embedder)
330 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
331 | assert node_count == 0
332 | await alice_node.save(graph_driver)
333 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
334 | assert node_count == 1
335 |
336 | # Create community to community edge
337 | community_edge = CommunityEdge(
338 | source_node_uuid=community_node_1.uuid,
339 | target_node_uuid=community_node_2.uuid,
340 | created_at=now,
341 | group_id=group_id,
342 | )
343 | edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
344 | assert edge_count == 0
345 | await community_edge.save(graph_driver)
346 | edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
347 | assert edge_count == 1
348 |
349 | # Get edge by uuid
350 | retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
351 | assert retrieved.uuid == community_edge.uuid
352 | assert retrieved.source_node_uuid == community_node_1.uuid
353 | assert retrieved.target_node_uuid == community_node_2.uuid
354 | assert retrieved.created_at == now
355 | assert retrieved.group_id == group_id
356 |
357 | # Get edge by uuids
358 | retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
359 | assert len(retrieved) == 1
360 | assert retrieved[0].uuid == community_edge.uuid
361 | assert retrieved[0].source_node_uuid == community_node_1.uuid
362 | assert retrieved[0].target_node_uuid == community_node_2.uuid
363 | assert retrieved[0].created_at == now
364 | assert retrieved[0].group_id == group_id
365 |
366 | # Get edge by group ids
367 | retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
368 | assert len(retrieved) == 1
369 | assert retrieved[0].uuid == community_edge.uuid
370 | assert retrieved[0].source_node_uuid == community_node_1.uuid
371 | assert retrieved[0].target_node_uuid == community_node_2.uuid
372 | assert retrieved[0].created_at == now
373 | assert retrieved[0].group_id == group_id
374 |
375 | # Delete edge by uuid
376 | await community_edge.delete(graph_driver)
377 | edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
378 | assert edge_count == 0
379 |
380 | # Delete edge by uuids
381 | await community_edge.save(graph_driver)
382 | await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
383 | edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
384 | assert edge_count == 0
385 |
386 | # Cleanup nodes
387 | await alice_node.delete(graph_driver)
388 | node_count = await get_node_count(graph_driver, [alice_node.uuid])
389 | assert node_count == 0
390 | await community_node_1.delete(graph_driver)
391 | node_count = await get_node_count(graph_driver, [community_node_1.uuid])
392 | assert node_count == 0
393 | await community_node_2.delete(graph_driver)
394 | node_count = await get_node_count(graph_driver, [community_node_2.uuid])
395 | assert node_count == 0
396 |
397 | await graph_driver.close()
398 |
```
--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Running tests: pytest -xvs tests/embedder/test_gemini.py
18 |
19 | from collections.abc import Generator
20 | from typing import Any
21 | from unittest.mock import AsyncMock, MagicMock, patch
22 |
23 | import pytest
24 | from embedder_fixtures import create_embedding_values
25 |
26 | from graphiti_core.embedder.gemini import (
27 | DEFAULT_EMBEDDING_MODEL,
28 | GeminiEmbedder,
29 | GeminiEmbedderConfig,
30 | )
31 |
32 |
33 | def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
34 | """Create a mock Gemini embedding with specified value multiplier and dimension."""
35 | mock_embedding = MagicMock()
36 | mock_embedding.values = create_embedding_values(multiplier, dimension)
37 | return mock_embedding
38 |
39 |
40 | @pytest.fixture
41 | def mock_gemini_response() -> MagicMock:
42 | """Create a mock Gemini embeddings response."""
43 | mock_result = MagicMock()
44 | mock_result.embeddings = [create_gemini_embedding()]
45 | return mock_result
46 |
47 |
48 | @pytest.fixture
49 | def mock_gemini_batch_response() -> MagicMock:
50 | """Create a mock Gemini batch embeddings response."""
51 | mock_result = MagicMock()
52 | mock_result.embeddings = [
53 | create_gemini_embedding(0.1),
54 | create_gemini_embedding(0.2),
55 | create_gemini_embedding(0.3),
56 | ]
57 | return mock_result
58 |
59 |
60 | @pytest.fixture
61 | def mock_gemini_client() -> Generator[Any, Any, None]:
62 | """Create a mocked Gemini client."""
63 | with patch('google.genai.Client') as mock_client:
64 | mock_instance = mock_client.return_value
65 | mock_instance.aio = MagicMock()
66 | mock_instance.aio.models = MagicMock()
67 | mock_instance.aio.models.embed_content = AsyncMock()
68 | yield mock_instance
69 |
70 |
71 | @pytest.fixture
72 | def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
73 | """Create a GeminiEmbedder with a mocked client."""
74 | config = GeminiEmbedderConfig(api_key='test_api_key')
75 | client = GeminiEmbedder(config=config)
76 | client.client = mock_gemini_client
77 | return client
78 |
79 |
80 | class TestGeminiEmbedderInitialization:
81 | """Tests for GeminiEmbedder initialization."""
82 |
83 | @patch('google.genai.Client')
84 | def test_init_with_config(self, mock_client):
85 | """Test initialization with a config object."""
86 | config = GeminiEmbedderConfig(
87 | api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
88 | )
89 | embedder = GeminiEmbedder(config=config)
90 |
91 | assert embedder.config == config
92 | assert embedder.config.embedding_model == 'custom-model'
93 | assert embedder.config.api_key == 'test_api_key'
94 | assert embedder.config.embedding_dim == 768
95 |
96 | @patch('google.genai.Client')
97 | def test_init_without_config(self, mock_client):
98 | """Test initialization without a config uses defaults."""
99 | embedder = GeminiEmbedder()
100 |
101 | assert embedder.config is not None
102 | assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
103 |
104 | @patch('google.genai.Client')
105 | def test_init_with_partial_config(self, mock_client):
106 | """Test initialization with partial config."""
107 | config = GeminiEmbedderConfig(api_key='test_api_key')
108 | embedder = GeminiEmbedder(config=config)
109 |
110 | assert embedder.config.api_key == 'test_api_key'
111 | assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
112 |
113 |
114 | class TestGeminiEmbedderCreate:
115 | """Tests for GeminiEmbedder create method."""
116 |
117 | @pytest.mark.asyncio
118 | async def test_create_calls_api_correctly(
119 | self,
120 | gemini_embedder: GeminiEmbedder,
121 | mock_gemini_client: Any,
122 | mock_gemini_response: MagicMock,
123 | ) -> None:
124 | """Test that create method correctly calls the API and processes the response."""
125 | # Setup
126 | mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
127 |
128 | # Call method
129 | result = await gemini_embedder.create('Test input')
130 |
131 | # Verify API is called with correct parameters
132 | mock_gemini_client.aio.models.embed_content.assert_called_once()
133 | _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
134 | assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
135 | assert kwargs['contents'] == ['Test input']
136 |
137 | # Verify result is processed correctly
138 | assert result == mock_gemini_response.embeddings[0].values
139 |
140 | @pytest.mark.asyncio
141 | @patch('google.genai.Client')
142 | async def test_create_with_custom_model(
143 | self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
144 | ) -> None:
145 | """Test create method with custom embedding model."""
146 | # Setup embedder with custom model
147 | config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
148 | embedder = GeminiEmbedder(config=config)
149 | embedder.client = mock_gemini_client
150 | mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
151 |
152 | # Call method
153 | await embedder.create('Test input')
154 |
155 | # Verify custom model is used
156 | _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
157 | assert kwargs['model'] == 'custom-model'
158 |
159 | @pytest.mark.asyncio
160 | @patch('google.genai.Client')
161 | async def test_create_with_custom_dimension(
162 | self, mock_client_class, mock_gemini_client: Any
163 | ) -> None:
164 | """Test create method with custom embedding dimension."""
165 | # Setup embedder with custom dimension
166 | config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
167 | embedder = GeminiEmbedder(config=config)
168 | embedder.client = mock_gemini_client
169 |
170 | # Setup mock response with custom dimension
171 | mock_response = MagicMock()
172 | mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
173 | mock_gemini_client.aio.models.embed_content.return_value = mock_response
174 |
175 | # Call method
176 | result = await embedder.create('Test input')
177 |
178 | # Verify custom dimension is used in config
179 | _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
180 | assert kwargs['config'].output_dimensionality == 768
181 |
182 | # Verify result has correct dimension
183 | assert len(result) == 768
184 |
185 | @pytest.mark.asyncio
186 | async def test_create_with_different_input_types(
187 | self,
188 | gemini_embedder: GeminiEmbedder,
189 | mock_gemini_client: Any,
190 | mock_gemini_response: MagicMock,
191 | ) -> None:
192 | """Test create method with different input types."""
193 | mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
194 |
195 | # Test with string
196 | await gemini_embedder.create('Test string')
197 |
198 | # Test with list of strings
199 | await gemini_embedder.create(['Test', 'List'])
200 |
201 | # Test with iterable of integers
202 | await gemini_embedder.create([1, 2, 3])
203 |
204 | # Verify all calls were made
205 | assert mock_gemini_client.aio.models.embed_content.call_count == 3
206 |
207 | @pytest.mark.asyncio
208 | async def test_create_no_embeddings_error(
209 | self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
210 | ) -> None:
211 | """Test create method handling of no embeddings response."""
212 | # Setup mock response with no embeddings
213 | mock_response = MagicMock()
214 | mock_response.embeddings = []
215 | mock_gemini_client.aio.models.embed_content.return_value = mock_response
216 |
217 | # Call method and expect exception
218 | with pytest.raises(ValueError) as exc_info:
219 | await gemini_embedder.create('Test input')
220 |
221 | assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
222 |
223 | @pytest.mark.asyncio
224 | async def test_create_no_values_error(
225 | self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
226 | ) -> None:
227 | """Test create method handling of embeddings with no values."""
228 | # Setup mock response with embedding but no values
229 | mock_embedding = MagicMock()
230 | mock_embedding.values = None
231 | mock_response = MagicMock()
232 | mock_response.embeddings = [mock_embedding]
233 | mock_gemini_client.aio.models.embed_content.return_value = mock_response
234 |
235 | # Call method and expect exception
236 | with pytest.raises(ValueError) as exc_info:
237 | await gemini_embedder.create('Test input')
238 |
239 | assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
240 |
241 |
242 | class TestGeminiEmbedderCreateBatch:
243 | """Tests for GeminiEmbedder create_batch method."""
244 |
245 | @pytest.mark.asyncio
246 | async def test_create_batch_processes_multiple_inputs(
247 | self,
248 | gemini_embedder: GeminiEmbedder,
249 | mock_gemini_client: Any,
250 | mock_gemini_batch_response: MagicMock,
251 | ) -> None:
252 | """Test that create_batch method correctly processes multiple inputs."""
253 | # Setup
254 | mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
255 | input_batch = ['Input 1', 'Input 2', 'Input 3']
256 |
257 | # Call method
258 | result = await gemini_embedder.create_batch(input_batch)
259 |
260 | # Verify API is called with correct parameters
261 | mock_gemini_client.aio.models.embed_content.assert_called_once()
262 | _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
263 | assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
264 | assert kwargs['contents'] == input_batch
265 |
266 | # Verify all results are processed correctly
267 | assert len(result) == 3
268 | assert result == [
269 | mock_gemini_batch_response.embeddings[0].values,
270 | mock_gemini_batch_response.embeddings[1].values,
271 | mock_gemini_batch_response.embeddings[2].values,
272 | ]
273 |
274 | @pytest.mark.asyncio
275 | async def test_create_batch_single_input(
276 | self,
277 | gemini_embedder: GeminiEmbedder,
278 | mock_gemini_client: Any,
279 | mock_gemini_response: MagicMock,
280 | ) -> None:
281 | """Test create_batch method with single input."""
282 | mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
283 | input_batch = ['Single input']
284 |
285 | result = await gemini_embedder.create_batch(input_batch)
286 |
287 | assert len(result) == 1
288 | assert result[0] == mock_gemini_response.embeddings[0].values
289 |
290 | @pytest.mark.asyncio
291 | async def test_create_batch_empty_input(
292 | self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
293 | ) -> None:
294 | """Test create_batch method with empty input."""
295 | # Setup mock response with no embeddings
296 | mock_response = MagicMock()
297 | mock_response.embeddings = []
298 | mock_gemini_client.aio.models.embed_content.return_value = mock_response
299 |
300 | input_batch = []
301 |
302 | result = await gemini_embedder.create_batch(input_batch)
303 | assert result == []
304 | mock_gemini_client.aio.models.embed_content.assert_not_called()
305 |
306 | @pytest.mark.asyncio
307 | async def test_create_batch_no_embeddings_error(
308 | self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
309 | ) -> None:
310 | """Test create_batch method handling of no embeddings response."""
311 | # Setup mock response with no embeddings
312 | mock_response = MagicMock()
313 | mock_response.embeddings = []
314 | mock_gemini_client.aio.models.embed_content.return_value = mock_response
315 |
316 | input_batch = ['Input 1', 'Input 2']
317 |
318 | with pytest.raises(ValueError) as exc_info:
319 | await gemini_embedder.create_batch(input_batch)
320 |
321 | assert 'No embeddings returned from Gemini API' in str(exc_info.value)
322 |
323 | @pytest.mark.asyncio
324 | async def test_create_batch_empty_values_error(
325 | self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
326 | ) -> None:
327 | """Test create_batch method handling of embeddings with empty values."""
328 | # Setup mock response with embeddings but empty values
329 | mock_embedding1 = MagicMock()
330 | mock_embedding1.values = [0.1, 0.2, 0.3] # Valid values
331 | mock_embedding2 = MagicMock()
332 | mock_embedding2.values = None # Empty values
333 |
334 | # Mock response for the initial batch call
335 | mock_batch_response = MagicMock()
336 | mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
337 |
338 | # Mock response for individual processing of 'Input 1'
339 | mock_individual_response_1 = MagicMock()
340 | mock_individual_response_1.embeddings = [mock_embedding1]
341 |
342 | # Mock response for individual processing of 'Input 2' (which has empty values)
343 | mock_individual_response_2 = MagicMock()
344 | mock_individual_response_2.embeddings = [mock_embedding2]
345 |
346 | # Set side_effect for embed_content to control return values for each call
347 | mock_gemini_client.aio.models.embed_content.side_effect = [
348 | mock_batch_response, # First call for the batch
349 | mock_individual_response_1, # Second call for individual item 1
350 | mock_individual_response_2, # Third call for individual item 2
351 | ]
352 |
353 | input_batch = ['Input 1', 'Input 2']
354 |
355 | with pytest.raises(ValueError) as exc_info:
356 | await gemini_embedder.create_batch(input_batch)
357 |
358 | assert 'Empty embedding values returned' in str(exc_info.value)
359 |
360 | @pytest.mark.asyncio
361 | @patch('google.genai.Client')
362 | async def test_create_batch_with_custom_model_and_dimension(
363 | self, mock_client_class, mock_gemini_client: Any
364 | ) -> None:
365 | """Test create_batch method with custom model and dimension."""
366 | # Setup embedder with custom settings
367 | config = GeminiEmbedderConfig(
368 | api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
369 | )
370 | embedder = GeminiEmbedder(config=config)
371 | embedder.client = mock_gemini_client
372 |
373 | # Setup mock response
374 | mock_response = MagicMock()
375 | mock_response.embeddings = [
376 | create_gemini_embedding(0.1, 512),
377 | create_gemini_embedding(0.2, 512),
378 | ]
379 | mock_gemini_client.aio.models.embed_content.return_value = mock_response
380 |
381 | input_batch = ['Input 1', 'Input 2']
382 | result = await embedder.create_batch(input_batch)
383 |
384 | # Verify custom settings are used
385 | _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
386 | assert kwargs['model'] == 'custom-batch-model'
387 | assert kwargs['config'].output_dimensionality == 512
388 |
389 | # Verify results have correct dimension
390 | assert len(result) == 2
391 | assert all(len(embedding) == 512 for embedding in result)
392 |
393 |
394 | if __name__ == '__main__':
395 | pytest.main(['-xvs', __file__])
396 |
```
--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import os
18 | import unittest
19 | from datetime import datetime, timezone
20 | from unittest.mock import AsyncMock, MagicMock, patch
21 |
22 | import pytest
23 |
24 | from graphiti_core.driver.driver import GraphProvider
25 |
26 | try:
27 | from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
28 |
29 | HAS_FALKORDB = True
30 | except ImportError:
31 | FalkorDriver = None
32 | HAS_FALKORDB = False
33 |
34 |
35 | class TestFalkorDriver:
36 | """Comprehensive test suite for FalkorDB driver."""
37 |
38 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
39 | def setup_method(self):
40 | """Set up test fixtures."""
41 | self.mock_client = MagicMock()
42 | with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
43 | self.driver = FalkorDriver()
44 | self.driver.client = self.mock_client
45 |
46 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
47 | def test_init_with_connection_params(self):
48 | """Test initialization with connection parameters."""
49 | with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
50 | driver = FalkorDriver(
51 | host='test-host', port='1234', username='test-user', password='test-pass'
52 | )
53 | assert driver.provider == GraphProvider.FALKORDB
54 | mock_falkor_db.assert_called_once_with(
55 | host='test-host', port='1234', username='test-user', password='test-pass'
56 | )
57 |
58 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
59 | def test_init_with_falkor_db_instance(self):
60 | """Test initialization with a FalkorDB instance."""
61 | with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
62 | mock_falkor_db = MagicMock()
63 | driver = FalkorDriver(falkor_db=mock_falkor_db)
64 | assert driver.provider == GraphProvider.FALKORDB
65 | assert driver.client is mock_falkor_db
66 | mock_falkor_db_class.assert_not_called()
67 |
68 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
69 | def test_provider(self):
70 | """Test driver provider identification."""
71 | assert self.driver.provider == GraphProvider.FALKORDB
72 |
73 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
74 | def test_get_graph_with_name(self):
75 | """Test _get_graph with specific graph name."""
76 | mock_graph = MagicMock()
77 | self.mock_client.select_graph.return_value = mock_graph
78 |
79 | result = self.driver._get_graph('test_graph')
80 |
81 | self.mock_client.select_graph.assert_called_once_with('test_graph')
82 | assert result is mock_graph
83 |
84 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
85 | def test_get_graph_with_none_defaults_to_default_database(self):
86 | """Test _get_graph with None defaults to default_db."""
87 | mock_graph = MagicMock()
88 | self.mock_client.select_graph.return_value = mock_graph
89 |
90 | result = self.driver._get_graph(None)
91 |
92 | self.mock_client.select_graph.assert_called_once_with('default_db')
93 | assert result is mock_graph
94 |
95 | @pytest.mark.asyncio
96 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
97 | async def test_execute_query_success(self):
98 | """Test successful query execution."""
99 | mock_graph = MagicMock()
100 | mock_result = MagicMock()
101 | mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
102 | mock_result.result_set = [['row1col1', 'row1col2']]
103 | mock_graph.query = AsyncMock(return_value=mock_result)
104 | self.mock_client.select_graph.return_value = mock_graph
105 |
106 | result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')
107 |
108 | mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})
109 |
110 | result_set, header, summary = result
111 | assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
112 | assert header == ['column1', 'column2']
113 | assert summary is None
114 |
115 | @pytest.mark.asyncio
116 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
117 | async def test_execute_query_handles_index_already_exists_error(self):
118 | """Test handling of 'already indexed' error."""
119 | mock_graph = MagicMock()
120 | mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
121 | self.mock_client.select_graph.return_value = mock_graph
122 |
123 | with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
124 | result = await self.driver.execute_query('CREATE INDEX ...')
125 |
126 | mock_logger.info.assert_called_once()
127 | assert result is None
128 |
129 | @pytest.mark.asyncio
130 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
131 | async def test_execute_query_propagates_other_exceptions(self):
132 | """Test that other exceptions are properly propagated."""
133 | mock_graph = MagicMock()
134 | mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
135 | self.mock_client.select_graph.return_value = mock_graph
136 |
137 | with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
138 | with pytest.raises(Exception, match='Other error'):
139 | await self.driver.execute_query('INVALID QUERY')
140 |
141 | mock_logger.error.assert_called_once()
142 |
143 | @pytest.mark.asyncio
144 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
145 | async def test_execute_query_converts_datetime_parameters(self):
146 | """Test that datetime objects in kwargs are converted to ISO strings."""
147 | mock_graph = MagicMock()
148 | mock_result = MagicMock()
149 | mock_result.header = []
150 | mock_result.result_set = []
151 | mock_graph.query = AsyncMock(return_value=mock_result)
152 | self.mock_client.select_graph.return_value = mock_graph
153 |
154 | test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
155 |
156 | await self.driver.execute_query(
157 | 'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
158 | )
159 |
160 | call_args = mock_graph.query.call_args[0]
161 | assert call_args[1]['created_at'] == test_datetime.isoformat()
162 |
163 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
164 | def test_session_creation(self):
165 | """Test session creation with specific database."""
166 | mock_graph = MagicMock()
167 | self.mock_client.select_graph.return_value = mock_graph
168 |
169 | session = self.driver.session()
170 |
171 | assert isinstance(session, FalkorDriverSession)
172 | assert session.graph is mock_graph
173 |
174 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
175 | def test_session_creation_with_none_uses_default_database(self):
176 | """Test session creation with None uses default database."""
177 | mock_graph = MagicMock()
178 | self.mock_client.select_graph.return_value = mock_graph
179 |
180 | session = self.driver.session()
181 |
182 | assert isinstance(session, FalkorDriverSession)
183 |
184 | @pytest.mark.asyncio
185 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
186 | async def test_close_calls_connection_close(self):
187 | """Test driver close method calls connection close."""
188 | mock_connection = MagicMock()
189 | mock_connection.close = AsyncMock()
190 | self.mock_client.connection = mock_connection
191 |
192 | # Ensure hasattr checks work correctly
193 | del self.mock_client.aclose # Remove aclose if it exists
194 |
195 | with patch('builtins.hasattr') as mock_hasattr:
196 | # hasattr(self.client, 'aclose') returns False
197 | # hasattr(self.client.connection, 'aclose') returns False
198 | # hasattr(self.client.connection, 'close') returns True
199 | mock_hasattr.side_effect = lambda obj, attr: (
200 | attr == 'close' and obj is mock_connection
201 | )
202 |
203 | await self.driver.close()
204 |
205 | mock_connection.close.assert_called_once()
206 |
207 | @pytest.mark.asyncio
208 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
209 | async def test_delete_all_indexes(self):
210 | """Test delete_all_indexes method."""
211 | with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
212 | # Return None to simulate no indexes found
213 | mock_execute.return_value = None
214 |
215 | await self.driver.delete_all_indexes()
216 |
217 | mock_execute.assert_called_once_with('CALL db.indexes()')
218 |
219 |
220 | class TestFalkorDriverSession:
221 | """Test FalkorDB driver session functionality."""
222 |
223 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
224 | def setup_method(self):
225 | """Set up test fixtures."""
226 | self.mock_graph = MagicMock()
227 | self.session = FalkorDriverSession(self.mock_graph)
228 |
229 | @pytest.mark.asyncio
230 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
231 | async def test_session_async_context_manager(self):
232 | """Test session can be used as async context manager."""
233 | async with self.session as s:
234 | assert s is self.session
235 |
236 | @pytest.mark.asyncio
237 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
238 | async def test_close_method(self):
239 | """Test session close method doesn't raise exceptions."""
240 | await self.session.close() # Should not raise
241 |
242 | @pytest.mark.asyncio
243 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
244 | async def test_execute_write_passes_session_and_args(self):
245 | """Test execute_write method passes session and arguments correctly."""
246 |
247 | async def test_func(session, *args, **kwargs):
248 | assert session is self.session
249 | assert args == ('arg1', 'arg2')
250 | assert kwargs == {'key': 'value'}
251 | return 'result'
252 |
253 | result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
254 | assert result == 'result'
255 |
256 | @pytest.mark.asyncio
257 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
258 | async def test_run_single_query_with_parameters(self):
259 | """Test running a single query with parameters."""
260 | self.mock_graph.query = AsyncMock()
261 |
262 | await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')
263 |
264 | self.mock_graph.query.assert_called_once_with(
265 | 'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
266 | )
267 |
268 | @pytest.mark.asyncio
269 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
270 | async def test_run_multiple_queries_as_list(self):
271 | """Test running multiple queries passed as list."""
272 | self.mock_graph.query = AsyncMock()
273 |
274 | queries = [
275 | ('MATCH (n) RETURN n', {'param1': 'value1'}),
276 | ('CREATE (n:Node)', {'param2': 'value2'}),
277 | ]
278 |
279 | await self.session.run(queries)
280 |
281 | assert self.mock_graph.query.call_count == 2
282 | calls = self.mock_graph.query.call_args_list
283 | assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
284 | assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
285 |
286 | @pytest.mark.asyncio
287 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
288 | async def test_run_converts_datetime_objects_to_iso_strings(self):
289 | """Test that datetime objects are converted to ISO strings."""
290 | self.mock_graph.query = AsyncMock()
291 | test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
292 |
293 | await self.session.run(
294 | 'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
295 | )
296 |
297 | self.mock_graph.query.assert_called_once()
298 | call_args = self.mock_graph.query.call_args[0]
299 | assert call_args[1]['created_at'] == test_datetime.isoformat()
300 |
301 |
302 | class TestDatetimeConversion:
303 | """Test datetime conversion utility function."""
304 |
305 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
306 | def test_convert_datetime_dict(self):
307 | """Test datetime conversion in nested dictionary."""
308 | from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
309 |
310 | test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
311 | input_dict = {
312 | 'string_val': 'test',
313 | 'datetime_val': test_datetime,
314 | 'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
315 | }
316 |
317 | result = convert_datetimes_to_strings(input_dict)
318 |
319 | assert result['string_val'] == 'test'
320 | assert result['datetime_val'] == test_datetime.isoformat()
321 | assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
322 | assert result['nested_dict']['nested_string'] == 'nested_test'
323 |
324 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
325 | def test_convert_datetime_list_and_tuple(self):
326 | """Test datetime conversion in lists and tuples."""
327 | from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
328 |
329 | test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
330 |
331 | # Test list
332 | input_list = ['test', test_datetime, ['nested', test_datetime]]
333 | result_list = convert_datetimes_to_strings(input_list)
334 | assert result_list[0] == 'test'
335 | assert result_list[1] == test_datetime.isoformat()
336 | assert result_list[2][1] == test_datetime.isoformat()
337 |
338 | # Test tuple
339 | input_tuple = ('test', test_datetime)
340 | result_tuple = convert_datetimes_to_strings(input_tuple)
341 | assert isinstance(result_tuple, tuple)
342 | assert result_tuple[0] == 'test'
343 | assert result_tuple[1] == test_datetime.isoformat()
344 |
345 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
346 | def test_convert_single_datetime(self):
347 | """Test datetime conversion for single datetime object."""
348 | from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
349 |
350 | test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
351 | result = convert_datetimes_to_strings(test_datetime)
352 | assert result == test_datetime.isoformat()
353 |
354 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
355 | def test_convert_other_types_unchanged(self):
356 | """Test that non-datetime types are returned unchanged."""
357 | from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
358 |
359 | assert convert_datetimes_to_strings('string') == 'string'
360 | assert convert_datetimes_to_strings(123) == 123
361 | assert convert_datetimes_to_strings(None) is None
362 | assert convert_datetimes_to_strings(True) is True
363 |
364 |
365 | # Simple integration test
366 | class TestFalkorDriverIntegration:
367 | """Simple integration test for FalkorDB driver."""
368 |
369 | @pytest.mark.asyncio
370 | @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
371 | async def test_basic_integration_with_real_falkordb(self):
372 | """Basic integration test with real FalkorDB instance."""
373 | pytest.importorskip('falkordb')
374 |
375 | falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
376 | falkor_port = os.getenv('FALKORDB_PORT', '6379')
377 |
378 | try:
379 | driver = FalkorDriver(host=falkor_host, port=falkor_port)
380 |
381 | # Test basic query execution
382 | result = await driver.execute_query('RETURN 1 as test')
383 | assert result is not None
384 |
385 | result_set, header, summary = result
386 | assert header == ['test']
387 | assert result_set == [{'test': 1}]
388 |
389 | await driver.close()
390 |
391 | except Exception as e:
392 | pytest.skip(f'FalkorDB not available for integration test: {e}')
393 |
```
--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------
```python
1 | """Factory classes for creating LLM, Embedder, and Database clients."""
2 |
3 | from openai import AsyncAzureOpenAI
4 |
5 | from config.schema import (
6 | DatabaseConfig,
7 | EmbedderConfig,
8 | LLMConfig,
9 | )
10 |
11 | # Try to import FalkorDriver if available
12 | try:
13 | from graphiti_core.driver.falkordb_driver import FalkorDriver # noqa: F401
14 |
15 | HAS_FALKOR = True
16 | except ImportError:
17 | HAS_FALKOR = False
18 |
19 | # Kuzu support removed - FalkorDB is now the default
20 | from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
21 | from graphiti_core.llm_client import LLMClient, OpenAIClient
22 | from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig
23 |
24 | # Try to import additional providers if available
25 | try:
26 | from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
27 |
28 | HAS_AZURE_EMBEDDER = True
29 | except ImportError:
30 | HAS_AZURE_EMBEDDER = False
31 |
32 | try:
33 | from graphiti_core.embedder.gemini import GeminiEmbedder
34 |
35 | HAS_GEMINI_EMBEDDER = True
36 | except ImportError:
37 | HAS_GEMINI_EMBEDDER = False
38 |
39 | try:
40 | from graphiti_core.embedder.voyage import VoyageAIEmbedder
41 |
42 | HAS_VOYAGE_EMBEDDER = True
43 | except ImportError:
44 | HAS_VOYAGE_EMBEDDER = False
45 |
46 | try:
47 | from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
48 |
49 | HAS_AZURE_LLM = True
50 | except ImportError:
51 | HAS_AZURE_LLM = False
52 |
53 | try:
54 | from graphiti_core.llm_client.anthropic_client import AnthropicClient
55 |
56 | HAS_ANTHROPIC = True
57 | except ImportError:
58 | HAS_ANTHROPIC = False
59 |
60 | try:
61 | from graphiti_core.llm_client.gemini_client import GeminiClient
62 |
63 | HAS_GEMINI = True
64 | except ImportError:
65 | HAS_GEMINI = False
66 |
67 | try:
68 | from graphiti_core.llm_client.groq_client import GroqClient
69 |
70 | HAS_GROQ = True
71 | except ImportError:
72 | HAS_GROQ = False
73 | from utils.utils import create_azure_credential_token_provider
74 |
75 |
76 | def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
77 | """Validate API key is present.
78 |
79 | Args:
80 | provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
81 | api_key: The API key to validate
82 | logger: Logger instance for output
83 |
84 | Returns:
85 | The validated API key
86 |
87 | Raises:
88 | ValueError: If API key is None or empty
89 | """
90 | if not api_key:
91 | raise ValueError(
92 | f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
93 | )
94 |
95 | logger.info(f'Creating {provider_name} client')
96 |
97 | return api_key
98 |
99 |
100 | class LLMClientFactory:
101 | """Factory for creating LLM clients based on configuration."""
102 |
103 | @staticmethod
104 | def create(config: LLMConfig) -> LLMClient:
105 | """Create an LLM client based on the configured provider."""
106 | import logging
107 |
108 | logger = logging.getLogger(__name__)
109 |
110 | provider = config.provider.lower()
111 |
112 | match provider:
113 | case 'openai':
114 | if not config.providers.openai:
115 | raise ValueError('OpenAI provider configuration not found')
116 |
117 | api_key = config.providers.openai.api_key
118 | _validate_api_key('OpenAI', api_key, logger)
119 |
120 | from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
121 |
122 | # Determine appropriate small model based on main model type
123 | is_reasoning_model = (
124 | config.model.startswith('gpt-5')
125 | or config.model.startswith('o1')
126 | or config.model.startswith('o3')
127 | )
128 | small_model = (
129 | 'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
130 | ) # Use reasoning model for small tasks if main model is reasoning
131 |
132 | llm_config = CoreLLMConfig(
133 | api_key=api_key,
134 | model=config.model,
135 | small_model=small_model,
136 | temperature=config.temperature,
137 | max_tokens=config.max_tokens,
138 | )
139 |
140 | # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
141 | if is_reasoning_model:
142 | return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
143 | else:
144 | # For non-reasoning models, explicitly pass None to disable these parameters
145 | return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)
146 |
147 | case 'azure_openai':
148 | if not HAS_AZURE_LLM:
149 | raise ValueError(
150 | 'Azure OpenAI LLM client not available in current graphiti-core version'
151 | )
152 | if not config.providers.azure_openai:
153 | raise ValueError('Azure OpenAI provider configuration not found')
154 | azure_config = config.providers.azure_openai
155 |
156 | if not azure_config.api_url:
157 | raise ValueError('Azure OpenAI API URL is required')
158 |
159 | # Handle Azure AD authentication if enabled
160 | api_key: str | None = None
161 | azure_ad_token_provider = None
162 | if azure_config.use_azure_ad:
163 | logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
164 | azure_ad_token_provider = create_azure_credential_token_provider()
165 | else:
166 | api_key = azure_config.api_key
167 | _validate_api_key('Azure OpenAI', api_key, logger)
168 |
169 | # Create the Azure OpenAI client first
170 | azure_client = AsyncAzureOpenAI(
171 | api_key=api_key,
172 | azure_endpoint=azure_config.api_url,
173 | api_version=azure_config.api_version,
174 | azure_deployment=azure_config.deployment_name,
175 | azure_ad_token_provider=azure_ad_token_provider,
176 | )
177 |
178 | # Then create the LLMConfig
179 | from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
180 |
181 | llm_config = CoreLLMConfig(
182 | api_key=api_key,
183 | base_url=azure_config.api_url,
184 | model=config.model,
185 | temperature=config.temperature,
186 | max_tokens=config.max_tokens,
187 | )
188 |
189 | return AzureOpenAILLMClient(
190 | azure_client=azure_client,
191 | config=llm_config,
192 | max_tokens=config.max_tokens,
193 | )
194 |
195 | case 'anthropic':
196 | if not HAS_ANTHROPIC:
197 | raise ValueError(
198 | 'Anthropic client not available in current graphiti-core version'
199 | )
200 | if not config.providers.anthropic:
201 | raise ValueError('Anthropic provider configuration not found')
202 |
203 | api_key = config.providers.anthropic.api_key
204 | _validate_api_key('Anthropic', api_key, logger)
205 |
206 | llm_config = GraphitiLLMConfig(
207 | api_key=api_key,
208 | model=config.model,
209 | temperature=config.temperature,
210 | max_tokens=config.max_tokens,
211 | )
212 | return AnthropicClient(config=llm_config)
213 |
214 | case 'gemini':
215 | if not HAS_GEMINI:
216 | raise ValueError('Gemini client not available in current graphiti-core version')
217 | if not config.providers.gemini:
218 | raise ValueError('Gemini provider configuration not found')
219 |
220 | api_key = config.providers.gemini.api_key
221 | _validate_api_key('Gemini', api_key, logger)
222 |
223 | llm_config = GraphitiLLMConfig(
224 | api_key=api_key,
225 | model=config.model,
226 | temperature=config.temperature,
227 | max_tokens=config.max_tokens,
228 | )
229 | return GeminiClient(config=llm_config)
230 |
231 | case 'groq':
232 | if not HAS_GROQ:
233 | raise ValueError('Groq client not available in current graphiti-core version')
234 | if not config.providers.groq:
235 | raise ValueError('Groq provider configuration not found')
236 |
237 | api_key = config.providers.groq.api_key
238 | _validate_api_key('Groq', api_key, logger)
239 |
240 | llm_config = GraphitiLLMConfig(
241 | api_key=api_key,
242 | base_url=config.providers.groq.api_url,
243 | model=config.model,
244 | temperature=config.temperature,
245 | max_tokens=config.max_tokens,
246 | )
247 | return GroqClient(config=llm_config)
248 |
249 | case _:
250 | raise ValueError(f'Unsupported LLM provider: {provider}')
251 |
252 |
253 | class EmbedderFactory:
254 | """Factory for creating Embedder clients based on configuration."""
255 |
256 | @staticmethod
257 | def create(config: EmbedderConfig) -> EmbedderClient:
258 | """Create an Embedder client based on the configured provider."""
259 | import logging
260 |
261 | logger = logging.getLogger(__name__)
262 |
263 | provider = config.provider.lower()
264 |
265 | match provider:
266 | case 'openai':
267 | if not config.providers.openai:
268 | raise ValueError('OpenAI provider configuration not found')
269 |
270 | api_key = config.providers.openai.api_key
271 | _validate_api_key('OpenAI Embedder', api_key, logger)
272 |
273 | from graphiti_core.embedder.openai import OpenAIEmbedderConfig
274 |
275 | embedder_config = OpenAIEmbedderConfig(
276 | api_key=api_key,
277 | embedding_model=config.model,
278 | )
279 | return OpenAIEmbedder(config=embedder_config)
280 |
281 | case 'azure_openai':
282 | if not HAS_AZURE_EMBEDDER:
283 | raise ValueError(
284 | 'Azure OpenAI embedder not available in current graphiti-core version'
285 | )
286 | if not config.providers.azure_openai:
287 | raise ValueError('Azure OpenAI provider configuration not found')
288 | azure_config = config.providers.azure_openai
289 |
290 | if not azure_config.api_url:
291 | raise ValueError('Azure OpenAI API URL is required')
292 |
293 | # Handle Azure AD authentication if enabled
294 | api_key: str | None = None
295 | azure_ad_token_provider = None
296 | if azure_config.use_azure_ad:
297 | logger.info(
298 | 'Creating Azure OpenAI Embedder client with Azure AD authentication'
299 | )
300 | azure_ad_token_provider = create_azure_credential_token_provider()
301 | else:
302 | api_key = azure_config.api_key
303 | _validate_api_key('Azure OpenAI Embedder', api_key, logger)
304 |
305 | # Create the Azure OpenAI client first
306 | azure_client = AsyncAzureOpenAI(
307 | api_key=api_key,
308 | azure_endpoint=azure_config.api_url,
309 | api_version=azure_config.api_version,
310 | azure_deployment=azure_config.deployment_name,
311 | azure_ad_token_provider=azure_ad_token_provider,
312 | )
313 |
314 | return AzureOpenAIEmbedderClient(
315 | azure_client=azure_client,
316 | model=config.model or 'text-embedding-3-small',
317 | )
318 |
319 | case 'gemini':
320 | if not HAS_GEMINI_EMBEDDER:
321 | raise ValueError(
322 | 'Gemini embedder not available in current graphiti-core version'
323 | )
324 | if not config.providers.gemini:
325 | raise ValueError('Gemini provider configuration not found')
326 |
327 | api_key = config.providers.gemini.api_key
328 | _validate_api_key('Gemini Embedder', api_key, logger)
329 |
330 | from graphiti_core.embedder.gemini import GeminiEmbedderConfig
331 |
332 | gemini_config = GeminiEmbedderConfig(
333 | api_key=api_key,
334 | embedding_model=config.model or 'models/text-embedding-004',
335 | embedding_dim=config.dimensions or 768,
336 | )
337 | return GeminiEmbedder(config=gemini_config)
338 |
339 | case 'voyage':
340 | if not HAS_VOYAGE_EMBEDDER:
341 | raise ValueError(
342 | 'Voyage embedder not available in current graphiti-core version'
343 | )
344 | if not config.providers.voyage:
345 | raise ValueError('Voyage provider configuration not found')
346 |
347 | api_key = config.providers.voyage.api_key
348 | _validate_api_key('Voyage Embedder', api_key, logger)
349 |
350 | from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig
351 |
352 | voyage_config = VoyageAIEmbedderConfig(
353 | api_key=api_key,
354 | embedding_model=config.model or 'voyage-3',
355 | embedding_dim=config.dimensions or 1024,
356 | )
357 | return VoyageAIEmbedder(config=voyage_config)
358 |
359 | case _:
360 | raise ValueError(f'Unsupported Embedder provider: {provider}')
361 |
362 |
363 | class DatabaseDriverFactory:
364 | """Factory for creating Database drivers based on configuration.
365 |
366 | Note: This returns configuration dictionaries that can be passed to Graphiti(),
367 | not driver instances directly, as the drivers require complex initialization.
368 | """
369 |
370 | @staticmethod
371 | def create_config(config: DatabaseConfig) -> dict:
372 | """Create database configuration dictionary based on the configured provider."""
373 | provider = config.provider.lower()
374 |
375 | match provider:
376 | case 'neo4j':
377 | # Use Neo4j config if provided, otherwise use defaults
378 | if config.providers.neo4j:
379 | neo4j_config = config.providers.neo4j
380 | else:
381 | # Create default Neo4j configuration
382 | from config.schema import Neo4jProviderConfig
383 |
384 | neo4j_config = Neo4jProviderConfig()
385 |
386 | # Check for environment variable overrides (for CI/CD compatibility)
387 | import os
388 |
389 | uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
390 | username = os.environ.get('NEO4J_USER', neo4j_config.username)
391 | password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)
392 |
393 | return {
394 | 'uri': uri,
395 | 'user': username,
396 | 'password': password,
397 | # Note: database and use_parallel_runtime would need to be passed
398 | # to the driver after initialization if supported
399 | }
400 |
401 | case 'falkordb':
402 | if not HAS_FALKOR:
403 | raise ValueError(
404 | 'FalkorDB driver not available in current graphiti-core version'
405 | )
406 |
407 | # Use FalkorDB config if provided, otherwise use defaults
408 | if config.providers.falkordb:
409 | falkor_config = config.providers.falkordb
410 | else:
411 | # Create default FalkorDB configuration
412 | from config.schema import FalkorDBProviderConfig
413 |
414 | falkor_config = FalkorDBProviderConfig()
415 |
416 | # Check for environment variable overrides (for CI/CD compatibility)
417 | import os
418 | from urllib.parse import urlparse
419 |
420 | uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
421 | password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)
422 |
423 | # Parse the URI to extract host and port
424 | parsed = urlparse(uri)
425 | host = parsed.hostname or 'localhost'
426 | port = parsed.port or 6379
427 |
428 | return {
429 | 'driver': 'falkordb',
430 | 'host': host,
431 | 'port': port,
432 | 'password': password,
433 | 'database': falkor_config.database,
434 | }
435 |
436 | case _:
437 | raise ValueError(f'Unsupported Database provider: {provider}')
438 |
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/anthropic_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import os
20 | import typing
21 | from json import JSONDecodeError
22 | from typing import TYPE_CHECKING, Literal
23 |
24 | from pydantic import BaseModel, ValidationError
25 |
26 | from ..prompts.models import Message
27 | from .client import LLMClient
28 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
29 | from .errors import RateLimitError, RefusalError
30 |
31 | if TYPE_CHECKING:
32 | import anthropic
33 | from anthropic import AsyncAnthropic
34 | from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
35 | else:
36 | try:
37 | import anthropic
38 | from anthropic import AsyncAnthropic
39 | from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
40 | except ImportError:
41 | raise ImportError(
42 | 'anthropic is required for AnthropicClient. '
43 | 'Install it with: pip install graphiti-core[anthropic]'
44 | ) from None
45 |
46 |
47 | logger = logging.getLogger(__name__)
48 |
49 | AnthropicModel = Literal[
50 | 'claude-sonnet-4-5-latest',
51 | 'claude-sonnet-4-5-20250929',
52 | 'claude-haiku-4-5-latest',
53 | 'claude-3-7-sonnet-latest',
54 | 'claude-3-7-sonnet-20250219',
55 | 'claude-3-5-haiku-latest',
56 | 'claude-3-5-haiku-20241022',
57 | 'claude-3-5-sonnet-latest',
58 | 'claude-3-5-sonnet-20241022',
59 | 'claude-3-5-sonnet-20240620',
60 | 'claude-3-opus-latest',
61 | 'claude-3-opus-20240229',
62 | 'claude-3-sonnet-20240229',
63 | 'claude-3-haiku-20240307',
64 | 'claude-2.1',
65 | 'claude-2.0',
66 | ]
67 |
68 | DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
69 |
70 | # Maximum output tokens for different Anthropic models
71 | # Based on official Anthropic documentation (as of 2025)
72 | # Note: These represent standard limits without beta headers.
73 | # Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
74 | # 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
75 | ANTHROPIC_MODEL_MAX_TOKENS = {
76 | # Claude 4.5 models - 64K tokens
77 | 'claude-sonnet-4-5-latest': 65536,
78 | 'claude-sonnet-4-5-20250929': 65536,
79 | 'claude-haiku-4-5-latest': 65536,
80 | # Claude 3.7 models - standard 64K tokens
81 | 'claude-3-7-sonnet-latest': 65536,
82 | 'claude-3-7-sonnet-20250219': 65536,
83 | # Claude 3.5 models
84 | 'claude-3-5-haiku-latest': 8192,
85 | 'claude-3-5-haiku-20241022': 8192,
86 | 'claude-3-5-sonnet-latest': 8192,
87 | 'claude-3-5-sonnet-20241022': 8192,
88 | 'claude-3-5-sonnet-20240620': 8192,
89 | # Claude 3 models - 4K tokens
90 | 'claude-3-opus-latest': 4096,
91 | 'claude-3-opus-20240229': 4096,
92 | 'claude-3-sonnet-20240229': 4096,
93 | 'claude-3-haiku-20240307': 4096,
94 | # Claude 2 models - 4K tokens
95 | 'claude-2.1': 4096,
96 | 'claude-2.0': 4096,
97 | }
98 |
99 | # Default max tokens for models not in the mapping
100 | DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
101 |
102 |
103 | class AnthropicClient(LLMClient):
104 | """
105 | A client for the Anthropic LLM.
106 |
107 | Args:
108 | config: A configuration object for the LLM.
109 | cache: Whether to cache the LLM responses.
110 | client: An optional client instance to use.
111 | max_tokens: The maximum number of tokens to generate.
112 |
113 | Methods:
114 | generate_response: Generate a response from the LLM.
115 |
116 | Notes:
117 | - If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
118 | variable, and all default values will be used for the LLMConfig.
119 |
120 | """
121 |
122 | model: AnthropicModel
123 |
124 | def __init__(
125 | self,
126 | config: LLMConfig | None = None,
127 | cache: bool = False,
128 | client: AsyncAnthropic | None = None,
129 | max_tokens: int = DEFAULT_MAX_TOKENS,
130 | ) -> None:
131 | if config is None:
132 | config = LLMConfig()
133 | config.api_key = os.getenv('ANTHROPIC_API_KEY')
134 | config.max_tokens = max_tokens
135 |
136 | if config.model is None:
137 | config.model = DEFAULT_MODEL
138 |
139 | super().__init__(config, cache)
140 | # Explicitly set the instance model to the config model to prevent type checking errors
141 | self.model = typing.cast(AnthropicModel, config.model)
142 |
143 | if not client:
144 | self.client = AsyncAnthropic(
145 | api_key=config.api_key,
146 | max_retries=1,
147 | )
148 | else:
149 | self.client = client
150 |
151 | def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
152 | """Extract JSON from text content.
153 |
154 | A helper method to extract JSON from text content, used when tool use fails or
155 | no response_model is provided.
156 |
157 | Args:
158 | text: The text to extract JSON from
159 |
160 | Returns:
161 | Extracted JSON as a dictionary
162 |
163 | Raises:
164 | ValueError: If JSON cannot be extracted or parsed
165 | """
166 | try:
167 | json_start = text.find('{')
168 | json_end = text.rfind('}') + 1
169 | if json_start >= 0 and json_end > json_start:
170 | json_str = text[json_start:json_end]
171 | return json.loads(json_str)
172 | else:
173 | raise ValueError(f'Could not extract JSON from model response: {text}')
174 | except (JSONDecodeError, ValueError) as e:
175 | raise ValueError(f'Could not extract JSON from model response: {text}') from e
176 |
177 | def _create_tool(
178 | self, response_model: type[BaseModel] | None = None
179 | ) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
180 | """
181 | Create a tool definition based on the response_model if provided, or a generic JSON tool if not.
182 |
183 | Args:
184 | response_model: Optional Pydantic model to use for structured output.
185 |
186 | Returns:
187 | A list containing a single tool definition for use with the Anthropic API.
188 | """
189 | if response_model is not None:
190 | # Use the response_model to define the tool
191 | model_schema = response_model.model_json_schema()
192 | tool_name = response_model.__name__
193 | description = model_schema.get('description', f'Extract {tool_name} information')
194 | else:
195 | # Create a generic JSON output tool
196 | tool_name = 'generic_json_output'
197 | description = 'Output data in JSON format'
198 | model_schema = {
199 | 'type': 'object',
200 | 'additionalProperties': True,
201 | 'description': 'Any JSON object containing the requested information',
202 | }
203 |
204 | tool = {
205 | 'name': tool_name,
206 | 'description': description,
207 | 'input_schema': model_schema,
208 | }
209 | tool_list = [tool]
210 | tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
211 | tool_choice = {'type': 'tool', 'name': tool_name}
212 | tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
213 | return tool_list_cast, tool_choice_cast
214 |
215 | def _get_max_tokens_for_model(self, model: str) -> int:
216 | """Get the maximum output tokens for a specific Anthropic model.
217 |
218 | Args:
219 | model: The model name to look up
220 |
221 | Returns:
222 | int: The maximum output tokens for the model
223 | """
224 | return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
225 |
226 | def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
227 | """
228 | Resolve the maximum output tokens to use based on precedence rules.
229 |
230 | Precedence order (highest to lowest):
231 | 1. Explicit max_tokens parameter passed to generate_response()
232 | 2. Instance max_tokens set during client initialization
233 | 3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
234 | 4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
235 |
236 | Args:
237 | requested_max_tokens: The max_tokens parameter passed to generate_response()
238 | model: The model name to look up model-specific limits
239 |
240 | Returns:
241 | int: The resolved maximum tokens to use
242 | """
243 | # 1. Use explicit parameter if provided
244 | if requested_max_tokens is not None:
245 | return requested_max_tokens
246 |
247 | # 2. Use instance max_tokens if set during initialization
248 | if self.max_tokens is not None:
249 | return self.max_tokens
250 |
251 | # 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
252 | return self._get_max_tokens_for_model(model)
253 |
254 | async def _generate_response(
255 | self,
256 | messages: list[Message],
257 | response_model: type[BaseModel] | None = None,
258 | max_tokens: int | None = None,
259 | model_size: ModelSize = ModelSize.medium,
260 | ) -> dict[str, typing.Any]:
261 | """
262 | Generate a response from the Anthropic LLM using tool-based approach for all requests.
263 |
264 | Args:
265 | messages: List of message objects to send to the LLM.
266 | response_model: Optional Pydantic model to use for structured output.
267 | max_tokens: Maximum number of tokens to generate.
268 |
269 | Returns:
270 | Dictionary containing the structured response from the LLM.
271 |
272 | Raises:
273 | RateLimitError: If the rate limit is exceeded.
274 | RefusalError: If the LLM refuses to respond.
275 | Exception: If an error occurs during the generation process.
276 | """
277 | system_message = messages[0]
278 | user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
279 | user_messages_cast = typing.cast(list[MessageParam], user_messages)
280 |
281 | # Resolve max_tokens dynamically based on the model's capabilities
282 | # This allows different models to use their full output capacity
283 | max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
284 |
285 | try:
286 | # Create the appropriate tool based on whether response_model is provided
287 | tools, tool_choice = self._create_tool(response_model)
288 | result = await self.client.messages.create(
289 | system=system_message.content,
290 | max_tokens=max_creation_tokens,
291 | temperature=self.temperature,
292 | messages=user_messages_cast,
293 | model=self.model,
294 | tools=tools,
295 | tool_choice=tool_choice,
296 | )
297 |
298 | # Extract the tool output from the response
299 | for content_item in result.content:
300 | if content_item.type == 'tool_use':
301 | if isinstance(content_item.input, dict):
302 | tool_args: dict[str, typing.Any] = content_item.input
303 | else:
304 | tool_args = json.loads(str(content_item.input))
305 | return tool_args
306 |
307 | # If we didn't get a proper tool_use response, try to extract from text
308 | for content_item in result.content:
309 | if content_item.type == 'text':
310 | return self._extract_json_from_text(content_item.text)
311 | else:
312 | raise ValueError(
313 | f'Could not extract structured data from model response: {result.content}'
314 | )
315 |
316 | # If we get here, we couldn't parse a structured response
317 | raise ValueError(
318 | f'Could not extract structured data from model response: {result.content}'
319 | )
320 |
321 | except anthropic.RateLimitError as e:
322 | raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
323 | except anthropic.APIError as e:
324 | # Special case for content policy violations. We convert these to RefusalError
325 | # to bypass the retry mechanism, as retrying policy-violating content will always fail.
326 | # This avoids wasting API calls and provides more specific error messaging to the user.
327 | if 'refused to respond' in str(e).lower():
328 | raise RefusalError(str(e)) from e
329 | raise e
330 | except Exception as e:
331 | raise e
332 |
333 | async def generate_response(
334 | self,
335 | messages: list[Message],
336 | response_model: type[BaseModel] | None = None,
337 | max_tokens: int | None = None,
338 | model_size: ModelSize = ModelSize.medium,
339 | group_id: str | None = None,
340 | prompt_name: str | None = None,
341 | ) -> dict[str, typing.Any]:
342 | """
343 | Generate a response from the LLM.
344 |
345 | Args:
346 | messages: List of message objects to send to the LLM.
347 | response_model: Optional Pydantic model to use for structured output.
348 | max_tokens: Maximum number of tokens to generate.
349 |
350 | Returns:
351 | Dictionary containing the structured response from the LLM.
352 |
353 | Raises:
354 | RateLimitError: If the rate limit is exceeded.
355 | RefusalError: If the LLM refuses to respond.
356 | Exception: If an error occurs during the generation process.
357 | """
358 | if max_tokens is None:
359 | max_tokens = self.max_tokens
360 |
361 | # Wrap entire operation in tracing span
362 | with self.tracer.start_span('llm.generate') as span:
363 | attributes = {
364 | 'llm.provider': 'anthropic',
365 | 'model.size': model_size.value,
366 | 'max_tokens': max_tokens,
367 | }
368 | if prompt_name:
369 | attributes['prompt.name'] = prompt_name
370 | span.add_attributes(attributes)
371 |
372 | retry_count = 0
373 | max_retries = 2
374 | last_error: Exception | None = None
375 |
376 | while retry_count <= max_retries:
377 | try:
378 | response = await self._generate_response(
379 | messages, response_model, max_tokens, model_size
380 | )
381 |
382 | # If we have a response_model, attempt to validate the response
383 | if response_model is not None:
384 | # Validate the response against the response_model
385 | model_instance = response_model(**response)
386 | return model_instance.model_dump()
387 |
388 | # If no validation needed, return the response
389 | return response
390 |
391 | except (RateLimitError, RefusalError):
392 | # These errors should not trigger retries
393 | span.set_status('error', str(last_error))
394 | raise
395 | except Exception as e:
396 | last_error = e
397 |
398 | if retry_count >= max_retries:
399 | if isinstance(e, ValidationError):
400 | logger.error(
401 | f'Validation error after {retry_count}/{max_retries} attempts: {e}'
402 | )
403 | else:
404 | logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
405 | span.set_status('error', str(e))
406 | span.record_exception(e)
407 | raise e
408 |
409 | if isinstance(e, ValidationError):
410 | response_model_cast = typing.cast(type[BaseModel], response_model)
411 | error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
412 | else:
413 | error_context = (
414 | f'The previous response attempt was invalid. '
415 | f'Error type: {e.__class__.__name__}. '
416 | f'Error details: {str(e)}. '
417 | f'Please try again with a valid response.'
418 | )
419 |
420 | # Common retry logic
421 | retry_count += 1
422 | messages.append(Message(role='user', content=error_context))
423 | logger.warning(
424 | f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
425 | )
426 |
427 | # If we somehow get here, raise the last error
428 | span.set_status('error', str(last_error))
429 | raise last_error or Exception('Max retries exceeded with no specific error')
430 |
```