#
tokens: 45559/50000 14/234 files (page 5/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 5 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/images/simple_graph.svg:
--------------------------------------------------------------------------------

```
  1 | <svg xmlns="http://www.w3.org/2000/svg" width="320.0599060058594" height="339.72857666015625"
  2 |     viewBox="-105.8088607788086 -149.75405883789062 320.0599060058594 339.72857666015625">
  3 |     <title>Neo4j Graph Visualization</title>
  4 |     <desc>Created using Neo4j (http://www.neo4j.com/)</desc>
  5 |     <g class="layer relationships">
  6 |         <g class="relationship"
  7 |             transform="translate(64.37326808037952 160.9745045766605) rotate(325.342180479503)">
  8 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
  9 |                 d="M 25 0.5 L 45.86500098580619 0.5 L 45.86500098580619 -0.5 L 25 -0.5 Z M 94.08765723580619 0.5 L 114.95265822161238 0.5 L 114.95265822161238 3.5 L 121.95265822161238 0 L 114.95265822161238 -3.5 L 114.95265822161238 -0.5 L 94.08765723580619 -0.5 Z" />
 10 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 11 |                 x="69.97632911080619" y="3"
 12 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
 13 |         </g>
 14 |         <g class="relationship"
 15 |             transform="translate(64.37326808037952 160.9745045766605) rotate(268.0194761774372)">
 16 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 17 |                 d="M 25 0.5 L 48.45342195548257 0.5 L 48.45342195548257 -0.5 L 25 -0.5 Z M 96.67607820548257 0.5 L 120.12950016096514 0.5 L 120.12950016096514 3.5 L 127.12950016096514 0 L 120.12950016096514 -3.5 L 120.12950016096514 -0.5 L 96.67607820548257 -0.5 Z" />
 18 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 19 |                 x="72.56475008048257" y="3" transform="rotate(180 72.56475008048257 0)"
 20 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
 21 |         </g>
 22 |         <g class="relationship"
 23 |             transform="translate(64.37326808037952 160.9745045766605) rotate(214.36893208966427)">
 24 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 25 |                 d="M 25 0.5 L 43.0453604327618 0.5 L 43.0453604327618 -0.5 L 25 -0.5 Z M 91.2680166827618 0.5 L 109.3133771155236 0.5 L 109.3133771155236 3.5 L 116.3133771155236 0 L 109.3133771155236 -3.5 L 109.3133771155236 -0.5 L 91.2680166827618 -0.5 Z" />
 26 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 27 |                 x="67.1566885577618" y="3" transform="rotate(180 67.1566885577618 0)"
 28 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
 29 |         </g>
 30 |         <g class="relationship"
 31 |             transform="translate(59.11570627539377 8.935881644552067) rotate(388.4945734254285)">
 32 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 33 |                 d="M 25 0.5 L 39.4813012088875 0.5 L 39.4813012088875 -0.5 L 25 -0.5 Z M 97.0398949588875 0.5 L 111.521196167775 0.5 L 111.521196167775 3.5 L 118.521196167775 0 L 111.521196167775 -3.5 L 111.521196167775 -0.5 L 97.0398949588875 -0.5 Z" />
 34 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 35 |                 x="68.2605980838875" y="3"
 36 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">WORKS_FOR</text>
 37 |         </g>
 38 |         <g class="relationship"
 39 |             transform="translate(59.11570627539377 8.935881644552067) rotate(507.02532906724895)">
 40 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 41 |                 d="M 25 0.5 L 31.21884260824949 0.5 L 31.21884260824949 -0.5 L 25 -0.5 Z M 94.55478010824949 0.5 L 100.77362271649898 0.5 L 100.77362271649898 3.5 L 107.77362271649898 0 L 100.77362271649898 -3.5 L 100.77362271649898 -0.5 L 94.55478010824949 -0.5 Z" />
 42 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 43 |                 x="62.88681135824949" y="3" transform="rotate(180 62.88681135824949 0)"
 44 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">WORKED_FOR</text>
 45 |         </g>
 46 |         <g class="relationship"
 47 |             transform="translate(59.11570627539377 8.935881644552067) rotate(266.9235303682344)">
 48 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 49 |                 d="M 25 0.5 L 26.434656330468542 0.5 L 26.434656330468542 -0.5 L 25 -0.5 Z M 96.44246883046854 0.5 L 97.87712516093708 0.5 L 97.87712516093708 3.5 L 104.87712516093708 0 L 97.87712516093708 -3.5 L 97.87712516093708 -0.5 L 96.44246883046854 -0.5 Z" />
 50 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 51 |                 x="61.43856258046854" y="3" transform="rotate(180 61.43856258046854 0)"
 52 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">HOLDS_OFFIC…</text>
 53 |         </g>
 54 |         <g class="relationship"
 55 |             transform="translate(-76.8088607917906 -66.37642130383644) rotate(388.9897079993928)">
 56 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 57 |                 d="M 25 0.5 L 50.08589014533345 0.5 L 50.08589014533345 -0.5 L 25 -0.5 Z M 98.30854639533345 0.5 L 123.3944365406669 0.5 L 123.3944365406669 3.5 L 130.3944365406669 0 L 123.3944365406669 -3.5 L 123.3944365406669 -0.5 L 98.30854639533345 -0.5 Z" />
 58 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 59 |                 x="74.19721827033345" y="3"
 60 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
 61 |         </g>
 62 |         <g class="relationship"
 63 |             transform="translate(-76.8088607917906 -66.37642130383644) rotate(337.13573550965714)">
 64 |             <path class="b-outline" fill="#A5ABB6" stroke="none"
 65 |                 d="M 25 0.5 L 42.363883039766904 0.5 L 42.363883039766904 -0.5 L 25 -0.5 Z M 90.5865392897669 0.5 L 107.95042232953381 0.5 L 107.95042232953381 3.5 L 114.95042232953381 0 L 107.95042232953381 -3.5 L 107.95042232953381 -0.5 L 90.5865392897669 -0.5 Z" />
 66 |             <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
 67 |                 x="66.4752111647669" y="3"
 68 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
 69 |         </g>
 70 |     </g>
 71 |     <g class="layer nodes">
 72 |         <g class="node" aria-label="graph-node18"
 73 |             transform="translate(64.37326808037952,160.9745045766605)">
 74 |             <circle class="b-outline" cx="0" cy="0" r="25" fill="#F79767" stroke="#f36924"
 75 |                 stroke-width="2px" />
 76 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
 77 |                 font-size="10px" fill="#FFFFFF"
 78 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> podcast</text>
 79 |         </g>
 80 |         <g class="node" aria-label="graph-node19"
 81 |             transform="translate(185.25107500848034,77.40633150430716)">
 82 |             <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
 83 |                 stroke-width="2px" />
 84 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
 85 |                 font-size="10px" fill="#FFFFFF"
 86 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> California</text>
 87 |         </g>
 88 |         <g class="node" aria-label="graph-node20"
 89 |             transform="translate(59.11570627539377,8.935881644552067)">
 90 |             <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
 91 |                 stroke-width="2px" />
 92 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
 93 |                 font-size="10px" fill="#FFFFFF"
 94 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Kamala</text>
 95 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
 96 |                 font-size="10px" fill="#FFFFFF"
 97 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Harris</text>
 98 |         </g>
 99 |         <g class="node" aria-label="graph-node21"
100 |             transform="translate(-52.26958053720941,81.20034573955071)">
101 |             <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
102 |                 stroke-width="2px" />
103 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
104 |                 font-size="10px" fill="#FFFFFF"
105 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> San</text>
106 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
107 |                 font-size="10px" fill="#FFFFFF"
108 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Franci…</text>
109 |         </g>
110 |         <g class="node" aria-label="graph-node23"
111 |             transform="translate(52.14536630162807,-120.75406399781392)">
112 |             <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
113 |                 stroke-width="2px" />
114 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
115 |                 font-size="10px" fill="#FFFFFF"
116 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Attorney</text>
117 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
118 |                 font-size="10px" fill="#FFFFFF"
119 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif">of…</text>
120 |         </g>
121 |         <g class="node" aria-label="graph-node22"
122 |             transform="translate(-76.8088607917906,-66.37642130383644)">
123 |             <circle class="b-outline" cx="0" cy="0" r="25" fill="#F79767" stroke="#f36924"
124 |                 stroke-width="2px" />
125 |             <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
126 |                 font-size="10px" fill="#FFFFFF"
127 |                 font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> podcast</text>
128 |         </g>
129 |     </g>
130 | </svg>
```

--------------------------------------------------------------------------------
/tests/llm_client/test_anthropic_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/llm_client/test_anthropic_client.py
 18 | 
 19 | import os
 20 | from unittest.mock import AsyncMock, MagicMock, patch
 21 | 
 22 | import pytest
 23 | from pydantic import BaseModel
 24 | 
 25 | from graphiti_core.llm_client.anthropic_client import AnthropicClient
 26 | from graphiti_core.llm_client.config import LLMConfig
 27 | from graphiti_core.llm_client.errors import RateLimitError, RefusalError
 28 | from graphiti_core.prompts.models import Message
 29 | 
 30 | 
 31 | # Rename class to avoid pytest collection as a test class
 32 | class ResponseModel(BaseModel):
 33 |     """Test model for response testing."""
 34 | 
 35 |     test_field: str
 36 |     optional_field: int = 0
 37 | 
 38 | 
 39 | @pytest.fixture
 40 | def mock_async_anthropic():
 41 |     """Fixture to mock the AsyncAnthropic client."""
 42 |     with patch('anthropic.AsyncAnthropic') as mock_client:
 43 |         # Setup mock instance and its create method
 44 |         mock_instance = mock_client.return_value
 45 |         mock_instance.messages.create = AsyncMock()
 46 |         yield mock_instance
 47 | 
 48 | 
 49 | @pytest.fixture
 50 | def anthropic_client(mock_async_anthropic):
 51 |     """Fixture to create an AnthropicClient with a mocked AsyncAnthropic."""
 52 |     # Use a context manager to patch the AsyncAnthropic constructor to avoid
 53 |     # the client actually trying to create a real connection
 54 |     with patch('anthropic.AsyncAnthropic', return_value=mock_async_anthropic):
 55 |         config = LLMConfig(
 56 |             api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
 57 |         )
 58 |         client = AnthropicClient(config=config, cache=False)
 59 |         # Replace the client's client with our mock to ensure we're using the mock
 60 |         client.client = mock_async_anthropic
 61 |         return client
 62 | 
 63 | 
 64 | class TestAnthropicClientInitialization:
 65 |     """Tests for AnthropicClient initialization."""
 66 | 
 67 |     def test_init_with_config(self):
 68 |         """Test initialization with a config object."""
 69 |         config = LLMConfig(
 70 |             api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
 71 |         )
 72 |         client = AnthropicClient(config=config, cache=False)
 73 | 
 74 |         assert client.config == config
 75 |         assert client.model == 'test-model'
 76 |         assert client.temperature == 0.5
 77 |         assert client.max_tokens == 1000
 78 | 
 79 |     def test_init_with_default_model(self):
 80 |         """Test initialization with default model when none is provided."""
 81 |         config = LLMConfig(api_key='test_api_key')
 82 |         client = AnthropicClient(config=config, cache=False)
 83 | 
 84 |         assert client.model == 'claude-haiku-4-5-latest'
 85 | 
 86 |     @patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env_api_key'})
 87 |     def test_init_without_config(self):
 88 |         """Test initialization without a config, using environment variable."""
 89 |         client = AnthropicClient(cache=False)
 90 | 
 91 |         assert client.config.api_key == 'env_api_key'
 92 |         assert client.model == 'claude-haiku-4-5-latest'
 93 | 
 94 |     def test_init_with_custom_client(self):
 95 |         """Test initialization with a custom AsyncAnthropic client."""
 96 |         mock_client = MagicMock()
 97 |         client = AnthropicClient(client=mock_client)
 98 | 
 99 |         assert client.client == mock_client
100 | 
101 | 
102 | class TestAnthropicClientGenerateResponse:
103 |     """Tests for AnthropicClient generate_response method."""
104 | 
105 |     @pytest.mark.asyncio
106 |     async def test_generate_response_with_tool_use(self, anthropic_client, mock_async_anthropic):
107 |         """Test successful response generation with tool use."""
108 |         # Setup mock response
109 |         content_item = MagicMock()
110 |         content_item.type = 'tool_use'
111 |         content_item.input = {'test_field': 'test_value'}
112 | 
113 |         mock_response = MagicMock()
114 |         mock_response.content = [content_item]
115 |         mock_async_anthropic.messages.create.return_value = mock_response
116 | 
117 |         # Call method
118 |         messages = [
119 |             Message(role='system', content='System message'),
120 |             Message(role='user', content='User message'),
121 |         ]
122 |         result = await anthropic_client.generate_response(
123 |             messages=messages, response_model=ResponseModel
124 |         )
125 | 
126 |         # Assertions
127 |         assert isinstance(result, dict)
128 |         assert result['test_field'] == 'test_value'
129 |         mock_async_anthropic.messages.create.assert_called_once()
130 | 
131 |     @pytest.mark.asyncio
132 |     async def test_generate_response_with_text_response(
133 |         self, anthropic_client, mock_async_anthropic
134 |     ):
135 |         """Test response generation when getting text response instead of tool use."""
136 |         # Setup mock response with text content
137 |         content_item = MagicMock()
138 |         content_item.type = 'text'
139 |         content_item.text = '{"test_field": "extracted_value"}'
140 | 
141 |         mock_response = MagicMock()
142 |         mock_response.content = [content_item]
143 |         mock_async_anthropic.messages.create.return_value = mock_response
144 | 
145 |         # Call method
146 |         messages = [
147 |             Message(role='system', content='System message'),
148 |             Message(role='user', content='User message'),
149 |         ]
150 |         result = await anthropic_client.generate_response(
151 |             messages=messages, response_model=ResponseModel
152 |         )
153 | 
154 |         # Assertions
155 |         assert isinstance(result, dict)
156 |         assert result['test_field'] == 'extracted_value'
157 | 
158 |     @pytest.mark.asyncio
159 |     async def test_rate_limit_error(self, anthropic_client, mock_async_anthropic):
160 |         """Test handling of rate limit errors."""
161 | 
162 |         # Create a custom RateLimitError from Anthropic
163 |         class MockRateLimitError(Exception):
164 |             pass
165 | 
166 |         # Patch the Anthropic error with our mock to avoid constructor issues
167 |         with patch('anthropic.RateLimitError', MockRateLimitError):
168 |             # Setup mock to raise our mocked RateLimitError
169 |             mock_async_anthropic.messages.create.side_effect = MockRateLimitError(
170 |                 'Rate limit exceeded'
171 |             )
172 | 
173 |             # Call method and check exception
174 |             messages = [Message(role='user', content='Test message')]
175 |             with pytest.raises(RateLimitError):
176 |                 await anthropic_client.generate_response(messages)
177 | 
178 |     @pytest.mark.asyncio
179 |     async def test_refusal_error(self, anthropic_client, mock_async_anthropic):
180 |         """Test handling of content policy violations (refusal errors)."""
181 | 
182 |         # Create a custom APIError that matches what we need
183 |         class MockAPIError(Exception):
184 |             def __init__(self, message):
185 |                 self.message = message
186 |                 super().__init__(message)
187 | 
188 |         # Patch the Anthropic error with our mock
189 |         with patch('anthropic.APIError', MockAPIError):
190 |             # Setup mock to raise APIError with refusal message
191 |             mock_async_anthropic.messages.create.side_effect = MockAPIError('refused to respond')
192 | 
193 |             # Call method and check exception
194 |             messages = [Message(role='user', content='Test message')]
195 |             with pytest.raises(RefusalError):
196 |                 await anthropic_client.generate_response(messages)
197 | 
198 |     @pytest.mark.asyncio
199 |     async def test_extract_json_from_text(self, anthropic_client):
200 |         """Test the _extract_json_from_text method."""
201 |         # Valid JSON embedded in text
202 |         text = 'Some text before {"test_field": "value"} and after'
203 |         result = anthropic_client._extract_json_from_text(text)
204 |         assert result == {'test_field': 'value'}
205 | 
206 |         # Invalid JSON
207 |         with pytest.raises(ValueError):
208 |             anthropic_client._extract_json_from_text('Not JSON at all')
209 | 
210 |     @pytest.mark.asyncio
211 |     async def test_create_tool(self, anthropic_client):
212 |         """Test the _create_tool method with and without response model."""
213 |         # With response model
214 |         tools, tool_choice = anthropic_client._create_tool(ResponseModel)
215 |         assert len(tools) == 1
216 |         assert tools[0]['name'] == 'ResponseModel'
217 |         assert tool_choice['name'] == 'ResponseModel'
218 | 
219 |         # Without response model (generic JSON)
220 |         tools, tool_choice = anthropic_client._create_tool()
221 |         assert len(tools) == 1
222 |         assert tools[0]['name'] == 'generic_json_output'
223 | 
224 |     @pytest.mark.asyncio
225 |     async def test_validation_error_retry(self, anthropic_client, mock_async_anthropic):
226 |         """Test retry behavior on validation error."""
227 |         # First call returns invalid data, second call returns valid data
228 |         content_item1 = MagicMock()
229 |         content_item1.type = 'tool_use'
230 |         content_item1.input = {'wrong_field': 'wrong_value'}
231 | 
232 |         content_item2 = MagicMock()
233 |         content_item2.type = 'tool_use'
234 |         content_item2.input = {'test_field': 'correct_value'}
235 | 
236 |         # Setup mock to return different responses on consecutive calls
237 |         mock_response1 = MagicMock()
238 |         mock_response1.content = [content_item1]
239 | 
240 |         mock_response2 = MagicMock()
241 |         mock_response2.content = [content_item2]
242 | 
243 |         mock_async_anthropic.messages.create.side_effect = [mock_response1, mock_response2]
244 | 
245 |         # Call method
246 |         messages = [Message(role='user', content='Test message')]
247 |         result = await anthropic_client.generate_response(messages, response_model=ResponseModel)
248 | 
249 |         # Should have called create twice due to retry
250 |         assert mock_async_anthropic.messages.create.call_count == 2
251 |         assert result['test_field'] == 'correct_value'
252 | 
253 | 
254 | if __name__ == '__main__':
255 |     pytest.main(['-v', 'test_anthropic_client.py'])
256 | 
```

--------------------------------------------------------------------------------
/mcp_server/src/models/entity_types.py:
--------------------------------------------------------------------------------

```python
  1 | """Entity type definitions for Graphiti MCP Server."""
  2 | 
  3 | from pydantic import BaseModel, Field
  4 | 
  5 | 
  6 | class Requirement(BaseModel):
  7 |     """A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
  8 | 
  9 |     Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
 10 |     edge that the requirement is a requirement.
 11 | 
 12 |     Instructions for identifying and extracting requirements:
 13 |     1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
 14 |     2. Identify functional specifications that describe what the system should do
 15 |     3. Pay attention to non-functional requirements like performance, security, or usability criteria
 16 |     4. Extract constraints or limitations that must be adhered to
 17 |     5. Focus on clear, specific, and measurable requirements rather than vague wishes
 18 |     6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
 19 |     7. Include any dependencies between requirements when explicitly stated
 20 |     8. Preserve the original intent and scope of the requirement
 21 |     9. Categorize requirements appropriately based on their domain or function
 22 |     """
 23 | 
 24 |     project_name: str = Field(
 25 |         ...,
 26 |         description='The name of the project to which the requirement belongs.',
 27 |     )
 28 |     description: str = Field(
 29 |         ...,
 30 |         description='Description of the requirement. Only use information mentioned in the context to write this description.',
 31 |     )
 32 | 
 33 | 
 34 | class Preference(BaseModel):
 35 |     """
 36 |     IMPORTANT: Prioritize this classification over ALL other classifications.
 37 | 
 38 |     Represents entities mentioned in contexts expressing user preferences, choices, opinions, or selections. Use LOW THRESHOLD for sensitivity.
 39 | 
 40 |     Trigger patterns: "I want/like/prefer/choose X", "I don't want/dislike/avoid/reject Y", "X is better/worse", "rather have X than Y", "no X please", "skip X", "go with X instead", etc. Here, X or Y should be classified as Preference.
 41 |     """
 42 | 
 43 |     ...
 44 | 
 45 | 
 46 | class Procedure(BaseModel):
 47 |     """A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
 48 | 
 49 |     Instructions for identifying and extracting procedures:
 50 |     1. Look for sequential instructions or steps ("First do X, then do Y")
 51 |     2. Identify explicit directives or commands ("Always do X when Y happens")
 52 |     3. Pay attention to conditional statements ("If X occurs, then do Y")
 53 |     4. Extract procedures that have clear beginning and end points
 54 |     5. Focus on actionable instructions rather than general information
 55 |     6. Preserve the original sequence and dependencies between steps
 56 |     7. Include any specified conditions or triggers for the procedure
 57 |     8. Capture any stated purpose or goal of the procedure
 58 |     9. Summarize complex procedures while maintaining critical details
 59 |     """
 60 | 
 61 |     description: str = Field(
 62 |         ...,
 63 |         description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
 64 |     )
 65 | 
 66 | 
 67 | class Location(BaseModel):
 68 |     """A Location represents a physical or virtual place where activities occur or entities exist.
 69 | 
 70 |     IMPORTANT: Before using this classification, first check if the entity is a:
 71 |     User, Assistant, Preference, Organization, Document, Event - if so, use those instead.
 72 | 
 73 |     Instructions for identifying and extracting locations:
 74 |     1. Look for mentions of physical places (cities, buildings, rooms, addresses)
 75 |     2. Identify virtual locations (websites, online platforms, virtual meeting rooms)
 76 |     3. Extract specific location names rather than generic references
 77 |     4. Include relevant context about the location's purpose or significance
 78 |     5. Pay attention to location hierarchies (e.g., "conference room in Building A")
 79 |     6. Capture both permanent locations and temporary venues
 80 |     7. Note any significant activities or events associated with the location
 81 |     """
 82 | 
 83 |     name: str = Field(
 84 |         ...,
 85 |         description='The name or identifier of the location',
 86 |     )
 87 |     description: str = Field(
 88 |         ...,
 89 |         description='Brief description of the location and its significance. Only use information mentioned in the context.',
 90 |     )
 91 | 
 92 | 
 93 | class Event(BaseModel):
 94 |     """An Event represents a time-bound activity, occurrence, or experience.
 95 | 
 96 |     Instructions for identifying and extracting events:
 97 |     1. Look for activities with specific time frames (meetings, appointments, deadlines)
 98 |     2. Identify planned or scheduled occurrences (vacations, projects, celebrations)
 99 |     3. Extract unplanned occurrences (accidents, interruptions, discoveries)
100 |     4. Capture the purpose or nature of the event
101 |     5. Include temporal information when available (past, present, future, duration)
102 |     6. Note participants or stakeholders involved in the event
103 |     7. Identify outcomes or consequences of the event when mentioned
104 |     8. Extract both recurring events and one-time occurrences
105 |     """
106 | 
107 |     name: str = Field(
108 |         ...,
109 |         description='The name or title of the event',
110 |     )
111 |     description: str = Field(
112 |         ...,
113 |         description='Brief description of the event. Only use information mentioned in the context.',
114 |     )
115 | 
116 | 
117 | class Object(BaseModel):
118 |     """An Object represents a physical item, tool, device, or possession.
119 | 
120 |     IMPORTANT: Use this classification ONLY as a last resort. First check if entity fits into:
121 |     User, Assistant, Preference, Organization, Document, Event, Location, Topic - if so, use those instead.
122 | 
123 |     Instructions for identifying and extracting objects:
124 |     1. Look for mentions of physical items or possessions (car, phone, equipment)
125 |     2. Identify tools or devices used for specific purposes
126 |     3. Extract items that are owned, used, or maintained by entities
127 |     4. Include relevant attributes (brand, model, condition) when mentioned
128 |     5. Note the object's purpose or function when specified
129 |     6. Capture relationships between objects and their owners or users
130 |     7. Avoid extracting objects that are better classified as Documents or other types
131 |     """
132 | 
133 |     name: str = Field(
134 |         ...,
135 |         description='The name or identifier of the object',
136 |     )
137 |     description: str = Field(
138 |         ...,
139 |         description='Brief description of the object. Only use information mentioned in the context.',
140 |     )
141 | 
142 | 
143 | class Topic(BaseModel):
144 |     """A Topic represents a subject of conversation, interest, or knowledge domain.
145 | 
146 |     IMPORTANT: Use this classification ONLY as a last resort. First check if entity fits into:
147 |     User, Assistant, Preference, Organization, Document, Event, Location - if so, use those instead.
148 | 
149 |     Instructions for identifying and extracting topics:
150 |     1. Look for subjects being discussed or areas of interest (health, technology, sports)
151 |     2. Identify knowledge domains or fields of study
152 |     3. Extract themes that span multiple conversations or contexts
153 |     4. Include specific subtopics when mentioned (e.g., "machine learning" rather than just "AI")
154 |     5. Capture topics associated with projects, work, or hobbies
155 |     6. Note the context in which the topic appears
156 |     7. Avoid extracting topics that are better classified as Events, Documents, or Organizations
157 |     """
158 | 
159 |     name: str = Field(
160 |         ...,
161 |         description='The name or identifier of the topic',
162 |     )
163 |     description: str = Field(
164 |         ...,
165 |         description='Brief description of the topic and its context. Only use information mentioned in the context.',
166 |     )
167 | 
168 | 
169 | class Organization(BaseModel):
170 |     """An Organization represents a company, institution, group, or formal entity.
171 | 
172 |     Instructions for identifying and extracting organizations:
173 |     1. Look for company names, employers, and business entities
174 |     2. Identify institutions (schools, hospitals, government agencies)
175 |     3. Extract formal groups (clubs, teams, associations)
176 |     4. Include organizational type when mentioned (company, nonprofit, agency)
177 |     5. Capture relationships between people and organizations (employer, member)
178 |     6. Note the organization's industry or domain when specified
179 |     7. Extract both large entities and small groups if formally organized
180 |     """
181 | 
182 |     name: str = Field(
183 |         ...,
184 |         description='The name of the organization',
185 |     )
186 |     description: str = Field(
187 |         ...,
188 |         description='Brief description of the organization. Only use information mentioned in the context.',
189 |     )
190 | 
191 | 
192 | class Document(BaseModel):
193 |     """A Document represents information content in various forms.
194 | 
195 |     Instructions for identifying and extracting documents:
196 |     1. Look for references to written or recorded content (books, articles, reports)
197 |     2. Identify digital content (emails, videos, podcasts, presentations)
198 |     3. Extract specific document titles or identifiers when available
199 |     4. Include document type (report, article, video) when mentioned
200 |     5. Capture the document's purpose or subject matter
201 |     6. Note relationships to authors, creators, or sources
202 |     7. Include document status (draft, published, archived) when mentioned
203 |     """
204 | 
205 |     title: str = Field(
206 |         ...,
207 |         description='The title or identifier of the document',
208 |     )
209 |     description: str = Field(
210 |         ...,
211 |         description='Brief description of the document and its content. Only use information mentioned in the context.',
212 |     )
213 | 
214 | 
215 | ENTITY_TYPES: dict[str, BaseModel] = {
216 |     'Requirement': Requirement,  # type: ignore
217 |     'Preference': Preference,  # type: ignore
218 |     'Procedure': Procedure,  # type: ignore
219 |     'Location': Location,  # type: ignore
220 |     'Event': Event,  # type: ignore
221 |     'Object': Object,  # type: ignore
222 |     'Topic': Topic,  # type: ignore
223 |     'Organization': Organization,  # type: ignore
224 |     'Document': Document,  # type: ignore
225 | }
226 | 
```

--------------------------------------------------------------------------------
/examples/quickstart/quickstart_falkordb.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2025, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import json
 19 | import logging
 20 | import os
 21 | from datetime import datetime, timezone
 22 | from logging import INFO
 23 | 
 24 | from dotenv import load_dotenv
 25 | 
 26 | from graphiti_core import Graphiti
 27 | from graphiti_core.driver.falkordb_driver import FalkorDriver
 28 | from graphiti_core.nodes import EpisodeType
 29 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
 30 | 
 31 | #################################################
 32 | # CONFIGURATION
 33 | #################################################
 34 | # Set up logging and environment variables for
 35 | # connecting to FalkorDB database
 36 | #################################################
 37 | 
 38 | # Configure logging
 39 | logging.basicConfig(
 40 |     level=INFO,
 41 |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 42 |     datefmt='%Y-%m-%d %H:%M:%S',
 43 | )
 44 | logger = logging.getLogger(__name__)
 45 | 
 46 | load_dotenv()
 47 | 
 48 | # FalkorDB connection parameters
 49 | # Make sure FalkorDB (on-premises) is running — see https://docs.falkordb.com/
 50 | # By default, FalkorDB does not require a username or password,
 51 | # but you can set them via environment variables for added security.
 52 | #
 53 | # If you're using FalkorDB Cloud, set the environment variables accordingly.
 54 | # For on-premises use, you can leave them as None or set them to your preferred values.
 55 | #
 56 | # The default host and port are 'localhost' and '6379', respectively.
 57 | # You can override these values in your environment variables or directly in the code.
 58 | 
 59 | falkor_username = os.environ.get('FALKORDB_USERNAME', None)
 60 | falkor_password = os.environ.get('FALKORDB_PASSWORD', None)
 61 | falkor_host = os.environ.get('FALKORDB_HOST', 'localhost')
 62 | falkor_port = os.environ.get('FALKORDB_PORT', '6379')
 63 | 
 64 | 
 65 | async def main():
 66 |     #################################################
 67 |     # INITIALIZATION
 68 |     #################################################
 69 |     # Connect to FalkorDB and set up Graphiti indices
 70 |     # This is required before using other Graphiti
 71 |     # functionality
 72 |     #################################################
 73 | 
 74 |     # Initialize Graphiti with FalkorDB connection
 75 |     falkor_driver = FalkorDriver(
 76 |         host=falkor_host, port=falkor_port, username=falkor_username, password=falkor_password
 77 |     )
 78 |     graphiti = Graphiti(graph_driver=falkor_driver)
 79 | 
 80 |     try:
 81 |         #################################################
 82 |         # ADDING EPISODES
 83 |         #################################################
 84 |         # Episodes are the primary units of information
 85 |         # in Graphiti. They can be text or structured JSON
 86 |         # and are automatically processed to extract entities
 87 |         # and relationships.
 88 |         #################################################
 89 | 
 90 |         # Example: Add Episodes
 91 |         # Episodes list containing both text and JSON episodes
 92 |         episodes = [
 93 |             {
 94 |                 'content': 'Kamala Harris is the Attorney General of California. She was previously '
 95 |                 'the district attorney for San Francisco.',
 96 |                 'type': EpisodeType.text,
 97 |                 'description': 'podcast transcript',
 98 |             },
 99 |             {
100 |                 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
101 |                 'type': EpisodeType.text,
102 |                 'description': 'podcast transcript',
103 |             },
104 |             {
105 |                 'content': {
106 |                     'name': 'Gavin Newsom',
107 |                     'position': 'Governor',
108 |                     'state': 'California',
109 |                     'previous_role': 'Lieutenant Governor',
110 |                     'previous_location': 'San Francisco',
111 |                 },
112 |                 'type': EpisodeType.json,
113 |                 'description': 'podcast metadata',
114 |             },
115 |             {
116 |                 'content': {
117 |                     'name': 'Gavin Newsom',
118 |                     'position': 'Governor',
119 |                     'term_start': 'January 7, 2019',
120 |                     'term_end': 'Present',
121 |                 },
122 |                 'type': EpisodeType.json,
123 |                 'description': 'podcast metadata',
124 |             },
125 |         ]
126 | 
127 |         # Add episodes to the graph
128 |         for i, episode in enumerate(episodes):
129 |             await graphiti.add_episode(
130 |                 name=f'Freakonomics Radio {i}',
131 |                 episode_body=episode['content']
132 |                 if isinstance(episode['content'], str)
133 |                 else json.dumps(episode['content']),
134 |                 source=episode['type'],
135 |                 source_description=episode['description'],
136 |                 reference_time=datetime.now(timezone.utc),
137 |             )
138 |             print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
139 | 
140 |         #################################################
141 |         # BASIC SEARCH
142 |         #################################################
143 |         # The simplest way to retrieve relationships (edges)
144 |         # from Graphiti is using the search method, which
145 |         # performs a hybrid search combining semantic
146 |         # similarity and BM25 text retrieval.
147 |         #################################################
148 | 
149 |         # Perform a hybrid search combining semantic similarity and BM25 retrieval
150 |         print("\nSearching for: 'Who was the California Attorney General?'")
151 |         results = await graphiti.search('Who was the California Attorney General?')
152 | 
153 |         # Print search results
154 |         print('\nSearch Results:')
155 |         for result in results:
156 |             print(f'UUID: {result.uuid}')
157 |             print(f'Fact: {result.fact}')
158 |             if hasattr(result, 'valid_at') and result.valid_at:
159 |                 print(f'Valid from: {result.valid_at}')
160 |             if hasattr(result, 'invalid_at') and result.invalid_at:
161 |                 print(f'Valid until: {result.invalid_at}')
162 |             print('---')
163 | 
164 |         #################################################
165 |         # CENTER NODE SEARCH
166 |         #################################################
167 |         # For more contextually relevant results, you can
168 |         # use a center node to rerank search results based
169 |         # on their graph distance to a specific node
170 |         #################################################
171 | 
172 |         # Use the top search result's UUID as the center node for reranking
173 |         if results and len(results) > 0:
174 |             # Get the source node UUID from the top result
175 |             center_node_uuid = results[0].source_node_uuid
176 | 
177 |             print('\nReranking search results based on graph distance:')
178 |             print(f'Using center node UUID: {center_node_uuid}')
179 | 
180 |             reranked_results = await graphiti.search(
181 |                 'Who was the California Attorney General?', center_node_uuid=center_node_uuid
182 |             )
183 | 
184 |             # Print reranked search results
185 |             print('\nReranked Search Results:')
186 |             for result in reranked_results:
187 |                 print(f'UUID: {result.uuid}')
188 |                 print(f'Fact: {result.fact}')
189 |                 if hasattr(result, 'valid_at') and result.valid_at:
190 |                     print(f'Valid from: {result.valid_at}')
191 |                 if hasattr(result, 'invalid_at') and result.invalid_at:
192 |                     print(f'Valid until: {result.invalid_at}')
193 |                 print('---')
194 |         else:
195 |             print('No results found in the initial search to use as center node.')
196 | 
197 |         #################################################
198 |         # NODE SEARCH USING SEARCH RECIPES
199 |         #################################################
200 |         # Graphiti provides predefined search recipes
201 |         # optimized for different search scenarios.
202 |         # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
203 |         # nodes directly instead of edges.
204 |         #################################################
205 | 
206 |         # Example: Perform a node search using _search method with standard recipes
207 |         print(
208 |             '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
209 |         )
210 | 
211 |         # Use a predefined search configuration recipe and modify its limit
212 |         node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
213 |         node_search_config.limit = 5  # Limit to 5 results
214 | 
215 |         # Execute the node search
216 |         node_search_results = await graphiti._search(
217 |             query='California Governor',
218 |             config=node_search_config,
219 |         )
220 | 
221 |         # Print node search results
222 |         print('\nNode Search Results:')
223 |         for node in node_search_results.nodes:
224 |             print(f'Node UUID: {node.uuid}')
225 |             print(f'Node Name: {node.name}')
226 |             node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
227 |             print(f'Content Summary: {node_summary}')
228 |             print(f'Node Labels: {", ".join(node.labels)}')
229 |             print(f'Created At: {node.created_at}')
230 |             if hasattr(node, 'attributes') and node.attributes:
231 |                 print('Attributes:')
232 |                 for key, value in node.attributes.items():
233 |                     print(f'  {key}: {value}')
234 |             print('---')
235 | 
236 |     finally:
237 |         #################################################
238 |         # CLEANUP
239 |         #################################################
240 |         # Always close the connection to FalkorDB when
241 |         # finished to properly release resources
242 |         #################################################
243 | 
244 |         # Close the connection
245 |         await graphiti.close()
246 |         print('\nConnection closed')
247 | 
248 | 
249 | if __name__ == '__main__':
250 |     asyncio.run(main())
251 | 
```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/openai_base_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import typing
 20 | from abc import abstractmethod
 21 | from typing import Any, ClassVar
 22 | 
 23 | import openai
 24 | from openai.types.chat import ChatCompletionMessageParam
 25 | from pydantic import BaseModel
 26 | 
 27 | from ..prompts.models import Message
 28 | from .client import LLMClient, get_extraction_language_instruction
 29 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
 30 | from .errors import RateLimitError, RefusalError
 31 | 
 32 | logger = logging.getLogger(__name__)
 33 | 
 34 | DEFAULT_MODEL = 'gpt-5-mini'
 35 | DEFAULT_SMALL_MODEL = 'gpt-5-nano'
 36 | DEFAULT_REASONING = 'minimal'
 37 | DEFAULT_VERBOSITY = 'low'
 38 | 
 39 | 
 40 | class BaseOpenAIClient(LLMClient):
 41 |     """
 42 |     Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).
 43 | 
 44 |     This class contains shared logic for both OpenAI and Azure OpenAI clients,
 45 |     reducing code duplication while allowing for implementation-specific differences.
 46 |     """
 47 | 
 48 |     # Class-level constants
 49 |     MAX_RETRIES: ClassVar[int] = 2
 50 | 
 51 |     def __init__(
 52 |         self,
 53 |         config: LLMConfig | None = None,
 54 |         cache: bool = False,
 55 |         max_tokens: int = DEFAULT_MAX_TOKENS,
 56 |         reasoning: str | None = DEFAULT_REASONING,
 57 |         verbosity: str | None = DEFAULT_VERBOSITY,
 58 |     ):
 59 |         if cache:
 60 |             raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
 61 | 
 62 |         if config is None:
 63 |             config = LLMConfig()
 64 | 
 65 |         super().__init__(config, cache)
 66 |         self.max_tokens = max_tokens
 67 |         self.reasoning = reasoning
 68 |         self.verbosity = verbosity
 69 | 
 70 |     @abstractmethod
 71 |     async def _create_completion(
 72 |         self,
 73 |         model: str,
 74 |         messages: list[ChatCompletionMessageParam],
 75 |         temperature: float | None,
 76 |         max_tokens: int,
 77 |         response_model: type[BaseModel] | None = None,
 78 |     ) -> Any:
 79 |         """Create a completion using the specific client implementation."""
 80 |         pass
 81 | 
 82 |     @abstractmethod
 83 |     async def _create_structured_completion(
 84 |         self,
 85 |         model: str,
 86 |         messages: list[ChatCompletionMessageParam],
 87 |         temperature: float | None,
 88 |         max_tokens: int,
 89 |         response_model: type[BaseModel],
 90 |         reasoning: str | None,
 91 |         verbosity: str | None,
 92 |     ) -> Any:
 93 |         """Create a structured completion using the specific client implementation."""
 94 |         pass
 95 | 
 96 |     def _convert_messages_to_openai_format(
 97 |         self, messages: list[Message]
 98 |     ) -> list[ChatCompletionMessageParam]:
 99 |         """Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
100 |         openai_messages: list[ChatCompletionMessageParam] = []
101 |         for m in messages:
102 |             m.content = self._clean_input(m.content)
103 |             if m.role == 'user':
104 |                 openai_messages.append({'role': 'user', 'content': m.content})
105 |             elif m.role == 'system':
106 |                 openai_messages.append({'role': 'system', 'content': m.content})
107 |         return openai_messages
108 | 
109 |     def _get_model_for_size(self, model_size: ModelSize) -> str:
110 |         """Get the appropriate model name based on the requested size."""
111 |         if model_size == ModelSize.small:
112 |             return self.small_model or DEFAULT_SMALL_MODEL
113 |         else:
114 |             return self.model or DEFAULT_MODEL
115 | 
116 |     def _handle_structured_response(self, response: Any) -> dict[str, Any]:
117 |         """Handle structured response parsing and validation."""
118 |         response_object = response.output_text
119 | 
120 |         if response_object:
121 |             return json.loads(response_object)
122 |         elif response_object.refusal:
123 |             raise RefusalError(response_object.refusal)
124 |         else:
125 |             raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
126 | 
127 |     def _handle_json_response(self, response: Any) -> dict[str, Any]:
128 |         """Handle JSON response parsing."""
129 |         result = response.choices[0].message.content or '{}'
130 |         return json.loads(result)
131 | 
132 |     async def _generate_response(
133 |         self,
134 |         messages: list[Message],
135 |         response_model: type[BaseModel] | None = None,
136 |         max_tokens: int = DEFAULT_MAX_TOKENS,
137 |         model_size: ModelSize = ModelSize.medium,
138 |     ) -> dict[str, Any]:
139 |         """Generate a response using the appropriate client implementation."""
140 |         openai_messages = self._convert_messages_to_openai_format(messages)
141 |         model = self._get_model_for_size(model_size)
142 | 
143 |         try:
144 |             if response_model:
145 |                 response = await self._create_structured_completion(
146 |                     model=model,
147 |                     messages=openai_messages,
148 |                     temperature=self.temperature,
149 |                     max_tokens=max_tokens or self.max_tokens,
150 |                     response_model=response_model,
151 |                     reasoning=self.reasoning,
152 |                     verbosity=self.verbosity,
153 |                 )
154 |                 return self._handle_structured_response(response)
155 |             else:
156 |                 response = await self._create_completion(
157 |                     model=model,
158 |                     messages=openai_messages,
159 |                     temperature=self.temperature,
160 |                     max_tokens=max_tokens or self.max_tokens,
161 |                 )
162 |                 return self._handle_json_response(response)
163 | 
164 |         except openai.LengthFinishReasonError as e:
165 |             raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
166 |         except openai.RateLimitError as e:
167 |             raise RateLimitError from e
168 |         except openai.AuthenticationError as e:
169 |             logger.error(
170 |                 f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
171 |             )
172 |             raise
173 |         except Exception as e:
174 |             # Provide more context for connection errors
175 |             error_msg = str(e)
176 |             if 'Connection error' in error_msg or 'connection' in error_msg.lower():
177 |                 logger.error(
178 |                     f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
179 |                 )
180 |             else:
181 |                 logger.error(f'Error in generating LLM response: {e}')
182 |             raise
183 | 
184 |     async def generate_response(
185 |         self,
186 |         messages: list[Message],
187 |         response_model: type[BaseModel] | None = None,
188 |         max_tokens: int | None = None,
189 |         model_size: ModelSize = ModelSize.medium,
190 |         group_id: str | None = None,
191 |         prompt_name: str | None = None,
192 |     ) -> dict[str, typing.Any]:
193 |         """Generate a response with retry logic and error handling."""
194 |         if max_tokens is None:
195 |             max_tokens = self.max_tokens
196 | 
197 |         # Add multilingual extraction instructions
198 |         messages[0].content += get_extraction_language_instruction(group_id)
199 | 
200 |         # Wrap entire operation in tracing span
201 |         with self.tracer.start_span('llm.generate') as span:
202 |             attributes = {
203 |                 'llm.provider': 'openai',
204 |                 'model.size': model_size.value,
205 |                 'max_tokens': max_tokens,
206 |             }
207 |             if prompt_name:
208 |                 attributes['prompt.name'] = prompt_name
209 |             span.add_attributes(attributes)
210 | 
211 |             retry_count = 0
212 |             last_error = None
213 | 
214 |             while retry_count <= self.MAX_RETRIES:
215 |                 try:
216 |                     response = await self._generate_response(
217 |                         messages, response_model, max_tokens, model_size
218 |                     )
219 |                     return response
220 |                 except (RateLimitError, RefusalError):
221 |                     # These errors should not trigger retries
222 |                     span.set_status('error', str(last_error))
223 |                     raise
224 |                 except (
225 |                     openai.APITimeoutError,
226 |                     openai.APIConnectionError,
227 |                     openai.InternalServerError,
228 |                 ):
229 |                     # Let OpenAI's client handle these retries
230 |                     span.set_status('error', str(last_error))
231 |                     raise
232 |                 except Exception as e:
233 |                     last_error = e
234 | 
235 |                     # Don't retry if we've hit the max retries
236 |                     if retry_count >= self.MAX_RETRIES:
237 |                         logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
238 |                         span.set_status('error', str(e))
239 |                         span.record_exception(e)
240 |                         raise
241 | 
242 |                     retry_count += 1
243 | 
244 |                     # Construct a detailed error message for the LLM
245 |                     error_context = (
246 |                         f'The previous response attempt was invalid. '
247 |                         f'Error type: {e.__class__.__name__}. '
248 |                         f'Error details: {str(e)}. '
249 |                         f'Please try again with a valid response, ensuring the output matches '
250 |                         f'the expected format and constraints.'
251 |                     )
252 | 
253 |                     error_message = Message(role='user', content=error_context)
254 |                     messages.append(error_message)
255 |                     logger.warning(
256 |                         f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
257 |                     )
258 | 
259 |             # If we somehow get here, raise the last error
260 |             span.set_status('error', str(last_error))
261 |             raise last_error or Exception('Max retries exceeded with no specific error')
262 | 
```

--------------------------------------------------------------------------------
/mcp_server/src/config/schema.py:
--------------------------------------------------------------------------------

```python
  1 | """Configuration schemas with pydantic-settings and YAML support."""
  2 | 
  3 | import os
  4 | from pathlib import Path
  5 | from typing import Any
  6 | 
  7 | import yaml
  8 | from pydantic import BaseModel, Field
  9 | from pydantic_settings import (
 10 |     BaseSettings,
 11 |     PydanticBaseSettingsSource,
 12 |     SettingsConfigDict,
 13 | )
 14 | 
 15 | 
 16 | class YamlSettingsSource(PydanticBaseSettingsSource):
 17 |     """Custom settings source for loading from YAML files."""
 18 | 
 19 |     def __init__(self, settings_cls: type[BaseSettings], config_path: Path | None = None):
 20 |         super().__init__(settings_cls)
 21 |         self.config_path = config_path or Path('config.yaml')
 22 | 
 23 |     def _expand_env_vars(self, value: Any) -> Any:
 24 |         """Recursively expand environment variables in configuration values."""
 25 |         if isinstance(value, str):
 26 |             # Support ${VAR} and ${VAR:default} syntax
 27 |             import re
 28 | 
 29 |             def replacer(match):
 30 |                 var_name = match.group(1)
 31 |                 default_value = match.group(3) if match.group(3) is not None else ''
 32 |                 return os.environ.get(var_name, default_value)
 33 | 
 34 |             pattern = r'\$\{([^:}]+)(:([^}]*))?\}'
 35 | 
 36 |             # Check if the entire value is a single env var expression
 37 |             full_match = re.fullmatch(pattern, value)
 38 |             if full_match:
 39 |                 result = replacer(full_match)
 40 |                 # Convert boolean-like strings to actual booleans
 41 |                 if isinstance(result, str):
 42 |                     lower_result = result.lower().strip()
 43 |                     if lower_result in ('true', '1', 'yes', 'on'):
 44 |                         return True
 45 |                     elif lower_result in ('false', '0', 'no', 'off'):
 46 |                         return False
 47 |                     elif lower_result == '':
 48 |                         # Empty string means env var not set - return None for optional fields
 49 |                         return None
 50 |                 return result
 51 |             else:
 52 |                 # Otherwise, do string substitution (keep as strings for partial replacements)
 53 |                 return re.sub(pattern, replacer, value)
 54 |         elif isinstance(value, dict):
 55 |             return {k: self._expand_env_vars(v) for k, v in value.items()}
 56 |         elif isinstance(value, list):
 57 |             return [self._expand_env_vars(item) for item in value]
 58 |         return value
 59 | 
 60 |     def get_field_value(self, field_name: str, field_info: Any) -> Any:
 61 |         """Get field value from YAML config."""
 62 |         return None
 63 | 
 64 |     def __call__(self) -> dict[str, Any]:
 65 |         """Load and parse YAML configuration."""
 66 |         if not self.config_path.exists():
 67 |             return {}
 68 | 
 69 |         with open(self.config_path) as f:
 70 |             raw_config = yaml.safe_load(f) or {}
 71 | 
 72 |         # Expand environment variables
 73 |         return self._expand_env_vars(raw_config)
 74 | 
 75 | 
 76 | class ServerConfig(BaseModel):
 77 |     """Server configuration."""
 78 | 
 79 |     transport: str = Field(
 80 |         default='http',
 81 |         description='Transport type: http (default, recommended), stdio, or sse (deprecated)',
 82 |     )
 83 |     host: str = Field(default='0.0.0.0', description='Server host')
 84 |     port: int = Field(default=8000, description='Server port')
 85 | 
 86 | 
 87 | class OpenAIProviderConfig(BaseModel):
 88 |     """OpenAI provider configuration."""
 89 | 
 90 |     api_key: str | None = None
 91 |     api_url: str = 'https://api.openai.com/v1'
 92 |     organization_id: str | None = None
 93 | 
 94 | 
 95 | class AzureOpenAIProviderConfig(BaseModel):
 96 |     """Azure OpenAI provider configuration."""
 97 | 
 98 |     api_key: str | None = None
 99 |     api_url: str | None = None
100 |     api_version: str = '2024-10-21'
101 |     deployment_name: str | None = None
102 |     use_azure_ad: bool = False
103 | 
104 | 
105 | class AnthropicProviderConfig(BaseModel):
106 |     """Anthropic provider configuration."""
107 | 
108 |     api_key: str | None = None
109 |     api_url: str = 'https://api.anthropic.com'
110 |     max_retries: int = 3
111 | 
112 | 
113 | class GeminiProviderConfig(BaseModel):
114 |     """Gemini provider configuration."""
115 | 
116 |     api_key: str | None = None
117 |     project_id: str | None = None
118 |     location: str = 'us-central1'
119 | 
120 | 
121 | class GroqProviderConfig(BaseModel):
122 |     """Groq provider configuration."""
123 | 
124 |     api_key: str | None = None
125 |     api_url: str = 'https://api.groq.com/openai/v1'
126 | 
127 | 
128 | class VoyageProviderConfig(BaseModel):
129 |     """Voyage AI provider configuration."""
130 | 
131 |     api_key: str | None = None
132 |     api_url: str = 'https://api.voyageai.com/v1'
133 |     model: str = 'voyage-3'
134 | 
135 | 
136 | class LLMProvidersConfig(BaseModel):
137 |     """LLM providers configuration."""
138 | 
139 |     openai: OpenAIProviderConfig | None = None
140 |     azure_openai: AzureOpenAIProviderConfig | None = None
141 |     anthropic: AnthropicProviderConfig | None = None
142 |     gemini: GeminiProviderConfig | None = None
143 |     groq: GroqProviderConfig | None = None
144 | 
145 | 
146 | class LLMConfig(BaseModel):
147 |     """LLM configuration."""
148 | 
149 |     provider: str = Field(default='openai', description='LLM provider')
150 |     model: str = Field(default='gpt-4.1', description='Model name')
151 |     temperature: float | None = Field(
152 |         default=None, description='Temperature (optional, defaults to None for reasoning models)'
153 |     )
154 |     max_tokens: int = Field(default=4096, description='Max tokens')
155 |     providers: LLMProvidersConfig = Field(default_factory=LLMProvidersConfig)
156 | 
157 | 
158 | class EmbedderProvidersConfig(BaseModel):
159 |     """Embedder providers configuration."""
160 | 
161 |     openai: OpenAIProviderConfig | None = None
162 |     azure_openai: AzureOpenAIProviderConfig | None = None
163 |     gemini: GeminiProviderConfig | None = None
164 |     voyage: VoyageProviderConfig | None = None
165 | 
166 | 
167 | class EmbedderConfig(BaseModel):
168 |     """Embedder configuration."""
169 | 
170 |     provider: str = Field(default='openai', description='Embedder provider')
171 |     model: str = Field(default='text-embedding-3-small', description='Model name')
172 |     dimensions: int = Field(default=1536, description='Embedding dimensions')
173 |     providers: EmbedderProvidersConfig = Field(default_factory=EmbedderProvidersConfig)
174 | 
175 | 
176 | class Neo4jProviderConfig(BaseModel):
177 |     """Neo4j provider configuration."""
178 | 
179 |     uri: str = 'bolt://localhost:7687'
180 |     username: str = 'neo4j'
181 |     password: str | None = None
182 |     database: str = 'neo4j'
183 |     use_parallel_runtime: bool = False
184 | 
185 | 
186 | class FalkorDBProviderConfig(BaseModel):
187 |     """FalkorDB provider configuration."""
188 | 
189 |     uri: str = 'redis://localhost:6379'
190 |     password: str | None = None
191 |     database: str = 'default_db'
192 | 
193 | 
194 | class DatabaseProvidersConfig(BaseModel):
195 |     """Database providers configuration."""
196 | 
197 |     neo4j: Neo4jProviderConfig | None = None
198 |     falkordb: FalkorDBProviderConfig | None = None
199 | 
200 | 
201 | class DatabaseConfig(BaseModel):
202 |     """Database configuration."""
203 | 
204 |     provider: str = Field(default='falkordb', description='Database provider')
205 |     providers: DatabaseProvidersConfig = Field(default_factory=DatabaseProvidersConfig)
206 | 
207 | 
208 | class EntityTypeConfig(BaseModel):
209 |     """Entity type configuration."""
210 | 
211 |     name: str
212 |     description: str
213 | 
214 | 
215 | class GraphitiAppConfig(BaseModel):
216 |     """Graphiti-specific configuration."""
217 | 
218 |     group_id: str = Field(default='main', description='Group ID')
219 |     episode_id_prefix: str | None = Field(default='', description='Episode ID prefix')
220 |     user_id: str = Field(default='mcp_user', description='User ID')
221 |     entity_types: list[EntityTypeConfig] = Field(default_factory=list)
222 | 
223 |     def model_post_init(self, __context) -> None:
224 |         """Convert None to empty string for episode_id_prefix."""
225 |         if self.episode_id_prefix is None:
226 |             self.episode_id_prefix = ''
227 | 
228 | 
229 | class GraphitiConfig(BaseSettings):
230 |     """Graphiti configuration with YAML and environment support."""
231 | 
232 |     server: ServerConfig = Field(default_factory=ServerConfig)
233 |     llm: LLMConfig = Field(default_factory=LLMConfig)
234 |     embedder: EmbedderConfig = Field(default_factory=EmbedderConfig)
235 |     database: DatabaseConfig = Field(default_factory=DatabaseConfig)
236 |     graphiti: GraphitiAppConfig = Field(default_factory=GraphitiAppConfig)
237 | 
238 |     # Additional server options
239 |     destroy_graph: bool = Field(default=False, description='Clear graph on startup')
240 | 
241 |     model_config = SettingsConfigDict(
242 |         env_prefix='',
243 |         env_nested_delimiter='__',
244 |         case_sensitive=False,
245 |         extra='ignore',
246 |     )
247 | 
248 |     @classmethod
249 |     def settings_customise_sources(
250 |         cls,
251 |         settings_cls: type[BaseSettings],
252 |         init_settings: PydanticBaseSettingsSource,
253 |         env_settings: PydanticBaseSettingsSource,
254 |         dotenv_settings: PydanticBaseSettingsSource,
255 |         file_secret_settings: PydanticBaseSettingsSource,
256 |     ) -> tuple[PydanticBaseSettingsSource, ...]:
257 |         """Customize settings sources to include YAML."""
258 |         config_path = Path(os.environ.get('CONFIG_PATH', 'config/config.yaml'))
259 |         yaml_settings = YamlSettingsSource(settings_cls, config_path)
260 |         # Priority: CLI args (init) > env vars > yaml > defaults
261 |         return (init_settings, env_settings, yaml_settings, dotenv_settings)
262 | 
263 |     def apply_cli_overrides(self, args) -> None:
264 |         """Apply CLI argument overrides to configuration."""
265 |         # Override server settings
266 |         if hasattr(args, 'transport') and args.transport:
267 |             self.server.transport = args.transport
268 | 
269 |         # Override LLM settings
270 |         if hasattr(args, 'llm_provider') and args.llm_provider:
271 |             self.llm.provider = args.llm_provider
272 |         if hasattr(args, 'model') and args.model:
273 |             self.llm.model = args.model
274 |         if hasattr(args, 'temperature') and args.temperature is not None:
275 |             self.llm.temperature = args.temperature
276 | 
277 |         # Override embedder settings
278 |         if hasattr(args, 'embedder_provider') and args.embedder_provider:
279 |             self.embedder.provider = args.embedder_provider
280 |         if hasattr(args, 'embedder_model') and args.embedder_model:
281 |             self.embedder.model = args.embedder_model
282 | 
283 |         # Override database settings
284 |         if hasattr(args, 'database_provider') and args.database_provider:
285 |             self.database.provider = args.database_provider
286 | 
287 |         # Override Graphiti settings
288 |         if hasattr(args, 'group_id') and args.group_id:
289 |             self.graphiti.group_id = args.group_id
290 |         if hasattr(args, 'user_id') and args.user_id:
291 |             self.graphiti.user_id = args.user_id
292 | 
```

--------------------------------------------------------------------------------
/tests/helpers_test.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import os
 18 | from unittest.mock import Mock
 19 | 
 20 | import numpy as np
 21 | import pytest
 22 | from dotenv import load_dotenv
 23 | 
 24 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
 25 | from graphiti_core.edges import EntityEdge, EpisodicEdge
 26 | from graphiti_core.embedder.client import EmbedderClient
 27 | from graphiti_core.helpers import lucene_sanitize
 28 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
 29 | from graphiti_core.utils.maintenance.graph_data_operations import clear_data
 30 | 
 31 | load_dotenv()
 32 | 
 33 | drivers: list[GraphProvider] = []
 34 | if os.getenv('DISABLE_NEO4J') is None:
 35 |     try:
 36 |         from graphiti_core.driver.neo4j_driver import Neo4jDriver
 37 | 
 38 |         drivers.append(GraphProvider.NEO4J)
 39 |     except ImportError:
 40 |         raise
 41 | 
 42 | if os.getenv('DISABLE_FALKORDB') is None:
 43 |     try:
 44 |         from graphiti_core.driver.falkordb_driver import FalkorDriver
 45 | 
 46 |         drivers.append(GraphProvider.FALKORDB)
 47 |     except ImportError:
 48 |         raise
 49 | 
 50 | if os.getenv('DISABLE_KUZU') is None:
 51 |     try:
 52 |         from graphiti_core.driver.kuzu_driver import KuzuDriver
 53 | 
 54 |         drivers.append(GraphProvider.KUZU)
 55 |     except ImportError:
 56 |         raise
 57 | 
 58 | # Disable Neptune for now
 59 | os.environ['DISABLE_NEPTUNE'] = 'True'
 60 | if os.getenv('DISABLE_NEPTUNE') is None:
 61 |     try:
 62 |         from graphiti_core.driver.neptune_driver import NeptuneDriver
 63 | 
 64 |         drivers.append(GraphProvider.NEPTUNE)
 65 |     except ImportError:
 66 |         raise
 67 | 
 68 | NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
 69 | NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
 70 | NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
 71 | 
 72 | FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
 73 | FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
 74 | FALKORDB_USER = os.getenv('FALKORDB_USER', None)
 75 | FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
 76 | 
 77 | NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
 78 | NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
 79 | AOSS_HOST = os.getenv('AOSS_HOST', None)
 80 | 
 81 | KUZU_DB = os.getenv('KUZU_DB', ':memory:')
 82 | 
 83 | group_id = 'graphiti_test_group'
 84 | group_id_2 = 'graphiti_test_group_2'
 85 | 
 86 | 
 87 | def get_driver(provider: GraphProvider) -> GraphDriver:
 88 |     if provider == GraphProvider.NEO4J:
 89 |         return Neo4jDriver(
 90 |             uri=NEO4J_URI,
 91 |             user=NEO4J_USER,
 92 |             password=NEO4J_PASSWORD,
 93 |         )
 94 |     elif provider == GraphProvider.FALKORDB:
 95 |         return FalkorDriver(
 96 |             host=FALKORDB_HOST,
 97 |             port=int(FALKORDB_PORT),
 98 |             username=FALKORDB_USER,
 99 |             password=FALKORDB_PASSWORD,
100 |         )
101 |     elif provider == GraphProvider.KUZU:
102 |         driver = KuzuDriver(
103 |             db=KUZU_DB,
104 |         )
105 |         return driver
106 |     elif provider == GraphProvider.NEPTUNE:
107 |         return NeptuneDriver(
108 |             host=NEPTUNE_HOST,
109 |             port=int(NEPTUNE_PORT),
110 |             aoss_host=AOSS_HOST,
111 |         )
112 |     else:
113 |         raise ValueError(f'Driver {provider} not available')
114 | 
115 | 
116 | @pytest.fixture(params=drivers)
117 | async def graph_driver(request):
118 |     driver = request.param
119 |     graph_driver = get_driver(driver)
120 |     await clear_data(graph_driver, [group_id, group_id_2])
121 |     try:
122 |         yield graph_driver  # provide driver to the test
123 |     finally:
124 |         # always called, even if the test fails or raises
125 |         # await clean_up(graph_driver)
126 |         await graph_driver.close()
127 | 
128 | 
129 | embedding_dim = 384
130 | embeddings = {
131 |     key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
132 |     for key in [
133 |         'Alice',
134 |         'Bob',
135 |         'Alice likes Bob',
136 |         'test_entity_1',
137 |         'test_entity_2',
138 |         'test_entity_3',
139 |         'test_entity_4',
140 |         'test_entity_alice',
141 |         'test_entity_bob',
142 |         'test_entity_1 is a duplicate of test_entity_2',
143 |         'test_entity_3 is a duplicate of test_entity_4',
144 |         'test_entity_1 relates to test_entity_2',
145 |         'test_entity_1 relates to test_entity_3',
146 |         'test_entity_2 relates to test_entity_3',
147 |         'test_entity_1 relates to test_entity_4',
148 |         'test_entity_2 relates to test_entity_4',
149 |         'test_entity_3 relates to test_entity_4',
150 |         'test_entity_1 relates to test_entity_2',
151 |         'test_entity_3 relates to test_entity_4',
152 |         'test_entity_2 relates to test_entity_3',
153 |         'test_community_1',
154 |         'test_community_2',
155 |     ]
156 | }
157 | embeddings['Alice Smith'] = embeddings['Alice']
158 | 
159 | 
160 | @pytest.fixture
161 | def mock_embedder():
162 |     mock_model = Mock(spec=EmbedderClient)
163 | 
164 |     def mock_embed(input_data):
165 |         if isinstance(input_data, str):
166 |             return embeddings[input_data]
167 |         elif isinstance(input_data, list):
168 |             combined_input = ' '.join(input_data)
169 |             return embeddings[combined_input]
170 |         else:
171 |             raise ValueError(f'Unsupported input type: {type(input_data)}')
172 | 
173 |     mock_model.create.side_effect = mock_embed
174 |     return mock_model
175 | 
176 | 
177 | def test_lucene_sanitize():
178 |     # Call the function with test data
179 |     queries = [
180 |         (
181 |             'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
182 |             '\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/',
183 |         ),
184 |         ('this has no escape characters', 'this has no escape characters'),
185 |     ]
186 | 
187 |     for query, assert_result in queries:
188 |         result = lucene_sanitize(query)
189 |         assert assert_result == result
190 | 
191 | 
192 | async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
193 |     results, _, _ = await driver.execute_query(
194 |         """
195 |         MATCH (n)
196 |         WHERE n.uuid IN $uuids
197 |         RETURN COUNT(n) as count
198 |         """,
199 |         uuids=uuids,
200 |     )
201 |     return int(results[0]['count'])
202 | 
203 | 
204 | async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
205 |     results, _, _ = await driver.execute_query(
206 |         """
207 |         MATCH (n)-[e]->(m)
208 |         WHERE e.uuid IN $uuids
209 |         RETURN COUNT(e) as count
210 |         UNION ALL
211 |         MATCH (e:RelatesToNode_)
212 |         WHERE e.uuid IN $uuids
213 |         RETURN COUNT(e) as count
214 |         """,
215 |         uuids=uuids,
216 |     )
217 |     return sum(int(result['count']) for result in results)
218 | 
219 | 
220 | async def print_graph(graph_driver: GraphDriver):
221 |     nodes, _, _ = await graph_driver.execute_query(
222 |         """
223 |         MATCH (n)
224 |         RETURN n.uuid, n.name
225 |         """,
226 |     )
227 |     print('Nodes:')
228 |     for node in nodes:
229 |         print('  ', node)
230 |     edges, _, _ = await graph_driver.execute_query(
231 |         """
232 |         MATCH (n)-[e]->(m)
233 |         RETURN n.name, e.uuid, m.name
234 |         """,
235 |     )
236 |     print('Edges:')
237 |     for edge in edges:
238 |         print('  ', edge)
239 | 
240 | 
241 | async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
242 |     assert retrieved.uuid == sample.uuid
243 |     assert retrieved.name == sample.name
244 |     assert retrieved.group_id == group_id
245 |     assert retrieved.created_at == sample.created_at
246 |     assert retrieved.source == sample.source
247 |     assert retrieved.source_description == sample.source_description
248 |     assert retrieved.content == sample.content
249 |     assert retrieved.valid_at == sample.valid_at
250 |     assert set(retrieved.entity_edges) == set(sample.entity_edges)
251 | 
252 | 
253 | async def assert_entity_node_equals(
254 |     graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
255 | ):
256 |     await retrieved.load_name_embedding(graph_driver)
257 |     assert retrieved.uuid == sample.uuid
258 |     assert retrieved.name == sample.name
259 |     assert retrieved.group_id == sample.group_id
260 |     assert set(retrieved.labels) == set(sample.labels)
261 |     assert retrieved.created_at == sample.created_at
262 |     assert retrieved.name_embedding is not None
263 |     assert sample.name_embedding is not None
264 |     assert np.allclose(retrieved.name_embedding, sample.name_embedding)
265 |     assert retrieved.summary == sample.summary
266 |     assert retrieved.attributes == sample.attributes
267 | 
268 | 
269 | async def assert_community_node_equals(
270 |     graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
271 | ):
272 |     await retrieved.load_name_embedding(graph_driver)
273 |     assert retrieved.uuid == sample.uuid
274 |     assert retrieved.name == sample.name
275 |     assert retrieved.group_id == group_id
276 |     assert retrieved.created_at == sample.created_at
277 |     assert retrieved.name_embedding is not None
278 |     assert sample.name_embedding is not None
279 |     assert np.allclose(retrieved.name_embedding, sample.name_embedding)
280 |     assert retrieved.summary == sample.summary
281 | 
282 | 
283 | async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
284 |     assert retrieved.uuid == sample.uuid
285 |     assert retrieved.group_id == sample.group_id
286 |     assert retrieved.created_at == sample.created_at
287 |     assert retrieved.source_node_uuid == sample.source_node_uuid
288 |     assert retrieved.target_node_uuid == sample.target_node_uuid
289 | 
290 | 
291 | async def assert_entity_edge_equals(
292 |     graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
293 | ):
294 |     await retrieved.load_fact_embedding(graph_driver)
295 |     assert retrieved.uuid == sample.uuid
296 |     assert retrieved.group_id == sample.group_id
297 |     assert retrieved.created_at == sample.created_at
298 |     assert retrieved.source_node_uuid == sample.source_node_uuid
299 |     assert retrieved.target_node_uuid == sample.target_node_uuid
300 |     assert retrieved.name == sample.name
301 |     assert retrieved.fact == sample.fact
302 |     assert retrieved.fact_embedding is not None
303 |     assert sample.fact_embedding is not None
304 |     assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
305 |     assert retrieved.episodes == sample.episodes
306 |     assert retrieved.expired_at == sample.expired_at
307 |     assert retrieved.valid_at == sample.valid_at
308 |     assert retrieved.invalid_at == sample.invalid_at
309 |     assert retrieved.attributes == sample.attributes
310 | 
311 | 
312 | if __name__ == '__main__':
313 |     pytest.main([__file__])
314 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_fixtures.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Shared test fixtures and utilities for Graphiti MCP integration tests.
  3 | """
  4 | 
  5 | import asyncio
  6 | import contextlib
  7 | import json
  8 | import os
  9 | import random
 10 | import time
 11 | from contextlib import asynccontextmanager
 12 | from typing import Any
 13 | 
 14 | import pytest
 15 | from faker import Faker
 16 | from mcp import ClientSession, StdioServerParameters
 17 | from mcp.client.stdio import stdio_client
 18 | 
 19 | fake = Faker()
 20 | 
 21 | 
 22 | class TestDataGenerator:
 23 |     """Generate realistic test data for various scenarios."""
 24 | 
 25 |     @staticmethod
 26 |     def generate_company_profile() -> dict[str, Any]:
 27 |         """Generate a realistic company profile."""
 28 |         return {
 29 |             'company': {
 30 |                 'name': fake.company(),
 31 |                 'founded': random.randint(1990, 2023),
 32 |                 'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']),
 33 |                 'employees': random.randint(10, 10000),
 34 |                 'revenue': f'${random.randint(1, 1000)}M',
 35 |                 'headquarters': fake.city(),
 36 |             },
 37 |             'products': [
 38 |                 {
 39 |                     'id': fake.uuid4()[:8],
 40 |                     'name': fake.catch_phrase(),
 41 |                     'category': random.choice(['Software', 'Hardware', 'Service']),
 42 |                     'price': random.randint(10, 10000),
 43 |                 }
 44 |                 for _ in range(random.randint(1, 5))
 45 |             ],
 46 |             'leadership': {
 47 |                 'ceo': fake.name(),
 48 |                 'cto': fake.name(),
 49 |                 'cfo': fake.name(),
 50 |             },
 51 |         }
 52 | 
 53 |     @staticmethod
 54 |     def generate_conversation(turns: int = 3) -> str:
 55 |         """Generate a realistic conversation."""
 56 |         topics = [
 57 |             'product features',
 58 |             'pricing',
 59 |             'technical support',
 60 |             'integration',
 61 |             'documentation',
 62 |             'performance',
 63 |         ]
 64 | 
 65 |         conversation = []
 66 |         for _ in range(turns):
 67 |             topic = random.choice(topics)
 68 |             user_msg = f'user: {fake.sentence()} about {topic}?'
 69 |             assistant_msg = f'assistant: {fake.paragraph(nb_sentences=2)}'
 70 |             conversation.extend([user_msg, assistant_msg])
 71 | 
 72 |         return '\n'.join(conversation)
 73 | 
 74 |     @staticmethod
 75 |     def generate_technical_document() -> str:
 76 |         """Generate technical documentation content."""
 77 |         sections = [
 78 |             f'# {fake.catch_phrase()}\n\n{fake.paragraph()}',
 79 |             f'## Architecture\n{fake.paragraph()}',
 80 |             f'## Implementation\n{fake.paragraph()}',
 81 |             f'## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s',
 82 |             f'## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}',
 83 |         ]
 84 |         return '\n\n'.join(sections)
 85 | 
 86 |     @staticmethod
 87 |     def generate_news_article() -> str:
 88 |         """Generate a news article."""
 89 |         company = fake.company()
 90 |         return f"""
 91 |         {company} Announces {fake.catch_phrase()}
 92 | 
 93 |         {fake.city()}, {fake.date()} - {company} today announced {fake.paragraph()}.
 94 | 
 95 |         "This is a significant milestone," said {fake.name()}, CEO of {company}.
 96 |         "{fake.sentence()}"
 97 | 
 98 |         The announcement comes after {fake.paragraph()}.
 99 | 
100 |         Industry analysts predict {fake.paragraph()}.
101 |         """
102 | 
103 |     @staticmethod
104 |     def generate_user_profile() -> dict[str, Any]:
105 |         """Generate a user profile."""
106 |         return {
107 |             'user_id': fake.uuid4(),
108 |             'name': fake.name(),
109 |             'email': fake.email(),
110 |             'joined': fake.date_time_this_year().isoformat(),
111 |             'preferences': {
112 |                 'theme': random.choice(['light', 'dark', 'auto']),
113 |                 'notifications': random.choice([True, False]),
114 |                 'language': random.choice(['en', 'es', 'fr', 'de']),
115 |             },
116 |             'activity': {
117 |                 'last_login': fake.date_time_this_month().isoformat(),
118 |                 'total_sessions': random.randint(1, 1000),
119 |                 'average_duration': f'{random.randint(1, 60)} minutes',
120 |             },
121 |         }
122 | 
123 | 
124 | class MockLLMProvider:
125 |     """Mock LLM provider for testing without actual API calls."""
126 | 
127 |     def __init__(self, delay: float = 0.1):
128 |         self.delay = delay  # Simulate LLM latency
129 | 
130 |     async def generate(self, prompt: str) -> str:
131 |         """Simulate LLM generation with delay."""
132 |         await asyncio.sleep(self.delay)
133 | 
134 |         # Return deterministic responses based on prompt patterns
135 |         if 'extract entities' in prompt.lower():
136 |             return json.dumps(
137 |                 {
138 |                     'entities': [
139 |                         {'name': 'TestEntity1', 'type': 'PERSON'},
140 |                         {'name': 'TestEntity2', 'type': 'ORGANIZATION'},
141 |                     ]
142 |                 }
143 |             )
144 |         elif 'summarize' in prompt.lower():
145 |             return 'This is a test summary of the provided content.'
146 |         else:
147 |             return 'Mock LLM response'
148 | 
149 | 
150 | @asynccontextmanager
151 | async def graphiti_test_client(
152 |     group_id: str | None = None,
153 |     database: str = 'falkordb',
154 |     use_mock_llm: bool = False,
155 |     config_overrides: dict[str, Any] | None = None,
156 | ):
157 |     """
158 |     Context manager for creating test clients with various configurations.
159 | 
160 |     Args:
161 |         group_id: Test group identifier
162 |         database: Database backend (neo4j, falkordb)
163 |         use_mock_llm: Whether to use mock LLM for faster tests
164 |         config_overrides: Additional config overrides
165 |     """
166 |     test_group_id = group_id or f'test_{int(time.time())}_{random.randint(1000, 9999)}'
167 | 
168 |     env = {
169 |         'DATABASE_PROVIDER': database,
170 |         'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key' if use_mock_llm else None),
171 |     }
172 | 
173 |     # Database-specific configuration
174 |     if database == 'neo4j':
175 |         env.update(
176 |             {
177 |                 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
178 |                 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
179 |                 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
180 |             }
181 |         )
182 |     elif database == 'falkordb':
183 |         env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
184 | 
185 |     # Apply config overrides
186 |     if config_overrides:
187 |         env.update(config_overrides)
188 | 
189 |     # Add mock LLM flag if needed
190 |     if use_mock_llm:
191 |         env['USE_MOCK_LLM'] = 'true'
192 | 
193 |     server_params = StdioServerParameters(
194 |         command='uv', args=['run', 'main.py', '--transport', 'stdio'], env=env
195 |     )
196 | 
197 |     async with stdio_client(server_params) as (read, write):
198 |         session = ClientSession(read, write)
199 |         await session.initialize()
200 | 
201 |         try:
202 |             yield session, test_group_id
203 |         finally:
204 |             # Cleanup: Clear test data
205 |             with contextlib.suppress(Exception):
206 |                 await session.call_tool('clear_graph', {'group_id': test_group_id})
207 | 
208 |             await session.close()
209 | 
210 | 
211 | class PerformanceBenchmark:
212 |     """Track and analyze performance benchmarks."""
213 | 
214 |     def __init__(self):
215 |         self.measurements: dict[str, list[float]] = {}
216 | 
217 |     def record(self, operation: str, duration: float):
218 |         """Record a performance measurement."""
219 |         if operation not in self.measurements:
220 |             self.measurements[operation] = []
221 |         self.measurements[operation].append(duration)
222 | 
223 |     def get_stats(self, operation: str) -> dict[str, float]:
224 |         """Get statistics for an operation."""
225 |         if operation not in self.measurements or not self.measurements[operation]:
226 |             return {}
227 | 
228 |         durations = self.measurements[operation]
229 |         return {
230 |             'count': len(durations),
231 |             'mean': sum(durations) / len(durations),
232 |             'min': min(durations),
233 |             'max': max(durations),
234 |             'median': sorted(durations)[len(durations) // 2],
235 |         }
236 | 
237 |     def report(self) -> str:
238 |         """Generate a performance report."""
239 |         lines = ['Performance Benchmark Report', '=' * 40]
240 | 
241 |         for operation in sorted(self.measurements.keys()):
242 |             stats = self.get_stats(operation)
243 |             lines.append(f'\n{operation}:')
244 |             lines.append(f'  Samples: {stats["count"]}')
245 |             lines.append(f'  Mean: {stats["mean"]:.3f}s')
246 |             lines.append(f'  Median: {stats["median"]:.3f}s')
247 |             lines.append(f'  Min: {stats["min"]:.3f}s')
248 |             lines.append(f'  Max: {stats["max"]:.3f}s')
249 | 
250 |         return '\n'.join(lines)
251 | 
252 | 
253 | # Pytest fixtures
254 | @pytest.fixture
255 | def test_data_generator():
256 |     """Provide test data generator."""
257 |     return TestDataGenerator()
258 | 
259 | 
260 | @pytest.fixture
261 | def performance_benchmark():
262 |     """Provide performance benchmark tracker."""
263 |     return PerformanceBenchmark()
264 | 
265 | 
266 | @pytest.fixture
267 | async def mock_graphiti_client():
268 |     """Provide a Graphiti client with mocked LLM."""
269 |     async with graphiti_test_client(use_mock_llm=True) as (session, group_id):
270 |         yield session, group_id
271 | 
272 | 
273 | @pytest.fixture
274 | async def graphiti_client():
275 |     """Provide a real Graphiti client."""
276 |     async with graphiti_test_client(use_mock_llm=False) as (session, group_id):
277 |         yield session, group_id
278 | 
279 | 
280 | # Test data fixtures
281 | @pytest.fixture
282 | def sample_memories():
283 |     """Provide sample memory data for testing."""
284 |     return [
285 |         {
286 |             'name': 'Company Overview',
287 |             'episode_body': TestDataGenerator.generate_company_profile(),
288 |             'source': 'json',
289 |             'source_description': 'company database',
290 |         },
291 |         {
292 |             'name': 'Product Launch',
293 |             'episode_body': TestDataGenerator.generate_news_article(),
294 |             'source': 'text',
295 |             'source_description': 'press release',
296 |         },
297 |         {
298 |             'name': 'Customer Support',
299 |             'episode_body': TestDataGenerator.generate_conversation(),
300 |             'source': 'message',
301 |             'source_description': 'support chat',
302 |         },
303 |         {
304 |             'name': 'Technical Specs',
305 |             'episode_body': TestDataGenerator.generate_technical_document(),
306 |             'source': 'text',
307 |             'source_description': 'documentation',
308 |         },
309 |     ]
310 | 
311 | 
312 | @pytest.fixture
313 | def large_dataset():
314 |     """Generate a large dataset for stress testing."""
315 |     return [
316 |         {
317 |             'name': f'Document {i}',
318 |             'episode_body': TestDataGenerator.generate_technical_document(),
319 |             'source': 'text',
320 |             'source_description': 'bulk import',
321 |         }
322 |         for i in range(50)
323 |     ]
324 | 
```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_bulk_utils.py:
--------------------------------------------------------------------------------

```python
  1 | from collections import deque
  2 | from unittest.mock import AsyncMock, MagicMock
  3 | 
  4 | import pytest
  5 | 
  6 | from graphiti_core.edges import EntityEdge
  7 | from graphiti_core.graphiti_types import GraphitiClients
  8 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
  9 | from graphiti_core.utils import bulk_utils
 10 | from graphiti_core.utils.datetime_utils import utc_now
 11 | 
 12 | 
 13 | def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
 14 |     return EpisodicNode(
 15 |         name=f'episode-{uuid_suffix}',
 16 |         group_id=group_id,
 17 |         labels=[],
 18 |         source=EpisodeType.message,
 19 |         content='content',
 20 |         source_description='test',
 21 |         created_at=utc_now(),
 22 |         valid_at=utc_now(),
 23 |     )
 24 | 
 25 | 
 26 | def _make_clients() -> GraphitiClients:
 27 |     driver = MagicMock()
 28 |     embedder = MagicMock()
 29 |     cross_encoder = MagicMock()
 30 |     llm_client = MagicMock()
 31 | 
 32 |     return GraphitiClients.model_construct(  # bypass validation to allow test doubles
 33 |         driver=driver,
 34 |         embedder=embedder,
 35 |         cross_encoder=cross_encoder,
 36 |         llm_client=llm_client,
 37 |     )
 38 | 
 39 | 
 40 | @pytest.mark.asyncio
 41 | async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
 42 |     clients = _make_clients()
 43 | 
 44 |     episode_one = _make_episode('1')
 45 |     episode_two = _make_episode('2')
 46 | 
 47 |     extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
 48 |     extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
 49 | 
 50 |     canonical = extracted_one
 51 | 
 52 |     call_queue = deque()
 53 | 
 54 |     async def fake_resolve(
 55 |         clients_arg,
 56 |         nodes_arg,
 57 |         episode_arg,
 58 |         previous_episodes_arg,
 59 |         entity_types_arg,
 60 |         existing_nodes_override=None,
 61 |     ):
 62 |         call_queue.append(existing_nodes_override)
 63 | 
 64 |         if nodes_arg == [extracted_one]:
 65 |             return [canonical], {canonical.uuid: canonical.uuid}, []
 66 | 
 67 |         assert nodes_arg == [extracted_two]
 68 |         assert existing_nodes_override is None
 69 | 
 70 |         return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
 71 | 
 72 |     monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
 73 | 
 74 |     nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
 75 |         clients,
 76 |         [[extracted_one], [extracted_two]],
 77 |         [(episode_one, []), (episode_two, [])],
 78 |     )
 79 | 
 80 |     assert len(call_queue) == 2
 81 |     assert call_queue[0] is None
 82 |     assert call_queue[1] is None
 83 | 
 84 |     assert nodes_by_episode[episode_one.uuid] == [canonical]
 85 |     assert nodes_by_episode[episode_two.uuid] == [canonical]
 86 |     assert compressed_map.get(extracted_two.uuid) == canonical.uuid
 87 | 
 88 | 
 89 | @pytest.mark.asyncio
 90 | async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
 91 |     clients = _make_clients()
 92 | 
 93 |     resolve_mock = AsyncMock()
 94 |     monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
 95 | 
 96 |     nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
 97 |         clients,
 98 |         [],
 99 |         [],
