#
tokens: 23512/50000 1/234 files (page 11/12)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 11 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── daily_issue_maintenance.yml
│       ├── issue-triage.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       └── search
│           └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/search/search_utils.py:
--------------------------------------------------------------------------------

```python
   1 | """
   2 | Copyright 2024, Zep Software, Inc.
   3 | 
   4 | Licensed under the Apache License, Version 2.0 (the "License");
   5 | you may not use this file except in compliance with the License.
   6 | You may obtain a copy of the License at
   7 | 
   8 |     http://www.apache.org/licenses/LICENSE-2.0
   9 | 
  10 | Unless required by applicable law or agreed to in writing, software
  11 | distributed under the License is distributed on an "AS IS" BASIS,
  12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13 | See the License for the specific language governing permissions and
  14 | limitations under the License.
  15 | """
  16 | 
  17 | import logging
  18 | from collections import defaultdict
  19 | from time import time
  20 | from typing import Any
  21 | 
  22 | import numpy as np
  23 | from numpy._typing import NDArray
  24 | from typing_extensions import LiteralString
  25 | 
  26 | from graphiti_core.driver.driver import (
  27 |     GraphDriver,
  28 |     GraphProvider,
  29 | )
  30 | from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
  31 | from graphiti_core.graph_queries import (
  32 |     get_nodes_query,
  33 |     get_relationships_query,
  34 |     get_vector_cosine_func_query,
  35 | )
  36 | from graphiti_core.helpers import (
  37 |     lucene_sanitize,
  38 |     normalize_l2,
  39 |     semaphore_gather,
  40 | )
  41 | from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
  42 | from graphiti_core.models.nodes.node_db_queries import (
  43 |     COMMUNITY_NODE_RETURN,
  44 |     EPISODIC_NODE_RETURN,
  45 |     get_entity_node_return_query,
  46 | )
  47 | from graphiti_core.nodes import (
  48 |     CommunityNode,
  49 |     EntityNode,
  50 |     EpisodicNode,
  51 |     get_community_node_from_record,
  52 |     get_entity_node_from_record,
  53 |     get_episodic_node_from_record,
  54 | )
  55 | from graphiti_core.search.search_filters import (
  56 |     SearchFilters,
  57 |     edge_search_filter_query_constructor,
  58 |     node_search_filter_query_constructor,
  59 | )
  60 | 
  61 | logger = logging.getLogger(__name__)
  62 | 
  63 | RELEVANT_SCHEMA_LIMIT = 10
  64 | DEFAULT_MIN_SCORE = 0.6
  65 | DEFAULT_MMR_LAMBDA = 0.5
  66 | MAX_SEARCH_DEPTH = 3
  67 | MAX_QUERY_LENGTH = 128
  68 | 
  69 | 
  70 | def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
  71 |     """
  72 |     Calculates the cosine similarity between two vectors using NumPy.
  73 |     """
  74 |     dot_product = np.dot(vector1, vector2)
  75 |     norm_vector1 = np.linalg.norm(vector1)
  76 |     norm_vector2 = np.linalg.norm(vector2)
  77 | 
  78 |     if norm_vector1 == 0 or norm_vector2 == 0:
  79 |         return 0  # Handle cases where one or both vectors are zero vectors
  80 | 
  81 |     return dot_product / (norm_vector1 * norm_vector2)
  82 | 
  83 | 
  84 | def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
  85 |     if driver.provider == GraphProvider.KUZU:
  86 |         # Kuzu only supports simple queries.
  87 |         if len(query.split(' ')) > MAX_QUERY_LENGTH:
  88 |             return ''
  89 |         return query
  90 |     elif driver.provider == GraphProvider.FALKORDB:
  91 |         return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
  92 |     group_ids_filter_list = (
  93 |         [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
  94 |         if group_ids is not None
  95 |         else []
  96 |     )
  97 |     group_ids_filter = ''
  98 |     for f in group_ids_filter_list:
  99 |         group_ids_filter += f if not group_ids_filter else f' OR {f}'
 100 | 
 101 |     group_ids_filter += ' AND ' if group_ids_filter else ''
 102 | 
 103 |     lucene_query = lucene_sanitize(query)
 104 |     # If the lucene query is too long return no query
 105 |     if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
 106 |         return ''
 107 | 
 108 |     full_query = group_ids_filter + '(' + lucene_query + ')'
 109 | 
 110 |     return full_query
 111 | 
 112 | 
 113 | async def get_episodes_by_mentions(
 114 |     driver: GraphDriver,
 115 |     nodes: list[EntityNode],
 116 |     edges: list[EntityEdge],
 117 |     limit: int = RELEVANT_SCHEMA_LIMIT,
 118 | ) -> list[EpisodicNode]:
 119 |     episode_uuids: list[str] = []
 120 |     for edge in edges:
 121 |         episode_uuids.extend(edge.episodes)
 122 | 
 123 |     episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
 124 | 
 125 |     return episodes
 126 | 
 127 | 
 128 | async def get_mentioned_nodes(
 129 |     driver: GraphDriver, episodes: list[EpisodicNode]
 130 | ) -> list[EntityNode]:
 131 |     episode_uuids = [episode.uuid for episode in episodes]
 132 | 
 133 |     records, _, _ = await driver.execute_query(
 134 |         """
 135 |         MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
 136 |         WHERE episode.uuid IN $uuids
 137 |         RETURN DISTINCT
 138 |         """
 139 |         + get_entity_node_return_query(driver.provider),
 140 |         uuids=episode_uuids,
 141 |         routing_='r',
 142 |     )
 143 | 
 144 |     nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
 145 | 
 146 |     return nodes
 147 | 
 148 | 
 149 | async def get_communities_by_nodes(
 150 |     driver: GraphDriver, nodes: list[EntityNode]
 151 | ) -> list[CommunityNode]:
 152 |     node_uuids = [node.uuid for node in nodes]
 153 | 
 154 |     records, _, _ = await driver.execute_query(
 155 |         """
 156 |         MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
 157 |         WHERE m.uuid IN $uuids
 158 |         RETURN DISTINCT
 159 |         """
 160 |         + COMMUNITY_NODE_RETURN,
 161 |         uuids=node_uuids,
 162 |         routing_='r',
 163 |     )
 164 | 
 165 |     communities = [get_community_node_from_record(record) for record in records]
 166 | 
 167 |     return communities
 168 | 
 169 | 
 170 | async def edge_fulltext_search(
 171 |     driver: GraphDriver,
 172 |     query: str,
 173 |     search_filter: SearchFilters,
 174 |     group_ids: list[str] | None = None,
 175 |     limit=RELEVANT_SCHEMA_LIMIT,
 176 | ) -> list[EntityEdge]:
 177 |     if driver.search_interface:
 178 |         return await driver.search_interface.edge_fulltext_search(
 179 |             driver, query, search_filter, group_ids, limit
 180 |         )
 181 | 
 182 |     # fulltext search over facts
 183 |     fuzzy_query = fulltext_query(query, group_ids, driver)
 184 | 
 185 |     if fuzzy_query == '':
 186 |         return []
 187 | 
 188 |     match_query = """
 189 |     YIELD relationship AS rel, score
 190 |     MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
 191 |     """
 192 |     if driver.provider == GraphProvider.KUZU:
 193 |         match_query = """
 194 |         YIELD node, score
 195 |         MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
 196 |         """
 197 | 
 198 |     filter_queries, filter_params = edge_search_filter_query_constructor(
 199 |         search_filter, driver.provider
 200 |     )
 201 | 
 202 |     if group_ids is not None:
 203 |         filter_queries.append('e.group_id IN $group_ids')
 204 |         filter_params['group_ids'] = group_ids
 205 | 
 206 |     filter_query = ''
 207 |     if filter_queries:
 208 |         filter_query = ' WHERE ' + (' AND '.join(filter_queries))
 209 | 
 210 |     if driver.provider == GraphProvider.NEPTUNE:
 211 |         res = driver.run_aoss_query('edge_name_and_fact', query)  # pyright: ignore reportAttributeAccessIssue
 212 |         if res['hits']['total']['value'] > 0:
 213 |             input_ids = []
 214 |             for r in res['hits']['hits']:
 215 |                 input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
 216 | 
 217 |             # Match the edge ids and return the values
 218 |             query = (
 219 |                 """
 220 |                                 UNWIND $ids as id
 221 |                                 MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
 222 |                                 WHERE e.group_id IN $group_ids 
 223 |                                 AND id(e)=id 
 224 |                                 """
 225 |                 + filter_query
 226 |                 + """
 227 |                 AND id(e)=id
 228 |                 WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
 229 |                 RETURN
 230 |                     e.uuid AS uuid,
 231 |                     e.group_id AS group_id,
 232 |                     n.uuid AS source_node_uuid,
 233 |                     m.uuid AS target_node_uuid,
 234 |                     e.created_at AS created_at,
 235 |                     e.name AS name,
 236 |                     e.fact AS fact,
 237 |                     split(e.episodes, ",") AS episodes,
 238 |                     e.expired_at AS expired_at,
 239 |                     e.valid_at AS valid_at,
 240 |                     e.invalid_at AS invalid_at,
 241 |                     properties(e) AS attributes
 242 |                 ORDER BY score DESC LIMIT $limit
 243 |                             """
 244 |             )
 245 | 
 246 |             records, _, _ = await driver.execute_query(
 247 |                 query,
 248 |                 query=fuzzy_query,
 249 |                 ids=input_ids,
 250 |                 limit=limit,
 251 |                 routing_='r',
 252 |                 **filter_params,
 253 |             )
 254 |         else:
 255 |             return []
 256 |     else:
 257 |         query = (
 258 |             get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
 259 |             + match_query
 260 |             + filter_query
 261 |             + """
 262 |             WITH e, score, n, m
 263 |             RETURN
 264 |             """
 265 |             + get_entity_edge_return_query(driver.provider)
 266 |             + """
 267 |             ORDER BY score DESC
 268 |             LIMIT $limit
 269 |             """
 270 |         )
 271 | 
 272 |         records, _, _ = await driver.execute_query(
 273 |             query,
 274 |             query=fuzzy_query,
 275 |             limit=limit,
 276 |             routing_='r',
 277 |             **filter_params,
 278 |         )
 279 | 
 280 |     edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
 281 | 
 282 |     return edges
 283 | 
 284 | 
 285 | async def edge_similarity_search(
 286 |     driver: GraphDriver,
 287 |     search_vector: list[float],
 288 |     source_node_uuid: str | None,
 289 |     target_node_uuid: str | None,
 290 |     search_filter: SearchFilters,
 291 |     group_ids: list[str] | None = None,
 292 |     limit: int = RELEVANT_SCHEMA_LIMIT,
 293 |     min_score: float = DEFAULT_MIN_SCORE,
 294 | ) -> list[EntityEdge]:
 295 |     if driver.search_interface:
 296 |         return await driver.search_interface.edge_similarity_search(
 297 |             driver,
 298 |             search_vector,
 299 |             source_node_uuid,
 300 |             target_node_uuid,
 301 |             search_filter,
 302 |             group_ids,
 303 |             limit,
 304 |             min_score,
 305 |         )
 306 | 
 307 |     match_query = """
 308 |         MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
 309 |     """
 310 |     if driver.provider == GraphProvider.KUZU:
 311 |         match_query = """
 312 |             MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
 313 |         """
 314 | 
 315 |     filter_queries, filter_params = edge_search_filter_query_constructor(
 316 |         search_filter, driver.provider
 317 |     )
 318 | 
 319 |     if group_ids is not None:
 320 |         filter_queries.append('e.group_id IN $group_ids')
 321 |         filter_params['group_ids'] = group_ids
 322 | 
 323 |         if source_node_uuid is not None:
 324 |             filter_params['source_uuid'] = source_node_uuid
 325 |             filter_queries.append('n.uuid = $source_uuid')
 326 | 
 327 |         if target_node_uuid is not None:
 328 |             filter_params['target_uuid'] = target_node_uuid
 329 |             filter_queries.append('m.uuid = $target_uuid')
 330 | 
 331 |     filter_query = ''
 332 |     if filter_queries:
 333 |         filter_query = ' WHERE ' + (' AND '.join(filter_queries))
 334 | 
 335 |     search_vector_var = '$search_vector'
 336 |     if driver.provider == GraphProvider.KUZU:
 337 |         search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
 338 | 
 339 |     if driver.provider == GraphProvider.NEPTUNE:
 340 |         query = (
 341 |             """
 342 |                             MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
 343 |                             """
 344 |             + filter_query
 345 |             + """
 346 |             RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
 347 |             """
 348 |         )
 349 |         resp, header, _ = await driver.execute_query(
 350 |             query,
 351 |             search_vector=search_vector,
 352 |             limit=limit,
 353 |             min_score=min_score,
 354 |             routing_='r',
 355 |             **filter_params,
 356 |         )
 357 | 
 358 |         if len(resp) > 0:
 359 |             # Calculate Cosine similarity then return the edge ids
 360 |             input_ids = []
 361 |             for r in resp:
 362 |                 if r['embedding']:
 363 |                     score = calculate_cosine_similarity(
 364 |                         search_vector, list(map(float, r['embedding'].split(',')))
 365 |                     )
 366 |                     if score > min_score:
 367 |                         input_ids.append({'id': r['id'], 'score': score})
 368 | 
 369 |             # Match the edge ides and return the values
 370 |             query = """
 371 |                 UNWIND $ids as i
 372 |                 MATCH ()-[r]->()
 373 |                 WHERE id(r) = i.id
 374 |                 RETURN
 375 |                     r.uuid AS uuid,
 376 |                     r.group_id AS group_id,
 377 |                     startNode(r).uuid AS source_node_uuid,
 378 |                     endNode(r).uuid AS target_node_uuid,
 379 |                     r.created_at AS created_at,
 380 |                     r.name AS name,
 381 |                     r.fact AS fact,
 382 |                     split(r.episodes, ",") AS episodes,
 383 |                     r.expired_at AS expired_at,
 384 |                     r.valid_at AS valid_at,
 385 |                     r.invalid_at AS invalid_at,
 386 |                     properties(r) AS attributes
 387 |                 ORDER BY i.score DESC
 388 |                 LIMIT $limit
 389 |                     """
 390 |             records, _, _ = await driver.execute_query(
 391 |                 query,
 392 |                 ids=input_ids,
 393 |                 search_vector=search_vector,
 394 |                 limit=limit,
 395 |                 min_score=min_score,
 396 |                 routing_='r',
 397 |                 **filter_params,
 398 |             )
 399 |         else:
 400 |             return []
 401 |     else:
 402 |         query = (
 403 |             match_query
 404 |             + filter_query
 405 |             + """
 406 |             WITH DISTINCT e, n, m, """
 407 |             + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
 408 |             + """ AS score
 409 |             WHERE score > $min_score
 410 |             RETURN
 411 |             """
 412 |             + get_entity_edge_return_query(driver.provider)
 413 |             + """
 414 |             ORDER BY score DESC
 415 |             LIMIT $limit
 416 |             """
 417 |         )
 418 | 
 419 |         records, _, _ = await driver.execute_query(
 420 |             query,
 421 |             search_vector=search_vector,
 422 |             limit=limit,
 423 |             min_score=min_score,
 424 |             routing_='r',
 425 |             **filter_params,
 426 |         )
 427 | 
 428 |     edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
 429 | 
 430 |     return edges
 431 | 
 432 | 
 433 | async def edge_bfs_search(
 434 |     driver: GraphDriver,
 435 |     bfs_origin_node_uuids: list[str] | None,
 436 |     bfs_max_depth: int,
 437 |     search_filter: SearchFilters,
 438 |     group_ids: list[str] | None = None,
 439 |     limit: int = RELEVANT_SCHEMA_LIMIT,
 440 | ) -> list[EntityEdge]:
 441 |     # vector similarity search over embedded facts
 442 |     if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
 443 |         return []
 444 | 
 445 |     filter_queries, filter_params = edge_search_filter_query_constructor(
 446 |         search_filter, driver.provider
 447 |     )
 448 | 
 449 |     if group_ids is not None:
 450 |         filter_queries.append('e.group_id IN $group_ids')
 451 |         filter_params['group_ids'] = group_ids
 452 | 
 453 |     filter_query = ''
 454 |     if filter_queries:
 455 |         filter_query = ' WHERE ' + (' AND '.join(filter_queries))
 456 | 
 457 |     if driver.provider == GraphProvider.KUZU:
 458 |         # Kuzu stores entity edges twice with an intermediate node, so we need to match them
 459 |         # separately for the correct BFS depth.
 460 |         depth = bfs_max_depth * 2 - 1
 461 |         match_queries = [
 462 |             f"""
 463 |             UNWIND $bfs_origin_node_uuids AS origin_uuid
 464 |             MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
 465 |             UNWIND nodes(path) AS relNode
 466 |             MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
 467 |             """,
 468 |         ]
 469 |         if bfs_max_depth > 1:
 470 |             depth = (bfs_max_depth - 1) * 2 - 1
 471 |             match_queries.append(f"""
 472 |                 UNWIND $bfs_origin_node_uuids AS origin_uuid
 473 |                 MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
 474 |                 UNWIND nodes(path) AS relNode
 475 |                 MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
 476 |             """)
 477 | 
 478 |         records = []
 479 |         for match_query in match_queries:
 480 |             sub_records, _, _ = await driver.execute_query(
 481 |                 match_query
 482 |                 + filter_query
 483 |                 + """
 484 |                 RETURN DISTINCT
 485 |                 """
 486 |                 + get_entity_edge_return_query(driver.provider)
 487 |                 + """
 488 |                 LIMIT $limit
 489 |                 """,
 490 |                 bfs_origin_node_uuids=bfs_origin_node_uuids,
 491 |                 limit=limit,
 492 |                 routing_='r',
 493 |                 **filter_params,
 494 |             )
 495 |             records.extend(sub_records)
 496 |     else:
 497 |         if driver.provider == GraphProvider.NEPTUNE:
 498 |             query = (
 499 |                 f"""
 500 |                 UNWIND $bfs_origin_node_uuids AS origin_uuid
 501 |                 MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
 502 |                 WHERE origin:Entity OR origin:Episodic
 503 |                 UNWIND relationships(path) AS rel
 504 |                 MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
 505 |                 """
 506 |                 + filter_query
 507 |                 + """
 508 |                 RETURN DISTINCT
 509 |                     e.uuid AS uuid,
 510 |                     e.group_id AS group_id,
 511 |                     startNode(e).uuid AS source_node_uuid,
 512 |                     endNode(e).uuid AS target_node_uuid,
 513 |                     e.created_at AS created_at,
 514 |                     e.name AS name,
 515 |                     e.fact AS fact,
 516 |                     split(e.episodes, ',') AS episodes,
 517 |                     e.expired_at AS expired_at,
 518 |                     e.valid_at AS valid_at,
 519 |                     e.invalid_at AS invalid_at,
 520 |                     properties(e) AS attributes
 521 |                 LIMIT $limit
 522 |                 """
 523 |             )
 524 |         else:
 525 |             query = (
 526 |                 f"""
 527 |                 UNWIND $bfs_origin_node_uuids AS origin_uuid
 528 |                 MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
 529 |                 UNWIND relationships(path) AS rel
 530 |                 MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
 531 |                 """
 532 |                 + filter_query
 533 |                 + """
 534 |                 RETURN DISTINCT
 535 |                 """
 536 |                 + get_entity_edge_return_query(driver.provider)
 537 |                 + """
 538 |                 LIMIT $limit
 539 |                 """
 540 |             )
 541 | 
 542 |         records, _, _ = await driver.execute_query(
 543 |             query,
 544 |             bfs_origin_node_uuids=bfs_origin_node_uuids,
 545 |             depth=bfs_max_depth,
 546 |             limit=limit,
 547 |             routing_='r',
 548 |             **filter_params,
 549 |         )
 550 | 
 551 |     edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
 552 | 
 553 |     return edges
 554 | 
 555 | 
 556 | async def node_fulltext_search(
 557 |     driver: GraphDriver,
 558 |     query: str,
 559 |     search_filter: SearchFilters,
 560 |     group_ids: list[str] | None = None,
 561 |     limit=RELEVANT_SCHEMA_LIMIT,
 562 | ) -> list[EntityNode]:
 563 |     if driver.search_interface:
 564 |         return await driver.search_interface.node_fulltext_search(
 565 |             driver, query, search_filter, group_ids, limit
 566 |         )
 567 | 
 568 |     # BM25 search to get top nodes
 569 |     fuzzy_query = fulltext_query(query, group_ids, driver)
 570 |     if fuzzy_query == '':
 571 |         return []
 572 | 
 573 |     filter_queries, filter_params = node_search_filter_query_constructor(
 574 |         search_filter, driver.provider
 575 |     )
 576 | 
 577 |     if group_ids is not None:
 578 |         filter_queries.append('n.group_id IN $group_ids')
 579 |         filter_params['group_ids'] = group_ids
 580 | 
 581 |     filter_query = ''
 582 |     if filter_queries:
 583 |         filter_query = ' WHERE ' + (' AND '.join(filter_queries))
 584 | 
 585 |     yield_query = 'YIELD node AS n, score'
 586 |     if driver.provider == GraphProvider.KUZU:
 587 |         yield_query = 'WITH node AS n, score'
 588 | 
 589 |     if driver.provider == GraphProvider.NEPTUNE:
 590 |         res = driver.run_aoss_query('node_name_and_summary', query, limit=limit)  # pyright: ignore reportAttributeAccessIssue
 591 |         if res['hits']['total']['value'] > 0:
 592 |             input_ids = []
 593 |             for r in res['hits']['hits']:
 594 |                 input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
 595 | 
 596 |             # Match the edge ides and return the values
 597 |             query = (
 598 |                 """
 599 |                                 UNWIND $ids as i
 600 |                                 MATCH (n:Entity)
 601 |                                 WHERE n.uuid=i.id
 602 |                                 RETURN
 603 |                                 """
 604 |                 + get_entity_node_return_query(driver.provider)
 605 |                 + """
 606 |                 ORDER BY i.score DESC
 607 |                 LIMIT $limit
 608 |                             """
 609 |             )
 610 |             records, _, _ = await driver.execute_query(
 611 |                 query,
 612 |                 ids=input_ids,
 613 |                 query=fuzzy_query,
 614 |                 limit=limit,
 615 |                 routing_='r',
 616 |                 **filter_params,
 617 |             )
 618 |         else:
 619 |             return []
 620 |     else:
 621 |         query = (
 622 |             get_nodes_query(
 623 |                 'node_name_and_summary', '$query', limit=limit, provider=driver.provider
 624 |             )
 625 |             + yield_query
 626 |             + filter_query
 627 |             + """
 628 |             WITH n, score
 629 |             ORDER BY score DESC
 630 |             LIMIT $limit
 631 |             RETURN
 632 |             """
 633 |             + get_entity_node_return_query(driver.provider)
 634 |         )
 635 | 
 636 |         records, _, _ = await driver.execute_query(
 637 |             query,
 638 |             query=fuzzy_query,
 639 |             limit=limit,
 640 |             routing_='r',
 641 |             **filter_params,
 642 |         )
 643 | 
 644 |     nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
 645 | 
 646 |     return nodes
 647 | 
 648 | 
 649 | async def node_similarity_search(
 650 |     driver: GraphDriver,
 651 |     search_vector: list[float],
 652 |     search_filter: SearchFilters,
 653 |     group_ids: list[str] | None = None,
 654 |     limit=RELEVANT_SCHEMA_LIMIT,
 655 |     min_score: float = DEFAULT_MIN_SCORE,
 656 | ) -> list[EntityNode]:
 657 |     if driver.search_interface:
 658 |         return await driver.search_interface.node_similarity_search(
 659 |             driver, search_vector, search_filter, group_ids, limit, min_score
 660 |         )
 661 | 
 662 |     filter_queries, filter_params = node_search_filter_query_constructor(
 663 |         search_filter, driver.provider
 664 |     )
 665 | 
 666 |     if group_ids is not None:
 667 |         filter_queries.append('n.group_id IN $group_ids')
 668 |         filter_params['group_ids'] = group_ids
 669 | 
 670 |     filter_query = ''
 671 |     if filter_queries:
 672 |         filter_query = ' WHERE ' + (' AND '.join(filter_queries))
 673 | 
 674 |     search_vector_var = '$search_vector'
 675 |     if driver.provider == GraphProvider.KUZU:
 676 |         search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
 677 | 
 678 |     if driver.provider == GraphProvider.NEPTUNE:
 679 |         query = (
 680 |             """
 681 |                                                                                                                                     MATCH (n:Entity)
 682 |                                                                                                                                     """
 683 |             + filter_query
 684 |             + """
 685 |             RETURN DISTINCT id(n) as id, n.name_embedding as embedding
 686 |             """
 687 |         )
 688 |         resp, header, _ = await driver.execute_query(
 689 |             query,
 690 |             params=filter_params,
 691 |             search_vector=search_vector,
 692 |             limit=limit,
 693 |             min_score=min_score,
 694 |             routing_='r',
 695 |         )
 696 | 
 697 |         if len(resp) > 0:
 698 |             # Calculate Cosine similarity then return the edge ids
 699 |             input_ids = []
 700 |             for r in resp:
 701 |                 if r['embedding']:
 702 |                     score = calculate_cosine_similarity(
 703 |                         search_vector, list(map(float, r['embedding'].split(',')))
 704 |                     )
 705 |                     if score > min_score:
 706 |                         input_ids.append({'id': r['id'], 'score': score})
 707 | 
 708 |             # Match the edge ides and return the values
 709 |             query = (
 710 |                 """
 711 |                                                                                                                                                                 UNWIND $ids as i
 712 |                                                                                                                                                                 MATCH (n:Entity)
 713 |                                                                                                                                                                 WHERE id(n)=i.id
 714 |                                                                                                                                                                 RETURN 
 715 |                                                                                                                                                                 """
 716 |                 + get_entity_node_return_query(driver.provider)
 717 |                 + """
 718 |                     ORDER BY i.score DESC
 719 |                     LIMIT $limit
 720 |                 """
 721 |             )
 722 |             records, header, _ = await driver.execute_query(
 723 |                 query,
 724 |                 ids=input_ids,
 725 |                 search_vector=search_vector,
 726 |                 limit=limit,
 727 |                 min_score=min_score,
 728 |                 routing_='r',
 729 |                 **filter_params,
 730 |             )
 731 |         else:
 732 |             return []
 733 |     else:
 734 |         query = (
 735 |             """
 736 |                                                                                                                                     MATCH (n:Entity)
 737 |                                                                                                                                     """
 738 |             + filter_query
 739 |             + """
 740 |             WITH n, """
 741 |             + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
 742 |             + """ AS score
 743 |             WHERE score > $min_score
 744 |             RETURN
 745 |             """
 746 |             + get_entity_node_return_query(driver.provider)
 747 |             + """
 748 |             ORDER BY score DESC
 749 |             LIMIT $limit
 750 |             """
 751 |         )
 752 | 
 753 |         records, _, _ = await driver.execute_query(
 754 |             query,
 755 |             search_vector=search_vector,
 756 |             limit=limit,
 757 |             min_score=min_score,
 758 |             routing_='r',
 759 |             **filter_params,
 760 |         )
 761 | 
 762 |     nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
 763 | 
 764 |     return nodes
 765 | 
 766 | 
 767 | async def node_bfs_search(
 768 |     driver: GraphDriver,
 769 |     bfs_origin_node_uuids: list[str] | None,
 770 |     search_filter: SearchFilters,
 771 |     bfs_max_depth: int,
 772 |     group_ids: list[str] | None = None,
 773 |     limit: int = RELEVANT_SCHEMA_LIMIT,
 774 | ) -> list[EntityNode]:
 775 |     if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
 776 |         return []
 777 | 
 778 |     filter_queries, filter_params = node_search_filter_query_constructor(
 779 |         search_filter, driver.provider
 780 |     )
 781 | 
 782 |     if group_ids is not None:
 783 |         filter_queries.append('n.group_id IN $group_ids')
 784 |         filter_queries.append('origin.group_id IN $group_ids')
 785 |         filter_params['group_ids'] = group_ids
 786 | 
 787 |     filter_query = ''
 788 |     if filter_queries:
 789 |         filter_query = ' AND ' + (' AND '.join(filter_queries))
 790 | 
 791 |     match_queries = [
 792 |         f"""
 793 |         UNWIND $bfs_origin_node_uuids AS origin_uuid
 794 |         MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
 795 |         WHERE n.group_id = origin.group_id
 796 |         """
 797 |     ]
 798 | 
 799 |     if driver.provider == GraphProvider.NEPTUNE:
 800 |         match_queries = [
 801 |             f"""
 802 |             UNWIND $bfs_origin_node_uuids AS origin_uuid
 803 |             MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
 804 |             WHERE origin:Entity OR origin.Episode
 805 |             AND n.group_id = origin.group_id
 806 |             """
 807 |         ]
 808 | 
 809 |     if driver.provider == GraphProvider.KUZU:
 810 |         depth = bfs_max_depth * 2
 811 |         match_queries = [
 812 |             """
 813 |             UNWIND $bfs_origin_node_uuids AS origin_uuid
 814 |             MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
 815 |             WHERE n.group_id = origin.group_id
 816 |             """,
 817 |             f"""
 818 |             UNWIND $bfs_origin_node_uuids AS origin_uuid
 819 |             MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
 820 |             WHERE n.group_id = origin.group_id
 821 |             """,
 822 |         ]
 823 |         if bfs_max_depth > 1:
 824 |             depth = (bfs_max_depth - 1) * 2
 825 |             match_queries.append(f"""
 826 |                 UNWIND $bfs_origin_node_uuids AS origin_uuid
 827 |                 MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
 828 |                 WHERE n.group_id = origin.group_id
 829 |             """)
 830 | 
 831 |     records = []
 832 |     for match_query in match_queries:
 833 |         sub_records, _, _ = await driver.execute_query(
 834 |             match_query
 835 |             + filter_query
 836 |             + """
 837 |             RETURN
 838 |             """
 839 |             + get_entity_node_return_query(driver.provider)
 840 |             + """
 841 |             LIMIT $limit
 842 |             """,
 843 |             bfs_origin_node_uuids=bfs_origin_node_uuids,
 844 |             limit=limit,
 845 |             routing_='r',
 846 |             **filter_params,
 847 |         )
 848 |         records.extend(sub_records)
 849 | 
 850 |     nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
 851 | 
 852 |     return nodes
 853 | 
 854 | 
 855 | async def episode_fulltext_search(
 856 |     driver: GraphDriver,
 857 |     query: str,
 858 |     _search_filter: SearchFilters,
 859 |     group_ids: list[str] | None = None,
 860 |     limit=RELEVANT_SCHEMA_LIMIT,
 861 | ) -> list[EpisodicNode]:
 862 |     if driver.search_interface:
 863 |         return await driver.search_interface.episode_fulltext_search(
 864 |             driver, query, _search_filter, group_ids, limit
 865 |         )
 866 | 
 867 |     # BM25 search to get top episodes
 868 |     fuzzy_query = fulltext_query(query, group_ids, driver)
 869 |     if fuzzy_query == '':
 870 |         return []
 871 | 
 872 |     filter_params: dict[str, Any] = {}
 873 |     group_filter_query: LiteralString = ''
 874 |     if group_ids is not None:
 875 |         group_filter_query += '\nAND e.group_id IN $group_ids'
 876 |         filter_params['group_ids'] = group_ids
 877 | 
 878 |     if driver.provider == GraphProvider.NEPTUNE:
 879 |         res = driver.run_aoss_query('episode_content', query, limit=limit)  # pyright: ignore reportAttributeAccessIssue
 880 |         if res['hits']['total']['value'] > 0:
 881 |             input_ids = []
 882 |             for r in res['hits']['hits']:
 883 |                 input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
 884 | 
 885 |             # Match the edge ides and return the values
 886 |             query = """
 887 |                 UNWIND $ids as i
 888 |                 MATCH (e:Episodic)
 889 |                 WHERE e.uuid=i.uuid
 890 |             RETURN
 891 |                     e.content AS content,
 892 |                     e.created_at AS created_at,
 893 |                     e.valid_at AS valid_at,
 894 |                     e.uuid AS uuid,
 895 |                     e.name AS name,
 896 |                     e.group_id AS group_id,
 897 |                     e.source_description AS source_description,
 898 |                     e.source AS source,
 899 |                     e.entity_edges AS entity_edges
 900 |                 ORDER BY i.score DESC
 901 |                 LIMIT $limit
 902 |             """
 903 |             records, _, _ = await driver.execute_query(
 904 |                 query,
 905 |                 ids=input_ids,
 906 |                 query=fuzzy_query,
 907 |                 limit=limit,
 908 |                 routing_='r',
 909 |                 **filter_params,
 910 |             )
 911 |         else:
 912 |             return []
 913 |     else:
 914 |         query = (
 915 |             get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
 916 |             + """
 917 |             YIELD node AS episode, score
 918 |             MATCH (e:Episodic)
 919 |             WHERE e.uuid = episode.uuid
 920 |             """
 921 |             + group_filter_query
 922 |             + """
 923 |             RETURN
 924 |             """
 925 |             + EPISODIC_NODE_RETURN
 926 |             + """
 927 |             ORDER BY score DESC
 928 |             LIMIT $limit
 929 |             """
 930 |         )
 931 | 
 932 |         records, _, _ = await driver.execute_query(
 933 |             query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
 934 |         )
 935 | 
 936 |     episodes = [get_episodic_node_from_record(record) for record in records]
 937 | 
 938 |     return episodes
 939 | 
 940 | 
 941 | async def community_fulltext_search(
 942 |     driver: GraphDriver,
 943 |     query: str,
 944 |     group_ids: list[str] | None = None,
 945 |     limit=RELEVANT_SCHEMA_LIMIT,
 946 | ) -> list[CommunityNode]:
 947 |     # BM25 search to get top communities
 948 |     fuzzy_query = fulltext_query(query, group_ids, driver)
 949 |     if fuzzy_query == '':
 950 |         return []
 951 | 
 952 |     filter_params: dict[str, Any] = {}
 953 |     group_filter_query: LiteralString = ''
 954 |     if group_ids is not None:
 955 |         group_filter_query = 'WHERE c.group_id IN $group_ids'
 956 |         filter_params['group_ids'] = group_ids
 957 | 
 958 |     yield_query = 'YIELD node AS c, score'
 959 |     if driver.provider == GraphProvider.KUZU:
 960 |         yield_query = 'WITH node AS c, score'
 961 | 
 962 |     if driver.provider == GraphProvider.NEPTUNE:
 963 |         res = driver.run_aoss_query('community_name', query, limit=limit)  # pyright: ignore reportAttributeAccessIssue
 964 |         if res['hits']['total']['value'] > 0:
 965 |             # Calculate Cosine similarity then return the edge ids
 966 |             input_ids = []
 967 |             for r in res['hits']['hits']:
 968 |                 input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
 969 | 
 970 |             # Match the edge ides and return the values
 971 |             query = """
 972 |                 UNWIND $ids as i
 973 |                 MATCH (comm:Community)
 974 |                 WHERE comm.uuid=i.id
 975 |                 RETURN
 976 |                     comm.uuid AS uuid,
 977 |                     comm.group_id AS group_id,
 978 |                     comm.name AS name,
 979 |                     comm.created_at AS created_at,
 980 |                     comm.summary AS summary,
 981 |                     [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
 982 |                 ORDER BY i.score DESC
 983 |                 LIMIT $limit
 984 |             """
 985 |             records, _, _ = await driver.execute_query(
 986 |                 query,
 987 |                 ids=input_ids,
 988 |                 query=fuzzy_query,
 989 |                 limit=limit,
 990 |                 routing_='r',
 991 |                 **filter_params,
 992 |             )
 993 |         else:
 994 |             return []
 995 |     else:
 996 |         query = (
 997 |             get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
 998 |             + yield_query
 999 |             + """
1000 |             WITH c, score
1001 |             """
1002 |             + group_filter_query
1003 |             + """
1004 |             RETURN
1005 |             """
1006 |             + COMMUNITY_NODE_RETURN
1007 |             + """
1008 |             ORDER BY score DESC
1009 |             LIMIT $limit
1010 |             """
1011 |         )
1012 | 
1013 |         records, _, _ = await driver.execute_query(
1014 |             query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
1015 |         )
1016 | 
1017 |     communities = [get_community_node_from_record(record) for record in records]
1018 | 
1019 |     return communities
1020 | 
1021 | 
1022 | async def community_similarity_search(
1023 |     driver: GraphDriver,
1024 |     search_vector: list[float],
1025 |     group_ids: list[str] | None = None,
1026 |     limit=RELEVANT_SCHEMA_LIMIT,
1027 |     min_score=DEFAULT_MIN_SCORE,
1028 | ) -> list[CommunityNode]:
1029 |     # vector similarity search over entity names
1030 |     query_params: dict[str, Any] = {}
1031 | 
1032 |     group_filter_query: LiteralString = ''
1033 |     if group_ids is not None:
1034 |         group_filter_query += ' WHERE c.group_id IN $group_ids'
1035 |         query_params['group_ids'] = group_ids
1036 | 
1037 |     if driver.provider == GraphProvider.NEPTUNE:
1038 |         query = (
1039 |             """
1040 |                                                                                                                                     MATCH (n:Community)
1041 |                                                                                                                                     """
1042 |             + group_filter_query
1043 |             + """
1044 |             RETURN DISTINCT id(n) as id, n.name_embedding as embedding
1045 |             """
1046 |         )
1047 |         resp, header, _ = await driver.execute_query(
1048 |             query,
1049 |             search_vector=search_vector,
1050 |             limit=limit,
1051 |             min_score=min_score,
1052 |             routing_='r',
1053 |             **query_params,
1054 |         )
1055 | 
1056 |         if len(resp) > 0:
1057 |             # Calculate Cosine similarity then return the edge ids
1058 |             input_ids = []
1059 |             for r in resp:
1060 |                 if r['embedding']:
1061 |                     score = calculate_cosine_similarity(
1062 |                         search_vector, list(map(float, r['embedding'].split(',')))
1063 |                     )
1064 |                     if score > min_score:
1065 |                         input_ids.append({'id': r['id'], 'score': score})
1066 | 
1067 |             # Match the edge ides and return the values
1068 |             query = """
1069 |                     UNWIND $ids as i
1070 |                     MATCH (comm:Community)
1071 |                     WHERE id(comm)=i.id
1072 |                     RETURN
1073 |                         comm.uuid As uuid,
1074 |                         comm.group_id AS group_id,
1075 |                         comm.name AS name,
1076 |                         comm.created_at AS created_at,
1077 |                         comm.summary AS summary,
1078 |                         comm.name_embedding AS name_embedding
1079 |                     ORDER BY i.score DESC
1080 |                     LIMIT $limit
1081 |                 """
1082 |             records, header, _ = await driver.execute_query(
1083 |                 query,
1084 |                 ids=input_ids,
1085 |                 search_vector=search_vector,
1086 |                 limit=limit,
1087 |                 min_score=min_score,
1088 |                 routing_='r',
1089 |                 **query_params,
1090 |             )
1091 |         else:
1092 |             return []
1093 |     else:
1094 |         search_vector_var = '$search_vector'
1095 |         if driver.provider == GraphProvider.KUZU:
1096 |             search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
1097 | 
1098 |         query = (
1099 |             """
1100 |                                                                                                                                     MATCH (c:Community)
1101 |                                                                                                                                     """
1102 |             + group_filter_query
1103 |             + """
1104 |             WITH c,
1105 |             """
1106 |             + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
1107 |             + """ AS score
1108 |             WHERE score > $min_score
1109 |             RETURN
1110 |             """
1111 |             + COMMUNITY_NODE_RETURN
1112 |             + """
1113 |             ORDER BY score DESC
1114 |             LIMIT $limit
1115 |             """
1116 |         )
1117 | 
1118 |         records, _, _ = await driver.execute_query(
1119 |             query,
1120 |             search_vector=search_vector,
1121 |             limit=limit,
1122 |             min_score=min_score,
1123 |             routing_='r',
1124 |             **query_params,
1125 |         )
1126 | 
1127 |     communities = [get_community_node_from_record(record) for record in records]
1128 | 
1129 |     return communities
1130 | 
1131 | 
1132 | async def hybrid_node_search(
1133 |     queries: list[str],
1134 |     embeddings: list[list[float]],
1135 |     driver: GraphDriver,
1136 |     search_filter: SearchFilters,
1137 |     group_ids: list[str] | None = None,
1138 |     limit: int = RELEVANT_SCHEMA_LIMIT,
1139 | ) -> list[EntityNode]:
1140 |     """
1141 |     Perform a hybrid search for nodes using both text queries and embeddings.
1142 | 
1143 |     This method combines fulltext search and vector similarity search to find
1144 |     relevant nodes in the graph database. It uses a rrf reranker.
1145 | 
1146 |     Parameters
1147 |     ----------
1148 |     queries : list[str]
1149 |         A list of text queries to search for.
1150 |     embeddings : list[list[float]]
1151 |         A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
1152 |     driver : GraphDriver
1153 |         The Neo4j driver instance for database operations.
1154 |     group_ids : list[str] | None, optional
1155 |         The list of group ids to retrieve nodes from.
1156 |     limit : int | None, optional
1157 |         The maximum number of results to return per search method. If None, a default limit will be applied.
1158 | 
1159 |     Returns
1160 |     -------
1161 |     list[EntityNode]
1162 |         A list of unique EntityNode objects that match the search criteria.
1163 | 
1164 |     Notes
1165 |     -----
1166 |     This method performs the following steps:
1167 |     1. Executes fulltext searches for each query.
1168 |     2. Executes vector similarity searches for each embedding.
1169 |     3. Combines and deduplicates the results from both search types.
1170 |     4. Logs the performance metrics of the search operation.
1171 | 
1172 |     The search results are deduplicated based on the node UUIDs to ensure
1173 |     uniqueness in the returned list. The 'limit' parameter is applied to each
1174 |     individual search method before deduplication. If not specified, a default
1175 |     limit (defined in the individual search functions) will be used.
1176 |     """
1177 | 
1178 |     start = time()
1179 |     results: list[list[EntityNode]] = list(
1180 |         await semaphore_gather(
1181 |             *[
1182 |                 node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
1183 |                 for q in queries
1184 |             ],
1185 |             *[
1186 |                 node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
1187 |                 for e in embeddings
1188 |             ],
1189 |         )
1190 |     )
1191 | 
1192 |     node_uuid_map: dict[str, EntityNode] = {
1193 |         node.uuid: node for result in results for node in result
1194 |     }
1195 |     result_uuids = [[node.uuid for node in result] for result in results]
1196 | 
1197 |     ranked_uuids, _ = rrf(result_uuids)
1198 | 
1199 |     relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
1200 | 
1201 |     end = time()
1202 |     logger.debug(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
1203 |     return relevant_nodes
1204 | 
1205 | 
1206 | async def get_relevant_nodes(
1207 |     driver: GraphDriver,
1208 |     nodes: list[EntityNode],
1209 |     search_filter: SearchFilters,
1210 |     min_score: float = DEFAULT_MIN_SCORE,
1211 |     limit: int = RELEVANT_SCHEMA_LIMIT,
1212 | ) -> list[list[EntityNode]]:
1213 |     if len(nodes) == 0:
1214 |         return []
1215 | 
1216 |     group_id = nodes[0].group_id
1217 |     query_nodes = [
1218 |         {
1219 |             'uuid': node.uuid,
1220 |             'name': node.name,
1221 |             'name_embedding': node.name_embedding,
1222 |             'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
1223 |         }
1224 |         for node in nodes
1225 |     ]
1226 | 
1227 |     filter_queries, filter_params = node_search_filter_query_constructor(
1228 |         search_filter, driver.provider
1229 |     )
1230 | 
1231 |     filter_query = ''
1232 |     if filter_queries:
1233 |         filter_query = 'WHERE ' + (' AND '.join(filter_queries))
1234 | 
1235 |     if driver.provider == GraphProvider.KUZU:
1236 |         embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
1237 |         if embedding_size == 0:
1238 |             return []
1239 | 
1240 |         # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1241 |         query = (
1242 |             """
1243 |                                                                                                                                     UNWIND $nodes AS node
1244 |                                                                                                                                     MATCH (n:Entity {group_id: $group_id})
1245 |                                                                                                                                     """
1246 |             + filter_query
1247 |             + """
1248 |             WITH node, n, """
1249 |             + get_vector_cosine_func_query(
1250 |                 'n.name_embedding',
1251 |                 f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
1252 |                 driver.provider,
1253 |             )
1254 |             + """ AS score
1255 |             WHERE score > $min_score
1256 |             WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1257 |             """
1258 |             + get_nodes_query(
1259 |                 'node_name_and_summary',
1260 |                 'node.fulltext_query',
1261 |                 limit=limit,
1262 |                 provider=driver.provider,
1263 |             )
1264 |             + """
1265 |             WITH node AS m
1266 |             WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
1267 |             WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
1268 | 
1269 |             WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
1270 | 
1271 |             UNWIND combined_nodes AS x
1272 |             WITH node, collect(DISTINCT {
1273 |                 uuid: x.uuid,
1274 |                 name: x.name,
1275 |                 name_embedding: x.name_embedding,
1276 |                 group_id: x.group_id,
1277 |                 created_at: x.created_at,
1278 |                 summary: x.summary,
1279 |                 labels: x.labels,
1280 |                 attributes: x.attributes
1281 |             }) AS matches
1282 | 
1283 |             RETURN
1284 |             node.uuid AS search_node_uuid, matches
1285 |             """
1286 |         )
1287 |     else:
1288 |         query = (
1289 |             """
1290 |                                                                                                                                     UNWIND $nodes AS node
1291 |                                                                                                                                     MATCH (n:Entity {group_id: $group_id})
1292 |                                                                                                                                     """
1293 |             + filter_query
1294 |             + """
1295 |             WITH node, n, """
1296 |             + get_vector_cosine_func_query(
1297 |                 'n.name_embedding', 'node.name_embedding', driver.provider
1298 |             )
1299 |             + """ AS score
1300 |             WHERE score > $min_score
1301 |             WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1302 |             """
1303 |             + get_nodes_query(
1304 |                 'node_name_and_summary',
1305 |                 'node.fulltext_query',
1306 |                 limit=limit,
1307 |                 provider=driver.provider,
1308 |             )
1309 |             + """
1310 |             YIELD node AS m
1311 |             WHERE m.group_id = $group_id
1312 |             WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1313 | 
1314 |             WITH node,
1315 |                 top_vector_nodes,
1316 |                 [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1317 | 
1318 |             WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1319 | 
1320 |             UNWIND combined_nodes AS combined_node
1321 |             WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1322 | 
1323 |             RETURN
1324 |             node.uuid AS search_node_uuid,
1325 |             [x IN deduped_nodes | {
1326 |                 uuid: x.uuid,
1327 |                 name: x.name,
1328 |                 name_embedding: x.name_embedding,
1329 |                 group_id: x.group_id,
1330 |                 created_at: x.created_at,
1331 |                 summary: x.summary,
1332 |                 labels: labels(x),
1333 |                 attributes: properties(x)
1334 |             }] AS matches
1335 |             """
1336 |         )
1337 | 
1338 |     results, _, _ = await driver.execute_query(
1339 |         query,
1340 |         nodes=query_nodes,
1341 |         group_id=group_id,
1342 |         limit=limit,
1343 |         min_score=min_score,
1344 |         routing_='r',
1345 |         **filter_params,
1346 |     )
1347 | 
1348 |     relevant_nodes_dict: dict[str, list[EntityNode]] = {
1349 |         result['search_node_uuid']: [
1350 |             get_entity_node_from_record(record, driver.provider) for record in result['matches']
1351 |         ]
1352 |         for result in results
1353 |     }
1354 | 
1355 |     relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
1356 | 
1357 |     return relevant_nodes
1358 | 
1359 | 
1360 | async def get_relevant_edges(
1361 |     driver: GraphDriver,
1362 |     edges: list[EntityEdge],
1363 |     search_filter: SearchFilters,
1364 |     min_score: float = DEFAULT_MIN_SCORE,
1365 |     limit: int = RELEVANT_SCHEMA_LIMIT,
1366 | ) -> list[list[EntityEdge]]:
1367 |     if len(edges) == 0:
1368 |         return []
1369 | 
1370 |     filter_queries, filter_params = edge_search_filter_query_constructor(
1371 |         search_filter, driver.provider
1372 |     )
1373 | 
1374 |     filter_query = ''
1375 |     if filter_queries:
1376 |         filter_query = ' WHERE ' + (' AND '.join(filter_queries))
1377 | 
1378 |     if driver.provider == GraphProvider.NEPTUNE:
1379 |         query = (
1380 |             """
1381 |                                                                                                                                     UNWIND $edges AS edge
1382 |                                                                                                                                     MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1383 |                                                                                                                                     """
1384 |             + filter_query
1385 |             + """
1386 |             WITH e, edge
1387 |             RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1388 |             edge.fact_embedding as target_embedding
1389 |             """
1390 |         )
1391 |         resp, _, _ = await driver.execute_query(
1392 |             query,
1393 |             edges=[edge.model_dump() for edge in edges],
1394 |             limit=limit,
1395 |             min_score=min_score,
1396 |             routing_='r',
1397 |             **filter_params,
1398 |         )
1399 | 
1400 |         # Calculate Cosine similarity then return the edge ids
1401 |         input_ids = []
1402 |         for r in resp:
1403 |             score = calculate_cosine_similarity(
1404 |                 list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1405 |             )
1406 |             if score > min_score:
1407 |                 input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1408 | 
1409 |         # Match the edge ides and return the values
1410 |         query = """
1411 |         UNWIND $ids AS edge
1412 |         MATCH ()-[e]->()
1413 |         WHERE id(e) = edge.id
1414 |         WITH edge, e
1415 |         ORDER BY edge.score DESC
1416 |         RETURN edge.uuid AS search_edge_uuid,
1417 |             collect({
1418 |                 uuid: e.uuid,
1419 |                 source_node_uuid: startNode(e).uuid,
1420 |                 target_node_uuid: endNode(e).uuid,
1421 |                 created_at: e.created_at,
1422 |                 name: e.name,
1423 |                 group_id: e.group_id,
1424 |                 fact: e.fact,
1425 |                 fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1426 |                 episodes: split(e.episodes, ","),
1427 |                 expired_at: e.expired_at,
1428 |                 valid_at: e.valid_at,
1429 |                 invalid_at: e.invalid_at,
1430 |                 attributes: properties(e)
1431 |             })[..$limit] AS matches
1432 |                 """
1433 | 
1434 |         results, _, _ = await driver.execute_query(
1435 |             query,
1436 |             ids=input_ids,
1437 |             edges=[edge.model_dump() for edge in edges],
1438 |             limit=limit,
1439 |             min_score=min_score,
1440 |             routing_='r',
1441 |             **filter_params,
1442 |         )
1443 |     else:
1444 |         if driver.provider == GraphProvider.KUZU:
1445 |             embedding_size = (
1446 |                 len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1447 |             )
1448 |             if embedding_size == 0:
1449 |                 return []
1450 | 
1451 |             query = (
1452 |                 """
1453 |                                                                                                                                         UNWIND $edges AS edge
1454 |                                                                                                                                         MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1455 |                                                                                                                                         """
1456 |                 + filter_query
1457 |                 + """
1458 |                 WITH e, edge, n, m, """
1459 |                 + get_vector_cosine_func_query(
1460 |                     'e.fact_embedding',
1461 |                     f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1462 |                     driver.provider,
1463 |                 )
1464 |                 + """ AS score
1465 |                 WHERE score > $min_score
1466 |                 WITH e, edge, n, m, score
1467 |                 ORDER BY score DESC
1468 |                 LIMIT $limit
1469 |                 RETURN
1470 |                     edge.uuid AS search_edge_uuid,
1471 |                     collect({
1472 |                         uuid: e.uuid,
1473 |                         source_node_uuid: n.uuid,
1474 |                         target_node_uuid: m.uuid,
1475 |                         created_at: e.created_at,
1476 |                         name: e.name,
1477 |                         group_id: e.group_id,
1478 |                         fact: e.fact,
1479 |                         fact_embedding: e.fact_embedding,
1480 |                         episodes: e.episodes,
1481 |                         expired_at: e.expired_at,
1482 |                         valid_at: e.valid_at,
1483 |                         invalid_at: e.invalid_at,
1484 |                         attributes: e.attributes
1485 |                     }) AS matches
1486 |                 """
1487 |             )
1488 |         else:
1489 |             query = (
1490 |                 """
1491 |                                                                                                                                         UNWIND $edges AS edge
1492 |                                                                                                                                         MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1493 |                                                                                                                                         """
1494 |                 + filter_query
1495 |                 + """
1496 |                 WITH e, edge, """
1497 |                 + get_vector_cosine_func_query(
1498 |                     'e.fact_embedding', 'edge.fact_embedding', driver.provider
1499 |                 )
1500 |                 + """ AS score
1501 |                 WHERE score > $min_score
1502 |                 WITH edge, e, score
1503 |                 ORDER BY score DESC
1504 |                 RETURN
1505 |                     edge.uuid AS search_edge_uuid,
1506 |                     collect({
1507 |                         uuid: e.uuid,
1508 |                         source_node_uuid: startNode(e).uuid,
1509 |                         target_node_uuid: endNode(e).uuid,
1510 |                         created_at: e.created_at,
1511 |                         name: e.name,
1512 |                         group_id: e.group_id,
1513 |                         fact: e.fact,
1514 |                         fact_embedding: e.fact_embedding,
1515 |                         episodes: e.episodes,
1516 |                         expired_at: e.expired_at,
1517 |                         valid_at: e.valid_at,
1518 |                         invalid_at: e.invalid_at,
1519 |                         attributes: properties(e)
1520 |                     })[..$limit] AS matches
1521 |                 """
1522 |             )
1523 | 
1524 |         results, _, _ = await driver.execute_query(
1525 |             query,
1526 |             edges=[edge.model_dump() for edge in edges],
1527 |             limit=limit,
1528 |             min_score=min_score,
1529 |             routing_='r',
1530 |             **filter_params,
1531 |         )
1532 | 
1533 |     relevant_edges_dict: dict[str, list[EntityEdge]] = {
1534 |         result['search_edge_uuid']: [
1535 |             get_entity_edge_from_record(record, driver.provider) for record in result['matches']
1536 |         ]
1537 |         for result in results
1538 |     }
1539 | 
1540 |     relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
1541 | 
1542 |     return relevant_edges
1543 | 
1544 | 
1545 | async def get_edge_invalidation_candidates(
1546 |     driver: GraphDriver,
1547 |     edges: list[EntityEdge],
1548 |     search_filter: SearchFilters,
1549 |     min_score: float = DEFAULT_MIN_SCORE,
1550 |     limit: int = RELEVANT_SCHEMA_LIMIT,
1551 | ) -> list[list[EntityEdge]]:
1552 |     if len(edges) == 0:
1553 |         return []
1554 | 
1555 |     filter_queries, filter_params = edge_search_filter_query_constructor(
1556 |         search_filter, driver.provider
1557 |     )
1558 | 
1559 |     filter_query = ''
1560 |     if filter_queries:
1561 |         filter_query = ' AND ' + (' AND '.join(filter_queries))
1562 | 
1563 |     if driver.provider == GraphProvider.NEPTUNE:
1564 |         query = (
1565 |             """
1566 |                                                                                                                                     UNWIND $edges AS edge
1567 |                                                                                                                                     MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1568 |                                                                                                                                     WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1569 |                                                                                                                                     """
1570 |             + filter_query
1571 |             + """
1572 |             WITH e, edge
1573 |             RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
1574 |             edge.fact_embedding as target_embedding,
1575 |             edge.uuid as search_edge_uuid
1576 |             """
1577 |         )
1578 |         resp, _, _ = await driver.execute_query(
1579 |             query,
1580 |             edges=[edge.model_dump() for edge in edges],
1581 |             limit=limit,
1582 |             min_score=min_score,
1583 |             routing_='r',
1584 |             **filter_params,
1585 |         )
1586 | 
1587 |         # Calculate Cosine similarity then return the edge ids
1588 |         input_ids = []
1589 |         for r in resp:
1590 |             score = calculate_cosine_similarity(
1591 |                 list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1592 |             )
1593 |             if score > min_score:
1594 |                 input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1595 | 
1596 |         # Match the edge ides and return the values
1597 |         query = """
1598 |         UNWIND $ids AS edge
1599 |         MATCH ()-[e]->()
1600 |         WHERE id(e) = edge.id
1601 |         WITH edge, e
1602 |         ORDER BY edge.score DESC
1603 |         RETURN edge.uuid AS search_edge_uuid,
1604 |             collect({
1605 |                 uuid: e.uuid,
1606 |                 source_node_uuid: startNode(e).uuid,
1607 |                 target_node_uuid: endNode(e).uuid,
1608 |                 created_at: e.created_at,
1609 |                 name: e.name,
1610 |                 group_id: e.group_id,
1611 |                 fact: e.fact,
1612 |                 fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1613 |                 episodes: split(e.episodes, ","),
1614 |                 expired_at: e.expired_at,
1615 |                 valid_at: e.valid_at,
1616 |                 invalid_at: e.invalid_at,
1617 |                 attributes: properties(e)
1618 |             })[..$limit] AS matches
1619 |                 """
1620 |         results, _, _ = await driver.execute_query(
1621 |             query,
1622 |             ids=input_ids,
1623 |             edges=[edge.model_dump() for edge in edges],
1624 |             limit=limit,
1625 |             min_score=min_score,
1626 |             routing_='r',
1627 |             **filter_params,
1628 |         )
1629 |     else:
1630 |         if driver.provider == GraphProvider.KUZU:
1631 |             embedding_size = (
1632 |                 len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1633 |             )
1634 |             if embedding_size == 0:
1635 |                 return []
1636 | 
1637 |             query = (
1638 |                 """
1639 |                                                                                                                                         UNWIND $edges AS edge
1640 |                                                                                                                                         MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1641 |                                                                                                                                         WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1642 |                                                                                                                                         """
1643 |                 + filter_query
1644 |                 + """
1645 |                 WITH edge, e, n, m, """
1646 |                 + get_vector_cosine_func_query(
1647 |                     'e.fact_embedding',
1648 |                     f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1649 |                     driver.provider,
1650 |                 )
1651 |                 + """ AS score
1652 |                 WHERE score > $min_score
1653 |                 WITH edge, e, n, m, score
1654 |                 ORDER BY score DESC
1655 |                 LIMIT $limit
1656 |                 RETURN
1657 |                     edge.uuid AS search_edge_uuid,
1658 |                     collect({
1659 |                         uuid: e.uuid,
1660 |                         source_node_uuid: n.uuid,
1661 |                         target_node_uuid: m.uuid,
1662 |                         created_at: e.created_at,
1663 |                         name: e.name,
1664 |                         group_id: e.group_id,
1665 |                         fact: e.fact,
1666 |                         fact_embedding: e.fact_embedding,
1667 |                         episodes: e.episodes,
1668 |                         expired_at: e.expired_at,
1669 |                         valid_at: e.valid_at,
1670 |                         invalid_at: e.invalid_at,
1671 |                         attributes: e.attributes
1672 |                     }) AS matches
1673 |                 """
1674 |             )
1675 |         else:
1676 |             query = (
1677 |                 """
1678 |                                                                                                                                         UNWIND $edges AS edge
1679 |                                                                                                                                         MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1680 |                                                                                                                                         WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1681 |                                                                                                                                         """
1682 |                 + filter_query
1683 |                 + """
1684 |                 WITH edge, e, """
1685 |                 + get_vector_cosine_func_query(
1686 |                     'e.fact_embedding', 'edge.fact_embedding', driver.provider
1687 |                 )
1688 |                 + """ AS score
1689 |                 WHERE score > $min_score
1690 |                 WITH edge, e, score
1691 |                 ORDER BY score DESC
1692 |                 RETURN
1693 |                     edge.uuid AS search_edge_uuid,
1694 |                     collect({
1695 |                         uuid: e.uuid,
1696 |                         source_node_uuid: startNode(e).uuid,
1697 |                         target_node_uuid: endNode(e).uuid,
1698 |                         created_at: e.created_at,
1699 |                         name: e.name,
1700 |                         group_id: e.group_id,
1701 |                         fact: e.fact,
1702 |                         fact_embedding: e.fact_embedding,
1703 |                         episodes: e.episodes,
1704 |                         expired_at: e.expired_at,
1705 |                         valid_at: e.valid_at,
1706 |                         invalid_at: e.invalid_at,
1707 |                         attributes: properties(e)
1708 |                     })[..$limit] AS matches
1709 |                 """
1710 |             )
1711 | 
1712 |         results, _, _ = await driver.execute_query(
1713 |             query,
1714 |             edges=[edge.model_dump() for edge in edges],
1715 |             limit=limit,
1716 |             min_score=min_score,
1717 |             routing_='r',
1718 |             **filter_params,
1719 |         )
1720 |     invalidation_edges_dict: dict[str, list[EntityEdge]] = {
1721 |         result['search_edge_uuid']: [
1722 |             get_entity_edge_from_record(record, driver.provider) for record in result['matches']
1723 |         ]
1724 |         for result in results
1725 |     }
1726 | 
1727 |     invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
1728 | 
1729 |     return invalidation_edges
1730 | 
1731 | 
1732 | # takes in a list of rankings of uuids
1733 | def rrf(
1734 |     results: list[list[str]], rank_const=1, min_score: float = 0
1735 | ) -> tuple[list[str], list[float]]:
1736 |     scores: dict[str, float] = defaultdict(float)
1737 |     for result in results:
1738 |         for i, uuid in enumerate(result):
1739 |             scores[uuid] += 1 / (i + rank_const)
1740 | 
1741 |     scored_uuids = [term for term in scores.items()]
1742 |     scored_uuids.sort(reverse=True, key=lambda term: term[1])
1743 | 
1744 |     sorted_uuids = [term[0] for term in scored_uuids]
1745 | 
1746 |     return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1747 |         scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1748 |     ]
1749 | 
1750 | 
1751 | async def node_distance_reranker(
1752 |     driver: GraphDriver,
1753 |     node_uuids: list[str],
1754 |     center_node_uuid: str,
1755 |     min_score: float = 0,
1756 | ) -> tuple[list[str], list[float]]:
1757 |     # filter out node_uuid center node node uuid
1758 |     filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
1759 |     scores: dict[str, float] = {center_node_uuid: 0.0}
1760 | 
1761 |     query = """
1762 |     UNWIND $node_uuids AS node_uuid
1763 |     MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1764 |     RETURN 1 AS score, node_uuid AS uuid
1765 |     """
1766 |     if driver.provider == GraphProvider.KUZU:
1767 |         query = """
1768 |         UNWIND $node_uuids AS node_uuid
1769 |         MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
1770 |         RETURN 1 AS score, node_uuid AS uuid
1771 |         """
1772 | 
1773 |     # Find the shortest path to center node
1774 |     results, header, _ = await driver.execute_query(
1775 |         query,
1776 |         node_uuids=filtered_uuids,
1777 |         center_uuid=center_node_uuid,
1778 |         routing_='r',
1779 |     )
1780 |     if driver.provider == GraphProvider.FALKORDB:
1781 |         results = [dict(zip(header, row, strict=True)) for row in results]
1782 | 
1783 |     for result in results:
1784 |         uuid = result['uuid']
1785 |         score = result['score']
1786 |         scores[uuid] = score
1787 | 
1788 |     for uuid in filtered_uuids:
1789 |         if uuid not in scores:
1790 |             scores[uuid] = float('inf')
1791 | 
1792 |     # rerank on shortest distance
1793 |     filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
1794 | 
1795 |     # add back in filtered center uuid if it was filtered out
1796 |     if center_node_uuid in node_uuids:
1797 |         scores[center_node_uuid] = 0.1
1798 |         filtered_uuids = [center_node_uuid] + filtered_uuids
1799 | 
1800 |     return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1801 |         1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
1802 |     ]
1803 | 
1804 | 
1805 | async def episode_mentions_reranker(
1806 |     driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
1807 | ) -> tuple[list[str], list[float]]:
1808 |     # use rrf as a preliminary ranker
1809 |     sorted_uuids, _ = rrf(node_uuids)
1810 |     scores: dict[str, float] = {}
1811 | 
1812 |     # Find the shortest path to center node
1813 |     results, _, _ = await driver.execute_query(
1814 |         """
1815 |         UNWIND $node_uuids AS node_uuid
1816 |         MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
1817 |         RETURN count(*) AS score, n.uuid AS uuid
1818 |         """,
1819 |         node_uuids=sorted_uuids,
1820 |         routing_='r',
1821 |     )
1822 | 
1823 |     for result in results:
1824 |         scores[result['uuid']] = result['score']
1825 | 
1826 |     for uuid in sorted_uuids:
1827 |         if uuid not in scores:
1828 |             scores[uuid] = float('inf')
1829 | 
1830 |     # rerank on shortest distance
1831 |     sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
1832 | 
1833 |     return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1834 |         scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1835 |     ]
1836 | 
1837 | 
1838 | def maximal_marginal_relevance(
1839 |     query_vector: list[float],
1840 |     candidates: dict[str, list[float]],
1841 |     mmr_lambda: float = DEFAULT_MMR_LAMBDA,
1842 |     min_score: float = -2.0,
1843 | ) -> tuple[list[str], list[float]]:
1844 |     start = time()
1845 |     query_array = np.array(query_vector)
1846 |     candidate_arrays: dict[str, NDArray] = {}
1847 |     for uuid, embedding in candidates.items():
1848 |         candidate_arrays[uuid] = normalize_l2(embedding)
1849 | 
1850 |     uuids: list[str] = list(candidate_arrays.keys())
1851 | 
1852 |     similarity_matrix = np.zeros((len(uuids), len(uuids)))
1853 | 
1854 |     for i, uuid_1 in enumerate(uuids):
1855 |         for j, uuid_2 in enumerate(uuids[:i]):
1856 |             u = candidate_arrays[uuid_1]
1857 |             v = candidate_arrays[uuid_2]
1858 |             similarity = np.dot(u, v)
1859 | 
1860 |             similarity_matrix[i, j] = similarity
1861 |             similarity_matrix[j, i] = similarity
1862 | 
1863 |     mmr_scores: dict[str, float] = {}
1864 |     for i, uuid in enumerate(uuids):
1865 |         max_sim = np.max(similarity_matrix[i, :])
1866 |         mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
1867 |         mmr_scores[uuid] = mmr
1868 | 
1869 |     uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
1870 | 
1871 |     end = time()
1872 |     logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
1873 | 
1874 |     return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
1875 |         mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
1876 |     ]
1877 | 
1878 | 
1879 | async def get_embeddings_for_nodes(
1880 |     driver: GraphDriver, nodes: list[EntityNode]
1881 | ) -> dict[str, list[float]]:
1882 |     if driver.graph_operations_interface:
1883 |         return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
1884 |     elif driver.provider == GraphProvider.NEPTUNE:
1885 |         query = """
1886 |         MATCH (n:Entity)
1887 |         WHERE n.uuid IN $node_uuids
1888 |         RETURN DISTINCT
1889 |             n.uuid AS uuid,
1890 |             split(n.name_embedding, ",") AS name_embedding
1891 |         """
1892 |     else:
1893 |         query = """
1894 |         MATCH (n:Entity)
1895 |         WHERE n.uuid IN $node_uuids
1896 |         RETURN DISTINCT
1897 |             n.uuid AS uuid,
1898 |             n.name_embedding AS name_embedding
1899 |         """
1900 |     results, _, _ = await driver.execute_query(
1901 |         query,
1902 |         node_uuids=[node.uuid for node in nodes],
1903 |         routing_='r',
1904 |     )
1905 | 
1906 |     embeddings_dict: dict[str, list[float]] = {}
1907 |     for result in results:
1908 |         uuid: str = result.get('uuid')
1909 |         embedding: list[float] = result.get('name_embedding')
1910 |         if uuid is not None and embedding is not None:
1911 |             embeddings_dict[uuid] = embedding
1912 | 
1913 |     return embeddings_dict
1914 | 
1915 | 
1916 | async def get_embeddings_for_communities(
1917 |     driver: GraphDriver, communities: list[CommunityNode]
1918 | ) -> dict[str, list[float]]:
1919 |     if driver.provider == GraphProvider.NEPTUNE:
1920 |         query = """
1921 |         MATCH (c:Community)
1922 |         WHERE c.uuid IN $community_uuids
1923 |         RETURN DISTINCT
1924 |             c.uuid AS uuid,
1925 |             split(c.name_embedding, ",") AS name_embedding
1926 |         """
1927 |     else:
1928 |         query = """
1929 |         MATCH (c:Community)
1930 |         WHERE c.uuid IN $community_uuids
1931 |         RETURN DISTINCT
1932 |             c.uuid AS uuid,
1933 |             c.name_embedding AS name_embedding
1934 |         """
1935 |     results, _, _ = await driver.execute_query(
1936 |         query,
1937 |         community_uuids=[community.uuid for community in communities],
1938 |         routing_='r',
1939 |     )
1940 | 
1941 |     embeddings_dict: dict[str, list[float]] = {}
1942 |     for result in results:
1943 |         uuid: str = result.get('uuid')
1944 |         embedding: list[float] = result.get('name_embedding')
1945 |         if uuid is not None and embedding is not None:
1946 |             embeddings_dict[uuid] = embedding
1947 | 
1948 |     return embeddings_dict
1949 | 
1950 | 
1951 | async def get_embeddings_for_edges(
1952 |     driver: GraphDriver, edges: list[EntityEdge]
1953 | ) -> dict[str, list[float]]:
1954 |     if driver.graph_operations_interface:
1955 |         return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
1956 |     elif driver.provider == GraphProvider.NEPTUNE:
1957 |         query = """
1958 |         MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1959 |         WHERE e.uuid IN $edge_uuids
1960 |         RETURN DISTINCT
1961 |             e.uuid AS uuid,
1962 |             split(e.fact_embedding, ",") AS fact_embedding
1963 |         """
1964 |     else:
1965 |         match_query = """
1966 |             MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1967 |         """
1968 |         if driver.provider == GraphProvider.KUZU:
1969 |             match_query = """
1970 |                 MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
1971 |             """
1972 | 
1973 |         query = (
1974 |             match_query
1975 |             + """
1976 |         WHERE e.uuid IN $edge_uuids
1977 |         RETURN DISTINCT
1978 |             e.uuid AS uuid,
1979 |             e.fact_embedding AS fact_embedding
1980 |         """
1981 |         )
1982 |     results, _, _ = await driver.execute_query(
1983 |         query,
1984 |         edge_uuids=[edge.uuid for edge in edges],
1985 |         routing_='r',
1986 |     )
1987 | 
1988 |     embeddings_dict: dict[str, list[float]] = {}
1989 |     for result in results:
1990 |         uuid: str = result.get('uuid')
1991 |         embedding: list[float] = result.get('fact_embedding')
1992 |         if uuid is not None and embedding is not None:
1993 |             embeddings_dict[uuid] = embedding
1994 | 
1995 |     return embeddings_dict
1996 | 
```
Page 11/12FirstPrevNextLast