#
tokens: 48745/50000 19/234 files (page 4/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 4 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/prompts/extract_edges.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from typing import Any, Protocol, TypedDict
 18 | 
 19 | from pydantic import BaseModel, Field
 20 | 
 21 | from .models import Message, PromptFunction, PromptVersion
 22 | from .prompt_helpers import to_prompt_json
 23 | 
 24 | 
 25 | class Edge(BaseModel):
 26 |     relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
 27 |     source_entity_id: int = Field(
 28 |         ..., description='The id of the source entity from the ENTITIES list'
 29 |     )
 30 |     target_entity_id: int = Field(
 31 |         ..., description='The id of the target entity from the ENTITIES list'
 32 |     )
 33 |     fact: str = Field(
 34 |         ...,
 35 |         description='A natural language description of the relationship between the entities, paraphrased from the source text',
 36 |     )
 37 |     valid_at: str | None = Field(
 38 |         None,
 39 |         description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
 40 |     )
 41 |     invalid_at: str | None = Field(
 42 |         None,
 43 |         description='The date and time when the relationship described by the edge fact stopped being true or ended. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
 44 |     )
 45 | 
 46 | 
 47 | class ExtractedEdges(BaseModel):
 48 |     edges: list[Edge]
 49 | 
 50 | 
 51 | class MissingFacts(BaseModel):
 52 |     missing_facts: list[str] = Field(..., description="facts that weren't extracted")
 53 | 
 54 | 
 55 | class Prompt(Protocol):
 56 |     edge: PromptVersion
 57 |     reflexion: PromptVersion
 58 |     extract_attributes: PromptVersion
 59 | 
 60 | 
 61 | class Versions(TypedDict):
 62 |     edge: PromptFunction
 63 |     reflexion: PromptFunction
 64 |     extract_attributes: PromptFunction
 65 | 
 66 | 
 67 | def edge(context: dict[str, Any]) -> list[Message]:
 68 |     return [
 69 |         Message(
 70 |             role='system',
 71 |             content='You are an expert fact extractor that extracts fact triples from text. '
 72 |             '1. Extracted fact triples should also be extracted with relevant date information.'
 73 |             '2. Treat the CURRENT TIME as the time the CURRENT MESSAGE was sent. All temporal information should be extracted relative to this time.',
 74 |         ),
 75 |         Message(
 76 |             role='user',
 77 |             content=f"""
 78 | <FACT TYPES>
 79 | {context['edge_types']}
 80 | </FACT TYPES>
 81 | 
 82 | <PREVIOUS_MESSAGES>
 83 | {to_prompt_json([ep for ep in context['previous_episodes']])}
 84 | </PREVIOUS_MESSAGES>
 85 | 
 86 | <CURRENT_MESSAGE>
 87 | {context['episode_content']}
 88 | </CURRENT_MESSAGE>
 89 | 
 90 | <ENTITIES>
 91 | {to_prompt_json(context['nodes'])}
 92 | </ENTITIES>
 93 | 
 94 | <REFERENCE_TIME>
 95 | {context['reference_time']}  # ISO 8601 (UTC); used to resolve relative time mentions
 96 | </REFERENCE_TIME>
 97 | 
 98 | # TASK
 99 | Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
100 | Only extract facts that:
101 | - involve two DISTINCT ENTITIES from the ENTITIES list,
102 | - are clearly stated or unambiguously implied in the CURRENT MESSAGE,
103 |     and can be represented as edges in a knowledge graph.
104 | - Facts should include entity names rather than pronouns whenever possible.
105 | - The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
106 | - The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
107 |     of the FACT TYPES
108 | - The FACT TYPES each contain their fact_type_signature which represents the source and target entity types.
109 | 
110 | You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
111 | 
112 | 
113 | {context['custom_prompt']}
114 | 
115 | # EXTRACTION RULES
116 | 
117 | 1. **Entity ID Validation**: `source_entity_id` and `target_entity_id` must use only the `id` values from the ENTITIES list provided above.
118 |    - **CRITICAL**: Using IDs not in the list will cause the edge to be rejected
119 | 2. Each fact must involve two **distinct** entities.
120 | 3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
121 | 4. Do not emit duplicate or semantically redundant facts.
122 | 5. The `fact` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text.
123 | 6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
124 | 7. Do **not** hallucinate or infer temporal bounds from unrelated events.
125 | 
126 | # DATETIME RULES
127 | 
128 | - Use ISO 8601 with “Z” suffix (UTC) (e.g., 2025-04-30T00:00:00Z).
129 | - If the fact is ongoing (present tense), set `valid_at` to REFERENCE_TIME.
130 | - If a change/termination is expressed, set `invalid_at` to the relevant timestamp.
131 | - Leave both fields `null` if no explicit or resolvable time is stated.
132 | - If only a date is mentioned (no time), assume 00:00:00.
133 | - If only a year is mentioned, use January 1st at 00:00:00.
134 |         """,
135 |         ),
136 |     ]
137 | 
138 | 
139 | def reflexion(context: dict[str, Any]) -> list[Message]:
140 |     sys_prompt = """You are an AI assistant that determines which facts have not been extracted from the given context"""
141 | 
142 |     user_prompt = f"""
143 | <PREVIOUS MESSAGES>
144 | {to_prompt_json([ep for ep in context['previous_episodes']])}
145 | </PREVIOUS MESSAGES>
146 | <CURRENT MESSAGE>
147 | {context['episode_content']}
148 | </CURRENT MESSAGE>
149 | 
150 | <EXTRACTED ENTITIES>
151 | {context['nodes']}
152 | </EXTRACTED ENTITIES>
153 | 
154 | <EXTRACTED FACTS>
155 | {context['extracted_facts']}
156 | </EXTRACTED FACTS>
157 | 
158 | Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS; 
159 | determine if any facts haven't been extracted.
160 | """
161 |     return [
162 |         Message(role='system', content=sys_prompt),
163 |         Message(role='user', content=user_prompt),
164 |     ]
165 | 
166 | 
167 | def extract_attributes(context: dict[str, Any]) -> list[Message]:
168 |     return [
169 |         Message(
170 |             role='system',
171 |             content='You are a helpful assistant that extracts fact properties from the provided text.',
172 |         ),
173 |         Message(
174 |             role='user',
175 |             content=f"""
176 | 
177 |         <MESSAGE>
178 |         {to_prompt_json(context['episode_content'])}
179 |         </MESSAGE>
180 |         <REFERENCE TIME>
181 |         {context['reference_time']}
182 |         </REFERENCE TIME>
183 | 
184 |         Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
185 |         in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
186 | 
187 |         Guidelines:
188 |         1. Do not hallucinate entity property values if they cannot be found in the current context.
189 |         2. Only use the provided MESSAGES and FACT to set attribute values.
190 | 
191 |         <FACT>
192 |         {context['fact']}
193 |         </FACT>
194 |         """,
195 |         ),
196 |     ]
197 | 
198 | 
199 | versions: Versions = {
200 |     'edge': edge,
201 |     'reflexion': reflexion,
202 |     'extract_attributes': extract_attributes,
203 | }
204 | 
```

--------------------------------------------------------------------------------
/graphiti_core/embedder/gemini.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | from collections.abc import Iterable
 19 | from typing import TYPE_CHECKING
 20 | 
 21 | if TYPE_CHECKING:
 22 |     from google import genai
 23 |     from google.genai import types
 24 | else:
 25 |     try:
 26 |         from google import genai
 27 |         from google.genai import types
 28 |     except ImportError:
 29 |         raise ImportError(
 30 |             'google-genai is required for GeminiEmbedder. '
 31 |             'Install it with: pip install graphiti-core[google-genai]'
 32 |         ) from None
 33 | 
 34 | from pydantic import Field
 35 | 
 36 | from .client import EmbedderClient, EmbedderConfig
 37 | 
 38 | logger = logging.getLogger(__name__)
 39 | 
 40 | DEFAULT_EMBEDDING_MODEL = 'text-embedding-001'  # gemini-embedding-001 or text-embedding-005
 41 | 
 42 | DEFAULT_BATCH_SIZE = 100
 43 | 
 44 | 
 45 | class GeminiEmbedderConfig(EmbedderConfig):
 46 |     embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
 47 |     api_key: str | None = None
 48 | 
 49 | 
 50 | class GeminiEmbedder(EmbedderClient):
 51 |     """
 52 |     Google Gemini Embedder Client
 53 |     """
 54 | 
 55 |     def __init__(
 56 |         self,
 57 |         config: GeminiEmbedderConfig | None = None,
 58 |         client: 'genai.Client | None' = None,
 59 |         batch_size: int | None = None,
 60 |     ):
 61 |         """
 62 |         Initialize the GeminiEmbedder with the provided configuration and client.
 63 | 
 64 |         Args:
 65 |             config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
 66 |             client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
 67 |             batch_size (int | None): An optional batch size to use. If not provided, the default batch size will be used.
 68 |         """
 69 |         if config is None:
 70 |             config = GeminiEmbedderConfig()
 71 | 
 72 |         self.config = config
 73 | 
 74 |         if client is None:
 75 |             self.client = genai.Client(api_key=config.api_key)
 76 |         else:
 77 |             self.client = client
 78 | 
 79 |         if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
 80 |             # Gemini API has a limit on the number of instances per request
 81 |             # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
 82 |             self.batch_size = 1
 83 |         elif batch_size is None:
 84 |             self.batch_size = DEFAULT_BATCH_SIZE
 85 |         else:
 86 |             self.batch_size = batch_size
 87 | 
 88 |     async def create(
 89 |         self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
 90 |     ) -> list[float]:
 91 |         """
 92 |         Create embeddings for the given input data using Google's Gemini embedding model.
 93 | 
 94 |         Args:
 95 |             input_data: The input data to create embeddings for. Can be a string, list of strings,
 96 |                        or an iterable of integers or iterables of integers.
 97 | 
 98 |         Returns:
 99 |             A list of floats representing the embedding vector.
100 |         """
101 |         # Generate embeddings
102 |         result = await self.client.aio.models.embed_content(
103 |             model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
104 |             contents=[input_data],  # type: ignore[arg-type]  # mypy fails on broad union type
105 |             config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
106 |         )
107 | 
108 |         if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
109 |             raise ValueError('No embeddings returned from Gemini API in create()')
110 | 
111 |         return result.embeddings[0].values
112 | 
113 |     async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
114 |         """
115 |         Create embeddings for a batch of input data using Google's Gemini embedding model.
116 | 
117 |         This method handles batching to respect the Gemini API's limits on the number
118 |         of instances that can be processed in a single request.
119 | 
120 |         Args:
121 |             input_data_list: A list of strings to create embeddings for.
122 | 
123 |         Returns:
124 |             A list of embedding vectors (each vector is a list of floats).
125 |         """
126 |         if not input_data_list:
127 |             return []
128 | 
129 |         batch_size = self.batch_size
130 |         all_embeddings = []
131 | 
132 |         # Process inputs in batches
133 |         for i in range(0, len(input_data_list), batch_size):
134 |             batch = input_data_list[i : i + batch_size]
135 | 
136 |             try:
137 |                 # Generate embeddings for this batch
138 |                 result = await self.client.aio.models.embed_content(
139 |                     model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
140 |                     contents=batch,  # type: ignore[arg-type]  # mypy fails on broad union type
141 |                     config=types.EmbedContentConfig(
142 |                         output_dimensionality=self.config.embedding_dim
143 |                     ),
144 |                 )
145 | 
146 |                 if not result.embeddings or len(result.embeddings) == 0:
147 |                     raise Exception('No embeddings returned')
148 | 
149 |                 # Process embeddings from this batch
150 |                 for embedding in result.embeddings:
151 |                     if not embedding.values:
152 |                         raise ValueError('Empty embedding values returned')
153 |                     all_embeddings.append(embedding.values)
154 | 
155 |             except Exception as e:
156 |                 # If batch processing fails, fall back to individual processing
157 |                 logger.warning(
158 |                     f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
159 |                 )
160 | 
161 |                 for item in batch:
162 |                     try:
163 |                         # Process each item individually
164 |                         result = await self.client.aio.models.embed_content(
165 |                             model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
166 |                             contents=[item],  # type: ignore[arg-type]  # mypy fails on broad union type
167 |                             config=types.EmbedContentConfig(
168 |                                 output_dimensionality=self.config.embedding_dim
169 |                             ),
170 |                         )
171 | 
172 |                         if not result.embeddings or len(result.embeddings) == 0:
173 |                             raise ValueError('No embeddings returned from Gemini API')
174 |                         if not result.embeddings[0].values:
175 |                             raise ValueError('Empty embedding values returned')
176 | 
177 |                         all_embeddings.append(result.embeddings[0].values)
178 | 
179 |                     except Exception as individual_error:
180 |                         logger.error(f'Failed to embed individual item: {individual_error}')
181 |                         raise individual_error
182 | 
183 |         return all_embeddings
184 | 
```

--------------------------------------------------------------------------------
/mcp_server/docker/README-falkordb-combined.md:
--------------------------------------------------------------------------------

```markdown
  1 | # FalkorDB + Graphiti MCP Server Combined Image
  2 | 
  3 | This Docker setup bundles FalkorDB (graph database) and the Graphiti MCP Server into a single container image for simplified deployment.
  4 | 
  5 | ## Overview
  6 | 
  7 | The combined image extends the official FalkorDB Docker image to include:
  8 | - **FalkorDB**: Redis-based graph database running on port 6379
  9 | - **FalkorDB Web UI**: Graph visualization interface on port 3000
 10 | - **Graphiti MCP Server**: Knowledge graph API on port 8000
 11 | 
 12 | Both services are managed by a startup script that launches FalkorDB as a daemon and the MCP server in the foreground.
 13 | 
 14 | ## Quick Start
 15 | 
 16 | ### Using Docker Compose (Recommended)
 17 | 
 18 | 1. Create a `.env` file in the `mcp_server` directory:
 19 | 
 20 | ```bash
 21 | # Required
 22 | OPENAI_API_KEY=your_openai_api_key
 23 | 
 24 | # Optional
 25 | GRAPHITI_GROUP_ID=main
 26 | SEMAPHORE_LIMIT=10
 27 | FALKORDB_PASSWORD=
 28 | ```
 29 | 
 30 | 2. Start the combined service:
 31 | 
 32 | ```bash
 33 | cd mcp_server
 34 | docker compose -f docker/docker-compose-falkordb-combined.yml up
 35 | ```
 36 | 
 37 | 3. Access the services:
 38 |    - MCP Server: http://localhost:8000/mcp/
 39 |    - FalkorDB Web UI: http://localhost:3000
 40 |    - FalkorDB (Redis): localhost:6379
 41 | 
 42 | ### Using Docker Run
 43 | 
 44 | ```bash
 45 | docker run -d \
 46 |   -p 6379:6379 \
 47 |   -p 3000:3000 \
 48 |   -p 8000:8000 \
 49 |   -e OPENAI_API_KEY=your_key \
 50 |   -e GRAPHITI_GROUP_ID=main \
 51 |   -v falkordb_data:/var/lib/falkordb/data \
 52 |   zepai/graphiti-falkordb:latest
 53 | ```
 54 | 
 55 | ## Building the Image
 56 | 
 57 | ### Build with Default Version
 58 | 
 59 | ```bash
 60 | docker compose -f docker/docker-compose-falkordb-combined.yml build
 61 | ```
 62 | 
 63 | ### Build with Specific Graphiti Version
 64 | 
 65 | ```bash
 66 | GRAPHITI_CORE_VERSION=0.22.0 docker compose -f docker/docker-compose-falkordb-combined.yml build
 67 | ```
 68 | 
 69 | ### Build Arguments
 70 | 
 71 | - `GRAPHITI_CORE_VERSION`: Version of graphiti-core package (default: 0.22.0)
 72 | - `MCP_SERVER_VERSION`: MCP server version tag (default: 1.0.0rc0)
 73 | - `BUILD_DATE`: Build timestamp
 74 | - `VCS_REF`: Git commit hash
 75 | 
 76 | ## Configuration
 77 | 
 78 | ### Environment Variables
 79 | 
 80 | All environment variables from the standard MCP server are supported:
 81 | 
 82 | **Required:**
 83 | - `OPENAI_API_KEY`: OpenAI API key for LLM operations
 84 | 
 85 | **Optional:**
 86 | - `BROWSER`: Enable FalkorDB Browser web UI on port 3000 (default: "1", set to "0" to disable)
 87 | - `GRAPHITI_GROUP_ID`: Namespace for graph data (default: "main")
 88 | - `SEMAPHORE_LIMIT`: Concurrency limit for episode processing (default: 10)
 89 | - `FALKORDB_PASSWORD`: Password for FalkorDB (optional)
 90 | - `FALKORDB_DATABASE`: FalkorDB database name (default: "default_db")
 91 | 
 92 | **Other LLM Providers:**
 93 | - `ANTHROPIC_API_KEY`: For Claude models
 94 | - `GOOGLE_API_KEY`: For Gemini models
 95 | - `GROQ_API_KEY`: For Groq models
 96 | 
 97 | ### Volumes
 98 | 
 99 | - `/var/lib/falkordb/data`: Persistent storage for graph data
100 | - `/var/log/graphiti`: MCP server and FalkorDB Browser logs
101 | 
102 | ## Service Management
103 | 
104 | ### View Logs
105 | 
106 | ```bash
107 | # All logs (both services stdout/stderr)
108 | docker compose -f docker/docker-compose-falkordb-combined.yml logs -f
109 | 
110 | # Only container logs
111 | docker compose -f docker/docker-compose-falkordb-combined.yml logs -f graphiti-falkordb
112 | ```
113 | 
114 | ### Restart Services
115 | 
116 | ```bash
117 | # Restart entire container (both services)
118 | docker compose -f docker/docker-compose-falkordb-combined.yml restart
119 | 
120 | # Check FalkorDB status
121 | docker compose -f docker/docker-compose-falkordb-combined.yml exec graphiti-falkordb redis-cli ping
122 | 
123 | # Check MCP server status
124 | curl http://localhost:8000/health
125 | ```
126 | 
127 | ### Disabling the FalkorDB Browser
128 | 
129 | To disable the FalkorDB Browser web UI (port 3000), set the `BROWSER` environment variable to `0`:
130 | 
131 | ```bash
132 | # Using docker run
133 | docker run -d \
134 |   -p 6379:6379 \
135 |   -p 3000:3000 \
136 |   -p 8000:8000 \
137 |   -e BROWSER=0 \
138 |   -e OPENAI_API_KEY=your_key \
139 |   zepai/graphiti-falkordb:latest
140 | 
141 | # Using docker-compose
142 | # Add to your .env file:
143 | BROWSER=0
144 | ```
145 | 
146 | When disabled, only FalkorDB (port 6379) and the MCP server (port 8000) will run.
147 | 
148 | ## Health Checks
149 | 
150 | The container includes a health check that verifies:
151 | 1. FalkorDB is responding to ping
152 | 2. MCP server health endpoint is accessible
153 | 
154 | Check health status:
155 | ```bash
156 | docker compose -f docker/docker-compose-falkordb-combined.yml ps
157 | ```
158 | 
159 | ## Architecture
160 | 
161 | ### Process Structure
162 | ```
163 | start-services.sh (PID 1)
164 | ├── redis-server (FalkorDB daemon)
165 | ├── node server.js (FalkorDB Browser - background, if BROWSER=1)
166 | └── uv run main.py (MCP server - foreground)
167 | ```
168 | 
169 | The startup script launches FalkorDB as a background daemon, waits for it to be ready, optionally starts the FalkorDB Browser (if `BROWSER=1`), then starts the MCP server in the foreground. When the MCP server stops, the container exits.
170 | 
171 | ### Directory Structure
172 | ```
173 | /app/mcp/                    # MCP server application
174 | ├── main.py
175 | ├── src/
176 | ├── config/
177 | │   └── config.yaml          # FalkorDB-specific configuration
178 | └── .graphiti-core-version   # Installed version info
179 | 
180 | /var/lib/falkordb/data/      # Persistent graph storage
181 | /var/lib/falkordb/browser/   # FalkorDB Browser web UI
182 | /var/log/graphiti/           # MCP server and Browser logs
183 | /start-services.sh           # Startup script
184 | ```
185 | 
186 | ## Benefits of Combined Image
187 | 
188 | 1. **Simplified Deployment**: Single container to manage
189 | 2. **Reduced Network Latency**: Localhost communication between services
190 | 3. **Easier Development**: One command to start entire stack
191 | 4. **Unified Logging**: All logs available via docker logs
192 | 5. **Resource Efficiency**: Shared base image and dependencies
193 | 
194 | ## Troubleshooting
195 | 
196 | ### FalkorDB Not Starting
197 | 
198 | Check container logs:
199 | ```bash
200 | docker compose -f docker/docker-compose-falkordb-combined.yml logs graphiti-falkordb
201 | ```
202 | 
203 | ### MCP Server Connection Issues
204 | 
205 | 1. Verify FalkorDB is running:
206 | ```bash
207 | docker compose -f docker/docker-compose-falkordb-combined.yml exec graphiti-falkordb redis-cli ping
208 | ```
209 | 
210 | 2. Check MCP server health:
211 | ```bash
212 | curl http://localhost:8000/health
213 | ```
214 | 
215 | 3. View all container logs:
216 | ```bash
217 | docker compose -f docker/docker-compose-falkordb-combined.yml logs -f
218 | ```
219 | 
220 | ### Port Conflicts
221 | 
222 | If ports 6379, 3000, or 8000 are already in use, modify the port mappings in `docker-compose-falkordb-combined.yml`:
223 | 
224 | ```yaml
225 | ports:
226 |   - "16379:6379"  # Use different external port
227 |   - "13000:3000"
228 |   - "18000:8000"
229 | ```
230 | 
231 | ## Production Considerations
232 | 
233 | 1. **Resource Limits**: Add resource constraints in docker-compose:
234 | ```yaml
235 | deploy:
236 |   resources:
237 |     limits:
238 |       cpus: '2'
239 |       memory: 4G
240 | ```
241 | 
242 | 2. **Persistent Volumes**: Use named volumes or bind mounts for production data
243 | 3. **Monitoring**: Export logs to external monitoring system
244 | 4. **Backups**: Regular backups of `/var/lib/falkordb/data` volume
245 | 5. **Security**: Set `FALKORDB_PASSWORD` in production environments
246 | 
247 | ## Comparison with Separate Containers
248 | 
249 | | Aspect | Combined Image | Separate Containers |
250 | |--------|---------------|---------------------|
251 | | Setup Complexity | Simple (one container) | Moderate (service dependencies) |
252 | | Network Latency | Lower (localhost) | Higher (container network) |
253 | | Resource Usage | Lower (shared base) | Higher (separate images) |
254 | | Scalability | Limited | Better (scale independently) |
255 | | Debugging | Harder (multiple processes) | Easier (isolated services) |
256 | | Production Use | Development/Single-node | Recommended |
257 | 
258 | ## See Also
259 | 
260 | - [Main MCP Server README](../README.md)
261 | - [FalkorDB Documentation](https://docs.falkordb.com/)
262 | - [Docker Compose Documentation](https://docs.docker.com/compose/)
263 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_falkordb_integration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | FalkorDB integration test for the Graphiti MCP Server.
  4 | Tests MCP server functionality with FalkorDB as the graph database backend.
  5 | """
  6 | 
  7 | import asyncio
  8 | import json
  9 | import time
 10 | from typing import Any
 11 | 
 12 | from mcp import StdioServerParameters
 13 | from mcp.client.stdio import stdio_client
 14 | 
 15 | 
 16 | class GraphitiFalkorDBIntegrationTest:
 17 |     """Integration test client for Graphiti MCP Server using FalkorDB backend."""
 18 | 
 19 |     def __init__(self):
 20 |         self.test_group_id = f'falkor_test_group_{int(time.time())}'
 21 |         self.session = None
 22 | 
 23 |     async def __aenter__(self):
 24 |         """Start the MCP client session with FalkorDB configuration."""
 25 |         # Configure server parameters to run with FalkorDB backend
 26 |         server_params = StdioServerParameters(
 27 |             command='uv',
 28 |             args=['run', 'main.py', '--transport', 'stdio', '--database-provider', 'falkordb'],
 29 |             env={
 30 |                 'FALKORDB_URI': 'redis://localhost:6379',
 31 |                 'FALKORDB_PASSWORD': '',  # No password for test instance
 32 |                 'FALKORDB_DATABASE': 'default_db',
 33 |                 'OPENAI_API_KEY': 'dummy_key_for_testing',
 34 |                 'GRAPHITI_GROUP_ID': self.test_group_id,
 35 |             },
 36 |         )
 37 | 
 38 |         # Start the stdio client
 39 |         self.session = await stdio_client(server_params).__aenter__()
 40 |         print('   📡 Started MCP client session with FalkorDB backend')
 41 |         return self
 42 | 
 43 |     async def __aexit__(self, exc_type, exc_val, exc_tb):
 44 |         """Clean up the MCP client session."""
 45 |         if self.session:
 46 |             await self.session.close()
 47 |             print('   🔌 Closed MCP client session')
 48 | 
 49 |     async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
 50 |         """Call an MCP tool via the stdio client."""
 51 |         try:
 52 |             result = await self.session.call_tool(tool_name, arguments)
 53 |             if hasattr(result, 'content') and result.content:
 54 |                 # Handle different content types
 55 |                 if hasattr(result.content[0], 'text'):
 56 |                     content = result.content[0].text
 57 |                     try:
 58 |                         return json.loads(content)
 59 |                     except json.JSONDecodeError:
 60 |                         return {'raw_response': content}
 61 |                 else:
 62 |                     return {'content': str(result.content[0])}
 63 |             return {'result': 'success', 'content': None}
 64 |         except Exception as e:
 65 |             return {'error': str(e), 'tool': tool_name, 'arguments': arguments}
 66 | 
 67 |     async def test_server_status(self) -> bool:
 68 |         """Test the get_status tool to verify FalkorDB connectivity."""
 69 |         print('   🏥 Testing server status with FalkorDB...')
 70 |         result = await self.call_mcp_tool('get_status', {})
 71 | 
 72 |         if 'error' in result:
 73 |             print(f'   ❌ Status check failed: {result["error"]}')
 74 |             return False
 75 | 
 76 |         # Check if status indicates FalkorDB is working
 77 |         status_text = result.get('raw_response', result.get('content', ''))
 78 |         if 'running' in str(status_text).lower() or 'ready' in str(status_text).lower():
 79 |             print('   ✅ Server status OK with FalkorDB')
 80 |             return True
 81 |         else:
 82 |             print(f'   ⚠️  Status unclear: {status_text}')
 83 |             return True  # Don't fail on unclear status
 84 | 
 85 |     async def test_add_episode(self) -> bool:
 86 |         """Test adding an episode to FalkorDB."""
 87 |         print('   📝 Testing episode addition to FalkorDB...')
 88 | 
 89 |         episode_data = {
 90 |             'name': 'FalkorDB Test Episode',
 91 |             'episode_body': 'This is a test episode to verify FalkorDB integration works correctly.',
 92 |             'source': 'text',
 93 |             'source_description': 'Integration test for FalkorDB backend',
 94 |         }
 95 | 
 96 |         result = await self.call_mcp_tool('add_episode', episode_data)
 97 | 
 98 |         if 'error' in result:
 99 |             print(f'   ❌ Add episode failed: {result["error"]}')
100 |             return False
101 | 
102 |         print('   ✅ Episode added successfully to FalkorDB')
103 |         return True
104 | 
105 |     async def test_search_functionality(self) -> bool:
106 |         """Test search functionality with FalkorDB."""
107 |         print('   🔍 Testing search functionality with FalkorDB...')
108 | 
109 |         # Give some time for episode processing
110 |         await asyncio.sleep(2)
111 | 
112 |         # Test node search
113 |         search_result = await self.call_mcp_tool(
114 |             'search_nodes', {'query': 'FalkorDB test episode', 'limit': 5}
115 |         )
116 | 
117 |         if 'error' in search_result:
118 |             print(f'   ⚠️  Search returned error (may be expected): {search_result["error"]}')
119 |             return True  # Don't fail on search errors in integration test
120 | 
121 |         print('   ✅ Search functionality working with FalkorDB')
122 |         return True
123 | 
124 |     async def test_clear_graph(self) -> bool:
125 |         """Test clearing the graph in FalkorDB."""
126 |         print('   🧹 Testing graph clearing in FalkorDB...')
127 | 
128 |         result = await self.call_mcp_tool('clear_graph', {})
129 | 
130 |         if 'error' in result:
131 |             print(f'   ❌ Clear graph failed: {result["error"]}')
132 |             return False
133 | 
134 |         print('   ✅ Graph cleared successfully in FalkorDB')
135 |         return True
136 | 
137 | 
138 | async def run_falkordb_integration_test() -> bool:
139 |     """Run the complete FalkorDB integration test suite."""
140 |     print('🧪 Starting FalkorDB Integration Test Suite')
141 |     print('=' * 55)
142 | 
143 |     test_results = []
144 | 
145 |     try:
146 |         async with GraphitiFalkorDBIntegrationTest() as test_client:
147 |             print(f'   🎯 Using test group: {test_client.test_group_id}')
148 | 
149 |             # Run test suite
150 |             tests = [
151 |                 ('Server Status', test_client.test_server_status),
152 |                 ('Add Episode', test_client.test_add_episode),
153 |                 ('Search Functionality', test_client.test_search_functionality),
154 |                 ('Clear Graph', test_client.test_clear_graph),
155 |             ]
156 | 
157 |             for test_name, test_func in tests:
158 |                 print(f'\n🔬 Running {test_name} Test...')
159 |                 try:
160 |                     result = await test_func()
161 |                     test_results.append((test_name, result))
162 |                     if result:
163 |                         print(f'   ✅ {test_name}: PASSED')
164 |                     else:
165 |                         print(f'   ❌ {test_name}: FAILED')
166 |                 except Exception as e:
167 |                     print(f'   💥 {test_name}: ERROR - {e}')
168 |                     test_results.append((test_name, False))
169 | 
170 |     except Exception as e:
171 |         print(f'💥 Test setup failed: {e}')
172 |         return False
173 | 
174 |     # Summary
175 |     print('\n' + '=' * 55)
176 |     print('📊 FalkorDB Integration Test Results:')
177 |     print('-' * 30)
178 | 
179 |     passed = sum(1 for _, result in test_results if result)
180 |     total = len(test_results)
181 | 
182 |     for test_name, result in test_results:
183 |         status = '✅ PASS' if result else '❌ FAIL'
184 |         print(f'   {test_name}: {status}')
185 | 
186 |     print(f'\n🎯 Overall: {passed}/{total} tests passed')
187 | 
188 |     if passed == total:
189 |         print('🎉 All FalkorDB integration tests PASSED!')
190 |         return True
191 |     else:
192 |         print('⚠️  Some FalkorDB integration tests failed')
193 |         return passed >= (total * 0.7)  # Pass if 70% of tests pass
194 | 
195 | 
196 | if __name__ == '__main__':
197 |     success = asyncio.run(run_falkordb_integration_test())
198 |     exit(0 if success else 1)
199 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_configuration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """Test script for configuration loading and factory patterns."""
  3 | 
  4 | import asyncio
  5 | import os
  6 | import sys
  7 | from pathlib import Path
  8 | 
  9 | # Add the current directory to the path
 10 | sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
 11 | 
 12 | from config.schema import GraphitiConfig
 13 | from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
 14 | 
 15 | 
 16 | def test_config_loading():
 17 |     """Test loading configuration from YAML and environment variables."""
 18 |     print('Testing configuration loading...')
 19 | 
 20 |     # Test with default config.yaml
 21 |     config = GraphitiConfig()
 22 | 
 23 |     print('✓ Loaded configuration successfully')
 24 |     print(f'  - Server transport: {config.server.transport}')
 25 |     print(f'  - LLM provider: {config.llm.provider}')
 26 |     print(f'  - LLM model: {config.llm.model}')
 27 |     print(f'  - Embedder provider: {config.embedder.provider}')
 28 |     print(f'  - Database provider: {config.database.provider}')
 29 |     print(f'  - Group ID: {config.graphiti.group_id}')
 30 | 
 31 |     # Test environment variable override
 32 |     os.environ['LLM__PROVIDER'] = 'anthropic'
 33 |     os.environ['LLM__MODEL'] = 'claude-3-opus'
 34 |     config2 = GraphitiConfig()
 35 | 
 36 |     print('\n✓ Environment variable overrides work')
 37 |     print(f'  - LLM provider (overridden): {config2.llm.provider}')
 38 |     print(f'  - LLM model (overridden): {config2.llm.model}')
 39 | 
 40 |     # Clean up env vars
 41 |     del os.environ['LLM__PROVIDER']
 42 |     del os.environ['LLM__MODEL']
 43 | 
 44 |     assert config is not None
 45 |     assert config2 is not None
 46 | 
 47 |     # Return the first config for subsequent tests
 48 |     return config
 49 | 
 50 | 
 51 | def test_llm_factory(config: GraphitiConfig):
 52 |     """Test LLM client factory creation."""
 53 |     print('\nTesting LLM client factory...')
 54 | 
 55 |     # Test OpenAI client creation (if API key is set)
 56 |     if (
 57 |         config.llm.provider == 'openai'
 58 |         and config.llm.providers.openai
 59 |         and config.llm.providers.openai.api_key
 60 |     ):
 61 |         try:
 62 |             client = LLMClientFactory.create(config.llm)
 63 |             print(f'✓ Created {config.llm.provider} LLM client successfully')
 64 |             print(f'  - Model: {client.model}')
 65 |             print(f'  - Temperature: {client.temperature}')
 66 |         except Exception as e:
 67 |             print(f'✗ Failed to create LLM client: {e}')
 68 |     else:
 69 |         print(f'⚠ Skipping LLM factory test (no API key configured for {config.llm.provider})')
 70 | 
 71 |     # Test switching providers
 72 |     test_config = config.llm.model_copy()
 73 |     test_config.provider = 'gemini'
 74 |     if not test_config.providers.gemini:
 75 |         from config.schema import GeminiProviderConfig
 76 | 
 77 |         test_config.providers.gemini = GeminiProviderConfig(api_key='dummy_value_for_testing')
 78 |     else:
 79 |         test_config.providers.gemini.api_key = 'dummy_value_for_testing'
 80 | 
 81 |     try:
 82 |         client = LLMClientFactory.create(test_config)
 83 |         print('✓ Factory supports provider switching (tested with Gemini)')
 84 |     except Exception as e:
 85 |         print(f'✗ Factory provider switching failed: {e}')
 86 | 
 87 | 
 88 | def test_embedder_factory(config: GraphitiConfig):
 89 |     """Test Embedder client factory creation."""
 90 |     print('\nTesting Embedder client factory...')
 91 | 
 92 |     # Test OpenAI embedder creation (if API key is set)
 93 |     if (
 94 |         config.embedder.provider == 'openai'
 95 |         and config.embedder.providers.openai
 96 |         and config.embedder.providers.openai.api_key
 97 |     ):
 98 |         try:
 99 |             _ = EmbedderFactory.create(config.embedder)
100 |             print(f'✓ Created {config.embedder.provider} Embedder client successfully')
101 |             # The embedder client may not expose model/dimensions as attributes
102 |             print(f'  - Configured model: {config.embedder.model}')
103 |             print(f'  - Configured dimensions: {config.embedder.dimensions}')
104 |         except Exception as e:
105 |             print(f'✗ Failed to create Embedder client: {e}')
106 |     else:
107 |         print(
108 |             f'⚠ Skipping Embedder factory test (no API key configured for {config.embedder.provider})'
109 |         )
110 | 
111 | 
112 | async def test_database_factory(config: GraphitiConfig):
113 |     """Test Database driver factory creation."""
114 |     print('\nTesting Database driver factory...')
115 | 
116 |     # Test Neo4j config creation
117 |     if config.database.provider == 'neo4j' and config.database.providers.neo4j:
118 |         try:
119 |             db_config = DatabaseDriverFactory.create_config(config.database)
120 |             print(f'✓ Created {config.database.provider} configuration successfully')
121 |             print(f'  - URI: {db_config["uri"]}')
122 |             print(f'  - User: {db_config["user"]}')
123 |             print(
124 |                 f'  - Password: {"*" * len(db_config["password"]) if db_config["password"] else "None"}'
125 |             )
126 | 
127 |             # Test actual connection would require initializing Graphiti
128 |             from graphiti_core import Graphiti
129 | 
130 |             try:
131 |                 # This will fail if Neo4j is not running, but tests the config
132 |                 graphiti = Graphiti(
133 |                     uri=db_config['uri'],
134 |                     user=db_config['user'],
135 |                     password=db_config['password'],
136 |                 )
137 |                 await graphiti.driver.client.verify_connectivity()
138 |                 print('  ✓ Successfully connected to Neo4j')
139 |                 await graphiti.driver.client.close()
140 |             except Exception as e:
141 |                 print(f'  ⚠ Could not connect to Neo4j (is it running?): {type(e).__name__}')
142 |         except Exception as e:
143 |             print(f'✗ Failed to create Database configuration: {e}')
144 |     else:
145 |         print(f'⚠ Skipping Database factory test (no configuration for {config.database.provider})')
146 | 
147 | 
148 | def test_cli_override():
149 |     """Test CLI argument override functionality."""
150 |     print('\nTesting CLI argument override...')
151 | 
152 |     # Simulate argparse Namespace
153 |     class Args:
154 |         config = Path('config.yaml')
155 |         transport = 'stdio'
156 |         llm_provider = 'anthropic'
157 |         model = 'claude-3-sonnet'
158 |         temperature = 0.5
159 |         embedder_provider = 'voyage'
160 |         embedder_model = 'voyage-3'
161 |         database_provider = 'falkordb'
162 |         group_id = 'test-group'
163 |         user_id = 'test-user'
164 | 
165 |     config = GraphitiConfig()
166 |     config.apply_cli_overrides(Args())
167 | 
168 |     print('✓ CLI overrides applied successfully')
169 |     print(f'  - Transport: {config.server.transport}')
170 |     print(f'  - LLM provider: {config.llm.provider}')
171 |     print(f'  - LLM model: {config.llm.model}')
172 |     print(f'  - Temperature: {config.llm.temperature}')
173 |     print(f'  - Embedder provider: {config.embedder.provider}')
174 |     print(f'  - Database provider: {config.database.provider}')
175 |     print(f'  - Group ID: {config.graphiti.group_id}')
176 |     print(f'  - User ID: {config.graphiti.user_id}')
177 | 
178 | 
179 | async def main():
180 |     """Run all tests."""
181 |     print('=' * 60)
182 |     print('Configuration and Factory Pattern Test Suite')
183 |     print('=' * 60)
184 | 
185 |     try:
186 |         # Test configuration loading
187 |         config = test_config_loading()
188 | 
189 |         # Test factories
190 |         test_llm_factory(config)
191 |         test_embedder_factory(config)
192 |         await test_database_factory(config)
193 | 
194 |         # Test CLI overrides
195 |         test_cli_override()
196 | 
197 |         print('\n' + '=' * 60)
198 |         print('✓ All tests completed successfully!')
199 |         print('=' * 60)
200 | 
201 |     except Exception as e:
202 |         print(f'\n✗ Test suite failed: {e}')
203 |         sys.exit(1)
204 | 
205 | 
206 | if __name__ == '__main__':
207 |     asyncio.run(main())
208 | 
```

--------------------------------------------------------------------------------
/graphiti_core/search/search_config_recipes.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from graphiti_core.search.search_config import (
 18 |     CommunityReranker,
 19 |     CommunitySearchConfig,
 20 |     CommunitySearchMethod,
 21 |     EdgeReranker,
 22 |     EdgeSearchConfig,
 23 |     EdgeSearchMethod,
 24 |     EpisodeReranker,
 25 |     EpisodeSearchConfig,
 26 |     EpisodeSearchMethod,
 27 |     NodeReranker,
 28 |     NodeSearchConfig,
 29 |     NodeSearchMethod,
 30 |     SearchConfig,
 31 | )
 32 | 
 33 | # Performs a hybrid search with rrf reranking over edges, nodes, and communities
 34 | COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
 35 |     edge_config=EdgeSearchConfig(
 36 |         search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
 37 |         reranker=EdgeReranker.rrf,
 38 |     ),
 39 |     node_config=NodeSearchConfig(
 40 |         search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
 41 |         reranker=NodeReranker.rrf,
 42 |     ),
 43 |     episode_config=EpisodeSearchConfig(
 44 |         search_methods=[
 45 |             EpisodeSearchMethod.bm25,
 46 |         ],
 47 |         reranker=EpisodeReranker.rrf,
 48 |     ),
 49 |     community_config=CommunitySearchConfig(
 50 |         search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
 51 |         reranker=CommunityReranker.rrf,
 52 |     ),
 53 | )
 54 | 
 55 | # Performs a hybrid search with mmr reranking over edges, nodes, and communities
 56 | COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
 57 |     edge_config=EdgeSearchConfig(
 58 |         search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
 59 |         reranker=EdgeReranker.mmr,
 60 |         mmr_lambda=1,
 61 |     ),
 62 |     node_config=NodeSearchConfig(
 63 |         search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
 64 |         reranker=NodeReranker.mmr,
 65 |         mmr_lambda=1,
 66 |     ),
 67 |     episode_config=EpisodeSearchConfig(
 68 |         search_methods=[
 69 |             EpisodeSearchMethod.bm25,
 70 |         ],
 71 |         reranker=EpisodeReranker.rrf,
 72 |     ),
 73 |     community_config=CommunitySearchConfig(
 74 |         search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
 75 |         reranker=CommunityReranker.mmr,
 76 |         mmr_lambda=1,
 77 |     ),
 78 | )
 79 | 
 80 | # Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
 81 | COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
 82 |     edge_config=EdgeSearchConfig(
 83 |         search_methods=[
 84 |             EdgeSearchMethod.bm25,
 85 |             EdgeSearchMethod.cosine_similarity,
 86 |             EdgeSearchMethod.bfs,
 87 |         ],
 88 |         reranker=EdgeReranker.cross_encoder,
 89 |     ),
 90 |     node_config=NodeSearchConfig(
 91 |         search_methods=[
 92 |             NodeSearchMethod.bm25,
 93 |             NodeSearchMethod.cosine_similarity,
 94 |             NodeSearchMethod.bfs,
 95 |         ],
 96 |         reranker=NodeReranker.cross_encoder,
 97 |     ),
 98 |     episode_config=EpisodeSearchConfig(
 99 |         search_methods=[
100 |             EpisodeSearchMethod.bm25,
101 |         ],
102 |         reranker=EpisodeReranker.cross_encoder,
103 |     ),
104 |     community_config=CommunitySearchConfig(
105 |         search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
106 |         reranker=CommunityReranker.cross_encoder,
107 |     ),
108 | )
109 | 
110 | # performs a hybrid search over edges with rrf reranking
111 | EDGE_HYBRID_SEARCH_RRF = SearchConfig(
112 |     edge_config=EdgeSearchConfig(
113 |         search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
114 |         reranker=EdgeReranker.rrf,
115 |     )
116 | )
117 | 
118 | # performs a hybrid search over edges with mmr reranking
119 | EDGE_HYBRID_SEARCH_MMR = SearchConfig(
120 |     edge_config=EdgeSearchConfig(
121 |         search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
122 |         reranker=EdgeReranker.mmr,
123 |     )
124 | )
125 | 
126 | # performs a hybrid search over edges with node distance reranking
127 | EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
128 |     edge_config=EdgeSearchConfig(
129 |         search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
130 |         reranker=EdgeReranker.node_distance,
131 |     ),
132 | )
133 | 
134 | # performs a hybrid search over edges with episode mention reranking
135 | EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
136 |     edge_config=EdgeSearchConfig(
137 |         search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
138 |         reranker=EdgeReranker.episode_mentions,
139 |     )
140 | )
141 | 
142 | # performs a hybrid search over edges with cross encoder reranking
143 | EDGE_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
144 |     edge_config=EdgeSearchConfig(
145 |         search_methods=[
146 |             EdgeSearchMethod.bm25,
147 |             EdgeSearchMethod.cosine_similarity,
148 |             EdgeSearchMethod.bfs,
149 |         ],
150 |         reranker=EdgeReranker.cross_encoder,
151 |     ),
152 |     limit=10,
153 | )
154 | 
155 | # performs a hybrid search over nodes with rrf reranking
156 | NODE_HYBRID_SEARCH_RRF = SearchConfig(
157 |     node_config=NodeSearchConfig(
158 |         search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
159 |         reranker=NodeReranker.rrf,
160 |     )
161 | )
162 | 
163 | # performs a hybrid search over nodes with mmr reranking
164 | NODE_HYBRID_SEARCH_MMR = SearchConfig(
165 |     node_config=NodeSearchConfig(
166 |         search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
167 |         reranker=NodeReranker.mmr,
168 |     )
169 | )
170 | 
171 | # performs a hybrid search over nodes with node distance reranking
172 | NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
173 |     node_config=NodeSearchConfig(
174 |         search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
175 |         reranker=NodeReranker.node_distance,
176 |     )
177 | )
178 | 
179 | # performs a hybrid search over nodes with episode mentions reranking
180 | NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
181 |     node_config=NodeSearchConfig(
182 |         search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
183 |         reranker=NodeReranker.episode_mentions,
184 |     )
185 | )
186 | 
187 | # performs a hybrid search over nodes with episode mentions reranking
188 | NODE_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
189 |     node_config=NodeSearchConfig(
190 |         search_methods=[
191 |             NodeSearchMethod.bm25,
192 |             NodeSearchMethod.cosine_similarity,
193 |             NodeSearchMethod.bfs,
194 |         ],
195 |         reranker=NodeReranker.cross_encoder,
196 |     ),
197 |     limit=10,
198 | )
199 | 
200 | # performs a hybrid search over communities with rrf reranking
201 | COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
202 |     community_config=CommunitySearchConfig(
203 |         search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
204 |         reranker=CommunityReranker.rrf,
205 |     )
206 | )
207 | 
208 | # performs a hybrid search over communities with mmr reranking
209 | COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
210 |     community_config=CommunitySearchConfig(
211 |         search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
212 |         reranker=CommunityReranker.mmr,
213 |     )
214 | )
215 | 
216 | # performs a hybrid search over communities with mmr reranking
217 | COMMUNITY_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
218 |     community_config=CommunitySearchConfig(
219 |         search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
220 |         reranker=CommunityReranker.cross_encoder,
221 |     ),
222 |     limit=3,
223 | )
224 | 
```

--------------------------------------------------------------------------------
/tests/test_node_int.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from datetime import datetime, timedelta
 18 | from uuid import uuid4
 19 | 
 20 | import pytest
 21 | 
 22 | from graphiti_core.nodes import (
 23 |     CommunityNode,
 24 |     EntityNode,
 25 |     EpisodeType,
 26 |     EpisodicNode,
 27 | )
 28 | from tests.helpers_test import (
 29 |     assert_community_node_equals,
 30 |     assert_entity_node_equals,
 31 |     assert_episodic_node_equals,
 32 |     get_node_count,
 33 |     group_id,
 34 | )
 35 | 
 36 | created_at = datetime.now()
 37 | deleted_at = created_at + timedelta(days=3)
 38 | valid_at = created_at + timedelta(days=1)
 39 | invalid_at = created_at + timedelta(days=2)
 40 | 
 41 | 
 42 | @pytest.fixture
 43 | def sample_entity_node():
 44 |     return EntityNode(
 45 |         uuid=str(uuid4()),
 46 |         name='Test Entity',
 47 |         group_id=group_id,
 48 |         labels=['Entity', 'Person'],
 49 |         created_at=created_at,
 50 |         name_embedding=[0.5] * 1024,
 51 |         summary='Entity Summary',
 52 |         attributes={
 53 |             'age': 30,
 54 |             'location': 'New York',
 55 |         },
 56 |     )
 57 | 
 58 | 
 59 | @pytest.fixture
 60 | def sample_episodic_node():
 61 |     return EpisodicNode(
 62 |         uuid=str(uuid4()),
 63 |         name='Episode 1',
 64 |         group_id=group_id,
 65 |         created_at=created_at,
 66 |         source=EpisodeType.text,
 67 |         source_description='Test source',
 68 |         content='Some content here',
 69 |         valid_at=valid_at,
 70 |         entity_edges=[],
 71 |     )
 72 | 
 73 | 
 74 | @pytest.fixture
 75 | def sample_community_node():
 76 |     return CommunityNode(
 77 |         uuid=str(uuid4()),
 78 |         name='Community A',
 79 |         group_id=group_id,
 80 |         created_at=created_at,
 81 |         name_embedding=[0.5] * 1024,
 82 |         summary='Community summary',
 83 |     )
 84 | 
 85 | 
 86 | @pytest.mark.asyncio
 87 | async def test_entity_node(sample_entity_node, graph_driver):
 88 |     uuid = sample_entity_node.uuid
 89 | 
 90 |     # Create node
 91 |     node_count = await get_node_count(graph_driver, [uuid])
 92 |     assert node_count == 0
 93 |     await sample_entity_node.save(graph_driver)
 94 |     node_count = await get_node_count(graph_driver, [uuid])
 95 |     assert node_count == 1
 96 | 
 97 |     # Get node by uuid
 98 |     retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
 99 |     await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
100 | 
101 |     # Get node by uuids
102 |     retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
103 |     await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
104 | 
105 |     # Get node by group ids
106 |     retrieved = await EntityNode.get_by_group_ids(
107 |         graph_driver, [group_id], limit=2, with_embeddings=True
108 |     )
109 |     assert len(retrieved) == 1
110 |     await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
111 | 
112 |     # Delete node by uuid
113 |     await sample_entity_node.delete(graph_driver)
114 |     node_count = await get_node_count(graph_driver, [uuid])
115 |     assert node_count == 0
116 | 
117 |     # Delete node by uuids
118 |     await sample_entity_node.save(graph_driver)
119 |     node_count = await get_node_count(graph_driver, [uuid])
120 |     assert node_count == 1
121 |     await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
122 |     node_count = await get_node_count(graph_driver, [uuid])
123 |     assert node_count == 0
124 | 
125 |     # Delete node by group id
126 |     await sample_entity_node.save(graph_driver)
127 |     node_count = await get_node_count(graph_driver, [uuid])
128 |     assert node_count == 1
129 |     await sample_entity_node.delete_by_group_id(graph_driver, group_id)
130 |     node_count = await get_node_count(graph_driver, [uuid])
131 |     assert node_count == 0
132 | 
133 |     await graph_driver.close()
134 | 
135 | 
136 | @pytest.mark.asyncio
137 | async def test_community_node(sample_community_node, graph_driver):
138 |     uuid = sample_community_node.uuid
139 | 
140 |     # Create node
141 |     node_count = await get_node_count(graph_driver, [uuid])
142 |     assert node_count == 0
143 |     await sample_community_node.save(graph_driver)
144 |     node_count = await get_node_count(graph_driver, [uuid])
145 |     assert node_count == 1
146 | 
147 |     # Get node by uuid
148 |     retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
149 |     await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
150 | 
151 |     # Get node by uuids
152 |     retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
153 |     await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
154 | 
155 |     # Get node by group ids
156 |     retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
157 |     assert len(retrieved) == 1
158 |     await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
159 | 
160 |     # Delete node by uuid
161 |     await sample_community_node.delete(graph_driver)
162 |     node_count = await get_node_count(graph_driver, [uuid])
163 |     assert node_count == 0
164 | 
165 |     # Delete node by uuids
166 |     await sample_community_node.save(graph_driver)
167 |     node_count = await get_node_count(graph_driver, [uuid])
168 |     assert node_count == 1
169 |     await sample_community_node.delete_by_uuids(graph_driver, [uuid])
170 |     node_count = await get_node_count(graph_driver, [uuid])
171 |     assert node_count == 0
172 | 
173 |     # Delete node by group id
174 |     await sample_community_node.save(graph_driver)
175 |     node_count = await get_node_count(graph_driver, [uuid])
176 |     assert node_count == 1
177 |     await sample_community_node.delete_by_group_id(graph_driver, group_id)
178 |     node_count = await get_node_count(graph_driver, [uuid])
179 |     assert node_count == 0
180 | 
181 |     await graph_driver.close()
182 | 
183 | 
184 | @pytest.mark.asyncio
185 | async def test_episodic_node(sample_episodic_node, graph_driver):
186 |     uuid = sample_episodic_node.uuid
187 | 
188 |     # Create node
189 |     node_count = await get_node_count(graph_driver, [uuid])
190 |     assert node_count == 0
191 |     await sample_episodic_node.save(graph_driver)
192 |     node_count = await get_node_count(graph_driver, [uuid])
193 |     assert node_count == 1
194 | 
195 |     # Get node by uuid
196 |     retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
197 |     await assert_episodic_node_equals(retrieved, sample_episodic_node)
198 | 
199 |     # Get node by uuids
200 |     retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
201 |     await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
202 | 
203 |     # Get node by group ids
204 |     retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
205 |     assert len(retrieved) == 1
206 |     await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
207 | 
208 |     # Delete node by uuid
209 |     await sample_episodic_node.delete(graph_driver)
210 |     node_count = await get_node_count(graph_driver, [uuid])
211 |     assert node_count == 0
212 | 
213 |     # Delete node by uuids
214 |     await sample_episodic_node.save(graph_driver)
215 |     node_count = await get_node_count(graph_driver, [uuid])
216 |     assert node_count == 1
217 |     await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
218 |     node_count = await get_node_count(graph_driver, [uuid])
219 |     assert node_count == 0
220 | 
221 |     # Delete node by group id
222 |     await sample_episodic_node.save(graph_driver)
223 |     node_count = await get_node_count(graph_driver, [uuid])
224 |     assert node_count == 1
225 |     await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
226 |     node_count = await get_node_count(graph_driver, [uuid])
227 |     assert node_count == 0
228 | 
229 |     await graph_driver.close()
230 | 
```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_temporal_operations_int.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import os
 18 | from datetime import timedelta
 19 | 
 20 | import pytest
 21 | from dotenv import load_dotenv
 22 | 
 23 | from graphiti_core.edges import EntityEdge
 24 | from graphiti_core.llm_client import LLMConfig, OpenAIClient
 25 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
 26 | from graphiti_core.utils.datetime_utils import utc_now
 27 | from graphiti_core.utils.maintenance.temporal_operations import (
 28 |     get_edge_contradictions,
 29 | )
 30 | 
 31 | load_dotenv()
 32 | 
 33 | 
 34 | def setup_llm_client():
 35 |     return OpenAIClient(
 36 |         LLMConfig(
 37 |             api_key=os.getenv('TEST_OPENAI_API_KEY'),
 38 |             model=os.getenv('TEST_OPENAI_MODEL'),
 39 |             base_url='https://api.openai.com/v1',
 40 |         )
 41 |     )
 42 | 
 43 | 
 44 | def create_test_data():
 45 |     now = utc_now()
 46 | 
 47 |     # Create edges
 48 |     existing_edge = EntityEdge(
 49 |         uuid='e1',
 50 |         source_node_uuid='1',
 51 |         target_node_uuid='2',
 52 |         name='LIKES',
 53 |         fact='Alice likes Bob',
 54 |         created_at=now - timedelta(days=1),
 55 |         group_id='1',
 56 |     )
 57 |     new_edge = EntityEdge(
 58 |         uuid='e2',
 59 |         source_node_uuid='1',
 60 |         target_node_uuid='2',
 61 |         name='DISLIKES',
 62 |         fact='Alice dislikes Bob',
 63 |         created_at=now,
 64 |         group_id='1',
 65 |     )
 66 | 
 67 |     # Create current episode
 68 |     current_episode = EpisodicNode(
 69 |         name='Current Episode',
 70 |         content='Alice now dislikes Bob',
 71 |         created_at=now,
 72 |         valid_at=now,
 73 |         source=EpisodeType.message,
 74 |         source_description='Test episode for unit testing',
 75 |         group_id='1',
 76 |     )
 77 | 
 78 |     # Create previous episodes
 79 |     previous_episodes = [
 80 |         EpisodicNode(
 81 |             name='Previous Episode',
 82 |             content='Alice liked Bob',
 83 |             created_at=now - timedelta(days=1),
 84 |             valid_at=now - timedelta(days=1),
 85 |             source=EpisodeType.message,
 86 |             source_description='Test previous episode for unit testing',
 87 |             group_id='1',
 88 |         )
 89 |     ]
 90 | 
 91 |     return existing_edge, new_edge, current_episode, previous_episodes
 92 | 
 93 | 
 94 | @pytest.mark.asyncio
 95 | @pytest.mark.integration
 96 | async def test_get_edge_contradictions():
 97 |     existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
 98 | 
 99 |     invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge])