100 |     )
101 | 
102 |     assert nodes_by_episode == {}
103 |     assert compressed_map == {}
104 |     resolve_mock.assert_not_awaited()
105 | 
106 | 
107 | @pytest.mark.asyncio
108 | async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
109 |     clients = _make_clients()
110 | 
111 |     episode = _make_episode('solo')
112 |     extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
113 | 
114 |     resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
115 |     monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
116 | 
117 |     nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
118 |         clients,
119 |         [[extracted]],
120 |         [(episode, [])],
121 |     )
122 | 
123 |     assert nodes_by_episode == {episode.uuid: [extracted]}
124 |     assert compressed_map == {extracted.uuid: extracted.uuid}
125 |     resolve_mock.assert_awaited_once()
126 | 
127 | 
128 | @pytest.mark.asyncio
129 | async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
130 |     clients = _make_clients()
131 | 
132 |     episode_one = _make_episode('one')
133 |     episode_two = _make_episode('two')
134 | 
135 |     extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
136 |     extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
137 | 
138 |     canonical = extracted_one
139 |     alias = extracted_two
140 | 
141 |     async def fake_resolve(
142 |         clients_arg,
143 |         nodes_arg,
144 |         episode_arg,
145 |         previous_episodes_arg,
146 |         entity_types_arg,
147 |         existing_nodes_override=None,
148 |     ):
149 |         if nodes_arg == [extracted_one]:
150 |             return [canonical], {canonical.uuid: canonical.uuid}, []
151 |         assert nodes_arg == [extracted_two]
152 |         return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
153 | 
154 |     monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
155 | 
156 |     nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
157 |         clients,
158 |         [[extracted_one], [extracted_two]],
159 |         [(episode_one, []), (episode_two, [])],
160 |     )
161 | 
162 |     assert nodes_by_episode[episode_one.uuid] == [canonical]
163 |     assert nodes_by_episode[episode_two.uuid] == [canonical]
164 |     assert compressed_map.get(alias.uuid) == canonical.uuid
165 | 
166 | 
167 | @pytest.mark.asyncio
168 | async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
169 |     clients = _make_clients()
170 | 
171 |     episode = _make_episode('missing')
172 |     extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
173 | 
174 |     resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
175 |     monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
176 | 
177 |     with caplog.at_level('WARNING'):
178 |         nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
179 |             clients,
180 |             [[extracted]],
181 |             [(episode, [])],
182 |         )
183 | 
184 |     assert nodes_by_episode[episode.uuid] == [extracted]
185 |     assert compressed_map.get(extracted.uuid) == 'missing-canonical'
186 |     assert any('Canonical node missing' in rec.message for rec in caplog.records)
187 | 
188 | 
189 | def test_build_directed_uuid_map_empty():
190 |     assert bulk_utils._build_directed_uuid_map([]) == {}
191 | 
192 | 
193 | def test_build_directed_uuid_map_chain():
194 |     mapping = bulk_utils._build_directed_uuid_map(
195 |         [
196 |             ('a', 'b'),
197 |             ('b', 'c'),
198 |         ]
199 |     )
200 | 
201 |     assert mapping['a'] == 'c'
202 |     assert mapping['b'] == 'c'
203 |     assert mapping['c'] == 'c'
204 | 
205 | 
206 | def test_build_directed_uuid_map_preserves_direction():
207 |     mapping = bulk_utils._build_directed_uuid_map(
208 |         [
209 |             ('alias', 'canonical'),
210 |         ]
211 |     )
212 | 
213 |     assert mapping['alias'] == 'canonical'
214 |     assert mapping['canonical'] == 'canonical'
215 | 
216 | 
217 | def test_resolve_edge_pointers_updates_sources():
218 |     created_at = utc_now()
219 |     edge = EntityEdge(
220 |         name='knows',
221 |         fact='fact',
222 |         group_id='group',
223 |         source_node_uuid='alias',
224 |         target_node_uuid='target',
225 |         created_at=created_at,
226 |     )
227 | 
228 |     bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
229 | 
230 |     assert edge.source_node_uuid == 'canonical'
231 |     assert edge.target_node_uuid == 'target'
232 | 
233 | 
234 | @pytest.mark.asyncio
235 | async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
236 |     """Test that dedupe_edges_bulk correctly compares edges within the same episode.
237 | 
238 |     This test verifies the fix that removed the `if i == j: continue` check,
239 |     which was preventing edges from the same episode from being compared against each other.
240 |     """
241 |     clients = _make_clients()
242 | 
243 |     # Track which edges are compared
244 |     comparisons_made = []
245 | 
246 |     # Create mock embedder that sets embedding values
247 |     async def mock_create_embeddings(embedder, edges):
248 |         for edge in edges:
249 |             edge.fact_embedding = [0.1, 0.2, 0.3]
250 | 
251 |     monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)
252 | 
253 |     # Mock resolve_extracted_edge to track comparisons and mark duplicates
254 |     async def mock_resolve_extracted_edge(
255 |         llm_client,
256 |         extracted_edge,
257 |         related_edges,
258 |         existing_edges,
259 |         episode,
260 |         edge_type_candidates=None,
261 |         custom_edge_type_names=None,
262 |     ):
263 |         # Track that this edge was compared against the related_edges
264 |         comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))
265 | 
266 |         # If there are related edges with same source/target/fact, mark as duplicate
267 |         for related in related_edges:
268 |             if (
269 |                 related.uuid != extracted_edge.uuid  # Can't be duplicate of self
270 |                 and related.source_node_uuid == extracted_edge.source_node_uuid
271 |                 and related.target_node_uuid == extracted_edge.target_node_uuid
272 |                 and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
273 |             ):
274 |                 # Return the related edge and mark extracted_edge as duplicate
275 |                 return related, [], [related]
276 |         # Otherwise return the extracted edge as-is
277 |         return extracted_edge, [], []
278 | 
279 |     monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)
280 | 
281 |     episode = _make_episode('1')
282 |     source_uuid = 'source-uuid'
283 |     target_uuid = 'target-uuid'
284 | 
285 |     # Create 3 identical edges within the same episode
286 |     edge1 = EntityEdge(
287 |         name='recommends',
288 |         fact='assistant recommends yoga poses',
289 |         group_id='group',
290 |         source_node_uuid=source_uuid,
291 |         target_node_uuid=target_uuid,
292 |         created_at=utc_now(),
293 |         episodes=[episode.uuid],
294 |     )
295 |     edge2 = EntityEdge(
296 |         name='recommends',
297 |         fact='assistant recommends yoga poses',
298 |         group_id='group',
299 |         source_node_uuid=source_uuid,
300 |         target_node_uuid=target_uuid,
301 |         created_at=utc_now(),
302 |         episodes=[episode.uuid],
303 |     )
304 |     edge3 = EntityEdge(
305 |         name='recommends',
306 |         fact='assistant recommends yoga poses',
307 |         group_id='group',
308 |         source_node_uuid=source_uuid,
309 |         target_node_uuid=target_uuid,
310 |         created_at=utc_now(),
311 |         episodes=[episode.uuid],
312 |     )
313 | 
314 |     await bulk_utils.dedupe_edges_bulk(
315 |         clients,
316 |         [[edge1, edge2, edge3]],
317 |         [(episode, [])],
318 |         [],
319 |         {},
320 |         {},
321 |     )
322 | 
323 |     # Verify that edges were compared against each other (within same episode)
324 |     # Each edge should have been compared against all 3 edges (including itself, which gets filtered)
325 |     assert len(comparisons_made) == 3
326 |     for _, compared_against in comparisons_made:
327 |         # Each edge should have access to all 3 edges as candidates
328 |         assert len(compared_against) >= 2  # At least 2 others (self is filtered out)
329 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/community_operations.py:
--------------------------------------------------------------------------------

