#
tokens: 47214/50000 10/236 files (page 6/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 6 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── dense_vs_normal_ingestion.py
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── content_chunking.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_entity_extraction.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       ├── search
│       │   └── search_utils_test.py
│       └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/driver/falkordb_driver.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import asyncio
 18 | import datetime
 19 | import logging
 20 | from typing import TYPE_CHECKING, Any
 21 | 
 22 | if TYPE_CHECKING:
 23 |     from falkordb import Graph as FalkorGraph
 24 |     from falkordb.asyncio import FalkorDB
 25 | else:
 26 |     try:
 27 |         from falkordb import Graph as FalkorGraph
 28 |         from falkordb.asyncio import FalkorDB
 29 |     except ImportError:
 30 |         # If falkordb is not installed, raise an ImportError
 31 |         raise ImportError(
 32 |             'falkordb is required for FalkorDriver. '
 33 |             'Install it with: pip install graphiti-core[falkordb]'
 34 |         ) from None
 35 | 
 36 | from graphiti_core.driver.driver import GraphDriver, GraphDriverSession, GraphProvider
 37 | from graphiti_core.graph_queries import get_fulltext_indices, get_range_indices
 38 | from graphiti_core.utils.datetime_utils import convert_datetimes_to_strings
 39 | 
 40 | logger = logging.getLogger(__name__)
 41 | 
 42 | STOPWORDS = [
 43 |     'a',
 44 |     'is',
 45 |     'the',
 46 |     'an',
 47 |     'and',
 48 |     'are',
 49 |     'as',
 50 |     'at',
 51 |     'be',
 52 |     'but',
 53 |     'by',
 54 |     'for',
 55 |     'if',
 56 |     'in',
 57 |     'into',
 58 |     'it',
 59 |     'no',
 60 |     'not',
 61 |     'of',
 62 |     'on',
 63 |     'or',
 64 |     'such',
 65 |     'that',
 66 |     'their',
 67 |     'then',
 68 |     'there',
 69 |     'these',
 70 |     'they',
 71 |     'this',
 72 |     'to',
 73 |     'was',
 74 |     'will',
 75 |     'with',
 76 | ]
 77 | 
 78 | 
 79 | class FalkorDriverSession(GraphDriverSession):
 80 |     provider = GraphProvider.FALKORDB
 81 | 
 82 |     def __init__(self, graph: FalkorGraph):
 83 |         self.graph = graph
 84 | 
 85 |     async def __aenter__(self):
 86 |         return self
 87 | 
 88 |     async def __aexit__(self, exc_type, exc, tb):
 89 |         # No cleanup needed for Falkor, but method must exist
 90 |         pass
 91 | 
 92 |     async def close(self):
 93 |         # No explicit close needed for FalkorDB, but method must exist
 94 |         pass
 95 | 
 96 |     async def execute_write(self, func, *args, **kwargs):
 97 |         # Directly await the provided async function with `self` as the transaction/session
 98 |         return await func(self, *args, **kwargs)
 99 | 
100 |     async def run(self, query: str | list, **kwargs: Any) -> Any:
101 |         # FalkorDB does not support argument for Label Set, so it's converted into an array of queries
102 |         if isinstance(query, list):
103 |             for cypher, params in query:
104 |                 params = convert_datetimes_to_strings(params)
105 |                 await self.graph.query(str(cypher), params)  # type: ignore[reportUnknownArgumentType]
106 |         else:
107 |             params = dict(kwargs)
108 |             params = convert_datetimes_to_strings(params)
109 |             await self.graph.query(str(query), params)  # type: ignore[reportUnknownArgumentType]
110 |         # Assuming `graph.query` is async (ideal); otherwise, wrap in executor
111 |         return None
112 | 
113 | 
114 | class FalkorDriver(GraphDriver):
115 |     provider = GraphProvider.FALKORDB
116 |     default_group_id: str = '\\_'
117 |     fulltext_syntax: str = '@'  # FalkorDB uses a redisearch-like syntax for fulltext queries
118 |     aoss_client: None = None
119 | 
120 |     def __init__(
121 |         self,
122 |         host: str = 'localhost',
123 |         port: int = 6379,
124 |         username: str | None = None,
125 |         password: str | None = None,
126 |         falkor_db: FalkorDB | None = None,
127 |         database: str = 'default_db',
128 |     ):
129 |         """
130 |         Initialize the FalkorDB driver.
131 | 
132 |         FalkorDB is a multi-tenant graph database.
133 |         To connect, provide the host and port.
134 |         The default parameters assume a local (on-premises) FalkorDB instance.
135 | 
136 |         Args:
137 |         host (str): The host where FalkorDB is running.
138 |         port (int): The port on which FalkorDB is listening.
139 |         username (str | None): The username for authentication (if required).
140 |         password (str | None): The password for authentication (if required).
141 |         falkor_db (FalkorDB | None): An existing FalkorDB instance to use instead of creating a new one.
142 |         database (str): The name of the database to connect to. Defaults to 'default_db'.
143 |         """
144 |         super().__init__()
145 |         self._database = database
146 |         if falkor_db is not None:
147 |             # If a FalkorDB instance is provided, use it directly
148 |             self.client = falkor_db
149 |         else:
150 |             self.client = FalkorDB(host=host, port=port, username=username, password=password)
151 | 
152 |         # Schedule the indices and constraints to be built
153 |         try:
154 |             # Try to get the current event loop
155 |             loop = asyncio.get_running_loop()
156 |             # Schedule the build_indices_and_constraints to run
157 |             loop.create_task(self.build_indices_and_constraints())
158 |         except RuntimeError:
159 |             # No event loop running, this will be handled later
160 |             pass
161 | 
162 |     def _get_graph(self, graph_name: str | None) -> FalkorGraph:
163 |         # FalkorDB requires a non-None database name for multi-tenant graphs; the default is "default_db"
164 |         if graph_name is None:
165 |             graph_name = self._database
166 |         return self.client.select_graph(graph_name)
167 | 
168 |     async def execute_query(self, cypher_query_, **kwargs: Any):
169 |         graph = self._get_graph(self._database)
170 | 
171 |         # Convert datetime objects to ISO strings (FalkorDB does not support datetime objects directly)
172 |         params = convert_datetimes_to_strings(dict(kwargs))
173 | 
174 |         try:
175 |             result = await graph.query(cypher_query_, params)  # type: ignore[reportUnknownArgumentType]
176 |         except Exception as e:
177 |             if 'already indexed' in str(e):
178 |                 # check if index already exists
179 |                 logger.info(f'Index already exists: {e}')
180 |                 return None
181 |             logger.error(f'Error executing FalkorDB query: {e}\n{cypher_query_}\n{params}')
182 |             raise
183 | 
184 |         # Convert the result header to a list of strings
185 |         header = [h[1] for h in result.header]
186 | 
187 |         # Convert FalkorDB's result format (list of lists) to the format expected by Graphiti (list of dicts)
188 |         records = []
189 |         for row in result.result_set:
190 |             record = {}
191 |             for i, field_name in enumerate(header):
192 |                 if i < len(row):
193 |                     record[field_name] = row[i]
194 |                 else:
195 |                     # If there are more fields in header than values in row, set to None
196 |                     record[field_name] = None
197 |             records.append(record)
198 | 
199 |         return records, header, None
200 | 
201 |     def session(self, database: str | None = None) -> GraphDriverSession:
202 |         return FalkorDriverSession(self._get_graph(database))
203 | 
204 |     async def close(self) -> None:
205 |         """Close the driver connection."""
206 |         if hasattr(self.client, 'aclose'):
207 |             await self.client.aclose()  # type: ignore[reportUnknownMemberType]
208 |         elif hasattr(self.client.connection, 'aclose'):
209 |             await self.client.connection.aclose()
210 |         elif hasattr(self.client.connection, 'close'):
211 |             await self.client.connection.close()
212 | 
213 |     async def delete_all_indexes(self) -> None:
214 |         result = await self.execute_query('CALL db.indexes()')
215 |         if not result:
216 |             return
217 | 
218 |         records, _, _ = result
219 |         drop_tasks = []
220 | 
221 |         for record in records:
222 |             label = record['label']
223 |             entity_type = record['entitytype']
224 | 
225 |             for field_name, index_type in record['types'].items():
226 |                 if 'RANGE' in index_type:
227 |                     drop_tasks.append(self.execute_query(f'DROP INDEX ON :{label}({field_name})'))
228 |                 elif 'FULLTEXT' in index_type:
229 |                     if entity_type == 'NODE':
230 |                         drop_tasks.append(
231 |                             self.execute_query(
232 |                                 f'DROP FULLTEXT INDEX FOR (n:{label}) ON (n.{field_name})'
233 |                             )
234 |                         )
235 |                     elif entity_type == 'RELATIONSHIP':
236 |                         drop_tasks.append(
237 |                             self.execute_query(
238 |                                 f'DROP FULLTEXT INDEX FOR ()-[e:{label}]-() ON (e.{field_name})'
239 |                             )
240 |                         )
241 | 
242 |         if drop_tasks:
243 |             await asyncio.gather(*drop_tasks)
244 | 
245 |     async def build_indices_and_constraints(self, delete_existing=False):
246 |         if delete_existing:
247 |             await self.delete_all_indexes()
248 |         index_queries = get_range_indices(self.provider) + get_fulltext_indices(self.provider)
249 |         for query in index_queries:
250 |             await self.execute_query(query)
251 | 
252 |     def clone(self, database: str) -> 'GraphDriver':
253 |         """
254 |         Returns a shallow copy of this driver with a different default database.
255 |         Reuses the same connection (e.g. FalkorDB, Neo4j).
256 |         """
257 |         if database == self._database:
258 |             cloned = self
259 |         elif database == self.default_group_id:
260 |             cloned = FalkorDriver(falkor_db=self.client)
261 |         else:
262 |             # Create a new instance of FalkorDriver with the same connection but a different database
263 |             cloned = FalkorDriver(falkor_db=self.client, database=database)
264 | 
265 |         return cloned
266 | 
267 |     async def health_check(self) -> None:
268 |         """Check FalkorDB connectivity by running a simple query."""
269 |         try:
270 |             await self.execute_query('MATCH (n) RETURN 1 LIMIT 1')
271 |             return None
272 |         except Exception as e:
273 |             print(f'FalkorDB health check failed: {e}')
274 |             raise
275 | 
276 |     @staticmethod
277 |     def convert_datetimes_to_strings(obj):
278 |         if isinstance(obj, dict):
279 |             return {k: FalkorDriver.convert_datetimes_to_strings(v) for k, v in obj.items()}
280 |         elif isinstance(obj, list):
281 |             return [FalkorDriver.convert_datetimes_to_strings(item) for item in obj]
282 |         elif isinstance(obj, tuple):
283 |             return tuple(FalkorDriver.convert_datetimes_to_strings(item) for item in obj)
284 |         elif isinstance(obj, datetime):
285 |             return obj.isoformat()
286 |         else:
287 |             return obj
288 | 
289 |     def sanitize(self, query: str) -> str:
290 |         """
291 |         Replace FalkorDB special characters with whitespace.
292 |         Based on FalkorDB tokenization rules: ,.<>{}[]"':;!@#$%^&*()-+=~
293 |         """
294 |         # FalkorDB separator characters that break text into tokens
295 |         separator_map = str.maketrans(
296 |             {
297 |                 ',': ' ',
298 |                 '.': ' ',
299 |                 '<': ' ',
300 |                 '>': ' ',
301 |                 '{': ' ',
302 |                 '}': ' ',
303 |                 '[': ' ',
304 |                 ']': ' ',
305 |                 '"': ' ',
306 |                 "'": ' ',
307 |                 ':': ' ',
308 |                 ';': ' ',
309 |                 '!': ' ',
310 |                 '@': ' ',
311 |                 '#': ' ',
312 |                 '$': ' ',
313 |                 '%': ' ',
314 |                 '^': ' ',
315 |                 '&': ' ',
316 |                 '*': ' ',
317 |                 '(': ' ',
318 |                 ')': ' ',
319 |                 '-': ' ',
320 |                 '+': ' ',
321 |                 '=': ' ',
322 |                 '~': ' ',
323 |                 '?': ' ',
324 |             }
325 |         )
326 |         sanitized = query.translate(separator_map)
327 |         # Clean up multiple spaces
328 |         sanitized = ' '.join(sanitized.split())
329 |         return sanitized
330 | 
331 |     def build_fulltext_query(
332 |         self, query: str, group_ids: list[str] | None = None, max_query_length: int = 128
333 |     ) -> str:
334 |         """
335 |         Build a fulltext query string for FalkorDB using RedisSearch syntax.
336 |         FalkorDB uses RedisSearch-like syntax where:
337 |         - Field queries use @ prefix: @field:value
338 |         - Multiple values for same field: (@field:value1|value2)
339 |         - Text search doesn't need @ prefix for content fields
340 |         - AND is implicit with space: (@group_id:value) (text)
341 |         - OR uses pipe within parentheses: (@group_id:value1|value2)
342 |         """
343 |         if group_ids is None or len(group_ids) == 0:
344 |             group_filter = ''
345 |         else:
346 |             group_values = '|'.join(group_ids)
347 |             group_filter = f'(@group_id:{group_values})'
348 | 
349 |         sanitized_query = self.sanitize(query)
350 | 
351 |         # Remove stopwords from the sanitized query
352 |         query_words = sanitized_query.split()
353 |         filtered_words = [word for word in query_words if word.lower() not in STOPWORDS]
354 |         sanitized_query = ' | '.join(filtered_words)
355 | 
356 |         # If the query is too long return no query
357 |         if len(sanitized_query.split(' ')) + len(group_ids or '') >= max_query_length:
358 |             return ''
359 | 
360 |         full_query = group_filter + ' (' + sanitized_query + ')'
361 | 
362 |         return full_query
363 | 
```

--------------------------------------------------------------------------------
/graphiti_core/models/nodes/node_db_queries.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | from typing import Any
 18 | 
 19 | from graphiti_core.driver.driver import GraphProvider
 20 | 
 21 | 
 22 | def get_episode_node_save_query(provider: GraphProvider) -> str:
 23 |     match provider:
 24 |         case GraphProvider.NEPTUNE:
 25 |             return """
 26 |                 MERGE (n:Episodic {uuid: $uuid})
 27 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
 28 |                 entity_edges: join([x IN coalesce($entity_edges, []) | toString(x) ], '|'), created_at: $created_at, valid_at: $valid_at}
 29 |                 RETURN n.uuid AS uuid
 30 |             """
 31 |         case GraphProvider.KUZU:
 32 |             return """
 33 |                 MERGE (n:Episodic {uuid: $uuid})
 34 |                 SET
 35 |                     n.name = $name,
 36 |                     n.group_id = $group_id,
 37 |                     n.created_at = $created_at,
 38 |                     n.source = $source,
 39 |                     n.source_description = $source_description,
 40 |                     n.content = $content,
 41 |                     n.valid_at = $valid_at,
 42 |                     n.entity_edges = $entity_edges
 43 |                 RETURN n.uuid AS uuid
 44 |             """
 45 |         case GraphProvider.FALKORDB:
 46 |             return """
 47 |                 MERGE (n:Episodic {uuid: $uuid})
 48 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
 49 |                 entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
 50 |                 RETURN n.uuid AS uuid
 51 |             """
 52 |         case _:  # Neo4j
 53 |             return """
 54 |                 MERGE (n:Episodic {uuid: $uuid})
 55 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, source_description: $source_description, source: $source, content: $content,
 56 |                 entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at}
 57 |                 RETURN n.uuid AS uuid
 58 |             """
 59 | 
 60 | 
 61 | def get_episode_node_save_bulk_query(provider: GraphProvider) -> str:
 62 |     match provider:
 63 |         case GraphProvider.NEPTUNE:
 64 |             return """
 65 |                 UNWIND $episodes AS episode
 66 |                 MERGE (n:Episodic {uuid: episode.uuid})
 67 |                 SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description,
 68 |                     source: episode.source, content: episode.content,
 69 |                 entity_edges: join([x IN coalesce(episode.entity_edges, []) | toString(x) ], '|'), created_at: episode.created_at, valid_at: episode.valid_at}
 70 |                 RETURN n.uuid AS uuid
 71 |             """
 72 |         case GraphProvider.KUZU:
 73 |             return """
 74 |                 MERGE (n:Episodic {uuid: $uuid})
 75 |                 SET
 76 |                     n.name = $name,
 77 |                     n.group_id = $group_id,
 78 |                     n.created_at = $created_at,
 79 |                     n.source = $source,
 80 |                     n.source_description = $source_description,
 81 |                     n.content = $content,
 82 |                     n.valid_at = $valid_at,
 83 |                     n.entity_edges = $entity_edges
 84 |                 RETURN n.uuid AS uuid
 85 |             """
 86 |         case GraphProvider.FALKORDB:
 87 |             return """
 88 |                 UNWIND $episodes AS episode
 89 |                 MERGE (n:Episodic {uuid: episode.uuid})
 90 |                 SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, 
 91 |                 entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
 92 |                 RETURN n.uuid AS uuid
 93 |             """
 94 |         case _:  # Neo4j
 95 |             return """
 96 |                 UNWIND $episodes AS episode
 97 |                 MERGE (n:Episodic {uuid: episode.uuid})
 98 |                 SET n = {uuid: episode.uuid, name: episode.name, group_id: episode.group_id, source_description: episode.source_description, source: episode.source, content: episode.content, 
 99 |                 entity_edges: episode.entity_edges, created_at: episode.created_at, valid_at: episode.valid_at}
100 |                 RETURN n.uuid AS uuid
101 |             """
102 | 
103 | 
104 | EPISODIC_NODE_RETURN = """
105 |     e.uuid AS uuid,
106 |     e.name AS name,
107 |     e.group_id AS group_id,
108 |     e.created_at AS created_at,
109 |     e.source AS source,
110 |     e.source_description AS source_description,
111 |     e.content AS content,
112 |     e.valid_at AS valid_at,
113 |     e.entity_edges AS entity_edges
114 | """
115 | 
116 | EPISODIC_NODE_RETURN_NEPTUNE = """
117 |     e.content AS content,
118 |     e.created_at AS created_at,
119 |     e.valid_at AS valid_at,
120 |     e.uuid AS uuid,
121 |     e.name AS name,
122 |     e.group_id AS group_id,
123 |     e.source_description AS source_description,
124 |     e.source AS source,
125 |     split(e.entity_edges, ",") AS entity_edges
126 | """
127 | 
128 | 
129 | def get_entity_node_save_query(provider: GraphProvider, labels: str, has_aoss: bool = False) -> str:
130 |     match provider:
131 |         case GraphProvider.FALKORDB:
132 |             return f"""
133 |                 MERGE (n:Entity {{uuid: $entity_data.uuid}})
134 |                 SET n:{labels}
135 |                 SET n = $entity_data
136 |                 SET n.name_embedding = vecf32($entity_data.name_embedding)
137 |                 RETURN n.uuid AS uuid
138 |             """
139 |         case GraphProvider.KUZU:
140 |             return """
141 |                 MERGE (n:Entity {uuid: $uuid})
142 |                 SET
143 |                     n.name = $name,
144 |                     n.group_id = $group_id,
145 |                     n.labels = $labels,
146 |                     n.created_at = $created_at,
147 |                     n.name_embedding = $name_embedding,
148 |                     n.summary = $summary,
149 |                     n.attributes = $attributes
150 |                 WITH n
151 |                 RETURN n.uuid AS uuid
152 |             """
153 |         case GraphProvider.NEPTUNE:
154 |             label_subquery = ''
155 |             for label in labels.split(':'):
156 |                 label_subquery += f' SET n:{label}\n'
157 |             return f"""
158 |                 MERGE (n:Entity {{uuid: $entity_data.uuid}})
159 |                 {label_subquery}
160 |                 SET n = removeKeyFromMap(removeKeyFromMap($entity_data, "labels"), "name_embedding")
161 |                 SET n.name_embedding = join([x IN coalesce($entity_data.name_embedding, []) | toString(x) ], ",")
162 |                 RETURN n.uuid AS uuid
163 |             """
164 |         case _:
165 |             save_embedding_query = (
166 |                 'WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $entity_data.name_embedding)'
167 |                 if not has_aoss
168 |                 else ''
169 |             )
170 |             return (
171 |                 f"""
172 |                 MERGE (n:Entity {{uuid: $entity_data.uuid}})
173 |                 SET n:{labels}
174 |                 SET n = $entity_data
175 |                 """
176 |                 + save_embedding_query
177 |                 + """
178 |                 RETURN n.uuid AS uuid
179 |             """
180 |             )
181 | 
182 | 
183 | def get_entity_node_save_bulk_query(
184 |     provider: GraphProvider, nodes: list[dict], has_aoss: bool = False
185 | ) -> str | Any:
186 |     match provider:
187 |         case GraphProvider.FALKORDB:
188 |             queries = []
189 |             for node in nodes:
190 |                 for label in node['labels']:
191 |                     queries.append(
192 |                         (
193 |                             f"""
194 |                             UNWIND $nodes AS node
195 |                             MERGE (n:Entity {{uuid: node.uuid}})
196 |                             SET n:{label}
197 |                             SET n = node
198 |                             WITH n, node
199 |                             SET n.name_embedding = vecf32(node.name_embedding)
200 |                             RETURN n.uuid AS uuid
201 |                             """,
202 |                             {'nodes': [node]},
203 |                         )
204 |                     )
205 |             return queries
206 |         case GraphProvider.NEPTUNE:
207 |             queries = []
208 |             for node in nodes:
209 |                 labels = ''
210 |                 for label in node['labels']:
211 |                     labels += f' SET n:{label}\n'
212 |                 queries.append(
213 |                     f"""
214 |                         UNWIND $nodes AS node
215 |                         MERGE (n:Entity {{uuid: node.uuid}})
216 |                         {labels}
217 |                         SET n = removeKeyFromMap(removeKeyFromMap(node, "labels"), "name_embedding")
218 |                         SET n.name_embedding = join([x IN coalesce(node.name_embedding, []) | toString(x) ], ",")
219 |                         RETURN n.uuid AS uuid
220 |                     """
221 |                 )
222 |             return queries
223 |         case GraphProvider.KUZU:
224 |             return """
225 |                 MERGE (n:Entity {uuid: $uuid})
226 |                 SET
227 |                     n.name = $name,
228 |                     n.group_id = $group_id,
229 |                     n.labels = $labels,
230 |                     n.created_at = $created_at,
231 |                     n.name_embedding = $name_embedding,
232 |                     n.summary = $summary,
233 |                     n.attributes = $attributes
234 |                 RETURN n.uuid AS uuid
235 |             """
236 |         case _:  # Neo4j
237 |             save_embedding_query = (
238 |                 'WITH n, node CALL db.create.setNodeVectorProperty(n, "name_embedding", node.name_embedding)'
239 |                 if not has_aoss
240 |                 else ''
241 |             )
242 |             return (
243 |                 """
244 |                     UNWIND $nodes AS node
245 |                     MERGE (n:Entity {uuid: node.uuid})
246 |                     SET n:$(node.labels)
247 |                     SET n = node
248 |                     """
249 |                 + save_embedding_query
250 |                 + """
251 |                 RETURN n.uuid AS uuid
252 |             """
253 |             )
254 | 
255 | 
256 | def get_entity_node_return_query(provider: GraphProvider) -> str:
257 |     # `name_embedding` is not returned by default and must be loaded manually using `load_name_embedding()`.
258 |     if provider == GraphProvider.KUZU:
259 |         return """
260 |             n.uuid AS uuid,
261 |             n.name AS name,
262 |             n.group_id AS group_id,
263 |             n.labels AS labels,
264 |             n.created_at AS created_at,
265 |             n.summary AS summary,
266 |             n.attributes AS attributes
267 |         """
268 | 
269 |     return """
270 |         n.uuid AS uuid,
271 |         n.name AS name,
272 |         n.group_id AS group_id,
273 |         n.created_at AS created_at,
274 |         n.summary AS summary,
275 |         labels(n) AS labels,
276 |         properties(n) AS attributes
277 |     """
278 | 
279 | 
280 | def get_community_node_save_query(provider: GraphProvider) -> str:
281 |     match provider:
282 |         case GraphProvider.FALKORDB:
283 |             return """
284 |                 MERGE (n:Community {uuid: $uuid})
285 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at, name_embedding: vecf32($name_embedding)}
286 |                 RETURN n.uuid AS uuid
287 |             """
288 |         case GraphProvider.NEPTUNE:
289 |             return """
290 |                 MERGE (n:Community {uuid: $uuid})
291 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
292 |                 SET n.name_embedding = join([x IN coalesce($name_embedding, []) | toString(x) ], ",")
293 |                 RETURN n.uuid AS uuid
294 |             """
295 |         case GraphProvider.KUZU:
296 |             return """
297 |                 MERGE (n:Community {uuid: $uuid})
298 |                 SET
299 |                     n.name = $name,
300 |                     n.group_id = $group_id,
301 |                     n.created_at = $created_at,
302 |                     n.name_embedding = $name_embedding,
303 |                     n.summary = $summary
304 |                 RETURN n.uuid AS uuid
305 |             """
306 |         case _:  # Neo4j
307 |             return """
308 |                 MERGE (n:Community {uuid: $uuid})
309 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, summary: $summary, created_at: $created_at}
310 |                 WITH n CALL db.create.setNodeVectorProperty(n, "name_embedding", $name_embedding)
311 |                 RETURN n.uuid AS uuid
312 |             """
313 | 
314 | 
315 | COMMUNITY_NODE_RETURN = """
316 |     c.uuid AS uuid,
317 |     c.name AS name,
318 |     c.group_id AS group_id,
319 |     c.created_at AS created_at,
320 |     c.name_embedding AS name_embedding,
321 |     c.summary AS summary
322 | """
323 | 
324 | COMMUNITY_NODE_RETURN_NEPTUNE = """
325 |     n.uuid AS uuid,
326 |     n.name AS name,
327 |     [x IN split(n.name_embedding, ",") | toFloat(x)] AS name_embedding,
328 |     n.group_id AS group_id,
329 |     n.summary AS summary,
330 |     n.created_at AS created_at
331 | """
332 | 
333 | 
334 | def get_saga_node_save_query(provider: GraphProvider) -> str:
335 |     match provider:
336 |         case GraphProvider.KUZU:
337 |             return """
338 |                 MERGE (n:Saga {uuid: $uuid})
339 |                 SET
340 |                     n.name = $name,
341 |                     n.group_id = $group_id,
342 |                     n.created_at = $created_at
343 |                 RETURN n.uuid AS uuid
344 |             """
345 |         case _:  # Neo4j, FalkorDB, Neptune
346 |             return """
347 |                 MERGE (n:Saga {uuid: $uuid})
348 |                 SET n = {uuid: $uuid, name: $name, group_id: $group_id, created_at: $created_at}
349 |                 RETURN n.uuid AS uuid
350 |             """
351 | 
352 | 
353 | SAGA_NODE_RETURN = """
354 |     s.uuid AS uuid,
355 |     s.name AS name,
356 |     s.group_id AS group_id,
357 |     s.created_at AS created_at
358 | """
359 | 
360 | SAGA_NODE_RETURN_NEPTUNE = """
361 |     s.uuid AS uuid,
362 |     s.name AS name,
363 |     s.group_id AS group_id,
364 |     s.created_at AS created_at
365 | """
366 | 
```

--------------------------------------------------------------------------------
/signatures/version1/cla.json:
--------------------------------------------------------------------------------

```json
  1 | {
  2 |   "signedContributors": [
  3 |     {
  4 |       "name": "colombod",
  5 |       "id": 375556,
  6 |       "comment_id": 2761979440,
  7 |       "created_at": "2025-03-28T17:21:29Z",
  8 |       "repoId": 840056306,
  9 |       "pullRequestNo": 310
 10 |     },
 11 |     {
 12 |       "name": "evanmschultz",
 13 |       "id": 3806601,
 14 |       "comment_id": 2813673237,
 15 |       "created_at": "2025-04-17T17:56:24Z",
 16 |       "repoId": 840056306,
 17 |       "pullRequestNo": 372
 18 |     },
 19 |     {
 20 |       "name": "soichisumi",
 21 |       "id": 30210641,
 22 |       "comment_id": 2818469528,
 23 |       "created_at": "2025-04-21T14:02:11Z",
 24 |       "repoId": 840056306,
 25 |       "pullRequestNo": 382
 26 |     },
 27 |     {
 28 |       "name": "drumnation",
 29 |       "id": 18486434,
 30 |       "comment_id": 2822330188,
 31 |       "created_at": "2025-04-22T19:51:09Z",
 32 |       "repoId": 840056306,
 33 |       "pullRequestNo": 389
 34 |     },
 35 |     {
 36 |       "name": "jackaldenryan",
 37 |       "id": 61809814,
 38 |       "comment_id": 2845356793,
 39 |       "created_at": "2025-05-01T17:51:11Z",
 40 |       "repoId": 840056306,
 41 |       "pullRequestNo": 429
 42 |     },
 43 |     {
 44 |       "name": "t41372",
 45 |       "id": 36402030,
 46 |       "comment_id": 2849035400,
 47 |       "created_at": "2025-05-04T06:24:37Z",
 48 |       "repoId": 840056306,
 49 |       "pullRequestNo": 438
 50 |     },
 51 |     {
 52 |       "name": "markalosey",
 53 |       "id": 1949914,
 54 |       "comment_id": 2878173826,
 55 |       "created_at": "2025-05-13T23:27:16Z",
 56 |       "repoId": 840056306,
 57 |       "pullRequestNo": 486
 58 |     },
 59 |     {
 60 |       "name": "adamkatav",
 61 |       "id": 13109136,
 62 |       "comment_id": 2887184706,
 63 |       "created_at": "2025-05-16T16:29:22Z",
 64 |       "repoId": 840056306,
 65 |       "pullRequestNo": 493
 66 |     },
 67 |     {
 68 |       "name": "realugbun",
 69 |       "id": 74101927,
 70 |       "comment_id": 2899731784,
 71 |       "created_at": "2025-05-22T02:36:44Z",
 72 |       "repoId": 840056306,
 73 |       "pullRequestNo": 513
 74 |     },
 75 |     {
 76 |       "name": "dudizimber",
 77 |       "id": 16744955,
 78 |       "comment_id": 2912211548,
 79 |       "created_at": "2025-05-27T11:45:57Z",
 80 |       "repoId": 840056306,
 81 |       "pullRequestNo": 525
 82 |     },
 83 |     {
 84 |       "name": "galshubeli",
 85 |       "id": 124919062,
 86 |       "comment_id": 2912289100,
 87 |       "created_at": "2025-05-27T12:15:03Z",
 88 |       "repoId": 840056306,
 89 |       "pullRequestNo": 525
 90 |     },
 91 |     {
 92 |       "name": "TheEpTic",
 93 |       "id": 326774,
 94 |       "comment_id": 2917970901,
 95 |       "created_at": "2025-05-29T01:26:54Z",
 96 |       "repoId": 840056306,
 97 |       "pullRequestNo": 541
 98 |     },
 99 |     {
100 |       "name": "PrettyWood",
101 |       "id": 18406791,
102 |       "comment_id": 2938495182,
103 |       "created_at": "2025-06-04T04:44:59Z",
104 |       "repoId": 840056306,
105 |       "pullRequestNo": 558
106 |     },
107 |     {
108 |       "name": "denyska",
109 |       "id": 1242726,
110 |       "comment_id": 2957480685,
111 |       "created_at": "2025-06-10T02:08:05Z",
112 |       "repoId": 840056306,
113 |       "pullRequestNo": 574
114 |     },
115 |     {
116 |       "name": "LongPML",
117 |       "id": 59755436,
118 |       "comment_id": 2965391879,
119 |       "created_at": "2025-06-12T07:10:01Z",
120 |       "repoId": 840056306,
121 |       "pullRequestNo": 579
122 |     },
123 |     {
124 |       "name": "karn09",
125 |       "id": 3743119,
126 |       "comment_id": 2973492225,
127 |       "created_at": "2025-06-15T04:45:13Z",
128 |       "repoId": 840056306,
129 |       "pullRequestNo": 584
130 |     },
131 |     {
132 |       "name": "abab-dev",
133 |       "id": 146825408,
134 |       "comment_id": 2975719469,
135 |       "created_at": "2025-06-16T09:12:53Z",
136 |       "repoId": 840056306,
137 |       "pullRequestNo": 588
138 |     },
139 |     {
140 |       "name": "thorchh",
141 |       "id": 75025911,
142 |       "comment_id": 2982990164,
143 |       "created_at": "2025-06-18T07:19:38Z",
144 |       "repoId": 840056306,
145 |       "pullRequestNo": 601
146 |     },
147 |     {
148 |       "name": "robrichardson13",
149 |       "id": 9492530,
150 |       "comment_id": 2989798338,
151 |       "created_at": "2025-06-20T04:59:06Z",
152 |       "repoId": 840056306,
153 |       "pullRequestNo": 611
154 |     },
155 |     {
156 |       "name": "gkorland",
157 |       "id": 753206,
158 |       "comment_id": 2993690025,
159 |       "created_at": "2025-06-21T17:35:37Z",
160 |       "repoId": 840056306,
161 |       "pullRequestNo": 609
162 |     },
163 |     {
164 |       "name": "urmzd",
165 |       "id": 45431570,
166 |       "comment_id": 3027098935,
167 |       "created_at": "2025-07-02T09:16:46Z",
168 |       "repoId": 840056306,
169 |       "pullRequestNo": 661
170 |     },
171 |     {
172 |       "name": "jawwadfirdousi",
173 |       "id": 10913083,
174 |       "comment_id": 3027808026,
175 |       "created_at": "2025-07-02T13:02:22Z",
176 |       "repoId": 840056306,
177 |       "pullRequestNo": 663
178 |     },
179 |     {
180 |       "name": "jamesindeed",
181 |       "id": 60527576,
182 |       "comment_id": 3028293328,
183 |       "created_at": "2025-07-02T15:24:23Z",
184 |       "repoId": 840056306,
185 |       "pullRequestNo": 664
186 |     },
187 |     {
188 |       "name": "dev-mirzabicer",
189 |       "id": 90691873,
190 |       "comment_id": 3035836506,
191 |       "created_at": "2025-07-04T11:47:08Z",
192 |       "repoId": 840056306,
193 |       "pullRequestNo": 672
194 |     },
195 |     {
196 |       "name": "zeroasterisk",
197 |       "id": 23422,
198 |       "comment_id": 3040716245,
199 |       "created_at": "2025-07-06T03:41:19Z",
200 |       "repoId": 840056306,
201 |       "pullRequestNo": 679
202 |     },
203 |     {
204 |       "name": "charlesmcchan",
205 |       "id": 425857,
206 |       "comment_id": 3066732289,
207 |       "created_at": "2025-07-13T08:54:26Z",
208 |       "repoId": 840056306,
209 |       "pullRequestNo": 711
210 |     },
211 |     {
212 |       "name": "soraxas",
213 |       "id": 22362177,
214 |       "comment_id": 3084093750,
215 |       "created_at": "2025-07-17T13:33:25Z",
216 |       "repoId": 840056306,
217 |       "pullRequestNo": 741
218 |     },
219 |     {
220 |       "name": "sdht0",
221 |       "id": 867424,
222 |       "comment_id": 3092540466,
223 |       "created_at": "2025-07-19T19:52:21Z",
224 |       "repoId": 840056306,
225 |       "pullRequestNo": 748
226 |     },
227 |     {
228 |       "name": "Naseem77",
229 |       "id": 34807727,
230 |       "comment_id": 3093746709,
231 |       "created_at": "2025-07-20T07:07:33Z",
232 |       "repoId": 840056306,
233 |       "pullRequestNo": 742
234 |     },
235 |     {
236 |       "name": "kavenGw",
237 |       "id": 3193355,
238 |       "comment_id": 3100620568,
239 |       "created_at": "2025-07-22T02:58:50Z",
240 |       "repoId": 840056306,
241 |       "pullRequestNo": 750
242 |     },
243 |     {
244 |       "name": "paveljakov",
245 |       "id": 45147436,
246 |       "comment_id": 3113955940,
247 |       "created_at": "2025-07-24T15:39:36Z",
248 |       "repoId": 840056306,
249 |       "pullRequestNo": 764
250 |     },
251 |     {
252 |       "name": "gifflet",
253 |       "id": 33522742,
254 |       "comment_id": 3133869379,
255 |       "created_at": "2025-07-29T20:00:27Z",
256 |       "repoId": 840056306,
257 |       "pullRequestNo": 782
258 |     },
259 |     {
260 |       "name": "bechbd",
261 |       "id": 6898505,
262 |       "comment_id": 3140501814,
263 |       "created_at": "2025-07-31T15:58:08Z",
264 |       "repoId": 840056306,
265 |       "pullRequestNo": 793
266 |     },
267 |     {
268 |       "name": "hugo-son",
269 |       "id": 141999572,
270 |       "comment_id": 3155009405,
271 |       "created_at": "2025-08-05T12:27:09Z",
272 |       "repoId": 840056306,
273 |       "pullRequestNo": 805
274 |     },
275 |     {
276 |       "name": "mvanders",
277 |       "id": 758617,
278 |       "comment_id": 3160523661,
279 |       "created_at": "2025-08-06T14:56:21Z",
280 |       "repoId": 840056306,
281 |       "pullRequestNo": 808
282 |     },
283 |     {
284 |       "name": "v-khanna",
285 |       "id": 102773390,
286 |       "comment_id": 3162200130,
287 |       "created_at": "2025-08-07T02:23:09Z",
288 |       "repoId": 840056306,
289 |       "pullRequestNo": 812
290 |     },
291 |     {
292 |       "name": "vjeeva",
293 |       "id": 13189349,
294 |       "comment_id": 3165600173,
295 |       "created_at": "2025-08-07T20:24:08Z",
296 |       "repoId": 840056306,
297 |       "pullRequestNo": 814
298 |     },
299 |     {
300 |       "name": "liebertar",
301 |       "id": 99405438,
302 |       "comment_id": 3166905812,
303 |       "created_at": "2025-08-08T07:52:27Z",
304 |       "repoId": 840056306,
305 |       "pullRequestNo": 816
306 |     },
307 |     {
308 |       "name": "CaroLe-prw",
309 |       "id": 42695882,
310 |       "comment_id": 3187949734,
311 |       "created_at": "2025-08-14T10:29:25Z",
312 |       "repoId": 840056306,
313 |       "pullRequestNo": 833
314 |     },
315 |     {
316 |       "name": "Wizmann",
317 |       "id": 1270921,
318 |       "comment_id": 3196208374,
319 |       "created_at": "2025-08-18T11:09:35Z",
320 |       "repoId": 840056306,
321 |       "pullRequestNo": 842
322 |     },
323 |     {
324 |       "name": "liangyuanpeng",
325 |       "id": 28711504,
326 |       "comment_id": 3205841804,
327 |       "created_at": "2025-08-20T11:35:42Z",
328 |       "repoId": 840056306,
329 |       "pullRequestNo": 847
330 |     },
331 |     {
332 |       "name": "aktek-yazge",
333 |       "id": 218602044,
334 |       "comment_id": 3078757968,
335 |       "created_at": "2025-07-16T14:00:40Z",
336 |       "repoId": 840056306,
337 |       "pullRequestNo": 735
338 |     },
339 |     {
340 |       "name": "Shelvak",
341 |       "id": 873323,
342 |       "comment_id": 3243330690,
343 |       "created_at": "2025-09-01T22:26:32Z",
344 |       "repoId": 840056306,
345 |       "pullRequestNo": 885
346 |     },
347 |     {
348 |       "name": "maskshell",
349 |       "id": 5113279,
350 |       "comment_id": 3244187860,
351 |       "created_at": "2025-09-02T07:48:05Z",
352 |       "repoId": 840056306,
353 |       "pullRequestNo": 886
354 |     },
355 |     {
356 |       "name": "jeanlucthumm",
357 |       "id": 4934853,
358 |       "comment_id": 3255120747,
359 |       "created_at": "2025-09-04T18:49:57Z",
360 |       "repoId": 840056306,
361 |       "pullRequestNo": 892
362 |     },
363 |     {
364 |       "name": "Bit-urd",
365 |       "id": 43745133,
366 |       "comment_id": 3264006888,
367 |       "created_at": "2025-09-07T20:01:08Z",
368 |       "repoId": 840056306,
369 |       "pullRequestNo": 895
370 |     },
371 |     {
372 |       "name": "DavIvek",
373 |       "id": 88043717,
374 |       "comment_id": 3269895491,
375 |       "created_at": "2025-09-09T09:59:47Z",
376 |       "repoId": 840056306,
377 |       "pullRequestNo": 900
378 |     },
379 |     {
380 |       "name": "gsw945",
381 |       "id": 6281968,
382 |       "comment_id": 3270396586,
383 |       "created_at": "2025-09-09T12:05:27Z",
384 |       "repoId": 840056306,
385 |       "pullRequestNo": 901
386 |     },
387 |     {
388 |       "name": "luan122",
389 |       "id": 5606023,
390 |       "comment_id": 3287095238,
391 |       "created_at": "2025-09-12T23:14:21Z",
392 |       "repoId": 840056306,
393 |       "pullRequestNo": 908
394 |     },
395 |     {
396 |       "name": "Brandtweary",
397 |       "id": 7968557,
398 |       "comment_id": 3314191937,
399 |       "created_at": "2025-09-19T23:37:33Z",
400 |       "repoId": 840056306,
401 |       "pullRequestNo": 916
402 |     },
403 |     {
404 |       "name": "clsferguson",
405 |       "id": 48876201,
406 |       "comment_id": 3368715688,
407 |       "created_at": "2025-10-05T03:30:10Z",
408 |       "repoId": 840056306,
409 |       "pullRequestNo": 981
410 |     },
411 |     {
412 |       "name": "ngaiyuc",
413 |       "id": 69293565,
414 |       "comment_id": 3407383300,
415 |       "created_at": "2025-10-15T16:45:10Z",
416 |       "repoId": 840056306,
417 |       "pullRequestNo": 1005
418 |     },
419 |     {
420 |       "name": "0fism",
421 |       "id": 63762457,
422 |       "comment_id": 3407328042,
423 |       "created_at": "2025-10-15T16:29:33Z",
424 |       "repoId": 840056306,
425 |       "pullRequestNo": 1005
426 |     },
427 |     {
428 |       "name": "dontang97",
429 |       "id": 88384441,
430 |       "comment_id": 3431443627,
431 |       "created_at": "2025-10-22T09:52:01Z",
432 |       "repoId": 840056306,
433 |       "pullRequestNo": 1020
434 |     },
435 |     {
436 |       "name": "didier-durand",
437 |       "id": 2927957,
438 |       "comment_id": 3460571645,
439 |       "created_at": "2025-10-29T09:31:25Z",
440 |       "repoId": 840056306,
441 |       "pullRequestNo": 1028
442 |     },
443 |     {
444 |       "name": "anubhavgirdhar1",
445 |       "id": 85768253,
446 |       "comment_id": 3468525446,
447 |       "created_at": "2025-10-30T15:11:58Z",
448 |       "repoId": 840056306,
449 |       "pullRequestNo": 1035
450 |     },
451 |     {
452 |       "name": "Galleons2029",
453 |       "id": 88185941,
454 |       "comment_id": 3495884964,
455 |       "created_at": "2025-11-06T08:39:46Z",
456 |       "repoId": 840056306,
457 |       "pullRequestNo": 1053
458 |     },
459 |     {
460 |       "name": "supmo668",
461 |       "id": 28805779,
462 |       "comment_id": 3550309664,
463 |       "created_at": "2025-11-19T01:56:25Z",
464 |       "repoId": 840056306,
465 |       "pullRequestNo": 1072
466 |     },
467 |     {
468 |       "name": "donbr",
469 |       "id": 7340008,
470 |       "comment_id": 3568970102,
471 |       "created_at": "2025-11-24T05:19:42Z",
472 |       "repoId": 840056306,
473 |       "pullRequestNo": 1081
474 |     },
475 |     {
476 |       "name": "apetti1920",
477 |       "id": 4706645,
478 |       "comment_id": 3572726648,
479 |       "created_at": "2025-11-24T21:07:34Z",
480 |       "repoId": 840056306,
481 |       "pullRequestNo": 1084
482 |     },
483 |     {
484 |       "name": "ZLBillShaw",
485 |       "id": 55940186,
486 |       "comment_id": 3583997833,
487 |       "created_at": "2025-11-27T02:45:53Z",
488 |       "repoId": 840056306,
489 |       "pullRequestNo": 1085
490 |     },
491 |     {
492 |       "name": "ronaldmego",
493 |       "id": 17481958,
494 |       "comment_id": 3617267429,
495 |       "created_at": "2025-12-05T14:59:42Z",
496 |       "repoId": 840056306,
497 |       "pullRequestNo": 1094
498 |     },
499 |     {
500 |       "name": "NShumway",
501 |       "id": 29358113,
502 |       "comment_id": 3634967978,
503 |       "created_at": "2025-12-10T01:26:49Z",
504 |       "repoId": 840056306,
505 |       "pullRequestNo": 1102
506 |     },
507 |     {
508 |       "name": "husniadil",
509 |       "id": 10581130,
510 |       "comment_id": 3650156180,
511 |       "created_at": "2025-12-14T03:37:59Z",
512 |       "repoId": 840056306,
513 |       "pullRequestNo": 1105
514 |     },
515 |     {
516 |       "name": "yulongbai-nov",
517 |       "id": 177719410,
518 |       "comment_id": 3654653668,
519 |       "created_at": "2025-12-15T09:34:02Z",
520 |       "repoId": 840056306,
521 |       "pullRequestNo": 1106
522 |     },
523 |     {
524 |       "name": "AlonsoDeCosio",
525 |       "id": 11743394,
526 |       "comment_id": 3661133466,
527 |       "created_at": "2025-12-16T15:29:32Z",
528 |       "repoId": 840056306,
529 |       "pullRequestNo": 1107
530 |     },
531 |     {
532 |       "name": "Ataxia123",
533 |       "id": 22284759,
534 |       "comment_id": 3665072009,
535 |       "created_at": "2025-12-17T12:13:09Z",
536 |       "repoId": 840056306,
537 |       "pullRequestNo": 1109
538 |     },
539 |     {
540 |       "name": "david-morales",
541 |       "id": 7139121,
542 |       "comment_id": 3678178733,
543 |       "created_at": "2025-12-20T22:43:57Z",
544 |       "repoId": 840056306,
545 |       "pullRequestNo": 1117
546 |     },
547 |     {
548 |       "name": "lehcode",
549 |       "id": 53556648,
550 |       "comment_id": 3681728685,
551 |       "created_at": "2025-12-22T11:49:38Z",
552 |       "repoId": 840056306,
553 |       "pullRequestNo": 1120
554 |     },
555 |     {
556 |       "name": "Parteeksachdeva",
557 |       "id": 51407683,
558 |       "comment_id": 3702001948,
559 |       "created_at": "2025-12-31T11:14:17Z",
560 |       "repoId": 840056306,
561 |       "pullRequestNo": 1130
562 |     },
563 |     {
564 |       "name": "JohannesBin",
565 |       "id": 190308091,
566 |       "comment_id": 3704209742,
567 |       "created_at": "2026-01-01T23:03:17Z",
568 |       "repoId": 840056306,
569 |       "pullRequestNo": 1131
570 |     },
571 |     {
572 |       "name": "LongSunnyDay",
573 |       "id": 45385863,
574 |       "comment_id": 3719233680,
575 |       "created_at": "2026-01-07T14:51:46Z",
576 |       "repoId": 840056306,
577 |       "pullRequestNo": 1137
578 |     }
579 |   ]
580 | }
```

--------------------------------------------------------------------------------
/tests/cross_encoder/test_gemini_reranker_client.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/cross_encoder/test_gemini_reranker_client.py
 18 | 
 19 | from unittest.mock import AsyncMock, MagicMock, patch
 20 | 
 21 | import pytest
 22 | 
 23 | from graphiti_core.cross_encoder.gemini_reranker_client import GeminiRerankerClient
 24 | from graphiti_core.llm_client import LLMConfig, RateLimitError
 25 | 
 26 | 
 27 | @pytest.fixture
 28 | def mock_gemini_client():
 29 |     """Fixture to mock the Google Gemini client."""
 30 |     with patch('google.genai.Client') as mock_client:
 31 |         # Setup mock instance and its methods
 32 |         mock_instance = mock_client.return_value
 33 |         mock_instance.aio = MagicMock()
 34 |         mock_instance.aio.models = MagicMock()
 35 |         mock_instance.aio.models.generate_content = AsyncMock()
 36 |         yield mock_instance
 37 | 
 38 | 
 39 | @pytest.fixture
 40 | def gemini_reranker_client(mock_gemini_client):
 41 |     """Fixture to create a GeminiRerankerClient with a mocked client."""
 42 |     config = LLMConfig(api_key='test_api_key', model='test-model')
 43 |     client = GeminiRerankerClient(config=config)
 44 |     # Replace the client's client with our mock to ensure we're using the mock
 45 |     client.client = mock_gemini_client
 46 |     return client
 47 | 
 48 | 
 49 | def create_mock_response(score_text: str) -> MagicMock:
 50 |     """Helper function to create a mock Gemini response."""
 51 |     mock_response = MagicMock()
 52 |     mock_response.text = score_text
 53 |     return mock_response
 54 | 
 55 | 
 56 | class TestGeminiRerankerClientInitialization:
 57 |     """Tests for GeminiRerankerClient initialization."""
 58 | 
 59 |     def test_init_with_config(self):
 60 |         """Test initialization with a config object."""
 61 |         config = LLMConfig(api_key='test_api_key', model='test-model')
 62 |         client = GeminiRerankerClient(config=config)
 63 | 
 64 |         assert client.config == config
 65 | 
 66 |     @patch('google.genai.Client')
 67 |     def test_init_without_config(self, mock_client):
 68 |         """Test initialization without a config uses defaults."""
 69 |         client = GeminiRerankerClient()
 70 | 
 71 |         assert client.config is not None
 72 | 
 73 |     def test_init_with_custom_client(self):
 74 |         """Test initialization with a custom client."""
 75 |         mock_client = MagicMock()
 76 |         client = GeminiRerankerClient(client=mock_client)
 77 | 
 78 |         assert client.client == mock_client
 79 | 
 80 | 
 81 | class TestGeminiRerankerClientRanking:
 82 |     """Tests for GeminiRerankerClient rank method."""
 83 | 
 84 |     @pytest.mark.asyncio
 85 |     async def test_rank_basic_functionality(self, gemini_reranker_client, mock_gemini_client):
 86 |         """Test basic ranking functionality."""
 87 |         # Setup mock responses with different scores
 88 |         mock_responses = [
 89 |             create_mock_response('85'),  # High relevance
 90 |             create_mock_response('45'),  # Medium relevance
 91 |             create_mock_response('20'),  # Low relevance
 92 |         ]
 93 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
 94 | 
 95 |         # Test data
 96 |         query = 'What is the capital of France?'
 97 |         passages = [
 98 |             'Paris is the capital and most populous city of France.',
 99 |             'London is the capital city of England and the United Kingdom.',
100 |             'Berlin is the capital and largest city of Germany.',
101 |         ]
102 | 
103 |         # Call method
104 |         result = await gemini_reranker_client.rank(query, passages)
105 | 
106 |         # Assertions
107 |         assert len(result) == 3
108 |         assert all(isinstance(item, tuple) for item in result)
109 |         assert all(
110 |             isinstance(passage, str) and isinstance(score, float) for passage, score in result
111 |         )
112 | 
113 |         # Check scores are normalized to [0, 1] and sorted in descending order
114 |         scores = [score for _, score in result]
115 |         assert all(0.0 <= score <= 1.0 for score in scores)
116 |         assert scores == sorted(scores, reverse=True)
117 | 
118 |         # Check that the highest scoring passage is first
119 |         assert result[0][1] == 0.85  # 85/100
120 |         assert result[1][1] == 0.45  # 45/100
121 |         assert result[2][1] == 0.20  # 20/100
122 | 
123 |     @pytest.mark.asyncio
124 |     async def test_rank_empty_passages(self, gemini_reranker_client):
125 |         """Test ranking with empty passages list."""
126 |         query = 'Test query'
127 |         passages = []
128 | 
129 |         result = await gemini_reranker_client.rank(query, passages)
130 | 
131 |         assert result == []
132 | 
133 |     @pytest.mark.asyncio
134 |     async def test_rank_single_passage(self, gemini_reranker_client, mock_gemini_client):
135 |         """Test ranking with a single passage."""
136 |         # Setup mock response
137 |         mock_gemini_client.aio.models.generate_content.return_value = create_mock_response('75')
138 | 
139 |         query = 'Test query'
140 |         passages = ['Single test passage']
141 | 
142 |         result = await gemini_reranker_client.rank(query, passages)
143 | 
144 |         assert len(result) == 1
145 |         assert result[0][0] == 'Single test passage'
146 |         assert result[0][1] == 1.0  # Single passage gets full score
147 | 
148 |     @pytest.mark.asyncio
149 |     async def test_rank_score_extraction_with_regex(
150 |         self, gemini_reranker_client, mock_gemini_client
151 |     ):
152 |         """Test score extraction from various response formats."""
153 |         # Setup mock responses with different formats
154 |         mock_responses = [
155 |             create_mock_response('Score: 90'),  # Contains text before number
156 |             create_mock_response('The relevance is 65 out of 100'),  # Contains text around number
157 |             create_mock_response('8'),  # Just the number
158 |         ]
159 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
160 | 
161 |         query = 'Test query'
162 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
163 | 
164 |         result = await gemini_reranker_client.rank(query, passages)
165 | 
166 |         # Check that scores were extracted correctly and normalized
167 |         scores = [score for _, score in result]
168 |         assert 0.90 in scores  # 90/100
169 |         assert 0.65 in scores  # 65/100
170 |         assert 0.08 in scores  # 8/100
171 | 
172 |     @pytest.mark.asyncio
173 |     async def test_rank_invalid_score_handling(self, gemini_reranker_client, mock_gemini_client):
174 |         """Test handling of invalid or non-numeric scores."""
175 |         # Setup mock responses with invalid scores
176 |         mock_responses = [
177 |             create_mock_response('Not a number'),  # Invalid response
178 |             create_mock_response(''),  # Empty response
179 |             create_mock_response('95'),  # Valid response
180 |         ]
181 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
182 | 
183 |         query = 'Test query'
184 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
185 | 
186 |         result = await gemini_reranker_client.rank(query, passages)
187 | 
188 |         # Check that invalid scores are handled gracefully (assigned 0.0)
189 |         scores = [score for _, score in result]
190 |         assert 0.95 in scores  # Valid score
191 |         assert scores.count(0.0) == 2  # Two invalid scores assigned 0.0
192 | 
193 |     @pytest.mark.asyncio
194 |     async def test_rank_score_clamping(self, gemini_reranker_client, mock_gemini_client):
195 |         """Test that scores are properly clamped to [0, 1] range."""
196 |         # Setup mock responses with extreme scores
197 |         # Note: regex only matches 1-3 digits, so negative numbers won't match
198 |         mock_responses = [
199 |             create_mock_response('999'),  # Above 100 but within regex range
200 |             create_mock_response('invalid'),  # Invalid response becomes 0.0
201 |             create_mock_response('50'),  # Normal score
202 |         ]
203 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
204 | 
205 |         query = 'Test query'
206 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
207 | 
208 |         result = await gemini_reranker_client.rank(query, passages)
209 | 
210 |         # Check that scores are normalized and clamped
211 |         scores = [score for _, score in result]
212 |         assert all(0.0 <= score <= 1.0 for score in scores)
213 |         # 999 should be clamped to 1.0 (999/100 = 9.99, clamped to 1.0)
214 |         assert 1.0 in scores
215 |         # Invalid response should be 0.0
216 |         assert 0.0 in scores
217 |         # Normal score should be normalized (50/100 = 0.5)
218 |         assert 0.5 in scores
219 | 
220 |     @pytest.mark.asyncio
221 |     async def test_rank_rate_limit_error(self, gemini_reranker_client, mock_gemini_client):
222 |         """Test handling of rate limit errors."""
223 |         # Setup mock to raise rate limit error
224 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
225 |             'Rate limit exceeded'
226 |         )
227 | 
228 |         query = 'Test query'
229 |         passages = ['Passage 1', 'Passage 2']
230 | 
231 |         with pytest.raises(RateLimitError):
232 |             await gemini_reranker_client.rank(query, passages)
233 | 
234 |     @pytest.mark.asyncio
235 |     async def test_rank_quota_error(self, gemini_reranker_client, mock_gemini_client):
236 |         """Test handling of quota errors."""
237 |         # Setup mock to raise quota error
238 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception('Quota exceeded')
239 | 
240 |         query = 'Test query'
241 |         passages = ['Passage 1', 'Passage 2']
242 | 
243 |         with pytest.raises(RateLimitError):
244 |             await gemini_reranker_client.rank(query, passages)
245 | 
246 |     @pytest.mark.asyncio
247 |     async def test_rank_resource_exhausted_error(self, gemini_reranker_client, mock_gemini_client):
248 |         """Test handling of resource exhausted errors."""
249 |         # Setup mock to raise resource exhausted error
250 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception('resource_exhausted')
251 | 
252 |         query = 'Test query'
253 |         passages = ['Passage 1', 'Passage 2']
254 | 
255 |         with pytest.raises(RateLimitError):
256 |             await gemini_reranker_client.rank(query, passages)
257 | 
258 |     @pytest.mark.asyncio
259 |     async def test_rank_429_error(self, gemini_reranker_client, mock_gemini_client):
260 |         """Test handling of HTTP 429 errors."""
261 |         # Setup mock to raise 429 error
262 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception(
263 |             'HTTP 429 Too Many Requests'
264 |         )
265 | 
266 |         query = 'Test query'
267 |         passages = ['Passage 1', 'Passage 2']
268 | 
269 |         with pytest.raises(RateLimitError):
270 |             await gemini_reranker_client.rank(query, passages)
271 | 
272 |     @pytest.mark.asyncio
273 |     async def test_rank_generic_error(self, gemini_reranker_client, mock_gemini_client):
274 |         """Test handling of generic errors."""
275 |         # Setup mock to raise generic error
276 |         mock_gemini_client.aio.models.generate_content.side_effect = Exception('Generic error')
277 | 
278 |         query = 'Test query'
279 |         passages = ['Passage 1', 'Passage 2']
280 | 
281 |         with pytest.raises(Exception) as exc_info:
282 |             await gemini_reranker_client.rank(query, passages)
283 | 
284 |         assert 'Generic error' in str(exc_info.value)
285 | 
286 |     @pytest.mark.asyncio
287 |     async def test_rank_concurrent_requests(self, gemini_reranker_client, mock_gemini_client):
288 |         """Test that multiple passages are scored concurrently."""
289 |         # Setup mock responses
290 |         mock_responses = [
291 |             create_mock_response('80'),
292 |             create_mock_response('60'),
293 |             create_mock_response('40'),
294 |         ]
295 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
296 | 
297 |         query = 'Test query'
298 |         passages = ['Passage 1', 'Passage 2', 'Passage 3']
299 | 
300 |         await gemini_reranker_client.rank(query, passages)
301 | 
302 |         # Verify that generate_content was called for each passage
303 |         assert mock_gemini_client.aio.models.generate_content.call_count == 3
304 | 
305 |         # Verify that all calls were made with correct parameters
306 |         calls = mock_gemini_client.aio.models.generate_content.call_args_list
307 |         for call in calls:
308 |             args, kwargs = call
309 |             assert kwargs['model'] == gemini_reranker_client.config.model
310 |             assert kwargs['config'].temperature == 0.0
311 |             assert kwargs['config'].max_output_tokens == 3
312 | 
313 |     @pytest.mark.asyncio
314 |     async def test_rank_response_parsing_error(self, gemini_reranker_client, mock_gemini_client):
315 |         """Test handling of response parsing errors."""
316 |         # Setup mock responses that will trigger ValueError during parsing
317 |         mock_responses = [
318 |             create_mock_response('not a number at all'),  # Will fail regex match
319 |             create_mock_response('also invalid text'),  # Will fail regex match
320 |         ]
321 |         mock_gemini_client.aio.models.generate_content.side_effect = mock_responses
322 | 
323 |         query = 'Test query'
324 |         # Use multiple passages to avoid the single passage special case
325 |         passages = ['Passage 1', 'Passage 2']
326 | 
327 |         result = await gemini_reranker_client.rank(query, passages)
328 | 
329 |         # Should handle the error gracefully and assign 0.0 score to both
330 |         assert len(result) == 2
331 |         assert all(score == 0.0 for _, score in result)
332 | 
333 |     @pytest.mark.asyncio
334 |     async def test_rank_empty_response_text(self, gemini_reranker_client, mock_gemini_client):
335 |         """Test handling of empty response text."""
336 |         # Setup mock response with empty text
337 |         mock_response = MagicMock()
338 |         mock_response.text = ''  # Empty string instead of None
339 |         mock_gemini_client.aio.models.generate_content.return_value = mock_response
340 | 
341 |         query = 'Test query'
342 |         # Use multiple passages to avoid the single passage special case
343 |         passages = ['Passage 1', 'Passage 2']
344 | 
345 |         result = await gemini_reranker_client.rank(query, passages)
346 | 
347 |         # Should handle empty text gracefully and assign 0.0 score to both
348 |         assert len(result) == 2
349 |         assert all(score == 0.0 for _, score in result)
350 | 
351 | 
352 | if __name__ == '__main__':
353 |     pytest.main(['-v', 'test_gemini_reranker_client.py'])
354 | 
```

--------------------------------------------------------------------------------
/examples/quickstart/dense_vs_normal_ingestion.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2025, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | 
 16 | Dense vs Normal Episode Ingestion Example
 17 | -----------------------------------------
 18 | This example demonstrates how Graphiti handles different types of content:
 19 | 
 20 | 1. Normal Content (prose, narrative, conversations):
 21 |    - Lower entity density (few entities per token)
 22 |    - Processed in a single LLM call
 23 |    - Examples: meeting transcripts, news articles, documentation
 24 | 
 25 | 2. Dense Content (structured data with many entities):
 26 |    - High entity density (many entities per token)
 27 |    - Automatically chunked for reliable extraction
 28 |    - Examples: bulk data imports, cost reports, entity-dense JSON
 29 | 
 30 | The chunking behavior is controlled by environment variables:
 31 | - CHUNK_MIN_TOKENS: Minimum tokens before considering chunking (default: 1000)
 32 | - CHUNK_DENSITY_THRESHOLD: Entity density threshold (default: 0.15)
 33 | - CHUNK_TOKEN_SIZE: Target size per chunk (default: 3000)
 34 | - CHUNK_OVERLAP_TOKENS: Overlap between chunks (default: 200)
 35 | """
 36 | 
 37 | import asyncio
 38 | import json
 39 | import logging
 40 | import os
 41 | from datetime import datetime, timezone
 42 | from logging import INFO
 43 | 
 44 | from dotenv import load_dotenv
 45 | 
 46 | from graphiti_core import Graphiti
 47 | from graphiti_core.nodes import EpisodeType
 48 | 
 49 | #################################################
 50 | # CONFIGURATION
 51 | #################################################
 52 | 
 53 | logging.basicConfig(
 54 |     level=INFO,
 55 |     format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 56 |     datefmt='%Y-%m-%d %H:%M:%S',
 57 | )
 58 | logger = logging.getLogger(__name__)
 59 | 
 60 | load_dotenv()
 61 | 
 62 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
 63 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
 64 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
 65 | 
 66 | if not neo4j_uri or not neo4j_user or not neo4j_password:
 67 |     raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
 68 | 
 69 | 
 70 | #################################################
 71 | # EXAMPLE DATA
 72 | #################################################
 73 | 
 74 | # Normal content: A meeting transcript (low entity density)
 75 | # This is prose/narrative content with few entities per token.
 76 | # It will NOT trigger chunking - processed in a single LLM call.
 77 | NORMAL_EPISODE_CONTENT = """
 78 | Meeting Notes - Q4 Planning Session
 79 | 
 80 | Alice opened the meeting by reviewing our progress on the mobile app redesign.
 81 | She mentioned that the user research phase went well and highlighted key findings
 82 | from the customer interviews conducted last month.
 83 | 
 84 | Bob then presented the engineering timeline. He explained that the backend API
 85 | refactoring is about 60% complete and should be finished by end of November.
 86 | The team has resolved most of the performance issues identified in the load tests.
 87 | 
 88 | Carol raised concerns about the holiday freeze period affecting our deployment
 89 | schedule. She suggested we move the beta launch to early December to give the
 90 | QA team enough time for regression testing before the code freeze.
 91 | 
 92 | David agreed with Carol's assessment and proposed allocating two additional
 93 | engineers from the platform team to help with the testing effort. He also
 94 | mentioned that the documentation needs to be updated before the release.
 95 | 
 96 | Action items:
 97 | - Alice will finalize the design specs by Friday
 98 | - Bob will coordinate with the platform team on resource allocation
 99 | - Carol will update the project timeline in Jira
100 | - David will schedule a follow-up meeting for next Tuesday
101 | 
102 | The meeting concluded at 3:30 PM with agreement to reconvene next week.
103 | """
104 | 
105 | # Dense content: AWS cost data (high entity density)
106 | # This is structured data with many entities per token.
107 | # It WILL trigger chunking - processed in multiple LLM calls.
108 | DENSE_EPISODE_CONTENT = {
109 |     'report_type': 'AWS Cost Breakdown',
110 |     'months': [
111 |         {
112 |             'period': '2025-01',
113 |             'services': [
114 |                 {'name': 'Amazon S3', 'cost': 2487.97},
115 |                 {'name': 'Amazon RDS', 'cost': 1071.74},
116 |                 {'name': 'Amazon ECS', 'cost': 853.74},
117 |                 {'name': 'Amazon OpenSearch', 'cost': 389.74},
118 |                 {'name': 'AWS Secrets Manager', 'cost': 265.77},
119 |                 {'name': 'CloudWatch', 'cost': 232.34},
120 |                 {'name': 'Amazon VPC', 'cost': 238.39},
121 |                 {'name': 'EC2 Other', 'cost': 226.82},
122 |                 {'name': 'Amazon EC2 Compute', 'cost': 78.27},
123 |                 {'name': 'Amazon DocumentDB', 'cost': 65.40},
124 |                 {'name': 'Amazon ECR', 'cost': 29.00},
125 |                 {'name': 'Amazon ELB', 'cost': 37.53},
126 |             ],
127 |         },
128 |         {
129 |             'period': '2025-02',
130 |             'services': [
131 |                 {'name': 'Amazon S3', 'cost': 2721.04},
132 |                 {'name': 'Amazon RDS', 'cost': 1035.77},
133 |                 {'name': 'Amazon ECS', 'cost': 779.49},
134 |                 {'name': 'Amazon OpenSearch', 'cost': 357.90},
135 |                 {'name': 'AWS Secrets Manager', 'cost': 268.57},
136 |                 {'name': 'CloudWatch', 'cost': 224.57},
137 |                 {'name': 'Amazon VPC', 'cost': 215.15},
138 |                 {'name': 'EC2 Other', 'cost': 213.86},
139 |                 {'name': 'Amazon EC2 Compute', 'cost': 70.70},
140 |                 {'name': 'Amazon DocumentDB', 'cost': 59.07},
141 |                 {'name': 'Amazon ECR', 'cost': 33.92},
142 |                 {'name': 'Amazon ELB', 'cost': 33.89},
143 |             ],
144 |         },
145 |         {
146 |             'period': '2025-03',
147 |             'services': [
148 |                 {'name': 'Amazon S3', 'cost': 2952.31},
149 |                 {'name': 'Amazon RDS', 'cost': 1198.79},
150 |                 {'name': 'Amazon ECS', 'cost': 869.78},
151 |                 {'name': 'Amazon OpenSearch', 'cost': 389.75},
152 |                 {'name': 'AWS Secrets Manager', 'cost': 271.33},
153 |                 {'name': 'CloudWatch', 'cost': 233.00},
154 |                 {'name': 'Amazon VPC', 'cost': 238.31},
155 |                 {'name': 'EC2 Other', 'cost': 227.78},
156 |                 {'name': 'Amazon EC2 Compute', 'cost': 78.21},
157 |                 {'name': 'Amazon DocumentDB', 'cost': 65.40},
158 |                 {'name': 'Amazon ECR', 'cost': 33.75},
159 |                 {'name': 'Amazon ELB', 'cost': 37.54},
160 |             ],
161 |         },
162 |         {
163 |             'period': '2025-04',
164 |             'services': [
165 |                 {'name': 'Amazon S3', 'cost': 3189.62},
166 |                 {'name': 'Amazon RDS', 'cost': 1102.30},
167 |                 {'name': 'Amazon ECS', 'cost': 848.19},
168 |                 {'name': 'Amazon OpenSearch', 'cost': 379.14},
169 |                 {'name': 'AWS Secrets Manager', 'cost': 270.89},
170 |                 {'name': 'CloudWatch', 'cost': 230.64},
171 |                 {'name': 'Amazon VPC', 'cost': 230.54},
172 |                 {'name': 'EC2 Other', 'cost': 220.18},
173 |                 {'name': 'Amazon EC2 Compute', 'cost': 75.70},
174 |                 {'name': 'Amazon DocumentDB', 'cost': 63.29},
175 |                 {'name': 'Amazon ECR', 'cost': 35.21},
176 |                 {'name': 'Amazon ELB', 'cost': 36.30},
177 |             ],
178 |         },
179 |         {
180 |             'period': '2025-05',
181 |             'services': [
182 |                 {'name': 'Amazon S3', 'cost': 3423.07},
183 |                 {'name': 'Amazon RDS', 'cost': 1014.50},
184 |                 {'name': 'Amazon ECS', 'cost': 874.75},
185 |                 {'name': 'Amazon OpenSearch', 'cost': 389.71},
186 |                 {'name': 'AWS Secrets Manager', 'cost': 274.91},
187 |                 {'name': 'CloudWatch', 'cost': 233.28},
188 |                 {'name': 'Amazon VPC', 'cost': 238.53},
189 |                 {'name': 'EC2 Other', 'cost': 227.27},
190 |                 {'name': 'Amazon EC2 Compute', 'cost': 78.27},
191 |                 {'name': 'Amazon DocumentDB', 'cost': 65.40},
192 |                 {'name': 'Amazon ECR', 'cost': 37.42},
193 |                 {'name': 'Amazon ELB', 'cost': 37.52},
194 |             ],
195 |         },
196 |         {
197 |             'period': '2025-06',
198 |             'services': [
199 |                 {'name': 'Amazon S3', 'cost': 3658.14},
200 |                 {'name': 'Amazon RDS', 'cost': 963.60},
201 |                 {'name': 'Amazon ECS', 'cost': 942.45},
202 |                 {'name': 'Amazon OpenSearch', 'cost': 379.06},
203 |                 {'name': 'AWS Secrets Manager', 'cost': 282.41},
204 |                 {'name': 'CloudWatch', 'cost': 230.69},
205 |                 {'name': 'Amazon VPC', 'cost': 230.64},
206 |                 {'name': 'EC2 Other', 'cost': 220.11},
207 |                 {'name': 'Amazon EC2 Compute', 'cost': 75.74},
208 |                 {'name': 'Amazon DocumentDB', 'cost': 63.30},
209 |                 {'name': 'Amazon ECR', 'cost': 39.49},
210 |                 {'name': 'Amazon ELB', 'cost': 36.30},
211 |             ],
212 |         },
213 |     ],
214 | }
215 | 
216 | # Message content: A conversation (moderate entity density)
217 | # This demonstrates message-type episodes with speaker patterns.
218 | MESSAGE_EPISODE_CONTENT = """User: I need to set up a new development environment for the analytics project.
219 | Assistant: I can help with that. Which cloud provider are you using - AWS, GCP, or Azure?
220 | User: We're using AWS. The project requires PostgreSQL, Redis, and Elasticsearch.
221 | Assistant: Great choice. For PostgreSQL, I recommend Amazon RDS with the db.r5.large instance type. For Redis, you can use Amazon ElastiCache. For Elasticsearch, Amazon OpenSearch Service would be ideal.
222 | User: What about the compute layer? We need to run Python data pipelines.
223 | Assistant: For Python data pipelines, I suggest Amazon ECS with Fargate for serverless containers, or AWS Lambda for smaller, event-driven tasks. If you need more control, EC2 instances with Auto Scaling would work well.
224 | User: Let's go with ECS Fargate. Can you also recommend a CI/CD setup?
225 | Assistant: For CI/CD with ECS Fargate, I recommend AWS CodePipeline integrated with CodeBuild for building container images, and CodeDeploy for blue-green deployments. You can store your container images in Amazon ECR.
226 | """
227 | 
228 | 
229 | async def main():
230 |     graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
231 | 
232 |     try:
233 |         #################################################
234 |         # EXAMPLE 1: Normal Content (No Chunking)
235 |         #################################################
236 |         # This prose content has low entity density.
237 |         # Graphiti will process it in a single LLM call.
238 |         #################################################
239 | 
240 |         print('=' * 60)
241 |         print('EXAMPLE 1: Normal Content (Meeting Transcript)')
242 |         print('=' * 60)
243 |         print(f'Content length: {len(NORMAL_EPISODE_CONTENT)} characters')
244 |         print(f'Estimated tokens: ~{len(NORMAL_EPISODE_CONTENT) // 4}')
245 |         print('Expected behavior: Single LLM call (no chunking)')
246 |         print()
247 | 
248 |         await graphiti.add_episode(
249 |             name='Q4 Planning Meeting',
250 |             episode_body=NORMAL_EPISODE_CONTENT,
251 |             source=EpisodeType.text,
252 |             source_description='Meeting transcript',
253 |             reference_time=datetime.now(timezone.utc),
254 |         )
255 |         print('Successfully added normal episode\n')
256 | 
257 |         #################################################
258 |         # EXAMPLE 2: Dense Content (Chunking Triggered)
259 |         #################################################
260 |         # This structured data has high entity density.
261 |         # Graphiti will automatically chunk it for
262 |         # reliable extraction across multiple LLM calls.
263 |         #################################################
264 | 
265 |         print('=' * 60)
266 |         print('EXAMPLE 2: Dense Content (AWS Cost Report)')
267 |         print('=' * 60)
268 |         dense_json = json.dumps(DENSE_EPISODE_CONTENT)
269 |         print(f'Content length: {len(dense_json)} characters')
270 |         print(f'Estimated tokens: ~{len(dense_json) // 4}')
271 |         print('Expected behavior: Multiple LLM calls (chunking enabled)')
272 |         print()
273 | 
274 |         await graphiti.add_episode(
275 |             name='AWS Cost Report 2025 H1',
276 |             episode_body=dense_json,
277 |             source=EpisodeType.json,
278 |             source_description='AWS cost breakdown by service',
279 |             reference_time=datetime.now(timezone.utc),
280 |         )
281 |         print('Successfully added dense episode\n')
282 | 
283 |         #################################################
284 |         # EXAMPLE 3: Message Content
285 |         #################################################
286 |         # Conversation content with speaker patterns.
287 |         # Chunking preserves message boundaries.
288 |         #################################################
289 | 
290 |         print('=' * 60)
291 |         print('EXAMPLE 3: Message Content (Conversation)')
292 |         print('=' * 60)
293 |         print(f'Content length: {len(MESSAGE_EPISODE_CONTENT)} characters')
294 |         print(f'Estimated tokens: ~{len(MESSAGE_EPISODE_CONTENT) // 4}')
295 |         print('Expected behavior: Depends on density threshold')
296 |         print()
297 | 
298 |         await graphiti.add_episode(
299 |             name='Dev Environment Setup Chat',
300 |             episode_body=MESSAGE_EPISODE_CONTENT,
301 |             source=EpisodeType.message,
302 |             source_description='Support conversation',
303 |             reference_time=datetime.now(timezone.utc),
304 |         )
305 |         print('Successfully added message episode\n')
306 | 
307 |         #################################################
308 |         # SEARCH RESULTS
309 |         #################################################
310 | 
311 |         print('=' * 60)
312 |         print('SEARCH: Verifying extracted entities')
313 |         print('=' * 60)
314 | 
315 |         # Search for entities from normal content
316 |         print("\nSearching for: 'Q4 planning meeting participants'")
317 |         results = await graphiti.search('Q4 planning meeting participants')
318 |         print(f'Found {len(results)} results')
319 |         for r in results[:3]:
320 |             print(f'  - {r.fact}')
321 | 
322 |         # Search for entities from dense content
323 |         print("\nSearching for: 'AWS S3 costs'")
324 |         results = await graphiti.search('AWS S3 costs')
325 |         print(f'Found {len(results)} results')
326 |         for r in results[:3]:
327 |             print(f'  - {r.fact}')
328 | 
329 |         # Search for entities from message content
330 |         print("\nSearching for: 'ECS Fargate recommendations'")
331 |         results = await graphiti.search('ECS Fargate recommendations')
332 |         print(f'Found {len(results)} results')
333 |         for r in results[:3]:
334 |             print(f'  - {r.fact}')
335 | 
336 |     finally:
337 |         await graphiti.close()
338 |         print('\nConnection closed')
339 | 
340 | 
341 | if __name__ == '__main__':
342 |     asyncio.run(main())
343 | 
```