100 | 
101 |     assert len(invalidated_edges) == 1
102 |     assert invalidated_edges[0].uuid == existing_edge.uuid
103 | 
104 | 
105 | @pytest.mark.asyncio
106 | @pytest.mark.integration
107 | async def test_get_edge_contradictions_no_contradictions():
108 |     _, new_edge, current_episode, previous_episodes = create_test_data()
109 | 
110 |     invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [])
111 | 
112 |     assert len(invalidated_edges) == 0
113 | 
114 | 
115 | @pytest.mark.skip(reason='Flaky LLM-based test with non-deterministic results')
116 | @pytest.mark.asyncio
117 | @pytest.mark.integration
118 | async def test_get_edge_contradictions_multiple_existing():
119 |     existing_edge1, new_edge, _, _ = create_test_data()
120 |     existing_edge2, _, _, _ = create_test_data()
121 |     existing_edge2.uuid = 'e3'
122 |     existing_edge2.name = 'KNOWS'
123 |     existing_edge2.fact = 'Alice knows Bob'
124 | 
125 |     invalidated_edges = await get_edge_contradictions(
126 |         setup_llm_client(), new_edge, [existing_edge1, existing_edge2]
127 |     )
128 | 
129 |     assert len(invalidated_edges) == 1
130 |     assert invalidated_edges[0].uuid == existing_edge1.uuid
131 | 
132 | 
133 | # Helper function to create more complex test data
134 | def create_complex_test_data():
135 |     now = utc_now()
136 | 
137 |     # Create nodes
138 |     node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
139 |     node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
140 |     node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
141 |     node4 = EntityNode(
142 |         uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
143 |     )
144 | 
145 |     # Create edges
146 |     existing_edge1 = EntityEdge(
147 |         uuid='e1',
148 |         source_node_uuid='1',
149 |         target_node_uuid='2',
150 |         name='LIKES',
151 |         fact='Alice likes Bob',
152 |         group_id='1',
153 |         created_at=now - timedelta(days=5),
154 |     )
155 |     existing_edge2 = EntityEdge(
156 |         uuid='e2',
157 |         source_node_uuid='1',
158 |         target_node_uuid='3',
159 |         name='FRIENDS_WITH',
160 |         fact='Alice is friends with Charlie',
161 |         group_id='1',
162 |         created_at=now - timedelta(days=3),
163 |     )
164 |     existing_edge3 = EntityEdge(
165 |         uuid='e3',
166 |         source_node_uuid='2',
167 |         target_node_uuid='4',
168 |         name='WORKS_FOR',
169 |         fact='Bob works for Company XYZ',
170 |         group_id='1',
171 |         created_at=now - timedelta(days=2),
172 |     )
173 | 
174 |     return [existing_edge1, existing_edge2, existing_edge3], [
175 |         node1,
176 |         node2,
177 |         node3,
178 |         node4,
179 |     ]
180 | 
181 | 
182 | @pytest.mark.asyncio
183 | @pytest.mark.integration
184 | async def test_invalidate_edges_complex():
185 |     existing_edges, nodes = create_complex_test_data()
186 | 
187 |     # Create a new edge that contradicts an existing one
188 |     new_edge = EntityEdge(
189 |         uuid='e4',
190 |         source_node_uuid='1',
191 |         target_node_uuid='2',
192 |         name='DISLIKES',
193 |         fact='Alice dislikes Bob',
194 |         group_id='1',
195 |         created_at=utc_now(),
196 |     )
197 | 
198 |     invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
199 | 
200 |     assert len(invalidated_edges) == 1
201 |     assert invalidated_edges[0].uuid == 'e1'
202 | 
203 | 
204 | @pytest.mark.asyncio
205 | @pytest.mark.integration
206 | async def test_get_edge_contradictions_temporal_update():
207 |     existing_edges, nodes = create_complex_test_data()
208 | 
209 |     # Create a new edge that updates an existing one with new information
210 |     new_edge = EntityEdge(
211 |         uuid='e5',
212 |         source_node_uuid='2',
213 |         target_node_uuid='4',
214 |         name='LEFT_JOB',
215 |         fact='Bob no longer works at at Company XYZ',
216 |         group_id='1',
217 |         created_at=utc_now(),
218 |     )
219 | 
220 |     invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
221 | 
222 |     assert len(invalidated_edges) == 1
223 |     assert invalidated_edges[0].uuid == 'e3'
224 | 
225 | 
226 | @pytest.mark.asyncio
227 | @pytest.mark.integration
228 | async def test_get_edge_contradictions_no_effect():
229 |     existing_edges, nodes = create_complex_test_data()
230 | 
231 |     # Create a new edge that doesn't invalidate any existing edges
232 |     new_edge = EntityEdge(
233 |         uuid='e8',
234 |         source_node_uuid='3',
235 |         target_node_uuid='4',
236 |         name='APPLIED_TO',
237 |         fact='Charlie applied to Company XYZ',
238 |         group_id='1',
239 |         created_at=utc_now(),
240 |     )
241 | 
242 |     invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
243 | 
244 |     assert len(invalidated_edges) == 0
245 | 
246 | 
247 | @pytest.mark.skip(reason='Flaky LLM-based test with non-deterministic results')
248 | @pytest.mark.asyncio
249 | @pytest.mark.integration
250 | async def test_invalidate_edges_partial_update():
251 |     existing_edges, nodes = create_complex_test_data()
252 | 
253 |     # Create a new edge that partially updates an existing one
254 |     new_edge = EntityEdge(
255 |         uuid='e9',
256 |         source_node_uuid='2',
257 |         target_node_uuid='4',
258 |         name='CHANGED_POSITION',
259 |         fact='Bob changed his position at Company XYZ',
260 |         group_id='1',
261 |         created_at=utc_now(),
262 |     )
263 | 
264 |     invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
265 | 
266 |     assert len(invalidated_edges) == 0  # The existing edge is not invalidated, just updated
267 | 
268 | 
269 | # Run the tests
270 | if __name__ == '__main__':
271 |     pytest.main([__file__])
272 | 
```

--------------------------------------------------------------------------------
/graphiti_core/graph_queries.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Database query utilities for different graph database backends.
  3 | 
  4 | This module provides database-agnostic query generation for Neo4j and FalkorDB,
  5 | supporting index creation, fulltext search, and bulk operations.
  6 | """
  7 | 
  8 | from typing_extensions import LiteralString
  9 | 
 10 | from graphiti_core.driver.driver import GraphProvider
 11 | 
 12 | # Mapping from Neo4j fulltext index names to FalkorDB node labels
 13 | NEO4J_TO_FALKORDB_MAPPING = {
 14 |     'node_name_and_summary': 'Entity',
 15 |     'community_name': 'Community',
 16 |     'episode_content': 'Episodic',
 17 |     'edge_name_and_fact': 'RELATES_TO',
 18 | }
 19 | # Mapping from fulltext index names to Kuzu node labels
 20 | INDEX_TO_LABEL_KUZU_MAPPING = {
 21 |     'node_name_and_summary': 'Entity',
 22 |     'community_name': 'Community',
 23 |     'episode_content': 'Episodic',
 24 |     'edge_name_and_fact': 'RelatesToNode_',
 25 | }
 26 | 
 27 | 
 28 | def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
 29 |     if provider == GraphProvider.FALKORDB:
 30 |         return [
 31 |             # Entity node
 32 |             'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
 33 |             # Episodic node
 34 |             'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
 35 |             # Community node
 36 |             'CREATE INDEX FOR (n:Community) ON (n.uuid)',
 37 |             # RELATES_TO edge
 38 |             'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
 39 |             # MENTIONS edge
 40 |             'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
 41 |             # HAS_MEMBER edge
 42 |             'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
 43 |         ]
 44 | 
 45 |     if provider == GraphProvider.KUZU:
 46 |         return []
 47 | 
 48 |     return [
 49 |         'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
 50 |         'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
 51 |         'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
 52 |         'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
 53 |         'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
 54 |         'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
 55 |         'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
 56 |         'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
 57 |         'CREATE INDEX community_group_id IF NOT EXISTS FOR (n:Community) ON (n.group_id)',
 58 |         'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
 59 |         'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
 60 |         'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
 61 |         'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
 62 |         'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
 63 |         'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
 64 |         'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
 65 |         'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
 66 |         'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
 67 |         'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
 68 |         'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
 69 |     ]
 70 | 
 71 | 
 72 | def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
 73 |     if provider == GraphProvider.FALKORDB:
 74 |         from typing import cast
 75 | 
 76 |         from graphiti_core.driver.falkordb_driver import STOPWORDS
 77 | 
 78 |         # Convert to string representation for embedding in queries
 79 |         stopwords_str = str(STOPWORDS)
 80 | 
 81 |         # Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
 82 |         return cast(
 83 |             list[LiteralString],
 84 |             [
 85 |                 f"""CALL db.idx.fulltext.createNodeIndex(
 86 |                                                 {{
 87 |                                                     label: 'Episodic',
 88 |                                                     stopwords: {stopwords_str}
 89 |                                                 }},
 90 |                                                 'content', 'source', 'source_description', 'group_id'
 91 |                                                 )""",
 92 |                 f"""CALL db.idx.fulltext.createNodeIndex(
 93 |                                                 {{
 94 |                                                     label: 'Entity',
 95 |                                                     stopwords: {stopwords_str}
 96 |                                                 }},
 97 |                                                 'name', 'summary', 'group_id'
 98 |                                                 )""",
 99 |                 f"""CALL db.idx.fulltext.createNodeIndex(
100 |                                                 {{
101 |                                                     label: 'Community',
102 |                                                     stopwords: {stopwords_str}
103 |                                                 }},
104 |                                                 'name', 'group_id'
105 |                                                 )""",
106 |                 """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
107 |             ],
108 |         )
109 | 
110 |     if provider == GraphProvider.KUZU:
111 |         return [
112 |             "CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
113 |             "CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
114 |             "CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
115 |             "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
116 |         ]
117 | 
118 |     return [
119 |         """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
120 |         FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
121 |         """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
122 |         FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
123 |         """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
124 |         FOR (n:Community) ON EACH [n.name, n.group_id]""",
125 |         """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
126 |         FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
127 |     ]
128 | 
129 | 
130 | def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
131 |     if provider == GraphProvider.FALKORDB:
132 |         label = NEO4J_TO_FALKORDB_MAPPING[name]
133 |         return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
134 | 
135 |     if provider == GraphProvider.KUZU:
136 |         label = INDEX_TO_LABEL_KUZU_MAPPING[name]
137 |         return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
138 | 
139 |     return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
140 | 
141 | 
142 | def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
143 |     if provider == GraphProvider.FALKORDB:
144 |         # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
145 |         return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
146 | 
147 |     if provider == GraphProvider.KUZU:
148 |         return f'array_cosine_similarity({vec1}, {vec2})'
149 | 
150 |     return f'vector.similarity.cosine({vec1}, {vec2})'
151 | 
152 | 
153 | def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
154 |     if provider == GraphProvider.FALKORDB:
155 |         label = NEO4J_TO_FALKORDB_MAPPING[name]
156 |         return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
157 | 
158 |     if provider == GraphProvider.KUZU:
159 |         label = INDEX_TO_LABEL_KUZU_MAPPING[name]
160 |         return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
161 | 
162 |     return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
163 | 
```