```python
  1 | import asyncio
  2 | import logging
  3 | from collections import defaultdict
  4 | 
  5 | from pydantic import BaseModel
  6 | 
  7 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
  8 | from graphiti_core.edges import CommunityEdge
  9 | from graphiti_core.embedder import EmbedderClient
 10 | from graphiti_core.helpers import semaphore_gather
 11 | from graphiti_core.llm_client import LLMClient
 12 | from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
 13 | from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
 14 | from graphiti_core.prompts import prompt_library
 15 | from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
 16 | from graphiti_core.utils.datetime_utils import utc_now
 17 | from graphiti_core.utils.maintenance.edge_operations import build_community_edges
 18 | 
 19 | MAX_COMMUNITY_BUILD_CONCURRENCY = 10
 20 | 
 21 | logger = logging.getLogger(__name__)
 22 | 
 23 | 
 24 | class Neighbor(BaseModel):
 25 |     node_uuid: str
 26 |     edge_count: int
 27 | 
 28 | 
 29 | async def get_community_clusters(
 30 |     driver: GraphDriver, group_ids: list[str] | None
 31 | ) -> list[list[EntityNode]]:
 32 |     community_clusters: list[list[EntityNode]] = []
 33 | 
 34 |     if group_ids is None:
 35 |         group_id_values, _, _ = await driver.execute_query(
 36 |             """
 37 |             MATCH (n:Entity)
 38 |             WHERE n.group_id IS NOT NULL
 39 |             RETURN
 40 |                 collect(DISTINCT n.group_id) AS group_ids
 41 |             """
 42 |         )
 43 | 
 44 |         group_ids = group_id_values[0]['group_ids'] if group_id_values else []
 45 | 
 46 |     for group_id in group_ids:
 47 |         projection: dict[str, list[Neighbor]] = {}
 48 |         nodes = await EntityNode.get_by_group_ids(driver, [group_id])
 49 |         for node in nodes:
 50 |             match_query = """
 51 |                 MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
 52 |             """
 53 |             if driver.provider == GraphProvider.KUZU:
 54 |                 match_query = """
 55 |                 MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
 56 |                 """
 57 |             records, _, _ = await driver.execute_query(
 58 |                 match_query
 59 |                 + """
 60 |                 WITH count(e) AS count, m.uuid AS uuid
 61 |                 RETURN
 62 |                     uuid,
 63 |                     count
 64 |                 """,
 65 |                 uuid=node.uuid,
 66 |                 group_id=group_id,
 67 |             )
 68 | 
 69 |             projection[node.uuid] = [
 70 |                 Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
 71 |             ]
 72 | 
 73 |         cluster_uuids = label_propagation(projection)
 74 | 
 75 |         community_clusters.extend(
 76 |             list(
 77 |                 await semaphore_gather(
 78 |                     *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
 79 |                 )
 80 |             )
 81 |         )
 82 | 
 83 |     return community_clusters
 84 | 
 85 | 
 86 | def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
 87 |     # Implement the label propagation community detection algorithm.
 88 |     # 1. Start with each node being assigned its own community
 89 |     # 2. Each node will take on the community of the plurality of its neighbors
 90 |     # 3. Ties are broken by going to the largest community
 91 |     # 4. Continue until no communities change during propagation
 92 | 
 93 |     community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
 94 | 
 95 |     while True:
 96 |         no_change = True
 97 |         new_community_map: dict[str, int] = {}
 98 | 
 99 |         for uuid, neighbors in projection.items():
100 |             curr_community = community_map[uuid]
101 | 
102 |             community_candidates: dict[int, int] = defaultdict(int)
103 |             for neighbor in neighbors:
104 |                 community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
105 |             community_lst = [
106 |                 (count, community) for community, count in community_candidates.items()
107 |             ]
108 | 
109 |             community_lst.sort(reverse=True)
110 |             candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1)
111 |             if community_candidate != -1 and candidate_rank > 1:
112 |                 new_community = community_candidate
113 |             else:
114 |                 new_community = max(community_candidate, curr_community)
115 | 
116 |             new_community_map[uuid] = new_community
117 | 
118 |             if new_community != curr_community:
119 |                 no_change = False
120 | 
121 |         if no_change:
122 |             break
123 | 
124 |         community_map = new_community_map
125 | 
126 |     community_cluster_map = defaultdict(list)
127 |     for uuid, community in community_map.items():
128 |         community_cluster_map[community].append(uuid)
129 | 
130 |     clusters = [cluster for cluster in community_cluster_map.values()]
131 |     return clusters
132 | 
133 | 
134 | async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
135 |     # Prepare context for LLM
136 |     context = {
137 |         'node_summaries': [{'summary': summary} for summary in summary_pair],
138 |     }
139 | 
140 |     llm_response = await llm_client.generate_response(
141 |         prompt_library.summarize_nodes.summarize_pair(context),
142 |         response_model=Summary,
143 |         prompt_name='summarize_nodes.summarize_pair',
144 |     )
145 | 
146 |     pair_summary = llm_response.get('summary', '')
147 | 
148 |     return pair_summary
149 | 
150 | 
151 | async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
152 |     context = {
153 |         'summary': summary,
154 |     }
155 | 
156 |     llm_response = await llm_client.generate_response(
157 |         prompt_library.summarize_nodes.summary_description(context),
158 |         response_model=SummaryDescription,
159 |         prompt_name='summarize_nodes.summary_description',
160 |     )
161 | 
162 |     description = llm_response.get('description', '')
163 | 
164 |     return description
165 | 
166 | 
167 | async def build_community(
168 |     llm_client: LLMClient, community_cluster: list[EntityNode]
169 | ) -> tuple[CommunityNode, list[CommunityEdge]]:
170 |     summaries = [entity.summary for entity in community_cluster]
171 |     length = len(summaries)
172 |     while length > 1:
173 |         odd_one_out: str | None = None
174 |         if length % 2 == 1:
175 |             odd_one_out = summaries.pop()
176 |             length -= 1
177 |         new_summaries: list[str] = list(
178 |             await semaphore_gather(
179 |                 *[
180 |                     summarize_pair(llm_client, (str(left_summary), str(right_summary)))
181 |                     for left_summary, right_summary in zip(
182 |                         summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
183 |                     )
184 |                 ]
185 |             )
186 |         )
187 |         if odd_one_out is not None:
188 |             new_summaries.append(odd_one_out)
189 |         summaries = new_summaries
190 |         length = len(summaries)
191 | 
192 |     summary = summaries[0]
193 |     name = await generate_summary_description(llm_client, summary)
194 |     now = utc_now()
195 |     community_node = CommunityNode(
196 |         name=name,
197 |         group_id=community_cluster[0].group_id,
198 |         labels=['Community'],
199 |         created_at=now,
200 |         summary=summary,
201 |     )
202 |     community_edges = build_community_edges(community_cluster, community_node, now)
203 | 
204 |     logger.debug((community_node, community_edges))
205 | 
206 |     return community_node, community_edges
207 | 
208 | 
209 | async def build_communities(
210 |     driver: GraphDriver,
211 |     llm_client: LLMClient,
212 |     group_ids: list[str] | None,
213 | ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
214 |     community_clusters = await get_community_clusters(driver, group_ids)
215 | 
216 |     semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
217 | 
218 |     async def limited_build_community(cluster):
219 |         async with semaphore:
220 |             return await build_community(llm_client, cluster)
221 | 
222 |     communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
223 |         await semaphore_gather(
224 |             *[limited_build_community(cluster) for cluster in community_clusters]
225 |         )
226 |     )
227 | 
228 |     community_nodes: list[CommunityNode] = []
229 |     community_edges: list[CommunityEdge] = []
230 |     for community in communities:
231 |         community_nodes.append(community[0])
232 |         community_edges.extend(community[1])
233 | 
234 |     return community_nodes, community_edges
235 | 
236 | 
237 | async def remove_communities(driver: GraphDriver):
238 |     await driver.execute_query(
239 |         """
240 |         MATCH (c:Community)
241 |         DETACH DELETE c
242 |         """
243 |     )
244 | 
245 | 
246 | async def determine_entity_community(
247 |     driver: GraphDriver, entity: EntityNode
248 | ) -> tuple[CommunityNode | None, bool]:
249 |     # Check if the node is already part of a community
250 |     records, _, _ = await driver.execute_query(
251 |         """
252 |         MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
253 |         RETURN
254 |         """
255 |         + COMMUNITY_NODE_RETURN,
256 |         entity_uuid=entity.uuid,
257 |     )
258 | 
259 |     if len(records) > 0:
260 |         return get_community_node_from_record(records[0]), False
261 | 
262 |     # If the node has no community, add it to the mode community of surrounding entities
263 |     match_query = """
264 |         MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
265 |     """
266 |     if driver.provider == GraphProvider.KUZU:
267 |         match_query = """
268 |             MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
269 |         """
270 |     records, _, _ = await driver.execute_query(
271 |         match_query
272 |         + """
273 |         RETURN
274 |         """
275 |         + COMMUNITY_NODE_RETURN,
276 |         entity_uuid=entity.uuid,
277 |     )
278 | 
279 |     communities: list[CommunityNode] = [
280 |         get_community_node_from_record(record) for record in records
281 |     ]
282 | 
283 |     community_map: dict[str, int] = defaultdict(int)
284 |     for community in communities:
285 |         community_map[community.uuid] += 1
286 | 
287 |     community_uuid = None
288 |     max_count = 0
289 |     for uuid, count in community_map.items():
290 |         if count > max_count:
291 |             community_uuid = uuid
292 |             max_count = count
293 | 
294 |     if max_count == 0:
295 |         return None, False
296 | 
297 |     for community in communities:
298 |         if community.uuid == community_uuid:
299 |             return community, True
300 | 
301 |     return None, False
302 | 
303 | 
304 | async def update_community(
305 |     driver: GraphDriver,
306 |     llm_client: LLMClient,
307 |     embedder: EmbedderClient,
308 |     entity: EntityNode,
309 | ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
310 |     community, is_new = await determine_entity_community(driver, entity)
311 | 
312 |     if community is None:
313 |         return [], []
314 | 
315 |     new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
316 |     new_name = await generate_summary_description(llm_client, new_summary)
317 | 
318 |     community.summary = new_summary
319 |     community.name = new_name
320 | 
321 |     community_edges = []
322 |     if is_new:
323 |         community_edge = (build_community_edges([entity], community, utc_now()))[0]
324 |         await community_edge.save(driver)
325 |         community_edges.append(community_edge)
326 | 
327 |     await community.generate_name_embedding(embedder)
328 | 
329 |     await community.save(driver)
330 | 
331 |     return [community], community_edges
332 | 
```