--------------------------------------------------------------------------------
/tests/test_edge_int.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import logging
 18 | import sys
 19 | from datetime import datetime
 20 | 
 21 | import numpy as np
 22 | import pytest
 23 | 
 24 | from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
 25 | from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
 26 | from tests.helpers_test import get_edge_count, get_node_count, group_id
 27 | 
 28 | pytest_plugins = ('pytest_asyncio',)
 29 | 
 30 | 
 31 | def setup_logging():
 32 |     # Create a logger
 33 |     logger = logging.getLogger()
 34 |     logger.setLevel(logging.INFO)  # Set the logging level to INFO
 35 | 
 36 |     # Create console handler and set level to INFO
 37 |     console_handler = logging.StreamHandler(sys.stdout)
 38 |     console_handler.setLevel(logging.INFO)
 39 | 
 40 |     # Create formatter
 41 |     formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 42 | 
 43 |     # Add formatter to console handler
 44 |     console_handler.setFormatter(formatter)
 45 | 
 46 |     # Add console handler to logger
 47 |     logger.addHandler(console_handler)
 48 | 
 49 |     return logger
 50 | 
 51 | 
 52 | @pytest.mark.asyncio
 53 | async def test_episodic_edge(graph_driver, mock_embedder):
 54 |     now = datetime.now()
 55 | 
 56 |     # Create episodic node
 57 |     episode_node = EpisodicNode(
 58 |         name='test_episode',
 59 |         labels=[],
 60 |         created_at=now,
 61 |         valid_at=now,
 62 |         source=EpisodeType.message,
 63 |         source_description='conversation message',
 64 |         content='Alice likes Bob',
 65 |         entity_edges=[],
 66 |         group_id=group_id,
 67 |     )
 68 |     node_count = await get_node_count(graph_driver, [episode_node.uuid])
 69 |     assert node_count == 0
 70 |     await episode_node.save(graph_driver)
 71 |     node_count = await get_node_count(graph_driver, [episode_node.uuid])
 72 |     assert node_count == 1
 73 | 
 74 |     # Create entity node
 75 |     alice_node = EntityNode(
 76 |         name='Alice',
 77 |         labels=[],
 78 |         created_at=now,
 79 |         summary='Alice summary',
 80 |         group_id=group_id,
 81 |     )
 82 |     await alice_node.generate_name_embedding(mock_embedder)
 83 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
 84 |     assert node_count == 0
 85 |     await alice_node.save(graph_driver)
 86 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
 87 |     assert node_count == 1
 88 | 
 89 |     # Create episodic to entity edge
 90 |     episodic_edge = EpisodicEdge(
 91 |         source_node_uuid=episode_node.uuid,
 92 |         target_node_uuid=alice_node.uuid,
 93 |         created_at=now,
 94 |         group_id=group_id,
 95 |     )
 96 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
 97 |     assert edge_count == 0
 98 |     await episodic_edge.save(graph_driver)
 99 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