--------------------------------------------------------------------------------
/examples/azure-openai/azure_openai_neo4j.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2025, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import json
 19 | import logging
 20 | import os
 21 | from datetime import datetime, timezone
 22 | from logging import INFO
 23 | 
 24 | from dotenv import load_dotenv
 25 | from openai import AsyncOpenAI
 26 | 
 27 | from graphiti_core import Graphiti
 28 | from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
 29 | from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
 30 | from graphiti_core.llm_client.config import LLMConfig
 31 | from graphiti_core.nodes import EpisodeType
 32 | 
 33 | #################################################
 34 | # CONFIGURATION
 35 | #################################################
 36 | # Set up logging and environment variables for
 37 | # connecting to Neo4j database and Azure OpenAI
 38 | #################################################
 39 | 
 40 | # Configure logging
 41 | logging.basicConfig(
 42 |     level=INFO,
 43 |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 44 |     datefmt='%Y-%m-%d %H:%M:%S',
 45 | )
 46 | logger = logging.getLogger(__name__)
 47 | 
 48 | load_dotenv()
 49 | 
 50 | # Neo4j connection parameters
 51 | # Make sure Neo4j Desktop is running with a local DBMS started
 52 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
 53 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
 54 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
 55 | 
 56 | # Azure OpenAI connection parameters
 57 | azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
 58 | azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
 59 | azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', 'gpt-4.1')
 60 | azure_embedding_deployment = os.environ.get(
 61 |     'AZURE_OPENAI_EMBEDDING_DEPLOYMENT', 'text-embedding-3-small'
 62 | )
 63 | 
 64 | if not azure_endpoint or not azure_api_key:
 65 |     raise ValueError('AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set')
 66 | 
 67 | 
 68 | async def main():
 69 |     #################################################
 70 |     # INITIALIZATION
 71 |     #################################################
 72 |     # Connect to Neo4j and Azure OpenAI, then set up
 73 |     # Graphiti indices. This is required before using
 74 |     # other Graphiti functionality
 75 |     #################################################
 76 | 
 77 |     # Initialize Azure OpenAI client
 78 |     azure_client = AsyncOpenAI(
 79 |         base_url=f'{azure_endpoint}/openai/v1/',
 80 |         api_key=azure_api_key,
 81 |     )
 82 | 
 83 |     # Create LLM and Embedder clients
 84 |     llm_client = AzureOpenAILLMClient(
 85 |         azure_client=azure_client,
 86 |         config=LLMConfig(model=azure_deployment, small_model=azure_deployment),
 87 |     )
 88 |     embedder_client = AzureOpenAIEmbedderClient(
 89 |         azure_client=azure_client, model=azure_embedding_deployment
 90 |     )
 91 | 
 92 |     # Initialize Graphiti with Neo4j connection and Azure OpenAI clients
 93 |     graphiti = Graphiti(
 94 |         neo4j_uri,
 95 |         neo4j_user,
 96 |         neo4j_password,
 97 |         llm_client=llm_client,
 98 |         embedder=embedder_client,
 99 |     )