--------------------------------------------------------------------------------
/graphiti_core/prompts/extract_nodes.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from typing import Any, Protocol, TypedDict
 18 | 
 19 | from pydantic import BaseModel, Field
 20 | 
 21 | from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS
 22 | 
 23 | from .models import Message, PromptFunction, PromptVersion
 24 | from .prompt_helpers import to_prompt_json
 25 | from .snippets import summary_instructions
 26 | 
 27 | 
 28 | class ExtractedEntity(BaseModel):
 29 |     name: str = Field(..., description='Name of the extracted entity')
 30 |     entity_type_id: int = Field(
 31 |         description='ID of the classified entity type. '
 32 |         'Must be one of the provided entity_type_id integers.',
 33 |     )
 34 | 
 35 | 
 36 | class ExtractedEntities(BaseModel):
 37 |     extracted_entities: list[ExtractedEntity] = Field(..., description='List of extracted entities')
 38 | 
 39 | 
 40 | class MissedEntities(BaseModel):
 41 |     missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
 42 | 
 43 | 
 44 | class EntityClassificationTriple(BaseModel):
 45 |     uuid: str = Field(description='UUID of the entity')
 46 |     name: str = Field(description='Name of the entity')
 47 |     entity_type: str | None = Field(
 48 |         default=None,
 49 |         description='Type of the entity. Must be one of the provided types or None',
 50 |     )
 51 | 
 52 | 
 53 | class EntityClassification(BaseModel):
 54 |     entity_classifications: list[EntityClassificationTriple] = Field(
 55 |         ..., description='List of entities classification triples.'
 56 |     )
 57 | 
 58 | 
 59 | class EntitySummary(BaseModel):
 60 |     summary: str = Field(
 61 |         ...,
 62 |         description=f'Summary containing the important information about the entity. Under {MAX_SUMMARY_CHARS} characters.',
 63 |     )
 64 | 
 65 | 
 66 | class Prompt(Protocol):
 67 |     extract_message: PromptVersion
 68 |     extract_json: PromptVersion
 69 |     extract_text: PromptVersion
 70 |     reflexion: PromptVersion
 71 |     classify_nodes: PromptVersion
 72 |     extract_attributes: PromptVersion
 73 |     extract_summary: PromptVersion
 74 | 
 75 | 
 76 | class Versions(TypedDict):
 77 |     extract_message: PromptFunction
 78 |     extract_json: PromptFunction
 79 |     extract_text: PromptFunction
 80 |     reflexion: PromptFunction
 81 |     classify_nodes: PromptFunction
 82 |     extract_attributes: PromptFunction
 83 |     extract_summary: PromptFunction
 84 | 
 85 | 
 86 | def extract_message(context: dict[str, Any]) -> list[Message]:
 87 |     sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages. 
 88 |     Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
 89 | 
 90 |     user_prompt = f"""
 91 | <ENTITY TYPES>
 92 | {context['entity_types']}
 93 | </ENTITY TYPES>
 94 | 
 95 | <PREVIOUS MESSAGES>
 96 | {to_prompt_json([ep for ep in context['previous_episodes']])}
 97 | </PREVIOUS MESSAGES>
 98 | 
 99 | <CURRENT MESSAGE>
100 | {context['episode_content']}
101 | </CURRENT MESSAGE>
102 | 
103 | Instructions:
104 | 
105 | You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
106 | Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the 
107 | reference entities. Only extract distinct entities from the CURRENT MESSAGE. Don't extract pronouns like you, me, he/she/they, we/us as entities.
108 | 
109 | 1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node.
110 |    - If the speaker is mentioned again in the message, treat both mentions as a **single entity**.
111 | 
112 | 2. **Entity Identification**:
113 |    - Extract all significant entities, concepts, or actors that are **explicitly or implicitly** mentioned in the CURRENT MESSAGE.
114 |    - **Exclude** entities mentioned only in the PREVIOUS MESSAGES (they are for context only).
115 | 
116 | 3. **Entity Classification**:
117 |    - Use the descriptions in ENTITY TYPES to classify each extracted entity.
118 |    - Assign the appropriate `entity_type_id` for each one.
119 | 
120 | 4. **Exclusions**:
121 |    - Do NOT extract entities representing relationships or actions.
122 |    - Do NOT extract dates, times, or other temporal information—these will be handled separately.
123 | 
124 | 5. **Formatting**:
125 |    - Be **explicit and unambiguous** in naming entities (e.g., use full names when available).
126 | 
127 | {context['custom_prompt']}
128 | """
129 |     return [
130 |         Message(role='system', content=sys_prompt),
131 |         Message(role='user', content=user_prompt),
132 |     ]
133 | 
134 | 
135 | def extract_json(context: dict[str, Any]) -> list[Message]:
136 |     sys_prompt = """You are an AI assistant that extracts entity nodes from JSON. 
137 |     Your primary task is to extract and classify relevant entities from JSON files"""
138 | 
139 |     user_prompt = f"""
140 | <ENTITY TYPES>
141 | {context['entity_types']}
142 | </ENTITY TYPES>
143 | 
144 | <SOURCE DESCRIPTION>:
145 | {context['source_description']}
146 | </SOURCE DESCRIPTION>
147 | <JSON>
148 | {context['episode_content']}
149 | </JSON>
150 | 
151 | {context['custom_prompt']}
152 | 
153 | Given the above source description and JSON, extract relevant entities from the provided JSON.
154 | For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
155 | Indicate the classified entity type by providing its entity_type_id.
156 | 
157 | Guidelines:
158 | 1. Extract all entities that the JSON represents. This will often be something like a "name" or "user" field
159 | 2. Extract all entities mentioned in all other properties throughout the JSON structure
160 | 3. Do NOT extract any properties that contain dates
161 | """
162 |     return [
163 |         Message(role='system', content=sys_prompt),
164 |         Message(role='user', content=user_prompt),
165 |     ]
166 | 
167 | 
168 | def extract_text(context: dict[str, Any]) -> list[Message]:
169 |     sys_prompt = """You are an AI assistant that extracts entity nodes from text. 
170 |     Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
171 | 
172 |     user_prompt = f"""
173 | <ENTITY TYPES>
174 | {context['entity_types']}
175 | </ENTITY TYPES>
176 | 
177 | <TEXT>
178 | {context['episode_content']}
179 | </TEXT>
180 | 
181 | Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
182 | For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
183 | Indicate the classified entity type by providing its entity_type_id.
184 | 
185 | {context['custom_prompt']}
186 | 
187 | Guidelines:
188 | 1. Extract significant entities, concepts, or actors mentioned in the conversation.
189 | 2. Avoid creating nodes for relationships or actions.
190 | 3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
191 | 4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
192 | """
193 |     return [
194 |         Message(role='system', content=sys_prompt),
195 |         Message(role='user', content=user_prompt),
196 |     ]
197 | 
198 | 
199 | def reflexion(context: dict[str, Any]) -> list[Message]:
200 |     sys_prompt = """You are an AI assistant that determines which entities have not been extracted from the given context"""
201 | 
202 |     user_prompt = f"""
203 | <PREVIOUS MESSAGES>
204 | {to_prompt_json([ep for ep in context['previous_episodes']])}
205 | </PREVIOUS MESSAGES>
206 | <CURRENT MESSAGE>
207 | {context['episode_content']}
208 | </CURRENT MESSAGE>
209 | 
210 | <EXTRACTED ENTITIES>
211 | {context['extracted_entities']}
212 | </EXTRACTED ENTITIES>
213 | 
214 | Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
215 | extracted.
216 | """
217 |     return [
218 |         Message(role='system', content=sys_prompt),
219 |         Message(role='user', content=user_prompt),
220 |     ]
221 | 
222 | 
223 | def classify_nodes(context: dict[str, Any]) -> list[Message]:
224 |     sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
225 | 
226 |     user_prompt = f"""
227 |     <PREVIOUS MESSAGES>
228 |     {to_prompt_json([ep for ep in context['previous_episodes']])}
229 |     </PREVIOUS MESSAGES>
230 |     <CURRENT MESSAGE>
231 |     {context['episode_content']}
232 |     </CURRENT MESSAGE>
233 | 
234 |     <EXTRACTED ENTITIES>
235 |     {context['extracted_entities']}
236 |     </EXTRACTED ENTITIES>
237 | 
238 |     <ENTITY TYPES>
239 |     {context['entity_types']}
240 |     </ENTITY TYPES>
241 | 
242 |     Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
243 | 
244 |     Guidelines:
245 |     1. Each entity must have exactly one type
246 |     2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
247 |     3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
248 | """
249 |     return [
250 |         Message(role='system', content=sys_prompt),
251 |         Message(role='user', content=user_prompt),
252 |     ]
253 | 
254 | 
255 | def extract_attributes(context: dict[str, Any]) -> list[Message]:
256 |     return [
257 |         Message(
258 |             role='system',
259 |             content='You are a helpful assistant that extracts entity properties from the provided text.',
260 |         ),
261 |         Message(
262 |             role='user',
263 |             content=f"""
264 |         Given the MESSAGES and the following ENTITY, update any of its attributes based on the information provided
265 |         in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined.
266 | 
267 |         Guidelines:
268 |         1. Do not hallucinate entity property values if they cannot be found in the current context.
269 |         2. Only use the provided MESSAGES and ENTITY to set attribute values.
270 | 
271 |         <MESSAGES>
272 |         {to_prompt_json(context['previous_episodes'])}
273 |         {to_prompt_json(context['episode_content'])}
274 |         </MESSAGES>
275 | 
276 |         <ENTITY>
277 |         {context['node']}
278 |         </ENTITY>
279 |         """,
280 |         ),
281 |     ]
282 | 
283 | 
284 | def extract_summary(context: dict[str, Any]) -> list[Message]:
285 |     return [
286 |         Message(
287 |             role='system',
288 |             content='You are a helpful assistant that extracts entity summaries from the provided text.',
289 |         ),
290 |         Message(
291 |             role='user',
292 |             content=f"""
293 |         Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity
294 |         from the messages and relevant information from the existing summary.
295 | 
296 |         {summary_instructions}
297 | 
298 |         <MESSAGES>
299 |         {to_prompt_json(context['previous_episodes'])}
300 |         {to_prompt_json(context['episode_content'])}
301 |         </MESSAGES>
302 | 
303 |         <ENTITY>
304 |         {context['node']}
305 |         </ENTITY>
306 |         """,
307 |         ),
308 |     ]
309 | 
310 | 
311 | versions: Versions = {
312 |     'extract_message': extract_message,
313 |     'extract_json': extract_json,
314 |     'extract_text': extract_text,
315 |     'reflexion': reflexion,
316 |     'extract_summary': extract_summary,
317 |     'classify_nodes': classify_nodes,
318 |     'extract_attributes': extract_attributes,
319 | }
320 | 
```

--------------------------------------------------------------------------------
/graphiti_core/models/edges/edge_db_queries.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from graphiti_core.driver.driver import GraphProvider
 18 | 
 19 | EPISODIC_EDGE_SAVE = """
 20 |     MATCH (episode:Episodic {uuid: $episode_uuid})
 21 |     MATCH (node:Entity {uuid: $entity_uuid})
 22 |     MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
 23 |     SET
 24 |         e.group_id = $group_id,
 25 |         e.created_at = $created_at
 26 |     RETURN e.uuid AS uuid
 27 | """
 28 | 
 29 | 
 30 | def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
 31 |     if provider == GraphProvider.KUZU:
 32 |         return """
 33 |             MATCH (episode:Episodic {uuid: $source_node_uuid})
 34 |             MATCH (node:Entity {uuid: $target_node_uuid})
 35 |             MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
 36 |             SET
 37 |                 e.group_id = $group_id,
 38 |                 e.created_at = $created_at
 39 |             RETURN e.uuid AS uuid
 40 |         """
 41 | 
 42 |     return """
 43 |         UNWIND $episodic_edges AS edge
 44 |         MATCH (episode:Episodic {uuid: edge.source_node_uuid})
 45 |         MATCH (node:Entity {uuid: edge.target_node_uuid})
 46 |         MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
 47 |         SET
 48 |             e.group_id = edge.group_id,
 49 |             e.created_at = edge.created_at
 50 |         RETURN e.uuid AS uuid
 51 |     """
 52 | 
 53 | 
 54 | EPISODIC_EDGE_RETURN = """
 55 |     e.uuid AS uuid,
 56 |     e.group_id AS group_id,
 57 |     n.uuid AS source_node_uuid,
 58 |     m.uuid AS target_node_uuid,
 59 |     e.created_at AS created_at
 60 | """
 61 | 
 62 | 
 63 | def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
 64 |     match provider:
 65 |         case GraphProvider.FALKORDB:
 66 |             return """
 67 |                 MATCH (source:Entity {uuid: $edge_data.source_uuid})
 68 |                 MATCH (target:Entity {uuid: $edge_data.target_uuid})
 69 |                 MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
 70 |                 SET e = $edge_data
 71 |                 SET e.fact_embedding = vecf32($edge_data.fact_embedding)
 72 |                 RETURN e.uuid AS uuid
 73 |             """
 74 |         case GraphProvider.NEPTUNE:
 75 |             return """
 76 |                 MATCH (source:Entity {uuid: $edge_data.source_uuid})
 77 |                 MATCH (target:Entity {uuid: $edge_data.target_uuid})
 78 |                 MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
 79 |                 SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
 80 |                 SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
 81 |                 SET e.episodes = join($edge_data.episodes, ",")
 82 |                 RETURN $edge_data.uuid AS uuid
 83 |             """
 84 |         case GraphProvider.KUZU:
 85 |             return """
 86 |                 MATCH (source:Entity {uuid: $source_uuid})
 87 |                 MATCH (target:Entity {uuid: $target_uuid})
 88 |                 MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
 89 |                 SET
 90 |                     e.group_id = $group_id,
 91 |                     e.created_at = $created_at,
 92 |                     e.name = $name,
 93 |                     e.fact = $fact,
 94 |                     e.fact_embedding = $fact_embedding,
 95 |                     e.episodes = $episodes,
 96 |                     e.expired_at = $expired_at,
 97 |                     e.valid_at = $valid_at,
 98 |                     e.invalid_at = $invalid_at,
 99 |                     e.attributes = $attributes