100 |     assert edge_count == 1
101 | 
102 |     # Get edge by uuid
103 |     retrieved = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge.uuid)
104 |     assert retrieved.uuid == episodic_edge.uuid
105 |     assert retrieved.source_node_uuid == episode_node.uuid
106 |     assert retrieved.target_node_uuid == alice_node.uuid
107 |     assert retrieved.created_at == now
108 |     assert retrieved.group_id == group_id
109 | 
110 |     # Get edge by uuids
111 |     retrieved = await EpisodicEdge.get_by_uuids(graph_driver, [episodic_edge.uuid])
112 |     assert len(retrieved) == 1
113 |     assert retrieved[0].uuid == episodic_edge.uuid
114 |     assert retrieved[0].source_node_uuid == episode_node.uuid
115 |     assert retrieved[0].target_node_uuid == alice_node.uuid
116 |     assert retrieved[0].created_at == now
117 |     assert retrieved[0].group_id == group_id
118 | 
119 |     # Get edge by group ids
120 |     retrieved = await EpisodicEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
121 |     assert len(retrieved) == 1
122 |     assert retrieved[0].uuid == episodic_edge.uuid
123 |     assert retrieved[0].source_node_uuid == episode_node.uuid
124 |     assert retrieved[0].target_node_uuid == alice_node.uuid
125 |     assert retrieved[0].created_at == now
126 |     assert retrieved[0].group_id == group_id
127 | 
128 |     # Get episodic node by entity node uuid
129 |     retrieved = await EpisodicNode.get_by_entity_node_uuid(graph_driver, alice_node.uuid)
130 |     assert len(retrieved) == 1
131 |     assert retrieved[0].uuid == episode_node.uuid
132 |     assert retrieved[0].name == 'test_episode'
133 |     assert retrieved[0].created_at == now
134 |     assert retrieved[0].group_id == group_id
135 | 
136 |     # Delete edge by uuid
137 |     await episodic_edge.delete(graph_driver)
138 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
139 |     assert edge_count == 0
140 | 
141 |     # Delete edge by uuids
142 |     await episodic_edge.save(graph_driver)
143 |     await episodic_edge.delete_by_uuids(graph_driver, [episodic_edge.uuid])
144 |     edge_count = await get_edge_count(graph_driver, [episodic_edge.uuid])
145 |     assert edge_count == 0
146 | 
147 |     # Cleanup nodes
148 |     await episode_node.delete(graph_driver)
149 |     node_count = await get_node_count(graph_driver, [episode_node.uuid])
150 |     assert node_count == 0
151 |     await alice_node.delete(graph_driver)
152 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
153 |     assert node_count == 0
154 | 
155 |     await graph_driver.close()
156 | 
157 | 
158 | @pytest.mark.asyncio
159 | async def test_entity_edge(graph_driver, mock_embedder):
160 |     now = datetime.now()
161 | 
162 |     # Create entity node
163 |     alice_node = EntityNode(
164 |         name='Alice',
165 |         labels=[],
166 |         created_at=now,
167 |         summary='Alice summary',
168 |         group_id=group_id,
169 |     )
170 |     await alice_node.generate_name_embedding(mock_embedder)
171 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
172 |     assert node_count == 0
173 |     await alice_node.save(graph_driver)
174 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
175 |     assert node_count == 1
176 | 
177 |     # Create entity node
178 |     bob_node = EntityNode(
179 |         name='Bob', labels=[], created_at=now, summary='Bob summary', group_id=group_id
180 |     )
181 |     await bob_node.generate_name_embedding(mock_embedder)
182 |     node_count = await get_node_count(graph_driver, [bob_node.uuid])
183 |     assert node_count == 0
184 |     await bob_node.save(graph_driver)
185 |     node_count = await get_node_count(graph_driver, [bob_node.uuid])
186 |     assert node_count == 1
187 | 
188 |     # Create entity to entity edge
189 |     entity_edge = EntityEdge(
190 |         source_node_uuid=alice_node.uuid,
191 |         target_node_uuid=bob_node.uuid,
192 |         created_at=now,
193 |         name='likes',
194 |         fact='Alice likes Bob',
195 |         episodes=[],
196 |         expired_at=now,
197 |         valid_at=now,
198 |         invalid_at=now,
199 |         group_id=group_id,
200 |     )
201 |     edge_embedding = await entity_edge.generate_embedding(mock_embedder)
202 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
203 |     assert edge_count == 0
204 |     await entity_edge.save(graph_driver)
205 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
206 |     assert edge_count == 1
207 | 
208 |     # Get edge by uuid
209 |     retrieved = await EntityEdge.get_by_uuid(graph_driver, entity_edge.uuid)
210 |     assert retrieved.uuid == entity_edge.uuid
211 |     assert retrieved.source_node_uuid == alice_node.uuid
212 |     assert retrieved.target_node_uuid == bob_node.uuid
213 |     assert retrieved.created_at == now
214 |     assert retrieved.group_id == group_id
215 | 
216 |     # Get edge by uuids
217 |     retrieved = await EntityEdge.get_by_uuids(graph_driver, [entity_edge.uuid])
218 |     assert len(retrieved) == 1
219 |     assert retrieved[0].uuid == entity_edge.uuid
220 |     assert retrieved[0].source_node_uuid == alice_node.uuid
221 |     assert retrieved[0].target_node_uuid == bob_node.uuid
222 |     assert retrieved[0].created_at == now
223 |     assert retrieved[0].group_id == group_id
224 | 
225 |     # Get edge by group ids
226 |     retrieved = await EntityEdge.get_by_group_ids(graph_driver, [group_id], limit=2)
227 |     assert len(retrieved) == 1
228 |     assert retrieved[0].uuid == entity_edge.uuid
229 |     assert retrieved[0].source_node_uuid == alice_node.uuid
230 |     assert retrieved[0].target_node_uuid == bob_node.uuid
231 |     assert retrieved[0].created_at == now
232 |     assert retrieved[0].group_id == group_id
233 | 
234 |     # Get edge by node uuid
235 |     retrieved = await EntityEdge.get_by_node_uuid(graph_driver, alice_node.uuid)
236 |     assert len(retrieved) == 1
237 |     assert retrieved[0].uuid == entity_edge.uuid
238 |     assert retrieved[0].source_node_uuid == alice_node.uuid
239 |     assert retrieved[0].target_node_uuid == bob_node.uuid
240 |     assert retrieved[0].created_at == now
241 |     assert retrieved[0].group_id == group_id
242 | 
243 |     # Get fact embedding
244 |     await entity_edge.load_fact_embedding(graph_driver)
245 |     assert np.allclose(entity_edge.fact_embedding, edge_embedding)
246 | 
247 |     # Delete edge by uuid
248 |     await entity_edge.delete(graph_driver)
249 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
250 |     assert edge_count == 0
251 | 
252 |     # Delete edge by uuids
253 |     await entity_edge.save(graph_driver)
254 |     await entity_edge.delete_by_uuids(graph_driver, [entity_edge.uuid])
255 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
256 |     assert edge_count == 0
257 | 
258 |     # Deleting node should delete the edge
259 |     await entity_edge.save(graph_driver)
260 |     await alice_node.delete(graph_driver)
261 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
262 |     assert node_count == 0
263 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
264 |     assert edge_count == 0
265 | 
266 |     # Deleting node by uuids should delete the edge
267 |     await alice_node.save(graph_driver)
268 |     await entity_edge.save(graph_driver)
269 |     await alice_node.delete_by_uuids(graph_driver, [alice_node.uuid])
270 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
271 |     assert node_count == 0
272 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
273 |     assert edge_count == 0
274 | 
275 |     # Deleting node by group id should delete the edge
276 |     await alice_node.save(graph_driver)
277 |     await entity_edge.save(graph_driver)
278 |     await alice_node.delete_by_group_id(graph_driver, alice_node.group_id)
279 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
280 |     assert node_count == 0
281 |     edge_count = await get_edge_count(graph_driver, [entity_edge.uuid])
282 |     assert edge_count == 0
283 | 
284 |     # Cleanup nodes
285 |     await alice_node.delete(graph_driver)
286 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
287 |     assert node_count == 0
288 |     await bob_node.delete(graph_driver)
289 |     node_count = await get_node_count(graph_driver, [bob_node.uuid])
290 |     assert node_count == 0
291 | 
292 |     await graph_driver.close()
293 | 
294 | 
295 | @pytest.mark.asyncio
296 | async def test_community_edge(graph_driver, mock_embedder):
297 |     now = datetime.now()
298 | 
299 |     # Create community node
300 |     community_node_1 = CommunityNode(
301 |         name='test_community_1',
302 |         group_id=group_id,
303 |         summary='Community A summary',
304 |     )
305 |     await community_node_1.generate_name_embedding(mock_embedder)
306 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid])
307 |     assert node_count == 0
308 |     await community_node_1.save(graph_driver)
309 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid])
310 |     assert node_count == 1
311 | 
312 |     # Create community node
313 |     community_node_2 = CommunityNode(
314 |         name='test_community_2',
315 |         group_id=group_id,
316 |         summary='Community B summary',
317 |     )
318 |     await community_node_2.generate_name_embedding(mock_embedder)
319 |     node_count = await get_node_count(graph_driver, [community_node_2.uuid])
320 |     assert node_count == 0
321 |     await community_node_2.save(graph_driver)
322 |     node_count = await get_node_count(graph_driver, [community_node_2.uuid])
323 |     assert node_count == 1
324 | 
325 |     # Create entity node
326 |     alice_node = EntityNode(
327 |         name='Alice', labels=[], created_at=now, summary='Alice summary', group_id=group_id
328 |     )
329 |     await alice_node.generate_name_embedding(mock_embedder)
330 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
331 |     assert node_count == 0
332 |     await alice_node.save(graph_driver)
333 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
334 |     assert node_count == 1
335 | 
336 |     # Create community to community edge
337 |     community_edge = CommunityEdge(
338 |         source_node_uuid=community_node_1.uuid,
339 |         target_node_uuid=community_node_2.uuid,
340 |         created_at=now,
341 |         group_id=group_id,
342 |     )
343 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
344 |     assert edge_count == 0
345 |     await community_edge.save(graph_driver)
346 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
347 |     assert edge_count == 1
348 | 
349 |     # Get edge by uuid
350 |     retrieved = await CommunityEdge.get_by_uuid(graph_driver, community_edge.uuid)
351 |     assert retrieved.uuid == community_edge.uuid
352 |     assert retrieved.source_node_uuid == community_node_1.uuid
353 |     assert retrieved.target_node_uuid == community_node_2.uuid
354 |     assert retrieved.created_at == now
355 |     assert retrieved.group_id == group_id
356 | 
357 |     # Get edge by uuids
358 |     retrieved = await CommunityEdge.get_by_uuids(graph_driver, [community_edge.uuid])
359 |     assert len(retrieved) == 1
360 |     assert retrieved[0].uuid == community_edge.uuid
361 |     assert retrieved[0].source_node_uuid == community_node_1.uuid
362 |     assert retrieved[0].target_node_uuid == community_node_2.uuid
363 |     assert retrieved[0].created_at == now
364 |     assert retrieved[0].group_id == group_id
365 | 
366 |     # Get edge by group ids
367 |     retrieved = await CommunityEdge.get_by_group_ids(graph_driver, [group_id], limit=1)
368 |     assert len(retrieved) == 1
369 |     assert retrieved[0].uuid == community_edge.uuid
370 |     assert retrieved[0].source_node_uuid == community_node_1.uuid
371 |     assert retrieved[0].target_node_uuid == community_node_2.uuid
372 |     assert retrieved[0].created_at == now
373 |     assert retrieved[0].group_id == group_id
374 | 
375 |     # Delete edge by uuid
376 |     await community_edge.delete(graph_driver)
377 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
378 |     assert edge_count == 0
379 | 
380 |     # Delete edge by uuids
381 |     await community_edge.save(graph_driver)
382 |     await community_edge.delete_by_uuids(graph_driver, [community_edge.uuid])
383 |     edge_count = await get_edge_count(graph_driver, [community_edge.uuid])
384 |     assert edge_count == 0
385 | 
386 |     # Cleanup nodes
387 |     await alice_node.delete(graph_driver)
388 |     node_count = await get_node_count(graph_driver, [alice_node.uuid])
389 |     assert node_count == 0
390 |     await community_node_1.delete(graph_driver)
391 |     node_count = await get_node_count(graph_driver, [community_node_1.uuid])
392 |     assert node_count == 0
393 |     await community_node_2.delete(graph_driver)
394 |     node_count = await get_node_count(graph_driver, [community_node_2.uuid])
395 |     assert node_count == 0
396 | 
397 |     await graph_driver.close()
398 | 
```

--------------------------------------------------------------------------------
/tests/embedder/test_gemini.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | # Running tests: pytest -xvs tests/embedder/test_gemini.py
 18 | 
 19 | from collections.abc import Generator
 20 | from typing import Any
 21 | from unittest.mock import AsyncMock, MagicMock, patch
 22 | 
 23 | import pytest
 24 | from embedder_fixtures import create_embedding_values
 25 | 
 26 | from graphiti_core.embedder.gemini import (
 27 |     DEFAULT_EMBEDDING_MODEL,
 28 |     GeminiEmbedder,
 29 |     GeminiEmbedderConfig,
 30 | )
 31 | 
 32 | 
 33 | def create_gemini_embedding(multiplier: float = 0.1, dimension: int = 1536) -> MagicMock:
 34 |     """Create a mock Gemini embedding with specified value multiplier and dimension."""
 35 |     mock_embedding = MagicMock()
 36 |     mock_embedding.values = create_embedding_values(multiplier, dimension)
 37 |     return mock_embedding
 38 | 
 39 | 
 40 | @pytest.fixture
 41 | def mock_gemini_response() -> MagicMock:
 42 |     """Create a mock Gemini embeddings response."""
 43 |     mock_result = MagicMock()
 44 |     mock_result.embeddings = [create_gemini_embedding()]
 45 |     return mock_result
 46 | 
 47 | 
 48 | @pytest.fixture
 49 | def mock_gemini_batch_response() -> MagicMock:
 50 |     """Create a mock Gemini batch embeddings response."""
 51 |     mock_result = MagicMock()
 52 |     mock_result.embeddings = [
 53 |         create_gemini_embedding(0.1),
 54 |         create_gemini_embedding(0.2),
 55 |         create_gemini_embedding(0.3),
 56 |     ]
 57 |     return mock_result
 58 | 
 59 | 
 60 | @pytest.fixture
 61 | def mock_gemini_client() -> Generator[Any, Any, None]:
 62 |     """Create a mocked Gemini client."""
 63 |     with patch('google.genai.Client') as mock_client:
 64 |         mock_instance = mock_client.return_value
 65 |         mock_instance.aio = MagicMock()
 66 |         mock_instance.aio.models = MagicMock()
 67 |         mock_instance.aio.models.embed_content = AsyncMock()
 68 |         yield mock_instance
 69 | 
 70 | 
 71 | @pytest.fixture
 72 | def gemini_embedder(mock_gemini_client: Any) -> GeminiEmbedder:
 73 |     """Create a GeminiEmbedder with a mocked client."""
 74 |     config = GeminiEmbedderConfig(api_key='test_api_key')
 75 |     client = GeminiEmbedder(config=config)
 76 |     client.client = mock_gemini_client
 77 |     return client
 78 | 
 79 | 
 80 | class TestGeminiEmbedderInitialization:
 81 |     """Tests for GeminiEmbedder initialization."""
 82 | 
 83 |     @patch('google.genai.Client')
 84 |     def test_init_with_config(self, mock_client):
 85 |         """Test initialization with a config object."""
 86 |         config = GeminiEmbedderConfig(
 87 |             api_key='test_api_key', embedding_model='custom-model', embedding_dim=768
 88 |         )
 89 |         embedder = GeminiEmbedder(config=config)
 90 | 
 91 |         assert embedder.config == config
 92 |         assert embedder.config.embedding_model == 'custom-model'
 93 |         assert embedder.config.api_key == 'test_api_key'
 94 |         assert embedder.config.embedding_dim == 768
 95 | 
 96 |     @patch('google.genai.Client')
 97 |     def test_init_without_config(self, mock_client):
 98 |         """Test initialization without a config uses defaults."""
 99 |         embedder = GeminiEmbedder()
100 | 
101 |         assert embedder.config is not None
102 |         assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
103 | 
104 |     @patch('google.genai.Client')
105 |     def test_init_with_partial_config(self, mock_client):
106 |         """Test initialization with partial config."""
107 |         config = GeminiEmbedderConfig(api_key='test_api_key')
108 |         embedder = GeminiEmbedder(config=config)
109 | 
110 |         assert embedder.config.api_key == 'test_api_key'
111 |         assert embedder.config.embedding_model == DEFAULT_EMBEDDING_MODEL
112 | 
113 | 
114 | class TestGeminiEmbedderCreate:
115 |     """Tests for GeminiEmbedder create method."""
116 | 
117 |     @pytest.mark.asyncio
118 |     async def test_create_calls_api_correctly(
119 |         self,
120 |         gemini_embedder: GeminiEmbedder,
121 |         mock_gemini_client: Any,
122 |         mock_gemini_response: MagicMock,
123 |     ) -> None:
124 |         """Test that create method correctly calls the API and processes the response."""
125 |         # Setup
126 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
127 | 
128 |         # Call method
129 |         result = await gemini_embedder.create('Test input')
130 | 
131 |         # Verify API is called with correct parameters
132 |         mock_gemini_client.aio.models.embed_content.assert_called_once()
133 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
134 |         assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
135 |         assert kwargs['contents'] == ['Test input']
136 | 
137 |         # Verify result is processed correctly
138 |         assert result == mock_gemini_response.embeddings[0].values
139 | 
140 |     @pytest.mark.asyncio
141 |     @patch('google.genai.Client')
142 |     async def test_create_with_custom_model(
143 |         self, mock_client_class, mock_gemini_client: Any, mock_gemini_response: MagicMock
144 |     ) -> None:
145 |         """Test create method with custom embedding model."""
146 |         # Setup embedder with custom model
147 |         config = GeminiEmbedderConfig(api_key='test_api_key', embedding_model='custom-model')
148 |         embedder = GeminiEmbedder(config=config)
149 |         embedder.client = mock_gemini_client
150 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
151 | 
152 |         # Call method
153 |         await embedder.create('Test input')
154 | 
155 |         # Verify custom model is used
156 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
157 |         assert kwargs['model'] == 'custom-model'
158 | 
159 |     @pytest.mark.asyncio
160 |     @patch('google.genai.Client')
161 |     async def test_create_with_custom_dimension(
162 |         self, mock_client_class, mock_gemini_client: Any
163 |     ) -> None:
164 |         """Test create method with custom embedding dimension."""
165 |         # Setup embedder with custom dimension
166 |         config = GeminiEmbedderConfig(api_key='test_api_key', embedding_dim=768)
167 |         embedder = GeminiEmbedder(config=config)
168 |         embedder.client = mock_gemini_client
169 | 
170 |         # Setup mock response with custom dimension
171 |         mock_response = MagicMock()
172 |         mock_response.embeddings = [create_gemini_embedding(0.1, 768)]
173 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
174 | 
175 |         # Call method
176 |         result = await embedder.create('Test input')
177 | 
178 |         # Verify custom dimension is used in config
179 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
180 |         assert kwargs['config'].output_dimensionality == 768
181 | 
182 |         # Verify result has correct dimension
183 |         assert len(result) == 768
184 | 
185 |     @pytest.mark.asyncio
186 |     async def test_create_with_different_input_types(
187 |         self,
188 |         gemini_embedder: GeminiEmbedder,
189 |         mock_gemini_client: Any,
190 |         mock_gemini_response: MagicMock,
191 |     ) -> None:
192 |         """Test create method with different input types."""
193 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
194 | 
195 |         # Test with string
196 |         await gemini_embedder.create('Test string')
197 | 
198 |         # Test with list of strings
199 |         await gemini_embedder.create(['Test', 'List'])
200 | 
201 |         # Test with iterable of integers
202 |         await gemini_embedder.create([1, 2, 3])
203 | 
204 |         # Verify all calls were made
205 |         assert mock_gemini_client.aio.models.embed_content.call_count == 3
206 | 
207 |     @pytest.mark.asyncio
208 |     async def test_create_no_embeddings_error(
209 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
210 |     ) -> None:
211 |         """Test create method handling of no embeddings response."""
212 |         # Setup mock response with no embeddings
213 |         mock_response = MagicMock()
214 |         mock_response.embeddings = []
215 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
216 | 
217 |         # Call method and expect exception
218 |         with pytest.raises(ValueError) as exc_info:
219 |             await gemini_embedder.create('Test input')
220 | 
221 |         assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
222 | 
223 |     @pytest.mark.asyncio
224 |     async def test_create_no_values_error(
225 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
226 |     ) -> None:
227 |         """Test create method handling of embeddings with no values."""
228 |         # Setup mock response with embedding but no values
229 |         mock_embedding = MagicMock()
230 |         mock_embedding.values = None
231 |         mock_response = MagicMock()
232 |         mock_response.embeddings = [mock_embedding]
233 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
234 | 
235 |         # Call method and expect exception
236 |         with pytest.raises(ValueError) as exc_info:
237 |             await gemini_embedder.create('Test input')
238 | 
239 |         assert 'No embeddings returned from Gemini API in create()' in str(exc_info.value)
240 | 
241 | 
242 | class TestGeminiEmbedderCreateBatch:
243 |     """Tests for GeminiEmbedder create_batch method."""
244 | 
245 |     @pytest.mark.asyncio
246 |     async def test_create_batch_processes_multiple_inputs(
247 |         self,
248 |         gemini_embedder: GeminiEmbedder,
249 |         mock_gemini_client: Any,
250 |         mock_gemini_batch_response: MagicMock,
251 |     ) -> None:
252 |         """Test that create_batch method correctly processes multiple inputs."""
253 |         # Setup
254 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_batch_response
255 |         input_batch = ['Input 1', 'Input 2', 'Input 3']
256 | 
257 |         # Call method
258 |         result = await gemini_embedder.create_batch(input_batch)
259 | 
260 |         # Verify API is called with correct parameters
261 |         mock_gemini_client.aio.models.embed_content.assert_called_once()
262 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
263 |         assert kwargs['model'] == DEFAULT_EMBEDDING_MODEL
264 |         assert kwargs['contents'] == input_batch
265 | 
266 |         # Verify all results are processed correctly
267 |         assert len(result) == 3
268 |         assert result == [
269 |             mock_gemini_batch_response.embeddings[0].values,
270 |             mock_gemini_batch_response.embeddings[1].values,
271 |             mock_gemini_batch_response.embeddings[2].values,
272 |         ]
273 | 
274 |     @pytest.mark.asyncio
275 |     async def test_create_batch_single_input(
276 |         self,
277 |         gemini_embedder: GeminiEmbedder,
278 |         mock_gemini_client: Any,
279 |         mock_gemini_response: MagicMock,
280 |     ) -> None:
281 |         """Test create_batch method with single input."""
282 |         mock_gemini_client.aio.models.embed_content.return_value = mock_gemini_response
283 |         input_batch = ['Single input']
284 | 
285 |         result = await gemini_embedder.create_batch(input_batch)
286 | 
287 |         assert len(result) == 1
288 |         assert result[0] == mock_gemini_response.embeddings[0].values
289 | 
290 |     @pytest.mark.asyncio
291 |     async def test_create_batch_empty_input(
292 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
293 |     ) -> None:
294 |         """Test create_batch method with empty input."""
295 |         # Setup mock response with no embeddings
296 |         mock_response = MagicMock()
297 |         mock_response.embeddings = []
298 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
299 | 
300 |         input_batch = []
301 | 
302 |         result = await gemini_embedder.create_batch(input_batch)
303 |         assert result == []
304 |         mock_gemini_client.aio.models.embed_content.assert_not_called()
305 | 
306 |     @pytest.mark.asyncio
307 |     async def test_create_batch_no_embeddings_error(
308 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
309 |     ) -> None:
310 |         """Test create_batch method handling of no embeddings response."""
311 |         # Setup mock response with no embeddings
312 |         mock_response = MagicMock()
313 |         mock_response.embeddings = []
314 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
315 | 
316 |         input_batch = ['Input 1', 'Input 2']
317 | 
318 |         with pytest.raises(ValueError) as exc_info:
319 |             await gemini_embedder.create_batch(input_batch)
320 | 
321 |         assert 'No embeddings returned from Gemini API' in str(exc_info.value)
322 | 
323 |     @pytest.mark.asyncio
324 |     async def test_create_batch_empty_values_error(
325 |         self, gemini_embedder: GeminiEmbedder, mock_gemini_client: Any
326 |     ) -> None:
327 |         """Test create_batch method handling of embeddings with empty values."""
328 |         # Setup mock response with embeddings but empty values
329 |         mock_embedding1 = MagicMock()
330 |         mock_embedding1.values = [0.1, 0.2, 0.3]  # Valid values
331 |         mock_embedding2 = MagicMock()
332 |         mock_embedding2.values = None  # Empty values
333 | 
334 |         # Mock response for the initial batch call
335 |         mock_batch_response = MagicMock()
336 |         mock_batch_response.embeddings = [mock_embedding1, mock_embedding2]
337 | 
338 |         # Mock response for individual processing of 'Input 1'
339 |         mock_individual_response_1 = MagicMock()
340 |         mock_individual_response_1.embeddings = [mock_embedding1]
341 | 
342 |         # Mock response for individual processing of 'Input 2' (which has empty values)
343 |         mock_individual_response_2 = MagicMock()
344 |         mock_individual_response_2.embeddings = [mock_embedding2]
345 | 
346 |         # Set side_effect for embed_content to control return values for each call
347 |         mock_gemini_client.aio.models.embed_content.side_effect = [
348 |             mock_batch_response,  # First call for the batch
349 |             mock_individual_response_1,  # Second call for individual item 1
350 |             mock_individual_response_2,  # Third call for individual item 2
351 |         ]
352 | 
353 |         input_batch = ['Input 1', 'Input 2']
354 | 
355 |         with pytest.raises(ValueError) as exc_info:
356 |             await gemini_embedder.create_batch(input_batch)
357 | 
358 |         assert 'Empty embedding values returned' in str(exc_info.value)
359 | 
360 |     @pytest.mark.asyncio
361 |     @patch('google.genai.Client')
362 |     async def test_create_batch_with_custom_model_and_dimension(
363 |         self, mock_client_class, mock_gemini_client: Any
364 |     ) -> None:
365 |         """Test create_batch method with custom model and dimension."""
366 |         # Setup embedder with custom settings
367 |         config = GeminiEmbedderConfig(
368 |             api_key='test_api_key', embedding_model='custom-batch-model', embedding_dim=512
369 |         )
370 |         embedder = GeminiEmbedder(config=config)
371 |         embedder.client = mock_gemini_client
372 | 
373 |         # Setup mock response
374 |         mock_response = MagicMock()
375 |         mock_response.embeddings = [
376 |             create_gemini_embedding(0.1, 512),
377 |             create_gemini_embedding(0.2, 512),
378 |         ]
379 |         mock_gemini_client.aio.models.embed_content.return_value = mock_response
380 | 
381 |         input_batch = ['Input 1', 'Input 2']
382 |         result = await embedder.create_batch(input_batch)
383 | 
384 |         # Verify custom settings are used
385 |         _, kwargs = mock_gemini_client.aio.models.embed_content.call_args
386 |         assert kwargs['model'] == 'custom-batch-model'
387 |         assert kwargs['config'].output_dimensionality == 512
388 | 
389 |         # Verify results have correct dimension
390 |         assert len(result) == 2
391 |         assert all(len(embedding) == 512 for embedding in result)
392 | 
393 | 
394 | if __name__ == '__main__':
395 |     pytest.main(['-xvs', __file__])
396 | 
```