100 | 
101 |     try:
102 |         #################################################
103 |         # ADDING EPISODES
104 |         #################################################
105 |         # Episodes are the primary units of information
106 |         # in Graphiti. They can be text or structured JSON
107 |         # and are automatically processed to extract entities
108 |         # and relationships.
109 |         #################################################
110 | 
111 |         # Example: Add Episodes
112 |         # Episodes list containing both text and JSON episodes
113 |         episodes = [
114 |             {
115 |                 'content': 'Kamala Harris is the Attorney General of California. She was previously '
116 |                 'the district attorney for San Francisco.',
117 |                 'type': EpisodeType.text,
118 |                 'description': 'podcast transcript',
119 |             },
120 |             {
121 |                 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
122 |                 'type': EpisodeType.text,
123 |                 'description': 'podcast transcript',
124 |             },
125 |             {
126 |                 'content': {
127 |                     'name': 'Gavin Newsom',
128 |                     'position': 'Governor',
129 |                     'state': 'California',
130 |                     'previous_role': 'Lieutenant Governor',
131 |                     'previous_location': 'San Francisco',
132 |                 },
133 |                 'type': EpisodeType.json,
134 |                 'description': 'podcast metadata',
135 |             },
136 |         ]
137 | 
138 |         # Add episodes to the graph
139 |         for i, episode in enumerate(episodes):
140 |             await graphiti.add_episode(
141 |                 name=f'California Politics {i}',
142 |                 episode_body=(
143 |                     episode['content']
144 |                     if isinstance(episode['content'], str)
145 |                     else json.dumps(episode['content'])
146 |                 ),
147 |                 source=episode['type'],
148 |                 source_description=episode['description'],
149 |                 reference_time=datetime.now(timezone.utc),
150 |             )
151 |             print(f'Added episode: California Politics {i} ({episode["type"].value})')
152 | 
153 |         #################################################
154 |         # BASIC SEARCH
155 |         #################################################
156 |         # The simplest way to retrieve relationships (edges)
157 |         # from Graphiti is using the search method, which
158 |         # performs a hybrid search combining semantic
159 |         # similarity and BM25 text retrieval.
160 |         #################################################
161 | 
162 |         # Perform a hybrid search combining semantic similarity and BM25 retrieval
163 |         print("\nSearching for: 'Who was the California Attorney General?'")
164 |         results = await graphiti.search('Who was the California Attorney General?')
165 | 
166 |         # Print search results
167 |         print('\nSearch Results:')
168 |         for result in results:
169 |             print(f'UUID: {result.uuid}')
170 |             print(f'Fact: {result.fact}')
171 |             if hasattr(result, 'valid_at') and result.valid_at:
172 |                 print(f'Valid from: {result.valid_at}')
173 |             if hasattr(result, 'invalid_at') and result.invalid_at:
174 |                 print(f'Valid until: {result.invalid_at}')
175 |             print('---')
176 | 
177 |         #################################################
178 |         # CENTER NODE SEARCH
179 |         #################################################
180 |         # For more contextually relevant results, you can
181 |         # use a center node to rerank search results based
182 |         # on their graph distance to a specific node
183 |         #################################################
184 | 
185 |         # Use the top search result's UUID as the center node for reranking
186 |         if results and len(results) > 0:
187 |             # Get the source node UUID from the top result
188 |             center_node_uuid = results[0].source_node_uuid
189 | 
190 |             print('\nReranking search results based on graph distance:')
191 |             print(f'Using center node UUID: {center_node_uuid}')
192 | 
193 |             reranked_results = await graphiti.search(
194 |                 'Who was the California Attorney General?',
195 |                 center_node_uuid=center_node_uuid,
196 |             )
197 | 
198 |             # Print reranked search results
199 |             print('\nReranked Search Results:')
200 |             for result in reranked_results:
201 |                 print(f'UUID: {result.uuid}')
202 |                 print(f'Fact: {result.fact}')
203 |                 if hasattr(result, 'valid_at') and result.valid_at:
204 |                     print(f'Valid from: {result.valid_at}')
205 |                 if hasattr(result, 'invalid_at') and result.invalid_at:
206 |                     print(f'Valid until: {result.invalid_at}')
207 |                 print('---')
208 |         else:
209 |             print('No results found in the initial search to use as center node.')
210 | 
211 |     finally:
212 |         #################################################
213 |         # CLEANUP
214 |         #################################################
215 |         # Always close the connection to Neo4j when
216 |         # finished to properly release resources
217 |         #################################################
218 | 
219 |         # Close the connection
220 |         await graphiti.close()
221 |         print('\nConnection closed')
222 | 
223 | 
224 | if __name__ == '__main__':
225 |     asyncio.run(main())
226 | 
```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/openai_generic_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | import logging
 19 | import typing
 20 | from typing import Any, ClassVar
 21 | 
 22 | import openai
 23 | from openai import AsyncOpenAI
 24 | from openai.types.chat import ChatCompletionMessageParam
 25 | from pydantic import BaseModel
 26 | 
 27 | from ..prompts.models import Message
 28 | from .client import LLMClient, get_extraction_language_instruction
 29 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
 30 | from .errors import RateLimitError, RefusalError
 31 | 
 32 | logger = logging.getLogger(__name__)
 33 | 
 34 | DEFAULT_MODEL = 'gpt-4.1-mini'
 35 | 
 36 | 
 37 | class OpenAIGenericClient(LLMClient):
 38 |     """
 39 |     OpenAIClient is a client class for interacting with OpenAI's language models.
 40 | 
 41 |     This class extends the LLMClient and provides methods to initialize the client,
 42 |     get an embedder, and generate responses from the language model.
 43 | 
 44 |     Attributes:
 45 |         client (AsyncOpenAI): The OpenAI client used to interact with the API.
 46 |         model (str): The model name to use for generating responses.
 47 |         temperature (float): The temperature to use for generating responses.
 48 |         max_tokens (int): The maximum number of tokens to generate in a response.
 49 | 
 50 |     Methods:
 51 |         __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
 52 |             Initializes the OpenAIClient with the provided configuration, cache setting, and client.
 53 | 
 54 |         _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
 55 |             Generates a response from the language model based on the provided messages.
 56 |     """
 57 | 
 58 |     # Class-level constants
 59 |     MAX_RETRIES: ClassVar[int] = 2
 60 | 
 61 |     def __init__(
 62 |         self,
 63 |         config: LLMConfig | None = None,
 64 |         cache: bool = False,
 65 |         client: typing.Any = None,
 66 |         max_tokens: int = 16384,
 67 |     ):
 68 |         """
 69 |         Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
 70 | 
 71 |         Args:
 72 |             config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
 73 |             cache (bool): Whether to use caching for responses. Defaults to False.
 74 |             client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
 75 |             max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
 76 | 
 77 |         """
 78 |         # removed caching to simplify the `generate_response` override
 79 |         if cache:
 80 |             raise NotImplementedError('Caching is not implemented for OpenAI')
 81 | 
 82 |         if config is None:
 83 |             config = LLMConfig()
 84 | 
 85 |         super().__init__(config, cache)
 86 | 
 87 |         # Override max_tokens to support higher limits for local models
 88 |         self.max_tokens = max_tokens
 89 | 
 90 |         if client is None:
 91 |             self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
 92 |         else:
 93 |             self.client = client
 94 | 
 95 |     async def _generate_response(
 96 |         self,
 97 |         messages: list[Message],
 98 |         response_model: type[BaseModel] | None = None,
 99 |         max_tokens: int = DEFAULT_MAX_TOKENS,
100 |         model_size: ModelSize = ModelSize.medium,
101 |     ) -> dict[str, typing.Any]:
102 |         openai_messages: list[ChatCompletionMessageParam] = []
103 |         for m in messages:
104 |             m.content = self._clean_input(m.content)
105 |             if m.role == 'user':
106 |                 openai_messages.append({'role': 'user', 'content': m.content})
107 |             elif m.role == 'system':
108 |                 openai_messages.append({'role': 'system', 'content': m.content})
109 |         try:
110 |             # Prepare response format
111 |             response_format: dict[str, Any] = {'type': 'json_object'}
112 |             if response_model is not None:
113 |                 schema_name = getattr(response_model, '__name__', 'structured_response')
114 |                 json_schema = response_model.model_json_schema()
115 |                 response_format = {
116 |                     'type': 'json_schema',
117 |                     'json_schema': {
118 |                         'name': schema_name,
119 |                         'schema': json_schema,
120 |                     },
121 |                 }
122 | 
123 |             response = await self.client.chat.completions.create(
124 |                 model=self.model or DEFAULT_MODEL,
125 |                 messages=openai_messages,
126 |                 temperature=self.temperature,
127 |                 max_tokens=self.max_tokens,
128 |                 response_format=response_format,  # type: ignore[arg-type]
129 |             )
130 |             result = response.choices[0].message.content or ''
131 |             return json.loads(result)
132 |         except openai.RateLimitError as e:
133 |             raise RateLimitError from e
134 |         except Exception as e:
135 |             logger.error(f'Error in generating LLM response: {e}')
136 |             raise
137 | 
138 |     async def generate_response(
139 |         self,
140 |         messages: list[Message],
141 |         response_model: type[BaseModel] | None = None,
142 |         max_tokens: int | None = None,
143 |         model_size: ModelSize = ModelSize.medium,
144 |         group_id: str | None = None,
145 |         prompt_name: str | None = None,
146 |     ) -> dict[str, typing.Any]:
147 |         if max_tokens is None:
148 |             max_tokens = self.max_tokens
149 | 
150 |         # Add multilingual extraction instructions
151 |         messages[0].content += get_extraction_language_instruction(group_id)
152 | 
153 |         # Wrap entire operation in tracing span
154 |         with self.tracer.start_span('llm.generate') as span:
155 |             attributes = {
156 |                 'llm.provider': 'openai',
157 |                 'model.size': model_size.value,
158 |                 'max_tokens': max_tokens,
159 |             }
160 |             if prompt_name:
161 |                 attributes['prompt.name'] = prompt_name
162 |             span.add_attributes(attributes)
163 | 
164 |             retry_count = 0
165 |             last_error = None
166 | 
167 |             while retry_count <= self.MAX_RETRIES:
168 |                 try:
169 |                     response = await self._generate_response(
170 |                         messages, response_model, max_tokens=max_tokens, model_size=model_size
171 |                     )
172 |                     return response
173 |                 except (RateLimitError, RefusalError):
174 |                     # These errors should not trigger retries
175 |                     span.set_status('error', str(last_error))
176 |                     raise
177 |                 except (
178 |                     openai.APITimeoutError,
179 |                     openai.APIConnectionError,
180 |                     openai.InternalServerError,
181 |                 ):
182 |                     # Let OpenAI's client handle these retries
183 |                     span.set_status('error', str(last_error))
184 |                     raise
185 |                 except Exception as e:
186 |                     last_error = e
187 | 
188 |                     # Don't retry if we've hit the max retries
189 |                     if retry_count >= self.MAX_RETRIES:
190 |                         logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
191 |                         span.set_status('error', str(e))
192 |                         span.record_exception(e)
193 |                         raise
194 | 
195 |                     retry_count += 1
196 | 
197 |                     # Construct a detailed error message for the LLM
198 |                     error_context = (
199 |                         f'The previous response attempt was invalid. '
200 |                         f'Error type: {e.__class__.__name__}. '
201 |                         f'Error details: {str(e)}. '
202 |                         f'Please try again with a valid response, ensuring the output matches '
203 |                         f'the expected format and constraints.'
204 |                     )
205 | 
206 |                     error_message = Message(role='user', content=error_context)
207 |                     messages.append(error_message)
208 |                     logger.warning(
209 |                         f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
210 |                     )
211 | 
212 |             # If we somehow get here, raise the last error
213 |             span.set_status('error', str(last_error))
214 |             raise last_error or Exception('Max retries exceeded with no specific error')
215 | 
```