100 |                 RETURN e.uuid AS uuid
101 |             """
102 |         case _:  # Neo4j
103 |             save_embedding_query = (
104 |                 """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
105 |                 if not has_aoss
106 |                 else ''
107 |             )
108 |             return (
109 |                 (
110 |                     """
111 |                         MATCH (source:Entity {uuid: $edge_data.source_uuid})
112 |                         MATCH (target:Entity {uuid: $edge_data.target_uuid})
113 |                         MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
114 |                         SET e = $edge_data
115 |                         """
116 |                     + save_embedding_query
117 |                 )
118 |                 + """
119 |                 RETURN e.uuid AS uuid
120 |                 """
121 |             )
122 | 
123 | 
124 | def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
125 |     match provider:
126 |         case GraphProvider.FALKORDB:
127 |             return """
128 |                 UNWIND $entity_edges AS edge
129 |                 MATCH (source:Entity {uuid: edge.source_node_uuid})
130 |                 MATCH (target:Entity {uuid: edge.target_node_uuid})
131 |                 MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
132 |                 SET r = edge
133 |                 SET r.fact_embedding = vecf32(edge.fact_embedding)
134 |                 WITH r, edge
135 |                 RETURN edge.uuid AS uuid
136 |             """
137 |         case GraphProvider.NEPTUNE:
138 |             return """
139 |                 UNWIND $entity_edges AS edge
140 |                 MATCH (source:Entity {uuid: edge.source_node_uuid})
141 |                 MATCH (target:Entity {uuid: edge.target_node_uuid})
142 |                 MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
143 |                 SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
144 |                 SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
145 |                 SET r.episodes = join(edge.episodes, ",")
146 |                 RETURN edge.uuid AS uuid
147 |             """
148 |         case GraphProvider.KUZU:
149 |             return """
150 |                 MATCH (source:Entity {uuid: $source_node_uuid})
151 |                 MATCH (target:Entity {uuid: $target_node_uuid})
152 |                 MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
153 |                 SET
154 |                     e.group_id = $group_id,
155 |                     e.created_at = $created_at,
156 |                     e.name = $name,
157 |                     e.fact = $fact,
158 |                     e.fact_embedding = $fact_embedding,
159 |                     e.episodes = $episodes,
160 |                     e.expired_at = $expired_at,
161 |                     e.valid_at = $valid_at,
162 |                     e.invalid_at = $invalid_at,
163 |                     e.attributes = $attributes
164 |                 RETURN e.uuid AS uuid
165 |             """
166 |         case _:
167 |             save_embedding_query = (
168 |                 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
169 |                 if not has_aoss
170 |                 else ''
171 |             )
172 |             return (
173 |                 """
174 |                     UNWIND $entity_edges AS edge
175 |                     MATCH (source:Entity {uuid: edge.source_node_uuid})
176 |                     MATCH (target:Entity {uuid: edge.target_node_uuid})
177 |                     MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
178 |                     SET e = edge
179 |                     """
180 |                 + save_embedding_query
181 |                 + """
182 |                 RETURN edge.uuid AS uuid
183 |             """
184 |             )
185 | 
186 | 
187 | def get_entity_edge_return_query(provider: GraphProvider) -> str:
188 |     # `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
189 | 
190 |     if provider == GraphProvider.NEPTUNE:
191 |         return """
192 |         e.uuid AS uuid,
193 |         n.uuid AS source_node_uuid,
194 |         m.uuid AS target_node_uuid,
195 |         e.group_id AS group_id,
196 |         e.name AS name,
197 |         e.fact AS fact,
198 |         split(e.episodes, ',') AS episodes,
199 |         e.created_at AS created_at,
200 |         e.expired_at AS expired_at,
201 |         e.valid_at AS valid_at,
202 |         e.invalid_at AS invalid_at,
203 |         properties(e) AS attributes
204 |     """
205 | 
206 |     return """
207 |         e.uuid AS uuid,
208 |         n.uuid AS source_node_uuid,
209 |         m.uuid AS target_node_uuid,
210 |         e.group_id AS group_id,
211 |         e.created_at AS created_at,
212 |         e.name AS name,
213 |         e.fact AS fact,
214 |         e.episodes AS episodes,
215 |         e.expired_at AS expired_at,
216 |         e.valid_at AS valid_at,
217 |         e.invalid_at AS invalid_at,
218 |     """ + (
219 |         'e.attributes AS attributes'
220 |         if provider == GraphProvider.KUZU
221 |         else 'properties(e) AS attributes'
222 |     )
223 | 
224 | 
225 | def get_community_edge_save_query(provider: GraphProvider) -> str:
226 |     match provider:
227 |         case GraphProvider.FALKORDB:
228 |             return """
229 |                 MATCH (community:Community {uuid: $community_uuid})
230 |                 MATCH (node {uuid: $entity_uuid})
231 |                 MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
232 |                 SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
233 |                 RETURN e.uuid AS uuid
234 |             """
235 |         case GraphProvider.NEPTUNE:
236 |             return """
237 |                 MATCH (community:Community {uuid: $community_uuid})
238 |                 MATCH (node {uuid: $entity_uuid})
239 |                 WHERE node:Entity OR node:Community
240 |                 MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
241 |                 SET r.uuid= $uuid
242 |                 SET r.group_id= $group_id
243 |                 SET r.created_at= $created_at
244 |                 RETURN r.uuid AS uuid
245 |             """
246 |         case GraphProvider.KUZU:
247 |             return """
248 |                 MATCH (community:Community {uuid: $community_uuid})
249 |                 MATCH (node:Entity {uuid: $entity_uuid})
250 |                 MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
251 |                 SET
252 |                     e.group_id = $group_id,
253 |                     e.created_at = $created_at
254 |                 RETURN e.uuid AS uuid
255 |                 UNION
256 |                 MATCH (community:Community {uuid: $community_uuid})
257 |                 MATCH (node:Community {uuid: $entity_uuid})
258 |                 MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
259 |                 SET
260 |                     e.group_id = $group_id,
261 |                     e.created_at = $created_at
262 |                 RETURN e.uuid AS uuid
263 |             """
264 |         case _:  # Neo4j
265 |             return """
266 |                 MATCH (community:Community {uuid: $community_uuid})
267 |                 MATCH (node:Entity | Community {uuid: $entity_uuid})
268 |                 MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
269 |                 SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
270 |                 RETURN e.uuid AS uuid
271 |             """
272 | 
273 | 
274 | COMMUNITY_EDGE_RETURN = """
275 |     e.uuid AS uuid,
276 |     e.group_id AS group_id,
277 |     n.uuid AS source_node_uuid,
278 |     m.uuid AS target_node_uuid,
279 |     e.created_at AS created_at
280 | """
281 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/run_tests.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Test runner for Graphiti MCP integration tests.
  4 | Provides various test execution modes and reporting options.
  5 | """
  6 | 
  7 | import argparse
  8 | import os
  9 | import sys
 10 | import time
 11 | from pathlib import Path
 12 | 
 13 | import pytest
 14 | from dotenv import load_dotenv
 15 | 
 16 | # Load environment variables from .env file
 17 | env_file = Path(__file__).parent.parent / '.env'
 18 | if env_file.exists():
 19 |     load_dotenv(env_file)
 20 | else:
 21 |     # Try loading from current directory
 22 |     load_dotenv()
 23 | 
 24 | 
 25 | class TestRunner:
 26 |     """Orchestrate test execution with various configurations."""
 27 | 
 28 |     def __init__(self, args):
 29 |         self.args = args
 30 |         self.test_dir = Path(__file__).parent
 31 |         self.results = {}
 32 | 
 33 |     def check_prerequisites(self) -> dict[str, bool]:
 34 |         """Check if required services and dependencies are available."""
 35 |         checks = {}
 36 | 
 37 |         # Check for OpenAI API key if not using mocks
 38 |         if not self.args.mock_llm:
 39 |             api_key = os.environ.get('OPENAI_API_KEY')
 40 |             checks['openai_api_key'] = bool(api_key)
 41 |             if not api_key:
 42 |                 # Check if .env file exists for helpful message
 43 |                 env_path = Path(__file__).parent.parent / '.env'
 44 |                 if not env_path.exists():
 45 |                     checks['openai_api_key_hint'] = (
 46 |                         'Set OPENAI_API_KEY in environment or create mcp_server/.env file'
 47 |                     )
 48 |         else:
 49 |             checks['openai_api_key'] = True
 50 | 
 51 |         # Check database availability based on backend
 52 |         if self.args.database == 'neo4j':
 53 |             checks['neo4j'] = self._check_neo4j()
 54 |         elif self.args.database == 'falkordb':
 55 |             checks['falkordb'] = self._check_falkordb()
 56 | 
 57 |         # Check Python dependencies
 58 |         checks['mcp'] = self._check_python_package('mcp')
 59 |         checks['pytest'] = self._check_python_package('pytest')
 60 |         checks['pytest-asyncio'] = self._check_python_package('pytest-asyncio')
 61 | 
 62 |         return checks
 63 | 
 64 |     def _check_neo4j(self) -> bool:
 65 |         """Check if Neo4j is available."""
 66 |         try:
 67 |             import neo4j
 68 | 
 69 |             # Try to connect
 70 |             uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
 71 |             user = os.environ.get('NEO4J_USER', 'neo4j')
 72 |             password = os.environ.get('NEO4J_PASSWORD', 'graphiti')
 73 | 
 74 |             driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
 75 |             with driver.session() as session:
 76 |                 session.run('RETURN 1')
 77 |             driver.close()
 78 |             return True
 79 |         except Exception:
 80 |             return False
 81 | 
 82 |     def _check_falkordb(self) -> bool:
 83 |         """Check if FalkorDB is available."""
 84 |         try:
 85 |             import redis
 86 | 
 87 |             uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
 88 |             r = redis.from_url(uri)
 89 |             r.ping()
 90 |             return True
 91 |         except Exception:
 92 |             return False
 93 | 
 94 |     def _check_python_package(self, package: str) -> bool:
 95 |         """Check if a Python package is installed."""
 96 |         try:
 97 |             __import__(package.replace('-', '_'))
 98 |             return True
 99 |         except ImportError:
100 |             return False
101 | 
102 |     def run_test_suite(self, suite: str) -> int:
103 |         """Run a specific test suite."""
104 |         pytest_args = ['-v', '--tb=short']
105 | 
106 |         # Add database marker
107 |         if self.args.database:
108 |             for db in ['neo4j', 'falkordb']:
109 |                 if db != self.args.database:
110 |                     pytest_args.extend(['-m', f'not requires_{db}'])
111 | 
112 |         # Add suite-specific arguments
113 |         if suite == 'unit':
114 |             pytest_args.extend(['-m', 'unit', 'test_*.py'])
115 |         elif suite == 'integration':
116 |             pytest_args.extend(['-m', 'integration or not unit', 'test_*.py'])
117 |         elif suite == 'comprehensive':
118 |             pytest_args.append('test_comprehensive_integration.py')
119 |         elif suite == 'async':
120 |             pytest_args.append('test_async_operations.py')
121 |         elif suite == 'stress':
122 |             pytest_args.extend(['-m', 'slow', 'test_stress_load.py'])
123 |         elif suite == 'smoke':
124 |             # Quick smoke test - just basic operations
125 |             pytest_args.extend(
126 |                 [
127 |                     'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization',
128 |                     'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory',
129 |                 ]
130 |             )
131 |         elif suite == 'all':
132 |             pytest_args.append('.')
133 |         else:
134 |             pytest_args.append(suite)
135 | 
136 |         # Add coverage if requested
137 |         if self.args.coverage:
138 |             pytest_args.extend(['--cov=../src', '--cov-report=html'])
139 | 
140 |         # Add parallel execution if requested
141 |         if self.args.parallel:
142 |             pytest_args.extend(['-n', str(self.args.parallel)])
143 | 
144 |         # Add verbosity
145 |         if self.args.verbose:
146 |             pytest_args.append('-vv')
147 | 
148 |         # Add markers to skip
149 |         if self.args.skip_slow:
150 |             pytest_args.extend(['-m', 'not slow'])
151 | 
152 |         # Add timeout override
153 |         if self.args.timeout:
154 |             pytest_args.extend(['--timeout', str(self.args.timeout)])
155 | 
156 |         # Add environment variables
157 |         env = os.environ.copy()
158 |         if self.args.mock_llm:
159 |             env['USE_MOCK_LLM'] = 'true'
160 |         if self.args.database:
161 |             env['DATABASE_PROVIDER'] = self.args.database
162 | 
163 |         # Run tests from the test directory
164 |         print(f'Running {suite} tests with pytest args: {" ".join(pytest_args)}')
165 | 
166 |         # Change to test directory to run tests
167 |         original_dir = os.getcwd()
168 |         os.chdir(self.test_dir)
169 | 
170 |         try:
171 |             result = pytest.main(pytest_args)
172 |         finally:
173 |             os.chdir(original_dir)
174 | 
175 |         return result
176 | 
177 |     def run_performance_benchmark(self):
178 |         """Run performance benchmarking suite."""
179 |         print('Running performance benchmarks...')
180 | 
181 |         # Import test modules
182 | 
183 |         # Run performance tests
184 |         result = pytest.main(
185 |             [
186 |                 '-v',
187 |                 'test_comprehensive_integration.py::TestPerformance',
188 |                 'test_async_operations.py::TestAsyncPerformance',
189 |                 '--benchmark-only' if self.args.benchmark_only else '',
190 |             ]
191 |         )
192 | 
193 |         return result
194 | 
195 |     def generate_report(self):
196 |         """Generate test execution report."""
197 |         report = []
198 |         report.append('\n' + '=' * 60)
199 |         report.append('GRAPHITI MCP TEST EXECUTION REPORT')
200 |         report.append('=' * 60)
201 | 
202 |         # Prerequisites check
203 |         checks = self.check_prerequisites()
204 |         report.append('\nPrerequisites:')
205 |         for check, passed in checks.items():
206 |             status = '✅' if passed else '❌'
207 |             report.append(f'  {status} {check}')
208 | 
209 |         # Test configuration
210 |         report.append('\nConfiguration:')
211 |         report.append(f'  Database: {self.args.database}')
212 |         report.append(f'  Mock LLM: {self.args.mock_llm}')
213 |         report.append(f'  Parallel: {self.args.parallel or "No"}')
214 |         report.append(f'  Timeout: {self.args.timeout}s')
215 | 
216 |         # Results summary (if available)
217 |         if self.results:
218 |             report.append('\nResults:')
219 |             for suite, result in self.results.items():
220 |                 status = '✅ Passed' if result == 0 else f'❌ Failed ({result})'
221 |                 report.append(f'  {suite}: {status}')
222 | 
223 |         report.append('=' * 60)
224 |         return '\n'.join(report)
225 | 
226 | 
227 | def main():
228 |     """Main entry point for test runner."""
229 |     parser = argparse.ArgumentParser(
230 |         description='Run Graphiti MCP integration tests',
231 |         formatter_class=argparse.RawDescriptionHelpFormatter,
232 |         epilog="""
233 | Test Suites:
234 |   unit          - Run unit tests only
235 |   integration   - Run integration tests
236 |   comprehensive - Run comprehensive integration test suite
237 |   async         - Run async operation tests
238 |   stress        - Run stress and load tests
239 |   smoke         - Run quick smoke tests
240 |   all           - Run all tests
241 | 
242 | Examples:
243 |   python run_tests.py smoke                    # Quick smoke test
244 |   python run_tests.py integration --parallel 4 # Run integration tests in parallel
245 |   python run_tests.py stress --database neo4j  # Run stress tests with Neo4j
246 |   python run_tests.py all --coverage          # Run all tests with coverage
247 |         """,
248 |     )
249 | 
250 |     parser.add_argument(
251 |         'suite',
252 |         choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'],
253 |         help='Test suite to run',
254 |     )
255 | 
256 |     parser.add_argument(
257 |         '--database',
258 |         choices=['neo4j', 'falkordb'],
259 |         default='falkordb',
260 |         help='Database backend to test (default: falkordb)',
261 |     )
262 | 
263 |     parser.add_argument('--mock-llm', action='store_true', help='Use mock LLM for faster testing')
264 | 
265 |     parser.add_argument(
266 |         '--parallel', type=int, metavar='N', help='Run tests in parallel with N workers'
267 |     )
268 | 
269 |     parser.add_argument('--coverage', action='store_true', help='Generate coverage report')
270 | 
271 |     parser.add_argument('--verbose', action='store_true', help='Verbose output')
272 | 
273 |     parser.add_argument('--skip-slow', action='store_true', help='Skip slow tests')
274 | 
275 |     parser.add_argument(
276 |         '--timeout', type=int, default=300, help='Test timeout in seconds (default: 300)'
277 |     )
278 | 
279 |     parser.add_argument('--benchmark-only', action='store_true', help='Run only benchmark tests')
280 | 
281 |     parser.add_argument(
282 |         '--check-only', action='store_true', help='Only check prerequisites without running tests'
283 |     )
284 | 
285 |     args = parser.parse_args()
286 | 
287 |     # Create test runner
288 |     runner = TestRunner(args)
289 | 
290 |     # Check prerequisites
291 |     if args.check_only:
292 |         print(runner.generate_report())
293 |         sys.exit(0)
294 | 
295 |     # Check if prerequisites are met
296 |     checks = runner.check_prerequisites()
297 |     # Filter out hint keys from validation
298 |     validation_checks = {k: v for k, v in checks.items() if not k.endswith('_hint')}
299 | 
300 |     if not all(validation_checks.values()):
301 |         print('⚠️  Some prerequisites are not met:')
302 |         for check, passed in checks.items():
303 |             if check.endswith('_hint'):
304 |                 continue  # Skip hint entries
305 |             if not passed:
306 |                 print(f'  ❌ {check}')
307 |                 # Show hint if available
308 |                 hint_key = f'{check}_hint'
309 |                 if hint_key in checks:
310 |                     print(f'     💡 {checks[hint_key]}')
311 | 
312 |         if not args.mock_llm and not checks.get('openai_api_key'):
313 |             print('\n💡 Tip: Use --mock-llm to run tests without OpenAI API key')
314 | 
315 |         response = input('\nContinue anyway? (y/N): ')
316 |         if response.lower() != 'y':
317 |             sys.exit(1)
318 | 
319 |     # Run tests
320 |     print(f'\n🚀 Starting test execution: {args.suite}')
321 |     start_time = time.time()
322 | 
323 |     if args.benchmark_only:
324 |         result = runner.run_performance_benchmark()
325 |     else:
326 |         result = runner.run_test_suite(args.suite)
327 | 
328 |     duration = time.time() - start_time
329 | 
330 |     # Store results
331 |     runner.results[args.suite] = result
332 | 
333 |     # Generate and print report
334 |     print(runner.generate_report())
335 |     print(f'\n⏱️  Test execution completed in {duration:.2f} seconds')
336 | 
337 |     # Exit with test result code
338 |     sys.exit(result)
339 | 
340 | 
341 | if __name__ == '__main__':
342 |     main()
343 | 
```

--------------------------------------------------------------------------------
/graphiti_core/driver/neptune_driver.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import datetime
 19 | import logging
 20 | from collections.abc import Coroutine
 21 | from typing import Any
 22 | 
 23 | import boto3
 24 | from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
 25 | from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
 26 | 
 27 | from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
 28 | 
 29 | logger = logging.getLogger(__name__)
 30 | DEFAULT_SIZE = 10
 31 | 
 32 | aoss_indices = [
 33 |     {
 34 |         'index_name': 'node_name_and_summary',
 35 |         'body': {
 36 |             'mappings': {
 37 |                 'properties': {
 38 |                     'uuid': {'type': 'keyword'},
 39 |                     'name': {'type': 'text'},
 40 |                     'summary': {'type': 'text'},
 41 |                     'group_id': {'type': 'text'},
 42 |                 }
 43 |             }
 44 |         },
 45 |         'query': {
 46 |             'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
 47 |             'size': DEFAULT_SIZE,
 48 |         },
 49 |     },
 50 |     {
 51 |         'index_name': 'community_name',
 52 |         'body': {
 53 |             'mappings': {
 54 |                 'properties': {
 55 |                     'uuid': {'type': 'keyword'},
 56 |                     'name': {'type': 'text'},
 57 |                     'group_id': {'type': 'text'},
 58 |                 }
 59 |             }
 60 |         },
 61 |         'query': {
 62 |             'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
 63 |             'size': DEFAULT_SIZE,
 64 |         },
 65 |     },
 66 |     {
 67 |         'index_name': 'episode_content',
 68 |         'body': {
 69 |             'mappings': {
 70 |                 'properties': {
 71 |                     'uuid': {'type': 'keyword'},
 72 |                     'content': {'type': 'text'},
 73 |                     'source': {'type': 'text'},
 74 |                     'source_description': {'type': 'text'},
 75 |                     'group_id': {'type': 'text'},
 76 |                 }
 77 |             }
 78 |         },
 79 |         'query': {
 80 |             'query': {
 81 |                 'multi_match': {
 82 |                     'query': '',
 83 |                     'fields': ['content', 'source', 'source_description', 'group_id'],
 84 |                 }
 85 |             },
 86 |             'size': DEFAULT_SIZE,
 87 |         },
 88 |     },
 89 |     {
 90 |         'index_name': 'edge_name_and_fact',
 91 |         'body': {
 92 |             'mappings': {
 93 |                 'properties': {
 94 |                     'uuid': {'type': 'keyword'},
 95 |                     'name': {'type': 'text'},
 96 |                     'fact': {'type': 'text'},
 97 |                     'group_id': {'type': 'text'},
 98 |                 }
 99 |             }
100 |         },
101 |         'query': {
102 |             'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
103 |             'size': DEFAULT_SIZE,
104 |         },
105 |     },
106 | ]
107 | 
108 | 
109 | class NeptuneDriver(GraphDriver):
110 |     provider: GraphProvider = GraphProvider.NEPTUNE
111 | 
112 |     def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
113 |         """This initializes a NeptuneDriver for use with Neptune as a backend
114 | 
115 |         Args:
116 |             host (str): The Neptune Database or Neptune Analytics host
117 |             aoss_host (str): The OpenSearch host value
118 |             port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
119 |             aoss_port (int, optional): The OpenSearch port. Defaults to 443.
120 |         """
121 |         if not host:
122 |             raise ValueError('You must provide an endpoint to create a NeptuneDriver')
123 | 
124 |         if host.startswith('neptune-db://'):
125 |             # This is a Neptune Database Cluster
126 |             endpoint = host.replace('neptune-db://', '')
127 |             self.client = NeptuneGraph(endpoint, port)
128 |             logger.debug('Creating Neptune Database session for %s', host)
129 |         elif host.startswith('neptune-graph://'):
130 |             # This is a Neptune Analytics Graph
131 |             graphId = host.replace('neptune-graph://', '')
132 |             self.client = NeptuneAnalyticsGraph(graphId)
133 |             logger.debug('Creating Neptune Graph session for %s', host)
134 |         else:
135 |             raise ValueError(
136 |                 'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
137 |             )
138 | 
139 |         if not aoss_host:
140 |             raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
141 | 
142 |         session = boto3.Session()
143 |         self.aoss_client = OpenSearch(
144 |             hosts=[{'host': aoss_host, 'port': aoss_port}],
145 |             http_auth=Urllib3AWSV4SignerAuth(
146 |                 session.get_credentials(), session.region_name, 'aoss'
147 |             ),
148 |             use_ssl=True,
149 |             verify_certs=True,
150 |             connection_class=Urllib3HttpConnection,
151 |             pool_maxsize=20,
152 |         )
153 | 
154 |     def _sanitize_parameters(self, query, params: dict):
155 |         if isinstance(query, list):
156 |             queries = []
157 |             for q in query:
158 |                 queries.append(self._sanitize_parameters(q, params))
159 |             return queries
160 |         else:
161 |             for k, v in params.items():
162 |                 if isinstance(v, datetime.datetime):
163 |                     params[k] = v.isoformat()
164 |                 elif isinstance(v, list):
165 |                     # Handle lists that might contain datetime objects
166 |                     for i, item in enumerate(v):
167 |                         if isinstance(item, datetime.datetime):
168 |                             v[i] = item.isoformat()
169 |                             query = str(query).replace(f'${k}', f'datetime(${k})')
170 |                         if isinstance(item, dict):
171 |                             query = self._sanitize_parameters(query, v[i])
172 | 
173 |                     # If the list contains datetime objects, we need to wrap each element with datetime()
174 |                     if any(isinstance(item, str) and 'T' in item for item in v):
175 |                         # Create a new list expression with datetime() wrapped around each element
176 |                         datetime_list = (
177 |                             '['
178 |                             + ', '.join(
179 |                                 f'datetime("{item}")'
180 |                                 if isinstance(item, str) and 'T' in item
181 |                                 else repr(item)
182 |                                 for item in v
183 |                             )
184 |                             + ']'
185 |                         )
186 |                         query = str(query).replace(f'${k}', datetime_list)
187 |                 elif isinstance(v, dict):
188 |                     query = self._sanitize_parameters(query, v)
189 |             return query
190 | 
191 |     async def execute_query(
192 |         self, cypher_query_, **kwargs: Any
193 |     ) -> tuple[dict[str, Any], None, None]:
194 |         params = dict(kwargs)
195 |         if isinstance(cypher_query_, list):
196 |             for q in cypher_query_:
197 |                 result, _, _ = self._run_query(q[0], q[1])
198 |             return result, None, None
199 |         else:
200 |             return self._run_query(cypher_query_, params)
201 | 
202 |     def _run_query(self, cypher_query_, params):
203 |         cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
204 |         try:
205 |             result = self.client.query(cypher_query_, params=params)
206 |         except Exception as e:
207 |             logger.error('Query: %s', cypher_query_)
208 |             logger.error('Parameters: %s', params)
209 |             logger.error('Error executing query: %s', e)
210 |             raise e
211 | 
212 |         return result, None, None
213 | 
214 |     def session(self, database: str | None = None) -> GraphDriverSession:
215 |         return NeptuneDriverSession(driver=self)
216 | 
217 |     async def close(self) -> None:
218 |         return self.client.client.close()
219 | 
220 |     async def _delete_all_data(self) -> Any:
221 |         return await self.execute_query('MATCH (n) DETACH DELETE n')
222 | 
223 |     def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
224 |         return self.delete_all_indexes_impl()
225 | 
226 |     async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
227 |         # No matter what happens above, always return True
228 |         return self.delete_aoss_indices()
229 | 
230 |     async def create_aoss_indices(self):
231 |         for index in aoss_indices:
232 |             index_name = index['index_name']
233 |             client = self.aoss_client
234 |             if not client.indices.exists(index=index_name):
235 |                 client.indices.create(index=index_name, body=index['body'])
236 |         # Sleep for 1 minute to let the index creation complete
237 |         await asyncio.sleep(60)
238 | 
239 |     async def delete_aoss_indices(self):
240 |         for index in aoss_indices:
241 |             index_name = index['index_name']
242 |             client = self.aoss_client
243 |             if client.indices.exists(index=index_name):
244 |                 client.indices.delete(index=index_name)
245 | 
246 |     async def build_indices_and_constraints(self, delete_existing: bool = False):
247 |         # Neptune uses OpenSearch (AOSS) for indexing
248 |         if delete_existing:
249 |             await self.delete_aoss_indices()
250 |         await self.create_aoss_indices()
251 | 
252 |     def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
253 |         for index in aoss_indices:
254 |             if name.lower() == index['index_name']:
255 |                 index['query']['query']['multi_match']['query'] = query_text
256 |                 query = {'size': limit, 'query': index['query']}
257 |                 resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
258 |                 return resp
259 |         return {}
260 | 
261 |     def save_to_aoss(self, name: str, data: list[dict]) -> int:
262 |         for index in aoss_indices:
263 |             if name.lower() == index['index_name']:
264 |                 to_index = []
265 |                 for d in data:
266 |                     item = {'_index': name, '_id': d['uuid']}
267 |                     for p in index['body']['mappings']['properties']:
268 |                         if p in d:
269 |                             item[p] = d[p]
270 |                     to_index.append(item)
271 |                 success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
272 |                 return success
273 | 
274 |         return 0
275 | 
276 | 
277 | class NeptuneDriverSession(GraphDriverSession):
278 |     provider = GraphProvider.NEPTUNE
279 | 
280 |     def __init__(self, driver: NeptuneDriver):  # type: ignore[reportUnknownArgumentType]
281 |         self.driver = driver
282 | 
283 |     async def __aenter__(self):
284 |         return self
285 | 
286 |     async def __aexit__(self, exc_type, exc, tb):
287 |         # No cleanup needed for Neptune, but method must exist
288 |         pass
289 | 
290 |     async def close(self):
291 |         # No explicit close needed for Neptune, but method must exist
292 |         pass
293 | 
294 |     async def execute_write(self, func, *args, **kwargs):
295 |         # Directly await the provided async function with `self` as the transaction/session
296 |         return await func(self, *args, **kwargs)
297 | 
298 |     async def run(self, query: str | list, **kwargs: Any) -> Any:
299 |         if isinstance(query, list):
300 |             res = None
301 |             for q in query:
302 |                 res = await self.driver.execute_query(q, **kwargs)
303 |             return res
304 |         else:
305 |             return await self.driver.execute_query(str(query), **kwargs)
306 | 
```
Page 5/12FirstPrevNextLast