--------------------------------------------------------------------------------
/tests/utils/maintenance/test_entity_extraction.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import json
 18 | from unittest.mock import AsyncMock, MagicMock
 19 | 
 20 | import pytest
 21 | 
 22 | from graphiti_core.graphiti_types import GraphitiClients
 23 | from graphiti_core.nodes import EpisodeType, EpisodicNode
 24 | from graphiti_core.prompts.extract_nodes import ExtractedEntity
 25 | from graphiti_core.utils import content_chunking
 26 | from graphiti_core.utils.datetime_utils import utc_now
 27 | from graphiti_core.utils.maintenance import node_operations
 28 | from graphiti_core.utils.maintenance.node_operations import (
 29 |     _build_entity_types_context,
 30 |     _merge_extracted_entities,
 31 |     extract_nodes,
 32 | )
 33 | 
 34 | 
 35 | def _make_clients():
 36 |     """Create mock GraphitiClients for testing."""
 37 |     driver = MagicMock()
 38 |     embedder = MagicMock()
 39 |     cross_encoder = MagicMock()
 40 |     llm_client = MagicMock()
 41 |     llm_generate = AsyncMock()
 42 |     llm_client.generate_response = llm_generate
 43 | 
 44 |     clients = GraphitiClients.model_construct(  # bypass validation to allow test doubles
 45 |         driver=driver,
 46 |         embedder=embedder,
 47 |         cross_encoder=cross_encoder,
 48 |         llm_client=llm_client,
 49 |     )
 50 | 
 51 |     return clients, llm_generate
 52 | 
 53 | 
 54 | def _make_episode(
 55 |     content: str = 'Test content',
 56 |     source: EpisodeType = EpisodeType.text,
 57 |     group_id: str = 'group',
 58 | ) -> EpisodicNode:
 59 |     """Create a test episode node."""
 60 |     return EpisodicNode(
 61 |         name='test_episode',
 62 |         group_id=group_id,
 63 |         source=source,
 64 |         source_description='test',
 65 |         content=content,
 66 |         valid_at=utc_now(),
 67 |     )
 68 | 
 69 | 
 70 | class TestExtractNodesSmallInput:
 71 |     @pytest.mark.asyncio
 72 |     async def test_small_input_single_llm_call(self, monkeypatch):
 73 |         """Small inputs should use a single LLM call without chunking."""
 74 |         clients, llm_generate = _make_clients()
 75 | 
 76 |         # Mock LLM response
 77 |         llm_generate.return_value = {
 78 |             'extracted_entities': [
 79 |                 {'name': 'Alice', 'entity_type_id': 0},
 80 |                 {'name': 'Bob', 'entity_type_id': 0},
 81 |             ]
 82 |         }
 83 | 
 84 |         # Small content (below threshold)
 85 |         episode = _make_episode(content='Alice talked to Bob.')
 86 | 
 87 |         nodes = await extract_nodes(
 88 |             clients,
 89 |             episode,
 90 |             previous_episodes=[],
 91 |         )
 92 | 
 93 |         # Verify results
 94 |         assert len(nodes) == 2
 95 |         assert {n.name for n in nodes} == {'Alice', 'Bob'}
 96 | 
 97 |         # LLM should be called exactly once
 98 |         llm_generate.assert_awaited_once()
 99 | 