--------------------------------------------------------------------------------
/graphiti_core/llm_client/client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import hashlib
 18 | import json
 19 | import logging
 20 | import typing
 21 | from abc import ABC, abstractmethod
 22 | 
 23 | import httpx
 24 | from diskcache import Cache
 25 | from pydantic import BaseModel
 26 | from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
 27 | 
 28 | from ..prompts.models import Message
 29 | from ..tracer import NoOpTracer, Tracer
 30 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
 31 | from .errors import RateLimitError
 32 | 
 33 | DEFAULT_TEMPERATURE = 0
 34 | DEFAULT_CACHE_DIR = './llm_cache'
 35 | 
 36 | 
 37 | def get_extraction_language_instruction(group_id: str | None = None) -> str:
 38 |     """Returns instruction for language extraction behavior.
 39 | 
 40 |     Override this function to customize language extraction:
 41 |     - Return empty string to disable multilingual instructions
 42 |     - Return custom instructions for specific language requirements
 43 |     - Use group_id to provide different instructions per group/partition
 44 | 
 45 |     Args:
 46 |         group_id: Optional partition identifier for the graph
 47 | 
 48 |     Returns:
 49 |         str: Language instruction to append to system messages
 50 |     """
 51 |     return '\n\nAny extracted information should be returned in the same language as it was written in.'
 52 | 
 53 | 
 54 | logger = logging.getLogger(__name__)
 55 | 
 56 | 
 57 | def is_server_or_retry_error(exception):
 58 |     if isinstance(exception, RateLimitError | json.decoder.JSONDecodeError):
 59 |         return True
 60 | 
 61 |     return (
 62 |         isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
 63 |     )
 64 | 
 65 | 
 66 | class LLMClient(ABC):
 67 |     def __init__(self, config: LLMConfig | None, cache: bool = False):
 68 |         if config is None:
 69 |             config = LLMConfig()
 70 | 
 71 |         self.config = config
 72 |         self.model = config.model
 73 |         self.small_model = config.small_model
 74 |         self.temperature = config.temperature
 75 |         self.max_tokens = config.max_tokens
 76 |         self.cache_enabled = cache
 77 |         self.cache_dir = None
 78 |         self.tracer: Tracer = NoOpTracer()
 79 | 
 80 |         # Only create the cache directory if caching is enabled
 81 |         if self.cache_enabled:
 82 |             self.cache_dir = Cache(DEFAULT_CACHE_DIR)
 83 | 
 84 |     def set_tracer(self, tracer: Tracer) -> None:
 85 |         """Set the tracer for this LLM client."""
 86 |         self.tracer = tracer
 87 | 
 88 |     def _clean_input(self, input: str) -> str:
 89 |         """Clean input string of invalid unicode and control characters.
 90 | 
 91 |         Args:
 92 |             input: Raw input string to be cleaned
 93 | 
 94 |         Returns:
 95 |             Cleaned string safe for LLM processing
 96 |         """
 97 |         # Clean any invalid Unicode
 98 |         cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
 99 | 
100 |         # Remove zero-width characters and other invisible unicode
101 |         zero_width = '\u200b\u200c\u200d\ufeff\u2060'
102 |         for char in zero_width:
103 |             cleaned = cleaned.replace(char, '')
104 | 
105 |         # Remove control characters except newlines, returns, and tabs
106 |         cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
107 | 
108 |         return cleaned
109 | 
110 |     @retry(
111 |         stop=stop_after_attempt(4),
112 |         wait=wait_random_exponential(multiplier=10, min=5, max=120),
113 |         retry=retry_if_exception(is_server_or_retry_error),
114 |         after=lambda retry_state: logger.warning(
115 |             f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
116 |         )
117 |         if retry_state.attempt_number > 1
118 |         else None,
119 |         reraise=True,
120 |     )
121 |     async def _generate_response_with_retry(
122 |         self,
123 |         messages: list[Message],
124 |         response_model: type[BaseModel] | None = None,
125 |         max_tokens: int = DEFAULT_MAX_TOKENS,
126 |         model_size: ModelSize = ModelSize.medium,
127 |     ) -> dict[str, typing.Any]:
128 |         try:
129 |             return await self._generate_response(messages, response_model, max_tokens, model_size)
130 |         except (httpx.HTTPStatusError, RateLimitError) as e:
131 |             raise e
132 | 
133 |     @abstractmethod
134 |     async def _generate_response(
135 |         self,
136 |         messages: list[Message],
137 |         response_model: type[BaseModel] | None = None,
138 |         max_tokens: int = DEFAULT_MAX_TOKENS,
139 |         model_size: ModelSize = ModelSize.medium,
140 |     ) -> dict[str, typing.Any]:
141 |         pass
142 | 
143 |     def _get_cache_key(self, messages: list[Message]) -> str:
144 |         # Create a unique cache key based on the messages and model
145 |         message_str = json.dumps([m.model_dump() for m in messages], sort_keys=True)
146 |         key_str = f'{self.model}:{message_str}'
147 |         return hashlib.md5(key_str.encode()).hexdigest()
148 | 
149 |     async def generate_response(
150 |         self,
151 |         messages: list[Message],
152 |         response_model: type[BaseModel] | None = None,
153 |         max_tokens: int | None = None,
154 |         model_size: ModelSize = ModelSize.medium,
155 |         group_id: str | None = None,
156 |         prompt_name: str | None = None,
157 |     ) -> dict[str, typing.Any]:
158 |         if max_tokens is None:
159 |             max_tokens = self.max_tokens
160 | 
161 |         if response_model is not None:
162 |             serialized_model = json.dumps(response_model.model_json_schema())
163 |             messages[
164 |                 -1
165 |             ].content += (
166 |                 f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
167 |             )
168 | 
169 |         # Add multilingual extraction instructions
170 |         messages[0].content += get_extraction_language_instruction(group_id)
171 | 
172 |         for message in messages:
173 |             message.content = self._clean_input(message.content)
174 | 
175 |         # Wrap entire operation in tracing span
176 |         with self.tracer.start_span('llm.generate') as span:
177 |             attributes = {
178 |                 'llm.provider': self._get_provider_type(),
179 |                 'model.size': model_size.value,
180 |                 'max_tokens': max_tokens,
181 |                 'cache.enabled': self.cache_enabled,
182 |             }
183 |             if prompt_name:
184 |                 attributes['prompt.name'] = prompt_name
185 |             span.add_attributes(attributes)
186 | 
187 |             # Check cache first
188 |             if self.cache_enabled and self.cache_dir is not None:
189 |                 cache_key = self._get_cache_key(messages)
190 |                 cached_response = self.cache_dir.get(cache_key)
191 |                 if cached_response is not None:
192 |                     logger.debug(f'Cache hit for {cache_key}')
193 |                     span.add_attributes({'cache.hit': True})
194 |                     return cached_response
195 | 
196 |             span.add_attributes({'cache.hit': False})
197 | 
198 |             # Execute LLM call
199 |             try:
200 |                 response = await self._generate_response_with_retry(
201 |                     messages, response_model, max_tokens, model_size
202 |                 )
203 |             except Exception as e:
204 |                 span.set_status('error', str(e))
205 |                 span.record_exception(e)
206 |                 raise
207 | 
208 |             # Cache response if enabled
209 |             if self.cache_enabled and self.cache_dir is not None:
210 |                 cache_key = self._get_cache_key(messages)
211 |                 self.cache_dir.set(cache_key, response)
212 | 
213 |             return response
214 | 
215 |     def _get_provider_type(self) -> str:
216 |         """Get provider type from class name."""
217 |         class_name = self.__class__.__name__.lower()
218 |         if 'openai' in class_name:
219 |             return 'openai'
220 |         elif 'anthropic' in class_name:
221 |             return 'anthropic'
222 |         elif 'gemini' in class_name:
223 |             return 'gemini'
224 |         elif 'groq' in class_name:
225 |             return 'groq'
226 |         else:
227 |             return 'unknown'
228 | 
229 |     def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
230 |         """
231 |         Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
232 |         """
233 |         log = ''
234 |         log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
235 |         if output is not None:
236 |             if len(output) > 4000:
237 |                 log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
238 |             else:
239 |                 log += f'Raw output: {output}\n'
240 |         else:
241 |             log += 'No raw output available'
242 |         return log
243 | 
```

--------------------------------------------------------------------------------
/graphiti_core/prompts/dedupe_nodes.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from typing import Any, Protocol, TypedDict
 18 | 
 19 | from pydantic import BaseModel, Field
 20 | 
 21 | from .models import Message, PromptFunction, PromptVersion
 22 | from .prompt_helpers import to_prompt_json
 23 | 
 24 | 
 25 | class NodeDuplicate(BaseModel):
 26 |     id: int = Field(..., description='integer id of the entity')
 27 |     duplicate_idx: int = Field(
 28 |         ...,
 29 |         description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
 30 |     )
 31 |     name: str = Field(
 32 |         ...,
 33 |         description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
 34 |     )
 35 |     duplicates: list[int] = Field(
 36 |         ...,
 37 |         description='idx of all entities that are a duplicate of the entity with the above id.',
 38 |     )
 39 | 
 40 | 
 41 | class NodeResolutions(BaseModel):
 42 |     entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
 43 | 
 44 | 
 45 | class Prompt(Protocol):
 46 |     node: PromptVersion
 47 |     node_list: PromptVersion
 48 |     nodes: PromptVersion
 49 | 
 50 | 
 51 | class Versions(TypedDict):
 52 |     node: PromptFunction
 53 |     node_list: PromptFunction
 54 |     nodes: PromptFunction
 55 | 
 56 | 
 57 | def node(context: dict[str, Any]) -> list[Message]:
 58 |     return [
 59 |         Message(
 60 |             role='system',
 61 |             content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
 62 |         ),
 63 |         Message(
 64 |             role='user',
 65 |             content=f"""
 66 |         <PREVIOUS MESSAGES>
 67 |         {to_prompt_json([ep for ep in context['previous_episodes']])}
 68 |         </PREVIOUS MESSAGES>
 69 |         <CURRENT MESSAGE>
 70 |         {context['episode_content']}
 71 |         </CURRENT MESSAGE>
 72 |         <NEW ENTITY>
 73 |         {to_prompt_json(context['extracted_node'])}
 74 |         </NEW ENTITY>
 75 |         <ENTITY TYPE DESCRIPTION>
 76 |         {to_prompt_json(context['entity_type_description'])}
 77 |         </ENTITY TYPE DESCRIPTION>
 78 | 
 79 |         <EXISTING ENTITIES>
 80 |         {to_prompt_json(context['existing_nodes'])}
 81 |         </EXISTING ENTITIES>
 82 |         
 83 |         Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
 84 |         is a duplicate entity of one of the EXISTING ENTITIES.
 85 |         
 86 |         Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
 87 |         Semantic Equivalence: if a descriptive label in existing_entities clearly refers to a named entity in context, treat them as duplicates.
 88 | 
 89 |         Do NOT mark entities as duplicates if:
 90 |         - They are related but distinct.
 91 |         - They have similar names or purposes but refer to separate instances or concepts.
 92 | 
 93 |          TASK:
 94 |          1. Compare `new_entity` against each item in `existing_entities`.
 95 |          2. If it refers to the same real-world object or concept, collect its index.
 96 |          3. Let `duplicate_idx` = the smallest collected index, or -1 if none.
 97 |          4. Let `duplicates` = the sorted list of all collected indices (empty list if none).
 98 | 
 99 |         Respond with a JSON object containing an "entity_resolutions" array with a single entry:
100 |         {{
101 |             "entity_resolutions": [
102 |                 {{
103 |                     "id": integer id from NEW ENTITY,
104 |                     "name": the best full name for the entity,
105 |                     "duplicate_idx": integer index of the best duplicate in EXISTING ENTITIES, or -1 if none,
106 |                     "duplicates": sorted list of all duplicate indices you collected (deduplicate the list, use [] when none)
107 |                 }}
108 |             ]
109 |         }}
110 | 
111 |         Only reference indices that appear in EXISTING ENTITIES, and return [] / -1 when unsure.
112 |         """,
113 |         ),
114 |     ]
115 | 
116 | 
117 | def nodes(context: dict[str, Any]) -> list[Message]:
118 |     return [
119 |         Message(
120 |             role='system',
121 |             content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
122 |             ' of existing entities.',
123 |         ),
124 |         Message(
125 |             role='user',
126 |             content=f"""
127 |         <PREVIOUS MESSAGES>
128 |         {to_prompt_json([ep for ep in context['previous_episodes']])}
129 |         </PREVIOUS MESSAGES>
130 |         <CURRENT MESSAGE>
131 |         {context['episode_content']}
132 |         </CURRENT MESSAGE>
133 | 
134 | 
135 |         Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
136 |         Each entity in ENTITIES is represented as a JSON object with the following structure:
137 |         {{
138 |             id: integer id of the entity,
139 |             name: "name of the entity",
140 |             entity_type: ["Entity", "<optional additional label>", ...],
141 |             entity_type_description: "Description of what the entity type represents"
142 |         }}
143 | 
144 |         <ENTITIES>
145 |         {to_prompt_json(context['extracted_nodes'])}
146 |         </ENTITIES>
147 | 
148 |         <EXISTING ENTITIES>
149 |         {to_prompt_json(context['existing_nodes'])}
150 |         </EXISTING ENTITIES>
151 | 
152 |         Each entry in EXISTING ENTITIES is an object with the following structure:
153 |         {{
154 |             idx: integer index of the candidate entity (use this when referencing a duplicate),
155 |             name: "name of the candidate entity",
156 |             entity_types: ["Entity", "<optional additional label>", ...],
157 |             ...<additional attributes such as summaries or metadata>
158 |         }}
159 | 
160 |         For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
161 | 
162 |         Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
163 | 
164 |         Do NOT mark entities as duplicates if:
165 |         - They are related but distinct.
166 |         - They have similar names or purposes but refer to separate instances or concepts.
167 | 
168 |         Task:
169 |         ENTITIES contains {len(context['extracted_nodes'])} entities with IDs 0 through {len(context['extracted_nodes']) - 1}.
170 |         Your response MUST include EXACTLY {len(context['extracted_nodes'])} resolutions with IDs 0 through {len(context['extracted_nodes']) - 1}. Do not skip or add IDs.
171 | 
172 |         For every entity, return an object with the following keys:
173 |         {{
174 |             "id": integer id from ENTITIES,
175 |             "name": the best full name for the entity (preserve the original name unless a duplicate has a more complete name),
176 |             "duplicate_idx": the idx of the EXISTING ENTITY that is the best duplicate match, or -1 if there is no duplicate,
177 |             "duplicates": a sorted list of all idx values from EXISTING ENTITIES that refer to duplicates (deduplicate the list, use [] when none or unsure)
178 |         }}
179 | 
180 |         - Only use idx values that appear in EXISTING ENTITIES.
181 |         - Set duplicate_idx to the smallest idx you collected for that entity, or -1 if duplicates is empty.
182 |         - Never fabricate entities or indices.
183 |         """,
184 |         ),
185 |     ]
186 | 
187 | 
188 | def node_list(context: dict[str, Any]) -> list[Message]:
189 |     return [
190 |         Message(
191 |             role='system',
192 |             content='You are a helpful assistant that de-duplicates nodes from node lists.',
193 |         ),
194 |         Message(
195 |             role='user',
196 |             content=f"""
197 |         Given the following context, deduplicate a list of nodes:
198 | 
199 |         Nodes:
200 |         {to_prompt_json(context['nodes'])}
201 | 
202 |         Task:
203 |         1. Group nodes together such that all duplicate nodes are in the same list of uuids
204 |         2. All duplicate uuids should be grouped together in the same list
205 |         3. Also return a new summary that synthesizes the summary into a new short summary
206 | 
207 |         Guidelines:
208 |         1. Each uuid from the list of nodes should appear EXACTLY once in your response
209 |         2. If a node has no duplicates, it should appear in the response in a list of only one uuid
210 | 
211 |         Respond with a JSON object in the following format:
212 |         {{
213 |             "nodes": [
214 |                 {{
215 |                     "uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
216 |                     "summary": "Brief summary of the node summaries that appear in the list of names."
217 |                 }}
218 |             ]
219 |         }}
220 |         """,
221 |         ),
222 |     ]
223 | 
224 | 
225 | versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
226 | 
```

