#
tokens: 48680/50000 11/234 files (page 6/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 6 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/signatures/version1/cla.json:
--------------------------------------------------------------------------------

```json
  1 | {
  2 |   "signedContributors": [
  3 |     {
  4 |       "name": "colombod",
  5 |       "id": 375556,
  6 |       "comment_id": 2761979440,
  7 |       "created_at": "2025-03-28T17:21:29Z",
  8 |       "repoId": 840056306,
  9 |       "pullRequestNo": 310
 10 |     },
 11 |     {
 12 |       "name": "evanmschultz",
 13 |       "id": 3806601,
 14 |       "comment_id": 2813673237,
 15 |       "created_at": "2025-04-17T17:56:24Z",
 16 |       "repoId": 840056306,
 17 |       "pullRequestNo": 372
 18 |     },
 19 |     {
 20 |       "name": "soichisumi",
 21 |       "id": 30210641,
 22 |       "comment_id": 2818469528,
 23 |       "created_at": "2025-04-21T14:02:11Z",
 24 |       "repoId": 840056306,
 25 |       "pullRequestNo": 382
 26 |     },
 27 |     {
 28 |       "name": "drumnation",
 29 |       "id": 18486434,
 30 |       "comment_id": 2822330188,
 31 |       "created_at": "2025-04-22T19:51:09Z",
 32 |       "repoId": 840056306,
 33 |       "pullRequestNo": 389
 34 |     },
 35 |     {
 36 |       "name": "jackaldenryan",
 37 |       "id": 61809814,
 38 |       "comment_id": 2845356793,
 39 |       "created_at": "2025-05-01T17:51:11Z",
 40 |       "repoId": 840056306,
 41 |       "pullRequestNo": 429
 42 |     },
 43 |     {
 44 |       "name": "t41372",
 45 |       "id": 36402030,
 46 |       "comment_id": 2849035400,
 47 |       "created_at": "2025-05-04T06:24:37Z",
 48 |       "repoId": 840056306,
 49 |       "pullRequestNo": 438
 50 |     },
 51 |     {
 52 |       "name": "markalosey",
 53 |       "id": 1949914,
 54 |       "comment_id": 2878173826,
 55 |       "created_at": "2025-05-13T23:27:16Z",
 56 |       "repoId": 840056306,
 57 |       "pullRequestNo": 486
 58 |     },
 59 |     {
 60 |       "name": "adamkatav",
 61 |       "id": 13109136,
 62 |       "comment_id": 2887184706,
 63 |       "created_at": "2025-05-16T16:29:22Z",
 64 |       "repoId": 840056306,
 65 |       "pullRequestNo": 493
 66 |     },
 67 |     {
 68 |       "name": "realugbun",
 69 |       "id": 74101927,
 70 |       "comment_id": 2899731784,
 71 |       "created_at": "2025-05-22T02:36:44Z",
 72 |       "repoId": 840056306,
 73 |       "pullRequestNo": 513
 74 |     },
 75 |     {
 76 |       "name": "dudizimber",
 77 |       "id": 16744955,
 78 |       "comment_id": 2912211548,
 79 |       "created_at": "2025-05-27T11:45:57Z",
 80 |       "repoId": 840056306,
 81 |       "pullRequestNo": 525
 82 |     },
 83 |     {
 84 |       "name": "galshubeli",
 85 |       "id": 124919062,
 86 |       "comment_id": 2912289100,
 87 |       "created_at": "2025-05-27T12:15:03Z",
 88 |       "repoId": 840056306,
 89 |       "pullRequestNo": 525
 90 |     },
 91 |     {
 92 |       "name": "TheEpTic",
 93 |       "id": 326774,
 94 |       "comment_id": 2917970901,
 95 |       "created_at": "2025-05-29T01:26:54Z",
 96 |       "repoId": 840056306,
 97 |       "pullRequestNo": 541
 98 |     },
 99 |     {
100 |       "name": "PrettyWood",
101 |       "id": 18406791,
102 |       "comment_id": 2938495182,
103 |       "created_at": "2025-06-04T04:44:59Z",
104 |       "repoId": 840056306,
105 |       "pullRequestNo": 558
106 |     },
107 |     {
108 |       "name": "denyska",
109 |       "id": 1242726,
110 |       "comment_id": 2957480685,
111 |       "created_at": "2025-06-10T02:08:05Z",
112 |       "repoId": 840056306,
113 |       "pullRequestNo": 574
114 |     },
115 |     {
116 |       "name": "LongPML",
117 |       "id": 59755436,
118 |       "comment_id": 2965391879,
119 |       "created_at": "2025-06-12T07:10:01Z",
120 |       "repoId": 840056306,
121 |       "pullRequestNo": 579
122 |     },
123 |     {
124 |       "name": "karn09",
125 |       "id": 3743119,
126 |       "comment_id": 2973492225,
127 |       "created_at": "2025-06-15T04:45:13Z",
128 |       "repoId": 840056306,
129 |       "pullRequestNo": 584
130 |     },
131 |     {
132 |       "name": "abab-dev",
133 |       "id": 146825408,
134 |       "comment_id": 2975719469,
135 |       "created_at": "2025-06-16T09:12:53Z",
136 |       "repoId": 840056306,
137 |       "pullRequestNo": 588
138 |     },
139 |     {
140 |       "name": "thorchh",
141 |       "id": 75025911,
142 |       "comment_id": 2982990164,
143 |       "created_at": "2025-06-18T07:19:38Z",
144 |       "repoId": 840056306,
145 |       "pullRequestNo": 601
146 |     },
147 |     {
148 |       "name": "robrichardson13",
149 |       "id": 9492530,
150 |       "comment_id": 2989798338,
151 |       "created_at": "2025-06-20T04:59:06Z",
152 |       "repoId": 840056306,
153 |       "pullRequestNo": 611
154 |     },
155 |     {
156 |       "name": "gkorland",
157 |       "id": 753206,
158 |       "comment_id": 2993690025,
159 |       "created_at": "2025-06-21T17:35:37Z",
160 |       "repoId": 840056306,
161 |       "pullRequestNo": 609
162 |     },
163 |     {
164 |       "name": "urmzd",
165 |       "id": 45431570,
166 |       "comment_id": 3027098935,
167 |       "created_at": "2025-07-02T09:16:46Z",
168 |       "repoId": 840056306,
169 |       "pullRequestNo": 661
170 |     },
171 |     {
172 |       "name": "jawwadfirdousi",
173 |       "id": 10913083,
174 |       "comment_id": 3027808026,
175 |       "created_at": "2025-07-02T13:02:22Z",
176 |       "repoId": 840056306,
177 |       "pullRequestNo": 663
178 |     },
179 |     {
180 |       "name": "jamesindeed",
181 |       "id": 60527576,
182 |       "comment_id": 3028293328,
183 |       "created_at": "2025-07-02T15:24:23Z",
184 |       "repoId": 840056306,
185 |       "pullRequestNo": 664
186 |     },
187 |     {
188 |       "name": "dev-mirzabicer",
189 |       "id": 90691873,
190 |       "comment_id": 3035836506,
191 |       "created_at": "2025-07-04T11:47:08Z",
192 |       "repoId": 840056306,
193 |       "pullRequestNo": 672
194 |     },
195 |     {
196 |       "name": "zeroasterisk",
197 |       "id": 23422,
198 |       "comment_id": 3040716245,
199 |       "created_at": "2025-07-06T03:41:19Z",
200 |       "repoId": 840056306,
201 |       "pullRequestNo": 679
202 |     },
203 |     {
204 |       "name": "charlesmcchan",
205 |       "id": 425857,
206 |       "comment_id": 3066732289,
207 |       "created_at": "2025-07-13T08:54:26Z",
208 |       "repoId": 840056306,
209 |       "pullRequestNo": 711
210 |     },
211 |     {
212 |       "name": "soraxas",
213 |       "id": 22362177,
214 |       "comment_id": 3084093750,
215 |       "created_at": "2025-07-17T13:33:25Z",
216 |       "repoId": 840056306,
217 |       "pullRequestNo": 741
218 |     },
219 |     {
220 |       "name": "sdht0",
221 |       "id": 867424,
222 |       "comment_id": 3092540466,
223 |       "created_at": "2025-07-19T19:52:21Z",
224 |       "repoId": 840056306,
225 |       "pullRequestNo": 748
226 |     },
227 |     {
228 |       "name": "Naseem77",
229 |       "id": 34807727,
230 |       "comment_id": 3093746709,
231 |       "created_at": "2025-07-20T07:07:33Z",
232 |       "repoId": 840056306,
233 |       "pullRequestNo": 742
234 |     },
235 |     {
236 |       "name": "kavenGw",
237 |       "id": 3193355,
238 |       "comment_id": 3100620568,
239 |       "created_at": "2025-07-22T02:58:50Z",
240 |       "repoId": 840056306,
241 |       "pullRequestNo": 750
242 |     },
243 |     {
244 |       "name": "paveljakov",
245 |       "id": 45147436,
246 |       "comment_id": 3113955940,
247 |       "created_at": "2025-07-24T15:39:36Z",
248 |       "repoId": 840056306,
249 |       "pullRequestNo": 764
250 |     },
251 |     {
252 |       "name": "gifflet",
253 |       "id": 33522742,
254 |       "comment_id": 3133869379,
255 |       "created_at": "2025-07-29T20:00:27Z",
256 |       "repoId": 840056306,
257 |       "pullRequestNo": 782
258 |     },
259 |     {
260 |       "name": "bechbd",
261 |       "id": 6898505,
262 |       "comment_id": 3140501814,
263 |       "created_at": "2025-07-31T15:58:08Z",
264 |       "repoId": 840056306,
265 |       "pullRequestNo": 793
266 |     },
267 |     {
268 |       "name": "hugo-son",
269 |       "id": 141999572,
270 |       "comment_id": 3155009405,
271 |       "created_at": "2025-08-05T12:27:09Z",
272 |       "repoId": 840056306,
273 |       "pullRequestNo": 805
274 |     },
275 |     {
276 |       "name": "mvanders",
277 |       "id": 758617,
278 |       "comment_id": 3160523661,
279 |       "created_at": "2025-08-06T14:56:21Z",
280 |       "repoId": 840056306,
281 |       "pullRequestNo": 808
282 |     },
283 |     {
284 |       "name": "v-khanna",
285 |       "id": 102773390,
286 |       "comment_id": 3162200130,
287 |       "created_at": "2025-08-07T02:23:09Z",
288 |       "repoId": 840056306,
289 |       "pullRequestNo": 812
290 |     },
291 |     {
292 |       "name": "vjeeva",
293 |       "id": 13189349,
294 |       "comment_id": 3165600173,
295 |       "created_at": "2025-08-07T20:24:08Z",
296 |       "repoId": 840056306,
297 |       "pullRequestNo": 814
298 |     },
299 |     {
300 |       "name": "liebertar",
301 |       "id": 99405438,
302 |       "comment_id": 3166905812,
303 |       "created_at": "2025-08-08T07:52:27Z",
304 |       "repoId": 840056306,
305 |       "pullRequestNo": 816
306 |     },
307 |     {
308 |       "name": "CaroLe-prw",
309 |       "id": 42695882,
310 |       "comment_id": 3187949734,
311 |       "created_at": "2025-08-14T10:29:25Z",
312 |       "repoId": 840056306,
313 |       "pullRequestNo": 833
314 |     },
315 |     {
316 |       "name": "Wizmann",
317 |       "id": 1270921,
318 |       "comment_id": 3196208374,
319 |       "created_at": "2025-08-18T11:09:35Z",
320 |       "repoId": 840056306,
321 |       "pullRequestNo": 842
322 |     },
323 |     {
324 |       "name": "liangyuanpeng",
325 |       "id": 28711504,
326 |       "comment_id": 3205841804,
327 |       "created_at": "2025-08-20T11:35:42Z",
328 |       "repoId": 840056306,
329 |       "pullRequestNo": 847
330 |     },
331 |     {
332 |       "name": "aktek-yazge",
333 |       "id": 218602044,
334 |       "comment_id": 3078757968,
335 |       "created_at": "2025-07-16T14:00:40Z",
336 |       "repoId": 840056306,
337 |       "pullRequestNo": 735
338 |     },
339 |     {
340 |       "name": "Shelvak",
341 |       "id": 873323,
342 |       "comment_id": 3243330690,
343 |       "created_at": "2025-09-01T22:26:32Z",
344 |       "repoId": 840056306,
345 |       "pullRequestNo": 885
346 |     },
347 |     {
348 |       "name": "maskshell",
349 |       "id": 5113279,
350 |       "comment_id": 3244187860,
351 |       "created_at": "2025-09-02T07:48:05Z",
352 |       "repoId": 840056306,
353 |       "pullRequestNo": 886
354 |     },
355 |     {
356 |       "name": "jeanlucthumm",
357 |       "id": 4934853,
358 |       "comment_id": 3255120747,
359 |       "created_at": "2025-09-04T18:49:57Z",
360 |       "repoId": 840056306,
361 |       "pullRequestNo": 892
362 |     },
363 |     {
364 |       "name": "Bit-urd",
365 |       "id": 43745133,
366 |       "comment_id": 3264006888,
367 |       "created_at": "2025-09-07T20:01:08Z",
368 |       "repoId": 840056306,
369 |       "pullRequestNo": 895
370 |     },
371 |     {
372 |       "name": "DavIvek",
373 |       "id": 88043717,
374 |       "comment_id": 3269895491,
375 |       "created_at": "2025-09-09T09:59:47Z",
376 |       "repoId": 840056306,
377 |       "pullRequestNo": 900
378 |     },
379 |     {
380 |       "name": "gsw945",
381 |       "id": 6281968,
382 |       "comment_id": 3270396586,
383 |       "created_at": "2025-09-09T12:05:27Z",
384 |       "repoId": 840056306,
385 |       "pullRequestNo": 901
386 |     },
387 |     {
388 |       "name": "luan122",
389 |       "id": 5606023,
390 |       "comment_id": 3287095238,
391 |       "created_at": "2025-09-12T23:14:21Z",
392 |       "repoId": 840056306,
393 |       "pullRequestNo": 908
394 |     },
395 |     {
396 |       "name": "Brandtweary",
397 |       "id": 7968557,
398 |       "comment_id": 3314191937,
399 |       "created_at": "2025-09-19T23:37:33Z",
400 |       "repoId": 840056306,
401 |       "pullRequestNo": 916
402 |     },
403 |     {
404 |       "name": "clsferguson",
405 |       "id": 48876201,
406 |       "comment_id": 3368715688,
407 |       "created_at": "2025-10-05T03:30:10Z",
408 |       "repoId": 840056306,
409 |       "pullRequestNo": 981
410 |     },
411 |     {
412 |       "name": "ngaiyuc",
413 |       "id": 69293565,
414 |       "comment_id": 3407383300,
415 |       "created_at": "2025-10-15T16:45:10Z",
416 |       "repoId": 840056306,
417 |       "pullRequestNo": 1005
418 |     },
419 |     {
420 |       "name": "0fism",
421 |       "id": 63762457,
422 |       "comment_id": 3407328042,
423 |       "created_at": "2025-10-15T16:29:33Z",
424 |       "repoId": 840056306,
425 |       "pullRequestNo": 1005
426 |     },
427 |     {
428 |       "name": "dontang97",
429 |       "id": 88384441,
430 |       "comment_id": 3431443627,
431 |       "created_at": "2025-10-22T09:52:01Z",
432 |       "repoId": 840056306,
433 |       "pullRequestNo": 1020
434 |     },
435 |     {
436 |       "name": "didier-durand",
437 |       "id": 2927957,
438 |       "comment_id": 3460571645,
439 |       "created_at": "2025-10-29T09:31:25Z",
440 |       "repoId": 840056306,
441 |       "pullRequestNo": 1028
442 |     },
443 |     {
444 |       "name": "anubhavgirdhar1",
445 |       "id": 85768253,
446 |       "comment_id": 3468525446,
447 |       "created_at": "2025-10-30T15:11:58Z",
448 |       "repoId": 840056306,
449 |       "pullRequestNo": 1035
450 |     },
451 |     {
452 |       "name": "Galleons2029",
453 |       "id": 88185941,
454 |       "comment_id": 3495884964,
455 |       "created_at": "2025-11-06T08:39:46Z",
456 |       "repoId": 840056306,
457 |       "pullRequestNo": 1053
458 |     },
459 |     {
460 |       "name": "supmo668",
461 |       "id": 28805779,
462 |       "comment_id": 3550309664,
463 |       "created_at": "2025-11-19T01:56:25Z",
464 |       "repoId": 840056306,
465 |       "pullRequestNo": 1072
466 |     },
467 |     {
468 |       "name": "donbr",
469 |       "id": 7340008,
470 |       "comment_id": 3568970102,
471 |       "created_at": "2025-11-24T05:19:42Z",
472 |       "repoId": 840056306,
473 |       "pullRequestNo": 1081
474 |     },
475 |     {
476 |       "name": "apetti1920",
477 |       "id": 4706645,
478 |       "comment_id": 3572726648,
479 |       "created_at": "2025-11-24T21:07:34Z",
480 |       "repoId": 840056306,
481 |       "pullRequestNo": 1084
482 |     }
483 |   ]
484 | }
```

--------------------------------------------------------------------------------
/tests/test_entity_exclusion_int.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from datetime import datetime, timezone
 18 | 
 19 | import pytest
 20 | from pydantic import BaseModel, Field
 21 | 
 22 | from graphiti_core.graphiti import Graphiti
 23 | from graphiti_core.helpers import validate_excluded_entity_types
 24 | from tests.helpers_test import drivers, get_driver
 25 | 
 26 | pytestmark = pytest.mark.integration
 27 | pytest_plugins = ('pytest_asyncio',)
 28 | 
 29 | 
 30 | # Test entity type definitions
 31 | class Person(BaseModel):
 32 |     """A human person mentioned in the conversation."""
 33 | 
 34 |     first_name: str | None = Field(None, description='First name of the person')
 35 |     last_name: str | None = Field(None, description='Last name of the person')
 36 |     occupation: str | None = Field(None, description='Job or profession of the person')
 37 | 
 38 | 
 39 | class Organization(BaseModel):
 40 |     """A company, institution, or organized group."""
 41 | 
 42 |     organization_type: str | None = Field(
 43 |         None, description='Type of organization (company, NGO, etc.)'
 44 |     )
 45 |     industry: str | None = Field(
 46 |         None, description='Industry or sector the organization operates in'
 47 |     )
 48 | 
 49 | 
 50 | class Location(BaseModel):
 51 |     """A geographic location, place, or address."""
 52 | 
 53 |     location_type: str | None = Field(
 54 |         None, description='Type of location (city, country, building, etc.)'
 55 |     )
 56 |     coordinates: str | None = Field(None, description='Geographic coordinates if available')
 57 | 
 58 | 
 59 | @pytest.mark.asyncio
 60 | @pytest.mark.parametrize(
 61 |     'driver',
 62 |     drivers,
 63 | )
 64 | async def test_exclude_default_entity_type(driver):
 65 |     """Test excluding the default 'Entity' type while keeping custom types."""
 66 |     graphiti = Graphiti(graph_driver=get_driver(driver))
 67 | 
 68 |     try:
 69 |         await graphiti.build_indices_and_constraints()
 70 | 
 71 |         # Define entity types but exclude the default 'Entity' type
 72 |         entity_types = {
 73 |             'Person': Person,
 74 |             'Organization': Organization,
 75 |         }
 76 | 
 77 |         # Add an episode that would normally create both Entity and custom type entities
 78 |         episode_content = (
 79 |             'John Smith works at Acme Corporation in New York. The weather is nice today.'
 80 |         )
 81 | 
 82 |         result = await graphiti.add_episode(
 83 |             name='Business Meeting',
 84 |             episode_body=episode_content,
 85 |             source_description='Meeting notes',
 86 |             reference_time=datetime.now(timezone.utc),
 87 |             entity_types=entity_types,
 88 |             excluded_entity_types=['Entity'],  # Exclude default type
 89 |             group_id='test_exclude_default',
 90 |         )
 91 | 
 92 |         # Verify that nodes were created (custom types should still work)
 93 |         assert result is not None
 94 | 
 95 |         # Search for nodes to verify only custom types were created
 96 |         search_results = await graphiti.search_(
 97 |             query='John Smith Acme Corporation', group_ids=['test_exclude_default']
 98 |         )
 99 | 
100 |         # Check that entities were created but with specific types, not default 'Entity'
101 |         found_nodes = search_results.nodes
102 |         for node in found_nodes:
103 |             assert 'Entity' in node.labels  # All nodes should have Entity label
104 |             # But they should also have specific type labels
105 |             assert any(label in ['Person', 'Organization'] for label in node.labels), (
106 |                 f'Node {node.name} should have a specific type label, got: {node.labels}'
107 |             )
108 | 
109 |         # Clean up
110 |         await _cleanup_test_nodes(graphiti, 'test_exclude_default')
111 | 
112 |     finally:
113 |         await graphiti.close()
114 | 
115 | 
116 | @pytest.mark.asyncio
117 | @pytest.mark.parametrize(
118 |     'driver',
119 |     drivers,
120 | )
121 | async def test_exclude_specific_custom_types(driver):
122 |     """Test excluding specific custom entity types while keeping others."""
123 |     graphiti = Graphiti(graph_driver=get_driver(driver))
124 | 
125 |     try:
126 |         await graphiti.build_indices_and_constraints()
127 | 
128 |         # Define multiple entity types
129 |         entity_types = {
130 |             'Person': Person,
131 |             'Organization': Organization,
132 |             'Location': Location,
133 |         }
134 | 
135 |         # Add an episode with content that would create all types
136 |         episode_content = (
137 |             'Sarah Johnson from Google visited the San Francisco office to discuss the new project.'
138 |         )
139 | 
140 |         result = await graphiti.add_episode(
141 |             name='Office Visit',
142 |             episode_body=episode_content,
143 |             source_description='Visit report',
144 |             reference_time=datetime.now(timezone.utc),
145 |             entity_types=entity_types,
146 |             excluded_entity_types=['Organization', 'Location'],  # Exclude these types
147 |             group_id='test_exclude_custom',
148 |         )
149 | 
150 |         assert result is not None
151 | 
152 |         # Search for nodes to verify only Person and Entity types were created
153 |         search_results = await graphiti.search_(
154 |             query='Sarah Johnson Google San Francisco', group_ids=['test_exclude_custom']
155 |         )
156 | 
157 |         found_nodes = search_results.nodes
158 | 
159 |         # Should have Person and Entity type nodes, but no Organization or Location
160 |         for node in found_nodes:
161 |             assert 'Entity' in node.labels
162 |             # Should not have excluded types
163 |             assert 'Organization' not in node.labels, (
164 |                 f'Found excluded Organization in node: {node.name}'
165 |             )
166 |             assert 'Location' not in node.labels, f'Found excluded Location in node: {node.name}'
167 | 
168 |         # Should find at least one Person entity (Sarah Johnson)
169 |         person_nodes = [n for n in found_nodes if 'Person' in n.labels]
170 |         assert len(person_nodes) > 0, 'Should have found at least one Person entity'
171 | 
172 |         # Clean up
173 |         await _cleanup_test_nodes(graphiti, 'test_exclude_custom')
174 | 
175 |     finally:
176 |         await graphiti.close()
177 | 
178 | 
179 | @pytest.mark.asyncio
180 | @pytest.mark.parametrize(
181 |     'driver',
182 |     drivers,
183 | )
184 | async def test_exclude_all_types(driver):
185 |     """Test excluding all entity types (edge case)."""
186 |     graphiti = Graphiti(graph_driver=get_driver(driver))
187 | 
188 |     try:
189 |         await graphiti.build_indices_and_constraints()
190 | 
191 |         entity_types = {
192 |             'Person': Person,
193 |             'Organization': Organization,
194 |         }
195 | 
196 |         # Exclude all types
197 |         result = await graphiti.add_episode(
198 |             name='No Entities',
199 |             episode_body='This text mentions John and Microsoft but no entities should be created.',
200 |             source_description='Test content',
201 |             reference_time=datetime.now(timezone.utc),
202 |             entity_types=entity_types,
203 |             excluded_entity_types=['Entity', 'Person', 'Organization'],  # Exclude everything
204 |             group_id='test_exclude_all',
205 |         )
206 | 
207 |         assert result is not None
208 | 
209 |         # Search for nodes - should find very few or none from this episode
210 |         search_results = await graphiti.search_(
211 |             query='John Microsoft', group_ids=['test_exclude_all']
212 |         )
213 | 
214 |         # There should be minimal to no entities created
215 |         found_nodes = search_results.nodes
216 |         assert len(found_nodes) == 0, (
217 |             f'Expected no entities, but found: {[n.name for n in found_nodes]}'
218 |         )
219 | 
220 |         # Clean up
221 |         await _cleanup_test_nodes(graphiti, 'test_exclude_all')
222 | 
223 |     finally:
224 |         await graphiti.close()
225 | 
226 | 
227 | @pytest.mark.asyncio
228 | @pytest.mark.parametrize(
229 |     'driver',
230 |     drivers,
231 | )
232 | async def test_exclude_no_types(driver):
233 |     """Test normal behavior when no types are excluded (baseline test)."""
234 |     graphiti = Graphiti(graph_driver=get_driver(driver))
235 | 
236 |     try:
237 |         await graphiti.build_indices_and_constraints()
238 | 
239 |         entity_types = {
240 |             'Person': Person,
241 |             'Organization': Organization,
242 |         }
243 | 
244 |         # Don't exclude any types
245 |         result = await graphiti.add_episode(
246 |             name='Normal Behavior',
247 |             episode_body='Alice Smith works at TechCorp.',
248 |             source_description='Normal test',
249 |             reference_time=datetime.now(timezone.utc),
250 |             entity_types=entity_types,
251 |             excluded_entity_types=None,  # No exclusions
252 |             group_id='test_exclude_none',
253 |         )
254 | 
255 |         assert result is not None
256 | 
257 |         # Search for nodes - should find entities of all types
258 |         search_results = await graphiti.search_(
259 |             query='Alice Smith TechCorp', group_ids=['test_exclude_none']
260 |         )
261 | 
262 |         found_nodes = search_results.nodes
263 |         assert len(found_nodes) > 0, 'Should have found some entities'
264 | 
265 |         # Should have both Person and Organization entities
266 |         person_nodes = [n for n in found_nodes if 'Person' in n.labels]
267 |         org_nodes = [n for n in found_nodes if 'Organization' in n.labels]
268 | 
269 |         assert len(person_nodes) > 0, 'Should have found Person entities'
270 |         assert len(org_nodes) > 0, 'Should have found Organization entities'
271 | 
272 |         # Clean up
273 |         await _cleanup_test_nodes(graphiti, 'test_exclude_none')
274 | 
275 |     finally:
276 |         await graphiti.close()
277 | 
278 | 
279 | def test_validation_valid_excluded_types():
280 |     """Test validation function with valid excluded types."""
281 |     entity_types = {
282 |         'Person': Person,
283 |         'Organization': Organization,
284 |     }
285 | 
286 |     # Valid exclusions
287 |     assert validate_excluded_entity_types(['Entity'], entity_types) is True
288 |     assert validate_excluded_entity_types(['Person'], entity_types) is True
289 |     assert validate_excluded_entity_types(['Entity', 'Person'], entity_types) is True
290 |     assert validate_excluded_entity_types(None, entity_types) is True
291 |     assert validate_excluded_entity_types([], entity_types) is True
292 | 
293 | 
294 | def test_validation_invalid_excluded_types():
295 |     """Test validation function with invalid excluded types."""
296 |     entity_types = {
297 |         'Person': Person,
298 |         'Organization': Organization,
299 |     }
300 | 
301 |     # Invalid exclusions should raise ValueError
302 |     with pytest.raises(ValueError, match='Invalid excluded entity types'):
303 |         validate_excluded_entity_types(['InvalidType'], entity_types)
304 | 
305 |     with pytest.raises(ValueError, match='Invalid excluded entity types'):
306 |         validate_excluded_entity_types(['Person', 'NonExistentType'], entity_types)
307 | 
308 | 
309 | @pytest.mark.asyncio
310 | @pytest.mark.parametrize(
311 |     'driver',
312 |     drivers,
313 | )
314 | async def test_excluded_types_parameter_validation_in_add_episode(driver):
315 |     """Test that add_episode validates excluded_entity_types parameter."""
316 |     graphiti = Graphiti(graph_driver=get_driver(driver))
317 | 
318 |     try:
319 |         entity_types = {
320 |             'Person': Person,
321 |         }
322 | 
323 |         # Should raise ValueError for invalid excluded type
324 |         with pytest.raises(ValueError, match='Invalid excluded entity types'):
325 |             await graphiti.add_episode(
326 |                 name='Invalid Test',
327 |                 episode_body='Test content',
328 |                 source_description='Test',
329 |                 reference_time=datetime.now(timezone.utc),
330 |                 entity_types=entity_types,
331 |                 excluded_entity_types=['NonExistentType'],
332 |                 group_id='test_validation',
333 |             )
334 | 
335 |     finally:
336 |         await graphiti.close()
337 | 
338 | 
339 | async def _cleanup_test_nodes(graphiti: Graphiti, group_id: str):
340 |     """Helper function to clean up test nodes."""
341 |     try:
342 |         # Get all nodes for this group
343 |         search_results = await graphiti.search_(query='*', group_ids=[group_id])
344 | 
345 |         # Delete all found nodes
346 |         for node in search_results.nodes:
347 |             await node.delete(graphiti.driver)
348 | 
349 |     except Exception as e:
350 |         # Log but don't fail the test if cleanup fails
351 |         print(f'Warning: Failed to clean up test nodes for group {group_id}: {e}')
352 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_integration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | HTTP/SSE Integration test for the refactored Graphiti MCP Server.
  4 | Tests server functionality when running in SSE (Server-Sent Events) mode over HTTP.
  5 | Note: This test requires the server to be running with --transport sse.
  6 | """
  7 | 
  8 | import asyncio
  9 | import json
 10 | import time
 11 | from typing import Any
 12 | 
 13 | import httpx
 14 | 
 15 | 
 16 | class MCPIntegrationTest:
 17 |     """Integration test client for Graphiti MCP Server."""
 18 | 
 19 |     def __init__(self, base_url: str = 'http://localhost:8000'):
 20 |         self.base_url = base_url
 21 |         self.client = httpx.AsyncClient(timeout=30.0)
 22 |         self.test_group_id = f'test_group_{int(time.time())}'
 23 | 
 24 |     async def __aenter__(self):
 25 |         return self
 26 | 
 27 |     async def __aexit__(self, exc_type, exc_val, exc_tb):
 28 |         await self.client.aclose()
 29 | 
 30 |     async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
 31 |         """Call an MCP tool via the SSE endpoint."""
 32 |         # MCP protocol message structure
 33 |         message = {
 34 |             'jsonrpc': '2.0',
 35 |             'id': int(time.time() * 1000),
 36 |             'method': 'tools/call',
 37 |             'params': {'name': tool_name, 'arguments': arguments},
 38 |         }
 39 | 
 40 |         try:
 41 |             response = await self.client.post(
 42 |                 f'{self.base_url}/message',
 43 |                 json=message,
 44 |                 headers={'Content-Type': 'application/json'},
 45 |             )
 46 | 
 47 |             if response.status_code != 200:
 48 |                 return {'error': f'HTTP {response.status_code}: {response.text}'}
 49 | 
 50 |             result = response.json()
 51 |             return result.get('result', result)
 52 | 
 53 |         except Exception as e:
 54 |             return {'error': str(e)}
 55 | 
 56 |     async def test_server_status(self) -> bool:
 57 |         """Test the get_status resource."""
 58 |         print('🔍 Testing server status...')
 59 | 
 60 |         try:
 61 |             response = await self.client.get(f'{self.base_url}/resources/http://graphiti/status')
 62 |             if response.status_code == 200:
 63 |                 status = response.json()
 64 |                 print(f'   ✅ Server status: {status.get("status", "unknown")}')
 65 |                 return status.get('status') == 'ok'
 66 |             else:
 67 |                 print(f'   ❌ Status check failed: HTTP {response.status_code}')
 68 |                 return False
 69 |         except Exception as e:
 70 |             print(f'   ❌ Status check failed: {e}')
 71 |             return False
 72 | 
 73 |     async def test_add_memory(self) -> dict[str, str]:
 74 |         """Test adding various types of memory episodes."""
 75 |         print('📝 Testing add_memory functionality...')
 76 | 
 77 |         episode_results = {}
 78 | 
 79 |         # Test 1: Add text episode
 80 |         print('   Testing text episode...')
 81 |         result = await self.call_mcp_tool(
 82 |             'add_memory',
 83 |             {
 84 |                 'name': 'Test Company News',
 85 |                 'episode_body': 'Acme Corp announced a revolutionary new AI product that will transform the industry. The CEO mentioned this is their biggest launch since 2020.',
 86 |                 'source': 'text',
 87 |                 'source_description': 'news article',
 88 |                 'group_id': self.test_group_id,
 89 |             },
 90 |         )
 91 | 
 92 |         if 'error' in result:
 93 |             print(f'   ❌ Text episode failed: {result["error"]}')
 94 |         else:
 95 |             print(f'   ✅ Text episode queued: {result.get("message", "Success")}')
 96 |             episode_results['text'] = 'success'
 97 | 
 98 |         # Test 2: Add JSON episode
 99 |         print('   Testing JSON episode...')
100 |         json_data = {
101 |             'company': {'name': 'TechCorp', 'founded': 2010},
102 |             'products': [
103 |                 {'id': 'P001', 'name': 'CloudSync', 'category': 'software'},
104 |                 {'id': 'P002', 'name': 'DataMiner', 'category': 'analytics'},
105 |             ],
106 |             'employees': 150,
107 |         }
108 | 
109 |         result = await self.call_mcp_tool(
110 |             'add_memory',
111 |             {
112 |                 'name': 'Company Profile',
113 |                 'episode_body': json.dumps(json_data),
114 |                 'source': 'json',
115 |                 'source_description': 'CRM data',
116 |                 'group_id': self.test_group_id,
117 |             },
118 |         )
119 | 
120 |         if 'error' in result:
121 |             print(f'   ❌ JSON episode failed: {result["error"]}')
122 |         else:
123 |             print(f'   ✅ JSON episode queued: {result.get("message", "Success")}')
124 |             episode_results['json'] = 'success'
125 | 
126 |         # Test 3: Add message episode
127 |         print('   Testing message episode...')
128 |         result = await self.call_mcp_tool(
129 |             'add_memory',
130 |             {
131 |                 'name': 'Customer Support Chat',
132 |                 'episode_body': "user: What's your return policy?\nassistant: You can return items within 30 days of purchase with receipt.\nuser: Thanks!",
133 |                 'source': 'message',
134 |                 'source_description': 'support chat log',
135 |                 'group_id': self.test_group_id,
136 |             },
137 |         )
138 | 
139 |         if 'error' in result:
140 |             print(f'   ❌ Message episode failed: {result["error"]}')
141 |         else:
142 |             print(f'   ✅ Message episode queued: {result.get("message", "Success")}')
143 |             episode_results['message'] = 'success'
144 | 
145 |         return episode_results
146 | 
147 |     async def wait_for_processing(self, max_wait: int = 30) -> None:
148 |         """Wait for episode processing to complete."""
149 |         print(f'⏳ Waiting up to {max_wait} seconds for episode processing...')
150 | 
151 |         for i in range(max_wait):
152 |             await asyncio.sleep(1)
153 | 
154 |             # Check if we have any episodes
155 |             result = await self.call_mcp_tool(
156 |                 'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
157 |             )
158 | 
159 |             if not isinstance(result, dict) or 'error' in result:
160 |                 continue
161 | 
162 |             if isinstance(result, list) and len(result) > 0:
163 |                 print(f'   ✅ Found {len(result)} processed episodes after {i + 1} seconds')
164 |                 return
165 | 
166 |         print(f'   ⚠️  Still waiting after {max_wait} seconds...')
167 | 
168 |     async def test_search_functions(self) -> dict[str, bool]:
169 |         """Test search functionality."""
170 |         print('🔍 Testing search functions...')
171 | 
172 |         results = {}
173 | 
174 |         # Test search_memory_nodes
175 |         print('   Testing search_memory_nodes...')
176 |         result = await self.call_mcp_tool(
177 |             'search_memory_nodes',
178 |             {
179 |                 'query': 'Acme Corp product launch',
180 |                 'group_ids': [self.test_group_id],
181 |                 'max_nodes': 5,
182 |             },
183 |         )
184 | 
185 |         if 'error' in result:
186 |             print(f'   ❌ Node search failed: {result["error"]}')
187 |             results['nodes'] = False
188 |         else:
189 |             nodes = result.get('nodes', [])
190 |             print(f'   ✅ Node search returned {len(nodes)} nodes')
191 |             results['nodes'] = True
192 | 
193 |         # Test search_memory_facts
194 |         print('   Testing search_memory_facts...')
195 |         result = await self.call_mcp_tool(
196 |             'search_memory_facts',
197 |             {
198 |                 'query': 'company products software',
199 |                 'group_ids': [self.test_group_id],
200 |                 'max_facts': 5,
201 |             },
202 |         )
203 | 
204 |         if 'error' in result:
205 |             print(f'   ❌ Fact search failed: {result["error"]}')
206 |             results['facts'] = False
207 |         else:
208 |             facts = result.get('facts', [])
209 |             print(f'   ✅ Fact search returned {len(facts)} facts')
210 |             results['facts'] = True
211 | 
212 |         return results
213 | 
214 |     async def test_episode_retrieval(self) -> bool:
215 |         """Test episode retrieval."""
216 |         print('📚 Testing episode retrieval...')
217 | 
218 |         result = await self.call_mcp_tool(
219 |             'get_episodes', {'group_id': self.test_group_id, 'last_n': 10}
220 |         )
221 | 
222 |         if 'error' in result:
223 |             print(f'   ❌ Episode retrieval failed: {result["error"]}')
224 |             return False
225 | 
226 |         if isinstance(result, list):
227 |             print(f'   ✅ Retrieved {len(result)} episodes')
228 | 
229 |             # Print episode details
230 |             for i, episode in enumerate(result[:3]):  # Show first 3
231 |                 name = episode.get('name', 'Unknown')
232 |                 source = episode.get('source', 'unknown')
233 |                 print(f'     Episode {i + 1}: {name} (source: {source})')
234 | 
235 |             return len(result) > 0
236 |         else:
237 |             print(f'   ❌ Unexpected result format: {type(result)}')
238 |             return False
239 | 
240 |     async def test_edge_cases(self) -> dict[str, bool]:
241 |         """Test edge cases and error handling."""
242 |         print('🧪 Testing edge cases...')
243 | 
244 |         results = {}
245 | 
246 |         # Test with invalid group_id
247 |         print('   Testing invalid group_id...')
248 |         result = await self.call_mcp_tool(
249 |             'search_memory_nodes',
250 |             {'query': 'nonexistent data', 'group_ids': ['nonexistent_group'], 'max_nodes': 5},
251 |         )
252 | 
253 |         # Should not error, just return empty results
254 |         if 'error' not in result:
255 |             nodes = result.get('nodes', [])
256 |             print(f'   ✅ Invalid group_id handled gracefully (returned {len(nodes)} nodes)')
257 |             results['invalid_group'] = True
258 |         else:
259 |             print(f'   ❌ Invalid group_id caused error: {result["error"]}')
260 |             results['invalid_group'] = False
261 | 
262 |         # Test empty query
263 |         print('   Testing empty query...')
264 |         result = await self.call_mcp_tool(
265 |             'search_memory_nodes', {'query': '', 'group_ids': [self.test_group_id], 'max_nodes': 5}
266 |         )
267 | 
268 |         if 'error' not in result:
269 |             print('   ✅ Empty query handled gracefully')
270 |             results['empty_query'] = True
271 |         else:
272 |             print(f'   ❌ Empty query caused error: {result["error"]}')
273 |             results['empty_query'] = False
274 | 
275 |         return results
276 | 
277 |     async def run_full_test_suite(self) -> dict[str, Any]:
278 |         """Run the complete integration test suite."""
279 |         print('🚀 Starting Graphiti MCP Server Integration Test')
280 |         print(f'   Test group ID: {self.test_group_id}')
281 |         print('=' * 60)
282 | 
283 |         results = {
284 |             'server_status': False,
285 |             'add_memory': {},
286 |             'search': {},
287 |             'episodes': False,
288 |             'edge_cases': {},
289 |             'overall_success': False,
290 |         }
291 | 
292 |         # Test 1: Server Status
293 |         results['server_status'] = await self.test_server_status()
294 |         if not results['server_status']:
295 |             print('❌ Server not responding, aborting tests')
296 |             return results
297 | 
298 |         print()
299 | 
300 |         # Test 2: Add Memory
301 |         results['add_memory'] = await self.test_add_memory()
302 |         print()
303 | 
304 |         # Test 3: Wait for processing
305 |         await self.wait_for_processing()
306 |         print()
307 | 
308 |         # Test 4: Search Functions
309 |         results['search'] = await self.test_search_functions()
310 |         print()
311 | 
312 |         # Test 5: Episode Retrieval
313 |         results['episodes'] = await self.test_episode_retrieval()
314 |         print()
315 | 
316 |         # Test 6: Edge Cases
317 |         results['edge_cases'] = await self.test_edge_cases()
318 |         print()
319 | 
320 |         # Calculate overall success
321 |         memory_success = len(results['add_memory']) > 0
322 |         search_success = any(results['search'].values())
323 |         edge_case_success = any(results['edge_cases'].values())
324 | 
325 |         results['overall_success'] = (
326 |             results['server_status']
327 |             and memory_success
328 |             and results['episodes']
329 |             and (search_success or edge_case_success)  # At least some functionality working
330 |         )
331 | 
332 |         # Print summary
333 |         print('=' * 60)
334 |         print('📊 TEST SUMMARY')
335 |         print(f'   Server Status: {"✅" if results["server_status"] else "❌"}')
336 |         print(
337 |             f'   Memory Operations: {"✅" if memory_success else "❌"} ({len(results["add_memory"])} types)'
338 |         )
339 |         print(f'   Search Functions: {"✅" if search_success else "❌"}')
340 |         print(f'   Episode Retrieval: {"✅" if results["episodes"] else "❌"}')
341 |         print(f'   Edge Cases: {"✅" if edge_case_success else "❌"}')
342 |         print()
343 |         print(f'🎯 OVERALL: {"✅ SUCCESS" if results["overall_success"] else "❌ FAILED"}')
344 | 
345 |         if results['overall_success']:
346 |             print('   The refactored MCP server is working correctly!')
347 |         else:
348 |             print('   Some issues detected. Check individual test results above.')
349 | 
350 |         return results
351 | 
352 | 
353 | async def main():
354 |     """Run the integration test."""
355 |     async with MCPIntegrationTest() as test:
356 |         results = await test.run_full_test_suite()
357 | 
358 |         # Exit with appropriate code
359 |         exit_code = 0 if results['overall_success'] else 1
360 |         exit(exit_code)
361 | 
362 | 
363 | if __name__ == '__main__':
364 |     asyncio.run(main())
365 | 
```

--------------------------------------------------------------------------------
/graphiti_core/driver/falkordb_driver.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import datetime
 19 | import logging
 20 | from typing import TYPE_CHECKING, Any
 21 | 
 22 | if TYPE_CHECKING:
 23 |     from falkordb import Graph as FalkorGraph
 24 |     from falkordb.asyncio import FalkorDB
 25 | else:
 26 |     try:
 27 |         from falkordb import Graph as FalkorGraph
 28 |         from falkordb.asyncio import FalkorDB
 29 |     except ImportError:
 30 |         # If falkordb is not installed, raise an ImportError
 31 |         raise ImportError(
 32 |             'falkordb is required for FalkorDriver. '
 33 |             'Install it with: pip install graphiti-core[falkordb]'
 34 |         ) from None
 35 | 
 36 | from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
 37 | from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
 38 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
 39 | 
 40 | logger = logging.getLogger(__name__)
 41 | 
 42 | STOPWORDS = [
 43 |     'a',
 44 |     'is',
 45 |     'the',
 46 |     'an',
 47 |     'and',
 48 |     'are',
 49 |     'as',
 50 |     'at',
 51 |     'be',
 52 |     'but',
 53 |     'by',
 54 |     'for',
 55 |     'if',
 56 |     'in',
 57 |     'into',
 58 |     'it',
 59 |     'no',
 60 |     'not',
 61 |     'of',
 62 |     'on',
 63 |     'or',
 64 |     'such',
 65 |     'that',
 66 |     'their',
 67 |     'then',
 68 |     'there',
 69 |     'these',
 70 |     'they',
 71 |     'this',
 72 |     'to',
 73 |     'was',
 74 |     'will',
 75 |     'with',
 76 | ]
 77 | 
 78 | 
 79 | class FalkorDriverSession(GraphDriverSession):
 80 |     provider = GraphProvider.FALKORDB
 81 | 
 82 |     def __init__(self, graph: FalkorGraph):
 83 |         self.graph = graph
 84 | 
 85 |     async def __aenter__(self):
 86 |         return self
 87 | 
 88 |     async def __aexit__(self, exc_type, exc, tb):
 89 |         # No cleanup needed for Falkor, but method must exist
 90 |         pass
 91 | 
 92 |     async def close(self):
 93 |         # No explicit close needed for FalkorDB, but method must exist
 94 |         pass
 95 | 
 96 |     async def execute_write(self, func, *args, **kwargs):
 97 |         # Directly await the provided async function with `self` as the transaction/session
 98 |         return await func(self, *args, **kwargs)
 99 | 
100 |     async def run(self, query: str | list, **kwargs: Any) -> Any:
101 |         # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
102 |         if isinstance(query, list):
103 |             for cypher, params in query:
104 |                 params = convert_datetimes_to_strings(params)
105 |                 await self.graph.query(str(cypher), params)  # type: ignore[reportUnknownArgumentType]
106 |         else:
107 |             params = dict(kwargs)
108 |             params = convert_datetimes_to_strings(params)
109 |             await self.graph.query(str(query), params)  # type: ignore[reportUnknownArgumentType]
110 |         # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
111 |         return None
112 | 
113 | 
114 | class FalkorDriver(GraphDriver):
115 |     provider = GraphProvider.FALKORDB
116 |     default_group_id: str = '\\_'
117 |     fulltext_syntax: str = '@'  # FalkorDB uses a redisearch-like syntax for fulltext queries
118 |     aoss_client: None = None
119 | 
120 |     def __init__(
121 |         self,
122 |         host: str = 'localhost',
123 |         port: int = 6379,
124 |         username: str | None = None,
125 |         password: str | None = None,
126 |         falkor_db: FalkorDB | None = None,
127 |         database: str = 'default_db',
128 |     ):
129 |         """
130 |         Initialize the FalkorDB driver.
131 | 
132 |         FalkorDB is a multi-tenant graph database.
133 |         To connect, provide the host and port.
134 |         The default parameters assume a local (on-premises) FalkorDB instance.
135 | 
136 |         Args:
137 |         host (str): The host where FalkorDB is running.
138 |         port (int): The port on which FalkorDB is listening.
139 |         username (str | None): The username for authentication (if required).
140 |         password (str | None): The password for authentication (if required).
141 |         falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
142 |         database (str): The name of the database to connect to. Defaults to 'default_db'.
143 |         """
144 |         super().__init__()
145 |         self._database = database
146 |         if falkor_db is not None:
147 |             # If a FalkorDB instance is provided, use it directly
148 |             self.client = falkor_db
149 |         else:
150 |             self.client = FalkorDB(host=host, port=port, username=username, password=password)
151 | 
152 |         # Schedule the indices and constraints to be built
153 |         try:
154 |             # Try to get the current event loop
155 |             loop = asyncio.get_running_loop()
156 |             # Schedule the build_indices_and_constraints to run
157 |             loop.create_task(self.build_indices_and_constraints())
158 |         except RuntimeError:
159 |             # No event loop running, this will be handled later
160 |             pass
161 | 
162 |     def _get_graph(self, graph_name: str | None) -> FalkorGraph:
163 |         # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
164 |         if graph_name is None:
165 |             graph_name = self._database
166 |         return self.client.select_graph(graph_name)
167 | 
168 |     async def execute_query(self, cypher_query_, **kwargs: Any):
169 |         graph = self._get_graph(self._database)
170 | 
171 |         # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
172 |         params = convert_datetimes_to_strings(dict(kwargs))
173 | 
174 |         try:
175 |             result = await graph.query(cypher_query_, params)  # type: ignore[reportUnknownArgumentType]
176 |         except Exception as e:
177 |             if 'already indexed' in str(e):
178 |                 # check if index already exists
179 |                 logger.info(f'Index already exists: {e}')
180 |                 return None
181 |             logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
182 |             raise
183 | 
184 |         # Convert the result header to a list of strings
185 |         header = [h[1] for h in result.header]
186 | 
187 |         # Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
188 |         records = []
189 |         for row in result.result_set:
190 |             record = {}
191 |             for i, field_name in enumerate(header):
192 |                 if i < len(row):
193 |                     record[field_name] = row[i]
194 |                 else:
195 |                     # If there are more fields in header than values in row, set to None
196 |                     record[field_name] = None
197 |             records.append(record)
198 | 
199 |         return records, header, None
200 | 
201 |     def session(self, database: str | None = None) -> GraphDriverSession:
202 |         return FalkorDriverSession(self._get_graph(database))
203 | 
204 |     async def close(self) -> None:
205 |         """Close the driver connection."""
206 |         if hasattr(self.client, 'aclose'):
207 |             await self.client.aclose()  # type: ignore[reportUnknownMemberType]
208 |         elif hasattr(self.client.connection, 'aclose'):
209 |             await self.client.connection.aclose()
210 |         elif hasattr(self.client.connection, 'close'):
211 |             await self.client.connection.close()
212 | 
213 |     async def delete_all_indexes(self) -> None:
214 |         result = await self.execute_query('CALL db.indexes()')
215 |         if not result:
216 |             return
217 | 
218 |         records, _, _ = result
219 |         drop_tasks = []
220 | 
221 |         for record in records:
222 |             label = record['label']
223 |             entity_type = record['entitytype']
224 | 
225 |             for field_name, index_type in record['types'].items():
226 |                 if 'RANGE' in index_type:
227 |                     drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
228 |                 elif 'FULLTEXT' in index_type:
229 |                     if entity_type == 'NODE':
230 |                         drop_tasks.append(
231 |                             self.execute_query(
232 |                                 f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
233 |                             )
234 |                         )
235 |                     elif entity_type == 'RELATIONSHIP':
236 |                         drop_tasks.append(
237 |                             self.execute_query(
238 |                                 f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
239 |                             )
240 |                         )
241 | 
242 |         if drop_tasks:
243 |             await asyncio.gather(*drop_tasks)
244 | 
245 |     async def build_indices_and_constraints(self, delete_existing=False):
246 |         if delete_existing:
247 |             await self.delete_all_indexes()
248 |         index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
249 |         for query in index_queries:
250 |             await self.execute_query(query)
251 | 
252 |     def clone(self, database: str) -> 'GraphDriver':
253 |         """
254 |         Returns a shallow copy of this driver with a different default database.
255 |         Reuses the same connection (e.g. FalkorDB, Neo4j).
256 |         """
257 |         if database == self._database:
258 |             cloned = self
259 |         elif database == self.default_group_id:
260 |             cloned = FalkorDriver(falkor_db=self.client)
261 |         else:
262 |             # Create a new instance of FalkorDriver with the same connection but a different database
263 |             cloned = FalkorDriver(falkor_db=self.client, database=database)
264 | 
265 |         return cloned
266 | 
267 |     async def health_check(self) -> None:
268 |         """Check FalkorDB connectivity by running a simple query."""
269 |         try:
270 |             await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
271 |             return None
272 |         except Exception as e:
273 |             print(f'FalkorDB health check failed: {e}')
274 |             raise
275 | 
276 |     @staticmethod
277 |     def convert_datetimes_to_strings(obj):
278 |         if isinstance(obj, dict):
279 |             return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
280 |         elif isinstance(obj, list):
281 |             return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
282 |         elif isinstance(obj, tuple):
283 |             return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
284 |         elif isinstance(obj, datetime):
285 |             return obj.isoformat()
286 |         else:
287 |             return obj
288 | 
289 |     def sanitize(self, query: str) -> str:
290 |         """
291 |         Replace FalkorDB special characters with whitespace.
292 |         Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
293 |         """
294 |         # FalkorDB separator characters that break text into tokens
295 |         separator_map = str.maketrans(
296 |             {
297 |                 ',': ' ',
298 |                 '.': ' ',
299 |                 '<': ' ',
300 |                 '>': ' ',
301 |                 '{': ' ',
302 |                 '}': ' ',
303 |                 '[': ' ',
304 |                 ']': ' ',
305 |                 '"': ' ',
306 |                 "'": ' ',
307 |                 ':': ' ',
308 |                 ';': ' ',
309 |                 '!': ' ',
310 |                 '@': ' ',
311 |                 '#': ' ',
312 |                 '$': ' ',
313 |                 '%': ' ',
314 |                 '^': ' ',
315 |                 '&': ' ',
316 |                 '*': ' ',
317 |                 '(': ' ',
318 |                 ')': ' ',
319 |                 '-': ' ',
320 |                 '+': ' ',
321 |                 '=': ' ',
322 |                 '~': ' ',
323 |                 '?': ' ',
324 |             }
325 |         )
326 |         sanitized = query.translate(separator_map)
327 |         # Clean up multiple spaces
328 |         sanitized = ' '.join(sanitized.split())
329 |         return sanitized
330 | 
331 |     def build_fulltext_query(
332 |         self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
333 |     ) -> str:
334 |         """
335 |         Build a fulltext query string for FalkorDB using RedisSearch syntax.
336 |         FalkorDB uses RedisSearch-like syntax where:
337 |         - Field queries use @ prefix: @field:value
338 |         - Multiple values for same field: (@field:value1|value2)
339 |         - Text search doesn't need @ prefix for content fields
340 |         - AND is implicit with space: (@group_id:value) (text)
341 |         - OR uses pipe within parentheses: (@group_id:value1|value2)
342 |         """
343 |         if group_ids is None or len(group_ids) == 0:
344 |             group_filter = ''
345 |         else:
346 |             group_values = '|'.join(group_ids)
347 |             group_filter = f'(@group_id:{group_values})'
348 | 
349 |         sanitized_query = self.sanitize(query)
350 | 
351 |         # Remove stopwords from the sanitized query
352 |         query_words = sanitized_query.split()
353 |         filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
354 |         sanitized_query = ' | '.join(filtered_words)
355 | 
356 |         # If the query is too long return no query
357 |         if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
358 |             return ''
359 | 
360 |         full_query = group_filter + ' (' + sanitized_query + ')'
361 | 
362 |         return full_query
363 | 
```

--------------------------------------------------------------------------------
/graphiti_core/models/nodes/node_db_queries.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from typing import Any
 18 | 
 19 | from graphiti_core.driver.driver import GraphProvider
 20 | 
 21 | 
 22 | def get_episode_node_save_query(provider: GraphProvider) -> str:
 23 |     match provider:
 24 |         case GraphProvider.NEPTUNE:
 25 |             return """
 26 |                 MERGE (n:Episodic {uuid: $uuid})
 27 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
 28 |                 entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
 29 |                 RETURN n.uuid AS uuid
 30 |             """
 31 |         case GraphProvider.KUZU:
 32 |             return """
 33 |                 MERGE (n:Episodic {uuid: $uuid})
 34 |                 SET
 35 |                     n.name = $name,
 36 |                     n.group_id = $group_id,
 37 |                     n.created_at = $created_at,
 38 |                     n.source = $source,
 39 |                     n.source_description = $source_description,
 40 |                     n.content = $content,
 41 |                     n.valid_at = $valid_at,
 42 |                     n.entity_edges = $entity_edges
 43 |                 RETURN n.uuid AS uuid
 44 |             """
 45 |         case GraphProvider.FALKORDB:
 46 |             return """
 47 |                 MERGE (n:Episodic {uuid: $uuid})
 48 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
 49 |                 entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
 50 |                 RETURN n.uuid AS uuid
 51 |             """
 52 |         case _:  # Neo4j
 53 |             return """
 54 |                 MERGE (n:Episodic {uuid: $uuid})
 55 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
 56 |                 entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
 57 |                 RETURN n.uuid AS uuid
 58 |             """
 59 | 
 60 | 
 61 | def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
 62 |     match provider:
 63 |         case GraphProvider.NEPTUNE:
 64 |             return """
 65 |                 UNWIND $episodes AS episode
 66 |                 MERGE (n:Episodic {uuid: episode.uuid})
 67 |                 SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
 68 |                     source: episode.source, content: episode.content,
 69 |                 entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
 70 |                 RETURN n.uuid AS uuid
 71 |             """
 72 |         case GraphProvider.KUZU:
 73 |             return """
 74 |                 MERGE (n:Episodic {uuid: $uuid})
 75 |                 SET
 76 |                     n.name = $name,
 77 |                     n.group_id = $group_id,
 78 |                     n.created_at = $created_at,
 79 |                     n.source = $source,
 80 |                     n.source_description = $source_description,
 81 |                     n.content = $content,
 82 |                     n.valid_at = $valid_at,
 83 |                     n.entity_edges = $entity_edges
 84 |                 RETURN n.uuid AS uuid
 85 |             """
 86 |         case GraphProvider.FALKORDB:
 87 |             return """
 88 |                 UNWIND $episodes AS episode
 89 |                 MERGE (n:Episodic {uuid: episode.uuid})
 90 |                 SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, 
 91 |                 entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
 92 |                 RETURN n.uuid AS uuid
 93 |             """
 94 |         case _:  # Neo4j
 95 |             return """
 96 |                 UNWIND $episodes AS episode
 97 |                 MERGE (n:Episodic {uuid: episode.uuid})
 98 |                 SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, 
 99 |                 entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
100 |                 RETURN n.uuid AS uuid
101 |             """
102 | 
103 | 
104 | EPISODIC_NODE_RETURN = """
105 |     e.uuid AS uuid,
106 |     e.name AS name,
107 |     e.group_id AS group_id,
108 |     e.created_at AS created_at,
109 |     e.source AS source,
110 |     e.source_description AS source_description,
111 |     e.content AS content,
112 |     e.valid_at AS valid_at,
113 |     e.entity_edges AS entity_edges
114 | """
115 | 
116 | EPISODIC_NODE_RETURN_NEPTUNE = """
117 |     e.content AS content,
118 |     e.created_at AS created_at,
119 |     e.valid_at AS valid_at,
120 |     e.uuid AS uuid,
121 |     e.name AS name,
122 |     e.group_id AS group_id,
123 |     e.source_description AS source_description,
124 |     e.source AS source,
125 |     split(e.entity_edges, ",") AS entity_edges
126 | """
127 | 
128 | 
129 | def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
130 |     match provider:
131 |         case GraphProvider.FALKORDB:
132 |             return f"""
133 |                 MERGE (n:Entity {{uuid: $entity_data.uuid}})
134 |                 SET n:{labels}
135 |                 SET n = $entity_data
136 |                 SET n.name_embedding = vecf32($entity_data.name_embedding)
137 |                 RETURN n.uuid AS uuid
138 |             """
139 |         case GraphProvider.KUZU:
140 |             return """
141 |                 MERGE (n:Entity {uuid: $uuid})
142 |                 SET
143 |                     n.name = $name,
144 |                     n.group_id = $group_id,
145 |                     n.labels = $labels,
146 |                     n.created_at = $created_at,
147 |                     n.name_embedding = $name_embedding,
148 |                     n.summary = $summary,
149 |                     n.attributes = $attributes
150 |                 WITH n
151 |                 RETURN n.uuid AS uuid
152 |             """
153 |         case GraphProvider.NEPTUNE:
154 |             label_subquery = ''
155 |             for label in labels.split(':'):
156 |                 label_subquery += f' SET n:{label}\n'
157 |             return f"""
158 |                 MERGE (n:Entity {{uuid: $entity_data.uuid}})
159 |                 {label_subquery}
160 |                 SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
161 |                 SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
162 |                 RETURN n.uuid AS uuid
163 |             """
164 |         case _:
165 |             save_embedding_query = (
166 |                 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
167 |                 if not has_aoss
168 |                 else ''
169 |             )
170 |             return (
171 |                 f"""
172 |                 MERGE (n:Entity {{uuid: $entity_data.uuid}})
173 |                 SET n:{labels}
174 |                 SET n = $entity_data
175 |                 """
176 |                 + save_embedding_query
177 |                 + """
178 |                 RETURN n.uuid AS uuid
179 |             """
180 |             )
181 | 
182 | 
183 | def get_entity_node_save_bulk_query(
184 |     provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
185 | ) -> str | Any:
186 |     match provider:
187 |         case GraphProvider.FALKORDB:
188 |             queries = []
189 |             for node in nodes:
190 |                 for label in node['labels']:
191 |                     queries.append(
192 |                         (
193 |                             f"""
194 |                             UNWIND $nodes AS node
195 |                             MERGE (n:Entity {{uuid: node.uuid}})
196 |                             SET n:{label}
197 |                             SET n = node
198 |                             WITH n, node
199 |                             SET n.name_embedding = vecf32(node.name_embedding)
200 |                             RETURN n.uuid AS uuid
201 |                             """,
202 |                             {'nodes': [node]},
203 |                         )
204 |                     )
205 |             return queries
206 |         case GraphProvider.NEPTUNE:
207 |             queries = []
208 |             for node in nodes:
209 |                 labels = ''
210 |                 for label in node['labels']:
211 |                     labels += f' SET n:{label}\n'
212 |                 queries.append(
213 |                     f"""
214 |                         UNWIND $nodes AS node
215 |                         MERGE (n:Entity {{uuid: node.uuid}})
216 |                         {labels}
217 |                         SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
218 |                         SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
219 |                         RETURN n.uuid AS uuid
220 |                     """
221 |                 )
222 |             return queries
223 |         case GraphProvider.KUZU:
224 |             return """
225 |                 MERGE (n:Entity {uuid: $uuid})
226 |                 SET
227 |                     n.name = $name,
228 |                     n.group_id = $group_id,
229 |                     n.labels = $labels,
230 |                     n.created_at = $created_at,
231 |                     n.name_embedding = $name_embedding,
232 |                     n.summary = $summary,
233 |                     n.attributes = $attributes
234 |                 RETURN n.uuid AS uuid
235 |             """
236 |         case _:  # Neo4j
237 |             save_embedding_query = (
238 |                 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
239 |                 if not has_aoss
240 |                 else ''
241 |             )
242 |             return (
243 |                 """
244 |                     UNWIND $nodes AS node
245 |                     MERGE (n:Entity {uuid: node.uuid})
246 |                     SET n:$(node.labels)
247 |                     SET n = node
248 |                     """
249 |                 + save_embedding_query
250 |                 + """
251 |                 RETURN n.uuid AS uuid
252 |             """
253 |             )
254 | 
255 | 
256 | def get_entity_node_return_query(provider: GraphProvider) -> str:
257 |     # `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
258 |     if provider == GraphProvider.KUZU:
259 |         return """
260 |             n.uuid AS uuid,
261 |             n.name AS name,
262 |             n.group_id AS group_id,
263 |             n.labels AS labels,
264 |             n.created_at AS created_at,
265 |             n.summary AS summary,
266 |             n.attributes AS attributes
267 |         """
268 | 
269 |     return """
270 |         n.uuid AS uuid,
271 |         n.name AS name,
272 |         n.group_id AS group_id,
273 |         n.created_at AS created_at,
274 |         n.summary AS summary,
275 |         labels(n) AS labels,
276 |         properties(n) AS attributes
277 |     """
278 | 
279 | 
280 | def get_community_node_save_query(provider: GraphProvider) -> str:
281 |     match provider:
282 |         case GraphProvider.FALKORDB:
283 |             return """
284 |                 MERGE (n:Community {uuid: $uuid})
285 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
286 |                 RETURN n.uuid AS uuid
287 |             """
288 |         case GraphProvider.NEPTUNE:
289 |             return """
290 |                 MERGE (n:Community {uuid: $uuid})
291 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
292 |                 SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
293 |                 RETURN n.uuid AS uuid
294 |             """
295 |         case GraphProvider.KUZU:
296 |             return """
297 |                 MERGE (n:Community {uuid: $uuid})
298 |                 SET
299 |                     n.name = $name,
300 |                     n.group_id = $group_id,
301 |                     n.created_at = $created_at,
302 |                     n.name_embedding = $name_embedding,
303 |                     n.summary = $summary
304 |                 RETURN n.uuid AS uuid
305 |             """
306 |         case _:  # Neo4j
307 |             return """
308 |                 MERGE (n:Community {uuid: $uuid})
309 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
310 |                 WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
311 |                 RETURN n.uuid AS uuid
312 |             """
313 | 
314 | 
315 | COMMUNITY_NODE_RETURN = """
316 |     c.uuid AS uuid,
317 |     c.name AS name,
318 |     c.group_id AS group_id,
319 |     c.created_at AS created_at,
320 |     c.name_embedding AS name_embedding,
321 |     c.summary AS summary
322 | """
323 | 
324 | COMMUNITY_NODE_RETURN_NEPTUNE = """
325 |     n.uuid AS uuid,
326 |     n.name AS name,
327 |     [x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
328 |     n.group_id AS group_id,
329 |     n.summary AS summary,
330 |     n.created_at AS created_at
331 | """
332 | 
```

--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
 18 | 
 19 | from unittest.mock import AsyncMock, MagicMock, patch
 20 | 
 21 | import pytest
 22 | 
 23 | from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
 24 | from graphiti_core.llm_client import LLMConfig, RateLimitError
 25 | 
 26 | 
 27 | @pytest.fixture
 28 | def mock_gemini_client():
 29 |     """Fixture to mock the Google Gemini client."""
 30 |     with patch('google.genai.Client') as mock_client:
 31 |         # Setup mock instance and its methods
 32 |         mock_instance = mock_client.return_value
 33 |         mock_instance.aio = MagicMock()
 34 |         mock_instance.aio.models = MagicMock()
 35 |         mock_instance.aio.models.generate_content = AsyncMock()
 36 |         yield mock_instance
 37 | 
 38 | 
 39 | @pytest.fixture
 40 | def gemini_reranker_client(mock_gemini_client):
 41 |     """Fixture to create a GeminiRerankerClient with a mocked client."""
 42 |     config = LLMConfig(api_key='test_api_key', model='test-model')
 43 |     client = GeminiRerankerClient(config=config)
 44 |     # Replace the client's client with our mock to ensure we're using the mock
 45 |     client.client = mock_gemini_client
 46 |     return client
 47 | 
 48 | 
 49 | def create_mock_response(score_text: str) -> MagicMock:
 50 |     """Helper function to create a mock Gemini response."""
 51 |     mock_response = MagicMock()
 52 |     mock_response.text = score_text
 53 |     return mock_response
 54 | 
 55 | 
 56 | class TestGeminiRerankerClientInitialization:
 57 |     """Tests for GeminiRerankerClient initialization."""
 58 | 
 59 |     def test_init_with_config(self):
 60 |         """Test initialization with a config object."""
 61 |         config = LLMConfig(api_key='test_api_key', model='test-model')
 62 |         client = GeminiRerankerClient(config=config)
 63 | 
 64 |         assert client.config == config
 65 | 
 66 |     @patch('google.genai.Client')
 67 |     def test_init_without_config(self, mock_client):
 68 |         """Test initialization without a config uses defaults."""
 69 |         client = GeminiRerankerClient()
 70 | 
 71 |         assert client.config is not None
 72 | 
 73 |     def test_init_with_custom_client(self):
 74 |         """Test initialization with a custom client."""
 75 |         mock_client = MagicMock()
 76 |         client = GeminiRerankerClient(client=mock_client)
 77 | 
 78 |         assert client.client == mock_client
 79 | 
 80 | 
 81 | class TestGeminiRerankerClientRanking:
 82 |     """Tests for GeminiRerankerClient rank method."""
 83 | 
 84 |     @pytest.mark.asyncio
 85 |     async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
 86 |         """Test basic ranking functionality."""
 87 |         # Setup mock responses with different scores
 88 |         mock_responses = [
 89 |             create_mock_response('85'),  # High relevance
 90 |             create_mock_response('45'),  # Medium relevance
 91 |             create_mock_response('20'),  # Low relevance
 92 |         ]
 93 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
 94 | 
 95 |         # Test data
 96 |         query = 'What is the capital of France?'
 97 |         passages = [
 98 |             'Paris is the capital and most populous city of France.',
 99 |             'London is the capital city of England and the United Kingdom.',
100 |             'Berlin is the capital and largest city of Germany.',
101 |         ]
102 | 
103 |         # Call method
104 |         result = await gemini_reranker_client.rank(query, passages)
105 | 
106 |         # Assertions
107 |         assert len(result) == 3
108 |         assert all(isinstance(item, tuple) for item in result)
109 |         assert all(
110 |             isinstance(passage, str) and isinstance(score, float) for passage, score in result
111 |         )
112 | 
113 |         # Check scores are normalized to [0, 1] and sorted in descending order
114 |         scores = [score for _, score in result]
115 |         assert all(0.0 <= score <= 1.0 for score in scores)
116 |         assert scores == sorted(scores, reverse=True)
117 | 
118 |         # Check that the highest scoring passage is first
119 |         assert result[0][1] == 0.85  # 85/100
120 |         assert result[1][1] == 0.45  # 45/100
121 |         assert result[2][1] == 0.20  # 20/100
122 | 
123 |     @pytest.mark.asyncio
124 |     async def test_rank_empty_passages(self, gemini_reranker_client):
125 |         """Test ranking with empty passages list."""
126 |         query = 'Test query'
127 |         passages = []
128 | 
129 |         result = await gemini_reranker_client.rank(query, passages)
130 | 
131 |         assert result == []
132 | 
133 |     @pytest.mark.asyncio
134 |     async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
135 |         """Test ranking with a single passage."""
136 |         # Setup mock response
137 |         mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
138 | 
139 |         query = 'Test query'
140 |         passages = ['Single test passage']
141 | 
142 |         result = await gemini_reranker_client.rank(query, passages)
143 | 
144 |         assert len(result) == 1
145 |         assert result[0][0] == 'Single test passage'
146 |         assert result[0][1] == 1.0  # Single passage gets full score
147 | 
148 |     @pytest.mark.asyncio
149 |     async def test_rank_score_extraction_with_regex(
150 |         self, gemini_reranker_client, mock_gemini_client
151 |     ):
152 |         """Test score extraction from various response formats."""
153 |         # Setup mock responses with different formats
154 |         mock_responses = [
155 |             create_mock_response('Score: 90'),  # Contains text before number
156 |             create_mock_response('The relevance is 65 out of 100'),  # Contains text around number
157 |             create_mock_response('8'),  # Just the number
158 |         ]
159 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
160 | 
161 |         query = 'Test query'
162 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
163 | 
164 |         result = await gemini_reranker_client.rank(query, passages)
165 | 
166 |         # Check that scores were extracted correctly and normalized
167 |         scores = [score for _, score in result]
168 |         assert 0.90 in scores  # 90/100
169 |         assert 0.65 in scores  # 65/100
170 |         assert 0.08 in scores  # 8/100
171 | 
172 |     @pytest.mark.asyncio
173 |     async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
174 |         """Test handling of invalid or non-numeric scores."""
175 |         # Setup mock responses with invalid scores
176 |         mock_responses = [
177 |             create_mock_response('Not a number'),  # Invalid response
178 |             create_mock_response(''),  # Empty response
179 |             create_mock_response('95'),  # Valid response
180 |         ]
181 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
182 | 
183 |         query = 'Test query'
184 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
185 | 
186 |         result = await gemini_reranker_client.rank(query, passages)
187 | 
188 |         # Check that invalid scores are handled gracefully (assigned 0.0)
189 |         scores = [score for _, score in result]
190 |         assert 0.95 in scores  # Valid score
191 |         assert scores.count(0.0) == 2  # Two invalid scores assigned 0.0
192 | 
193 |     @pytest.mark.asyncio
194 |     async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
195 |         """Test that scores are properly clamped to [0, 1] range."""
196 |         # Setup mock responses with extreme scores
197 |         # Note: regex only matches 1-3 digits, so negative numbers won't match
198 |         mock_responses = [
199 |             create_mock_response('999'),  # Above 100 but within regex range
200 |             create_mock_response('invalid'),  # Invalid response becomes 0.0
201 |             create_mock_response('50'),  # Normal score
202 |         ]
203 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
204 | 
205 |         query = 'Test query'
206 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
207 | 
208 |         result = await gemini_reranker_client.rank(query, passages)
209 | 
210 |         # Check that scores are normalized and clamped
211 |         scores = [score for _, score in result]
212 |         assert all(0.0 <= score <= 1.0 for score in scores)
213 |         # 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
214 |         assert 1.0 in scores
215 |         # Invalid response should be 0.0
216 |         assert 0.0 in scores
217 |         # Normal score should be normalized (50/100 = 0.5)
218 |         assert 0.5 in scores
219 | 
220 |     @pytest.mark.asyncio
221 |     async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
222 |         """Test handling of rate limit errors."""
223 |         # Setup mock to raise rate limit error
224 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
225 |             'Rate limit exceeded'
226 |         )
227 | 
228 |         query = 'Test query'
229 |         passages = ['Passage 1', 'Passage 2']
230 | 
231 |         with pytest.raises(RateLimitError):
232 |             await gemini_reranker_client.rank(query, passages)
233 | 
234 |     @pytest.mark.asyncio
235 |     async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
236 |         """Test handling of quota errors."""
237 |         # Setup mock to raise quota error
238 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
239 | 
240 |         query = 'Test query'
241 |         passages = ['Passage 1', 'Passage 2']
242 | 
243 |         with pytest.raises(RateLimitError):
244 |             await gemini_reranker_client.rank(query, passages)
245 | 
246 |     @pytest.mark.asyncio
247 |     async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
248 |         """Test handling of resource exhausted errors."""
249 |         # Setup mock to raise resource exhausted error
250 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
251 | 
252 |         query = 'Test query'
253 |         passages = ['Passage 1', 'Passage 2']
254 | 
255 |         with pytest.raises(RateLimitError):
256 |             await gemini_reranker_client.rank(query, passages)
257 | 
258 |     @pytest.mark.asyncio
259 |     async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
260 |         """Test handling of HTTP 429 errors."""
261 |         # Setup mock to raise 429 error
262 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
263 |             'HTTP 429 Too Many Requests'
264 |         )
265 | 
266 |         query = 'Test query'
267 |         passages = ['Passage 1', 'Passage 2']
268 | 
269 |         with pytest.raises(RateLimitError):
270 |             await gemini_reranker_client.rank(query, passages)
271 | 
272 |     @pytest.mark.asyncio
273 |     async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
274 |         """Test handling of generic errors."""
275 |         # Setup mock to raise generic error
276 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
277 | 
278 |         query = 'Test query'
279 |         passages = ['Passage 1', 'Passage 2']
280 | 
281 |         with pytest.raises(Exception) as exc_info:
282 |             await gemini_reranker_client.rank(query, passages)
283 | 
284 |         assert 'Generic error' in str(exc_info.value)
285 | 
286 |     @pytest.mark.asyncio
287 |     async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
288 |         """Test that multiple passages are scored concurrently."""
289 |         # Setup mock responses
290 |         mock_responses = [
291 |             create_mock_response('80'),
292 |             create_mock_response('60'),
293 |             create_mock_response('40'),
294 |         ]
295 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
296 | 
297 |         query = 'Test query'
298 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
299 | 
300 |         await gemini_reranker_client.rank(query, passages)
301 | 
302 |         # Verify that generate_content was called for each passage
303 |         assert mock_gemini_client.aio.models.generate_content.call_count == 3
304 | 
305 |         # Verify that all calls were made with correct parameters
306 |         calls = mock_gemini_client.aio.models.generate_content.call_args_list
307 |         for call in calls:
308 |             args, kwargs = call
309 |             assert kwargs['model'] == gemini_reranker_client.config.model
310 |             assert kwargs['config'].temperature == 0.0
311 |             assert kwargs['config'].max_output_tokens == 3
312 | 
313 |     @pytest.mark.asyncio
314 |     async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
315 |         """Test handling of response parsing errors."""
316 |         # Setup mock responses that will trigger ValueError during parsing
317 |         mock_responses = [
318 |             create_mock_response('not a number at all'),  # Will fail regex match
319 |             create_mock_response('also invalid text'),  # Will fail regex match
320 |         ]
321 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
322 | 
323 |         query = 'Test query'
324 |         # Use multiple passages to avoid the single passage special case
325 |         passages = ['Passage 1', 'Passage 2']
326 | 
327 |         result = await gemini_reranker_client.rank(query, passages)
328 | 
329 |         # Should handle the error gracefully and assign 0.0 score to both
330 |         assert len(result) == 2
331 |         assert all(score == 0.0 for _, score in result)
332 | 
333 |     @pytest.mark.asyncio
334 |     async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
335 |         """Test handling of empty response text."""
336 |         # Setup mock response with empty text
337 |         mock_response = MagicMock()
338 |         mock_response.text = ''  # Empty string instead of None
339 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 | 
341 |         query = 'Test query'
342 |         # Use multiple passages to avoid the single passage special case
343 |         passages = ['Passage 1', 'Passage 2']
344 | 
345 |         result = await gemini_reranker_client.rank(query, passages)
346 | 
347 |         # Should handle empty text gracefully and assign 0.0 score to both
348 |         assert len(result) == 2
349 |         assert all(score == 0.0 for _, score in result)
350 | 
351 | 
352 | if __name__ == '__main__':
353 |     pytest.main(['-v', 'test_gemini_reranker_client.py'])
354 | 
```

--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | import sys
 19 | from datetime import datetime
 20 | 
 21 | import numpy as np
 22 | import pytest
 23 | 
 24 | from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
 25 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
 26 | from tests.helpers_test import get_edge_count, get_node_count, group_id
 27 | 
 28 | pytest_plugins = ('pytest_asyncio',)
 29 | 
 30 | 
 31 | def setup_logging():
 32 |     # Create a logger
 33 |     logger = logging.getLogger()
 34 |     logger.setLevel(logging.INFO)  # Set the logging level to INFO
 35 | 
 36 |     # Create console handler and set level to INFO
 37 |     console_handler = logging.StreamHandler(sys.stdout)
 38 |     console_handler.setLevel(logging.INFO)
 39 | 
 40 |     # Create formatter
 41 |     formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 42 | 
 43 |     # Add formatter to console handler
 44 |     console_handler.setFormatter(formatter)
 45 | 
 46 |     # Add console handler to logger
 47 |     logger.addHandler(console_handler)
 48 | 
 49 |     return logger
 50 | 
 51 | 
 52 | @pytest.mark.asyncio
 53 | async def test_episodic_edge(graph_driver, mock_embedder):
 54 |     now = datetime.now()
 55 | 
 56 |     # Create episodic node
 57 |     episode_node = EpisodicNode(
 58 |         name='test_episode',
 59 |         labels=[],
 60 |         created_at=now,
 61 |         valid_at=now,
 62 |         source=EpisodeType.message,
 63 |         source_description='conversation message',
 64 |         content='Alice likes Bob',
 65 |         entity_edges=[],
 66 |         group_id=group_id,
 67 |     )
 68 |     node_count = await get_node_count(graph_driver, [episode_node.uuid])
 69 |     assert node_count == 0
 70 |     await episode_node.save(graph_driver)
 71 |     node_count = await get_node_count(graph_driver, [episode_node.uuid])
 72 |     assert node_count == 1
 73 | 
 74 |     # Create entity node
 75 |     alice_node = EntityNode(
 76 |         name='Alice',
 77 |         labels=[],
 78 |         created_at=now,
 79 |         summary='Alice summary',
 80 |         group_id=group_id,
 81 |     )
 82 |     await alice_node.generate_name_embedding(mock_embedder)
 83 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
 84 |     assert node_count == 0
 85 |     await alice_node.save(graph_driver)
 86 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
 87 |     assert node_count == 1
 88 | 
 89 |     # Create episodic to entity edge
 90 |     episodic_edge = EpisodicEdge(
 91 |         source_node_uuid=episode_node.uuid,
 92 |         target_node_uuid=alice_node.uuid,
 93 |         created_at=now,
 94 |         group_id=group_id,
 95 |     )
 96 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
 97 |     assert edge_count == 0
 98 |     await episodic_edge.save(graph_driver)
 99 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
100 |     assert edge_count == 1
101 | 
102 |     # Get edge by uuid
103 |     retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
104 |     assert retrieved.uuid == episodic_edge.uuid
105 |     assert retrieved.source_node_uuid == episode_node.uuid
106 |     assert retrieved.target_node_uuid == alice_node.uuid
107 |     assert retrieved.created_at == now
108 |     assert retrieved.group_id == group_id
109 | 
110 |     # Get edge by uuids
111 |     retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
112 |     assert len(retrieved) == 1
113 |     assert retrieved[0].uuid == episodic_edge.uuid
114 |     assert retrieved[0].source_node_uuid == episode_node.uuid
115 |     assert retrieved[0].target_node_uuid == alice_node.uuid
116 |     assert retrieved[0].created_at == now
117 |     assert retrieved[0].group_id == group_id
118 | 
119 |     # Get edge by group ids
120 |     retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
121 |     assert len(retrieved) == 1
122 |     assert retrieved[0].uuid == episodic_edge.uuid
123 |     assert retrieved[0].source_node_uuid == episode_node.uuid
124 |     assert retrieved[0].target_node_uuid == alice_node.uuid
125 |     assert retrieved[0].created_at == now
126 |     assert retrieved[0].group_id == group_id
127 | 
128 |     # Get episodic node by entity node uuid
129 |     retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
130 |     assert len(retrieved) == 1
131 |     assert retrieved[0].uuid == episode_node.uuid
132 |     assert retrieved[0].name == 'test_episode'
133 |     assert retrieved[0].created_at == now
134 |     assert retrieved[0].group_id == group_id
135 | 
136 |     # Delete edge by uuid
137 |     await episodic_edge.delete(graph_driver)
138 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
139 |     assert edge_count == 0
140 | 
141 |     # Delete edge by uuids
142 |     await episodic_edge.save(graph_driver)
143 |     await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
144 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
145 |     assert edge_count == 0
146 | 
147 |     # Cleanup nodes
148 |     await episode_node.delete(graph_driver)
149 |     node_count = await get_node_count(graph_driver, [episode_node.uuid])
150 |     assert node_count == 0
151 |     await alice_node.delete(graph_driver)
152 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
153 |     assert node_count == 0
154 | 
155 |     await graph_driver.close()
156 | 
157 | 
158 | @pytest.mark.asyncio
159 | async def test_entity_edge(graph_driver, mock_embedder):
160 |     now = datetime.now()
161 | 
162 |     # Create entity node
163 |     alice_node = EntityNode(
164 |         name='Alice',
165 |         labels=[],
166 |         created_at=now,
167 |         summary='Alice summary',
168 |         group_id=group_id,
169 |     )
170 |     await alice_node.generate_name_embedding(mock_embedder)
171 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
172 |     assert node_count == 0
173 |     await alice_node.save(graph_driver)
174 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
175 |     assert node_count == 1
176 | 
177 |     # Create entity node
178 |     bob_node = EntityNode(
179 |         name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
180 |     )
181 |     await bob_node.generate_name_embedding(mock_embedder)
182 |     node_count = await get_node_count(graph_driver, [bob_node.uuid])
183 |     assert node_count == 0
184 |     await bob_node.save(graph_driver)
185 |     node_count = await get_node_count(graph_driver, [bob_node.uuid])
186 |     assert node_count == 1
187 | 
188 |     # Create entity to entity edge
189 |     entity_edge = EntityEdge(
190 |         source_node_uuid=alice_node.uuid,
191 |         target_node_uuid=bob_node.uuid,
192 |         created_at=now,
193 |         name='likes',
194 |         fact='Alice likes Bob',
195 |         episodes=[],
196 |         expired_at=now,
197 |         valid_at=now,
198 |         invalid_at=now,
199 |         group_id=group_id,
200 |     )
201 |     edge_embedding = await entity_edge.generate_embedding(mock_embedder)
202 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
203 |     assert edge_count == 0
204 |     await entity_edge.save(graph_driver)
205 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
206 |     assert edge_count == 1
207 | 
208 |     # Get edge by uuid
209 |     retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
210 |     assert retrieved.uuid == entity_edge.uuid
211 |     assert retrieved.source_node_uuid == alice_node.uuid
212 |     assert retrieved.target_node_uuid == bob_node.uuid
213 |     assert retrieved.created_at == now
214 |     assert retrieved.group_id == group_id
215 | 
216 |     # Get edge by uuids
217 |     retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
218 |     assert len(retrieved) == 1
219 |     assert retrieved[0].uuid == entity_edge.uuid
220 |     assert retrieved[0].source_node_uuid == alice_node.uuid
221 |     assert retrieved[0].target_node_uuid == bob_node.uuid
222 |     assert retrieved[0].created_at == now
223 |     assert retrieved[0].group_id == group_id
224 | 
225 |     # Get edge by group ids
226 |     retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
227 |     assert len(retrieved) == 1
228 |     assert retrieved[0].uuid == entity_edge.uuid
229 |     assert retrieved[0].source_node_uuid == alice_node.uuid
230 |     assert retrieved[0].target_node_uuid == bob_node.uuid
231 |     assert retrieved[0].created_at == now
232 |     assert retrieved[0].group_id == group_id
233 | 
234 |     # Get edge by node uuid
235 |     retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
236 |     assert len(retrieved) == 1
237 |     assert retrieved[0].uuid == entity_edge.uuid
238 |     assert retrieved[0].source_node_uuid == alice_node.uuid
239 |     assert retrieved[0].target_node_uuid == bob_node.uuid
240 |     assert retrieved[0].created_at == now
241 |     assert retrieved[0].group_id == group_id
242 | 
243 |     # Get fact embedding
244 |     await entity_edge.load_fact_embedding(graph_driver)
245 |     assert np.allclose(entity_edge.fact_embedding, edge_embedding)
246 | 
247 |     # Delete edge by uuid
248 |     await entity_edge.delete(graph_driver)
249 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
250 |     assert edge_count == 0
251 | 
252 |     # Delete edge by uuids
253 |     await entity_edge.save(graph_driver)
254 |     await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
255 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
256 |     assert edge_count == 0
257 | 
258 |     # Deleting node should delete the edge
259 |     await entity_edge.save(graph_driver)
260 |     await alice_node.delete(graph_driver)
261 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
262 |     assert node_count == 0
263 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
264 |     assert edge_count == 0
265 | 
266 |     # Deleting node by uuids should delete the edge
267 |     await alice_node.save(graph_driver)
268 |     await entity_edge.save(graph_driver)
269 |     await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
270 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
271 |     assert node_count == 0
272 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
273 |     assert edge_count == 0
274 | 
275 |     # Deleting node by group id should delete the edge
276 |     await alice_node.save(graph_driver)
277 |     await entity_edge.save(graph_driver)
278 |     await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
279 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
280 |     assert node_count == 0
281 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
282 |     assert edge_count == 0
283 | 
284 |     # Cleanup nodes
285 |     await alice_node.delete(graph_driver)
286 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
287 |     assert node_count == 0
288 |     await bob_node.delete(graph_driver)
289 |     node_count = await get_node_count(graph_driver, [bob_node.uuid])
290 |     assert node_count == 0
291 | 
292 |     await graph_driver.close()
293 | 
294 | 
295 | @pytest.mark.asyncio
296 | async def test_community_edge(graph_driver, mock_embedder):
297 |     now = datetime.now()
298 | 
299 |     # Create community node
300 |     community_node_1 = CommunityNode(
301 |         name='test_community_1',
302 |         group_id=group_id,
303 |         summary='Community A summary',
304 |     )
305 |     await community_node_1.generate_name_embedding(mock_embedder)
306 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid])
307 |     assert node_count == 0
308 |     await community_node_1.save(graph_driver)
309 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid])
310 |     assert node_count == 1
311 | 
312 |     # Create community node
313 |     community_node_2 = CommunityNode(
314 |         name='test_community_2',
315 |         group_id=group_id,
316 |         summary='Community B summary',
317 |     )
318 |     await community_node_2.generate_name_embedding(mock_embedder)
319 |     node_count = await get_node_count(graph_driver, [community_node_2.uuid])
320 |     assert node_count == 0
321 |     await community_node_2.save(graph_driver)
322 |     node_count = await get_node_count(graph_driver, [community_node_2.uuid])
323 |     assert node_count == 1
324 | 
325 |     # Create entity node
326 |     alice_node = EntityNode(
327 |         name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
328 |     )
329 |     await alice_node.generate_name_embedding(mock_embedder)
330 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
331 |     assert node_count == 0
332 |     await alice_node.save(graph_driver)
333 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
334 |     assert node_count == 1
335 | 
336 |     # Create community to community edge
337 |     community_edge = CommunityEdge(
338 |         source_node_uuid=community_node_1.uuid,
339 |         target_node_uuid=community_node_2.uuid,
340 |         created_at=now,
341 |         group_id=group_id,
342 |     )
343 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
344 |     assert edge_count == 0
345 |     await community_edge.save(graph_driver)
346 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
347 |     assert edge_count == 1
348 | 
349 |     # Get edge by uuid
350 |     retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
351 |     assert retrieved.uuid == community_edge.uuid
352 |     assert retrieved.source_node_uuid == community_node_1.uuid
353 |     assert retrieved.target_node_uuid == community_node_2.uuid
354 |     assert retrieved.created_at == now
355 |     assert retrieved.group_id == group_id
356 | 
357 |     # Get edge by uuids
358 |     retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
359 |     assert len(retrieved) == 1
360 |     assert retrieved[0].uuid == community_edge.uuid
361 |     assert retrieved[0].source_node_uuid == community_node_1.uuid
362 |     assert retrieved[0].target_node_uuid == community_node_2.uuid
363 |     assert retrieved[0].created_at == now
364 |     assert retrieved[0].group_id == group_id
365 | 
366 |     # Get edge by group ids
367 |     retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
368 |     assert len(retrieved) == 1
369 |     assert retrieved[0].uuid == community_edge.uuid
370 |     assert retrieved[0].source_node_uuid == community_node_1.uuid
371 |     assert retrieved[0].target_node_uuid == community_node_2.uuid
372 |     assert retrieved[0].created_at == now
373 |     assert retrieved[0].group_id == group_id
374 | 
375 |     # Delete edge by uuid
376 |     await community_edge.delete(graph_driver)
377 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
378 |     assert edge_count == 0
379 | 
380 |     # Delete edge by uuids
381 |     await community_edge.save(graph_driver)
382 |     await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
383 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
384 |     assert edge_count == 0
385 | 
386 |     # Cleanup nodes
387 |     await alice_node.delete(graph_driver)
388 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
389 |     assert node_count == 0
390 |     await community_node_1.delete(graph_driver)
391 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid])
392 |     assert node_count == 0
393 |     await community_node_2.delete(graph_driver)
394 |     node_count = await get_node_count(graph_driver, [community_node_2.uuid])
395 |     assert node_count == 0
396 | 
397 |     await graph_driver.close()
398 | 
```

--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/embedder/test_gemini.py
 18 | 
 19 | from collections.abc import Generator
 20 | from typing import Any
 21 | from unittest.mock import AsyncMock, MagicMock, patch
 22 | 
 23 | import pytest
 24 | from embedder_fixtures import create_embedding_values
 25 | 
 26 | from graphiti_core.embedder.gemini import (
 27 |     DEFAULT_EMBEDDING_MODEL,
 28 |     GeminiEmbedder,
 29 |     GeminiEmbedderConfig,
 30 | )
 31 | 
 32 | 
 33 | def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
 34 |     """Create a mock Gemini embedding with specified value multiplier and dimension."""
 35 |     mock_embedding = MagicMock()
 36 |     mock_embedding.values = create_embedding_values(multiplier, dimension)
 37 |     return mock_embedding
 38 | 
 39 | 
 40 | @pytest.fixture
 41 | def mock_gemini_response() -> MagicMock:
 42 |     """Create a mock Gemini embeddings response."""
 43 |     mock_result = MagicMock()
 44 |     mock_result.embeddings = [create_gemini_embedding()]
 45 |     return mock_result
 46 | 
 47 | 
 48 | @pytest.fixture
 49 | def mock_gemini_batch_response() -> MagicMock:
 50 |     """Create a mock Gemini batch embeddings response."""
 51 |     mock_result = MagicMock()
 52 |     mock_result.embeddings = [
 53 |         create_gemini_embedding(0.1),
 54 |         create_gemini_embedding(0.2),
 55 |         create_gemini_embedding(0.3),
 56 |     ]
 57 |     return mock_result
 58 | 
 59 | 
 60 | @pytest.fixture
 61 | def mock_gemini_client() -> Generator[Any, Any, None]:
 62 |     """Create a mocked Gemini client."""
 63 |     with patch('google.genai.Client') as mock_client:
 64 |         mock_instance = mock_client.return_value
 65 |         mock_instance.aio = MagicMock()
 66 |         mock_instance.aio.models = MagicMock()
 67 |         mock_instance.aio.models.embed_content = AsyncMock()
 68 |         yield mock_instance
 69 | 
 70 | 
 71 | @pytest.fixture
 72 | def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
 73 |     """Create a GeminiEmbedder with a mocked client."""
 74 |     config = GeminiEmbedderConfig(api_key='test_api_key')
 75 |     client = GeminiEmbedder(config=config)
 76 |     client.client = mock_gemini_client
 77 |     return client
 78 | 
 79 | 
 80 | class TestGeminiEmbedderInitialization:
 81 |     """Tests for GeminiEmbedder initialization."""
 82 | 
 83 |     @patch('google.genai.Client')
 84 |     def test_init_with_config(self, mock_client):
 85 |         """Test initialization with a config object."""
 86 |         config = GeminiEmbedderConfig(
 87 |             api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
 88 |         )
 89 |         embedder = GeminiEmbedder(config=config)
 90 | 
 91 |         assert embedder.config == config
 92 |         assert embedder.config.embedding_model == 'custom-model'
 93 |         assert embedder.config.api_key == 'test_api_key'
 94 |         assert embedder.config.embedding_dim == 768
 95 | 
 96 |     @patch('google.genai.Client')
 97 |     def test_init_without_config(self, mock_client):
 98 |         """Test initialization without a config uses defaults."""
 99 |         embedder = GeminiEmbedder()
100 | 
101 |         assert embedder.config is not None
102 |         assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
103 | 
104 |     @patch('google.genai.Client')
105 |     def test_init_with_partial_config(self, mock_client):
106 |         """Test initialization with partial config."""
107 |         config = GeminiEmbedderConfig(api_key='test_api_key')
108 |         embedder = GeminiEmbedder(config=config)
109 | 
110 |         assert embedder.config.api_key == 'test_api_key'
111 |         assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
112 | 
113 | 
114 | class TestGeminiEmbedderCreate:
115 |     """Tests for GeminiEmbedder create method."""
116 | 
117 |     @pytest.mark.asyncio
118 |     async def test_create_calls_api_correctly(
119 |         self,
120 |         gemini_embedder: GeminiEmbedder,
121 |         mock_gemini_client: Any,
122 |         mock_gemini_response: MagicMock,
123 |     ) -> None:
124 |         """Test that create method correctly calls the API and processes the response."""
125 |         # Setup
126 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
127 | 
128 |         # Call method
129 |         result = await gemini_embedder.create('Test input')
130 | 
131 |         # Verify API is called with correct parameters
132 |         mock_gemini_client.aio.models.embed_content.assert_called_once()
133 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
134 |         assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
135 |         assert kwargs['contents'] == ['Test input']
136 | 
137 |         # Verify result is processed correctly
138 |         assert result == mock_gemini_response.embeddings[0].values
139 | 
140 |     @pytest.mark.asyncio
141 |     @patch('google.genai.Client')
142 |     async def test_create_with_custom_model(
143 |         self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
144 |     ) -> None:
145 |         """Test create method with custom embedding model."""
146 |         # Setup embedder with custom model
147 |         config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
148 |         embedder = GeminiEmbedder(config=config)
149 |         embedder.client = mock_gemini_client
150 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
151 | 
152 |         # Call method
153 |         await embedder.create('Test input')
154 | 
155 |         # Verify custom model is used
156 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
157 |         assert kwargs['model'] == 'custom-model'
158 | 
159 |     @pytest.mark.asyncio
160 |     @patch('google.genai.Client')
161 |     async def test_create_with_custom_dimension(
162 |         self, mock_client_class, mock_gemini_client: Any
163 |     ) -> None:
164 |         """Test create method with custom embedding dimension."""
165 |         # Setup embedder with custom dimension
166 |         config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
167 |         embedder = GeminiEmbedder(config=config)
168 |         embedder.client = mock_gemini_client
169 | 
170 |         # Setup mock response with custom dimension
171 |         mock_response = MagicMock()
172 |         mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
173 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
174 | 
175 |         # Call method
176 |         result = await embedder.create('Test input')
177 | 
178 |         # Verify custom dimension is used in config
179 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
180 |         assert kwargs['config'].output_dimensionality == 768
181 | 
182 |         # Verify result has correct dimension
183 |         assert len(result) == 768
184 | 
185 |     @pytest.mark.asyncio
186 |     async def test_create_with_different_input_types(
187 |         self,
188 |         gemini_embedder: GeminiEmbedder,
189 |         mock_gemini_client: Any,
190 |         mock_gemini_response: MagicMock,
191 |     ) -> None:
192 |         """Test create method with different input types."""
193 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
194 | 
195 |         # Test with string
196 |         await gemini_embedder.create('Test string')
197 | 
198 |         # Test with list of strings
199 |         await gemini_embedder.create(['Test', 'List'])
200 | 
201 |         # Test with iterable of integers
202 |         await gemini_embedder.create([1, 2, 3])
203 | 
204 |         # Verify all calls were made
205 |         assert mock_gemini_client.aio.models.embed_content.call_count == 3
206 | 
207 |     @pytest.mark.asyncio
208 |     async def test_create_no_embeddings_error(
209 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
210 |     ) -> None:
211 |         """Test create method handling of no embeddings response."""
212 |         # Setup mock response with no embeddings
213 |         mock_response = MagicMock()
214 |         mock_response.embeddings = []
215 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
216 | 
217 |         # Call method and expect exception
218 |         with pytest.raises(ValueError) as exc_info:
219 |             await gemini_embedder.create('Test input')
220 | 
221 |         assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
222 | 
223 |     @pytest.mark.asyncio
224 |     async def test_create_no_values_error(
225 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
226 |     ) -> None:
227 |         """Test create method handling of embeddings with no values."""
228 |         # Setup mock response with embedding but no values
229 |         mock_embedding = MagicMock()
230 |         mock_embedding.values = None
231 |         mock_response = MagicMock()
232 |         mock_response.embeddings = [mock_embedding]
233 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
234 | 
235 |         # Call method and expect exception
236 |         with pytest.raises(ValueError) as exc_info:
237 |             await gemini_embedder.create('Test input')
238 | 
239 |         assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
240 | 
241 | 
242 | class TestGeminiEmbedderCreateBatch:
243 |     """Tests for GeminiEmbedder create_batch method."""
244 | 
245 |     @pytest.mark.asyncio
246 |     async def test_create_batch_processes_multiple_inputs(
247 |         self,
248 |         gemini_embedder: GeminiEmbedder,
249 |         mock_gemini_client: Any,
250 |         mock_gemini_batch_response: MagicMock,
251 |     ) -> None:
252 |         """Test that create_batch method correctly processes multiple inputs."""
253 |         # Setup
254 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
255 |         input_batch = ['Input 1', 'Input 2', 'Input 3']
256 | 
257 |         # Call method
258 |         result = await gemini_embedder.create_batch(input_batch)
259 | 
260 |         # Verify API is called with correct parameters
261 |         mock_gemini_client.aio.models.embed_content.assert_called_once()
262 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
263 |         assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
264 |         assert kwargs['contents'] == input_batch
265 | 
266 |         # Verify all results are processed correctly
267 |         assert len(result) == 3
268 |         assert result == [
269 |             mock_gemini_batch_response.embeddings[0].values,
270 |             mock_gemini_batch_response.embeddings[1].values,
271 |             mock_gemini_batch_response.embeddings[2].values,
272 |         ]
273 | 
274 |     @pytest.mark.asyncio
275 |     async def test_create_batch_single_input(
276 |         self,
277 |         gemini_embedder: GeminiEmbedder,
278 |         mock_gemini_client: Any,
279 |         mock_gemini_response: MagicMock,
280 |     ) -> None:
281 |         """Test create_batch method with single input."""
282 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
283 |         input_batch = ['Single input']
284 | 
285 |         result = await gemini_embedder.create_batch(input_batch)
286 | 
287 |         assert len(result) == 1
288 |         assert result[0] == mock_gemini_response.embeddings[0].values
289 | 
290 |     @pytest.mark.asyncio
291 |     async def test_create_batch_empty_input(
292 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
293 |     ) -> None:
294 |         """Test create_batch method with empty input."""
295 |         # Setup mock response with no embeddings
296 |         mock_response = MagicMock()
297 |         mock_response.embeddings = []
298 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
299 | 
300 |         input_batch = []
301 | 
302 |         result = await gemini_embedder.create_batch(input_batch)
303 |         assert result == []
304 |         mock_gemini_client.aio.models.embed_content.assert_not_called()
305 | 
306 |     @pytest.mark.asyncio
307 |     async def test_create_batch_no_embeddings_error(
308 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
309 |     ) -> None:
310 |         """Test create_batch method handling of no embeddings response."""
311 |         # Setup mock response with no embeddings
312 |         mock_response = MagicMock()
313 |         mock_response.embeddings = []
314 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
315 | 
316 |         input_batch = ['Input 1', 'Input 2']
317 | 
318 |         with pytest.raises(ValueError) as exc_info:
319 |             await gemini_embedder.create_batch(input_batch)
320 | 
321 |         assert 'No embeddings returned from Gemini API' in str(exc_info.value)
322 | 
323 |     @pytest.mark.asyncio
324 |     async def test_create_batch_empty_values_error(
325 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
326 |     ) -> None:
327 |         """Test create_batch method handling of embeddings with empty values."""
328 |         # Setup mock response with embeddings but empty values
329 |         mock_embedding1 = MagicMock()
330 |         mock_embedding1.values = [0.1, 0.2, 0.3]  # Valid values
331 |         mock_embedding2 = MagicMock()
332 |         mock_embedding2.values = None  # Empty values
333 | 
334 |         # Mock response for the initial batch call
335 |         mock_batch_response = MagicMock()
336 |         mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
337 | 
338 |         # Mock response for individual processing of 'Input 1'
339 |         mock_individual_response_1 = MagicMock()
340 |         mock_individual_response_1.embeddings = [mock_embedding1]
341 | 
342 |         # Mock response for individual processing of 'Input 2' (which has empty values)
343 |         mock_individual_response_2 = MagicMock()
344 |         mock_individual_response_2.embeddings = [mock_embedding2]
345 | 
346 |         # Set side_effect for embed_content to control return values for each call
347 |         mock_gemini_client.aio.models.embed_content.side_effect = [
348 |             mock_batch_response,  # First call for the batch
349 |             mock_individual_response_1,  # Second call for individual item 1
350 |             mock_individual_response_2,  # Third call for individual item 2
351 |         ]
352 | 
353 |         input_batch = ['Input 1', 'Input 2']
354 | 
355 |         with pytest.raises(ValueError) as exc_info:
356 |             await gemini_embedder.create_batch(input_batch)
357 | 
358 |         assert 'Empty embedding values returned' in str(exc_info.value)
359 | 
360 |     @pytest.mark.asyncio
361 |     @patch('google.genai.Client')
362 |     async def test_create_batch_with_custom_model_and_dimension(
363 |         self, mock_client_class, mock_gemini_client: Any
364 |     ) -> None:
365 |         """Test create_batch method with custom model and dimension."""
366 |         # Setup embedder with custom settings
367 |         config = GeminiEmbedderConfig(
368 |             api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
369 |         )
370 |         embedder = GeminiEmbedder(config=config)
371 |         embedder.client = mock_gemini_client
372 | 
373 |         # Setup mock response
374 |         mock_response = MagicMock()
375 |         mock_response.embeddings = [
376 |             create_gemini_embedding(0.1, 512),
377 |             create_gemini_embedding(0.2, 512),
378 |         ]
379 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
380 | 
381 |         input_batch = ['Input 1', 'Input 2']
382 |         result = await embedder.create_batch(input_batch)
383 | 
384 |         # Verify custom settings are used
385 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
386 |         assert kwargs['model'] == 'custom-batch-model'
387 |         assert kwargs['config'].output_dimensionality == 512
388 | 
389 |         # Verify results have correct dimension
390 |         assert len(result) == 2
391 |         assert all(len(embedding) == 512 for embedding in result)
392 | 
393 | 
394 | if __name__ == '__main__':
395 |     pytest.main(['-xvs', __file__])
396 | 
```

--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import os
 18 | import unittest
 19 | from datetime import datetime, timezone
 20 | from unittest.mock import AsyncMock, MagicMock, patch
 21 | 
 22 | import pytest
 23 | 
 24 | from graphiti_core.driver.driver import GraphProvider
 25 | 
 26 | try:
 27 |     from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
 28 | 
 29 |     HAS_FALKORDB = True
 30 | except ImportError:
 31 |     FalkorDriver = None
 32 |     HAS_FALKORDB = False
 33 | 
 34 | 
 35 | class TestFalkorDriver:
 36 |     """Comprehensive test suite for FalkorDB driver."""
 37 | 
 38 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 39 |     def setup_method(self):
 40 |         """Set up test fixtures."""
 41 |         self.mock_client = MagicMock()
 42 |         with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
 43 |             self.driver = FalkorDriver()
 44 |         self.driver.client = self.mock_client
 45 | 
 46 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 47 |     def test_init_with_connection_params(self):
 48 |         """Test initialization with connection parameters."""
 49 |         with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
 50 |             driver = FalkorDriver(
 51 |                 host='test-host', port='1234', username='test-user', password='test-pass'
 52 |             )
 53 |             assert driver.provider == GraphProvider.FALKORDB
 54 |             mock_falkor_db.assert_called_once_with(
 55 |                 host='test-host', port='1234', username='test-user', password='test-pass'
 56 |             )
 57 | 
 58 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 59 |     def test_init_with_falkor_db_instance(self):
 60 |         """Test initialization with a FalkorDB instance."""
 61 |         with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
 62 |             mock_falkor_db = MagicMock()
 63 |             driver = FalkorDriver(falkor_db=mock_falkor_db)
 64 |             assert driver.provider == GraphProvider.FALKORDB
 65 |             assert driver.client is mock_falkor_db
 66 |             mock_falkor_db_class.assert_not_called()
 67 | 
 68 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 69 |     def test_provider(self):
 70 |         """Test driver provider identification."""
 71 |         assert self.driver.provider == GraphProvider.FALKORDB
 72 | 
 73 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 74 |     def test_get_graph_with_name(self):
 75 |         """Test _get_graph with specific graph name."""
 76 |         mock_graph = MagicMock()
 77 |         self.mock_client.select_graph.return_value = mock_graph
 78 | 
 79 |         result = self.driver._get_graph('test_graph')
 80 | 
 81 |         self.mock_client.select_graph.assert_called_once_with('test_graph')
 82 |         assert result is mock_graph
 83 | 
 84 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 85 |     def test_get_graph_with_none_defaults_to_default_database(self):
 86 |         """Test _get_graph with None defaults to default_db."""
 87 |         mock_graph = MagicMock()
 88 |         self.mock_client.select_graph.return_value = mock_graph
 89 | 
 90 |         result = self.driver._get_graph(None)
 91 | 
 92 |         self.mock_client.select_graph.assert_called_once_with('default_db')
 93 |         assert result is mock_graph
 94 | 
 95 |     @pytest.mark.asyncio
 96 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 97 |     async def test_execute_query_success(self):
 98 |         """Test successful query execution."""
 99 |         mock_graph = MagicMock()
100 |         mock_result = MagicMock()
101 |         mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
102 |         mock_result.result_set = [['row1col1', 'row1col2']]
103 |         mock_graph.query = AsyncMock(return_value=mock_result)
104 |         self.mock_client.select_graph.return_value = mock_graph
105 | 
106 |         result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')
107 | 
108 |         mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})
109 | 
110 |         result_set, header, summary = result
111 |         assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
112 |         assert header == ['column1', 'column2']
113 |         assert summary is None
114 | 
115 |     @pytest.mark.asyncio
116 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
117 |     async def test_execute_query_handles_index_already_exists_error(self):
118 |         """Test handling of 'already indexed' error."""
119 |         mock_graph = MagicMock()
120 |         mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
121 |         self.mock_client.select_graph.return_value = mock_graph
122 | 
123 |         with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
124 |             result = await self.driver.execute_query('CREATE INDEX ...')
125 | 
126 |             mock_logger.info.assert_called_once()
127 |             assert result is None
128 | 
129 |     @pytest.mark.asyncio
130 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
131 |     async def test_execute_query_propagates_other_exceptions(self):
132 |         """Test that other exceptions are properly propagated."""
133 |         mock_graph = MagicMock()
134 |         mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
135 |         self.mock_client.select_graph.return_value = mock_graph
136 | 
137 |         with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
138 |             with pytest.raises(Exception, match='Other error'):
139 |                 await self.driver.execute_query('INVALID QUERY')
140 | 
141 |             mock_logger.error.assert_called_once()
142 | 
143 |     @pytest.mark.asyncio
144 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
145 |     async def test_execute_query_converts_datetime_parameters(self):
146 |         """Test that datetime objects in kwargs are converted to ISO strings."""
147 |         mock_graph = MagicMock()
148 |         mock_result = MagicMock()
149 |         mock_result.header = []
150 |         mock_result.result_set = []
151 |         mock_graph.query = AsyncMock(return_value=mock_result)
152 |         self.mock_client.select_graph.return_value = mock_graph
153 | 
154 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
155 | 
156 |         await self.driver.execute_query(
157 |             'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
158 |         )
159 | 
160 |         call_args = mock_graph.query.call_args[0]
161 |         assert call_args[1]['created_at'] == test_datetime.isoformat()
162 | 
163 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
164 |     def test_session_creation(self):
165 |         """Test session creation with specific database."""
166 |         mock_graph = MagicMock()
167 |         self.mock_client.select_graph.return_value = mock_graph
168 | 
169 |         session = self.driver.session()
170 | 
171 |         assert isinstance(session, FalkorDriverSession)
172 |         assert session.graph is mock_graph
173 | 
174 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
175 |     def test_session_creation_with_none_uses_default_database(self):
176 |         """Test session creation with None uses default database."""
177 |         mock_graph = MagicMock()
178 |         self.mock_client.select_graph.return_value = mock_graph
179 | 
180 |         session = self.driver.session()
181 | 
182 |         assert isinstance(session, FalkorDriverSession)
183 | 
184 |     @pytest.mark.asyncio
185 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
186 |     async def test_close_calls_connection_close(self):
187 |         """Test driver close method calls connection close."""
188 |         mock_connection = MagicMock()
189 |         mock_connection.close = AsyncMock()
190 |         self.mock_client.connection = mock_connection
191 | 
192 |         # Ensure hasattr checks work correctly
193 |         del self.mock_client.aclose  # Remove aclose if it exists
194 | 
195 |         with patch('builtins.hasattr') as mock_hasattr:
196 |             # hasattr(self.client, 'aclose') returns False
197 |             # hasattr(self.client.connection, 'aclose') returns False
198 |             # hasattr(self.client.connection, 'close') returns True
199 |             mock_hasattr.side_effect = lambda obj, attr: (
200 |                 attr == 'close' and obj is mock_connection
201 |             )
202 | 
203 |             await self.driver.close()
204 | 
205 |         mock_connection.close.assert_called_once()
206 | 
207 |     @pytest.mark.asyncio
208 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
209 |     async def test_delete_all_indexes(self):
210 |         """Test delete_all_indexes method."""
211 |         with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
212 |             # Return None to simulate no indexes found
213 |             mock_execute.return_value = None
214 | 
215 |             await self.driver.delete_all_indexes()
216 | 
217 |             mock_execute.assert_called_once_with('CALL db.indexes()')
218 | 
219 | 
220 | class TestFalkorDriverSession:
221 |     """Test FalkorDB driver session functionality."""
222 | 
223 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
224 |     def setup_method(self):
225 |         """Set up test fixtures."""
226 |         self.mock_graph = MagicMock()
227 |         self.session = FalkorDriverSession(self.mock_graph)
228 | 
229 |     @pytest.mark.asyncio
230 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
231 |     async def test_session_async_context_manager(self):
232 |         """Test session can be used as async context manager."""
233 |         async with self.session as s:
234 |             assert s is self.session
235 | 
236 |     @pytest.mark.asyncio
237 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
238 |     async def test_close_method(self):
239 |         """Test session close method doesn't raise exceptions."""
240 |         await self.session.close()  # Should not raise
241 | 
242 |     @pytest.mark.asyncio
243 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
244 |     async def test_execute_write_passes_session_and_args(self):
245 |         """Test execute_write method passes session and arguments correctly."""
246 | 
247 |         async def test_func(session, *args, **kwargs):
248 |             assert session is self.session
249 |             assert args == ('arg1', 'arg2')
250 |             assert kwargs == {'key': 'value'}
251 |             return 'result'
252 | 
253 |         result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
254 |         assert result == 'result'
255 | 
256 |     @pytest.mark.asyncio
257 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
258 |     async def test_run_single_query_with_parameters(self):
259 |         """Test running a single query with parameters."""
260 |         self.mock_graph.query = AsyncMock()
261 | 
262 |         await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')
263 | 
264 |         self.mock_graph.query.assert_called_once_with(
265 |             'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
266 |         )
267 | 
268 |     @pytest.mark.asyncio
269 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
270 |     async def test_run_multiple_queries_as_list(self):
271 |         """Test running multiple queries passed as list."""
272 |         self.mock_graph.query = AsyncMock()
273 | 
274 |         queries = [
275 |             ('MATCH (n) RETURN n', {'param1': 'value1'}),
276 |             ('CREATE (n:Node)', {'param2': 'value2'}),
277 |         ]
278 | 
279 |         await self.session.run(queries)
280 | 
281 |         assert self.mock_graph.query.call_count == 2
282 |         calls = self.mock_graph.query.call_args_list
283 |         assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
284 |         assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
285 | 
286 |     @pytest.mark.asyncio
287 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
288 |     async def test_run_converts_datetime_objects_to_iso_strings(self):
289 |         """Test that datetime objects are converted to ISO strings."""
290 |         self.mock_graph.query = AsyncMock()
291 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
292 | 
293 |         await self.session.run(
294 |             'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
295 |         )
296 | 
297 |         self.mock_graph.query.assert_called_once()
298 |         call_args = self.mock_graph.query.call_args[0]
299 |         assert call_args[1]['created_at'] == test_datetime.isoformat()
300 | 
301 | 
302 | class TestDatetimeConversion:
303 |     """Test datetime conversion utility function."""
304 | 
305 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
306 |     def test_convert_datetime_dict(self):
307 |         """Test datetime conversion in nested dictionary."""
308 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
309 | 
310 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
311 |         input_dict = {
312 |             'string_val': 'test',
313 |             'datetime_val': test_datetime,
314 |             'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
315 |         }
316 | 
317 |         result = convert_datetimes_to_strings(input_dict)
318 | 
319 |         assert result['string_val'] == 'test'
320 |         assert result['datetime_val'] == test_datetime.isoformat()
321 |         assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
322 |         assert result['nested_dict']['nested_string'] == 'nested_test'
323 | 
324 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
325 |     def test_convert_datetime_list_and_tuple(self):
326 |         """Test datetime conversion in lists and tuples."""
327 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
328 | 
329 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
330 | 
331 |         # Test list
332 |         input_list = ['test', test_datetime, ['nested', test_datetime]]
333 |         result_list = convert_datetimes_to_strings(input_list)
334 |         assert result_list[0] == 'test'
335 |         assert result_list[1] == test_datetime.isoformat()
336 |         assert result_list[2][1] == test_datetime.isoformat()
337 | 
338 |         # Test tuple
339 |         input_tuple = ('test', test_datetime)
340 |         result_tuple = convert_datetimes_to_strings(input_tuple)
341 |         assert isinstance(result_tuple, tuple)
342 |         assert result_tuple[0] == 'test'
343 |         assert result_tuple[1] == test_datetime.isoformat()
344 | 
345 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
346 |     def test_convert_single_datetime(self):
347 |         """Test datetime conversion for single datetime object."""
348 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
349 | 
350 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
351 |         result = convert_datetimes_to_strings(test_datetime)
352 |         assert result == test_datetime.isoformat()
353 | 
354 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
355 |     def test_convert_other_types_unchanged(self):
356 |         """Test that non-datetime types are returned unchanged."""
357 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
358 | 
359 |         assert convert_datetimes_to_strings('string') == 'string'
360 |         assert convert_datetimes_to_strings(123) == 123
361 |         assert convert_datetimes_to_strings(None) is None
362 |         assert convert_datetimes_to_strings(True) is True
363 | 
364 | 
365 | # Simple integration test
366 | class TestFalkorDriverIntegration:
367 |     """Simple integration test for FalkorDB driver."""
368 | 
369 |     @pytest.mark.asyncio
370 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
371 |     async def test_basic_integration_with_real_falkordb(self):
372 |         """Basic integration test with real FalkorDB instance."""
373 |         pytest.importorskip('falkordb')
374 | 
375 |         falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
376 |         falkor_port = os.getenv('FALKORDB_PORT', '6379')
377 | 
378 |         try:
379 |             driver = FalkorDriver(host=falkor_host, port=falkor_port)
380 | 
381 |             # Test basic query execution
382 |             result = await driver.execute_query('RETURN 1 as test')
383 |             assert result is not None
384 | 
385 |             result_set, header, summary = result
386 |             assert header == ['test']
387 |             assert result_set == [{'test': 1}]
388 | 
389 |             await driver.close()
390 | 
391 |         except Exception as e:
392 |             pytest.skip(f'FalkorDB not available for integration test: {e}')
393 | 
```

--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------

```python
  1 | """Factory classes for creating LLM, Embedder, and Database clients."""
  2 | 
  3 | from openai import AsyncAzureOpenAI
  4 | 
  5 | from config.schema import (
  6 |     DatabaseConfig,
  7 |     EmbedderConfig,
  8 |     LLMConfig,
  9 | )
 10 | 
 11 | # Try to import FalkorDriver if available
 12 | try:
 13 |     from graphiti_core.driver.falkordb_driver import FalkorDriver  # noqa: F401
 14 | 
 15 |     HAS_FALKOR = True
 16 | except ImportError:
 17 |     HAS_FALKOR = False
 18 | 
 19 | # Kuzu support removed - FalkorDB is now the default
 20 | from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
 21 | from graphiti_core.llm_client import LLMClient, OpenAIClient
 22 | from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig
 23 | 
 24 | # Try to import additional providers if available
 25 | try:
 26 |     from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
 27 | 
 28 |     HAS_AZURE_EMBEDDER = True
 29 | except ImportError:
 30 |     HAS_AZURE_EMBEDDER = False
 31 | 
 32 | try:
 33 |     from graphiti_core.embedder.gemini import GeminiEmbedder
 34 | 
 35 |     HAS_GEMINI_EMBEDDER = True
 36 | except ImportError:
 37 |     HAS_GEMINI_EMBEDDER = False
 38 | 
 39 | try:
 40 |     from graphiti_core.embedder.voyage import VoyageAIEmbedder
 41 | 
 42 |     HAS_VOYAGE_EMBEDDER = True
 43 | except ImportError:
 44 |     HAS_VOYAGE_EMBEDDER = False
 45 | 
 46 | try:
 47 |     from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
 48 | 
 49 |     HAS_AZURE_LLM = True
 50 | except ImportError:
 51 |     HAS_AZURE_LLM = False
 52 | 
 53 | try:
 54 |     from graphiti_core.llm_client.anthropic_client import AnthropicClient
 55 | 
 56 |     HAS_ANTHROPIC = True
 57 | except ImportError:
 58 |     HAS_ANTHROPIC = False
 59 | 
 60 | try:
 61 |     from graphiti_core.llm_client.gemini_client import GeminiClient
 62 | 
 63 |     HAS_GEMINI = True
 64 | except ImportError:
 65 |     HAS_GEMINI = False
 66 | 
 67 | try:
 68 |     from graphiti_core.llm_client.groq_client import GroqClient
 69 | 
 70 |     HAS_GROQ = True
 71 | except ImportError:
 72 |     HAS_GROQ = False
 73 | from utils.utils import create_azure_credential_token_provider
 74 | 
 75 | 
 76 | def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
 77 |     """Validate API key is present.
 78 | 
 79 |     Args:
 80 |         provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
 81 |         api_key: The API key to validate
 82 |         logger: Logger instance for output
 83 | 
 84 |     Returns:
 85 |         The validated API key
 86 | 
 87 |     Raises:
 88 |         ValueError: If API key is None or empty
 89 |     """
 90 |     if not api_key:
 91 |         raise ValueError(
 92 |             f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
 93 |         )
 94 | 
 95 |     logger.info(f'Creating {provider_name} client')
 96 | 
 97 |     return api_key
 98 | 
 99 | 
100 | class LLMClientFactory:
101 |     """Factory for creating LLM clients based on configuration."""
102 | 
103 |     @staticmethod
104 |     def create(config: LLMConfig) -> LLMClient:
105 |         """Create an LLM client based on the configured provider."""
106 |         import logging
107 | 
108 |         logger = logging.getLogger(__name__)
109 | 
110 |         provider = config.provider.lower()
111 | 
112 |         match provider:
113 |             case 'openai':
114 |                 if not config.providers.openai:
115 |                     raise ValueError('OpenAI provider configuration not found')
116 | 
117 |                 api_key = config.providers.openai.api_key
118 |                 _validate_api_key('OpenAI', api_key, logger)
119 | 
120 |                 from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
121 | 
122 |                 # Determine appropriate small model based on main model type
123 |                 is_reasoning_model = (
124 |                     config.model.startswith('gpt-5')
125 |                     or config.model.startswith('o1')
126 |                     or config.model.startswith('o3')
127 |                 )
128 |                 small_model = (
129 |                     'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
130 |                 )  # Use reasoning model for small tasks if main model is reasoning
131 | 
132 |                 llm_config = CoreLLMConfig(
133 |                     api_key=api_key,
134 |                     model=config.model,
135 |                     small_model=small_model,
136 |                     temperature=config.temperature,
137 |                     max_tokens=config.max_tokens,
138 |                 )
139 | 
140 |                 # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
141 |                 if is_reasoning_model:
142 |                     return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
143 |                 else:
144 |                     # For non-reasoning models, explicitly pass None to disable these parameters
145 |                     return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)
146 | 
147 |             case 'azure_openai':
148 |                 if not HAS_AZURE_LLM:
149 |                     raise ValueError(
150 |                         'Azure OpenAI LLM client not available in current graphiti-core version'
151 |                     )
152 |                 if not config.providers.azure_openai:
153 |                     raise ValueError('Azure OpenAI provider configuration not found')
154 |                 azure_config = config.providers.azure_openai
155 | 
156 |                 if not azure_config.api_url:
157 |                     raise ValueError('Azure OpenAI API URL is required')
158 | 
159 |                 # Handle Azure AD authentication if enabled
160 |                 api_key: str | None = None
161 |                 azure_ad_token_provider = None
162 |                 if azure_config.use_azure_ad:
163 |                     logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
164 |                     azure_ad_token_provider = create_azure_credential_token_provider()
165 |                 else:
166 |                     api_key = azure_config.api_key
167 |                     _validate_api_key('Azure OpenAI', api_key, logger)
168 | 
169 |                 # Create the Azure OpenAI client first
170 |                 azure_client = AsyncAzureOpenAI(
171 |                     api_key=api_key,
172 |                     azure_endpoint=azure_config.api_url,
173 |                     api_version=azure_config.api_version,
174 |                     azure_deployment=azure_config.deployment_name,
175 |                     azure_ad_token_provider=azure_ad_token_provider,
176 |                 )
177 | 
178 |                 # Then create the LLMConfig
179 |                 from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
180 | 
181 |                 llm_config = CoreLLMConfig(
182 |                     api_key=api_key,
183 |                     base_url=azure_config.api_url,
184 |                     model=config.model,
185 |                     temperature=config.temperature,
186 |                     max_tokens=config.max_tokens,
187 |                 )
188 | 
189 |                 return AzureOpenAILLMClient(
190 |                     azure_client=azure_client,
191 |                     config=llm_config,
192 |                     max_tokens=config.max_tokens,
193 |                 )
194 | 
195 |             case 'anthropic':
196 |                 if not HAS_ANTHROPIC:
197 |                     raise ValueError(
198 |                         'Anthropic client not available in current graphiti-core version'
199 |                     )
200 |                 if not config.providers.anthropic:
201 |                     raise ValueError('Anthropic provider configuration not found')
202 | 
203 |                 api_key = config.providers.anthropic.api_key
204 |                 _validate_api_key('Anthropic', api_key, logger)
205 | 
206 |                 llm_config = GraphitiLLMConfig(
207 |                     api_key=api_key,
208 |                     model=config.model,
209 |                     temperature=config.temperature,
210 |                     max_tokens=config.max_tokens,
211 |                 )
212 |                 return AnthropicClient(config=llm_config)
213 | 
214 |             case 'gemini':
215 |                 if not HAS_GEMINI:
216 |                     raise ValueError('Gemini client not available in current graphiti-core version')
217 |                 if not config.providers.gemini:
218 |                     raise ValueError('Gemini provider configuration not found')
219 | 
220 |                 api_key = config.providers.gemini.api_key
221 |                 _validate_api_key('Gemini', api_key, logger)
222 | 
223 |                 llm_config = GraphitiLLMConfig(
224 |                     api_key=api_key,
225 |                     model=config.model,
226 |                     temperature=config.temperature,
227 |                     max_tokens=config.max_tokens,
228 |                 )
229 |                 return GeminiClient(config=llm_config)
230 | 
231 |             case 'groq':
232 |                 if not HAS_GROQ:
233 |                     raise ValueError('Groq client not available in current graphiti-core version')
234 |                 if not config.providers.groq:
235 |                     raise ValueError('Groq provider configuration not found')
236 | 
237 |                 api_key = config.providers.groq.api_key
238 |                 _validate_api_key('Groq', api_key, logger)
239 | 
240 |                 llm_config = GraphitiLLMConfig(
241 |                     api_key=api_key,
242 |                     base_url=config.providers.groq.api_url,
243 |                     model=config.model,
244 |                     temperature=config.temperature,
245 |                     max_tokens=config.max_tokens,
246 |                 )
247 |                 return GroqClient(config=llm_config)
248 | 
249 |             case _:
250 |                 raise ValueError(f'Unsupported LLM provider: {provider}')
251 | 
252 | 
253 | class EmbedderFactory:
254 |     """Factory for creating Embedder clients based on configuration."""
255 | 
256 |     @staticmethod
257 |     def create(config: EmbedderConfig) -> EmbedderClient:
258 |         """Create an Embedder client based on the configured provider."""
259 |         import logging
260 | 
261 |         logger = logging.getLogger(__name__)
262 | 
263 |         provider = config.provider.lower()
264 | 
265 |         match provider:
266 |             case 'openai':
267 |                 if not config.providers.openai:
268 |                     raise ValueError('OpenAI provider configuration not found')
269 | 
270 |                 api_key = config.providers.openai.api_key
271 |                 _validate_api_key('OpenAI Embedder', api_key, logger)
272 | 
273 |                 from graphiti_core.embedder.openai import OpenAIEmbedderConfig
274 | 
275 |                 embedder_config = OpenAIEmbedderConfig(
276 |                     api_key=api_key,
277 |                     embedding_model=config.model,
278 |                 )
279 |                 return OpenAIEmbedder(config=embedder_config)
280 | 
281 |             case 'azure_openai':
282 |                 if not HAS_AZURE_EMBEDDER:
283 |                     raise ValueError(
284 |                         'Azure OpenAI embedder not available in current graphiti-core version'
285 |                     )
286 |                 if not config.providers.azure_openai:
287 |                     raise ValueError('Azure OpenAI provider configuration not found')
288 |                 azure_config = config.providers.azure_openai
289 | 
290 |                 if not azure_config.api_url:
291 |                     raise ValueError('Azure OpenAI API URL is required')
292 | 
293 |                 # Handle Azure AD authentication if enabled
294 |                 api_key: str | None = None
295 |                 azure_ad_token_provider = None
296 |                 if azure_config.use_azure_ad:
297 |                     logger.info(
298 |                         'Creating Azure OpenAI Embedder client with Azure AD authentication'
299 |                     )
300 |                     azure_ad_token_provider = create_azure_credential_token_provider()
301 |                 else:
302 |                     api_key = azure_config.api_key
303 |                     _validate_api_key('Azure OpenAI Embedder', api_key, logger)
304 | 
305 |                 # Create the Azure OpenAI client first
306 |                 azure_client = AsyncAzureOpenAI(
307 |                     api_key=api_key,
308 |                     azure_endpoint=azure_config.api_url,
309 |                     api_version=azure_config.api_version,
310 |                     azure_deployment=azure_config.deployment_name,
311 |                     azure_ad_token_provider=azure_ad_token_provider,
312 |                 )
313 | 
314 |                 return AzureOpenAIEmbedderClient(
315 |                     azure_client=azure_client,
316 |                     model=config.model or 'text-embedding-3-small',
317 |                 )
318 | 
319 |             case 'gemini':
320 |                 if not HAS_GEMINI_EMBEDDER:
321 |                     raise ValueError(
322 |                         'Gemini embedder not available in current graphiti-core version'
323 |                     )
324 |                 if not config.providers.gemini:
325 |                     raise ValueError('Gemini provider configuration not found')
326 | 
327 |                 api_key = config.providers.gemini.api_key
328 |                 _validate_api_key('Gemini Embedder', api_key, logger)
329 | 
330 |                 from graphiti_core.embedder.gemini import GeminiEmbedderConfig
331 | 
332 |                 gemini_config = GeminiEmbedderConfig(
333 |                     api_key=api_key,
334 |                     embedding_model=config.model or 'models/text-embedding-004',
335 |                     embedding_dim=config.dimensions or 768,
336 |                 )
337 |                 return GeminiEmbedder(config=gemini_config)
338 | 
339 |             case 'voyage':
340 |                 if not HAS_VOYAGE_EMBEDDER:
341 |                     raise ValueError(
342 |                         'Voyage embedder not available in current graphiti-core version'
343 |                     )
344 |                 if not config.providers.voyage:
345 |                     raise ValueError('Voyage provider configuration not found')
346 | 
347 |                 api_key = config.providers.voyage.api_key
348 |                 _validate_api_key('Voyage Embedder', api_key, logger)
349 | 
350 |                 from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig
351 | 
352 |                 voyage_config = VoyageAIEmbedderConfig(
353 |                     api_key=api_key,
354 |                     embedding_model=config.model or 'voyage-3',
355 |                     embedding_dim=config.dimensions or 1024,
356 |                 )
357 |                 return VoyageAIEmbedder(config=voyage_config)
358 | 
359 |             case _:
360 |                 raise ValueError(f'Unsupported Embedder provider: {provider}')
361 | 
362 | 
363 | class DatabaseDriverFactory:
364 |     """Factory for creating Database drivers based on configuration.
365 | 
366 |     Note: This returns configuration dictionaries that can be passed to Graphiti(),
367 |     not driver instances directly, as the drivers require complex initialization.
368 |     """
369 | 
370 |     @staticmethod
371 |     def create_config(config: DatabaseConfig) -> dict:
372 |         """Create database configuration dictionary based on the configured provider."""
373 |         provider = config.provider.lower()
374 | 
375 |         match provider:
376 |             case 'neo4j':
377 |                 # Use Neo4j config if provided, otherwise use defaults
378 |                 if config.providers.neo4j:
379 |                     neo4j_config = config.providers.neo4j
380 |                 else:
381 |                     # Create default Neo4j configuration
382 |                     from config.schema import Neo4jProviderConfig
383 | 
384 |                     neo4j_config = Neo4jProviderConfig()
385 | 
386 |                 # Check for environment variable overrides (for CI/CD compatibility)
387 |                 import os
388 | 
389 |                 uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
390 |                 username = os.environ.get('NEO4J_USER', neo4j_config.username)
391 |                 password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)
392 | 
393 |                 return {
394 |                     'uri': uri,
395 |                     'user': username,
396 |                     'password': password,
397 |                     # Note: database and use_parallel_runtime would need to be passed
398 |                     # to the driver after initialization if supported
399 |                 }
400 | 
401 |             case 'falkordb':
402 |                 if not HAS_FALKOR:
403 |                     raise ValueError(
404 |                         'FalkorDB driver not available in current graphiti-core version'
405 |                     )
406 | 
407 |                 # Use FalkorDB config if provided, otherwise use defaults
408 |                 if config.providers.falkordb:
409 |                     falkor_config = config.providers.falkordb
410 |                 else:
411 |                     # Create default FalkorDB configuration
412 |                     from config.schema import FalkorDBProviderConfig
413 | 
414 |                     falkor_config = FalkorDBProviderConfig()
415 | 
416 |                 # Check for environment variable overrides (for CI/CD compatibility)
417 |                 import os
418 |                 from urllib.parse import urlparse
419 | 
420 |                 uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
421 |                 password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)
422 | 
423 |                 # Parse the URI to extract host and port
424 |                 parsed = urlparse(uri)
425 |                 host = parsed.hostname or 'localhost'
426 |                 port = parsed.port or 6379
427 | 
428 |                 return {
429 |                     'driver': 'falkordb',
430 |                     'host': host,
431 |                     'port': port,
432 |                     'password': password,
433 |                     'database': falkor_config.database,
434 |                 }
435 | 
436 |             case _:
437 |                 raise ValueError(f'Unsupported Database provider: {provider}')
438 | 
```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/anthropic_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import os
 20 | import typing
 21 | from json import JSONDecodeError
 22 | from typing import TYPE_CHECKING, Literal
 23 | 
 24 | from pydantic import BaseModel, ValidationError
 25 | 
 26 | from ..prompts.models import Message
 27 | from .client import LLMClient
 28 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
 29 | from .errors import RateLimitError, RefusalError
 30 | 
 31 | if TYPE_CHECKING:
 32 |     import anthropic
 33 |     from anthropic import AsyncAnthropic
 34 |     from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
 35 | else:
 36 |     try:
 37 |         import anthropic
 38 |         from anthropic import AsyncAnthropic
 39 |         from anthropic.types import MessageParam, ToolChoiceParam, ToolUnionParam
 40 |     except ImportError:
 41 |         raise ImportError(
 42 |             'anthropic is required for AnthropicClient. '
 43 |             'Install it with: pip install graphiti-core[anthropic]'
 44 |         ) from None
 45 | 
 46 | 
 47 | logger = logging.getLogger(__name__)
 48 | 
 49 | AnthropicModel = Literal[
 50 |     'claude-sonnet-4-5-latest',
 51 |     'claude-sonnet-4-5-20250929',
 52 |     'claude-haiku-4-5-latest',
 53 |     'claude-3-7-sonnet-latest',
 54 |     'claude-3-7-sonnet-20250219',
 55 |     'claude-3-5-haiku-latest',
 56 |     'claude-3-5-haiku-20241022',
 57 |     'claude-3-5-sonnet-latest',
 58 |     'claude-3-5-sonnet-20241022',
 59 |     'claude-3-5-sonnet-20240620',
 60 |     'claude-3-opus-latest',
 61 |     'claude-3-opus-20240229',
 62 |     'claude-3-sonnet-20240229',
 63 |     'claude-3-haiku-20240307',
 64 |     'claude-2.1',
 65 |     'claude-2.0',
 66 | ]
 67 | 
 68 | DEFAULT_MODEL: AnthropicModel = 'claude-haiku-4-5-latest'
 69 | 
 70 | # Maximum output tokens for different Anthropic models
 71 | # Based on official Anthropic documentation (as of 2025)
 72 | # Note: These represent standard limits without beta headers.
 73 | # Some models support higher limits with additional configuration (e.g., Claude 3.7 supports
 74 | # 128K with 'anthropic-beta: output-128k-2025-02-19' header, but this is not currently implemented).
 75 | ANTHROPIC_MODEL_MAX_TOKENS = {
 76 |     # Claude 4.5 models - 64K tokens
 77 |     'claude-sonnet-4-5-latest': 65536,
 78 |     'claude-sonnet-4-5-20250929': 65536,
 79 |     'claude-haiku-4-5-latest': 65536,
 80 |     # Claude 3.7 models - standard 64K tokens
 81 |     'claude-3-7-sonnet-latest': 65536,
 82 |     'claude-3-7-sonnet-20250219': 65536,
 83 |     # Claude 3.5 models
 84 |     'claude-3-5-haiku-latest': 8192,
 85 |     'claude-3-5-haiku-20241022': 8192,
 86 |     'claude-3-5-sonnet-latest': 8192,
 87 |     'claude-3-5-sonnet-20241022': 8192,
 88 |     'claude-3-5-sonnet-20240620': 8192,
 89 |     # Claude 3 models - 4K tokens
 90 |     'claude-3-opus-latest': 4096,
 91 |     'claude-3-opus-20240229': 4096,
 92 |     'claude-3-sonnet-20240229': 4096,
 93 |     'claude-3-haiku-20240307': 4096,
 94 |     # Claude 2 models - 4K tokens
 95 |     'claude-2.1': 4096,
 96 |     'claude-2.0': 4096,
 97 | }
 98 | 
 99 | # Default max tokens for models not in the mapping
100 | DEFAULT_ANTHROPIC_MAX_TOKENS = 8192
101 | 
102 | 
103 | class AnthropicClient(LLMClient):
104 |     """
105 |     A client for the Anthropic LLM.
106 | 
107 |     Args:
108 |         config: A configuration object for the LLM.
109 |         cache: Whether to cache the LLM responses.
110 |         client: An optional client instance to use.
111 |         max_tokens: The maximum number of tokens to generate.
112 | 
113 |     Methods:
114 |         generate_response: Generate a response from the LLM.
115 | 
116 |     Notes:
117 |         - If a LLMConfig is not provided, api_key will be pulled from the ANTHROPIC_API_KEY environment
118 |             variable, and all default values will be used for the LLMConfig.
119 | 
120 |     """
121 | 
122 |     model: AnthropicModel
123 | 
124 |     def __init__(
125 |         self,
126 |         config: LLMConfig | None = None,
127 |         cache: bool = False,
128 |         client: AsyncAnthropic | None = None,
129 |         max_tokens: int = DEFAULT_MAX_TOKENS,
130 |     ) -> None:
131 |         if config is None:
132 |             config = LLMConfig()
133 |             config.api_key = os.getenv('ANTHROPIC_API_KEY')
134 |             config.max_tokens = max_tokens
135 | 
136 |         if config.model is None:
137 |             config.model = DEFAULT_MODEL
138 | 
139 |         super().__init__(config, cache)
140 |         # Explicitly set the instance model to the config model to prevent type checking errors
141 |         self.model = typing.cast(AnthropicModel, config.model)
142 | 
143 |         if not client:
144 |             self.client = AsyncAnthropic(
145 |                 api_key=config.api_key,
146 |                 max_retries=1,
147 |             )
148 |         else:
149 |             self.client = client
150 | 
151 |     def _extract_json_from_text(self, text: str) -> dict[str, typing.Any]:
152 |         """Extract JSON from text content.
153 | 
154 |         A helper method to extract JSON from text content, used when tool use fails or
155 |         no response_model is provided.
156 | 
157 |         Args:
158 |             text: The text to extract JSON from
159 | 
160 |         Returns:
161 |             Extracted JSON as a dictionary
162 | 
163 |         Raises:
164 |             ValueError: If JSON cannot be extracted or parsed
165 |         """
166 |         try:
167 |             json_start = text.find('{')
168 |             json_end = text.rfind('}') + 1
169 |             if json_start >= 0 and json_end > json_start:
170 |                 json_str = text[json_start:json_end]
171 |                 return json.loads(json_str)
172 |             else:
173 |                 raise ValueError(f'Could not extract JSON from model response: {text}')
174 |         except (JSONDecodeError, ValueError) as e:
175 |             raise ValueError(f'Could not extract JSON from model response: {text}') from e
176 | 
177 |     def _create_tool(
178 |         self, response_model: type[BaseModel] | None = None
179 |     ) -> tuple[list[ToolUnionParam], ToolChoiceParam]:
180 |         """
181 |         Create a tool definition based on the response_model if provided, or a generic JSON tool if not.
182 | 
183 |         Args:
184 |             response_model: Optional Pydantic model to use for structured output.
185 | 
186 |         Returns:
187 |             A list containing a single tool definition for use with the Anthropic API.
188 |         """
189 |         if response_model is not None:
190 |             # Use the response_model to define the tool
191 |             model_schema = response_model.model_json_schema()
192 |             tool_name = response_model.__name__
193 |             description = model_schema.get('description', f'Extract {tool_name} information')
194 |         else:
195 |             # Create a generic JSON output tool
196 |             tool_name = 'generic_json_output'
197 |             description = 'Output data in JSON format'
198 |             model_schema = {
199 |                 'type': 'object',
200 |                 'additionalProperties': True,
201 |                 'description': 'Any JSON object containing the requested information',
202 |             }
203 | 
204 |         tool = {
205 |             'name': tool_name,
206 |             'description': description,
207 |             'input_schema': model_schema,
208 |         }
209 |         tool_list = [tool]
210 |         tool_list_cast = typing.cast(list[ToolUnionParam], tool_list)
211 |         tool_choice = {'type': 'tool', 'name': tool_name}
212 |         tool_choice_cast = typing.cast(ToolChoiceParam, tool_choice)
213 |         return tool_list_cast, tool_choice_cast
214 | 
215 |     def _get_max_tokens_for_model(self, model: str) -> int:
216 |         """Get the maximum output tokens for a specific Anthropic model.
217 | 
218 |         Args:
219 |             model: The model name to look up
220 | 
221 |         Returns:
222 |             int: The maximum output tokens for the model
223 |         """
224 |         return ANTHROPIC_MODEL_MAX_TOKENS.get(model, DEFAULT_ANTHROPIC_MAX_TOKENS)
225 | 
226 |     def _resolve_max_tokens(self, requested_max_tokens: int | None, model: str) -> int:
227 |         """
228 |         Resolve the maximum output tokens to use based on precedence rules.
229 | 
230 |         Precedence order (highest to lowest):
231 |         1. Explicit max_tokens parameter passed to generate_response()
232 |         2. Instance max_tokens set during client initialization
233 |         3. Model-specific maximum tokens from ANTHROPIC_MODEL_MAX_TOKENS mapping
234 |         4. DEFAULT_ANTHROPIC_MAX_TOKENS as final fallback
235 | 
236 |         Args:
237 |             requested_max_tokens: The max_tokens parameter passed to generate_response()
238 |             model: The model name to look up model-specific limits
239 | 
240 |         Returns:
241 |             int: The resolved maximum tokens to use
242 |         """
243 |         # 1. Use explicit parameter if provided
244 |         if requested_max_tokens is not None:
245 |             return requested_max_tokens
246 | 
247 |         # 2. Use instance max_tokens if set during initialization
248 |         if self.max_tokens is not None:
249 |             return self.max_tokens
250 | 
251 |         # 3. Use model-specific maximum or return DEFAULT_ANTHROPIC_MAX_TOKENS
252 |         return self._get_max_tokens_for_model(model)
253 | 
254 |     async def _generate_response(
255 |         self,
256 |         messages: list[Message],
257 |         response_model: type[BaseModel] | None = None,
258 |         max_tokens: int | None = None,
259 |         model_size: ModelSize = ModelSize.medium,
260 |     ) -> dict[str, typing.Any]:
261 |         """
262 |         Generate a response from the Anthropic LLM using tool-based approach for all requests.
263 | 
264 |         Args:
265 |             messages: List of message objects to send to the LLM.
266 |             response_model: Optional Pydantic model to use for structured output.
267 |             max_tokens: Maximum number of tokens to generate.
268 | 
269 |         Returns:
270 |             Dictionary containing the structured response from the LLM.
271 | 
272 |         Raises:
273 |             RateLimitError: If the rate limit is exceeded.
274 |             RefusalError: If the LLM refuses to respond.
275 |             Exception: If an error occurs during the generation process.
276 |         """
277 |         system_message = messages[0]
278 |         user_messages = [{'role': m.role, 'content': m.content} for m in messages[1:]]
279 |         user_messages_cast = typing.cast(list[MessageParam], user_messages)
280 | 
281 |         # Resolve max_tokens dynamically based on the model's capabilities
282 |         # This allows different models to use their full output capacity
283 |         max_creation_tokens: int = self._resolve_max_tokens(max_tokens, self.model)
284 | 
285 |         try:
286 |             # Create the appropriate tool based on whether response_model is provided
287 |             tools, tool_choice = self._create_tool(response_model)
288 |             result = await self.client.messages.create(
289 |                 system=system_message.content,
290 |                 max_tokens=max_creation_tokens,
291 |                 temperature=self.temperature,
292 |                 messages=user_messages_cast,
293 |                 model=self.model,
294 |                 tools=tools,
295 |                 tool_choice=tool_choice,
296 |             )
297 | 
298 |             # Extract the tool output from the response
299 |             for content_item in result.content:
300 |                 if content_item.type == 'tool_use':
301 |                     if isinstance(content_item.input, dict):
302 |                         tool_args: dict[str, typing.Any] = content_item.input
303 |                     else:
304 |                         tool_args = json.loads(str(content_item.input))
305 |                     return tool_args
306 | 
307 |             # If we didn't get a proper tool_use response, try to extract from text
308 |             for content_item in result.content:
309 |                 if content_item.type == 'text':
310 |                     return self._extract_json_from_text(content_item.text)
311 |                 else:
312 |                     raise ValueError(
313 |                         f'Could not extract structured data from model response: {result.content}'
314 |                     )
315 | 
316 |             # If we get here, we couldn't parse a structured response
317 |             raise ValueError(
318 |                 f'Could not extract structured data from model response: {result.content}'
319 |             )
320 | 
321 |         except anthropic.RateLimitError as e:
322 |             raise RateLimitError(f'Rate limit exceeded. Please try again later. Error: {e}') from e
323 |         except anthropic.APIError as e:
324 |             # Special case for content policy violations. We convert these to RefusalError
325 |             # to bypass the retry mechanism, as retrying policy-violating content will always fail.
326 |             # This avoids wasting API calls and provides more specific error messaging to the user.
327 |             if 'refused to respond' in str(e).lower():
328 |                 raise RefusalError(str(e)) from e
329 |             raise e
330 |         except Exception as e:
331 |             raise e
332 | 
333 |     async def generate_response(
334 |         self,
335 |         messages: list[Message],
336 |         response_model: type[BaseModel] | None = None,
337 |         max_tokens: int | None = None,
338 |         model_size: ModelSize = ModelSize.medium,
339 |         group_id: str | None = None,
340 |         prompt_name: str | None = None,
341 |     ) -> dict[str, typing.Any]:
342 |         """
343 |         Generate a response from the LLM.
344 | 
345 |         Args:
346 |             messages: List of message objects to send to the LLM.
347 |             response_model: Optional Pydantic model to use for structured output.
348 |             max_tokens: Maximum number of tokens to generate.
349 | 
350 |         Returns:
351 |             Dictionary containing the structured response from the LLM.
352 | 
353 |         Raises:
354 |             RateLimitError: If the rate limit is exceeded.
355 |             RefusalError: If the LLM refuses to respond.
356 |             Exception: If an error occurs during the generation process.
357 |         """
358 |         if max_tokens is None:
359 |             max_tokens = self.max_tokens
360 | 
361 |         # Wrap entire operation in tracing span
362 |         with self.tracer.start_span('llm.generate') as span:
363 |             attributes = {
364 |                 'llm.provider': 'anthropic',
365 |                 'model.size': model_size.value,
366 |                 'max_tokens': max_tokens,
367 |             }
368 |             if prompt_name:
369 |                 attributes['prompt.name'] = prompt_name
370 |             span.add_attributes(attributes)
371 | 
372 |             retry_count = 0
373 |             max_retries = 2
374 |             last_error: Exception | None = None
375 | 
376 |             while retry_count <= max_retries:
377 |                 try:
378 |                     response = await self._generate_response(
379 |                         messages, response_model, max_tokens, model_size
380 |                     )
381 | 
382 |                     # If we have a response_model, attempt to validate the response
383 |                     if response_model is not None:
384 |                         # Validate the response against the response_model
385 |                         model_instance = response_model(**response)
386 |                         return model_instance.model_dump()
387 | 
388 |                     # If no validation needed, return the response
389 |                     return response
390 | 
391 |                 except (RateLimitError, RefusalError):
392 |                     # These errors should not trigger retries
393 |                     span.set_status('error', str(last_error))
394 |                     raise
395 |                 except Exception as e:
396 |                     last_error = e
397 | 
398 |                     if retry_count >= max_retries:
399 |                         if isinstance(e, ValidationError):
400 |                             logger.error(
401 |                                 f'Validation error after {retry_count}/{max_retries} attempts: {e}'
402 |                             )
403 |                         else:
404 |                             logger.error(f'Max retries ({max_retries}) exceeded. Last error: {e}')
405 |                         span.set_status('error', str(e))
406 |                         span.record_exception(e)
407 |                         raise e
408 | 
409 |                     if isinstance(e, ValidationError):
410 |                         response_model_cast = typing.cast(type[BaseModel], response_model)
411 |                         error_context = f'The previous response was invalid. Please provide a valid {response_model_cast.__name__} object. Error: {e}'
412 |                     else:
413 |                         error_context = (
414 |                             f'The previous response attempt was invalid. '
415 |                             f'Error type: {e.__class__.__name__}. '
416 |                             f'Error details: {str(e)}. '
417 |                             f'Please try again with a valid response.'
418 |                         )
419 | 
420 |                     # Common retry logic
421 |                     retry_count += 1
422 |                     messages.append(Message(role='user', content=error_context))
423 |                     logger.warning(
424 |                         f'Retrying after error (attempt {retry_count}/{max_retries}): {e}'
425 |                     )
426 | 
427 |             # If we somehow get here, raise the last error
428 |             span.set_status('error', str(last_error))
429 |             raise last_error or Exception('Max retries exceeded with no specific error')
430 | 
```
Page 6/12FirstPrevNextLast