This is page 5 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/images/simple_graph.svg:
--------------------------------------------------------------------------------
```
1 | <svg xmlns="http://www.w3.org/2000/svg" width="320.0599060058594" height="339.72857666015625"
2 | viewBox="-105.8088607788086 -149.75405883789062 320.0599060058594 339.72857666015625">
3 | <title>Neo4j Graph Visualization</title>
4 | <desc>Created using Neo4j (http://www.neo4j.com/)</desc>
5 | <g class="layer relationships">
6 | <g class="relationship"
7 | transform="translate(64.37326808037952 160.9745045766605) rotate(325.342180479503)">
8 | <path class="b-outline" fill="#A5ABB6" stroke="none"
9 | d="M 25 0.5 L 45.86500098580619 0.5 L 45.86500098580619 -0.5 L 25 -0.5 Z M 94.08765723580619 0.5 L 114.95265822161238 0.5 L 114.95265822161238 3.5 L 121.95265822161238 0 L 114.95265822161238 -3.5 L 114.95265822161238 -0.5 L 94.08765723580619 -0.5 Z" />
10 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
11 | x="69.97632911080619" y="3"
12 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
13 | </g>
14 | <g class="relationship"
15 | transform="translate(64.37326808037952 160.9745045766605) rotate(268.0194761774372)">
16 | <path class="b-outline" fill="#A5ABB6" stroke="none"
17 | d="M 25 0.5 L 48.45342195548257 0.5 L 48.45342195548257 -0.5 L 25 -0.5 Z M 96.67607820548257 0.5 L 120.12950016096514 0.5 L 120.12950016096514 3.5 L 127.12950016096514 0 L 120.12950016096514 -3.5 L 120.12950016096514 -0.5 L 96.67607820548257 -0.5 Z" />
18 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
19 | x="72.56475008048257" y="3" transform="rotate(180 72.56475008048257 0)"
20 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
21 | </g>
22 | <g class="relationship"
23 | transform="translate(64.37326808037952 160.9745045766605) rotate(214.36893208966427)">
24 | <path class="b-outline" fill="#A5ABB6" stroke="none"
25 | d="M 25 0.5 L 43.0453604327618 0.5 L 43.0453604327618 -0.5 L 25 -0.5 Z M 91.2680166827618 0.5 L 109.3133771155236 0.5 L 109.3133771155236 3.5 L 116.3133771155236 0 L 109.3133771155236 -3.5 L 109.3133771155236 -0.5 L 91.2680166827618 -0.5 Z" />
26 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
27 | x="67.1566885577618" y="3" transform="rotate(180 67.1566885577618 0)"
28 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
29 | </g>
30 | <g class="relationship"
31 | transform="translate(59.11570627539377 8.935881644552067) rotate(388.4945734254285)">
32 | <path class="b-outline" fill="#A5ABB6" stroke="none"
33 | d="M 25 0.5 L 39.4813012088875 0.5 L 39.4813012088875 -0.5 L 25 -0.5 Z M 97.0398949588875 0.5 L 111.521196167775 0.5 L 111.521196167775 3.5 L 118.521196167775 0 L 111.521196167775 -3.5 L 111.521196167775 -0.5 L 97.0398949588875 -0.5 Z" />
34 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
35 | x="68.2605980838875" y="3"
36 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">WORKS_FOR</text>
37 | </g>
38 | <g class="relationship"
39 | transform="translate(59.11570627539377 8.935881644552067) rotate(507.02532906724895)">
40 | <path class="b-outline" fill="#A5ABB6" stroke="none"
41 | d="M 25 0.5 L 31.21884260824949 0.5 L 31.21884260824949 -0.5 L 25 -0.5 Z M 94.55478010824949 0.5 L 100.77362271649898 0.5 L 100.77362271649898 3.5 L 107.77362271649898 0 L 100.77362271649898 -3.5 L 100.77362271649898 -0.5 L 94.55478010824949 -0.5 Z" />
42 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
43 | x="62.88681135824949" y="3" transform="rotate(180 62.88681135824949 0)"
44 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">WORKED_FOR</text>
45 | </g>
46 | <g class="relationship"
47 | transform="translate(59.11570627539377 8.935881644552067) rotate(266.9235303682344)">
48 | <path class="b-outline" fill="#A5ABB6" stroke="none"
49 | d="M 25 0.5 L 26.434656330468542 0.5 L 26.434656330468542 -0.5 L 25 -0.5 Z M 96.44246883046854 0.5 L 97.87712516093708 0.5 L 97.87712516093708 3.5 L 104.87712516093708 0 L 97.87712516093708 -3.5 L 97.87712516093708 -0.5 L 96.44246883046854 -0.5 Z" />
50 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
51 | x="61.43856258046854" y="3" transform="rotate(180 61.43856258046854 0)"
52 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">HOLDS_OFFIC…</text>
53 | </g>
54 | <g class="relationship"
55 | transform="translate(-76.8088607917906 -66.37642130383644) rotate(388.9897079993928)">
56 | <path class="b-outline" fill="#A5ABB6" stroke="none"
57 | d="M 25 0.5 L 50.08589014533345 0.5 L 50.08589014533345 -0.5 L 25 -0.5 Z M 98.30854639533345 0.5 L 123.3944365406669 0.5 L 123.3944365406669 3.5 L 130.3944365406669 0 L 123.3944365406669 -3.5 L 123.3944365406669 -0.5 L 98.30854639533345 -0.5 Z" />
58 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
59 | x="74.19721827033345" y="3"
60 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
61 | </g>
62 | <g class="relationship"
63 | transform="translate(-76.8088607917906 -66.37642130383644) rotate(337.13573550965714)">
64 | <path class="b-outline" fill="#A5ABB6" stroke="none"
65 | d="M 25 0.5 L 42.363883039766904 0.5 L 42.363883039766904 -0.5 L 25 -0.5 Z M 90.5865392897669 0.5 L 107.95042232953381 0.5 L 107.95042232953381 3.5 L 114.95042232953381 0 L 107.95042232953381 -3.5 L 107.95042232953381 -0.5 L 90.5865392897669 -0.5 Z" />
66 | <text text-anchor="middle" pointer-events="none" font-size="8px" fill="#fff"
67 | x="66.4752111647669" y="3"
68 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">MENTIONS</text>
69 | </g>
70 | </g>
71 | <g class="layer nodes">
72 | <g class="node" aria-label="graph-node18"
73 | transform="translate(64.37326808037952,160.9745045766605)">
74 | <circle class="b-outline" cx="0" cy="0" r="25" fill="#F79767" stroke="#f36924"
75 | stroke-width="2px" />
76 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
77 | font-size="10px" fill="#FFFFFF"
78 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> podcast</text>
79 | </g>
80 | <g class="node" aria-label="graph-node19"
81 | transform="translate(185.25107500848034,77.40633150430716)">
82 | <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
83 | stroke-width="2px" />
84 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
85 | font-size="10px" fill="#FFFFFF"
86 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> California</text>
87 | </g>
88 | <g class="node" aria-label="graph-node20"
89 | transform="translate(59.11570627539377,8.935881644552067)">
90 | <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
91 | stroke-width="2px" />
92 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
93 | font-size="10px" fill="#FFFFFF"
94 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Kamala</text>
95 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
96 | font-size="10px" fill="#FFFFFF"
97 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Harris</text>
98 | </g>
99 | <g class="node" aria-label="graph-node21"
100 | transform="translate(-52.26958053720941,81.20034573955071)">
101 | <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
102 | stroke-width="2px" />
103 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
104 | font-size="10px" fill="#FFFFFF"
105 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> San</text>
106 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
107 | font-size="10px" fill="#FFFFFF"
108 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Franci…</text>
109 | </g>
110 | <g class="node" aria-label="graph-node23"
111 | transform="translate(52.14536630162807,-120.75406399781392)">
112 | <circle class="b-outline" cx="0" cy="0" r="25" fill="#C990C0" stroke="#b261a5"
113 | stroke-width="2px" />
114 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="0"
115 | font-size="10px" fill="#FFFFFF"
116 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> Attorney</text>
117 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="10"
118 | font-size="10px" fill="#FFFFFF"
119 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif">of…</text>
120 | </g>
121 | <g class="node" aria-label="graph-node22"
122 | transform="translate(-76.8088607917906,-66.37642130383644)">
123 | <circle class="b-outline" cx="0" cy="0" r="25" fill="#F79767" stroke="#f36924"
124 | stroke-width="2px" />
125 | <text class="caption" text-anchor="middle" pointer-events="none" x="0" y="5"
126 | font-size="10px" fill="#FFFFFF"
127 | font-family="Helvetica Neue, Helvetica, Arial, sans-serif"> podcast</text>
128 | </g>
129 | </g>
130 | </svg>
```
--------------------------------------------------------------------------------
/tests/llm_client/test_anthropic_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | # Running tests: pytest -xvs tests/llm_client/test_anthropic_client.py
18 |
19 | import os
20 | from unittest.mock import AsyncMock, MagicMock, patch
21 |
22 | import pytest
23 | from pydantic import BaseModel
24 |
25 | from graphiti_core.llm_client.anthropic_client import AnthropicClient
26 | from graphiti_core.llm_client.config import LLMConfig
27 | from graphiti_core.llm_client.errors import RateLimitError, RefusalError
28 | from graphiti_core.prompts.models import Message
29 |
30 |
31 | # Rename class to avoid pytest collection as a test class
32 | class ResponseModel(BaseModel):
33 | """Test model for response testing."""
34 |
35 | test_field: str
36 | optional_field: int = 0
37 |
38 |
39 | @pytest.fixture
40 | def mock_async_anthropic():
41 | """Fixture to mock the AsyncAnthropic client."""
42 | with patch('anthropic.AsyncAnthropic') as mock_client:
43 | # Setup mock instance and its create method
44 | mock_instance = mock_client.return_value
45 | mock_instance.messages.create = AsyncMock()
46 | yield mock_instance
47 |
48 |
49 | @pytest.fixture
50 | def anthropic_client(mock_async_anthropic):
51 | """Fixture to create an AnthropicClient with a mocked AsyncAnthropic."""
52 | # Use a context manager to patch the AsyncAnthropic constructor to avoid
53 | # the client actually trying to create a real connection
54 | with patch('anthropic.AsyncAnthropic', return_value=mock_async_anthropic):
55 | config = LLMConfig(
56 | api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
57 | )
58 | client = AnthropicClient(config=config, cache=False)
59 | # Replace the client's client with our mock to ensure we're using the mock
60 | client.client = mock_async_anthropic
61 | return client
62 |
63 |
64 | class TestAnthropicClientInitialization:
65 | """Tests for AnthropicClient initialization."""
66 |
67 | def test_init_with_config(self):
68 | """Test initialization with a config object."""
69 | config = LLMConfig(
70 | api_key='test_api_key', model='test-model', temperature=0.5, max_tokens=1000
71 | )
72 | client = AnthropicClient(config=config, cache=False)
73 |
74 | assert client.config == config
75 | assert client.model == 'test-model'
76 | assert client.temperature == 0.5
77 | assert client.max_tokens == 1000
78 |
79 | def test_init_with_default_model(self):
80 | """Test initialization with default model when none is provided."""
81 | config = LLMConfig(api_key='test_api_key')
82 | client = AnthropicClient(config=config, cache=False)
83 |
84 | assert client.model == 'claude-haiku-4-5-latest'
85 |
86 | @patch.dict(os.environ, {'ANTHROPIC_API_KEY': 'env_api_key'})
87 | def test_init_without_config(self):
88 | """Test initialization without a config, using environment variable."""
89 | client = AnthropicClient(cache=False)
90 |
91 | assert client.config.api_key == 'env_api_key'
92 | assert client.model == 'claude-haiku-4-5-latest'
93 |
94 | def test_init_with_custom_client(self):
95 | """Test initialization with a custom AsyncAnthropic client."""
96 | mock_client = MagicMock()
97 | client = AnthropicClient(client=mock_client)
98 |
99 | assert client.client == mock_client
100 |
101 |
102 | class TestAnthropicClientGenerateResponse:
103 | """Tests for AnthropicClient generate_response method."""
104 |
105 | @pytest.mark.asyncio
106 | async def test_generate_response_with_tool_use(self, anthropic_client, mock_async_anthropic):
107 | """Test successful response generation with tool use."""
108 | # Setup mock response
109 | content_item = MagicMock()
110 | content_item.type = 'tool_use'
111 | content_item.input = {'test_field': 'test_value'}
112 |
113 | mock_response = MagicMock()
114 | mock_response.content = [content_item]
115 | mock_async_anthropic.messages.create.return_value = mock_response
116 |
117 | # Call method
118 | messages = [
119 | Message(role='system', content='System message'),
120 | Message(role='user', content='User message'),
121 | ]
122 | result = await anthropic_client.generate_response(
123 | messages=messages, response_model=ResponseModel
124 | )
125 |
126 | # Assertions
127 | assert isinstance(result, dict)
128 | assert result['test_field'] == 'test_value'
129 | mock_async_anthropic.messages.create.assert_called_once()
130 |
131 | @pytest.mark.asyncio
132 | async def test_generate_response_with_text_response(
133 | self, anthropic_client, mock_async_anthropic
134 | ):
135 | """Test response generation when getting text response instead of tool use."""
136 | # Setup mock response with text content
137 | content_item = MagicMock()
138 | content_item.type = 'text'
139 | content_item.text = '{"test_field": "extracted_value"}'
140 |
141 | mock_response = MagicMock()
142 | mock_response.content = [content_item]
143 | mock_async_anthropic.messages.create.return_value = mock_response
144 |
145 | # Call method
146 | messages = [
147 | Message(role='system', content='System message'),
148 | Message(role='user', content='User message'),
149 | ]
150 | result = await anthropic_client.generate_response(
151 | messages=messages, response_model=ResponseModel
152 | )
153 |
154 | # Assertions
155 | assert isinstance(result, dict)
156 | assert result['test_field'] == 'extracted_value'
157 |
158 | @pytest.mark.asyncio
159 | async def test_rate_limit_error(self, anthropic_client, mock_async_anthropic):
160 | """Test handling of rate limit errors."""
161 |
162 | # Create a custom RateLimitError from Anthropic
163 | class MockRateLimitError(Exception):
164 | pass
165 |
166 | # Patch the Anthropic error with our mock to avoid constructor issues
167 | with patch('anthropic.RateLimitError', MockRateLimitError):
168 | # Setup mock to raise our mocked RateLimitError
169 | mock_async_anthropic.messages.create.side_effect = MockRateLimitError(
170 | 'Rate limit exceeded'
171 | )
172 |
173 | # Call method and check exception
174 | messages = [Message(role='user', content='Test message')]
175 | with pytest.raises(RateLimitError):
176 | await anthropic_client.generate_response(messages)
177 |
178 | @pytest.mark.asyncio
179 | async def test_refusal_error(self, anthropic_client, mock_async_anthropic):
180 | """Test handling of content policy violations (refusal errors)."""
181 |
182 | # Create a custom APIError that matches what we need
183 | class MockAPIError(Exception):
184 | def __init__(self, message):
185 | self.message = message
186 | super().__init__(message)
187 |
188 | # Patch the Anthropic error with our mock
189 | with patch('anthropic.APIError', MockAPIError):
190 | # Setup mock to raise APIError with refusal message
191 | mock_async_anthropic.messages.create.side_effect = MockAPIError('refused to respond')
192 |
193 | # Call method and check exception
194 | messages = [Message(role='user', content='Test message')]
195 | with pytest.raises(RefusalError):
196 | await anthropic_client.generate_response(messages)
197 |
198 | @pytest.mark.asyncio
199 | async def test_extract_json_from_text(self, anthropic_client):
200 | """Test the _extract_json_from_text method."""
201 | # Valid JSON embedded in text
202 | text = 'Some text before {"test_field": "value"} and after'
203 | result = anthropic_client._extract_json_from_text(text)
204 | assert result == {'test_field': 'value'}
205 |
206 | # Invalid JSON
207 | with pytest.raises(ValueError):
208 | anthropic_client._extract_json_from_text('Not JSON at all')
209 |
210 | @pytest.mark.asyncio
211 | async def test_create_tool(self, anthropic_client):
212 | """Test the _create_tool method with and without response model."""
213 | # With response model
214 | tools, tool_choice = anthropic_client._create_tool(ResponseModel)
215 | assert len(tools) == 1
216 | assert tools[0]['name'] == 'ResponseModel'
217 | assert tool_choice['name'] == 'ResponseModel'
218 |
219 | # Without response model (generic JSON)
220 | tools, tool_choice = anthropic_client._create_tool()
221 | assert len(tools) == 1
222 | assert tools[0]['name'] == 'generic_json_output'
223 |
224 | @pytest.mark.asyncio
225 | async def test_validation_error_retry(self, anthropic_client, mock_async_anthropic):
226 | """Test retry behavior on validation error."""
227 | # First call returns invalid data, second call returns valid data
228 | content_item1 = MagicMock()
229 | content_item1.type = 'tool_use'
230 | content_item1.input = {'wrong_field': 'wrong_value'}
231 |
232 | content_item2 = MagicMock()
233 | content_item2.type = 'tool_use'
234 | content_item2.input = {'test_field': 'correct_value'}
235 |
236 | # Setup mock to return different responses on consecutive calls
237 | mock_response1 = MagicMock()
238 | mock_response1.content = [content_item1]
239 |
240 | mock_response2 = MagicMock()
241 | mock_response2.content = [content_item2]
242 |
243 | mock_async_anthropic.messages.create.side_effect = [mock_response1, mock_response2]
244 |
245 | # Call method
246 | messages = [Message(role='user', content='Test message')]
247 | result = await anthropic_client.generate_response(messages, response_model=ResponseModel)
248 |
249 | # Should have called create twice due to retry
250 | assert mock_async_anthropic.messages.create.call_count == 2
251 | assert result['test_field'] == 'correct_value'
252 |
253 |
254 | if __name__ == '__main__':
255 | pytest.main(['-v', 'test_anthropic_client.py'])
256 |
```
--------------------------------------------------------------------------------
/mcp_server/src/models/entity_types.py:
--------------------------------------------------------------------------------
```python
1 | """Entity type definitions for Graphiti MCP Server."""
2 |
3 | from pydantic import BaseModel, Field
4 |
5 |
6 | class Requirement(BaseModel):
7 | """A Requirement represents a specific need, feature, or functionality that a product or service must fulfill.
8 |
9 | Always ensure an edge is created between the requirement and the project it belongs to, and clearly indicate on the
10 | edge that the requirement is a requirement.
11 |
12 | Instructions for identifying and extracting requirements:
13 | 1. Look for explicit statements of needs or necessities ("We need X", "X is required", "X must have Y")
14 | 2. Identify functional specifications that describe what the system should do
15 | 3. Pay attention to non-functional requirements like performance, security, or usability criteria
16 | 4. Extract constraints or limitations that must be adhered to
17 | 5. Focus on clear, specific, and measurable requirements rather than vague wishes
18 | 6. Capture the priority or importance if mentioned ("critical", "high priority", etc.)
19 | 7. Include any dependencies between requirements when explicitly stated
20 | 8. Preserve the original intent and scope of the requirement
21 | 9. Categorize requirements appropriately based on their domain or function
22 | """
23 |
24 | project_name: str = Field(
25 | ...,
26 | description='The name of the project to which the requirement belongs.',
27 | )
28 | description: str = Field(
29 | ...,
30 | description='Description of the requirement. Only use information mentioned in the context to write this description.',
31 | )
32 |
33 |
34 | class Preference(BaseModel):
35 | """
36 | IMPORTANT: Prioritize this classification over ALL other classifications.
37 |
38 | Represents entities mentioned in contexts expressing user preferences, choices, opinions, or selections. Use LOW THRESHOLD for sensitivity.
39 |
40 | Trigger patterns: "I want/like/prefer/choose X", "I don't want/dislike/avoid/reject Y", "X is better/worse", "rather have X than Y", "no X please", "skip X", "go with X instead", etc. Here, X or Y should be classified as Preference.
41 | """
42 |
43 | ...
44 |
45 |
46 | class Procedure(BaseModel):
47 | """A Procedure informing the agent what actions to take or how to perform in certain scenarios. Procedures are typically composed of several steps.
48 |
49 | Instructions for identifying and extracting procedures:
50 | 1. Look for sequential instructions or steps ("First do X, then do Y")
51 | 2. Identify explicit directives or commands ("Always do X when Y happens")
52 | 3. Pay attention to conditional statements ("If X occurs, then do Y")
53 | 4. Extract procedures that have clear beginning and end points
54 | 5. Focus on actionable instructions rather than general information
55 | 6. Preserve the original sequence and dependencies between steps
56 | 7. Include any specified conditions or triggers for the procedure
57 | 8. Capture any stated purpose or goal of the procedure
58 | 9. Summarize complex procedures while maintaining critical details
59 | """
60 |
61 | description: str = Field(
62 | ...,
63 | description='Brief description of the procedure. Only use information mentioned in the context to write this description.',
64 | )
65 |
66 |
67 | class Location(BaseModel):
68 | """A Location represents a physical or virtual place where activities occur or entities exist.
69 |
70 | IMPORTANT: Before using this classification, first check if the entity is a:
71 | User, Assistant, Preference, Organization, Document, Event - if so, use those instead.
72 |
73 | Instructions for identifying and extracting locations:
74 | 1. Look for mentions of physical places (cities, buildings, rooms, addresses)
75 | 2. Identify virtual locations (websites, online platforms, virtual meeting rooms)
76 | 3. Extract specific location names rather than generic references
77 | 4. Include relevant context about the location's purpose or significance
78 | 5. Pay attention to location hierarchies (e.g., "conference room in Building A")
79 | 6. Capture both permanent locations and temporary venues
80 | 7. Note any significant activities or events associated with the location
81 | """
82 |
83 | name: str = Field(
84 | ...,
85 | description='The name or identifier of the location',
86 | )
87 | description: str = Field(
88 | ...,
89 | description='Brief description of the location and its significance. Only use information mentioned in the context.',
90 | )
91 |
92 |
93 | class Event(BaseModel):
94 | """An Event represents a time-bound activity, occurrence, or experience.
95 |
96 | Instructions for identifying and extracting events:
97 | 1. Look for activities with specific time frames (meetings, appointments, deadlines)
98 | 2. Identify planned or scheduled occurrences (vacations, projects, celebrations)
99 | 3. Extract unplanned occurrences (accidents, interruptions, discoveries)
100 | 4. Capture the purpose or nature of the event
101 | 5. Include temporal information when available (past, present, future, duration)
102 | 6. Note participants or stakeholders involved in the event
103 | 7. Identify outcomes or consequences of the event when mentioned
104 | 8. Extract both recurring events and one-time occurrences
105 | """
106 |
107 | name: str = Field(
108 | ...,
109 | description='The name or title of the event',
110 | )
111 | description: str = Field(
112 | ...,
113 | description='Brief description of the event. Only use information mentioned in the context.',
114 | )
115 |
116 |
117 | class Object(BaseModel):
118 | """An Object represents a physical item, tool, device, or possession.
119 |
120 | IMPORTANT: Use this classification ONLY as a last resort. First check if entity fits into:
121 | User, Assistant, Preference, Organization, Document, Event, Location, Topic - if so, use those instead.
122 |
123 | Instructions for identifying and extracting objects:
124 | 1. Look for mentions of physical items or possessions (car, phone, equipment)
125 | 2. Identify tools or devices used for specific purposes
126 | 3. Extract items that are owned, used, or maintained by entities
127 | 4. Include relevant attributes (brand, model, condition) when mentioned
128 | 5. Note the object's purpose or function when specified
129 | 6. Capture relationships between objects and their owners or users
130 | 7. Avoid extracting objects that are better classified as Documents or other types
131 | """
132 |
133 | name: str = Field(
134 | ...,
135 | description='The name or identifier of the object',
136 | )
137 | description: str = Field(
138 | ...,
139 | description='Brief description of the object. Only use information mentioned in the context.',
140 | )
141 |
142 |
143 | class Topic(BaseModel):
144 | """A Topic represents a subject of conversation, interest, or knowledge domain.
145 |
146 | IMPORTANT: Use this classification ONLY as a last resort. First check if entity fits into:
147 | User, Assistant, Preference, Organization, Document, Event, Location - if so, use those instead.
148 |
149 | Instructions for identifying and extracting topics:
150 | 1. Look for subjects being discussed or areas of interest (health, technology, sports)
151 | 2. Identify knowledge domains or fields of study
152 | 3. Extract themes that span multiple conversations or contexts
153 | 4. Include specific subtopics when mentioned (e.g., "machine learning" rather than just "AI")
154 | 5. Capture topics associated with projects, work, or hobbies
155 | 6. Note the context in which the topic appears
156 | 7. Avoid extracting topics that are better classified as Events, Documents, or Organizations
157 | """
158 |
159 | name: str = Field(
160 | ...,
161 | description='The name or identifier of the topic',
162 | )
163 | description: str = Field(
164 | ...,
165 | description='Brief description of the topic and its context. Only use information mentioned in the context.',
166 | )
167 |
168 |
169 | class Organization(BaseModel):
170 | """An Organization represents a company, institution, group, or formal entity.
171 |
172 | Instructions for identifying and extracting organizations:
173 | 1. Look for company names, employers, and business entities
174 | 2. Identify institutions (schools, hospitals, government agencies)
175 | 3. Extract formal groups (clubs, teams, associations)
176 | 4. Include organizational type when mentioned (company, nonprofit, agency)
177 | 5. Capture relationships between people and organizations (employer, member)
178 | 6. Note the organization's industry or domain when specified
179 | 7. Extract both large entities and small groups if formally organized
180 | """
181 |
182 | name: str = Field(
183 | ...,
184 | description='The name of the organization',
185 | )
186 | description: str = Field(
187 | ...,
188 | description='Brief description of the organization. Only use information mentioned in the context.',
189 | )
190 |
191 |
192 | class Document(BaseModel):
193 | """A Document represents information content in various forms.
194 |
195 | Instructions for identifying and extracting documents:
196 | 1. Look for references to written or recorded content (books, articles, reports)
197 | 2. Identify digital content (emails, videos, podcasts, presentations)
198 | 3. Extract specific document titles or identifiers when available
199 | 4. Include document type (report, article, video) when mentioned
200 | 5. Capture the document's purpose or subject matter
201 | 6. Note relationships to authors, creators, or sources
202 | 7. Include document status (draft, published, archived) when mentioned
203 | """
204 |
205 | title: str = Field(
206 | ...,
207 | description='The title or identifier of the document',
208 | )
209 | description: str = Field(
210 | ...,
211 | description='Brief description of the document and its content. Only use information mentioned in the context.',
212 | )
213 |
214 |
215 | ENTITY_TYPES: dict[str, BaseModel] = {
216 | 'Requirement': Requirement, # type: ignore
217 | 'Preference': Preference, # type: ignore
218 | 'Procedure': Procedure, # type: ignore
219 | 'Location': Location, # type: ignore
220 | 'Event': Event, # type: ignore
221 | 'Object': Object, # type: ignore
222 | 'Topic': Topic, # type: ignore
223 | 'Organization': Organization, # type: ignore
224 | 'Document': Document, # type: ignore
225 | }
226 |
```
--------------------------------------------------------------------------------
/examples/quickstart/quickstart_falkordb.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2025, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import asyncio
18 | import json
19 | import logging
20 | import os
21 | from datetime import datetime, timezone
22 | from logging import INFO
23 |
24 | from dotenv import load_dotenv
25 |
26 | from graphiti_core import Graphiti
27 | from graphiti_core.driver.falkordb_driver import FalkorDriver
28 | from graphiti_core.nodes import EpisodeType
29 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
30 |
31 | #################################################
32 | # CONFIGURATION
33 | #################################################
34 | # Set up logging and environment variables for
35 | # connecting to FalkorDB database
36 | #################################################
37 |
38 | # Configure logging
39 | logging.basicConfig(
40 | level=INFO,
41 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
42 | datefmt='%Y-%m-%d %H:%M:%S',
43 | )
44 | logger = logging.getLogger(__name__)
45 |
46 | load_dotenv()
47 |
48 | # FalkorDB connection parameters
49 | # Make sure FalkorDB (on-premises) is running — see https://docs.falkordb.com/
50 | # By default, FalkorDB does not require a username or password,
51 | # but you can set them via environment variables for added security.
52 | #
53 | # If you're using FalkorDB Cloud, set the environment variables accordingly.
54 | # For on-premises use, you can leave them as None or set them to your preferred values.
55 | #
56 | # The default host and port are 'localhost' and '6379', respectively.
57 | # You can override these values in your environment variables or directly in the code.
58 |
59 | falkor_username = os.environ.get('FALKORDB_USERNAME', None)
60 | falkor_password = os.environ.get('FALKORDB_PASSWORD', None)
61 | falkor_host = os.environ.get('FALKORDB_HOST', 'localhost')
62 | falkor_port = os.environ.get('FALKORDB_PORT', '6379')
63 |
64 |
65 | async def main():
66 | #################################################
67 | # INITIALIZATION
68 | #################################################
69 | # Connect to FalkorDB and set up Graphiti indices
70 | # This is required before using other Graphiti
71 | # functionality
72 | #################################################
73 |
74 | # Initialize Graphiti with FalkorDB connection
75 | falkor_driver = FalkorDriver(
76 | host=falkor_host, port=falkor_port, username=falkor_username, password=falkor_password
77 | )
78 | graphiti = Graphiti(graph_driver=falkor_driver)
79 |
80 | try:
81 | #################################################
82 | # ADDING EPISODES
83 | #################################################
84 | # Episodes are the primary units of information
85 | # in Graphiti. They can be text or structured JSON
86 | # and are automatically processed to extract entities
87 | # and relationships.
88 | #################################################
89 |
90 | # Example: Add Episodes
91 | # Episodes list containing both text and JSON episodes
92 | episodes = [
93 | {
94 | 'content': 'Kamala Harris is the Attorney General of California. She was previously '
95 | 'the district attorney for San Francisco.',
96 | 'type': EpisodeType.text,
97 | 'description': 'podcast transcript',
98 | },
99 | {
100 | 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
101 | 'type': EpisodeType.text,
102 | 'description': 'podcast transcript',
103 | },
104 | {
105 | 'content': {
106 | 'name': 'Gavin Newsom',
107 | 'position': 'Governor',
108 | 'state': 'California',
109 | 'previous_role': 'Lieutenant Governor',
110 | 'previous_location': 'San Francisco',
111 | },
112 | 'type': EpisodeType.json,
113 | 'description': 'podcast metadata',
114 | },
115 | {
116 | 'content': {
117 | 'name': 'Gavin Newsom',
118 | 'position': 'Governor',
119 | 'term_start': 'January 7, 2019',
120 | 'term_end': 'Present',
121 | },
122 | 'type': EpisodeType.json,
123 | 'description': 'podcast metadata',
124 | },
125 | ]
126 |
127 | # Add episodes to the graph
128 | for i, episode in enumerate(episodes):
129 | await graphiti.add_episode(
130 | name=f'Freakonomics Radio {i}',
131 | episode_body=episode['content']
132 | if isinstance(episode['content'], str)
133 | else json.dumps(episode['content']),
134 | source=episode['type'],
135 | source_description=episode['description'],
136 | reference_time=datetime.now(timezone.utc),
137 | )
138 | print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
139 |
140 | #################################################
141 | # BASIC SEARCH
142 | #################################################
143 | # The simplest way to retrieve relationships (edges)
144 | # from Graphiti is using the search method, which
145 | # performs a hybrid search combining semantic
146 | # similarity and BM25 text retrieval.
147 | #################################################
148 |
149 | # Perform a hybrid search combining semantic similarity and BM25 retrieval
150 | print("\nSearching for: 'Who was the California Attorney General?'")
151 | results = await graphiti.search('Who was the California Attorney General?')
152 |
153 | # Print search results
154 | print('\nSearch Results:')
155 | for result in results:
156 | print(f'UUID: {result.uuid}')
157 | print(f'Fact: {result.fact}')
158 | if hasattr(result, 'valid_at') and result.valid_at:
159 | print(f'Valid from: {result.valid_at}')
160 | if hasattr(result, 'invalid_at') and result.invalid_at:
161 | print(f'Valid until: {result.invalid_at}')
162 | print('---')
163 |
164 | #################################################
165 | # CENTER NODE SEARCH
166 | #################################################
167 | # For more contextually relevant results, you can
168 | # use a center node to rerank search results based
169 | # on their graph distance to a specific node
170 | #################################################
171 |
172 | # Use the top search result's UUID as the center node for reranking
173 | if results and len(results) > 0:
174 | # Get the source node UUID from the top result
175 | center_node_uuid = results[0].source_node_uuid
176 |
177 | print('\nReranking search results based on graph distance:')
178 | print(f'Using center node UUID: {center_node_uuid}')
179 |
180 | reranked_results = await graphiti.search(
181 | 'Who was the California Attorney General?', center_node_uuid=center_node_uuid
182 | )
183 |
184 | # Print reranked search results
185 | print('\nReranked Search Results:')
186 | for result in reranked_results:
187 | print(f'UUID: {result.uuid}')
188 | print(f'Fact: {result.fact}')
189 | if hasattr(result, 'valid_at') and result.valid_at:
190 | print(f'Valid from: {result.valid_at}')
191 | if hasattr(result, 'invalid_at') and result.invalid_at:
192 | print(f'Valid until: {result.invalid_at}')
193 | print('---')
194 | else:
195 | print('No results found in the initial search to use as center node.')
196 |
197 | #################################################
198 | # NODE SEARCH USING SEARCH RECIPES
199 | #################################################
200 | # Graphiti provides predefined search recipes
201 | # optimized for different search scenarios.
202 | # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
203 | # nodes directly instead of edges.
204 | #################################################
205 |
206 | # Example: Perform a node search using _search method with standard recipes
207 | print(
208 | '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
209 | )
210 |
211 | # Use a predefined search configuration recipe and modify its limit
212 | node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
213 | node_search_config.limit = 5 # Limit to 5 results
214 |
215 | # Execute the node search
216 | node_search_results = await graphiti._search(
217 | query='California Governor',
218 | config=node_search_config,
219 | )
220 |
221 | # Print node search results
222 | print('\nNode Search Results:')
223 | for node in node_search_results.nodes:
224 | print(f'Node UUID: {node.uuid}')
225 | print(f'Node Name: {node.name}')
226 | node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
227 | print(f'Content Summary: {node_summary}')
228 | print(f'Node Labels: {", ".join(node.labels)}')
229 | print(f'Created At: {node.created_at}')
230 | if hasattr(node, 'attributes') and node.attributes:
231 | print('Attributes:')
232 | for key, value in node.attributes.items():
233 | print(f' {key}: {value}')
234 | print('---')
235 |
236 | finally:
237 | #################################################
238 | # CLEANUP
239 | #################################################
240 | # Always close the connection to FalkorDB when
241 | # finished to properly release resources
242 | #################################################
243 |
244 | # Close the connection
245 | await graphiti.close()
246 | print('\nConnection closed')
247 |
248 |
249 | if __name__ == '__main__':
250 | asyncio.run(main())
251 |
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/openai_base_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import typing
20 | from abc import abstractmethod
21 | from typing import Any, ClassVar
22 |
23 | import openai
24 | from openai.types.chat import ChatCompletionMessageParam
25 | from pydantic import BaseModel
26 |
27 | from ..prompts.models import Message
28 | from .client import LLMClient, get_extraction_language_instruction
29 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
30 | from .errors import RateLimitError, RefusalError
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 | DEFAULT_MODEL = 'gpt-5-mini'
35 | DEFAULT_SMALL_MODEL = 'gpt-5-nano'
36 | DEFAULT_REASONING = 'minimal'
37 | DEFAULT_VERBOSITY = 'low'
38 |
39 |
40 | class BaseOpenAIClient(LLMClient):
41 | """
42 | Base client class for OpenAI-compatible APIs (OpenAI and Azure OpenAI).
43 |
44 | This class contains shared logic for both OpenAI and Azure OpenAI clients,
45 | reducing code duplication while allowing for implementation-specific differences.
46 | """
47 |
48 | # Class-level constants
49 | MAX_RETRIES: ClassVar[int] = 2
50 |
51 | def __init__(
52 | self,
53 | config: LLMConfig | None = None,
54 | cache: bool = False,
55 | max_tokens: int = DEFAULT_MAX_TOKENS,
56 | reasoning: str | None = DEFAULT_REASONING,
57 | verbosity: str | None = DEFAULT_VERBOSITY,
58 | ):
59 | if cache:
60 | raise NotImplementedError('Caching is not implemented for OpenAI-based clients')
61 |
62 | if config is None:
63 | config = LLMConfig()
64 |
65 | super().__init__(config, cache)
66 | self.max_tokens = max_tokens
67 | self.reasoning = reasoning
68 | self.verbosity = verbosity
69 |
70 | @abstractmethod
71 | async def _create_completion(
72 | self,
73 | model: str,
74 | messages: list[ChatCompletionMessageParam],
75 | temperature: float | None,
76 | max_tokens: int,
77 | response_model: type[BaseModel] | None = None,
78 | ) -> Any:
79 | """Create a completion using the specific client implementation."""
80 | pass
81 |
82 | @abstractmethod
83 | async def _create_structured_completion(
84 | self,
85 | model: str,
86 | messages: list[ChatCompletionMessageParam],
87 | temperature: float | None,
88 | max_tokens: int,
89 | response_model: type[BaseModel],
90 | reasoning: str | None,
91 | verbosity: str | None,
92 | ) -> Any:
93 | """Create a structured completion using the specific client implementation."""
94 | pass
95 |
96 | def _convert_messages_to_openai_format(
97 | self, messages: list[Message]
98 | ) -> list[ChatCompletionMessageParam]:
99 | """Convert internal Message format to OpenAI ChatCompletionMessageParam format."""
100 | openai_messages: list[ChatCompletionMessageParam] = []
101 | for m in messages:
102 | m.content = self._clean_input(m.content)
103 | if m.role == 'user':
104 | openai_messages.append({'role': 'user', 'content': m.content})
105 | elif m.role == 'system':
106 | openai_messages.append({'role': 'system', 'content': m.content})
107 | return openai_messages
108 |
109 | def _get_model_for_size(self, model_size: ModelSize) -> str:
110 | """Get the appropriate model name based on the requested size."""
111 | if model_size == ModelSize.small:
112 | return self.small_model or DEFAULT_SMALL_MODEL
113 | else:
114 | return self.model or DEFAULT_MODEL
115 |
116 | def _handle_structured_response(self, response: Any) -> dict[str, Any]:
117 | """Handle structured response parsing and validation."""
118 | response_object = response.output_text
119 |
120 | if response_object:
121 | return json.loads(response_object)
122 | elif response_object.refusal:
123 | raise RefusalError(response_object.refusal)
124 | else:
125 | raise Exception(f'Invalid response from LLM: {response_object.model_dump()}')
126 |
127 | def _handle_json_response(self, response: Any) -> dict[str, Any]:
128 | """Handle JSON response parsing."""
129 | result = response.choices[0].message.content or '{}'
130 | return json.loads(result)
131 |
132 | async def _generate_response(
133 | self,
134 | messages: list[Message],
135 | response_model: type[BaseModel] | None = None,
136 | max_tokens: int = DEFAULT_MAX_TOKENS,
137 | model_size: ModelSize = ModelSize.medium,
138 | ) -> dict[str, Any]:
139 | """Generate a response using the appropriate client implementation."""
140 | openai_messages = self._convert_messages_to_openai_format(messages)
141 | model = self._get_model_for_size(model_size)
142 |
143 | try:
144 | if response_model:
145 | response = await self._create_structured_completion(
146 | model=model,
147 | messages=openai_messages,
148 | temperature=self.temperature,
149 | max_tokens=max_tokens or self.max_tokens,
150 | response_model=response_model,
151 | reasoning=self.reasoning,
152 | verbosity=self.verbosity,
153 | )
154 | return self._handle_structured_response(response)
155 | else:
156 | response = await self._create_completion(
157 | model=model,
158 | messages=openai_messages,
159 | temperature=self.temperature,
160 | max_tokens=max_tokens or self.max_tokens,
161 | )
162 | return self._handle_json_response(response)
163 |
164 | except openai.LengthFinishReasonError as e:
165 | raise Exception(f'Output length exceeded max tokens {self.max_tokens}: {e}') from e
166 | except openai.RateLimitError as e:
167 | raise RateLimitError from e
168 | except openai.AuthenticationError as e:
169 | logger.error(
170 | f'OpenAI Authentication Error: {e}. Please verify your API key is correct.'
171 | )
172 | raise
173 | except Exception as e:
174 | # Provide more context for connection errors
175 | error_msg = str(e)
176 | if 'Connection error' in error_msg or 'connection' in error_msg.lower():
177 | logger.error(
178 | f'Connection error communicating with OpenAI API. Please check your network connection and API key. Error: {e}'
179 | )
180 | else:
181 | logger.error(f'Error in generating LLM response: {e}')
182 | raise
183 |
184 | async def generate_response(
185 | self,
186 | messages: list[Message],
187 | response_model: type[BaseModel] | None = None,
188 | max_tokens: int | None = None,
189 | model_size: ModelSize = ModelSize.medium,
190 | group_id: str | None = None,
191 | prompt_name: str | None = None,
192 | ) -> dict[str, typing.Any]:
193 | """Generate a response with retry logic and error handling."""
194 | if max_tokens is None:
195 | max_tokens = self.max_tokens
196 |
197 | # Add multilingual extraction instructions
198 | messages[0].content += get_extraction_language_instruction(group_id)
199 |
200 | # Wrap entire operation in tracing span
201 | with self.tracer.start_span('llm.generate') as span:
202 | attributes = {
203 | 'llm.provider': 'openai',
204 | 'model.size': model_size.value,
205 | 'max_tokens': max_tokens,
206 | }
207 | if prompt_name:
208 | attributes['prompt.name'] = prompt_name
209 | span.add_attributes(attributes)
210 |
211 | retry_count = 0
212 | last_error = None
213 |
214 | while retry_count <= self.MAX_RETRIES:
215 | try:
216 | response = await self._generate_response(
217 | messages, response_model, max_tokens, model_size
218 | )
219 | return response
220 | except (RateLimitError, RefusalError):
221 | # These errors should not trigger retries
222 | span.set_status('error', str(last_error))
223 | raise
224 | except (
225 | openai.APITimeoutError,
226 | openai.APIConnectionError,
227 | openai.InternalServerError,
228 | ):
229 | # Let OpenAI's client handle these retries
230 | span.set_status('error', str(last_error))
231 | raise
232 | except Exception as e:
233 | last_error = e
234 |
235 | # Don't retry if we've hit the max retries
236 | if retry_count >= self.MAX_RETRIES:
237 | logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
238 | span.set_status('error', str(e))
239 | span.record_exception(e)
240 | raise
241 |
242 | retry_count += 1
243 |
244 | # Construct a detailed error message for the LLM
245 | error_context = (
246 | f'The previous response attempt was invalid. '
247 | f'Error type: {e.__class__.__name__}. '
248 | f'Error details: {str(e)}. '
249 | f'Please try again with a valid response, ensuring the output matches '
250 | f'the expected format and constraints.'
251 | )
252 |
253 | error_message = Message(role='user', content=error_context)
254 | messages.append(error_message)
255 | logger.warning(
256 | f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
257 | )
258 |
259 | # If we somehow get here, raise the last error
260 | span.set_status('error', str(last_error))
261 | raise last_error or Exception('Max retries exceeded with no specific error')
262 |
```
--------------------------------------------------------------------------------
/mcp_server/src/config/schema.py:
--------------------------------------------------------------------------------
```python
1 | """Configuration schemas with pydantic-settings and YAML support."""
2 |
3 | import os
4 | from pathlib import Path
5 | from typing import Any
6 |
7 | import yaml
8 | from pydantic import BaseModel, Field
9 | from pydantic_settings import (
10 | BaseSettings,
11 | PydanticBaseSettingsSource,
12 | SettingsConfigDict,
13 | )
14 |
15 |
16 | class YamlSettingsSource(PydanticBaseSettingsSource):
17 | """Custom settings source for loading from YAML files."""
18 |
19 | def __init__(self, settings_cls: type[BaseSettings], config_path: Path | None = None):
20 | super().__init__(settings_cls)
21 | self.config_path = config_path or Path('config.yaml')
22 |
23 | def _expand_env_vars(self, value: Any) -> Any:
24 | """Recursively expand environment variables in configuration values."""
25 | if isinstance(value, str):
26 | # Support ${VAR} and ${VAR:default} syntax
27 | import re
28 |
29 | def replacer(match):
30 | var_name = match.group(1)
31 | default_value = match.group(3) if match.group(3) is not None else ''
32 | return os.environ.get(var_name, default_value)
33 |
34 | pattern = r'\$\{([^:}]+)(:([^}]*))?\}'
35 |
36 | # Check if the entire value is a single env var expression
37 | full_match = re.fullmatch(pattern, value)
38 | if full_match:
39 | result = replacer(full_match)
40 | # Convert boolean-like strings to actual booleans
41 | if isinstance(result, str):
42 | lower_result = result.lower().strip()
43 | if lower_result in ('true', '1', 'yes', 'on'):
44 | return True
45 | elif lower_result in ('false', '0', 'no', 'off'):
46 | return False
47 | elif lower_result == '':
48 | # Empty string means env var not set - return None for optional fields
49 | return None
50 | return result
51 | else:
52 | # Otherwise, do string substitution (keep as strings for partial replacements)
53 | return re.sub(pattern, replacer, value)
54 | elif isinstance(value, dict):
55 | return {k: self._expand_env_vars(v) for k, v in value.items()}
56 | elif isinstance(value, list):
57 | return [self._expand_env_vars(item) for item in value]
58 | return value
59 |
60 | def get_field_value(self, field_name: str, field_info: Any) -> Any:
61 | """Get field value from YAML config."""
62 | return None
63 |
64 | def __call__(self) -> dict[str, Any]:
65 | """Load and parse YAML configuration."""
66 | if not self.config_path.exists():
67 | return {}
68 |
69 | with open(self.config_path) as f:
70 | raw_config = yaml.safe_load(f) or {}
71 |
72 | # Expand environment variables
73 | return self._expand_env_vars(raw_config)
74 |
75 |
76 | class ServerConfig(BaseModel):
77 | """Server configuration."""
78 |
79 | transport: str = Field(
80 | default='http',
81 | description='Transport type: http (default, recommended), stdio, or sse (deprecated)',
82 | )
83 | host: str = Field(default='0.0.0.0', description='Server host')
84 | port: int = Field(default=8000, description='Server port')
85 |
86 |
87 | class OpenAIProviderConfig(BaseModel):
88 | """OpenAI provider configuration."""
89 |
90 | api_key: str | None = None
91 | api_url: str = 'https://api.openai.com/v1'
92 | organization_id: str | None = None
93 |
94 |
95 | class AzureOpenAIProviderConfig(BaseModel):
96 | """Azure OpenAI provider configuration."""
97 |
98 | api_key: str | None = None
99 | api_url: str | None = None
100 | api_version: str = '2024-10-21'
101 | deployment_name: str | None = None
102 | use_azure_ad: bool = False
103 |
104 |
105 | class AnthropicProviderConfig(BaseModel):
106 | """Anthropic provider configuration."""
107 |
108 | api_key: str | None = None
109 | api_url: str = 'https://api.anthropic.com'
110 | max_retries: int = 3
111 |
112 |
113 | class GeminiProviderConfig(BaseModel):
114 | """Gemini provider configuration."""
115 |
116 | api_key: str | None = None
117 | project_id: str | None = None
118 | location: str = 'us-central1'
119 |
120 |
121 | class GroqProviderConfig(BaseModel):
122 | """Groq provider configuration."""
123 |
124 | api_key: str | None = None
125 | api_url: str = 'https://api.groq.com/openai/v1'
126 |
127 |
128 | class VoyageProviderConfig(BaseModel):
129 | """Voyage AI provider configuration."""
130 |
131 | api_key: str | None = None
132 | api_url: str = 'https://api.voyageai.com/v1'
133 | model: str = 'voyage-3'
134 |
135 |
136 | class LLMProvidersConfig(BaseModel):
137 | """LLM providers configuration."""
138 |
139 | openai: OpenAIProviderConfig | None = None
140 | azure_openai: AzureOpenAIProviderConfig | None = None
141 | anthropic: AnthropicProviderConfig | None = None
142 | gemini: GeminiProviderConfig | None = None
143 | groq: GroqProviderConfig | None = None
144 |
145 |
146 | class LLMConfig(BaseModel):
147 | """LLM configuration."""
148 |
149 | provider: str = Field(default='openai', description='LLM provider')
150 | model: str = Field(default='gpt-4.1', description='Model name')
151 | temperature: float | None = Field(
152 | default=None, description='Temperature (optional, defaults to None for reasoning models)'
153 | )
154 | max_tokens: int = Field(default=4096, description='Max tokens')
155 | providers: LLMProvidersConfig = Field(default_factory=LLMProvidersConfig)
156 |
157 |
158 | class EmbedderProvidersConfig(BaseModel):
159 | """Embedder providers configuration."""
160 |
161 | openai: OpenAIProviderConfig | None = None
162 | azure_openai: AzureOpenAIProviderConfig | None = None
163 | gemini: GeminiProviderConfig | None = None
164 | voyage: VoyageProviderConfig | None = None
165 |
166 |
167 | class EmbedderConfig(BaseModel):
168 | """Embedder configuration."""
169 |
170 | provider: str = Field(default='openai', description='Embedder provider')
171 | model: str = Field(default='text-embedding-3-small', description='Model name')
172 | dimensions: int = Field(default=1536, description='Embedding dimensions')
173 | providers: EmbedderProvidersConfig = Field(default_factory=EmbedderProvidersConfig)
174 |
175 |
176 | class Neo4jProviderConfig(BaseModel):
177 | """Neo4j provider configuration."""
178 |
179 | uri: str = 'bolt://localhost:7687'
180 | username: str = 'neo4j'
181 | password: str | None = None
182 | database: str = 'neo4j'
183 | use_parallel_runtime: bool = False
184 |
185 |
186 | class FalkorDBProviderConfig(BaseModel):
187 | """FalkorDB provider configuration."""
188 |
189 | uri: str = 'redis://localhost:6379'
190 | password: str | None = None
191 | database: str = 'default_db'
192 |
193 |
194 | class DatabaseProvidersConfig(BaseModel):
195 | """Database providers configuration."""
196 |
197 | neo4j: Neo4jProviderConfig | None = None
198 | falkordb: FalkorDBProviderConfig | None = None
199 |
200 |
201 | class DatabaseConfig(BaseModel):
202 | """Database configuration."""
203 |
204 | provider: str = Field(default='falkordb', description='Database provider')
205 | providers: DatabaseProvidersConfig = Field(default_factory=DatabaseProvidersConfig)
206 |
207 |
208 | class EntityTypeConfig(BaseModel):
209 | """Entity type configuration."""
210 |
211 | name: str
212 | description: str
213 |
214 |
215 | class GraphitiAppConfig(BaseModel):
216 | """Graphiti-specific configuration."""
217 |
218 | group_id: str = Field(default='main', description='Group ID')
219 | episode_id_prefix: str | None = Field(default='', description='Episode ID prefix')
220 | user_id: str = Field(default='mcp_user', description='User ID')
221 | entity_types: list[EntityTypeConfig] = Field(default_factory=list)
222 |
223 | def model_post_init(self, __context) -> None:
224 | """Convert None to empty string for episode_id_prefix."""
225 | if self.episode_id_prefix is None:
226 | self.episode_id_prefix = ''
227 |
228 |
229 | class GraphitiConfig(BaseSettings):
230 | """Graphiti configuration with YAML and environment support."""
231 |
232 | server: ServerConfig = Field(default_factory=ServerConfig)
233 | llm: LLMConfig = Field(default_factory=LLMConfig)
234 | embedder: EmbedderConfig = Field(default_factory=EmbedderConfig)
235 | database: DatabaseConfig = Field(default_factory=DatabaseConfig)
236 | graphiti: GraphitiAppConfig = Field(default_factory=GraphitiAppConfig)
237 |
238 | # Additional server options
239 | destroy_graph: bool = Field(default=False, description='Clear graph on startup')
240 |
241 | model_config = SettingsConfigDict(
242 | env_prefix='',
243 | env_nested_delimiter='__',
244 | case_sensitive=False,
245 | extra='ignore',
246 | )
247 |
248 | @classmethod
249 | def settings_customise_sources(
250 | cls,
251 | settings_cls: type[BaseSettings],
252 | init_settings: PydanticBaseSettingsSource,
253 | env_settings: PydanticBaseSettingsSource,
254 | dotenv_settings: PydanticBaseSettingsSource,
255 | file_secret_settings: PydanticBaseSettingsSource,
256 | ) -> tuple[PydanticBaseSettingsSource, ...]:
257 | """Customize settings sources to include YAML."""
258 | config_path = Path(os.environ.get('CONFIG_PATH', 'config/config.yaml'))
259 | yaml_settings = YamlSettingsSource(settings_cls, config_path)
260 | # Priority: CLI args (init) > env vars > yaml > defaults
261 | return (init_settings, env_settings, yaml_settings, dotenv_settings)
262 |
263 | def apply_cli_overrides(self, args) -> None:
264 | """Apply CLI argument overrides to configuration."""
265 | # Override server settings
266 | if hasattr(args, 'transport') and args.transport:
267 | self.server.transport = args.transport
268 |
269 | # Override LLM settings
270 | if hasattr(args, 'llm_provider') and args.llm_provider:
271 | self.llm.provider = args.llm_provider
272 | if hasattr(args, 'model') and args.model:
273 | self.llm.model = args.model
274 | if hasattr(args, 'temperature') and args.temperature is not None:
275 | self.llm.temperature = args.temperature
276 |
277 | # Override embedder settings
278 | if hasattr(args, 'embedder_provider') and args.embedder_provider:
279 | self.embedder.provider = args.embedder_provider
280 | if hasattr(args, 'embedder_model') and args.embedder_model:
281 | self.embedder.model = args.embedder_model
282 |
283 | # Override database settings
284 | if hasattr(args, 'database_provider') and args.database_provider:
285 | self.database.provider = args.database_provider
286 |
287 | # Override Graphiti settings
288 | if hasattr(args, 'group_id') and args.group_id:
289 | self.graphiti.group_id = args.group_id
290 | if hasattr(args, 'user_id') and args.user_id:
291 | self.graphiti.user_id = args.user_id
292 |
```
--------------------------------------------------------------------------------
/tests/helpers_test.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import os
18 | from unittest.mock import Mock
19 |
20 | import numpy as np
21 | import pytest
22 | from dotenv import load_dotenv
23 |
24 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
25 | from graphiti_core.edges import EntityEdge, EpisodicEdge
26 | from graphiti_core.embedder.client import EmbedderClient
27 | from graphiti_core.helpers import lucene_sanitize
28 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodicNode
29 | from graphiti_core.utils.maintenance.graph_data_operations import clear_data
30 |
31 | load_dotenv()
32 |
33 | drivers: list[GraphProvider] = []
34 | if os.getenv('DISABLE_NEO4J') is None:
35 | try:
36 | from graphiti_core.driver.neo4j_driver import Neo4jDriver
37 |
38 | drivers.append(GraphProvider.NEO4J)
39 | except ImportError:
40 | raise
41 |
42 | if os.getenv('DISABLE_FALKORDB') is None:
43 | try:
44 | from graphiti_core.driver.falkordb_driver import FalkorDriver
45 |
46 | drivers.append(GraphProvider.FALKORDB)
47 | except ImportError:
48 | raise
49 |
50 | if os.getenv('DISABLE_KUZU') is None:
51 | try:
52 | from graphiti_core.driver.kuzu_driver import KuzuDriver
53 |
54 | drivers.append(GraphProvider.KUZU)
55 | except ImportError:
56 | raise
57 |
58 | # Disable Neptune for now
59 | os.environ['DISABLE_NEPTUNE'] = 'True'
60 | if os.getenv('DISABLE_NEPTUNE') is None:
61 | try:
62 | from graphiti_core.driver.neptune_driver import NeptuneDriver
63 |
64 | drivers.append(GraphProvider.NEPTUNE)
65 | except ImportError:
66 | raise
67 |
68 | NEO4J_URI = os.getenv('NEO4J_URI', 'bolt://localhost:7687')
69 | NEO4J_USER = os.getenv('NEO4J_USER', 'neo4j')
70 | NEO4J_PASSWORD = os.getenv('NEO4J_PASSWORD', 'test')
71 |
72 | FALKORDB_HOST = os.getenv('FALKORDB_HOST', 'localhost')
73 | FALKORDB_PORT = os.getenv('FALKORDB_PORT', '6379')
74 | FALKORDB_USER = os.getenv('FALKORDB_USER', None)
75 | FALKORDB_PASSWORD = os.getenv('FALKORDB_PASSWORD', None)
76 |
77 | NEPTUNE_HOST = os.getenv('NEPTUNE_HOST', 'localhost')
78 | NEPTUNE_PORT = os.getenv('NEPTUNE_PORT', 8182)
79 | AOSS_HOST = os.getenv('AOSS_HOST', None)
80 |
81 | KUZU_DB = os.getenv('KUZU_DB', ':memory:')
82 |
83 | group_id = 'graphiti_test_group'
84 | group_id_2 = 'graphiti_test_group_2'
85 |
86 |
87 | def get_driver(provider: GraphProvider) -> GraphDriver:
88 | if provider == GraphProvider.NEO4J:
89 | return Neo4jDriver(
90 | uri=NEO4J_URI,
91 | user=NEO4J_USER,
92 | password=NEO4J_PASSWORD,
93 | )
94 | elif provider == GraphProvider.FALKORDB:
95 | return FalkorDriver(
96 | host=FALKORDB_HOST,
97 | port=int(FALKORDB_PORT),
98 | username=FALKORDB_USER,
99 | password=FALKORDB_PASSWORD,
100 | )
101 | elif provider == GraphProvider.KUZU:
102 | driver = KuzuDriver(
103 | db=KUZU_DB,
104 | )
105 | return driver
106 | elif provider == GraphProvider.NEPTUNE:
107 | return NeptuneDriver(
108 | host=NEPTUNE_HOST,
109 | port=int(NEPTUNE_PORT),
110 | aoss_host=AOSS_HOST,
111 | )
112 | else:
113 | raise ValueError(f'Driver {provider} not available')
114 |
115 |
116 | @pytest.fixture(params=drivers)
117 | async def graph_driver(request):
118 | driver = request.param
119 | graph_driver = get_driver(driver)
120 | await clear_data(graph_driver, [group_id, group_id_2])
121 | try:
122 | yield graph_driver # provide driver to the test
123 | finally:
124 | # always called, even if the test fails or raises
125 | # await clean_up(graph_driver)
126 | await graph_driver.close()
127 |
128 |
129 | embedding_dim = 384
130 | embeddings = {
131 | key: np.random.uniform(0.0, 0.9, embedding_dim).tolist()
132 | for key in [
133 | 'Alice',
134 | 'Bob',
135 | 'Alice likes Bob',
136 | 'test_entity_1',
137 | 'test_entity_2',
138 | 'test_entity_3',
139 | 'test_entity_4',
140 | 'test_entity_alice',
141 | 'test_entity_bob',
142 | 'test_entity_1 is a duplicate of test_entity_2',
143 | 'test_entity_3 is a duplicate of test_entity_4',
144 | 'test_entity_1 relates to test_entity_2',
145 | 'test_entity_1 relates to test_entity_3',
146 | 'test_entity_2 relates to test_entity_3',
147 | 'test_entity_1 relates to test_entity_4',
148 | 'test_entity_2 relates to test_entity_4',
149 | 'test_entity_3 relates to test_entity_4',
150 | 'test_entity_1 relates to test_entity_2',
151 | 'test_entity_3 relates to test_entity_4',
152 | 'test_entity_2 relates to test_entity_3',
153 | 'test_community_1',
154 | 'test_community_2',
155 | ]
156 | }
157 | embeddings['Alice Smith'] = embeddings['Alice']
158 |
159 |
160 | @pytest.fixture
161 | def mock_embedder():
162 | mock_model = Mock(spec=EmbedderClient)
163 |
164 | def mock_embed(input_data):
165 | if isinstance(input_data, str):
166 | return embeddings[input_data]
167 | elif isinstance(input_data, list):
168 | combined_input = ' '.join(input_data)
169 | return embeddings[combined_input]
170 | else:
171 | raise ValueError(f'Unsupported input type: {type(input_data)}')
172 |
173 | mock_model.create.side_effect = mock_embed
174 | return mock_model
175 |
176 |
177 | def test_lucene_sanitize():
178 | # Call the function with test data
179 | queries = [
180 | (
181 | 'This has every escape character + - && || ! ( ) { } [ ] ^ " ~ * ? : \\ /',
182 | '\\This has every escape character \\+ \\- \\&\\& \\|\\| \\! \\( \\) \\{ \\} \\[ \\] \\^ \\" \\~ \\* \\? \\: \\\\ \\/',
183 | ),
184 | ('this has no escape characters', 'this has no escape characters'),
185 | ]
186 |
187 | for query, assert_result in queries:
188 | result = lucene_sanitize(query)
189 | assert assert_result == result
190 |
191 |
192 | async def get_node_count(driver: GraphDriver, uuids: list[str]) -> int:
193 | results, _, _ = await driver.execute_query(
194 | """
195 | MATCH (n)
196 | WHERE n.uuid IN $uuids
197 | RETURN COUNT(n) as count
198 | """,
199 | uuids=uuids,
200 | )
201 | return int(results[0]['count'])
202 |
203 |
204 | async def get_edge_count(driver: GraphDriver, uuids: list[str]) -> int:
205 | results, _, _ = await driver.execute_query(
206 | """
207 | MATCH (n)-[e]->(m)
208 | WHERE e.uuid IN $uuids
209 | RETURN COUNT(e) as count
210 | UNION ALL
211 | MATCH (e:RelatesToNode_)
212 | WHERE e.uuid IN $uuids
213 | RETURN COUNT(e) as count
214 | """,
215 | uuids=uuids,
216 | )
217 | return sum(int(result['count']) for result in results)
218 |
219 |
220 | async def print_graph(graph_driver: GraphDriver):
221 | nodes, _, _ = await graph_driver.execute_query(
222 | """
223 | MATCH (n)
224 | RETURN n.uuid, n.name
225 | """,
226 | )
227 | print('Nodes:')
228 | for node in nodes:
229 | print(' ', node)
230 | edges, _, _ = await graph_driver.execute_query(
231 | """
232 | MATCH (n)-[e]->(m)
233 | RETURN n.name, e.uuid, m.name
234 | """,
235 | )
236 | print('Edges:')
237 | for edge in edges:
238 | print(' ', edge)
239 |
240 |
241 | async def assert_episodic_node_equals(retrieved: EpisodicNode, sample: EpisodicNode):
242 | assert retrieved.uuid == sample.uuid
243 | assert retrieved.name == sample.name
244 | assert retrieved.group_id == group_id
245 | assert retrieved.created_at == sample.created_at
246 | assert retrieved.source == sample.source
247 | assert retrieved.source_description == sample.source_description
248 | assert retrieved.content == sample.content
249 | assert retrieved.valid_at == sample.valid_at
250 | assert set(retrieved.entity_edges) == set(sample.entity_edges)
251 |
252 |
253 | async def assert_entity_node_equals(
254 | graph_driver: GraphDriver, retrieved: EntityNode, sample: EntityNode
255 | ):
256 | await retrieved.load_name_embedding(graph_driver)
257 | assert retrieved.uuid == sample.uuid
258 | assert retrieved.name == sample.name
259 | assert retrieved.group_id == sample.group_id
260 | assert set(retrieved.labels) == set(sample.labels)
261 | assert retrieved.created_at == sample.created_at
262 | assert retrieved.name_embedding is not None
263 | assert sample.name_embedding is not None
264 | assert np.allclose(retrieved.name_embedding, sample.name_embedding)
265 | assert retrieved.summary == sample.summary
266 | assert retrieved.attributes == sample.attributes
267 |
268 |
269 | async def assert_community_node_equals(
270 | graph_driver: GraphDriver, retrieved: CommunityNode, sample: CommunityNode
271 | ):
272 | await retrieved.load_name_embedding(graph_driver)
273 | assert retrieved.uuid == sample.uuid
274 | assert retrieved.name == sample.name
275 | assert retrieved.group_id == group_id
276 | assert retrieved.created_at == sample.created_at
277 | assert retrieved.name_embedding is not None
278 | assert sample.name_embedding is not None
279 | assert np.allclose(retrieved.name_embedding, sample.name_embedding)
280 | assert retrieved.summary == sample.summary
281 |
282 |
283 | async def assert_episodic_edge_equals(retrieved: EpisodicEdge, sample: EpisodicEdge):
284 | assert retrieved.uuid == sample.uuid
285 | assert retrieved.group_id == sample.group_id
286 | assert retrieved.created_at == sample.created_at
287 | assert retrieved.source_node_uuid == sample.source_node_uuid
288 | assert retrieved.target_node_uuid == sample.target_node_uuid
289 |
290 |
291 | async def assert_entity_edge_equals(
292 | graph_driver: GraphDriver, retrieved: EntityEdge, sample: EntityEdge
293 | ):
294 | await retrieved.load_fact_embedding(graph_driver)
295 | assert retrieved.uuid == sample.uuid
296 | assert retrieved.group_id == sample.group_id
297 | assert retrieved.created_at == sample.created_at
298 | assert retrieved.source_node_uuid == sample.source_node_uuid
299 | assert retrieved.target_node_uuid == sample.target_node_uuid
300 | assert retrieved.name == sample.name
301 | assert retrieved.fact == sample.fact
302 | assert retrieved.fact_embedding is not None
303 | assert sample.fact_embedding is not None
304 | assert np.allclose(retrieved.fact_embedding, sample.fact_embedding)
305 | assert retrieved.episodes == sample.episodes
306 | assert retrieved.expired_at == sample.expired_at
307 | assert retrieved.valid_at == sample.valid_at
308 | assert retrieved.invalid_at == sample.invalid_at
309 | assert retrieved.attributes == sample.attributes
310 |
311 |
312 | if __name__ == '__main__':
313 | pytest.main([__file__])
314 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_fixtures.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Shared test fixtures and utilities for Graphiti MCP integration tests.
3 | """
4 |
5 | import asyncio
6 | import contextlib
7 | import json
8 | import os
9 | import random
10 | import time
11 | from contextlib import asynccontextmanager
12 | from typing import Any
13 |
14 | import pytest
15 | from faker import Faker
16 | from mcp import ClientSession, StdioServerParameters
17 | from mcp.client.stdio import stdio_client
18 |
19 | fake = Faker()
20 |
21 |
22 | class TestDataGenerator:
23 | """Generate realistic test data for various scenarios."""
24 |
25 | @staticmethod
26 | def generate_company_profile() -> dict[str, Any]:
27 | """Generate a realistic company profile."""
28 | return {
29 | 'company': {
30 | 'name': fake.company(),
31 | 'founded': random.randint(1990, 2023),
32 | 'industry': random.choice(['Tech', 'Finance', 'Healthcare', 'Retail']),
33 | 'employees': random.randint(10, 10000),
34 | 'revenue': f'${random.randint(1, 1000)}M',
35 | 'headquarters': fake.city(),
36 | },
37 | 'products': [
38 | {
39 | 'id': fake.uuid4()[:8],
40 | 'name': fake.catch_phrase(),
41 | 'category': random.choice(['Software', 'Hardware', 'Service']),
42 | 'price': random.randint(10, 10000),
43 | }
44 | for _ in range(random.randint(1, 5))
45 | ],
46 | 'leadership': {
47 | 'ceo': fake.name(),
48 | 'cto': fake.name(),
49 | 'cfo': fake.name(),
50 | },
51 | }
52 |
53 | @staticmethod
54 | def generate_conversation(turns: int = 3) -> str:
55 | """Generate a realistic conversation."""
56 | topics = [
57 | 'product features',
58 | 'pricing',
59 | 'technical support',
60 | 'integration',
61 | 'documentation',
62 | 'performance',
63 | ]
64 |
65 | conversation = []
66 | for _ in range(turns):
67 | topic = random.choice(topics)
68 | user_msg = f'user: {fake.sentence()} about {topic}?'
69 | assistant_msg = f'assistant: {fake.paragraph(nb_sentences=2)}'
70 | conversation.extend([user_msg, assistant_msg])
71 |
72 | return '\n'.join(conversation)
73 |
74 | @staticmethod
75 | def generate_technical_document() -> str:
76 | """Generate technical documentation content."""
77 | sections = [
78 | f'# {fake.catch_phrase()}\n\n{fake.paragraph()}',
79 | f'## Architecture\n{fake.paragraph()}',
80 | f'## Implementation\n{fake.paragraph()}',
81 | f'## Performance\n- Latency: {random.randint(1, 100)}ms\n- Throughput: {random.randint(100, 10000)} req/s',
82 | f'## Dependencies\n- {fake.word()}\n- {fake.word()}\n- {fake.word()}',
83 | ]
84 | return '\n\n'.join(sections)
85 |
86 | @staticmethod
87 | def generate_news_article() -> str:
88 | """Generate a news article."""
89 | company = fake.company()
90 | return f"""
91 | {company} Announces {fake.catch_phrase()}
92 |
93 | {fake.city()}, {fake.date()} - {company} today announced {fake.paragraph()}.
94 |
95 | "This is a significant milestone," said {fake.name()}, CEO of {company}.
96 | "{fake.sentence()}"
97 |
98 | The announcement comes after {fake.paragraph()}.
99 |
100 | Industry analysts predict {fake.paragraph()}.
101 | """
102 |
103 | @staticmethod
104 | def generate_user_profile() -> dict[str, Any]:
105 | """Generate a user profile."""
106 | return {
107 | 'user_id': fake.uuid4(),
108 | 'name': fake.name(),
109 | 'email': fake.email(),
110 | 'joined': fake.date_time_this_year().isoformat(),
111 | 'preferences': {
112 | 'theme': random.choice(['light', 'dark', 'auto']),
113 | 'notifications': random.choice([True, False]),
114 | 'language': random.choice(['en', 'es', 'fr', 'de']),
115 | },
116 | 'activity': {
117 | 'last_login': fake.date_time_this_month().isoformat(),
118 | 'total_sessions': random.randint(1, 1000),
119 | 'average_duration': f'{random.randint(1, 60)} minutes',
120 | },
121 | }
122 |
123 |
124 | class MockLLMProvider:
125 | """Mock LLM provider for testing without actual API calls."""
126 |
127 | def __init__(self, delay: float = 0.1):
128 | self.delay = delay # Simulate LLM latency
129 |
130 | async def generate(self, prompt: str) -> str:
131 | """Simulate LLM generation with delay."""
132 | await asyncio.sleep(self.delay)
133 |
134 | # Return deterministic responses based on prompt patterns
135 | if 'extract entities' in prompt.lower():
136 | return json.dumps(
137 | {
138 | 'entities': [
139 | {'name': 'TestEntity1', 'type': 'PERSON'},
140 | {'name': 'TestEntity2', 'type': 'ORGANIZATION'},
141 | ]
142 | }
143 | )
144 | elif 'summarize' in prompt.lower():
145 | return 'This is a test summary of the provided content.'
146 | else:
147 | return 'Mock LLM response'
148 |
149 |
150 | @asynccontextmanager
151 | async def graphiti_test_client(
152 | group_id: str | None = None,
153 | database: str = 'falkordb',
154 | use_mock_llm: bool = False,
155 | config_overrides: dict[str, Any] | None = None,
156 | ):
157 | """
158 | Context manager for creating test clients with various configurations.
159 |
160 | Args:
161 | group_id: Test group identifier
162 | database: Database backend (neo4j, falkordb)
163 | use_mock_llm: Whether to use mock LLM for faster tests
164 | config_overrides: Additional config overrides
165 | """
166 | test_group_id = group_id or f'test_{int(time.time())}_{random.randint(1000, 9999)}'
167 |
168 | env = {
169 | 'DATABASE_PROVIDER': database,
170 | 'OPENAI_API_KEY': os.environ.get('OPENAI_API_KEY', 'test_key' if use_mock_llm else None),
171 | }
172 |
173 | # Database-specific configuration
174 | if database == 'neo4j':
175 | env.update(
176 | {
177 | 'NEO4J_URI': os.environ.get('NEO4J_URI', 'bolt://localhost:7687'),
178 | 'NEO4J_USER': os.environ.get('NEO4J_USER', 'neo4j'),
179 | 'NEO4J_PASSWORD': os.environ.get('NEO4J_PASSWORD', 'graphiti'),
180 | }
181 | )
182 | elif database == 'falkordb':
183 | env['FALKORDB_URI'] = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
184 |
185 | # Apply config overrides
186 | if config_overrides:
187 | env.update(config_overrides)
188 |
189 | # Add mock LLM flag if needed
190 | if use_mock_llm:
191 | env['USE_MOCK_LLM'] = 'true'
192 |
193 | server_params = StdioServerParameters(
194 | command='uv', args=['run', 'main.py', '--transport', 'stdio'], env=env
195 | )
196 |
197 | async with stdio_client(server_params) as (read, write):
198 | session = ClientSession(read, write)
199 | await session.initialize()
200 |
201 | try:
202 | yield session, test_group_id
203 | finally:
204 | # Cleanup: Clear test data
205 | with contextlib.suppress(Exception):
206 | await session.call_tool('clear_graph', {'group_id': test_group_id})
207 |
208 | await session.close()
209 |
210 |
211 | class PerformanceBenchmark:
212 | """Track and analyze performance benchmarks."""
213 |
214 | def __init__(self):
215 | self.measurements: dict[str, list[float]] = {}
216 |
217 | def record(self, operation: str, duration: float):
218 | """Record a performance measurement."""
219 | if operation not in self.measurements:
220 | self.measurements[operation] = []
221 | self.measurements[operation].append(duration)
222 |
223 | def get_stats(self, operation: str) -> dict[str, float]:
224 | """Get statistics for an operation."""
225 | if operation not in self.measurements or not self.measurements[operation]:
226 | return {}
227 |
228 | durations = self.measurements[operation]
229 | return {
230 | 'count': len(durations),
231 | 'mean': sum(durations) / len(durations),
232 | 'min': min(durations),
233 | 'max': max(durations),
234 | 'median': sorted(durations)[len(durations) // 2],
235 | }
236 |
237 | def report(self) -> str:
238 | """Generate a performance report."""
239 | lines = ['Performance Benchmark Report', '=' * 40]
240 |
241 | for operation in sorted(self.measurements.keys()):
242 | stats = self.get_stats(operation)
243 | lines.append(f'\n{operation}:')
244 | lines.append(f' Samples: {stats["count"]}')
245 | lines.append(f' Mean: {stats["mean"]:.3f}s')
246 | lines.append(f' Median: {stats["median"]:.3f}s')
247 | lines.append(f' Min: {stats["min"]:.3f}s')
248 | lines.append(f' Max: {stats["max"]:.3f}s')
249 |
250 | return '\n'.join(lines)
251 |
252 |
253 | # Pytest fixtures
254 | @pytest.fixture
255 | def test_data_generator():
256 | """Provide test data generator."""
257 | return TestDataGenerator()
258 |
259 |
260 | @pytest.fixture
261 | def performance_benchmark():
262 | """Provide performance benchmark tracker."""
263 | return PerformanceBenchmark()
264 |
265 |
266 | @pytest.fixture
267 | async def mock_graphiti_client():
268 | """Provide a Graphiti client with mocked LLM."""
269 | async with graphiti_test_client(use_mock_llm=True) as (session, group_id):
270 | yield session, group_id
271 |
272 |
273 | @pytest.fixture
274 | async def graphiti_client():
275 | """Provide a real Graphiti client."""
276 | async with graphiti_test_client(use_mock_llm=False) as (session, group_id):
277 | yield session, group_id
278 |
279 |
280 | # Test data fixtures
281 | @pytest.fixture
282 | def sample_memories():
283 | """Provide sample memory data for testing."""
284 | return [
285 | {
286 | 'name': 'Company Overview',
287 | 'episode_body': TestDataGenerator.generate_company_profile(),
288 | 'source': 'json',
289 | 'source_description': 'company database',
290 | },
291 | {
292 | 'name': 'Product Launch',
293 | 'episode_body': TestDataGenerator.generate_news_article(),
294 | 'source': 'text',
295 | 'source_description': 'press release',
296 | },
297 | {
298 | 'name': 'Customer Support',
299 | 'episode_body': TestDataGenerator.generate_conversation(),
300 | 'source': 'message',
301 | 'source_description': 'support chat',
302 | },
303 | {
304 | 'name': 'Technical Specs',
305 | 'episode_body': TestDataGenerator.generate_technical_document(),
306 | 'source': 'text',
307 | 'source_description': 'documentation',
308 | },
309 | ]
310 |
311 |
312 | @pytest.fixture
313 | def large_dataset():
314 | """Generate a large dataset for stress testing."""
315 | return [
316 | {
317 | 'name': f'Document {i}',
318 | 'episode_body': TestDataGenerator.generate_technical_document(),
319 | 'source': 'text',
320 | 'source_description': 'bulk import',
321 | }
322 | for i in range(50)
323 | ]
324 |
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_bulk_utils.py:
--------------------------------------------------------------------------------
```python
1 | from collections import deque
2 | from unittest.mock import AsyncMock, MagicMock
3 |
4 | import pytest
5 |
6 | from graphiti_core.edges import EntityEdge
7 | from graphiti_core.graphiti_types import GraphitiClients
8 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
9 | from graphiti_core.utils import bulk_utils
10 | from graphiti_core.utils.datetime_utils import utc_now
11 |
12 |
13 | def _make_episode(uuid_suffix: str, group_id: str = 'group') -> EpisodicNode:
14 | return EpisodicNode(
15 | name=f'episode-{uuid_suffix}',
16 | group_id=group_id,
17 | labels=[],
18 | source=EpisodeType.message,
19 | content='content',
20 | source_description='test',
21 | created_at=utc_now(),
22 | valid_at=utc_now(),
23 | )
24 |
25 |
26 | def _make_clients() -> GraphitiClients:
27 | driver = MagicMock()
28 | embedder = MagicMock()
29 | cross_encoder = MagicMock()
30 | llm_client = MagicMock()
31 |
32 | return GraphitiClients.model_construct( # bypass validation to allow test doubles
33 | driver=driver,
34 | embedder=embedder,
35 | cross_encoder=cross_encoder,
36 | llm_client=llm_client,
37 | )
38 |
39 |
40 | @pytest.mark.asyncio
41 | async def test_dedupe_nodes_bulk_reuses_canonical_nodes(monkeypatch):
42 | clients = _make_clients()
43 |
44 | episode_one = _make_episode('1')
45 | episode_two = _make_episode('2')
46 |
47 | extracted_one = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
48 | extracted_two = EntityNode(name='Alice Smith', group_id='group', labels=['Entity'])
49 |
50 | canonical = extracted_one
51 |
52 | call_queue = deque()
53 |
54 | async def fake_resolve(
55 | clients_arg,
56 | nodes_arg,
57 | episode_arg,
58 | previous_episodes_arg,
59 | entity_types_arg,
60 | existing_nodes_override=None,
61 | ):
62 | call_queue.append(existing_nodes_override)
63 |
64 | if nodes_arg == [extracted_one]:
65 | return [canonical], {canonical.uuid: canonical.uuid}, []
66 |
67 | assert nodes_arg == [extracted_two]
68 | assert existing_nodes_override is None
69 |
70 | return [canonical], {extracted_two.uuid: canonical.uuid}, [(extracted_two, canonical)]
71 |
72 | monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
73 |
74 | nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
75 | clients,
76 | [[extracted_one], [extracted_two]],
77 | [(episode_one, []), (episode_two, [])],
78 | )
79 |
80 | assert len(call_queue) == 2
81 | assert call_queue[0] is None
82 | assert call_queue[1] is None
83 |
84 | assert nodes_by_episode[episode_one.uuid] == [canonical]
85 | assert nodes_by_episode[episode_two.uuid] == [canonical]
86 | assert compressed_map.get(extracted_two.uuid) == canonical.uuid
87 |
88 |
89 | @pytest.mark.asyncio
90 | async def test_dedupe_nodes_bulk_handles_empty_batch(monkeypatch):
91 | clients = _make_clients()
92 |
93 | resolve_mock = AsyncMock()
94 | monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
95 |
96 | nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
97 | clients,
98 | [],
99 | [],
100 | )
101 |
102 | assert nodes_by_episode == {}
103 | assert compressed_map == {}
104 | resolve_mock.assert_not_awaited()
105 |
106 |
107 | @pytest.mark.asyncio
108 | async def test_dedupe_nodes_bulk_single_episode(monkeypatch):
109 | clients = _make_clients()
110 |
111 | episode = _make_episode('solo')
112 | extracted = EntityNode(name='Solo', group_id='group', labels=['Entity'])
113 |
114 | resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: extracted.uuid}, []))
115 | monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
116 |
117 | nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
118 | clients,
119 | [[extracted]],
120 | [(episode, [])],
121 | )
122 |
123 | assert nodes_by_episode == {episode.uuid: [extracted]}
124 | assert compressed_map == {extracted.uuid: extracted.uuid}
125 | resolve_mock.assert_awaited_once()
126 |
127 |
128 | @pytest.mark.asyncio
129 | async def test_dedupe_nodes_bulk_uuid_map_respects_direction(monkeypatch):
130 | clients = _make_clients()
131 |
132 | episode_one = _make_episode('one')
133 | episode_two = _make_episode('two')
134 |
135 | extracted_one = EntityNode(uuid='b-uuid', name='Edge Case', group_id='group', labels=['Entity'])
136 | extracted_two = EntityNode(uuid='a-uuid', name='Edge Case', group_id='group', labels=['Entity'])
137 |
138 | canonical = extracted_one
139 | alias = extracted_two
140 |
141 | async def fake_resolve(
142 | clients_arg,
143 | nodes_arg,
144 | episode_arg,
145 | previous_episodes_arg,
146 | entity_types_arg,
147 | existing_nodes_override=None,
148 | ):
149 | if nodes_arg == [extracted_one]:
150 | return [canonical], {canonical.uuid: canonical.uuid}, []
151 | assert nodes_arg == [extracted_two]
152 | return [canonical], {alias.uuid: canonical.uuid}, [(alias, canonical)]
153 |
154 | monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', fake_resolve)
155 |
156 | nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
157 | clients,
158 | [[extracted_one], [extracted_two]],
159 | [(episode_one, []), (episode_two, [])],
160 | )
161 |
162 | assert nodes_by_episode[episode_one.uuid] == [canonical]
163 | assert nodes_by_episode[episode_two.uuid] == [canonical]
164 | assert compressed_map.get(alias.uuid) == canonical.uuid
165 |
166 |
167 | @pytest.mark.asyncio
168 | async def test_dedupe_nodes_bulk_missing_canonical_falls_back(monkeypatch, caplog):
169 | clients = _make_clients()
170 |
171 | episode = _make_episode('missing')
172 | extracted = EntityNode(name='Fallback', group_id='group', labels=['Entity'])
173 |
174 | resolve_mock = AsyncMock(return_value=([extracted], {extracted.uuid: 'missing-canonical'}, []))
175 | monkeypatch.setattr(bulk_utils, 'resolve_extracted_nodes', resolve_mock)
176 |
177 | with caplog.at_level('WARNING'):
178 | nodes_by_episode, compressed_map = await bulk_utils.dedupe_nodes_bulk(
179 | clients,
180 | [[extracted]],
181 | [(episode, [])],
182 | )
183 |
184 | assert nodes_by_episode[episode.uuid] == [extracted]
185 | assert compressed_map.get(extracted.uuid) == 'missing-canonical'
186 | assert any('Canonical node missing' in rec.message for rec in caplog.records)
187 |
188 |
189 | def test_build_directed_uuid_map_empty():
190 | assert bulk_utils._build_directed_uuid_map([]) == {}
191 |
192 |
193 | def test_build_directed_uuid_map_chain():
194 | mapping = bulk_utils._build_directed_uuid_map(
195 | [
196 | ('a', 'b'),
197 | ('b', 'c'),
198 | ]
199 | )
200 |
201 | assert mapping['a'] == 'c'
202 | assert mapping['b'] == 'c'
203 | assert mapping['c'] == 'c'
204 |
205 |
206 | def test_build_directed_uuid_map_preserves_direction():
207 | mapping = bulk_utils._build_directed_uuid_map(
208 | [
209 | ('alias', 'canonical'),
210 | ]
211 | )
212 |
213 | assert mapping['alias'] == 'canonical'
214 | assert mapping['canonical'] == 'canonical'
215 |
216 |
217 | def test_resolve_edge_pointers_updates_sources():
218 | created_at = utc_now()
219 | edge = EntityEdge(
220 | name='knows',
221 | fact='fact',
222 | group_id='group',
223 | source_node_uuid='alias',
224 | target_node_uuid='target',
225 | created_at=created_at,
226 | )
227 |
228 | bulk_utils.resolve_edge_pointers([edge], {'alias': 'canonical'})
229 |
230 | assert edge.source_node_uuid == 'canonical'
231 | assert edge.target_node_uuid == 'target'
232 |
233 |
234 | @pytest.mark.asyncio
235 | async def test_dedupe_edges_bulk_deduplicates_within_episode(monkeypatch):
236 | """Test that dedupe_edges_bulk correctly compares edges within the same episode.
237 |
238 | This test verifies the fix that removed the `if i == j: continue` check,
239 | which was preventing edges from the same episode from being compared against each other.
240 | """
241 | clients = _make_clients()
242 |
243 | # Track which edges are compared
244 | comparisons_made = []
245 |
246 | # Create mock embedder that sets embedding values
247 | async def mock_create_embeddings(embedder, edges):
248 | for edge in edges:
249 | edge.fact_embedding = [0.1, 0.2, 0.3]
250 |
251 | monkeypatch.setattr(bulk_utils, 'create_entity_edge_embeddings', mock_create_embeddings)
252 |
253 | # Mock resolve_extracted_edge to track comparisons and mark duplicates
254 | async def mock_resolve_extracted_edge(
255 | llm_client,
256 | extracted_edge,
257 | related_edges,
258 | existing_edges,
259 | episode,
260 | edge_type_candidates=None,
261 | custom_edge_type_names=None,
262 | ):
263 | # Track that this edge was compared against the related_edges
264 | comparisons_made.append((extracted_edge.uuid, [r.uuid for r in related_edges]))
265 |
266 | # If there are related edges with same source/target/fact, mark as duplicate
267 | for related in related_edges:
268 | if (
269 | related.uuid != extracted_edge.uuid # Can't be duplicate of self
270 | and related.source_node_uuid == extracted_edge.source_node_uuid
271 | and related.target_node_uuid == extracted_edge.target_node_uuid
272 | and related.fact.strip().lower() == extracted_edge.fact.strip().lower()
273 | ):
274 | # Return the related edge and mark extracted_edge as duplicate
275 | return related, [], [related]
276 | # Otherwise return the extracted edge as-is
277 | return extracted_edge, [], []
278 |
279 | monkeypatch.setattr(bulk_utils, 'resolve_extracted_edge', mock_resolve_extracted_edge)
280 |
281 | episode = _make_episode('1')
282 | source_uuid = 'source-uuid'
283 | target_uuid = 'target-uuid'
284 |
285 | # Create 3 identical edges within the same episode
286 | edge1 = EntityEdge(
287 | name='recommends',
288 | fact='assistant recommends yoga poses',
289 | group_id='group',
290 | source_node_uuid=source_uuid,
291 | target_node_uuid=target_uuid,
292 | created_at=utc_now(),
293 | episodes=[episode.uuid],
294 | )
295 | edge2 = EntityEdge(
296 | name='recommends',
297 | fact='assistant recommends yoga poses',
298 | group_id='group',
299 | source_node_uuid=source_uuid,
300 | target_node_uuid=target_uuid,
301 | created_at=utc_now(),
302 | episodes=[episode.uuid],
303 | )
304 | edge3 = EntityEdge(
305 | name='recommends',
306 | fact='assistant recommends yoga poses',
307 | group_id='group',
308 | source_node_uuid=source_uuid,
309 | target_node_uuid=target_uuid,
310 | created_at=utc_now(),
311 | episodes=[episode.uuid],
312 | )
313 |
314 | await bulk_utils.dedupe_edges_bulk(
315 | clients,
316 | [[edge1, edge2, edge3]],
317 | [(episode, [])],
318 | [],
319 | {},
320 | {},
321 | )
322 |
323 | # Verify that edges were compared against each other (within same episode)
324 | # Each edge should have been compared against all 3 edges (including itself, which gets filtered)
325 | assert len(comparisons_made) == 3
326 | for _, compared_against in comparisons_made:
327 | # Each edge should have access to all 3 edges as candidates
328 | assert len(compared_against) >= 2 # At least 2 others (self is filtered out)
329 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/community_operations.py:
--------------------------------------------------------------------------------
```python
1 | import asyncio
2 | import logging
3 | from collections import defaultdict
4 |
5 | from pydantic import BaseModel
6 |
7 | from graphiti_core.driver.driver import GraphDriver, GraphProvider
8 | from graphiti_core.edges import CommunityEdge
9 | from graphiti_core.embedder import EmbedderClient
10 | from graphiti_core.helpers import semaphore_gather
11 | from graphiti_core.llm_client import LLMClient
12 | from graphiti_core.models.nodes.node_db_queries import COMMUNITY_NODE_RETURN
13 | from graphiti_core.nodes import CommunityNode, EntityNode, get_community_node_from_record
14 | from graphiti_core.prompts import prompt_library
15 | from graphiti_core.prompts.summarize_nodes import Summary, SummaryDescription
16 | from graphiti_core.utils.datetime_utils import utc_now
17 | from graphiti_core.utils.maintenance.edge_operations import build_community_edges
18 |
19 | MAX_COMMUNITY_BUILD_CONCURRENCY = 10
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class Neighbor(BaseModel):
25 | node_uuid: str
26 | edge_count: int
27 |
28 |
29 | async def get_community_clusters(
30 | driver: GraphDriver, group_ids: list[str] | None
31 | ) -> list[list[EntityNode]]:
32 | community_clusters: list[list[EntityNode]] = []
33 |
34 | if group_ids is None:
35 | group_id_values, _, _ = await driver.execute_query(
36 | """
37 | MATCH (n:Entity)
38 | WHERE n.group_id IS NOT NULL
39 | RETURN
40 | collect(DISTINCT n.group_id) AS group_ids
41 | """
42 | )
43 |
44 | group_ids = group_id_values[0]['group_ids'] if group_id_values else []
45 |
46 | for group_id in group_ids:
47 | projection: dict[str, list[Neighbor]] = {}
48 | nodes = await EntityNode.get_by_group_ids(driver, [group_id])
49 | for node in nodes:
50 | match_query = """
51 | MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[e:RELATES_TO]-(m: Entity {group_id: $group_id})
52 | """
53 | if driver.provider == GraphProvider.KUZU:
54 | match_query = """
55 | MATCH (n:Entity {group_id: $group_id, uuid: $uuid})-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m: Entity {group_id: $group_id})
56 | """
57 | records, _, _ = await driver.execute_query(
58 | match_query
59 | + """
60 | WITH count(e) AS count, m.uuid AS uuid
61 | RETURN
62 | uuid,
63 | count
64 | """,
65 | uuid=node.uuid,
66 | group_id=group_id,
67 | )
68 |
69 | projection[node.uuid] = [
70 | Neighbor(node_uuid=record['uuid'], edge_count=record['count']) for record in records
71 | ]
72 |
73 | cluster_uuids = label_propagation(projection)
74 |
75 | community_clusters.extend(
76 | list(
77 | await semaphore_gather(
78 | *[EntityNode.get_by_uuids(driver, cluster) for cluster in cluster_uuids]
79 | )
80 | )
81 | )
82 |
83 | return community_clusters
84 |
85 |
86 | def label_propagation(projection: dict[str, list[Neighbor]]) -> list[list[str]]:
87 | # Implement the label propagation community detection algorithm.
88 | # 1. Start with each node being assigned its own community
89 | # 2. Each node will take on the community of the plurality of its neighbors
90 | # 3. Ties are broken by going to the largest community
91 | # 4. Continue until no communities change during propagation
92 |
93 | community_map = {uuid: i for i, uuid in enumerate(projection.keys())}
94 |
95 | while True:
96 | no_change = True
97 | new_community_map: dict[str, int] = {}
98 |
99 | for uuid, neighbors in projection.items():
100 | curr_community = community_map[uuid]
101 |
102 | community_candidates: dict[int, int] = defaultdict(int)
103 | for neighbor in neighbors:
104 | community_candidates[community_map[neighbor.node_uuid]] += neighbor.edge_count
105 | community_lst = [
106 | (count, community) for community, count in community_candidates.items()
107 | ]
108 |
109 | community_lst.sort(reverse=True)
110 | candidate_rank, community_candidate = community_lst[0] if community_lst else (0, -1)
111 | if community_candidate != -1 and candidate_rank > 1:
112 | new_community = community_candidate
113 | else:
114 | new_community = max(community_candidate, curr_community)
115 |
116 | new_community_map[uuid] = new_community
117 |
118 | if new_community != curr_community:
119 | no_change = False
120 |
121 | if no_change:
122 | break
123 |
124 | community_map = new_community_map
125 |
126 | community_cluster_map = defaultdict(list)
127 | for uuid, community in community_map.items():
128 | community_cluster_map[community].append(uuid)
129 |
130 | clusters = [cluster for cluster in community_cluster_map.values()]
131 | return clusters
132 |
133 |
134 | async def summarize_pair(llm_client: LLMClient, summary_pair: tuple[str, str]) -> str:
135 | # Prepare context for LLM
136 | context = {
137 | 'node_summaries': [{'summary': summary} for summary in summary_pair],
138 | }
139 |
140 | llm_response = await llm_client.generate_response(
141 | prompt_library.summarize_nodes.summarize_pair(context),
142 | response_model=Summary,
143 | prompt_name='summarize_nodes.summarize_pair',
144 | )
145 |
146 | pair_summary = llm_response.get('summary', '')
147 |
148 | return pair_summary
149 |
150 |
151 | async def generate_summary_description(llm_client: LLMClient, summary: str) -> str:
152 | context = {
153 | 'summary': summary,
154 | }
155 |
156 | llm_response = await llm_client.generate_response(
157 | prompt_library.summarize_nodes.summary_description(context),
158 | response_model=SummaryDescription,
159 | prompt_name='summarize_nodes.summary_description',
160 | )
161 |
162 | description = llm_response.get('description', '')
163 |
164 | return description
165 |
166 |
167 | async def build_community(
168 | llm_client: LLMClient, community_cluster: list[EntityNode]
169 | ) -> tuple[CommunityNode, list[CommunityEdge]]:
170 | summaries = [entity.summary for entity in community_cluster]
171 | length = len(summaries)
172 | while length > 1:
173 | odd_one_out: str | None = None
174 | if length % 2 == 1:
175 | odd_one_out = summaries.pop()
176 | length -= 1
177 | new_summaries: list[str] = list(
178 | await semaphore_gather(
179 | *[
180 | summarize_pair(llm_client, (str(left_summary), str(right_summary)))
181 | for left_summary, right_summary in zip(
182 | summaries[: int(length / 2)], summaries[int(length / 2) :], strict=False
183 | )
184 | ]
185 | )
186 | )
187 | if odd_one_out is not None:
188 | new_summaries.append(odd_one_out)
189 | summaries = new_summaries
190 | length = len(summaries)
191 |
192 | summary = summaries[0]
193 | name = await generate_summary_description(llm_client, summary)
194 | now = utc_now()
195 | community_node = CommunityNode(
196 | name=name,
197 | group_id=community_cluster[0].group_id,
198 | labels=['Community'],
199 | created_at=now,
200 | summary=summary,
201 | )
202 | community_edges = build_community_edges(community_cluster, community_node, now)
203 |
204 | logger.debug((community_node, community_edges))
205 |
206 | return community_node, community_edges
207 |
208 |
209 | async def build_communities(
210 | driver: GraphDriver,
211 | llm_client: LLMClient,
212 | group_ids: list[str] | None,
213 | ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
214 | community_clusters = await get_community_clusters(driver, group_ids)
215 |
216 | semaphore = asyncio.Semaphore(MAX_COMMUNITY_BUILD_CONCURRENCY)
217 |
218 | async def limited_build_community(cluster):
219 | async with semaphore:
220 | return await build_community(llm_client, cluster)
221 |
222 | communities: list[tuple[CommunityNode, list[CommunityEdge]]] = list(
223 | await semaphore_gather(
224 | *[limited_build_community(cluster) for cluster in community_clusters]
225 | )
226 | )
227 |
228 | community_nodes: list[CommunityNode] = []
229 | community_edges: list[CommunityEdge] = []
230 | for community in communities:
231 | community_nodes.append(community[0])
232 | community_edges.extend(community[1])
233 |
234 | return community_nodes, community_edges
235 |
236 |
237 | async def remove_communities(driver: GraphDriver):
238 | await driver.execute_query(
239 | """
240 | MATCH (c:Community)
241 | DETACH DELETE c
242 | """
243 | )
244 |
245 |
246 | async def determine_entity_community(
247 | driver: GraphDriver, entity: EntityNode
248 | ) -> tuple[CommunityNode | None, bool]:
249 | # Check if the node is already part of a community
250 | records, _, _ = await driver.execute_query(
251 | """
252 | MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity {uuid: $entity_uuid})
253 | RETURN
254 | """
255 | + COMMUNITY_NODE_RETURN,
256 | entity_uuid=entity.uuid,
257 | )
258 |
259 | if len(records) > 0:
260 | return get_community_node_from_record(records[0]), False
261 |
262 | # If the node has no community, add it to the mode community of surrounding entities
263 | match_query = """
264 | MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
265 | """
266 | if driver.provider == GraphProvider.KUZU:
267 | match_query = """
268 | MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(n:Entity {uuid: $entity_uuid})
269 | """
270 | records, _, _ = await driver.execute_query(
271 | match_query
272 | + """
273 | RETURN
274 | """
275 | + COMMUNITY_NODE_RETURN,
276 | entity_uuid=entity.uuid,
277 | )
278 |
279 | communities: list[CommunityNode] = [
280 | get_community_node_from_record(record) for record in records
281 | ]
282 |
283 | community_map: dict[str, int] = defaultdict(int)
284 | for community in communities:
285 | community_map[community.uuid] += 1
286 |
287 | community_uuid = None
288 | max_count = 0
289 | for uuid, count in community_map.items():
290 | if count > max_count:
291 | community_uuid = uuid
292 | max_count = count
293 |
294 | if max_count == 0:
295 | return None, False
296 |
297 | for community in communities:
298 | if community.uuid == community_uuid:
299 | return community, True
300 |
301 | return None, False
302 |
303 |
304 | async def update_community(
305 | driver: GraphDriver,
306 | llm_client: LLMClient,
307 | embedder: EmbedderClient,
308 | entity: EntityNode,
309 | ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
310 | community, is_new = await determine_entity_community(driver, entity)
311 |
312 | if community is None:
313 | return [], []
314 |
315 | new_summary = await summarize_pair(llm_client, (entity.summary, community.summary))
316 | new_name = await generate_summary_description(llm_client, new_summary)
317 |
318 | community.summary = new_summary
319 | community.name = new_name
320 |
321 | community_edges = []
322 | if is_new:
323 | community_edge = (build_community_edges([entity], community, utc_now()))[0]
324 | await community_edge.save(driver)
325 | community_edges.append(community_edge)
326 |
327 | await community.generate_name_embedding(embedder)
328 |
329 | await community.save(driver)
330 |
331 | return [community], community_edges
332 |
```
--------------------------------------------------------------------------------
/graphiti_core/prompts/extract_nodes.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import Any, Protocol, TypedDict
18 |
19 | from pydantic import BaseModel, Field
20 |
21 | from graphiti_core.utils.text_utils import MAX_SUMMARY_CHARS
22 |
23 | from .models import Message, PromptFunction, PromptVersion
24 | from .prompt_helpers import to_prompt_json
25 | from .snippets import summary_instructions
26 |
27 |
28 | class ExtractedEntity(BaseModel):
29 | name: str = Field(..., description='Name of the extracted entity')
30 | entity_type_id: int = Field(
31 | description='ID of the classified entity type. '
32 | 'Must be one of the provided entity_type_id integers.',
33 | )
34 |
35 |
36 | class ExtractedEntities(BaseModel):
37 | extracted_entities: list[ExtractedEntity] = Field(..., description='List of extracted entities')
38 |
39 |
40 | class MissedEntities(BaseModel):
41 | missed_entities: list[str] = Field(..., description="Names of entities that weren't extracted")
42 |
43 |
44 | class EntityClassificationTriple(BaseModel):
45 | uuid: str = Field(description='UUID of the entity')
46 | name: str = Field(description='Name of the entity')
47 | entity_type: str | None = Field(
48 | default=None,
49 | description='Type of the entity. Must be one of the provided types or None',
50 | )
51 |
52 |
53 | class EntityClassification(BaseModel):
54 | entity_classifications: list[EntityClassificationTriple] = Field(
55 | ..., description='List of entities classification triples.'
56 | )
57 |
58 |
59 | class EntitySummary(BaseModel):
60 | summary: str = Field(
61 | ...,
62 | description=f'Summary containing the important information about the entity. Under {MAX_SUMMARY_CHARS} characters.',
63 | )
64 |
65 |
66 | class Prompt(Protocol):
67 | extract_message: PromptVersion
68 | extract_json: PromptVersion
69 | extract_text: PromptVersion
70 | reflexion: PromptVersion
71 | classify_nodes: PromptVersion
72 | extract_attributes: PromptVersion
73 | extract_summary: PromptVersion
74 |
75 |
76 | class Versions(TypedDict):
77 | extract_message: PromptFunction
78 | extract_json: PromptFunction
79 | extract_text: PromptFunction
80 | reflexion: PromptFunction
81 | classify_nodes: PromptFunction
82 | extract_attributes: PromptFunction
83 | extract_summary: PromptFunction
84 |
85 |
86 | def extract_message(context: dict[str, Any]) -> list[Message]:
87 | sys_prompt = """You are an AI assistant that extracts entity nodes from conversational messages.
88 | Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation."""
89 |
90 | user_prompt = f"""
91 | <ENTITY TYPES>
92 | {context['entity_types']}
93 | </ENTITY TYPES>
94 |
95 | <PREVIOUS MESSAGES>
96 | {to_prompt_json([ep for ep in context['previous_episodes']])}
97 | </PREVIOUS MESSAGES>
98 |
99 | <CURRENT MESSAGE>
100 | {context['episode_content']}
101 | </CURRENT MESSAGE>
102 |
103 | Instructions:
104 |
105 | You are given a conversation context and a CURRENT MESSAGE. Your task is to extract **entity nodes** mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
106 | Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the
107 | reference entities. Only extract distinct entities from the CURRENT MESSAGE. Don't extract pronouns like you, me, he/she/they, we/us as entities.
108 |
109 | 1. **Speaker Extraction**: Always extract the speaker (the part before the colon `:` in each dialogue line) as the first entity node.
110 | - If the speaker is mentioned again in the message, treat both mentions as a **single entity**.
111 |
112 | 2. **Entity Identification**:
113 | - Extract all significant entities, concepts, or actors that are **explicitly or implicitly** mentioned in the CURRENT MESSAGE.
114 | - **Exclude** entities mentioned only in the PREVIOUS MESSAGES (they are for context only).
115 |
116 | 3. **Entity Classification**:
117 | - Use the descriptions in ENTITY TYPES to classify each extracted entity.
118 | - Assign the appropriate `entity_type_id` for each one.
119 |
120 | 4. **Exclusions**:
121 | - Do NOT extract entities representing relationships or actions.
122 | - Do NOT extract dates, times, or other temporal information—these will be handled separately.
123 |
124 | 5. **Formatting**:
125 | - Be **explicit and unambiguous** in naming entities (e.g., use full names when available).
126 |
127 | {context['custom_prompt']}
128 | """
129 | return [
130 | Message(role='system', content=sys_prompt),
131 | Message(role='user', content=user_prompt),
132 | ]
133 |
134 |
135 | def extract_json(context: dict[str, Any]) -> list[Message]:
136 | sys_prompt = """You are an AI assistant that extracts entity nodes from JSON.
137 | Your primary task is to extract and classify relevant entities from JSON files"""
138 |
139 | user_prompt = f"""
140 | <ENTITY TYPES>
141 | {context['entity_types']}
142 | </ENTITY TYPES>
143 |
144 | <SOURCE DESCRIPTION>:
145 | {context['source_description']}
146 | </SOURCE DESCRIPTION>
147 | <JSON>
148 | {context['episode_content']}
149 | </JSON>
150 |
151 | {context['custom_prompt']}
152 |
153 | Given the above source description and JSON, extract relevant entities from the provided JSON.
154 | For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
155 | Indicate the classified entity type by providing its entity_type_id.
156 |
157 | Guidelines:
158 | 1. Extract all entities that the JSON represents. This will often be something like a "name" or "user" field
159 | 2. Extract all entities mentioned in all other properties throughout the JSON structure
160 | 3. Do NOT extract any properties that contain dates
161 | """
162 | return [
163 | Message(role='system', content=sys_prompt),
164 | Message(role='user', content=user_prompt),
165 | ]
166 |
167 |
168 | def extract_text(context: dict[str, Any]) -> list[Message]:
169 | sys_prompt = """You are an AI assistant that extracts entity nodes from text.
170 | Your primary task is to extract and classify the speaker and other significant entities mentioned in the provided text."""
171 |
172 | user_prompt = f"""
173 | <ENTITY TYPES>
174 | {context['entity_types']}
175 | </ENTITY TYPES>
176 |
177 | <TEXT>
178 | {context['episode_content']}
179 | </TEXT>
180 |
181 | Given the above text, extract entities from the TEXT that are explicitly or implicitly mentioned.
182 | For each entity extracted, also determine its entity type based on the provided ENTITY TYPES and their descriptions.
183 | Indicate the classified entity type by providing its entity_type_id.
184 |
185 | {context['custom_prompt']}
186 |
187 | Guidelines:
188 | 1. Extract significant entities, concepts, or actors mentioned in the conversation.
189 | 2. Avoid creating nodes for relationships or actions.
190 | 3. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
191 | 4. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
192 | """
193 | return [
194 | Message(role='system', content=sys_prompt),
195 | Message(role='user', content=user_prompt),
196 | ]
197 |
198 |
199 | def reflexion(context: dict[str, Any]) -> list[Message]:
200 | sys_prompt = """You are an AI assistant that determines which entities have not been extracted from the given context"""
201 |
202 | user_prompt = f"""
203 | <PREVIOUS MESSAGES>
204 | {to_prompt_json([ep for ep in context['previous_episodes']])}
205 | </PREVIOUS MESSAGES>
206 | <CURRENT MESSAGE>
207 | {context['episode_content']}
208 | </CURRENT MESSAGE>
209 |
210 | <EXTRACTED ENTITIES>
211 | {context['extracted_entities']}
212 | </EXTRACTED ENTITIES>
213 |
214 | Given the above previous messages, current message, and list of extracted entities; determine if any entities haven't been
215 | extracted.
216 | """
217 | return [
218 | Message(role='system', content=sys_prompt),
219 | Message(role='user', content=user_prompt),
220 | ]
221 |
222 |
223 | def classify_nodes(context: dict[str, Any]) -> list[Message]:
224 | sys_prompt = """You are an AI assistant that classifies entity nodes given the context from which they were extracted"""
225 |
226 | user_prompt = f"""
227 | <PREVIOUS MESSAGES>
228 | {to_prompt_json([ep for ep in context['previous_episodes']])}
229 | </PREVIOUS MESSAGES>
230 | <CURRENT MESSAGE>
231 | {context['episode_content']}
232 | </CURRENT MESSAGE>
233 |
234 | <EXTRACTED ENTITIES>
235 | {context['extracted_entities']}
236 | </EXTRACTED ENTITIES>
237 |
238 | <ENTITY TYPES>
239 | {context['entity_types']}
240 | </ENTITY TYPES>
241 |
242 | Given the above conversation, extracted entities, and provided entity types and their descriptions, classify the extracted entities.
243 |
244 | Guidelines:
245 | 1. Each entity must have exactly one type
246 | 2. Only use the provided ENTITY TYPES as types, do not use additional types to classify entities.
247 | 3. If none of the provided entity types accurately classify an extracted node, the type should be set to None
248 | """
249 | return [
250 | Message(role='system', content=sys_prompt),
251 | Message(role='user', content=user_prompt),
252 | ]
253 |
254 |
255 | def extract_attributes(context: dict[str, Any]) -> list[Message]:
256 | return [
257 | Message(
258 | role='system',
259 | content='You are a helpful assistant that extracts entity properties from the provided text.',
260 | ),
261 | Message(
262 | role='user',
263 | content=f"""
264 | Given the MESSAGES and the following ENTITY, update any of its attributes based on the information provided
265 | in MESSAGES. Use the provided attribute descriptions to better understand how each attribute should be determined.
266 |
267 | Guidelines:
268 | 1. Do not hallucinate entity property values if they cannot be found in the current context.
269 | 2. Only use the provided MESSAGES and ENTITY to set attribute values.
270 |
271 | <MESSAGES>
272 | {to_prompt_json(context['previous_episodes'])}
273 | {to_prompt_json(context['episode_content'])}
274 | </MESSAGES>
275 |
276 | <ENTITY>
277 | {context['node']}
278 | </ENTITY>
279 | """,
280 | ),
281 | ]
282 |
283 |
284 | def extract_summary(context: dict[str, Any]) -> list[Message]:
285 | return [
286 | Message(
287 | role='system',
288 | content='You are a helpful assistant that extracts entity summaries from the provided text.',
289 | ),
290 | Message(
291 | role='user',
292 | content=f"""
293 | Given the MESSAGES and the ENTITY, update the summary that combines relevant information about the entity
294 | from the messages and relevant information from the existing summary.
295 |
296 | {summary_instructions}
297 |
298 | <MESSAGES>
299 | {to_prompt_json(context['previous_episodes'])}
300 | {to_prompt_json(context['episode_content'])}
301 | </MESSAGES>
302 |
303 | <ENTITY>
304 | {context['node']}
305 | </ENTITY>
306 | """,
307 | ),
308 | ]
309 |
310 |
311 | versions: Versions = {
312 | 'extract_message': extract_message,
313 | 'extract_json': extract_json,
314 | 'extract_text': extract_text,
315 | 'reflexion': reflexion,
316 | 'extract_summary': extract_summary,
317 | 'classify_nodes': classify_nodes,
318 | 'extract_attributes': extract_attributes,
319 | }
320 |
```
--------------------------------------------------------------------------------
/graphiti_core/models/edges/edge_db_queries.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from graphiti_core.driver.driver import GraphProvider
18 |
19 | EPISODIC_EDGE_SAVE = """
20 | MATCH (episode:Episodic {uuid: $episode_uuid})
21 | MATCH (node:Entity {uuid: $entity_uuid})
22 | MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
23 | SET
24 | e.group_id = $group_id,
25 | e.created_at = $created_at
26 | RETURN e.uuid AS uuid
27 | """
28 |
29 |
30 | def get_episodic_edge_save_bulk_query(provider: GraphProvider) -> str:
31 | if provider == GraphProvider.KUZU:
32 | return """
33 | MATCH (episode:Episodic {uuid: $source_node_uuid})
34 | MATCH (node:Entity {uuid: $target_node_uuid})
35 | MERGE (episode)-[e:MENTIONS {uuid: $uuid}]->(node)
36 | SET
37 | e.group_id = $group_id,
38 | e.created_at = $created_at
39 | RETURN e.uuid AS uuid
40 | """
41 |
42 | return """
43 | UNWIND $episodic_edges AS edge
44 | MATCH (episode:Episodic {uuid: edge.source_node_uuid})
45 | MATCH (node:Entity {uuid: edge.target_node_uuid})
46 | MERGE (episode)-[e:MENTIONS {uuid: edge.uuid}]->(node)
47 | SET
48 | e.group_id = edge.group_id,
49 | e.created_at = edge.created_at
50 | RETURN e.uuid AS uuid
51 | """
52 |
53 |
54 | EPISODIC_EDGE_RETURN = """
55 | e.uuid AS uuid,
56 | e.group_id AS group_id,
57 | n.uuid AS source_node_uuid,
58 | m.uuid AS target_node_uuid,
59 | e.created_at AS created_at
60 | """
61 |
62 |
63 | def get_entity_edge_save_query(provider: GraphProvider, has_aoss: bool = False) -> str:
64 | match provider:
65 | case GraphProvider.FALKORDB:
66 | return """
67 | MATCH (source:Entity {uuid: $edge_data.source_uuid})
68 | MATCH (target:Entity {uuid: $edge_data.target_uuid})
69 | MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
70 | SET e = $edge_data
71 | SET e.fact_embedding = vecf32($edge_data.fact_embedding)
72 | RETURN e.uuid AS uuid
73 | """
74 | case GraphProvider.NEPTUNE:
75 | return """
76 | MATCH (source:Entity {uuid: $edge_data.source_uuid})
77 | MATCH (target:Entity {uuid: $edge_data.target_uuid})
78 | MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
79 | SET e = removeKeyFromMap(removeKeyFromMap($edge_data, "fact_embedding"), "episodes")
80 | SET e.fact_embedding = join([x IN coalesce($edge_data.fact_embedding, []) | toString(x) ], ",")
81 | SET e.episodes = join($edge_data.episodes, ",")
82 | RETURN $edge_data.uuid AS uuid
83 | """
84 | case GraphProvider.KUZU:
85 | return """
86 | MATCH (source:Entity {uuid: $source_uuid})
87 | MATCH (target:Entity {uuid: $target_uuid})
88 | MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
89 | SET
90 | e.group_id = $group_id,
91 | e.created_at = $created_at,
92 | e.name = $name,
93 | e.fact = $fact,
94 | e.fact_embedding = $fact_embedding,
95 | e.episodes = $episodes,
96 | e.expired_at = $expired_at,
97 | e.valid_at = $valid_at,
98 | e.invalid_at = $invalid_at,
99 | e.attributes = $attributes
100 | RETURN e.uuid AS uuid
101 | """
102 | case _: # Neo4j
103 | save_embedding_query = (
104 | """WITH e CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", $edge_data.fact_embedding)"""
105 | if not has_aoss
106 | else ''
107 | )
108 | return (
109 | (
110 | """
111 | MATCH (source:Entity {uuid: $edge_data.source_uuid})
112 | MATCH (target:Entity {uuid: $edge_data.target_uuid})
113 | MERGE (source)-[e:RELATES_TO {uuid: $edge_data.uuid}]->(target)
114 | SET e = $edge_data
115 | """
116 | + save_embedding_query
117 | )
118 | + """
119 | RETURN e.uuid AS uuid
120 | """
121 | )
122 |
123 |
124 | def get_entity_edge_save_bulk_query(provider: GraphProvider, has_aoss: bool = False) -> str:
125 | match provider:
126 | case GraphProvider.FALKORDB:
127 | return """
128 | UNWIND $entity_edges AS edge
129 | MATCH (source:Entity {uuid: edge.source_node_uuid})
130 | MATCH (target:Entity {uuid: edge.target_node_uuid})
131 | MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
132 | SET r = edge
133 | SET r.fact_embedding = vecf32(edge.fact_embedding)
134 | WITH r, edge
135 | RETURN edge.uuid AS uuid
136 | """
137 | case GraphProvider.NEPTUNE:
138 | return """
139 | UNWIND $entity_edges AS edge
140 | MATCH (source:Entity {uuid: edge.source_node_uuid})
141 | MATCH (target:Entity {uuid: edge.target_node_uuid})
142 | MERGE (source)-[r:RELATES_TO {uuid: edge.uuid}]->(target)
143 | SET r = removeKeyFromMap(removeKeyFromMap(edge, "fact_embedding"), "episodes")
144 | SET r.fact_embedding = join([x IN coalesce(edge.fact_embedding, []) | toString(x) ], ",")
145 | SET r.episodes = join(edge.episodes, ",")
146 | RETURN edge.uuid AS uuid
147 | """
148 | case GraphProvider.KUZU:
149 | return """
150 | MATCH (source:Entity {uuid: $source_node_uuid})
151 | MATCH (target:Entity {uuid: $target_node_uuid})
152 | MERGE (source)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: $uuid})-[:RELATES_TO]->(target)
153 | SET
154 | e.group_id = $group_id,
155 | e.created_at = $created_at,
156 | e.name = $name,
157 | e.fact = $fact,
158 | e.fact_embedding = $fact_embedding,
159 | e.episodes = $episodes,
160 | e.expired_at = $expired_at,
161 | e.valid_at = $valid_at,
162 | e.invalid_at = $invalid_at,
163 | e.attributes = $attributes
164 | RETURN e.uuid AS uuid
165 | """
166 | case _:
167 | save_embedding_query = (
168 | 'WITH e, edge CALL db.create.setRelationshipVectorProperty(e, "fact_embedding", edge.fact_embedding)'
169 | if not has_aoss
170 | else ''
171 | )
172 | return (
173 | """
174 | UNWIND $entity_edges AS edge
175 | MATCH (source:Entity {uuid: edge.source_node_uuid})
176 | MATCH (target:Entity {uuid: edge.target_node_uuid})
177 | MERGE (source)-[e:RELATES_TO {uuid: edge.uuid}]->(target)
178 | SET e = edge
179 | """
180 | + save_embedding_query
181 | + """
182 | RETURN edge.uuid AS uuid
183 | """
184 | )
185 |
186 |
187 | def get_entity_edge_return_query(provider: GraphProvider) -> str:
188 | # `fact_embedding` is not returned by default and must be manually loaded using `load_fact_embedding()`.
189 |
190 | if provider == GraphProvider.NEPTUNE:
191 | return """
192 | e.uuid AS uuid,
193 | n.uuid AS source_node_uuid,
194 | m.uuid AS target_node_uuid,
195 | e.group_id AS group_id,
196 | e.name AS name,
197 | e.fact AS fact,
198 | split(e.episodes, ',') AS episodes,
199 | e.created_at AS created_at,
200 | e.expired_at AS expired_at,
201 | e.valid_at AS valid_at,
202 | e.invalid_at AS invalid_at,
203 | properties(e) AS attributes
204 | """
205 |
206 | return """
207 | e.uuid AS uuid,
208 | n.uuid AS source_node_uuid,
209 | m.uuid AS target_node_uuid,
210 | e.group_id AS group_id,
211 | e.created_at AS created_at,
212 | e.name AS name,
213 | e.fact AS fact,
214 | e.episodes AS episodes,
215 | e.expired_at AS expired_at,
216 | e.valid_at AS valid_at,
217 | e.invalid_at AS invalid_at,
218 | """ + (
219 | 'e.attributes AS attributes'
220 | if provider == GraphProvider.KUZU
221 | else 'properties(e) AS attributes'
222 | )
223 |
224 |
225 | def get_community_edge_save_query(provider: GraphProvider) -> str:
226 | match provider:
227 | case GraphProvider.FALKORDB:
228 | return """
229 | MATCH (community:Community {uuid: $community_uuid})
230 | MATCH (node {uuid: $entity_uuid})
231 | MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
232 | SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
233 | RETURN e.uuid AS uuid
234 | """
235 | case GraphProvider.NEPTUNE:
236 | return """
237 | MATCH (community:Community {uuid: $community_uuid})
238 | MATCH (node {uuid: $entity_uuid})
239 | WHERE node:Entity OR node:Community
240 | MERGE (community)-[r:HAS_MEMBER {uuid: $uuid}]->(node)
241 | SET r.uuid= $uuid
242 | SET r.group_id= $group_id
243 | SET r.created_at= $created_at
244 | RETURN r.uuid AS uuid
245 | """
246 | case GraphProvider.KUZU:
247 | return """
248 | MATCH (community:Community {uuid: $community_uuid})
249 | MATCH (node:Entity {uuid: $entity_uuid})
250 | MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
251 | SET
252 | e.group_id = $group_id,
253 | e.created_at = $created_at
254 | RETURN e.uuid AS uuid
255 | UNION
256 | MATCH (community:Community {uuid: $community_uuid})
257 | MATCH (node:Community {uuid: $entity_uuid})
258 | MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
259 | SET
260 | e.group_id = $group_id,
261 | e.created_at = $created_at
262 | RETURN e.uuid AS uuid
263 | """
264 | case _: # Neo4j
265 | return """
266 | MATCH (community:Community {uuid: $community_uuid})
267 | MATCH (node:Entity | Community {uuid: $entity_uuid})
268 | MERGE (community)-[e:HAS_MEMBER {uuid: $uuid}]->(node)
269 | SET e = {uuid: $uuid, group_id: $group_id, created_at: $created_at}
270 | RETURN e.uuid AS uuid
271 | """
272 |
273 |
274 | COMMUNITY_EDGE_RETURN = """
275 | e.uuid AS uuid,
276 | e.group_id AS group_id,
277 | n.uuid AS source_node_uuid,
278 | m.uuid AS target_node_uuid,
279 | e.created_at AS created_at
280 | """
281 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/run_tests.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Test runner for Graphiti MCP integration tests.
4 | Provides various test execution modes and reporting options.
5 | """
6 |
7 | import argparse
8 | import os
9 | import sys
10 | import time
11 | from pathlib import Path
12 |
13 | import pytest
14 | from dotenv import load_dotenv
15 |
16 | # Load environment variables from .env file
17 | env_file = Path(__file__).parent.parent / '.env'
18 | if env_file.exists():
19 | load_dotenv(env_file)
20 | else:
21 | # Try loading from current directory
22 | load_dotenv()
23 |
24 |
25 | class TestRunner:
26 | """Orchestrate test execution with various configurations."""
27 |
28 | def __init__(self, args):
29 | self.args = args
30 | self.test_dir = Path(__file__).parent
31 | self.results = {}
32 |
33 | def check_prerequisites(self) -> dict[str, bool]:
34 | """Check if required services and dependencies are available."""
35 | checks = {}
36 |
37 | # Check for OpenAI API key if not using mocks
38 | if not self.args.mock_llm:
39 | api_key = os.environ.get('OPENAI_API_KEY')
40 | checks['openai_api_key'] = bool(api_key)
41 | if not api_key:
42 | # Check if .env file exists for helpful message
43 | env_path = Path(__file__).parent.parent / '.env'
44 | if not env_path.exists():
45 | checks['openai_api_key_hint'] = (
46 | 'Set OPENAI_API_KEY in environment or create mcp_server/.env file'
47 | )
48 | else:
49 | checks['openai_api_key'] = True
50 |
51 | # Check database availability based on backend
52 | if self.args.database == 'neo4j':
53 | checks['neo4j'] = self._check_neo4j()
54 | elif self.args.database == 'falkordb':
55 | checks['falkordb'] = self._check_falkordb()
56 |
57 | # Check Python dependencies
58 | checks['mcp'] = self._check_python_package('mcp')
59 | checks['pytest'] = self._check_python_package('pytest')
60 | checks['pytest-asyncio'] = self._check_python_package('pytest-asyncio')
61 |
62 | return checks
63 |
64 | def _check_neo4j(self) -> bool:
65 | """Check if Neo4j is available."""
66 | try:
67 | import neo4j
68 |
69 | # Try to connect
70 | uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
71 | user = os.environ.get('NEO4J_USER', 'neo4j')
72 | password = os.environ.get('NEO4J_PASSWORD', 'graphiti')
73 |
74 | driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
75 | with driver.session() as session:
76 | session.run('RETURN 1')
77 | driver.close()
78 | return True
79 | except Exception:
80 | return False
81 |
82 | def _check_falkordb(self) -> bool:
83 | """Check if FalkorDB is available."""
84 | try:
85 | import redis
86 |
87 | uri = os.environ.get('FALKORDB_URI', 'redis://localhost:6379')
88 | r = redis.from_url(uri)
89 | r.ping()
90 | return True
91 | except Exception:
92 | return False
93 |
94 | def _check_python_package(self, package: str) -> bool:
95 | """Check if a Python package is installed."""
96 | try:
97 | __import__(package.replace('-', '_'))
98 | return True
99 | except ImportError:
100 | return False
101 |
102 | def run_test_suite(self, suite: str) -> int:
103 | """Run a specific test suite."""
104 | pytest_args = ['-v', '--tb=short']
105 |
106 | # Add database marker
107 | if self.args.database:
108 | for db in ['neo4j', 'falkordb']:
109 | if db != self.args.database:
110 | pytest_args.extend(['-m', f'not requires_{db}'])
111 |
112 | # Add suite-specific arguments
113 | if suite == 'unit':
114 | pytest_args.extend(['-m', 'unit', 'test_*.py'])
115 | elif suite == 'integration':
116 | pytest_args.extend(['-m', 'integration or not unit', 'test_*.py'])
117 | elif suite == 'comprehensive':
118 | pytest_args.append('test_comprehensive_integration.py')
119 | elif suite == 'async':
120 | pytest_args.append('test_async_operations.py')
121 | elif suite == 'stress':
122 | pytest_args.extend(['-m', 'slow', 'test_stress_load.py'])
123 | elif suite == 'smoke':
124 | # Quick smoke test - just basic operations
125 | pytest_args.extend(
126 | [
127 | 'test_comprehensive_integration.py::TestCoreOperations::test_server_initialization',
128 | 'test_comprehensive_integration.py::TestCoreOperations::test_add_text_memory',
129 | ]
130 | )
131 | elif suite == 'all':
132 | pytest_args.append('.')
133 | else:
134 | pytest_args.append(suite)
135 |
136 | # Add coverage if requested
137 | if self.args.coverage:
138 | pytest_args.extend(['--cov=../src', '--cov-report=html'])
139 |
140 | # Add parallel execution if requested
141 | if self.args.parallel:
142 | pytest_args.extend(['-n', str(self.args.parallel)])
143 |
144 | # Add verbosity
145 | if self.args.verbose:
146 | pytest_args.append('-vv')
147 |
148 | # Add markers to skip
149 | if self.args.skip_slow:
150 | pytest_args.extend(['-m', 'not slow'])
151 |
152 | # Add timeout override
153 | if self.args.timeout:
154 | pytest_args.extend(['--timeout', str(self.args.timeout)])
155 |
156 | # Add environment variables
157 | env = os.environ.copy()
158 | if self.args.mock_llm:
159 | env['USE_MOCK_LLM'] = 'true'
160 | if self.args.database:
161 | env['DATABASE_PROVIDER'] = self.args.database
162 |
163 | # Run tests from the test directory
164 | print(f'Running {suite} tests with pytest args: {" ".join(pytest_args)}')
165 |
166 | # Change to test directory to run tests
167 | original_dir = os.getcwd()
168 | os.chdir(self.test_dir)
169 |
170 | try:
171 | result = pytest.main(pytest_args)
172 | finally:
173 | os.chdir(original_dir)
174 |
175 | return result
176 |
177 | def run_performance_benchmark(self):
178 | """Run performance benchmarking suite."""
179 | print('Running performance benchmarks...')
180 |
181 | # Import test modules
182 |
183 | # Run performance tests
184 | result = pytest.main(
185 | [
186 | '-v',
187 | 'test_comprehensive_integration.py::TestPerformance',
188 | 'test_async_operations.py::TestAsyncPerformance',
189 | '--benchmark-only' if self.args.benchmark_only else '',
190 | ]
191 | )
192 |
193 | return result
194 |
195 | def generate_report(self):
196 | """Generate test execution report."""
197 | report = []
198 | report.append('\n' + '=' * 60)
199 | report.append('GRAPHITI MCP TEST EXECUTION REPORT')
200 | report.append('=' * 60)
201 |
202 | # Prerequisites check
203 | checks = self.check_prerequisites()
204 | report.append('\nPrerequisites:')
205 | for check, passed in checks.items():
206 | status = '✅' if passed else '❌'
207 | report.append(f' {status} {check}')
208 |
209 | # Test configuration
210 | report.append('\nConfiguration:')
211 | report.append(f' Database: {self.args.database}')
212 | report.append(f' Mock LLM: {self.args.mock_llm}')
213 | report.append(f' Parallel: {self.args.parallel or "No"}')
214 | report.append(f' Timeout: {self.args.timeout}s')
215 |
216 | # Results summary (if available)
217 | if self.results:
218 | report.append('\nResults:')
219 | for suite, result in self.results.items():
220 | status = '✅ Passed' if result == 0 else f'❌ Failed ({result})'
221 | report.append(f' {suite}: {status}')
222 |
223 | report.append('=' * 60)
224 | return '\n'.join(report)
225 |
226 |
227 | def main():
228 | """Main entry point for test runner."""
229 | parser = argparse.ArgumentParser(
230 | description='Run Graphiti MCP integration tests',
231 | formatter_class=argparse.RawDescriptionHelpFormatter,
232 | epilog="""
233 | Test Suites:
234 | unit - Run unit tests only
235 | integration - Run integration tests
236 | comprehensive - Run comprehensive integration test suite
237 | async - Run async operation tests
238 | stress - Run stress and load tests
239 | smoke - Run quick smoke tests
240 | all - Run all tests
241 |
242 | Examples:
243 | python run_tests.py smoke # Quick smoke test
244 | python run_tests.py integration --parallel 4 # Run integration tests in parallel
245 | python run_tests.py stress --database neo4j # Run stress tests with Neo4j
246 | python run_tests.py all --coverage # Run all tests with coverage
247 | """,
248 | )
249 |
250 | parser.add_argument(
251 | 'suite',
252 | choices=['unit', 'integration', 'comprehensive', 'async', 'stress', 'smoke', 'all'],
253 | help='Test suite to run',
254 | )
255 |
256 | parser.add_argument(
257 | '--database',
258 | choices=['neo4j', 'falkordb'],
259 | default='falkordb',
260 | help='Database backend to test (default: falkordb)',
261 | )
262 |
263 | parser.add_argument('--mock-llm', action='store_true', help='Use mock LLM for faster testing')
264 |
265 | parser.add_argument(
266 | '--parallel', type=int, metavar='N', help='Run tests in parallel with N workers'
267 | )
268 |
269 | parser.add_argument('--coverage', action='store_true', help='Generate coverage report')
270 |
271 | parser.add_argument('--verbose', action='store_true', help='Verbose output')
272 |
273 | parser.add_argument('--skip-slow', action='store_true', help='Skip slow tests')
274 |
275 | parser.add_argument(
276 | '--timeout', type=int, default=300, help='Test timeout in seconds (default: 300)'
277 | )
278 |
279 | parser.add_argument('--benchmark-only', action='store_true', help='Run only benchmark tests')
280 |
281 | parser.add_argument(
282 | '--check-only', action='store_true', help='Only check prerequisites without running tests'
283 | )
284 |
285 | args = parser.parse_args()
286 |
287 | # Create test runner
288 | runner = TestRunner(args)
289 |
290 | # Check prerequisites
291 | if args.check_only:
292 | print(runner.generate_report())
293 | sys.exit(0)
294 |
295 | # Check if prerequisites are met
296 | checks = runner.check_prerequisites()
297 | # Filter out hint keys from validation
298 | validation_checks = {k: v for k, v in checks.items() if not k.endswith('_hint')}
299 |
300 | if not all(validation_checks.values()):
301 | print('⚠️ Some prerequisites are not met:')
302 | for check, passed in checks.items():
303 | if check.endswith('_hint'):
304 | continue # Skip hint entries
305 | if not passed:
306 | print(f' ❌ {check}')
307 | # Show hint if available
308 | hint_key = f'{check}_hint'
309 | if hint_key in checks:
310 | print(f' 💡 {checks[hint_key]}')
311 |
312 | if not args.mock_llm and not checks.get('openai_api_key'):
313 | print('\n💡 Tip: Use --mock-llm to run tests without OpenAI API key')
314 |
315 | response = input('\nContinue anyway? (y/N): ')
316 | if response.lower() != 'y':
317 | sys.exit(1)
318 |
319 | # Run tests
320 | print(f'\n🚀 Starting test execution: {args.suite}')
321 | start_time = time.time()
322 |
323 | if args.benchmark_only:
324 | result = runner.run_performance_benchmark()
325 | else:
326 | result = runner.run_test_suite(args.suite)
327 |
328 | duration = time.time() - start_time
329 |
330 | # Store results
331 | runner.results[args.suite] = result
332 |
333 | # Generate and print report
334 | print(runner.generate_report())
335 | print(f'\n⏱️ Test execution completed in {duration:.2f} seconds')
336 |
337 | # Exit with test result code
338 | sys.exit(result)
339 |
340 |
341 | if __name__ == '__main__':
342 | main()
343 |
```
--------------------------------------------------------------------------------
/graphiti_core/driver/neptune_driver.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import asyncio
18 | import datetime
19 | import logging
20 | from collections.abc import Coroutine
21 | from typing import Any
22 |
23 | import boto3
24 | from langchain_aws.graphs import NeptuneAnalyticsGraph, NeptuneGraph
25 | from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
26 |
27 | from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
28 |
29 | logger = logging.getLogger(__name__)
30 | DEFAULT_SIZE = 10
31 |
32 | aoss_indices = [
33 | {
34 | 'index_name': 'node_name_and_summary',
35 | 'body': {
36 | 'mappings': {
37 | 'properties': {
38 | 'uuid': {'type': 'keyword'},
39 | 'name': {'type': 'text'},
40 | 'summary': {'type': 'text'},
41 | 'group_id': {'type': 'text'},
42 | }
43 | }
44 | },
45 | 'query': {
46 | 'query': {'multi_match': {'query': '', 'fields': ['name', 'summary', 'group_id']}},
47 | 'size': DEFAULT_SIZE,
48 | },
49 | },
50 | {
51 | 'index_name': 'community_name',
52 | 'body': {
53 | 'mappings': {
54 | 'properties': {
55 | 'uuid': {'type': 'keyword'},
56 | 'name': {'type': 'text'},
57 | 'group_id': {'type': 'text'},
58 | }
59 | }
60 | },
61 | 'query': {
62 | 'query': {'multi_match': {'query': '', 'fields': ['name', 'group_id']}},
63 | 'size': DEFAULT_SIZE,
64 | },
65 | },
66 | {
67 | 'index_name': 'episode_content',
68 | 'body': {
69 | 'mappings': {
70 | 'properties': {
71 | 'uuid': {'type': 'keyword'},
72 | 'content': {'type': 'text'},
73 | 'source': {'type': 'text'},
74 | 'source_description': {'type': 'text'},
75 | 'group_id': {'type': 'text'},
76 | }
77 | }
78 | },
79 | 'query': {
80 | 'query': {
81 | 'multi_match': {
82 | 'query': '',
83 | 'fields': ['content', 'source', 'source_description', 'group_id'],
84 | }
85 | },
86 | 'size': DEFAULT_SIZE,
87 | },
88 | },
89 | {
90 | 'index_name': 'edge_name_and_fact',
91 | 'body': {
92 | 'mappings': {
93 | 'properties': {
94 | 'uuid': {'type': 'keyword'},
95 | 'name': {'type': 'text'},
96 | 'fact': {'type': 'text'},
97 | 'group_id': {'type': 'text'},
98 | }
99 | }
100 | },
101 | 'query': {
102 | 'query': {'multi_match': {'query': '', 'fields': ['name', 'fact', 'group_id']}},
103 | 'size': DEFAULT_SIZE,
104 | },
105 | },
106 | ]
107 |
108 |
109 | class NeptuneDriver(GraphDriver):
110 | provider: GraphProvider = GraphProvider.NEPTUNE
111 |
112 | def __init__(self, host: str, aoss_host: str, port: int = 8182, aoss_port: int = 443):
113 | """This initializes a NeptuneDriver for use with Neptune as a backend
114 |
115 | Args:
116 | host (str): The Neptune Database or Neptune Analytics host
117 | aoss_host (str): The OpenSearch host value
118 | port (int, optional): The Neptune Database port, ignored for Neptune Analytics. Defaults to 8182.
119 | aoss_port (int, optional): The OpenSearch port. Defaults to 443.
120 | """
121 | if not host:
122 | raise ValueError('You must provide an endpoint to create a NeptuneDriver')
123 |
124 | if host.startswith('neptune-db://'):
125 | # This is a Neptune Database Cluster
126 | endpoint = host.replace('neptune-db://', '')
127 | self.client = NeptuneGraph(endpoint, port)
128 | logger.debug('Creating Neptune Database session for %s', host)
129 | elif host.startswith('neptune-graph://'):
130 | # This is a Neptune Analytics Graph
131 | graphId = host.replace('neptune-graph://', '')
132 | self.client = NeptuneAnalyticsGraph(graphId)
133 | logger.debug('Creating Neptune Graph session for %s', host)
134 | else:
135 | raise ValueError(
136 | 'You must provide an endpoint to create a NeptuneDriver as either neptune-db://<endpoint> or neptune-graph://<graphid>'
137 | )
138 |
139 | if not aoss_host:
140 | raise ValueError('You must provide an AOSS endpoint to create an OpenSearch driver.')
141 |
142 | session = boto3.Session()
143 | self.aoss_client = OpenSearch(
144 | hosts=[{'host': aoss_host, 'port': aoss_port}],
145 | http_auth=Urllib3AWSV4SignerAuth(
146 | session.get_credentials(), session.region_name, 'aoss'
147 | ),
148 | use_ssl=True,
149 | verify_certs=True,
150 | connection_class=Urllib3HttpConnection,
151 | pool_maxsize=20,
152 | )
153 |
154 | def _sanitize_parameters(self, query, params: dict):
155 | if isinstance(query, list):
156 | queries = []
157 | for q in query:
158 | queries.append(self._sanitize_parameters(q, params))
159 | return queries
160 | else:
161 | for k, v in params.items():
162 | if isinstance(v, datetime.datetime):
163 | params[k] = v.isoformat()
164 | elif isinstance(v, list):
165 | # Handle lists that might contain datetime objects
166 | for i, item in enumerate(v):
167 | if isinstance(item, datetime.datetime):
168 | v[i] = item.isoformat()
169 | query = str(query).replace(f'${k}', f'datetime(${k})')
170 | if isinstance(item, dict):
171 | query = self._sanitize_parameters(query, v[i])
172 |
173 | # If the list contains datetime objects, we need to wrap each element with datetime()
174 | if any(isinstance(item, str) and 'T' in item for item in v):
175 | # Create a new list expression with datetime() wrapped around each element
176 | datetime_list = (
177 | '['
178 | + ', '.join(
179 | f'datetime("{item}")'
180 | if isinstance(item, str) and 'T' in item
181 | else repr(item)
182 | for item in v
183 | )
184 | + ']'
185 | )
186 | query = str(query).replace(f'${k}', datetime_list)
187 | elif isinstance(v, dict):
188 | query = self._sanitize_parameters(query, v)
189 | return query
190 |
191 | async def execute_query(
192 | self, cypher_query_, **kwargs: Any
193 | ) -> tuple[dict[str, Any], None, None]:
194 | params = dict(kwargs)
195 | if isinstance(cypher_query_, list):
196 | for q in cypher_query_:
197 | result, _, _ = self._run_query(q[0], q[1])
198 | return result, None, None
199 | else:
200 | return self._run_query(cypher_query_, params)
201 |
202 | def _run_query(self, cypher_query_, params):
203 | cypher_query_ = str(self._sanitize_parameters(cypher_query_, params))
204 | try:
205 | result = self.client.query(cypher_query_, params=params)
206 | except Exception as e:
207 | logger.error('Query: %s', cypher_query_)
208 | logger.error('Parameters: %s', params)
209 | logger.error('Error executing query: %s', e)
210 | raise e
211 |
212 | return result, None, None
213 |
214 | def session(self, database: str | None = None) -> GraphDriverSession:
215 | return NeptuneDriverSession(driver=self)
216 |
217 | async def close(self) -> None:
218 | return self.client.client.close()
219 |
220 | async def _delete_all_data(self) -> Any:
221 | return await self.execute_query('MATCH (n) DETACH DELETE n')
222 |
223 | def delete_all_indexes(self) -> Coroutine[Any, Any, Any]:
224 | return self.delete_all_indexes_impl()
225 |
226 | async def delete_all_indexes_impl(self) -> Coroutine[Any, Any, Any]:
227 | # No matter what happens above, always return True
228 | return self.delete_aoss_indices()
229 |
230 | async def create_aoss_indices(self):
231 | for index in aoss_indices:
232 | index_name = index['index_name']
233 | client = self.aoss_client
234 | if not client.indices.exists(index=index_name):
235 | client.indices.create(index=index_name, body=index['body'])
236 | # Sleep for 1 minute to let the index creation complete
237 | await asyncio.sleep(60)
238 |
239 | async def delete_aoss_indices(self):
240 | for index in aoss_indices:
241 | index_name = index['index_name']
242 | client = self.aoss_client
243 | if client.indices.exists(index=index_name):
244 | client.indices.delete(index=index_name)
245 |
246 | async def build_indices_and_constraints(self, delete_existing: bool = False):
247 | # Neptune uses OpenSearch (AOSS) for indexing
248 | if delete_existing:
249 | await self.delete_aoss_indices()
250 | await self.create_aoss_indices()
251 |
252 | def run_aoss_query(self, name: str, query_text: str, limit: int = 10) -> dict[str, Any]:
253 | for index in aoss_indices:
254 | if name.lower() == index['index_name']:
255 | index['query']['query']['multi_match']['query'] = query_text
256 | query = {'size': limit, 'query': index['query']}
257 | resp = self.aoss_client.search(body=query['query'], index=index['index_name'])
258 | return resp
259 | return {}
260 |
261 | def save_to_aoss(self, name: str, data: list[dict]) -> int:
262 | for index in aoss_indices:
263 | if name.lower() == index['index_name']:
264 | to_index = []
265 | for d in data:
266 | item = {'_index': name, '_id': d['uuid']}
267 | for p in index['body']['mappings']['properties']:
268 | if p in d:
269 | item[p] = d[p]
270 | to_index.append(item)
271 | success, failed = helpers.bulk(self.aoss_client, to_index, stats_only=True)
272 | return success
273 |
274 | return 0
275 |
276 |
277 | class NeptuneDriverSession(GraphDriverSession):
278 | provider = GraphProvider.NEPTUNE
279 |
280 | def __init__(self, driver: NeptuneDriver): # type: ignore[reportUnknownArgumentType]
281 | self.driver = driver
282 |
283 | async def __aenter__(self):
284 | return self
285 |
286 | async def __aexit__(self, exc_type, exc, tb):
287 | # No cleanup needed for Neptune, but method must exist
288 | pass
289 |
290 | async def close(self):
291 | # No explicit close needed for Neptune, but method must exist
292 | pass
293 |
294 | async def execute_write(self, func, *args, **kwargs):
295 | # Directly await the provided async function with `self` as the transaction/session
296 | return await func(self, *args, **kwargs)
297 |
298 | async def run(self, query: str | list, **kwargs: Any) -> Any:
299 | if isinstance(query, list):
300 | res = None
301 | for q in query:
302 | res = await self.driver.execute_query(q, **kwargs)
303 | return res
304 | else:
305 | return await self.driver.execute_query(str(query), **kwargs)
306 |
```