--------------------------------------------------------------------------------
/graphiti_core/search/search_filters.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from datetime import datetime
 18 | from enum import Enum
 19 | from typing import Any
 20 | 
 21 | from pydantic import BaseModel, Field
 22 | 
 23 | from graphiti_core.driver.driver import GraphProvider
 24 | 
 25 | 
 26 | class ComparisonOperator(Enum):
 27 |     equals = '='
 28 |     not_equals = '<>'
 29 |     greater_than = '>'
 30 |     less_than = '<'
 31 |     greater_than_equal = '>='
 32 |     less_than_equal = '<='
 33 |     is_null = 'IS NULL'
 34 |     is_not_null = 'IS NOT NULL'
 35 | 
 36 | 
 37 | class DateFilter(BaseModel):
 38 |     date: datetime | None = Field(description='A datetime to filter on')
 39 |     comparison_operator: ComparisonOperator = Field(
 40 |         description='Comparison operator for date filter'
 41 |     )
 42 | 
 43 | 
 44 | class SearchFilters(BaseModel):
 45 |     node_labels: list[str] | None = Field(
 46 |         default=None, description='List of node labels to filter on'
 47 |     )
 48 |     edge_types: list[str] | None = Field(
 49 |         default=None, description='List of edge types to filter on'
 50 |     )
 51 |     valid_at: list[list[DateFilter]] | None = Field(default=None)
 52 |     invalid_at: list[list[DateFilter]] | None = Field(default=None)
 53 |     created_at: list[list[DateFilter]] | None = Field(default=None)
 54 |     expired_at: list[list[DateFilter]] | None = Field(default=None)
 55 |     edge_uuids: list[str] | None = Field(default=None)
 56 | 
 57 | 
 58 | def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
 59 |     mapping = {
 60 |         ComparisonOperator.greater_than: 'gt',
 61 |         ComparisonOperator.less_than: 'lt',
 62 |         ComparisonOperator.greater_than_equal: 'gte',
 63 |         ComparisonOperator.less_than_equal: 'lte',
 64 |     }
 65 |     return mapping.get(op, op.value)
 66 | 
 67 | 
 68 | def node_search_filter_query_constructor(
 69 |     filters: SearchFilters,
 70 |     provider: GraphProvider,
 71 | ) -> tuple[list[str], dict[str, Any]]:
 72 |     filter_queries: list[str] = []
 73 |     filter_params: dict[str, Any] = {}
 74 | 
 75 |     if filters.node_labels is not None:
 76 |         if provider == GraphProvider.KUZU:
 77 |             node_label_filter = 'list_has_all(n.labels, $labels)'
 78 |             filter_params['labels'] = filters.node_labels
 79 |         else:
 80 |             node_labels = '|'.join(filters.node_labels)
 81 |             node_label_filter = 'n:' + node_labels
 82 |         filter_queries.append(node_label_filter)
 83 | 
 84 |     return filter_queries, filter_params
 85 | 
 86 | 
 87 | def date_filter_query_constructor(
 88 |     value_name: str, param_name: str, operator: ComparisonOperator
 89 | ) -> str:
 90 |     query = '(' + value_name + ' '
 91 | 
 92 |     if operator == ComparisonOperator.is_null or operator == ComparisonOperator.is_not_null:
 93 |         query += operator.value + ')'
 94 |     else:
 95 |         query += operator.value + ' ' + param_name + ')'
 96 | 
 97 |     return query
 98 | 
 99 | 
100 | def edge_search_filter_query_constructor(
101 |     filters: SearchFilters,
102 |     provider: GraphProvider,
103 | ) -> tuple[list[str], dict[str, Any]]:
104 |     filter_queries: list[str] = []
105 |     filter_params: dict[str, Any] = {}
106 | 
107 |     if filters.edge_types is not None:
108 |         edge_types = filters.edge_types
109 |         filter_queries.append('e.name in $edge_types')
110 |         filter_params['edge_types'] = edge_types
111 | 
112 |     if filters.edge_uuids is not None:
113 |         filter_queries.append('e.uuid in $edge_uuids')
114 |         filter_params['edge_uuids'] = filters.edge_uuids
115 | 
116 |     if filters.node_labels is not None:
117 |         if provider == GraphProvider.KUZU:
118 |             node_label_filter = (
119 |                 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
120 |             )
121 |             filter_params['labels'] = filters.node_labels
122 |         else:
123 |             node_labels = '|'.join(filters.node_labels)
124 |             node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
125 |         filter_queries.append(node_label_filter)
126 | 
127 |     if filters.valid_at is not None:
128 |         valid_at_filter = '('
129 |         for i, or_list in enumerate(filters.valid_at):
130 |             for j, date_filter in enumerate(or_list):
131 |                 if date_filter.comparison_operator not in [
132 |                     ComparisonOperator.is_null,
133 |                     ComparisonOperator.is_not_null,
134 |                 ]:
135 |                     filter_params['valid_at_' + str(j)] = date_filter.date
136 | 
137 |             and_filters = [
138 |                 date_filter_query_constructor(
139 |                     'e.valid_at', f'$valid_at_{j}', date_filter.comparison_operator
140 |                 )
141 |                 for j, date_filter in enumerate(or_list)
142 |             ]
143 |             and_filter_query = ''
144 |             for j, and_filter in enumerate(and_filters):
145 |                 and_filter_query += and_filter
146 |                 if j != len(and_filters) - 1:
147 |                     and_filter_query += ' AND '
148 | 
149 |             valid_at_filter += and_filter_query
150 | 
151 |             if i == len(filters.valid_at) - 1:
152 |                 valid_at_filter += ')'
153 |             else:
154 |                 valid_at_filter += ' OR '
155 | 
156 |         filter_queries.append(valid_at_filter)
157 | 
158 |     if filters.invalid_at is not None:
159 |         invalid_at_filter = '('
160 |         for i, or_list in enumerate(filters.invalid_at):
161 |             for j, date_filter in enumerate(or_list):
162 |                 if date_filter.comparison_operator not in [
163 |                     ComparisonOperator.is_null,
164 |                     ComparisonOperator.is_not_null,
165 |                 ]:
166 |                     filter_params['invalid_at_' + str(j)] = date_filter.date
167 | 
168 |             and_filters = [
169 |                 date_filter_query_constructor(
170 |                     'e.invalid_at', f'$invalid_at_{j}', date_filter.comparison_operator
171 |                 )
172 |                 for j, date_filter in enumerate(or_list)
173 |             ]
174 |             and_filter_query = ''
175 |             for j, and_filter in enumerate(and_filters):
176 |                 and_filter_query += and_filter
177 |                 if j != len(and_filters) - 1:
178 |                     and_filter_query += ' AND '
179 | 
180 |             invalid_at_filter += and_filter_query
181 | 
182 |             if i == len(filters.invalid_at) - 1:
183 |                 invalid_at_filter += ')'
184 |             else:
185 |                 invalid_at_filter += ' OR '
186 | 
187 |         filter_queries.append(invalid_at_filter)
188 | 
189 |     if filters.created_at is not None:
190 |         created_at_filter = '('
191 |         for i, or_list in enumerate(filters.created_at):
192 |             for j, date_filter in enumerate(or_list):
193 |                 if date_filter.comparison_operator not in [
194 |                     ComparisonOperator.is_null,
195 |                     ComparisonOperator.is_not_null,
196 |                 ]:
197 |                     filter_params['created_at_' + str(j)] = date_filter.date
198 | 
199 |             and_filters = [
200 |                 date_filter_query_constructor(
201 |                     'e.created_at', f'$created_at_{j}', date_filter.comparison_operator
202 |                 )
203 |                 for j, date_filter in enumerate(or_list)
204 |             ]
205 |             and_filter_query = ''
206 |             for j, and_filter in enumerate(and_filters):
207 |                 and_filter_query += and_filter
208 |                 if j != len(and_filters) - 1:
209 |                     and_filter_query += ' AND '
210 | 
211 |             created_at_filter += and_filter_query
212 | 
213 |             if i == len(filters.created_at) - 1:
214 |                 created_at_filter += ')'
215 |             else:
216 |                 created_at_filter += ' OR '
217 | 
218 |         filter_queries.append(created_at_filter)
219 | 
220 |     if filters.expired_at is not None:
221 |         expired_at_filter = '('
222 |         for i, or_list in enumerate(filters.expired_at):
223 |             for j, date_filter in enumerate(or_list):
224 |                 if date_filter.comparison_operator not in [
225 |                     ComparisonOperator.is_null,
226 |                     ComparisonOperator.is_not_null,
227 |                 ]:
228 |                     filter_params['expired_at_' + str(j)] = date_filter.date
229 | 
230 |             and_filters = [
231 |                 date_filter_query_constructor(
232 |                     'e.expired_at', f'$expired_at_{j}', date_filter.comparison_operator
233 |                 )
234 |                 for j, date_filter in enumerate(or_list)
235 |             ]
236 |             and_filter_query = ''
237 |             for j, and_filter in enumerate(and_filters):
238 |                 and_filter_query += and_filter
239 |                 if j != len(and_filters) - 1:
240 |                     and_filter_query += ' AND '
241 | 
242 |             expired_at_filter += and_filter_query
243 | 
244 |             if i == len(filters.expired_at) - 1:
245 |                 expired_at_filter += ')'
246 |             else:
247 |                 expired_at_filter += ' OR '
248 | 
249 |         filter_queries.append(expired_at_filter)
250 | 
251 |     return filter_queries, filter_params
252 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_http_integration.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Integration test for MCP server using HTTP streaming transport.
  4 | This avoids the stdio subprocess timing issues.
  5 | """
  6 | 
  7 | import asyncio
  8 | import json
  9 | import sys
 10 | import time
 11 | 
 12 | from mcp.client.session import ClientSession
 13 | 
 14 | 
 15 | async def test_http_transport(base_url: str = 'http://localhost:8000'):
 16 |     """Test MCP server with HTTP streaming transport."""
 17 | 
 18 |     # Import the streamable http client
 19 |     try:
 20 |         from mcp.client.streamable_http import streamablehttp_client as http_client
 21 |     except ImportError:
 22 |         print('❌ Streamable HTTP client not available in MCP SDK')
 23 |         return False
 24 | 
 25 |     test_group_id = f'test_http_{int(time.time())}'
 26 | 
 27 |     print('🚀 Testing MCP Server with HTTP streaming transport')
 28 |     print(f'   Server URL: {base_url}')
 29 |     print(f'   Test Group: {test_group_id}')
 30 |     print('=' * 60)
 31 | 
 32 |     try:
 33 |         # Connect to the server via HTTP
 34 |         print('\n🔌 Connecting to server...')
 35 |         async with http_client(base_url) as (read_stream, write_stream):
 36 |             session = ClientSession(read_stream, write_stream)
 37 |             await session.initialize()
 38 |             print('✅ Connected successfully')
 39 | 
 40 |             # Test 1: List tools
 41 |             print('\n📋 Test 1: Listing tools...')
 42 |             try:
 43 |                 result = await session.list_tools()
 44 |                 tools = [tool.name for tool in result.tools]
 45 | 
 46 |                 expected = [
 47 |                     'add_memory',
 48 |                     'search_memory_nodes',
 49 |                     'search_memory_facts',
 50 |                     'get_episodes',
 51 |                     'delete_episode',
 52 |                     'clear_graph',
 53 |                 ]
 54 | 
 55 |                 found = [t for t in expected if t in tools]
 56 |                 print(f'   ✅ Found {len(tools)} tools ({len(found)}/{len(expected)} expected)')
 57 |                 for tool in tools[:5]:
 58 |                     print(f'      - {tool}')
 59 | 
 60 |             except Exception as e:
 61 |                 print(f'   ❌ Failed: {e}')
 62 |                 return False
 63 | 
 64 |             # Test 2: Add memory
 65 |             print('\n📝 Test 2: Adding memory...')
 66 |             try:
 67 |                 result = await session.call_tool(
 68 |                     'add_memory',
 69 |                     {
 70 |                         'name': 'Integration Test Episode',
 71 |                         'episode_body': 'This is a test episode created via HTTP transport integration test.',
 72 |                         'group_id': test_group_id,
 73 |                         'source': 'text',
 74 |                         'source_description': 'HTTP Integration Test',
 75 |                     },
 76 |                 )
 77 | 
 78 |                 if result.content and result.content[0].text:
 79 |                     response = result.content[0].text
 80 |                     if 'success' in response.lower() or 'queued' in response.lower():
 81 |                         print('   ✅ Memory added successfully')
 82 |                     else:
 83 |                         print(f'   ❌ Unexpected response: {response[:100]}')
 84 |                 else:
 85 |                     print('   ❌ No content in response')
 86 | 
 87 |             except Exception as e:
 88 |                 print(f'   ❌ Failed: {e}')
 89 | 
 90 |             # Test 3: Search nodes (with delay for processing)
 91 |             print('\n🔍 Test 3: Searching nodes...')
 92 |             await asyncio.sleep(2)  # Wait for async processing
 93 | 
 94 |             try:
 95 |                 result = await session.call_tool(
 96 |                     'search_memory_nodes',
 97 |                     {'query': 'integration test episode', 'group_ids': [test_group_id], 'limit': 5},
 98 |                 )
 99 | 
100 |                 if result.content and result.content[0].text:
101 |                     response = result.content[0].text
102 |                     try:
103 |                         data = json.loads(response)
104 |                         nodes = data.get('nodes', [])
105 |                         print(f'   ✅ Search returned {len(nodes)} nodes')
106 |                     except Exception:  # noqa: E722
107 |                         print(f'   ✅ Search completed: {response[:100]}')
108 |                 else:
109 |                     print('   ⚠️  No results (may be processing)')
110 | 
111 |             except Exception as e:
112 |                 print(f'   ❌ Failed: {e}')
113 | 
114 |             # Test 4: Get episodes
115 |             print('\n📚 Test 4: Getting episodes...')
116 |             try:
117 |                 result = await session.call_tool(
118 |                     'get_episodes', {'group_ids': [test_group_id], 'limit': 10}
119 |                 )
120 | 
121 |                 if result.content and result.content[0].text:
122 |                     response = result.content[0].text
123 |                     try:
124 |                         data = json.loads(response)
125 |                         episodes = data.get('episodes', [])
126 |                         print(f'   ✅ Found {len(episodes)} episodes')
127 |                     except Exception:  # noqa: E722
128 |                         print(f'   ✅ Episodes retrieved: {response[:100]}')
129 |                 else:
130 |                     print('   ⚠️  No episodes found')
131 | 
132 |             except Exception as e:
133 |                 print(f'   ❌ Failed: {e}')
134 | 
135 |             # Test 5: Clear graph
136 |             print('\n🧹 Test 5: Clearing graph...')
137 |             try:
138 |                 result = await session.call_tool('clear_graph', {'group_id': test_group_id})
139 | 
140 |                 if result.content and result.content[0].text:
141 |                     response = result.content[0].text
142 |                     if 'success' in response.lower() or 'cleared' in response.lower():
143 |                         print('   ✅ Graph cleared successfully')
144 |                     else:
145 |                         print(f'   ✅ Clear completed: {response[:100]}')
146 |                 else:
147 |                     print('   ❌ No response')
148 | 
149 |             except Exception as e:
150 |                 print(f'   ❌ Failed: {e}')
151 | 
152 |             print('\n' + '=' * 60)
153 |             print('✅ All integration tests completed!')
154 |             return True
155 | 
156 |     except Exception as e:
157 |         print(f'\n❌ Connection failed: {e}')
158 |         return False
159 | 
160 | 
161 | async def test_sse_transport(base_url: str = 'http://localhost:8000'):
162 |     """Test MCP server with SSE transport."""
163 | 
164 |     # Import the SSE client
165 |     try:
166 |         from mcp.client.sse import sse_client
167 |     except ImportError:
168 |         print('❌ SSE client not available in MCP SDK')
169 |         return False
170 | 
171 |     test_group_id = f'test_sse_{int(time.time())}'
172 | 
173 |     print('🚀 Testing MCP Server with SSE transport')
174 |     print(f'   Server URL: {base_url}/sse')
175 |     print(f'   Test Group: {test_group_id}')
176 |     print('=' * 60)
177 | 
178 |     try:
179 |         # Connect to the server via SSE
180 |         print('\n🔌 Connecting to server...')
181 |         async with sse_client(f'{base_url}/sse') as (read_stream, write_stream):
182 |             session = ClientSession(read_stream, write_stream)
183 |             await session.initialize()
184 |             print('✅ Connected successfully')
185 | 
186 |             # Run same tests as HTTP
187 |             print('\n📋 Test 1: Listing tools...')
188 |             try:
189 |                 result = await session.list_tools()
190 |                 tools = [tool.name for tool in result.tools]
191 |                 print(f'   ✅ Found {len(tools)} tools')
192 |                 for tool in tools[:3]:
193 |                     print(f'      - {tool}')
194 |             except Exception as e:
195 |                 print(f'   ❌ Failed: {e}')
196 |                 return False
197 | 
198 |             print('\n' + '=' * 60)
199 |             print('✅ SSE transport test completed!')
200 |             return True
201 | 
202 |     except Exception as e:
203 |         print(f'\n❌ SSE connection failed: {e}')
204 |         return False
205 | 
206 | 
207 | async def main():
208 |     """Run integration tests."""
209 | 
210 |     # Check command line arguments
211 |     if len(sys.argv) < 2:
212 |         print('Usage: python test_http_integration.py <transport> [host] [port]')
213 |         print('  transport: http or sse')
214 |         print('  host: server host (default: localhost)')
215 |         print('  port: server port (default: 8000)')
216 |         sys.exit(1)
217 | 
218 |     transport = sys.argv[1].lower()
219 |     host = sys.argv[2] if len(sys.argv) > 2 else 'localhost'
220 |     port = sys.argv[3] if len(sys.argv) > 3 else '8000'
221 |     base_url = f'http://{host}:{port}'
222 | 
223 |     # Check if server is running
224 |     import httpx
225 | 
226 |     try:
227 |         async with httpx.AsyncClient() as client:
228 |             # Try to connect to the server
229 |             await client.get(base_url, timeout=2.0)
230 |     except Exception:  # noqa: E722
231 |         print(f'⚠️  Server not responding at {base_url}')
232 |         print('Please start the server with one of these commands:')
233 |         print(f'  uv run main.py --transport http --port {port}')
234 |         print(f'  uv run main.py --transport sse --port {port}')
235 |         sys.exit(1)
236 | 
237 |     # Run the appropriate test
238 |     if transport == 'http':
239 |         success = await test_http_transport(base_url)
240 |     elif transport == 'sse':
241 |         success = await test_sse_transport(base_url)
242 |     else:
243 |         print(f'❌ Unknown transport: {transport}')
244 |         sys.exit(1)
245 | 
246 |     sys.exit(0 if success else 1)
247 | 
248 | 
249 | if __name__ == '__main__':
250 |     asyncio.run(main())
251 | 
```

--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/dedup_helpers.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from __future__ import annotations
 18 | 
 19 | import math
 20 | import re
 21 | from collections import defaultdict
 22 | from collections.abc import Iterable
 23 | from dataclasses import dataclass, field
 24 | from functools import lru_cache
 25 | from hashlib import blake2b
 26 | from typing import TYPE_CHECKING
 27 | 
 28 | if TYPE_CHECKING:
 29 |     from graphiti_core.nodes import EntityNode
 30 | 
 31 | _NAME_ENTROPY_THRESHOLD = 1.5
 32 | _MIN_NAME_LENGTH = 6
 33 | _MIN_TOKEN_COUNT = 2
 34 | _FUZZY_JACCARD_THRESHOLD = 0.9
 35 | _MINHASH_PERMUTATIONS = 32
 36 | _MINHASH_BAND_SIZE = 4
 37 | 
 38 | 
 39 | def _normalize_string_exact(name: str) -> str:
 40 |     """Lowercase text and collapse whitespace so equal names map to the same key."""
 41 |     normalized = re.sub(r'[\s]+', ' ', name.lower())
 42 |     return normalized.strip()
 43 | 
 44 | 
 45 | def _normalize_name_for_fuzzy(name: str) -> str:
 46 |     """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
 47 |     normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
 48 |     normalized = normalized.strip()
 49 |     return re.sub(r'[\s]+', ' ', normalized)
 50 | 
 51 | 
 52 | def _name_entropy(normalized_name: str) -> float:
 53 |     """Approximate text specificity using Shannon entropy over characters.
 54 | 
 55 |     We strip spaces, count how often each character appears, and sum
 56 |     probability * -log2(probability). Short or repetitive names yield low
 57 |     entropy, which signals we should defer resolution to the LLM instead of
 58 |     trusting fuzzy similarity.
 59 |     """
 60 |     if not normalized_name:
 61 |         return 0.0
 62 | 
 63 |     counts: dict[str, int] = {}
 64 |     for char in normalized_name.replace(' ', ''):
 65 |         counts[char] = counts.get(char, 0) + 1
 66 | 
 67 |     total = sum(counts.values())
 68 |     if total == 0:
 69 |         return 0.0
 70 | 
 71 |     entropy = 0.0
 72 |     for count in counts.values():
 73 |         probability = count / total
 74 |         entropy -= probability * math.log2(probability)
 75 | 
 76 |     return entropy
 77 | 
 78 | 
 79 | def _has_high_entropy(normalized_name: str) -> bool:
 80 |     """Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
 81 |     token_count = len(normalized_name.split())
 82 |     if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
 83 |         return False
 84 | 
 85 |     return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
 86 | 
 87 | 
 88 | def _shingles(normalized_name: str) -> set[str]:
 89 |     """Create 3-gram shingles from the normalized name for MinHash calculations."""
 90 |     cleaned = normalized_name.replace(' ', '')
 91 |     if len(cleaned) < 2:
 92 |         return {cleaned} if cleaned else set()
 93 | 
 94 |     return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
 95 | 
 96 | 
 97 | def _hash_shingle(shingle: str, seed: int) -> int:
 98 |     """Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
 99 |     digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