100 |     @pytest.mark.asyncio
101 |     async def test_extracts_entity_types(self, monkeypatch):
102 |         """Entity type classification should work correctly."""
103 |         clients, llm_generate = _make_clients()
104 | 
105 |         from pydantic import BaseModel
106 | 
107 |         class Person(BaseModel):
108 |             """A human person."""
109 | 
110 |             pass
111 | 
112 |         llm_generate.return_value = {
113 |             'extracted_entities': [
114 |                 {'name': 'Alice', 'entity_type_id': 1},  # Person
115 |                 {'name': 'Acme Corp', 'entity_type_id': 0},  # Default Entity
116 |             ]
117 |         }
118 | 
119 |         episode = _make_episode(content='Alice works at Acme Corp.')
120 | 
121 |         nodes = await extract_nodes(
122 |             clients,
123 |             episode,
124 |             previous_episodes=[],
125 |             entity_types={'Person': Person},
126 |         )
127 | 
128 |         # Alice should have Person label
129 |         alice = next(n for n in nodes if n.name == 'Alice')
130 |         assert 'Person' in alice.labels
131 | 
132 |         # Acme should have Entity label
133 |         acme = next(n for n in nodes if n.name == 'Acme Corp')
134 |         assert 'Entity' in acme.labels
135 | 
136 |     @pytest.mark.asyncio
137 |     async def test_excludes_entity_types(self, monkeypatch):
138 |         """Excluded entity types should not appear in results."""
139 |         clients, llm_generate = _make_clients()
140 | 
141 |         from pydantic import BaseModel
142 | 
143 |         class User(BaseModel):
144 |             """A user of the system."""
145 | 
146 |             pass
147 | 
148 |         llm_generate.return_value = {
149 |             'extracted_entities': [
150 |                 {'name': 'Alice', 'entity_type_id': 1},  # User (excluded)
151 |                 {'name': 'Project X', 'entity_type_id': 0},  # Entity
152 |             ]
153 |         }
154 | 
155 |         episode = _make_episode(content='Alice created Project X.')
156 | 
157 |         nodes = await extract_nodes(
158 |             clients,
159 |             episode,
160 |             previous_episodes=[],
161 |             entity_types={'User': User},
162 |             excluded_entity_types=['User'],
163 |         )
164 | 
165 |         # Alice should be excluded
166 |         assert len(nodes) == 1
167 |         assert nodes[0].name == 'Project X'
168 | 
169 |     @pytest.mark.asyncio
170 |     async def test_filters_empty_names(self, monkeypatch):
171 |         """Entities with empty names should be filtered out."""
172 |         clients, llm_generate = _make_clients()
173 | 
174 |         llm_generate.return_value = {
175 |             'extracted_entities': [
176 |                 {'name': 'Alice', 'entity_type_id': 0},
177 |                 {'name': '', 'entity_type_id': 0},
178 |                 {'name': '   ', 'entity_type_id': 0},
179 |             ]
180 |         }
181 | 
182 |         episode = _make_episode(content='Alice is here.')
183 | 
184 |         nodes = await extract_nodes(
185 |             clients,
186 |             episode,
187 |             previous_episodes=[],
188 |         )
189 | 
190 |         assert len(nodes) == 1
191 |         assert nodes[0].name == 'Alice'
192 | 
193 | 
194 | class TestExtractNodesChunking:
195 |     @pytest.mark.asyncio
196 |     async def test_large_input_triggers_chunking(self, monkeypatch):
197 |         """Large inputs should be chunked and processed in parallel."""
198 |         clients, llm_generate = _make_clients()
199 | 
200 |         # Track number of LLM calls
201 |         call_count = 0
202 | 
203 |         async def mock_generate(*args, **kwargs):
204 |             nonlocal call_count
205 |             call_count += 1
206 |             return {
207 |                 'extracted_entities': [
208 |                     {'name': f'Entity{call_count}', 'entity_type_id': 0},
209 |                 ]
210 |             }
211 | 
212 |         llm_generate.side_effect = mock_generate
213 | 
214 |         # Patch should_chunk where it's imported in node_operations
215 |         monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
216 |         monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size
217 | 
218 |         # Large content that exceeds threshold
219 |         large_content = 'word ' * 1000
220 |         episode = _make_episode(content=large_content)
221 | 
222 |         await extract_nodes(
223 |             clients,
224 |             episode,
225 |             previous_episodes=[],
226 |         )
227 | 
228 |         # Multiple LLM calls should have been made
229 |         assert call_count > 1
230 | 
231 |     @pytest.mark.asyncio
232 |     async def test_json_content_uses_json_chunking(self, monkeypatch):
233 |         """JSON episodes should use JSON-aware chunking."""
234 |         clients, llm_generate = _make_clients()
235 | 
236 |         llm_generate.return_value = {
237 |             'extracted_entities': [
238 |                 {'name': 'Service1', 'entity_type_id': 0},
239 |             ]
240 |         }
241 | 
242 |         # Patch should_chunk where it's imported in node_operations
243 |         monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
244 |         monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size
245 | 
246 |         # JSON content
247 |         json_data = [{'service': f'Service{i}'} for i in range(50)]
248 |         episode = _make_episode(
249 |             content=json.dumps(json_data),
250 |             source=EpisodeType.json,
251 |         )
252 | 
253 |         await extract_nodes(
254 |             clients,
255 |             episode,
256 |             previous_episodes=[],
257 |         )
258 | 
259 |         # Verify JSON chunking was used (LLM called multiple times)
260 |         assert llm_generate.await_count > 1
261 | 
262 |     @pytest.mark.asyncio
263 |     async def test_message_content_uses_message_chunking(self, monkeypatch):
264 |         """Message episodes should use message-aware chunking."""
265 |         clients, llm_generate = _make_clients()
266 | 
267 |         llm_generate.return_value = {
268 |             'extracted_entities': [
269 |                 {'name': 'Speaker', 'entity_type_id': 0},
270 |             ]
271 |         }
272 | 
273 |         # Patch should_chunk where it's imported in node_operations
274 |         monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
275 |         monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size
276 | 
277 |         # Conversation content
278 |         messages = [f'Speaker{i}: Hello from speaker {i}!' for i in range(50)]
279 |         episode = _make_episode(
280 |             content='\n'.join(messages),
281 |             source=EpisodeType.message,
282 |         )
283 | 
284 |         await extract_nodes(
285 |             clients,
286 |             episode,
287 |             previous_episodes=[],
288 |         )
289 | 
290 |         assert llm_generate.await_count > 1
291 | 
292 |     @pytest.mark.asyncio
293 |     async def test_deduplicates_across_chunks(self, monkeypatch):
294 |         """Entities appearing in multiple chunks should be deduplicated."""
295 |         clients, llm_generate = _make_clients()
296 | 
297 |         # Simulate same entity appearing in multiple chunks
298 |         call_count = 0
299 | 
300 |         async def mock_generate(*args, **kwargs):
301 |             nonlocal call_count
302 |             call_count += 1
303 |             # Return 'Alice' in every chunk
304 |             return {
305 |                 'extracted_entities': [
306 |                     {'name': 'Alice', 'entity_type_id': 0},
307 |                     {'name': f'Entity{call_count}', 'entity_type_id': 0},
308 |                 ]
309 |             }
310 | 
311 |         llm_generate.side_effect = mock_generate
312 | 
313 |         # Patch should_chunk where it's imported in node_operations
314 |         monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
315 |         monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size
316 | 
317 |         large_content = 'word ' * 1000
318 |         episode = _make_episode(content=large_content)
319 | 
320 |         nodes = await extract_nodes(
321 |             clients,
322 |             episode,
323 |             previous_episodes=[],
324 |         )
325 | 
326 |         # Alice should appear only once despite being in every chunk
327 |         alice_count = sum(1 for n in nodes if n.name == 'Alice')
328 |         assert alice_count == 1
329 | 
330 |     @pytest.mark.asyncio
331 |     async def test_deduplication_case_insensitive(self, monkeypatch):
332 |         """Deduplication should be case-insensitive."""
333 |         clients, llm_generate = _make_clients()
334 | 
335 |         call_count = 0
336 | 
337 |         async def mock_generate(*args, **kwargs):
338 |             nonlocal call_count
339 |             call_count += 1
340 |             if call_count == 1:
341 |                 return {'extracted_entities': [{'name': 'alice', 'entity_type_id': 0}]}
342 |             return {'extracted_entities': [{'name': 'Alice', 'entity_type_id': 0}]}
343 | 
344 |         llm_generate.side_effect = mock_generate
345 | 
346 |         # Patch should_chunk where it's imported in node_operations
347 |         monkeypatch.setattr(node_operations, 'should_chunk', lambda content, ep_type: True)
348 |         monkeypatch.setattr(content_chunking, 'CHUNK_TOKEN_SIZE', 50)  # Small chunk size
349 | 
350 |         large_content = 'word ' * 1000
351 |         episode = _make_episode(content=large_content)
352 | 
353 |         nodes = await extract_nodes(
354 |             clients,
355 |             episode,
356 |             previous_episodes=[],
357 |         )
358 | 
359 |         # Should have only one Alice (case-insensitive dedup)
360 |         alice_variants = [n for n in nodes if n.name.lower() == 'alice']
361 |         assert len(alice_variants) == 1
362 | 
363 | 
364 | class TestExtractNodesPromptSelection:
365 |     @pytest.mark.asyncio
366 |     async def test_uses_text_prompt_for_text_episodes(self, monkeypatch):
367 |         """Text episodes should use extract_text prompt."""
368 |         clients, llm_generate = _make_clients()
369 |         llm_generate.return_value = {'extracted_entities': []}
370 | 
371 |         episode = _make_episode(source=EpisodeType.text)
372 | 
373 |         await extract_nodes(clients, episode, previous_episodes=[])
374 | 
375 |         # Check prompt_name parameter
376 |         call_kwargs = llm_generate.call_args[1]
377 |         assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_text'
378 | 
379 |     @pytest.mark.asyncio
380 |     async def test_uses_json_prompt_for_json_episodes(self, monkeypatch):
381 |         """JSON episodes should use extract_json prompt."""
382 |         clients, llm_generate = _make_clients()
383 |         llm_generate.return_value = {'extracted_entities': []}
384 | 
385 |         episode = _make_episode(content='{}', source=EpisodeType.json)
386 | 
387 |         await extract_nodes(clients, episode, previous_episodes=[])
388 | 
389 |         call_kwargs = llm_generate.call_args[1]
390 |         assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_json'
391 | 
392 |     @pytest.mark.asyncio
393 |     async def test_uses_message_prompt_for_message_episodes(self, monkeypatch):
394 |         """Message episodes should use extract_message prompt."""
395 |         clients, llm_generate = _make_clients()
396 |         llm_generate.return_value = {'extracted_entities': []}
397 | 
398 |         episode = _make_episode(source=EpisodeType.message)
399 | 
400 |         await extract_nodes(clients, episode, previous_episodes=[])
401 | 
402 |         call_kwargs = llm_generate.call_args[1]
403 |         assert call_kwargs.get('prompt_name') == 'extract_nodes.extract_message'
404 | 
405 | 
406 | class TestBuildEntityTypesContext:
407 |     def test_default_entity_type_always_included(self):
408 |         """Default Entity type should always be at index 0."""
409 |         context = _build_entity_types_context(None)
410 | 
411 |         assert len(context) == 1
412 |         assert context[0]['entity_type_id'] == 0
413 |         assert context[0]['entity_type_name'] == 'Entity'
414 | 
415 |     def test_custom_types_added_after_default(self):
416 |         """Custom entity types should be added with sequential IDs."""
417 |         from pydantic import BaseModel
418 | 
419 |         class Person(BaseModel):
420 |             """A human person."""
421 | 
422 |             pass
423 | 
424 |         class Organization(BaseModel):
425 |             """A business or organization."""
426 | 
427 |             pass
428 | 
429 |         context = _build_entity_types_context(
430 |             {
431 |                 'Person': Person,
432 |                 'Organization': Organization,
433 |             }
434 |         )
435 | 
436 |         assert len(context) == 3
437 |         assert context[0]['entity_type_name'] == 'Entity'
438 |         assert context[1]['entity_type_name'] == 'Person'
439 |         assert context[1]['entity_type_id'] == 1
440 |         assert context[2]['entity_type_name'] == 'Organization'
441 |         assert context[2]['entity_type_id'] == 2
442 | 
443 | 
444 | class TestMergeExtractedEntities:
445 |     def test_merge_deduplicates_by_name(self):
446 |         """Entities with same name should be deduplicated."""
447 |         chunk_results = [
448 |             [
449 |                 ExtractedEntity(name='Alice', entity_type_id=0),
450 |                 ExtractedEntity(name='Bob', entity_type_id=0),
451 |             ],
452 |             [
453 |                 ExtractedEntity(name='Alice', entity_type_id=0),  # Duplicate
454 |                 ExtractedEntity(name='Charlie', entity_type_id=0),
455 |             ],
456 |         ]
457 | 
458 |         merged = _merge_extracted_entities(chunk_results)
459 | 
460 |         assert len(merged) == 3
461 |         names = {e.name for e in merged}
462 |         assert names == {'Alice', 'Bob', 'Charlie'}
463 | 
464 |     def test_merge_prefers_first_occurrence(self):
465 |         """When duplicates exist, first occurrence should be preferred."""
466 |         chunk_results = [
467 |             [ExtractedEntity(name='Alice', entity_type_id=1)],  # First: type 1
468 |             [ExtractedEntity(name='Alice', entity_type_id=2)],  # Later: type 2
469 |         ]
470 | 
471 |         merged = _merge_extracted_entities(chunk_results)
472 | 
473 |         assert len(merged) == 1
474 |         assert merged[0].entity_type_id == 1  # First occurrence wins
475 | 
```

--------------------------------------------------------------------------------
/tests/driver/test_falkordb_driver.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Copyright 2024, Zep Software, Inc.
  3 | 
  4 | Licensed under the Apache License, Version 2.0 (the "License");
  5 | you may not use this file except in compliance with the License.
  6 | You may obtain a copy of the License at
  7 | 
  8 |     http://www.apache.org/licenses/LICENSE-2.0
  9 | 
 10 | Unless required by applicable law or agreed to in writing, software
 11 | distributed under the License is distributed on an "AS IS" BASIS,
 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13 | See the License for the specific language governing permissions and
 14 | limitations under the License.
 15 | """
 16 | 
 17 | import os
 18 | import unittest
 19 | from datetime import datetime, timezone
 20 | from unittest.mock import AsyncMock, MagicMock, patch
 21 | 
 22 | import pytest
 23 | 
 24 | from graphiti_core.driver.driver import GraphProvider
 25 | 
 26 | try:
 27 |     from graphiti_core.driver.falkordb_driver import FalkorDriver, FalkorDriverSession
 28 | 
 29 |     HAS_FALKORDB = True
 30 | except ImportError:
 31 |     FalkorDriver = None
 32 |     HAS_FALKORDB = False
 33 | 
 34 | 
 35 | class TestFalkorDriver:
 36 |     """Comprehensive test suite for FalkorDB driver."""
 37 | 
 38 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 39 |     def setup_method(self):
 40 |         """Set up test fixtures."""
 41 |         self.mock_client = MagicMock()
 42 |         with patch('graphiti_core.driver.falkordb_driver.FalkorDB'):
 43 |             self.driver = FalkorDriver()
 44 |         self.driver.client = self.mock_client
 45 | 
 46 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 47 |     def test_init_with_connection_params(self):
 48 |         """Test initialization with connection parameters."""
 49 |         with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db:
 50 |             driver = FalkorDriver(
 51 |                 host='test-host', port='1234', username='test-user', password='test-pass'
 52 |             )
 53 |             assert driver.provider == GraphProvider.FALKORDB
 54 |             mock_falkor_db.assert_called_once_with(
 55 |                 host='test-host', port='1234', username='test-user', password='test-pass'
 56 |             )
 57 | 
 58 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 59 |     def test_init_with_falkor_db_instance(self):
 60 |         """Test initialization with a FalkorDB instance."""
 61 |         with patch('graphiti_core.driver.falkordb_driver.FalkorDB') as mock_falkor_db_class:
 62 |             mock_falkor_db = MagicMock()
 63 |             driver = FalkorDriver(falkor_db=mock_falkor_db)
 64 |             assert driver.provider == GraphProvider.FALKORDB
 65 |             assert driver.client is mock_falkor_db
 66 |             mock_falkor_db_class.assert_not_called()
 67 | 
 68 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 69 |     def test_provider(self):
 70 |         """Test driver provider identification."""
 71 |         assert self.driver.provider == GraphProvider.FALKORDB
 72 | 
 73 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 74 |     def test_get_graph_with_name(self):
 75 |         """Test _get_graph with specific graph name."""
 76 |         mock_graph = MagicMock()
 77 |         self.mock_client.select_graph.return_value = mock_graph
 78 | 
 79 |         result = self.driver._get_graph('test_graph')
 80 | 
 81 |         self.mock_client.select_graph.assert_called_once_with('test_graph')
 82 |         assert result is mock_graph
 83 | 
 84 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 85 |     def test_get_graph_with_none_defaults_to_default_database(self):
 86 |         """Test _get_graph with None defaults to default_db."""
 87 |         mock_graph = MagicMock()
 88 |         self.mock_client.select_graph.return_value = mock_graph
 89 | 
 90 |         result = self.driver._get_graph(None)
 91 | 
 92 |         self.mock_client.select_graph.assert_called_once_with('default_db')
 93 |         assert result is mock_graph
 94 | 
 95 |     @pytest.mark.asyncio
 96 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
 97 |     async def test_execute_query_success(self):
 98 |         """Test successful query execution."""
 99 |         mock_graph = MagicMock()
100 |         mock_result = MagicMock()
101 |         mock_result.header = [('col1', 'column1'), ('col2', 'column2')]
102 |         mock_result.result_set = [['row1col1', 'row1col2']]
103 |         mock_graph.query = AsyncMock(return_value=mock_result)
104 |         self.mock_client.select_graph.return_value = mock_graph
105 | 
106 |         result = await self.driver.execute_query('MATCH (n) RETURN n', param1='value1')
107 | 
108 |         mock_graph.query.assert_called_once_with('MATCH (n) RETURN n', {'param1': 'value1'})
109 | 
110 |         result_set, header, summary = result
111 |         assert result_set == [{'column1': 'row1col1', 'column2': 'row1col2'}]
112 |         assert header == ['column1', 'column2']
113 |         assert summary is None
114 | 
115 |     @pytest.mark.asyncio
116 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
117 |     async def test_execute_query_handles_index_already_exists_error(self):
118 |         """Test handling of 'already indexed' error."""
119 |         mock_graph = MagicMock()
120 |         mock_graph.query = AsyncMock(side_effect=Exception('Index already indexed'))
121 |         self.mock_client.select_graph.return_value = mock_graph
122 | 
123 |         with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
124 |             result = await self.driver.execute_query('CREATE INDEX ...')
125 | 
126 |             mock_logger.info.assert_called_once()
127 |             assert result is None
128 | 
129 |     @pytest.mark.asyncio
130 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
131 |     async def test_execute_query_propagates_other_exceptions(self):
132 |         """Test that other exceptions are properly propagated."""
133 |         mock_graph = MagicMock()
134 |         mock_graph.query = AsyncMock(side_effect=Exception('Other error'))
135 |         self.mock_client.select_graph.return_value = mock_graph
136 | 
137 |         with patch('graphiti_core.driver.falkordb_driver.logger') as mock_logger:
138 |             with pytest.raises(Exception, match='Other error'):
139 |                 await self.driver.execute_query('INVALID QUERY')
140 | 
141 |             mock_logger.error.assert_called_once()
142 | 
143 |     @pytest.mark.asyncio
144 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
145 |     async def test_execute_query_converts_datetime_parameters(self):
146 |         """Test that datetime objects in kwargs are converted to ISO strings."""
147 |         mock_graph = MagicMock()
148 |         mock_result = MagicMock()
149 |         mock_result.header = []
150 |         mock_result.result_set = []
151 |         mock_graph.query = AsyncMock(return_value=mock_result)
152 |         self.mock_client.select_graph.return_value = mock_graph
153 | 
154 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
155 | 
156 |         await self.driver.execute_query(
157 |             'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
158 |         )
159 | 
160 |         call_args = mock_graph.query.call_args[0]
161 |         assert call_args[1]['created_at'] == test_datetime.isoformat()
162 | 
163 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
164 |     def test_session_creation(self):
165 |         """Test session creation with specific database."""
166 |         mock_graph = MagicMock()
167 |         self.mock_client.select_graph.return_value = mock_graph
168 | 
169 |         session = self.driver.session()
170 | 
171 |         assert isinstance(session, FalkorDriverSession)
172 |         assert session.graph is mock_graph
173 | 
174 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
175 |     def test_session_creation_with_none_uses_default_database(self):
176 |         """Test session creation with None uses default database."""
177 |         mock_graph = MagicMock()
178 |         self.mock_client.select_graph.return_value = mock_graph
179 | 
180 |         session = self.driver.session()
181 | 
182 |         assert isinstance(session, FalkorDriverSession)
183 | 
184 |     @pytest.mark.asyncio
185 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
186 |     async def test_close_calls_connection_close(self):
187 |         """Test driver close method calls connection close."""
188 |         mock_connection = MagicMock()
189 |         mock_connection.close = AsyncMock()
190 |         self.mock_client.connection = mock_connection
191 | 
192 |         # Ensure hasattr checks work correctly
193 |         del self.mock_client.aclose  # Remove aclose if it exists
194 | 
195 |         with patch('builtins.hasattr') as mock_hasattr:
196 |             # hasattr(self.client, 'aclose') returns False
197 |             # hasattr(self.client.connection, 'aclose') returns False
198 |             # hasattr(self.client.connection, 'close') returns True
199 |             mock_hasattr.side_effect = lambda obj, attr: (
200 |                 attr == 'close' and obj is mock_connection
201 |             )
202 | 
203 |             await self.driver.close()
204 | 
205 |         mock_connection.close.assert_called_once()
206 | 
207 |     @pytest.mark.asyncio
208 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
209 |     async def test_delete_all_indexes(self):
210 |         """Test delete_all_indexes method."""
211 |         with patch.object(self.driver, 'execute_query', new_callable=AsyncMock) as mock_execute:
212 |             # Return None to simulate no indexes found
213 |             mock_execute.return_value = None
214 | 
215 |             await self.driver.delete_all_indexes()
216 | 
217 |             mock_execute.assert_called_once_with('CALL db.indexes()')
218 | 
219 | 
220 | class TestFalkorDriverSession:
221 |     """Test FalkorDB driver session functionality."""
222 | 
223 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
224 |     def setup_method(self):
225 |         """Set up test fixtures."""
226 |         self.mock_graph = MagicMock()
227 |         self.session = FalkorDriverSession(self.mock_graph)
228 | 
229 |     @pytest.mark.asyncio
230 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
231 |     async def test_session_async_context_manager(self):
232 |         """Test session can be used as async context manager."""
233 |         async with self.session as s:
234 |             assert s is self.session
235 | 
236 |     @pytest.mark.asyncio
237 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
238 |     async def test_close_method(self):
239 |         """Test session close method doesn't raise exceptions."""
240 |         await self.session.close()  # Should not raise
241 | 
242 |     @pytest.mark.asyncio
243 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
244 |     async def test_execute_write_passes_session_and_args(self):
245 |         """Test execute_write method passes session and arguments correctly."""
246 | 
247 |         async def test_func(session, *args, **kwargs):
248 |             assert session is self.session
249 |             assert args == ('arg1', 'arg2')
250 |             assert kwargs == {'key': 'value'}
251 |             return 'result'
252 | 
253 |         result = await self.session.execute_write(test_func, 'arg1', 'arg2', key='value')
254 |         assert result == 'result'
255 | 
256 |     @pytest.mark.asyncio
257 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
258 |     async def test_run_single_query_with_parameters(self):
259 |         """Test running a single query with parameters."""
260 |         self.mock_graph.query = AsyncMock()
261 | 
262 |         await self.session.run('MATCH (n) RETURN n', param1='value1', param2='value2')
263 | 
264 |         self.mock_graph.query.assert_called_once_with(
265 |             'MATCH (n) RETURN n', {'param1': 'value1', 'param2': 'value2'}
266 |         )
267 | 
268 |     @pytest.mark.asyncio
269 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
270 |     async def test_run_multiple_queries_as_list(self):
271 |         """Test running multiple queries passed as list."""
272 |         self.mock_graph.query = AsyncMock()
273 | 
274 |         queries = [
275 |             ('MATCH (n) RETURN n', {'param1': 'value1'}),
276 |             ('CREATE (n:Node)', {'param2': 'value2'}),
277 |         ]
278 | 
279 |         await self.session.run(queries)
280 | 
281 |         assert self.mock_graph.query.call_count == 2
282 |         calls = self.mock_graph.query.call_args_list
283 |         assert calls[0][0] == ('MATCH (n) RETURN n', {'param1': 'value1'})
284 |         assert calls[1][0] == ('CREATE (n:Node)', {'param2': 'value2'})
285 | 
286 |     @pytest.mark.asyncio
287 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
288 |     async def test_run_converts_datetime_objects_to_iso_strings(self):
289 |         """Test that datetime objects are converted to ISO strings."""
290 |         self.mock_graph.query = AsyncMock()
291 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
292 | 
293 |         await self.session.run(
294 |             'CREATE (n:Node) SET n.created_at = $created_at', created_at=test_datetime
295 |         )
296 | 
297 |         self.mock_graph.query.assert_called_once()
298 |         call_args = self.mock_graph.query.call_args[0]
299 |         assert call_args[1]['created_at'] == test_datetime.isoformat()
300 | 
301 | 
302 | class TestDatetimeConversion:
303 |     """Test datetime conversion utility function."""
304 | 
305 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
306 |     def test_convert_datetime_dict(self):
307 |         """Test datetime conversion in nested dictionary."""
308 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
309 | 
310 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
311 |         input_dict = {
312 |             'string_val': 'test',
313 |             'datetime_val': test_datetime,
314 |             'nested_dict': {'nested_datetime': test_datetime, 'nested_string': 'nested_test'},
315 |         }
316 | 
317 |         result = convert_datetimes_to_strings(input_dict)
318 | 
319 |         assert result['string_val'] == 'test'
320 |         assert result['datetime_val'] == test_datetime.isoformat()
321 |         assert result['nested_dict']['nested_datetime'] == test_datetime.isoformat()
322 |         assert result['nested_dict']['nested_string'] == 'nested_test'
323 | 
324 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
325 |     def test_convert_datetime_list_and_tuple(self):
326 |         """Test datetime conversion in lists and tuples."""
327 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
328 | 
329 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
330 | 
331 |         # Test list
332 |         input_list = ['test', test_datetime, ['nested', test_datetime]]
333 |         result_list = convert_datetimes_to_strings(input_list)
334 |         assert result_list[0] == 'test'
335 |         assert result_list[1] == test_datetime.isoformat()
336 |         assert result_list[2][1] == test_datetime.isoformat()
337 | 
338 |         # Test tuple
339 |         input_tuple = ('test', test_datetime)
340 |         result_tuple = convert_datetimes_to_strings(input_tuple)
341 |         assert isinstance(result_tuple, tuple)
342 |         assert result_tuple[0] == 'test'
343 |         assert result_tuple[1] == test_datetime.isoformat()
344 | 
345 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
346 |     def test_convert_single_datetime(self):
347 |         """Test datetime conversion for single datetime object."""
348 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
349 | 
350 |         test_datetime = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
351 |         result = convert_datetimes_to_strings(test_datetime)
352 |         assert result == test_datetime.isoformat()
353 | 
354 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
355 |     def test_convert_other_types_unchanged(self):
356 |         """Test that non-datetime types are returned unchanged."""
357 |         from graphiti_core.driver.falkordb_driver import convert_datetimes_to_strings
358 | 
359 |         assert convert_datetimes_to_strings('string') == 'string'
360 |         assert convert_datetimes_to_strings(123) == 123
361 |         assert convert_datetimes_to_strings(None) is None
362 |         assert convert_datetimes_to_strings(True) is True
363 | 
364 | 
365 | # Simple integration test
366 | class TestFalkorDriverIntegration:
367 |     """Simple integration test for FalkorDB driver."""
368 | 
369 |     @pytest.mark.asyncio
370 |     @unittest.skipIf(not HAS_FALKORDB, 'FalkorDB is not installed')
371 |     async def test_basic_integration_with_real_falkordb(self):
372 |         """Basic integration test with real FalkorDB instance."""
373 |         pytest.importorskip('falkordb')
374 | 
375 |         falkor_host = os.getenv('FALKORDB_HOST', 'localhost')
376 |         falkor_port = os.getenv('FALKORDB_PORT', '6379')
377 | 
378 |         try:
379 |             driver = FalkorDriver(host=falkor_host, port=falkor_port)
380 | 
381 |             # Test basic query execution
382 |             result = await driver.execute_query('RETURN 1 as test')
383 |             assert result is not None
384 | 
385 |             result_set, header, summary = result
386 |             assert header == ['test']
387 |             assert result_set == [{'test': 1}]
388 | 
389 |             await driver.close()
390 | 
391 |         except Exception as e:
392 |             pytest.skip(f'FalkorDB not available for integration test: {e}')
393 | 
```