100 |     return int.from_bytes(digest.digest(), 'big')
101 | 
102 | 
103 | def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
104 |     """Compute the MinHash signature for the shingle set across predefined permutations."""
105 |     if not shingles:
106 |         return tuple()
107 | 
108 |     seeds = range(_MINHASH_PERMUTATIONS)
109 |     signature: list[int] = []
110 |     for seed in seeds:
111 |         min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
112 |         signature.append(min_hash)
113 | 
114 |     return tuple(signature)
115 | 
116 | 
117 | def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
118 |     """Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
119 |     signature_list = list(signature)
120 |     if not signature_list:
121 |         return []
122 | 
123 |     bands: list[tuple[int, ...]] = []
124 |     for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
125 |         band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
126 |         if len(band) == _MINHASH_BAND_SIZE:
127 |             bands.append(band)
128 |     return bands
129 | 
130 | 
131 | def _jaccard_similarity(a: set[str], b: set[str]) -> float:
132 |     """Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
133 |     if not a and not b:
134 |         return 1.0
135 |     if not a or not b:
136 |         return 0.0
137 | 
138 |     intersection = len(a.intersection(b))
139 |     union = len(a.union(b))
140 |     return intersection / union if union else 0.0
141 | 
142 | 
143 | @lru_cache(maxsize=512)
144 | def _cached_shingles(name: str) -> set[str]:
145 |     """Cache shingle sets per normalized name to avoid recomputation within a worker."""
146 |     return _shingles(name)
147 | 
148 | 
149 | @dataclass
150 | class DedupCandidateIndexes:
151 |     """Precomputed lookup structures that drive entity deduplication heuristics."""
152 | 
153 |     existing_nodes: list[EntityNode]
154 |     nodes_by_uuid: dict[str, EntityNode]
155 |     normalized_existing: defaultdict[str, list[EntityNode]]
156 |     shingles_by_candidate: dict[str, set[str]]
157 |     lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
158 | 
159 | 
160 | @dataclass
161 | class DedupResolutionState:
162 |     """Mutable resolution bookkeeping shared across deterministic and LLM passes."""
163 | 
164 |     resolved_nodes: list[EntityNode | None]
165 |     uuid_map: dict[str, str]
166 |     unresolved_indices: list[int]
167 |     duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
168 | 
169 | 
170 | def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
171 |     """Precompute exact and fuzzy lookup structures once per dedupe run."""
172 |     normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
173 |     nodes_by_uuid: dict[str, EntityNode] = {}
174 |     shingles_by_candidate: dict[str, set[str]] = {}
175 |     lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
176 | 
177 |     for candidate in existing_nodes:
178 |         normalized = _normalize_string_exact(candidate.name)
179 |         normalized_existing[normalized].append(candidate)
180 |         nodes_by_uuid[candidate.uuid] = candidate
181 | 
182 |         shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
183 |         shingles_by_candidate[candidate.uuid] = shingles
184 | 
185 |         signature = _minhash_signature(shingles)
186 |         for band_index, band in enumerate(_lsh_bands(signature)):
187 |             lsh_buckets[(band_index, band)].append(candidate.uuid)
188 | 
189 |     return DedupCandidateIndexes(
190 |         existing_nodes=existing_nodes,
191 |         nodes_by_uuid=nodes_by_uuid,
192 |         normalized_existing=normalized_existing,
193 |         shingles_by_candidate=shingles_by_candidate,
194 |         lsh_buckets=lsh_buckets,
195 |     )
196 | 
197 | 
198 | def _resolve_with_similarity(
199 |     extracted_nodes: list[EntityNode],
200 |     indexes: DedupCandidateIndexes,
201 |     state: DedupResolutionState,
202 | ) -> None:
203 |     """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
204 |     for idx, node in enumerate(extracted_nodes):
205 |         normalized_exact = _normalize_string_exact(node.name)
206 |         normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
207 | 
208 |         if not _has_high_entropy(normalized_fuzzy):
209 |             state.unresolved_indices.append(idx)
210 |             continue
211 | 
212 |         existing_matches = indexes.normalized_existing.get(normalized_exact, [])
213 |         if len(existing_matches) == 1:
214 |             match = existing_matches[0]
215 |             state.resolved_nodes[idx] = match
216 |             state.uuid_map[node.uuid] = match.uuid
217 |             if match.uuid != node.uuid:
218 |                 state.duplicate_pairs.append((node, match))
219 |             continue
220 |         if len(existing_matches) > 1:
221 |             state.unresolved_indices.append(idx)
222 |             continue
223 | 
224 |         shingles = _cached_shingles(normalized_fuzzy)
225 |         signature = _minhash_signature(shingles)
226 |         candidate_ids: set[str] = set()
227 |         for band_index, band in enumerate(_lsh_bands(signature)):
228 |             candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
229 | 
230 |         best_candidate: EntityNode | None = None
231 |         best_score = 0.0
232 |         for candidate_id in candidate_ids:
233 |             candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
234 |             score = _jaccard_similarity(shingles, candidate_shingles)
235 |             if score > best_score:
236 |                 best_score = score
237 |                 best_candidate = indexes.nodes_by_uuid.get(candidate_id)
238 | 
239 |         if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
240 |             state.resolved_nodes[idx] = best_candidate
241 |             state.uuid_map[node.uuid] = best_candidate.uuid
242 |             if best_candidate.uuid != node.uuid:
243 |                 state.duplicate_pairs.append((node, best_candidate))
244 |             continue
245 | 
246 |         state.unresolved_indices.append(idx)
247 | 
248 | 
249 | __all__ = [
250 |     'DedupCandidateIndexes',
251 |     'DedupResolutionState',
252 |     '_normalize_string_exact',
253 |     '_normalize_name_for_fuzzy',
254 |     '_has_high_entropy',
255 |     '_minhash_signature',
256 |     '_lsh_bands',
257 |     '_jaccard_similarity',
258 |     '_cached_shingles',
259 |     '_FUZZY_JACCARD_THRESHOLD',
260 |     '_build_candidate_indexes',
261 |     '_resolve_with_similarity',
262 | ]
263 | 
```

--------------------------------------------------------------------------------
/examples/quickstart/quickstart_neo4j.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2025, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import json
 19 | import logging
 20 | import os
 21 | from datetime import datetime, timezone
 22 | from logging import INFO
 23 | 
 24 | from dotenv import load_dotenv
 25 | 
 26 | from graphiti_core import Graphiti
 27 | from graphiti_core.nodes import EpisodeType
 28 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
 29 | 
 30 | #################################################
 31 | # CONFIGURATION
 32 | #################################################
 33 | # Set up logging and environment variables for
 34 | # connecting to Neo4j database
 35 | #################################################
 36 | 
 37 | # Configure logging
 38 | logging.basicConfig(
 39 |     level=INFO,
 40 |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 41 |     datefmt='%Y-%m-%d %H:%M:%S',
 42 | )
 43 | logger = logging.getLogger(__name__)
 44 | 
 45 | load_dotenv()
 46 | 
 47 | # Neo4j connection parameters
 48 | # Make sure Neo4j Desktop is running with a local DBMS started
 49 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
 50 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
 51 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
 52 | 
 53 | if not neo4j_uri or not neo4j_user or not neo4j_password:
 54 |     raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
 55 | 
 56 | 
 57 | async def main():
 58 |     #################################################
 59 |     # INITIALIZATION
 60 |     #################################################
 61 |     # Connect to Neo4j and set up Graphiti indices
 62 |     # This is required before using other Graphiti
 63 |     # functionality
 64 |     #################################################
 65 | 
 66 |     # Initialize Graphiti with Neo4j connection
 67 |     graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
 68 | 
 69 |     try:
 70 |         #################################################
 71 |         # ADDING EPISODES
 72 |         #################################################
 73 |         # Episodes are the primary units of information
 74 |         # in Graphiti. They can be text or structured JSON
 75 |         # and are automatically processed to extract entities
 76 |         # and relationships.
 77 |         #################################################
 78 | 
 79 |         # Example: Add Episodes
 80 |         # Episodes list containing both text and JSON episodes
 81 |         episodes = [
 82 |             {
 83 |                 'content': 'Kamala Harris is the Attorney General of California. She was previously '
 84 |                 'the district attorney for San Francisco.',
 85 |                 'type': EpisodeType.text,
 86 |                 'description': 'podcast transcript',
 87 |             },
 88 |             {
 89 |                 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
 90 |                 'type': EpisodeType.text,
 91 |                 'description': 'podcast transcript',
 92 |             },
 93 |             {
 94 |                 'content': {
 95 |                     'name': 'Gavin Newsom',
 96 |                     'position': 'Governor',
 97 |                     'state': 'California',
 98 |                     'previous_role': 'Lieutenant Governor',
 99 |                     'previous_location': 'San Francisco',
100 |                 },
101 |                 'type': EpisodeType.json,
102 |                 'description': 'podcast metadata',
103 |             },
104 |             {
105 |                 'content': {
106 |                     'name': 'Gavin Newsom',
107 |                     'position': 'Governor',
108 |                     'term_start': 'January 7, 2019',
109 |                     'term_end': 'Present',
110 |                 },
111 |                 'type': EpisodeType.json,
112 |                 'description': 'podcast metadata',
113 |             },
114 |         ]
115 | 
116 |         # Add episodes to the graph
117 |         for i, episode in enumerate(episodes):
118 |             await graphiti.add_episode(
119 |                 name=f'Freakonomics Radio {i}',
120 |                 episode_body=episode['content']
121 |                 if isinstance(episode['content'], str)
122 |                 else json.dumps(episode['content']),
123 |                 source=episode['type'],
124 |                 source_description=episode['description'],
125 |                 reference_time=datetime.now(timezone.utc),
126 |             )
127 |             print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
128 | 
129 |         #################################################
130 |         # BASIC SEARCH
131 |         #################################################
132 |         # The simplest way to retrieve relationships (edges)
133 |         # from Graphiti is using the search method, which
134 |         # performs a hybrid search combining semantic
135 |         # similarity and BM25 text retrieval.
136 |         #################################################
137 | 
138 |         # Perform a hybrid search combining semantic similarity and BM25 retrieval
139 |         print("\nSearching for: 'Who was the California Attorney General?'")
140 |         results = await graphiti.search('Who was the California Attorney General?')
141 | 
142 |         # Print search results
143 |         print('\nSearch Results:')
144 |         for result in results:
145 |             print(f'UUID: {result.uuid}')
146 |             print(f'Fact: {result.fact}')
147 |             if hasattr(result, 'valid_at') and result.valid_at:
148 |                 print(f'Valid from: {result.valid_at}')
149 |             if hasattr(result, 'invalid_at') and result.invalid_at:
150 |                 print(f'Valid until: {result.invalid_at}')
151 |             print('---')
152 | 
153 |         #################################################
154 |         # CENTER NODE SEARCH
155 |         #################################################
156 |         # For more contextually relevant results, you can
157 |         # use a center node to rerank search results based
158 |         # on their graph distance to a specific node
159 |         #################################################
160 | 
161 |         # Use the top search result's UUID as the center node for reranking
162 |         if results and len(results) > 0:
163 |             # Get the source node UUID from the top result
164 |             center_node_uuid = results[0].source_node_uuid
165 | 
166 |             print('\nReranking search results based on graph distance:')
167 |             print(f'Using center node UUID: {center_node_uuid}')
168 | 
169 |             reranked_results = await graphiti.search(
170 |                 'Who was the California Attorney General?', center_node_uuid=center_node_uuid
171 |             )
172 | 
173 |             # Print reranked search results
174 |             print('\nReranked Search Results:')
175 |             for result in reranked_results:
176 |                 print(f'UUID: {result.uuid}')
177 |                 print(f'Fact: {result.fact}')
178 |                 if hasattr(result, 'valid_at') and result.valid_at:
179 |                     print(f'Valid from: {result.valid_at}')
180 |                 if hasattr(result, 'invalid_at') and result.invalid_at:
181 |                     print(f'Valid until: {result.invalid_at}')
182 |                 print('---')
183 |         else:
184 |             print('No results found in the initial search to use as center node.')
185 | 
186 |         #################################################
187 |         # NODE SEARCH USING SEARCH RECIPES
188 |         #################################################
189 |         # Graphiti provides predefined search recipes
190 |         # optimized for different search scenarios.
191 |         # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
192 |         # nodes directly instead of edges.
193 |         #################################################
194 | 
195 |         # Example: Perform a node search using _search method with standard recipes
196 |         print(
197 |             '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
198 |         )
199 | 
200 |         # Use a predefined search configuration recipe and modify its limit
201 |         node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
202 |         node_search_config.limit = 5  # Limit to 5 results
203 | 
204 |         # Execute the node search
205 |         node_search_results = await graphiti._search(
206 |             query='California Governor',
207 |             config=node_search_config,
208 |         )
209 | 
210 |         # Print node search results
211 |         print('\nNode Search Results:')
212 |         for node in node_search_results.nodes:
213 |             print(f'Node UUID: {node.uuid}')
214 |             print(f'Node Name: {node.name}')
215 |             node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
216 |             print(f'Content Summary: {node_summary}')
217 |             print(f'Node Labels: {", ".join(node.labels)}')
218 |             print(f'Created At: {node.created_at}')
219 |             if hasattr(node, 'attributes') and node.attributes:
220 |                 print('Attributes:')
221 |                 for key, value in node.attributes.items():
222 |                     print(f'  {key}: {value}')
223 |             print('---')
224 | 
225 |     finally:
226 |         #################################################
227 |         # CLEANUP
228 |         #################################################
229 |         # Always close the connection to Neo4j when
230 |         # finished to properly release resources
231 |         #################################################
232 | 
233 |         # Close the connection
234 |         await graphiti.close()
235 |         print('\nConnection closed')
236 | 
237 | 
238 | if __name__ == '__main__':
239 |     asyncio.run(main())
240 | 
```

--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_transports.py:
--------------------------------------------------------------------------------

```python
  1 | #!/usr/bin/env python3
  2 | """
  3 | Test MCP server with different transport modes using the MCP SDK.
  4 | Tests both SSE and streaming HTTP transports.
  5 | """
  6 | 
  7 | import asyncio
  8 | import json
  9 | import sys
 10 | import time
 11 | 
 12 | from mcp.client.session import ClientSession
 13 | from mcp.client.sse import sse_client
 14 | 
 15 | 
 16 | class MCPTransportTester:
 17 |     """Test MCP server with different transport modes."""
 18 | 
 19 |     def __init__(self, transport: str = 'sse', host: str = 'localhost', port: int = 8000):
 20 |         self.transport = transport
 21 |         self.host = host
 22 |         self.port = port
 23 |         self.base_url = f'http://{host}:{port}'
 24 |         self.test_group_id = f'test_{transport}_{int(time.time())}'
 25 |         self.session = None
 26 | 
 27 |     async def connect_sse(self) -> ClientSession:
 28 |         """Connect using SSE transport."""
 29 |         print(f'🔌 Connecting to MCP server via SSE at {self.base_url}/sse')
 30 | 
 31 |         # Use the sse_client to connect
 32 |         async with sse_client(self.base_url + '/sse') as (read_stream, write_stream):
 33 |             self.session = ClientSession(read_stream, write_stream)
 34 |             await self.session.initialize()
 35 |             return self.session
 36 | 
 37 |     async def connect_http(self) -> ClientSession:
 38 |         """Connect using streaming HTTP transport."""
 39 |         from mcp.client.http import http_client
 40 | 
 41 |         print(f'🔌 Connecting to MCP server via HTTP at {self.base_url}')
 42 | 
 43 |         # Use the http_client to connect
 44 |         async with http_client(self.base_url) as (read_stream, write_stream):
 45 |             self.session = ClientSession(read_stream, write_stream)
 46 |             await self.session.initialize()
 47 |             return self.session
 48 | 
 49 |     async def test_list_tools(self) -> bool:
 50 |         """Test listing available tools."""
 51 |         print('\n📋 Testing list_tools...')
 52 | 
 53 |         try:
 54 |             result = await self.session.list_tools()
 55 |             tools = [tool.name for tool in result.tools]
 56 | 
 57 |             expected_tools = [
 58 |                 'add_memory',
 59 |                 'search_memory_nodes',
 60 |                 'search_memory_facts',
 61 |                 'get_episodes',
 62 |                 'delete_episode',
 63 |                 'get_entity_edge',
 64 |                 'delete_entity_edge',
 65 |                 'clear_graph',
 66 |             ]
 67 | 
 68 |             print(f'   ✅ Found {len(tools)} tools')
 69 |             for tool in tools[:5]:  # Show first 5 tools
 70 |                 print(f'      - {tool}')
 71 | 
 72 |             # Check if we have most expected tools
 73 |             found_tools = [t for t in expected_tools if t in tools]
 74 |             success = len(found_tools) >= len(expected_tools) * 0.8
 75 | 
 76 |             if success:
 77 |                 print(
 78 |                     f'   ✅ Tool discovery successful ({len(found_tools)}/{len(expected_tools)} expected tools)'
 79 |                 )
 80 |             else:
 81 |                 print(f'   ❌ Missing too many tools ({len(found_tools)}/{len(expected_tools)})')
 82 | 
 83 |             return success
 84 |         except Exception as e:
 85 |             print(f'   ❌ Failed to list tools: {e}')
 86 |             return False
 87 | 
 88 |     async def test_add_memory(self) -> bool:
 89 |         """Test adding a memory."""
 90 |         print('\n📝 Testing add_memory...')
 91 | 
 92 |         try:
 93 |             result = await self.session.call_tool(
 94 |                 'add_memory',
 95 |                 {
 96 |                     'name': 'Test Episode',
 97 |                     'episode_body': 'This is a test episode created by the MCP transport test suite.',
 98 |                     'group_id': self.test_group_id,
 99 |                     'source': 'text',
100 |                     'source_description': 'Integration test',
101 |                 },
102 |             )
103 | 
104 |             # Check the result
105 |             if result.content:
106 |                 content = result.content[0]
107 |                 if hasattr(content, 'text'):
108 |                     response = (
109 |                         json.loads(content.text)
110 |                         if content.text.startswith('{')
111 |                         else {'message': content.text}
112 |                     )
113 |                     if 'success' in str(response).lower() or 'queued' in str(response).lower():
114 |                         print(f'   ✅ Memory added successfully: {response.get("message", "OK")}')
115 |                         return True
116 |                     else:
117 |                         print(f'   ❌ Unexpected response: {response}')
118 |                         return False
119 | 
120 |             print('   ❌ No content in response')
121 |             return False
122 | 
123 |         except Exception as e:
124 |             print(f'   ❌ Failed to add memory: {e}')
125 |             return False
126 | 
127 |     async def test_search_nodes(self) -> bool:
128 |         """Test searching for nodes."""
129 |         print('\n🔍 Testing search_memory_nodes...')
130 | 
131 |         # Wait a bit for the memory to be processed
132 |         await asyncio.sleep(2)
133 | 
134 |         try:
135 |             result = await self.session.call_tool(
136 |                 'search_memory_nodes',
137 |                 {'query': 'test episode', 'group_ids': [self.test_group_id], 'limit': 5},
138 |             )
139 | 
140 |             if result.content:
141 |                 content = result.content[0]
142 |                 if hasattr(content, 'text'):
143 |                     response = (
144 |                         json.loads(content.text) if content.text.startswith('{') else {'nodes': []}
145 |                     )
146 |                     nodes = response.get('nodes', [])
147 |                     print(f'   ✅ Search returned {len(nodes)} nodes')
148 |                     return True
149 | 
150 |             print('   ⚠️ No nodes found (this may be expected if processing is async)')
151 |             return True  # Don't fail on empty results
152 | 
153 |         except Exception as e:
154 |             print(f'   ❌ Failed to search nodes: {e}')
155 |             return False
156 | 
157 |     async def test_get_episodes(self) -> bool:
158 |         """Test getting episodes."""
159 |         print('\n📚 Testing get_episodes...')
160 | 
161 |         try:
162 |             result = await self.session.call_tool(
163 |                 'get_episodes', {'group_ids': [self.test_group_id], 'limit': 10}
164 |             )
165 | 
166 |             if result.content:
167 |                 content = result.content[0]
168 |                 if hasattr(content, 'text'):
169 |                     response = (
170 |                         json.loads(content.text)
171 |                         if content.text.startswith('{')
172 |                         else {'episodes': []}
173 |                     )
174 |                     episodes = response.get('episodes', [])
175 |                     print(f'   ✅ Found {len(episodes)} episodes')
176 |                     return True
177 | 
178 |             print('   ⚠️ No episodes found')
179 |             return True
180 | 
181 |         except Exception as e:
182 |             print(f'   ❌ Failed to get episodes: {e}')
183 |             return False
184 | 
185 |     async def test_clear_graph(self) -> bool:
186 |         """Test clearing the graph."""
187 |         print('\n🧹 Testing clear_graph...')
188 | 
189 |         try:
190 |             result = await self.session.call_tool('clear_graph', {'group_id': self.test_group_id})
191 | 
192 |             if result.content:
193 |                 content = result.content[0]
194 |                 if hasattr(content, 'text'):
195 |                     response = content.text
196 |                     if 'success' in response.lower() or 'cleared' in response.lower():
197 |                         print('   ✅ Graph cleared successfully')
198 |                         return True
199 | 
200 |             print('   ❌ Failed to clear graph')
201 |             return False
202 | 
203 |         except Exception as e:
204 |             print(f'   ❌ Failed to clear graph: {e}')
205 |             return False
206 | 
207 |     async def run_tests(self) -> bool:
208 |         """Run all tests for the configured transport."""
209 |         print(f'\n{"=" * 60}')
210 |         print(f'🚀 Testing MCP Server with {self.transport.upper()} transport')
211 |         print(f'   Server: {self.base_url}')
212 |         print(f'   Test Group: {self.test_group_id}')
213 |         print('=' * 60)
214 | 
215 |         try:
216 |             # Connect based on transport type
217 |             if self.transport == 'sse':
218 |                 await self.connect_sse()
219 |             elif self.transport == 'http':
220 |                 await self.connect_http()
221 |             else:
222 |                 print(f'❌ Unknown transport: {self.transport}')
223 |                 return False
224 | 
225 |             print(f'✅ Connected via {self.transport.upper()}')
226 | 
227 |             # Run tests
228 |             results = []
229 |             results.append(await self.test_list_tools())
230 |             results.append(await self.test_add_memory())
231 |             results.append(await self.test_search_nodes())
232 |             results.append(await self.test_get_episodes())
233 |             results.append(await self.test_clear_graph())
234 | 
235 |             # Summary
236 |             passed = sum(results)
237 |             total = len(results)
238 |             success = passed == total
239 | 
240 |             print(f'\n{"=" * 60}')
241 |             print(f'📊 Results for {self.transport.upper()} transport:')
242 |             print(f'   Passed: {passed}/{total}')
243 |             print(f'   Status: {"✅ ALL TESTS PASSED" if success else "❌ SOME TESTS FAILED"}')
244 |             print('=' * 60)
245 | 
246 |             return success
247 | 
248 |         except Exception as e:
249 |             print(f'❌ Test suite failed: {e}')
250 |             return False
251 |         finally:
252 |             if self.session:
253 |                 await self.session.close()
254 | 
255 | 
256 | async def main():
257 |     """Run tests for both transports."""
258 |     # Parse command line arguments
259 |     transport = sys.argv[1] if len(sys.argv) > 1 else 'sse'
260 |     host = sys.argv[2] if len(sys.argv) > 2 else 'localhost'
261 |     port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
262 | 
263 |     # Create tester
264 |     tester = MCPTransportTester(transport, host, port)
265 | 
266 |     # Run tests
267 |     success = await tester.run_tests()
268 | 
269 |     # Exit with appropriate code
270 |     exit(0 if success else 1)
271 | 
272 | 
273 | if __name__ == '__main__':
274 |     asyncio.run(main())
275 | 
```

--------------------------------------------------------------------------------
/examples/quickstart/quickstart_neptune.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2025, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import json
 19 | import logging
 20 | import os
 21 | from datetime import datetime, timezone
 22 | from logging import INFO
 23 | 
 24 | from dotenv import load_dotenv
 25 | 
 26 | from graphiti_core import Graphiti
 27 | from graphiti_core.driver.neptune_driver import NeptuneDriver
 28 | from graphiti_core.nodes import EpisodeType
 29 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
 30 | 
 31 | #################################################
 32 | # CONFIGURATION
 33 | #################################################
 34 | # Set up logging and environment variables for
 35 | # connecting to Neptune database
 36 | #################################################
 37 | 
 38 | # Configure logging
 39 | logging.basicConfig(
 40 |     level=INFO,
 41 |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 42 |     datefmt='%Y-%m-%d %H:%M:%S',
 43 | )
 44 | logger = logging.getLogger(__name__)
 45 | 
 46 | load_dotenv()
 47 | 
 48 | # Neptune and OpenSearch connection parameters
 49 | neptune_uri = os.environ.get('NEPTUNE_HOST')
 50 | neptune_port = int(os.environ.get('NEPTUNE_PORT', 8182))
 51 | aoss_host = os.environ.get('AOSS_HOST')
 52 | 
 53 | if not neptune_uri:
 54 |     raise ValueError('NEPTUNE_HOST must be set')
 55 | 
 56 | 
 57 | if not aoss_host:
 58 |     raise ValueError('AOSS_HOST must be set')
 59 | 
 60 | 
 61 | async def main():
 62 |     #################################################
 63 |     # INITIALIZATION
 64 |     #################################################
 65 |     # Connect to Neptune and set up Graphiti indices
 66 |     # This is required before using other Graphiti
 67 |     # functionality
 68 |     #################################################
 69 | 
 70 |     # Initialize Graphiti with Neptune connection
 71 |     driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
 72 | 
 73 |     graphiti = Graphiti(graph_driver=driver)
 74 | 
 75 |     try:
 76 |         # Initialize the graph database with graphiti's indices. This only needs to be done once.
 77 |         await driver.delete_aoss_indices()
 78 |         await driver._delete_all_data()
 79 |         await graphiti.build_indices_and_constraints()
 80 | 
 81 |         #################################################
 82 |         # ADDING EPISODES
 83 |         #################################################
 84 |         # Episodes are the primary units of information
 85 |         # in Graphiti. They can be text or structured JSON
 86 |         # and are automatically processed to extract entities
 87 |         # and relationships.
 88 |         #################################################
 89 | 
 90 |         # Example: Add Episodes
 91 |         # Episodes list containing both text and JSON episodes
 92 |         episodes = [
 93 |             {
 94 |                 'content': 'Kamala Harris is the Attorney General of California. She was previously '
 95 |                 'the district attorney for San Francisco.',
 96 |                 'type': EpisodeType.text,
 97 |                 'description': 'podcast transcript',
 98 |             },
 99 |             {
100 |                 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
101 |                 'type': EpisodeType.text,
102 |                 'description': 'podcast transcript',
103 |             },
104 |             {
105 |                 'content': {
106 |                     'name': 'Gavin Newsom',
107 |                     'position': 'Governor',
108 |                     'state': 'California',
109 |                     'previous_role': 'Lieutenant Governor',
110 |                     'previous_location': 'San Francisco',
111 |                 },
112 |                 'type': EpisodeType.json,
113 |                 'description': 'podcast metadata',
114 |             },
115 |             {
116 |                 'content': {
117 |                     'name': 'Gavin Newsom',
118 |                     'position': 'Governor',
119 |                     'term_start': 'January 7, 2019',
120 |                     'term_end': 'Present',
121 |                 },
122 |                 'type': EpisodeType.json,
123 |                 'description': 'podcast metadata',
124 |             },
125 |         ]
126 | 
127 |         # Add episodes to the graph
128 |         for i, episode in enumerate(episodes):
129 |             await graphiti.add_episode(
130 |                 name=f'Freakonomics Radio {i}',
131 |                 episode_body=episode['content']
132 |                 if isinstance(episode['content'], str)
133 |                 else json.dumps(episode['content']),
134 |                 source=episode['type'],
135 |                 source_description=episode['description'],
136 |                 reference_time=datetime.now(timezone.utc),
137 |             )
138 |             print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
139 | 
140 |         await graphiti.build_communities()
141 | 
142 |         #################################################
143 |         # BASIC SEARCH
144 |         #################################################
145 |         # The simplest way to retrieve relationships (edges)
146 |         # from Graphiti is using the search method, which
147 |         # performs a hybrid search combining semantic
148 |         # similarity and BM25 text retrieval.
149 |         #################################################
150 | 
151 |         # Perform a hybrid search combining semantic similarity and BM25 retrieval
152 |         print("\nSearching for: 'Who was the California Attorney General?'")
153 |         results = await graphiti.search('Who was the California Attorney General?')
154 | 
155 |         # Print search results
156 |         print('\nSearch Results:')
157 |         for result in results:
158 |             print(f'UUID: {result.uuid}')
159 |             print(f'Fact: {result.fact}')
160 |             if hasattr(result, 'valid_at') and result.valid_at:
161 |                 print(f'Valid from: {result.valid_at}')
162 |             if hasattr(result, 'invalid_at') and result.invalid_at:
163 |                 print(f'Valid until: {result.invalid_at}')
164 |             print('---')
165 | 
166 |         #################################################
167 |         # CENTER NODE SEARCH
168 |         #################################################
169 |         # For more contextually relevant results, you can
170 |         # use a center node to rerank search results based
171 |         # on their graph distance to a specific node
172 |         #################################################
173 | 
174 |         # Use the top search result's UUID as the center node for reranking
175 |         if results and len(results) > 0:
176 |             # Get the source node UUID from the top result
177 |             center_node_uuid = results[0].source_node_uuid
178 | 
179 |             print('\nReranking search results based on graph distance:')
180 |             print(f'Using center node UUID: {center_node_uuid}')
181 | 
182 |             reranked_results = await graphiti.search(
183 |                 'Who was the California Attorney General?', center_node_uuid=center_node_uuid
184 |             )
185 | 
186 |             # Print reranked search results
187 |             print('\nReranked Search Results:')
188 |             for result in reranked_results:
189 |                 print(f'UUID: {result.uuid}')
190 |                 print(f'Fact: {result.fact}')
191 |                 if hasattr(result, 'valid_at') and result.valid_at:
192 |                     print(f'Valid from: {result.valid_at}')
193 |                 if hasattr(result, 'invalid_at') and result.invalid_at:
194 |                     print(f'Valid until: {result.invalid_at}')
195 |                 print('---')
196 |         else:
197 |             print('No results found in the initial search to use as center node.')
198 | 
199 |         #################################################
200 |         # NODE SEARCH USING SEARCH RECIPES
201 |         #################################################
202 |         # Graphiti provides predefined search recipes
203 |         # optimized for different search scenarios.
204 |         # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
205 |         # nodes directly instead of edges.
206 |         #################################################
207 | 
208 |         # Example: Perform a node search using _search method with standard recipes
209 |         print(
210 |             '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
211 |         )
212 | 
213 |         # Use a predefined search configuration recipe and modify its limit
214 |         node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
215 |         node_search_config.limit = 5  # Limit to 5 results
216 | 
217 |         # Execute the node search
218 |         node_search_results = await graphiti._search(
219 |             query='California Governor',
220 |             config=node_search_config,
221 |         )
222 | 
223 |         # Print node search results
224 |         print('\nNode Search Results:')
225 |         for node in node_search_results.nodes:
226 |             print(f'Node UUID: {node.uuid}')
227 |             print(f'Node Name: {node.name}')
228 |             node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
229 |             print(f'Content Summary: {node_summary}')
230 |             print(f'Node Labels: {", ".join(node.labels)}')
231 |             print(f'Created At: {node.created_at}')
232 |             if hasattr(node, 'attributes') and node.attributes:
233 |                 print('Attributes:')
234 |                 for key, value in node.attributes.items():
235 |                     print(f'  {key}: {value}')
236 |             print('---')
237 | 
238 |     finally:
239 |         #################################################
240 |         # CLEANUP
241 |         #################################################
242 |         # Always close the connection to Neptune when
243 |         # finished to properly release resources
244 |         #################################################
245 | 
246 |         # Close the connection
247 |         await graphiti.close()
248 |         print('\nConnection closed')
249 | 
250 | 
251 | if __name__ == '__main__':
252 |     asyncio.run(main())
253 | 
```
Page 4/12FirstPrevNextLast