--------------------------------------------------------------------------------
/mcp_server/src/services/factories.py:
--------------------------------------------------------------------------------

```python
  1 | """Factory classes for creating LLM, Embedder, and Database clients."""
  2 | 
  3 | from openai import AsyncAzureOpenAI
  4 | 
  5 | from config.schema import (
  6 |     DatabaseConfig,
  7 |     EmbedderConfig,
  8 |     LLMConfig,
  9 | )
 10 | 
 11 | # Try to import FalkorDriver if available
 12 | try:
 13 |     from graphiti_core.driver.falkordb_driver import FalkorDriver  # noqa: F401
 14 | 
 15 |     HAS_FALKOR = True
 16 | except ImportError:
 17 |     HAS_FALKOR = False
 18 | 
 19 | # Kuzu support removed - FalkorDB is now the default
 20 | from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
 21 | from graphiti_core.llm_client import LLMClient, OpenAIClient
 22 | from graphiti_core.llm_client.config import LLMConfig as GraphitiLLMConfig
 23 | 
 24 | # Try to import additional providers if available
 25 | try:
 26 |     from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
 27 | 
 28 |     HAS_AZURE_EMBEDDER = True
 29 | except ImportError:
 30 |     HAS_AZURE_EMBEDDER = False
 31 | 
 32 | try:
 33 |     from graphiti_core.embedder.gemini import GeminiEmbedder
 34 | 
 35 |     HAS_GEMINI_EMBEDDER = True
 36 | except ImportError:
 37 |     HAS_GEMINI_EMBEDDER = False
 38 | 
 39 | try:
 40 |     from graphiti_core.embedder.voyage import VoyageAIEmbedder
 41 | 
 42 |     HAS_VOYAGE_EMBEDDER = True
 43 | except ImportError:
 44 |     HAS_VOYAGE_EMBEDDER = False
 45 | 
 46 | try:
 47 |     from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
 48 | 
 49 |     HAS_AZURE_LLM = True
 50 | except ImportError:
 51 |     HAS_AZURE_LLM = False
 52 | 
 53 | try:
 54 |     from graphiti_core.llm_client.anthropic_client import AnthropicClient
 55 | 
 56 |     HAS_ANTHROPIC = True
 57 | except ImportError:
 58 |     HAS_ANTHROPIC = False
 59 | 
 60 | try:
 61 |     from graphiti_core.llm_client.gemini_client import GeminiClient
 62 | 
 63 |     HAS_GEMINI = True
 64 | except ImportError:
 65 |     HAS_GEMINI = False
 66 | 
 67 | try:
 68 |     from graphiti_core.llm_client.groq_client import GroqClient
 69 | 
 70 |     HAS_GROQ = True
 71 | except ImportError:
 72 |     HAS_GROQ = False
 73 | from utils.utils import create_azure_credential_token_provider
 74 | 
 75 | 
 76 | def _validate_api_key(provider_name: str, api_key: str | None, logger) -> str:
 77 |     """Validate API key is present.
 78 | 
 79 |     Args:
 80 |         provider_name: Name of the provider (e.g., 'OpenAI', 'Anthropic')
 81 |         api_key: The API key to validate
 82 |         logger: Logger instance for output
 83 | 
 84 |     Returns:
 85 |         The validated API key
 86 | 
 87 |     Raises:
 88 |         ValueError: If API key is None or empty
 89 |     """
 90 |     if not api_key:
 91 |         raise ValueError(
 92 |             f'{provider_name} API key is not configured. Please set the appropriate environment variable.'
 93 |         )
 94 | 
 95 |     logger.info(f'Creating {provider_name} client')
 96 | 
 97 |     return api_key
 98 | 
 99 | 
100 | class LLMClientFactory:
101 |     """Factory for creating LLM clients based on configuration."""
102 | 
103 |     @staticmethod
104 |     def create(config: LLMConfig) -> LLMClient:
105 |         """Create an LLM client based on the configured provider."""
106 |         import logging
107 | 
108 |         logger = logging.getLogger(__name__)
109 | 
110 |         provider = config.provider.lower()
111 | 
112 |         match provider:
113 |             case 'openai':
114 |                 if not config.providers.openai:
115 |                     raise ValueError('OpenAI provider configuration not found')
116 | 
117 |                 api_key = config.providers.openai.api_key
118 |                 _validate_api_key('OpenAI', api_key, logger)
119 | 
120 |                 from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
121 | 
122 |                 # Determine appropriate small model based on main model type
123 |                 is_reasoning_model = (
124 |                     config.model.startswith('gpt-5')
125 |                     or config.model.startswith('o1')
126 |                     or config.model.startswith('o3')
127 |                 )
128 |                 small_model = (
129 |                     'gpt-5-nano' if is_reasoning_model else 'gpt-4.1-mini'
130 |                 )  # Use reasoning model for small tasks if main model is reasoning
131 | 
132 |                 llm_config = CoreLLMConfig(
133 |                     api_key=api_key,
134 |                     model=config.model,
135 |                     small_model=small_model,
136 |                     temperature=config.temperature,
137 |                     max_tokens=config.max_tokens,
138 |                 )
139 | 
140 |                 # Only pass reasoning/verbosity parameters for reasoning models (gpt-5 family)
141 |                 if is_reasoning_model:
142 |                     return OpenAIClient(config=llm_config, reasoning='minimal', verbosity='low')
143 |                 else:
144 |                     # For non-reasoning models, explicitly pass None to disable these parameters
145 |                     return OpenAIClient(config=llm_config, reasoning=None, verbosity=None)
146 | 
147 |             case 'azure_openai':
148 |                 if not HAS_AZURE_LLM:
149 |                     raise ValueError(
150 |                         'Azure OpenAI LLM client not available in current graphiti-core version'
151 |                     )
152 |                 if not config.providers.azure_openai:
153 |                     raise ValueError('Azure OpenAI provider configuration not found')
154 |                 azure_config = config.providers.azure_openai
155 | 
156 |                 if not azure_config.api_url:
157 |                     raise ValueError('Azure OpenAI API URL is required')
158 | 
159 |                 # Handle Azure AD authentication if enabled
160 |                 api_key: str | None = None
161 |                 azure_ad_token_provider = None
162 |                 if azure_config.use_azure_ad:
163 |                     logger.info('Creating Azure OpenAI LLM client with Azure AD authentication')
164 |                     azure_ad_token_provider = create_azure_credential_token_provider()
165 |                 else:
166 |                     api_key = azure_config.api_key
167 |                     _validate_api_key('Azure OpenAI', api_key, logger)
168 | 
169 |                 # Create the Azure OpenAI client first
170 |                 azure_client = AsyncAzureOpenAI(
171 |                     api_key=api_key,
172 |                     azure_endpoint=azure_config.api_url,
173 |                     api_version=azure_config.api_version,
174 |                     azure_deployment=azure_config.deployment_name,
175 |                     azure_ad_token_provider=azure_ad_token_provider,
176 |                 )
177 | 
178 |                 # Then create the LLMConfig
179 |                 from graphiti_core.llm_client.config import LLMConfig as CoreLLMConfig
180 | 
181 |                 llm_config = CoreLLMConfig(
182 |                     api_key=api_key,
183 |                     base_url=azure_config.api_url,
184 |                     model=config.model,
185 |                     temperature=config.temperature,
186 |                     max_tokens=config.max_tokens,
187 |                 )
188 | 
189 |                 return AzureOpenAILLMClient(
190 |                     azure_client=azure_client,
191 |                     config=llm_config,
192 |                     max_tokens=config.max_tokens,
193 |                 )
194 | 
195 |             case 'anthropic':
196 |                 if not HAS_ANTHROPIC:
197 |                     raise ValueError(
198 |                         'Anthropic client not available in current graphiti-core version'
199 |                     )
200 |                 if not config.providers.anthropic:
201 |                     raise ValueError('Anthropic provider configuration not found')
202 | 
203 |                 api_key = config.providers.anthropic.api_key
204 |                 _validate_api_key('Anthropic', api_key, logger)
205 | 
206 |                 llm_config = GraphitiLLMConfig(
207 |                     api_key=api_key,
208 |                     model=config.model,
209 |                     temperature=config.temperature,
210 |                     max_tokens=config.max_tokens,
211 |                 )
212 |                 return AnthropicClient(config=llm_config)
213 | 
214 |             case 'gemini':
215 |                 if not HAS_GEMINI:
216 |                     raise ValueError('Gemini client not available in current graphiti-core version')
217 |                 if not config.providers.gemini:
218 |                     raise ValueError('Gemini provider configuration not found')
219 | 
220 |                 api_key = config.providers.gemini.api_key
221 |                 _validate_api_key('Gemini', api_key, logger)
222 | 
223 |                 llm_config = GraphitiLLMConfig(
224 |                     api_key=api_key,
225 |                     model=config.model,
226 |                     temperature=config.temperature,
227 |                     max_tokens=config.max_tokens,
228 |                 )
229 |                 return GeminiClient(config=llm_config)
230 | 
231 |             case 'groq':
232 |                 if not HAS_GROQ:
233 |                     raise ValueError('Groq client not available in current graphiti-core version')
234 |                 if not config.providers.groq:
235 |                     raise ValueError('Groq provider configuration not found')
236 | 
237 |                 api_key = config.providers.groq.api_key
238 |                 _validate_api_key('Groq', api_key, logger)
239 | 
240 |                 llm_config = GraphitiLLMConfig(
241 |                     api_key=api_key,
242 |                     base_url=config.providers.groq.api_url,
243 |                     model=config.model,
244 |                     temperature=config.temperature,
245 |                     max_tokens=config.max_tokens,
246 |                 )
247 |                 return GroqClient(config=llm_config)
248 | 
249 |             case _:
250 |                 raise ValueError(f'Unsupported LLM provider: {provider}')
251 | 
252 | 
253 | class EmbedderFactory:
254 |     """Factory for creating Embedder clients based on configuration."""
255 | 
256 |     @staticmethod
257 |     def create(config: EmbedderConfig) -> EmbedderClient:
258 |         """Create an Embedder client based on the configured provider."""
259 |         import logging
260 | 
261 |         logger = logging.getLogger(__name__)
262 | 
263 |         provider = config.provider.lower()
264 | 
265 |         match provider:
266 |             case 'openai':
267 |                 if not config.providers.openai:
268 |                     raise ValueError('OpenAI provider configuration not found')
269 | 
270 |                 api_key = config.providers.openai.api_key
271 |                 _validate_api_key('OpenAI Embedder', api_key, logger)
272 | 
273 |                 from graphiti_core.embedder.openai import OpenAIEmbedderConfig
274 | 
275 |                 embedder_config = OpenAIEmbedderConfig(
276 |                     api_key=api_key,
277 |                     embedding_model=config.model,
278 |                 )
279 |                 return OpenAIEmbedder(config=embedder_config)
280 | 
281 |             case 'azure_openai':
282 |                 if not HAS_AZURE_EMBEDDER:
283 |                     raise ValueError(
284 |                         'Azure OpenAI embedder not available in current graphiti-core version'
285 |                     )
286 |                 if not config.providers.azure_openai:
287 |                     raise ValueError('Azure OpenAI provider configuration not found')
288 |                 azure_config = config.providers.azure_openai
289 | 
290 |                 if not azure_config.api_url:
291 |                     raise ValueError('Azure OpenAI API URL is required')
292 | 
293 |                 # Handle Azure AD authentication if enabled
294 |                 api_key: str | None = None
295 |                 azure_ad_token_provider = None
296 |                 if azure_config.use_azure_ad:
297 |                     logger.info(
298 |                         'Creating Azure OpenAI Embedder client with Azure AD authentication'
299 |                     )
300 |                     azure_ad_token_provider = create_azure_credential_token_provider()
301 |                 else:
302 |                     api_key = azure_config.api_key
303 |                     _validate_api_key('Azure OpenAI Embedder', api_key, logger)
304 | 
305 |                 # Create the Azure OpenAI client first
306 |                 azure_client = AsyncAzureOpenAI(
307 |                     api_key=api_key,
308 |                     azure_endpoint=azure_config.api_url,
309 |                     api_version=azure_config.api_version,
310 |                     azure_deployment=azure_config.deployment_name,
311 |                     azure_ad_token_provider=azure_ad_token_provider,
312 |                 )
313 | 
314 |                 return AzureOpenAIEmbedderClient(
315 |                     azure_client=azure_client,
316 |                     model=config.model or 'text-embedding-3-small',
317 |                 )
318 | 
319 |             case 'gemini':
320 |                 if not HAS_GEMINI_EMBEDDER:
321 |                     raise ValueError(
322 |                         'Gemini embedder not available in current graphiti-core version'
323 |                     )
324 |                 if not config.providers.gemini:
325 |                     raise ValueError('Gemini provider configuration not found')
326 | 
327 |                 api_key = config.providers.gemini.api_key
328 |                 _validate_api_key('Gemini Embedder', api_key, logger)
329 | 
330 |                 from graphiti_core.embedder.gemini import GeminiEmbedderConfig
331 | 
332 |                 gemini_config = GeminiEmbedderConfig(
333 |                     api_key=api_key,
334 |                     embedding_model=config.model or 'models/text-embedding-004',
335 |                     embedding_dim=config.dimensions or 768,
336 |                 )
337 |                 return GeminiEmbedder(config=gemini_config)
338 | 
339 |             case 'voyage':
340 |                 if not HAS_VOYAGE_EMBEDDER:
341 |                     raise ValueError(
342 |                         'Voyage embedder not available in current graphiti-core version'
343 |                     )
344 |                 if not config.providers.voyage:
345 |                     raise ValueError('Voyage provider configuration not found')
346 | 
347 |                 api_key = config.providers.voyage.api_key
348 |                 _validate_api_key('Voyage Embedder', api_key, logger)
349 | 
350 |                 from graphiti_core.embedder.voyage import VoyageAIEmbedderConfig
351 | 
352 |                 voyage_config = VoyageAIEmbedderConfig(
353 |                     api_key=api_key,
354 |                     embedding_model=config.model or 'voyage-3',
355 |                     embedding_dim=config.dimensions or 1024,
356 |                 )
357 |                 return VoyageAIEmbedder(config=voyage_config)
358 | 
359 |             case _:
360 |                 raise ValueError(f'Unsupported Embedder provider: {provider}')
361 | 
362 | 
363 | class DatabaseDriverFactory:
364 |     """Factory for creating Database drivers based on configuration.
365 | 
366 |     Note: This returns configuration dictionaries that can be passed to Graphiti(),
367 |     not driver instances directly, as the drivers require complex initialization.
368 |     """
369 | 
370 |     @staticmethod
371 |     def create_config(config: DatabaseConfig) -> dict:
372 |         """Create database configuration dictionary based on the configured provider."""
373 |         provider = config.provider.lower()
374 | 
375 |         match provider:
376 |             case 'neo4j':
377 |                 # Use Neo4j config if provided, otherwise use defaults
378 |                 if config.providers.neo4j:
379 |                     neo4j_config = config.providers.neo4j
380 |                 else:
381 |                     # Create default Neo4j configuration
382 |                     from config.schema import Neo4jProviderConfig
383 | 
384 |                     neo4j_config = Neo4jProviderConfig()
385 | 
386 |                 # Check for environment variable overrides (for CI/CD compatibility)
387 |                 import os
388 | 
389 |                 uri = os.environ.get('NEO4J_URI', neo4j_config.uri)
390 |                 username = os.environ.get('NEO4J_USER', neo4j_config.username)
391 |                 password = os.environ.get('NEO4J_PASSWORD', neo4j_config.password)
392 | 
393 |                 return {
394 |                     'uri': uri,
395 |                     'user': username,
396 |                     'password': password,
397 |                     # Note: database and use_parallel_runtime would need to be passed
398 |                     # to the driver after initialization if supported
399 |                 }
400 | 
401 |             case 'falkordb':
402 |                 if not HAS_FALKOR:
403 |                     raise ValueError(
404 |                         'FalkorDB driver not available in current graphiti-core version'
405 |                     )
406 | 
407 |                 # Use FalkorDB config if provided, otherwise use defaults
408 |                 if config.providers.falkordb:
409 |                     falkor_config = config.providers.falkordb
410 |                 else:
411 |                     # Create default FalkorDB configuration
412 |                     from config.schema import FalkorDBProviderConfig
413 | 
414 |                     falkor_config = FalkorDBProviderConfig()
415 | 
416 |                 # Check for environment variable overrides (for CI/CD compatibility)
417 |                 import os
418 |                 from urllib.parse import urlparse
419 | 
420 |                 uri = os.environ.get('FALKORDB_URI', falkor_config.uri)
421 |                 password = os.environ.get('FALKORDB_PASSWORD', falkor_config.password)
422 | 
423 |                 # Parse the URI to extract host and port
424 |                 parsed = urlparse(uri)
425 |                 host = parsed.hostname or 'localhost'
426 |                 port = parsed.port or 6379
427 | 
428 |                 return {
429 |                     'driver': 'falkordb',
430 |                     'host': host,
431 |                     'port': port,
432 |                     'password': password,
433 |                     'database': falkor_config.database,
434 |                 }
435 | 
436 |             case _:
437 |                 raise ValueError(f'Unsupported Database provider: {provider}')
438 | 
```
Page 6/12FirstPrevNextLast