This is page 11 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/graphiti_core/search/search_utils.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from collections import defaultdict
19 | from time import time
20 | from typing import Any
21 |
22 | import numpy as np
23 | from numpy._typing import NDArray
24 | from typing_extensions import LiteralString
25 |
26 | from graphiti_core.driver.driver import (
27 | GraphDriver,
28 | GraphProvider,
29 | )
30 | from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
31 | from graphiti_core.graph_queries import (
32 | get_nodes_query,
33 | get_relationships_query,
34 | get_vector_cosine_func_query,
35 | )
36 | from graphiti_core.helpers import (
37 | lucene_sanitize,
38 | normalize_l2,
39 | semaphore_gather,
40 | )
41 | from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
42 | from graphiti_core.models.nodes.node_db_queries import (
43 | COMMUNITY_NODE_RETURN,
44 | EPISODIC_NODE_RETURN,
45 | get_entity_node_return_query,
46 | )
47 | from graphiti_core.nodes import (
48 | CommunityNode,
49 | EntityNode,
50 | EpisodicNode,
51 | get_community_node_from_record,
52 | get_entity_node_from_record,
53 | get_episodic_node_from_record,
54 | )
55 | from graphiti_core.search.search_filters import (
56 | SearchFilters,
57 | edge_search_filter_query_constructor,
58 | node_search_filter_query_constructor,
59 | )
60 |
61 | logger = logging.getLogger(__name__)
62 |
63 | RELEVANT_SCHEMA_LIMIT = 10
64 | DEFAULT_MIN_SCORE = 0.6
65 | DEFAULT_MMR_LAMBDA = 0.5
66 | MAX_SEARCH_DEPTH = 3
67 | MAX_QUERY_LENGTH = 128
68 |
69 |
70 | def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
71 | """
72 | Calculates the cosine similarity between two vectors using NumPy.
73 | """
74 | dot_product = np.dot(vector1, vector2)
75 | norm_vector1 = np.linalg.norm(vector1)
76 | norm_vector2 = np.linalg.norm(vector2)
77 |
78 | if norm_vector1 == 0 or norm_vector2 == 0:
79 | return 0 # Handle cases where one or both vectors are zero vectors
80 |
81 | return dot_product / (norm_vector1 * norm_vector2)
82 |
83 |
84 | def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
85 | if driver.provider == GraphProvider.KUZU:
86 | # Kuzu only supports simple queries.
87 | if len(query.split(' ')) > MAX_QUERY_LENGTH:
88 | return ''
89 | return query
90 | elif driver.provider == GraphProvider.FALKORDB:
91 | return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
92 | group_ids_filter_list = (
93 | [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
94 | if group_ids is not None
95 | else []
96 | )
97 | group_ids_filter = ''
98 | for f in group_ids_filter_list:
99 | group_ids_filter += f if not group_ids_filter else f' OR {f}'
100 |
101 | group_ids_filter += ' AND ' if group_ids_filter else ''
102 |
103 | lucene_query = lucene_sanitize(query)
104 | # If the lucene query is too long return no query
105 | if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
106 | return ''
107 |
108 | full_query = group_ids_filter + '(' + lucene_query + ')'
109 |
110 | return full_query
111 |
112 |
113 | async def get_episodes_by_mentions(
114 | driver: GraphDriver,
115 | nodes: list[EntityNode],
116 | edges: list[EntityEdge],
117 | limit: int = RELEVANT_SCHEMA_LIMIT,
118 | ) -> list[EpisodicNode]:
119 | episode_uuids: list[str] = []
120 | for edge in edges:
121 | episode_uuids.extend(edge.episodes)
122 |
123 | episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])
124 |
125 | return episodes
126 |
127 |
128 | async def get_mentioned_nodes(
129 | driver: GraphDriver, episodes: list[EpisodicNode]
130 | ) -> list[EntityNode]:
131 | episode_uuids = [episode.uuid for episode in episodes]
132 |
133 | records, _, _ = await driver.execute_query(
134 | """
135 | MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
136 | WHERE episode.uuid IN $uuids
137 | RETURN DISTINCT
138 | """
139 | + get_entity_node_return_query(driver.provider),
140 | uuids=episode_uuids,
141 | routing_='r',
142 | )
143 |
144 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
145 |
146 | return nodes
147 |
148 |
149 | async def get_communities_by_nodes(
150 | driver: GraphDriver, nodes: list[EntityNode]
151 | ) -> list[CommunityNode]:
152 | node_uuids = [node.uuid for node in nodes]
153 |
154 | records, _, _ = await driver.execute_query(
155 | """
156 | MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
157 | WHERE m.uuid IN $uuids
158 | RETURN DISTINCT
159 | """
160 | + COMMUNITY_NODE_RETURN,
161 | uuids=node_uuids,
162 | routing_='r',
163 | )
164 |
165 | communities = [get_community_node_from_record(record) for record in records]
166 |
167 | return communities
168 |
169 |
170 | async def edge_fulltext_search(
171 | driver: GraphDriver,
172 | query: str,
173 | search_filter: SearchFilters,
174 | group_ids: list[str] | None = None,
175 | limit=RELEVANT_SCHEMA_LIMIT,
176 | ) -> list[EntityEdge]:
177 | if driver.search_interface:
178 | return await driver.search_interface.edge_fulltext_search(
179 | driver, query, search_filter, group_ids, limit
180 | )
181 |
182 | # fulltext search over facts
183 | fuzzy_query = fulltext_query(query, group_ids, driver)
184 |
185 | if fuzzy_query == '':
186 | return []
187 |
188 | match_query = """
189 | YIELD relationship AS rel, score
190 | MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
191 | """
192 | if driver.provider == GraphProvider.KUZU:
193 | match_query = """
194 | YIELD node, score
195 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
196 | """
197 |
198 | filter_queries, filter_params = edge_search_filter_query_constructor(
199 | search_filter, driver.provider
200 | )
201 |
202 | if group_ids is not None:
203 | filter_queries.append('e.group_id IN $group_ids')
204 | filter_params['group_ids'] = group_ids
205 |
206 | filter_query = ''
207 | if filter_queries:
208 | filter_query = ' WHERE ' + (' AND '.join(filter_queries))
209 |
210 | if driver.provider == GraphProvider.NEPTUNE:
211 | res = driver.run_aoss_query('edge_name_and_fact', query) # pyright: ignore reportAttributeAccessIssue
212 | if res['hits']['total']['value'] > 0:
213 | input_ids = []
214 | for r in res['hits']['hits']:
215 | input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
216 |
217 | # Match the edge ids and return the values
218 | query = (
219 | """
220 | UNWIND $ids as id
221 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
222 | WHERE e.group_id IN $group_ids
223 | AND id(e)=id
224 | """
225 | + filter_query
226 | + """
227 | AND id(e)=id
228 | WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
229 | RETURN
230 | e.uuid AS uuid,
231 | e.group_id AS group_id,
232 | n.uuid AS source_node_uuid,
233 | m.uuid AS target_node_uuid,
234 | e.created_at AS created_at,
235 | e.name AS name,
236 | e.fact AS fact,
237 | split(e.episodes, ",") AS episodes,
238 | e.expired_at AS expired_at,
239 | e.valid_at AS valid_at,
240 | e.invalid_at AS invalid_at,
241 | properties(e) AS attributes
242 | ORDER BY score DESC LIMIT $limit
243 | """
244 | )
245 |
246 | records, _, _ = await driver.execute_query(
247 | query,
248 | query=fuzzy_query,
249 | ids=input_ids,
250 | limit=limit,
251 | routing_='r',
252 | **filter_params,
253 | )
254 | else:
255 | return []
256 | else:
257 | query = (
258 | get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
259 | + match_query
260 | + filter_query
261 | + """
262 | WITH e, score, n, m
263 | RETURN
264 | """
265 | + get_entity_edge_return_query(driver.provider)
266 | + """
267 | ORDER BY score DESC
268 | LIMIT $limit
269 | """
270 | )
271 |
272 | records, _, _ = await driver.execute_query(
273 | query,
274 | query=fuzzy_query,
275 | limit=limit,
276 | routing_='r',
277 | **filter_params,
278 | )
279 |
280 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
281 |
282 | return edges
283 |
284 |
285 | async def edge_similarity_search(
286 | driver: GraphDriver,
287 | search_vector: list[float],
288 | source_node_uuid: str | None,
289 | target_node_uuid: str | None,
290 | search_filter: SearchFilters,
291 | group_ids: list[str] | None = None,
292 | limit: int = RELEVANT_SCHEMA_LIMIT,
293 | min_score: float = DEFAULT_MIN_SCORE,
294 | ) -> list[EntityEdge]:
295 | if driver.search_interface:
296 | return await driver.search_interface.edge_similarity_search(
297 | driver,
298 | search_vector,
299 | source_node_uuid,
300 | target_node_uuid,
301 | search_filter,
302 | group_ids,
303 | limit,
304 | min_score,
305 | )
306 |
307 | match_query = """
308 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
309 | """
310 | if driver.provider == GraphProvider.KUZU:
311 | match_query = """
312 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
313 | """
314 |
315 | filter_queries, filter_params = edge_search_filter_query_constructor(
316 | search_filter, driver.provider
317 | )
318 |
319 | if group_ids is not None:
320 | filter_queries.append('e.group_id IN $group_ids')
321 | filter_params['group_ids'] = group_ids
322 |
323 | if source_node_uuid is not None:
324 | filter_params['source_uuid'] = source_node_uuid
325 | filter_queries.append('n.uuid = $source_uuid')
326 |
327 | if target_node_uuid is not None:
328 | filter_params['target_uuid'] = target_node_uuid
329 | filter_queries.append('m.uuid = $target_uuid')
330 |
331 | filter_query = ''
332 | if filter_queries:
333 | filter_query = ' WHERE ' + (' AND '.join(filter_queries))
334 |
335 | search_vector_var = '$search_vector'
336 | if driver.provider == GraphProvider.KUZU:
337 | search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
338 |
339 | if driver.provider == GraphProvider.NEPTUNE:
340 | query = (
341 | """
342 | MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
343 | """
344 | + filter_query
345 | + """
346 | RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
347 | """
348 | )
349 | resp, header, _ = await driver.execute_query(
350 | query,
351 | search_vector=search_vector,
352 | limit=limit,
353 | min_score=min_score,
354 | routing_='r',
355 | **filter_params,
356 | )
357 |
358 | if len(resp) > 0:
359 | # Calculate Cosine similarity then return the edge ids
360 | input_ids = []
361 | for r in resp:
362 | if r['embedding']:
363 | score = calculate_cosine_similarity(
364 | search_vector, list(map(float, r['embedding'].split(',')))
365 | )
366 | if score > min_score:
367 | input_ids.append({'id': r['id'], 'score': score})
368 |
369 | # Match the edge ides and return the values
370 | query = """
371 | UNWIND $ids as i
372 | MATCH ()-[r]->()
373 | WHERE id(r) = i.id
374 | RETURN
375 | r.uuid AS uuid,
376 | r.group_id AS group_id,
377 | startNode(r).uuid AS source_node_uuid,
378 | endNode(r).uuid AS target_node_uuid,
379 | r.created_at AS created_at,
380 | r.name AS name,
381 | r.fact AS fact,
382 | split(r.episodes, ",") AS episodes,
383 | r.expired_at AS expired_at,
384 | r.valid_at AS valid_at,
385 | r.invalid_at AS invalid_at,
386 | properties(r) AS attributes
387 | ORDER BY i.score DESC
388 | LIMIT $limit
389 | """
390 | records, _, _ = await driver.execute_query(
391 | query,
392 | ids=input_ids,
393 | search_vector=search_vector,
394 | limit=limit,
395 | min_score=min_score,
396 | routing_='r',
397 | **filter_params,
398 | )
399 | else:
400 | return []
401 | else:
402 | query = (
403 | match_query
404 | + filter_query
405 | + """
406 | WITH DISTINCT e, n, m, """
407 | + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
408 | + """ AS score
409 | WHERE score > $min_score
410 | RETURN
411 | """
412 | + get_entity_edge_return_query(driver.provider)
413 | + """
414 | ORDER BY score DESC
415 | LIMIT $limit
416 | """
417 | )
418 |
419 | records, _, _ = await driver.execute_query(
420 | query,
421 | search_vector=search_vector,
422 | limit=limit,
423 | min_score=min_score,
424 | routing_='r',
425 | **filter_params,
426 | )
427 |
428 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
429 |
430 | return edges
431 |
432 |
433 | async def edge_bfs_search(
434 | driver: GraphDriver,
435 | bfs_origin_node_uuids: list[str] | None,
436 | bfs_max_depth: int,
437 | search_filter: SearchFilters,
438 | group_ids: list[str] | None = None,
439 | limit: int = RELEVANT_SCHEMA_LIMIT,
440 | ) -> list[EntityEdge]:
441 | # vector similarity search over embedded facts
442 | if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
443 | return []
444 |
445 | filter_queries, filter_params = edge_search_filter_query_constructor(
446 | search_filter, driver.provider
447 | )
448 |
449 | if group_ids is not None:
450 | filter_queries.append('e.group_id IN $group_ids')
451 | filter_params['group_ids'] = group_ids
452 |
453 | filter_query = ''
454 | if filter_queries:
455 | filter_query = ' WHERE ' + (' AND '.join(filter_queries))
456 |
457 | if driver.provider == GraphProvider.KUZU:
458 | # Kuzu stores entity edges twice with an intermediate node, so we need to match them
459 | # separately for the correct BFS depth.
460 | depth = bfs_max_depth * 2 - 1
461 | match_queries = [
462 | f"""
463 | UNWIND $bfs_origin_node_uuids AS origin_uuid
464 | MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
465 | UNWIND nodes(path) AS relNode
466 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
467 | """,
468 | ]
469 | if bfs_max_depth > 1:
470 | depth = (bfs_max_depth - 1) * 2 - 1
471 | match_queries.append(f"""
472 | UNWIND $bfs_origin_node_uuids AS origin_uuid
473 | MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
474 | UNWIND nodes(path) AS relNode
475 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
476 | """)
477 |
478 | records = []
479 | for match_query in match_queries:
480 | sub_records, _, _ = await driver.execute_query(
481 | match_query
482 | + filter_query
483 | + """
484 | RETURN DISTINCT
485 | """
486 | + get_entity_edge_return_query(driver.provider)
487 | + """
488 | LIMIT $limit
489 | """,
490 | bfs_origin_node_uuids=bfs_origin_node_uuids,
491 | limit=limit,
492 | routing_='r',
493 | **filter_params,
494 | )
495 | records.extend(sub_records)
496 | else:
497 | if driver.provider == GraphProvider.NEPTUNE:
498 | query = (
499 | f"""
500 | UNWIND $bfs_origin_node_uuids AS origin_uuid
501 | MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
502 | WHERE origin:Entity OR origin:Episodic
503 | UNWIND relationships(path) AS rel
504 | MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
505 | """
506 | + filter_query
507 | + """
508 | RETURN DISTINCT
509 | e.uuid AS uuid,
510 | e.group_id AS group_id,
511 | startNode(e).uuid AS source_node_uuid,
512 | endNode(e).uuid AS target_node_uuid,
513 | e.created_at AS created_at,
514 | e.name AS name,
515 | e.fact AS fact,
516 | split(e.episodes, ',') AS episodes,
517 | e.expired_at AS expired_at,
518 | e.valid_at AS valid_at,
519 | e.invalid_at AS invalid_at,
520 | properties(e) AS attributes
521 | LIMIT $limit
522 | """
523 | )
524 | else:
525 | query = (
526 | f"""
527 | UNWIND $bfs_origin_node_uuids AS origin_uuid
528 | MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
529 | UNWIND relationships(path) AS rel
530 | MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
531 | """
532 | + filter_query
533 | + """
534 | RETURN DISTINCT
535 | """
536 | + get_entity_edge_return_query(driver.provider)
537 | + """
538 | LIMIT $limit
539 | """
540 | )
541 |
542 | records, _, _ = await driver.execute_query(
543 | query,
544 | bfs_origin_node_uuids=bfs_origin_node_uuids,
545 | depth=bfs_max_depth,
546 | limit=limit,
547 | routing_='r',
548 | **filter_params,
549 | )
550 |
551 | edges = [get_entity_edge_from_record(record, driver.provider) for record in records]
552 |
553 | return edges
554 |
555 |
556 | async def node_fulltext_search(
557 | driver: GraphDriver,
558 | query: str,
559 | search_filter: SearchFilters,
560 | group_ids: list[str] | None = None,
561 | limit=RELEVANT_SCHEMA_LIMIT,
562 | ) -> list[EntityNode]:
563 | if driver.search_interface:
564 | return await driver.search_interface.node_fulltext_search(
565 | driver, query, search_filter, group_ids, limit
566 | )
567 |
568 | # BM25 search to get top nodes
569 | fuzzy_query = fulltext_query(query, group_ids, driver)
570 | if fuzzy_query == '':
571 | return []
572 |
573 | filter_queries, filter_params = node_search_filter_query_constructor(
574 | search_filter, driver.provider
575 | )
576 |
577 | if group_ids is not None:
578 | filter_queries.append('n.group_id IN $group_ids')
579 | filter_params['group_ids'] = group_ids
580 |
581 | filter_query = ''
582 | if filter_queries:
583 | filter_query = ' WHERE ' + (' AND '.join(filter_queries))
584 |
585 | yield_query = 'YIELD node AS n, score'
586 | if driver.provider == GraphProvider.KUZU:
587 | yield_query = 'WITH node AS n, score'
588 |
589 | if driver.provider == GraphProvider.NEPTUNE:
590 | res = driver.run_aoss_query('node_name_and_summary', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
591 | if res['hits']['total']['value'] > 0:
592 | input_ids = []
593 | for r in res['hits']['hits']:
594 | input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
595 |
596 | # Match the edge ides and return the values
597 | query = (
598 | """
599 | UNWIND $ids as i
600 | MATCH (n:Entity)
601 | WHERE n.uuid=i.id
602 | RETURN
603 | """
604 | + get_entity_node_return_query(driver.provider)
605 | + """
606 | ORDER BY i.score DESC
607 | LIMIT $limit
608 | """
609 | )
610 | records, _, _ = await driver.execute_query(
611 | query,
612 | ids=input_ids,
613 | query=fuzzy_query,
614 | limit=limit,
615 | routing_='r',
616 | **filter_params,
617 | )
618 | else:
619 | return []
620 | else:
621 | query = (
622 | get_nodes_query(
623 | 'node_name_and_summary', '$query', limit=limit, provider=driver.provider
624 | )
625 | + yield_query
626 | + filter_query
627 | + """
628 | WITH n, score
629 | ORDER BY score DESC
630 | LIMIT $limit
631 | RETURN
632 | """
633 | + get_entity_node_return_query(driver.provider)
634 | )
635 |
636 | records, _, _ = await driver.execute_query(
637 | query,
638 | query=fuzzy_query,
639 | limit=limit,
640 | routing_='r',
641 | **filter_params,
642 | )
643 |
644 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
645 |
646 | return nodes
647 |
648 |
649 | async def node_similarity_search(
650 | driver: GraphDriver,
651 | search_vector: list[float],
652 | search_filter: SearchFilters,
653 | group_ids: list[str] | None = None,
654 | limit=RELEVANT_SCHEMA_LIMIT,
655 | min_score: float = DEFAULT_MIN_SCORE,
656 | ) -> list[EntityNode]:
657 | if driver.search_interface:
658 | return await driver.search_interface.node_similarity_search(
659 | driver, search_vector, search_filter, group_ids, limit, min_score
660 | )
661 |
662 | filter_queries, filter_params = node_search_filter_query_constructor(
663 | search_filter, driver.provider
664 | )
665 |
666 | if group_ids is not None:
667 | filter_queries.append('n.group_id IN $group_ids')
668 | filter_params['group_ids'] = group_ids
669 |
670 | filter_query = ''
671 | if filter_queries:
672 | filter_query = ' WHERE ' + (' AND '.join(filter_queries))
673 |
674 | search_vector_var = '$search_vector'
675 | if driver.provider == GraphProvider.KUZU:
676 | search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
677 |
678 | if driver.provider == GraphProvider.NEPTUNE:
679 | query = (
680 | """
681 | MATCH (n:Entity)
682 | """
683 | + filter_query
684 | + """
685 | RETURN DISTINCT id(n) as id, n.name_embedding as embedding
686 | """
687 | )
688 | resp, header, _ = await driver.execute_query(
689 | query,
690 | params=filter_params,
691 | search_vector=search_vector,
692 | limit=limit,
693 | min_score=min_score,
694 | routing_='r',
695 | )
696 |
697 | if len(resp) > 0:
698 | # Calculate Cosine similarity then return the edge ids
699 | input_ids = []
700 | for r in resp:
701 | if r['embedding']:
702 | score = calculate_cosine_similarity(
703 | search_vector, list(map(float, r['embedding'].split(',')))
704 | )
705 | if score > min_score:
706 | input_ids.append({'id': r['id'], 'score': score})
707 |
708 | # Match the edge ides and return the values
709 | query = (
710 | """
711 | UNWIND $ids as i
712 | MATCH (n:Entity)
713 | WHERE id(n)=i.id
714 | RETURN
715 | """
716 | + get_entity_node_return_query(driver.provider)
717 | + """
718 | ORDER BY i.score DESC
719 | LIMIT $limit
720 | """
721 | )
722 | records, header, _ = await driver.execute_query(
723 | query,
724 | ids=input_ids,
725 | search_vector=search_vector,
726 | limit=limit,
727 | min_score=min_score,
728 | routing_='r',
729 | **filter_params,
730 | )
731 | else:
732 | return []
733 | else:
734 | query = (
735 | """
736 | MATCH (n:Entity)
737 | """
738 | + filter_query
739 | + """
740 | WITH n, """
741 | + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
742 | + """ AS score
743 | WHERE score > $min_score
744 | RETURN
745 | """
746 | + get_entity_node_return_query(driver.provider)
747 | + """
748 | ORDER BY score DESC
749 | LIMIT $limit
750 | """
751 | )
752 |
753 | records, _, _ = await driver.execute_query(
754 | query,
755 | search_vector=search_vector,
756 | limit=limit,
757 | min_score=min_score,
758 | routing_='r',
759 | **filter_params,
760 | )
761 |
762 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
763 |
764 | return nodes
765 |
766 |
767 | async def node_bfs_search(
768 | driver: GraphDriver,
769 | bfs_origin_node_uuids: list[str] | None,
770 | search_filter: SearchFilters,
771 | bfs_max_depth: int,
772 | group_ids: list[str] | None = None,
773 | limit: int = RELEVANT_SCHEMA_LIMIT,
774 | ) -> list[EntityNode]:
775 | if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
776 | return []
777 |
778 | filter_queries, filter_params = node_search_filter_query_constructor(
779 | search_filter, driver.provider
780 | )
781 |
782 | if group_ids is not None:
783 | filter_queries.append('n.group_id IN $group_ids')
784 | filter_queries.append('origin.group_id IN $group_ids')
785 | filter_params['group_ids'] = group_ids
786 |
787 | filter_query = ''
788 | if filter_queries:
789 | filter_query = ' AND ' + (' AND '.join(filter_queries))
790 |
791 | match_queries = [
792 | f"""
793 | UNWIND $bfs_origin_node_uuids AS origin_uuid
794 | MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
795 | WHERE n.group_id = origin.group_id
796 | """
797 | ]
798 |
799 | if driver.provider == GraphProvider.NEPTUNE:
800 | match_queries = [
801 | f"""
802 | UNWIND $bfs_origin_node_uuids AS origin_uuid
803 | MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
804 | WHERE origin:Entity OR origin.Episode
805 | AND n.group_id = origin.group_id
806 | """
807 | ]
808 |
809 | if driver.provider == GraphProvider.KUZU:
810 | depth = bfs_max_depth * 2
811 | match_queries = [
812 | """
813 | UNWIND $bfs_origin_node_uuids AS origin_uuid
814 | MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
815 | WHERE n.group_id = origin.group_id
816 | """,
817 | f"""
818 | UNWIND $bfs_origin_node_uuids AS origin_uuid
819 | MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
820 | WHERE n.group_id = origin.group_id
821 | """,
822 | ]
823 | if bfs_max_depth > 1:
824 | depth = (bfs_max_depth - 1) * 2
825 | match_queries.append(f"""
826 | UNWIND $bfs_origin_node_uuids AS origin_uuid
827 | MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
828 | WHERE n.group_id = origin.group_id
829 | """)
830 |
831 | records = []
832 | for match_query in match_queries:
833 | sub_records, _, _ = await driver.execute_query(
834 | match_query
835 | + filter_query
836 | + """
837 | RETURN
838 | """
839 | + get_entity_node_return_query(driver.provider)
840 | + """
841 | LIMIT $limit
842 | """,
843 | bfs_origin_node_uuids=bfs_origin_node_uuids,
844 | limit=limit,
845 | routing_='r',
846 | **filter_params,
847 | )
848 | records.extend(sub_records)
849 |
850 | nodes = [get_entity_node_from_record(record, driver.provider) for record in records]
851 |
852 | return nodes
853 |
854 |
855 | async def episode_fulltext_search(
856 | driver: GraphDriver,
857 | query: str,
858 | _search_filter: SearchFilters,
859 | group_ids: list[str] | None = None,
860 | limit=RELEVANT_SCHEMA_LIMIT,
861 | ) -> list[EpisodicNode]:
862 | if driver.search_interface:
863 | return await driver.search_interface.episode_fulltext_search(
864 | driver, query, _search_filter, group_ids, limit
865 | )
866 |
867 | # BM25 search to get top episodes
868 | fuzzy_query = fulltext_query(query, group_ids, driver)
869 | if fuzzy_query == '':
870 | return []
871 |
872 | filter_params: dict[str, Any] = {}
873 | group_filter_query: LiteralString = ''
874 | if group_ids is not None:
875 | group_filter_query += '\nAND e.group_id IN $group_ids'
876 | filter_params['group_ids'] = group_ids
877 |
878 | if driver.provider == GraphProvider.NEPTUNE:
879 | res = driver.run_aoss_query('episode_content', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
880 | if res['hits']['total']['value'] > 0:
881 | input_ids = []
882 | for r in res['hits']['hits']:
883 | input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
884 |
885 | # Match the edge ides and return the values
886 | query = """
887 | UNWIND $ids as i
888 | MATCH (e:Episodic)
889 | WHERE e.uuid=i.uuid
890 | RETURN
891 | e.content AS content,
892 | e.created_at AS created_at,
893 | e.valid_at AS valid_at,
894 | e.uuid AS uuid,
895 | e.name AS name,
896 | e.group_id AS group_id,
897 | e.source_description AS source_description,
898 | e.source AS source,
899 | e.entity_edges AS entity_edges
900 | ORDER BY i.score DESC
901 | LIMIT $limit
902 | """
903 | records, _, _ = await driver.execute_query(
904 | query,
905 | ids=input_ids,
906 | query=fuzzy_query,
907 | limit=limit,
908 | routing_='r',
909 | **filter_params,
910 | )
911 | else:
912 | return []
913 | else:
914 | query = (
915 | get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
916 | + """
917 | YIELD node AS episode, score
918 | MATCH (e:Episodic)
919 | WHERE e.uuid = episode.uuid
920 | """
921 | + group_filter_query
922 | + """
923 | RETURN
924 | """
925 | + EPISODIC_NODE_RETURN
926 | + """
927 | ORDER BY score DESC
928 | LIMIT $limit
929 | """
930 | )
931 |
932 | records, _, _ = await driver.execute_query(
933 | query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
934 | )
935 |
936 | episodes = [get_episodic_node_from_record(record) for record in records]
937 |
938 | return episodes
939 |
940 |
941 | async def community_fulltext_search(
942 | driver: GraphDriver,
943 | query: str,
944 | group_ids: list[str] | None = None,
945 | limit=RELEVANT_SCHEMA_LIMIT,
946 | ) -> list[CommunityNode]:
947 | # BM25 search to get top communities
948 | fuzzy_query = fulltext_query(query, group_ids, driver)
949 | if fuzzy_query == '':
950 | return []
951 |
952 | filter_params: dict[str, Any] = {}
953 | group_filter_query: LiteralString = ''
954 | if group_ids is not None:
955 | group_filter_query = 'WHERE c.group_id IN $group_ids'
956 | filter_params['group_ids'] = group_ids
957 |
958 | yield_query = 'YIELD node AS c, score'
959 | if driver.provider == GraphProvider.KUZU:
960 | yield_query = 'WITH node AS c, score'
961 |
962 | if driver.provider == GraphProvider.NEPTUNE:
963 | res = driver.run_aoss_query('community_name', query, limit=limit) # pyright: ignore reportAttributeAccessIssue
964 | if res['hits']['total']['value'] > 0:
965 | # Calculate Cosine similarity then return the edge ids
966 | input_ids = []
967 | for r in res['hits']['hits']:
968 | input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})
969 |
970 | # Match the edge ides and return the values
971 | query = """
972 | UNWIND $ids as i
973 | MATCH (comm:Community)
974 | WHERE comm.uuid=i.id
975 | RETURN
976 | comm.uuid AS uuid,
977 | comm.group_id AS group_id,
978 | comm.name AS name,
979 | comm.created_at AS created_at,
980 | comm.summary AS summary,
981 | [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
982 | ORDER BY i.score DESC
983 | LIMIT $limit
984 | """
985 | records, _, _ = await driver.execute_query(
986 | query,
987 | ids=input_ids,
988 | query=fuzzy_query,
989 | limit=limit,
990 | routing_='r',
991 | **filter_params,
992 | )
993 | else:
994 | return []
995 | else:
996 | query = (
997 | get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
998 | + yield_query
999 | + """
1000 | WITH c, score
1001 | """
1002 | + group_filter_query
1003 | + """
1004 | RETURN
1005 | """
1006 | + COMMUNITY_NODE_RETURN
1007 | + """
1008 | ORDER BY score DESC
1009 | LIMIT $limit
1010 | """
1011 | )
1012 |
1013 | records, _, _ = await driver.execute_query(
1014 | query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
1015 | )
1016 |
1017 | communities = [get_community_node_from_record(record) for record in records]
1018 |
1019 | return communities
1020 |
1021 |
1022 | async def community_similarity_search(
1023 | driver: GraphDriver,
1024 | search_vector: list[float],
1025 | group_ids: list[str] | None = None,
1026 | limit=RELEVANT_SCHEMA_LIMIT,
1027 | min_score=DEFAULT_MIN_SCORE,
1028 | ) -> list[CommunityNode]:
1029 | # vector similarity search over entity names
1030 | query_params: dict[str, Any] = {}
1031 |
1032 | group_filter_query: LiteralString = ''
1033 | if group_ids is not None:
1034 | group_filter_query += ' WHERE c.group_id IN $group_ids'
1035 | query_params['group_ids'] = group_ids
1036 |
1037 | if driver.provider == GraphProvider.NEPTUNE:
1038 | query = (
1039 | """
1040 | MATCH (n:Community)
1041 | """
1042 | + group_filter_query
1043 | + """
1044 | RETURN DISTINCT id(n) as id, n.name_embedding as embedding
1045 | """
1046 | )
1047 | resp, header, _ = await driver.execute_query(
1048 | query,
1049 | search_vector=search_vector,
1050 | limit=limit,
1051 | min_score=min_score,
1052 | routing_='r',
1053 | **query_params,
1054 | )
1055 |
1056 | if len(resp) > 0:
1057 | # Calculate Cosine similarity then return the edge ids
1058 | input_ids = []
1059 | for r in resp:
1060 | if r['embedding']:
1061 | score = calculate_cosine_similarity(
1062 | search_vector, list(map(float, r['embedding'].split(',')))
1063 | )
1064 | if score > min_score:
1065 | input_ids.append({'id': r['id'], 'score': score})
1066 |
1067 | # Match the edge ides and return the values
1068 | query = """
1069 | UNWIND $ids as i
1070 | MATCH (comm:Community)
1071 | WHERE id(comm)=i.id
1072 | RETURN
1073 | comm.uuid As uuid,
1074 | comm.group_id AS group_id,
1075 | comm.name AS name,
1076 | comm.created_at AS created_at,
1077 | comm.summary AS summary,
1078 | comm.name_embedding AS name_embedding
1079 | ORDER BY i.score DESC
1080 | LIMIT $limit
1081 | """
1082 | records, header, _ = await driver.execute_query(
1083 | query,
1084 | ids=input_ids,
1085 | search_vector=search_vector,
1086 | limit=limit,
1087 | min_score=min_score,
1088 | routing_='r',
1089 | **query_params,
1090 | )
1091 | else:
1092 | return []
1093 | else:
1094 | search_vector_var = '$search_vector'
1095 | if driver.provider == GraphProvider.KUZU:
1096 | search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'
1097 |
1098 | query = (
1099 | """
1100 | MATCH (c:Community)
1101 | """
1102 | + group_filter_query
1103 | + """
1104 | WITH c,
1105 | """
1106 | + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
1107 | + """ AS score
1108 | WHERE score > $min_score
1109 | RETURN
1110 | """
1111 | + COMMUNITY_NODE_RETURN
1112 | + """
1113 | ORDER BY score DESC
1114 | LIMIT $limit
1115 | """
1116 | )
1117 |
1118 | records, _, _ = await driver.execute_query(
1119 | query,
1120 | search_vector=search_vector,
1121 | limit=limit,
1122 | min_score=min_score,
1123 | routing_='r',
1124 | **query_params,
1125 | )
1126 |
1127 | communities = [get_community_node_from_record(record) for record in records]
1128 |
1129 | return communities
1130 |
1131 |
1132 | async def hybrid_node_search(
1133 | queries: list[str],
1134 | embeddings: list[list[float]],
1135 | driver: GraphDriver,
1136 | search_filter: SearchFilters,
1137 | group_ids: list[str] | None = None,
1138 | limit: int = RELEVANT_SCHEMA_LIMIT,
1139 | ) -> list[EntityNode]:
1140 | """
1141 | Perform a hybrid search for nodes using both text queries and embeddings.
1142 |
1143 | This method combines fulltext search and vector similarity search to find
1144 | relevant nodes in the graph database. It uses a rrf reranker.
1145 |
1146 | Parameters
1147 | ----------
1148 | queries : list[str]
1149 | A list of text queries to search for.
1150 | embeddings : list[list[float]]
1151 | A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
1152 | driver : GraphDriver
1153 | The Neo4j driver instance for database operations.
1154 | group_ids : list[str] | None, optional
1155 | The list of group ids to retrieve nodes from.
1156 | limit : int | None, optional
1157 | The maximum number of results to return per search method. If None, a default limit will be applied.
1158 |
1159 | Returns
1160 | -------
1161 | list[EntityNode]
1162 | A list of unique EntityNode objects that match the search criteria.
1163 |
1164 | Notes
1165 | -----
1166 | This method performs the following steps:
1167 | 1. Executes fulltext searches for each query.
1168 | 2. Executes vector similarity searches for each embedding.
1169 | 3. Combines and deduplicates the results from both search types.
1170 | 4. Logs the performance metrics of the search operation.
1171 |
1172 | The search results are deduplicated based on the node UUIDs to ensure
1173 | uniqueness in the returned list. The 'limit' parameter is applied to each
1174 | individual search method before deduplication. If not specified, a default
1175 | limit (defined in the individual search functions) will be used.
1176 | """
1177 |
1178 | start = time()
1179 | results: list[list[EntityNode]] = list(
1180 | await semaphore_gather(
1181 | *[
1182 | node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
1183 | for q in queries
1184 | ],
1185 | *[
1186 | node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
1187 | for e in embeddings
1188 | ],
1189 | )
1190 | )
1191 |
1192 | node_uuid_map: dict[str, EntityNode] = {
1193 | node.uuid: node for result in results for node in result
1194 | }
1195 | result_uuids = [[node.uuid for node in result] for result in results]
1196 |
1197 | ranked_uuids, _ = rrf(result_uuids)
1198 |
1199 | relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]
1200 |
1201 | end = time()
1202 | logger.debug(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
1203 | return relevant_nodes
1204 |
1205 |
1206 | async def get_relevant_nodes(
1207 | driver: GraphDriver,
1208 | nodes: list[EntityNode],
1209 | search_filter: SearchFilters,
1210 | min_score: float = DEFAULT_MIN_SCORE,
1211 | limit: int = RELEVANT_SCHEMA_LIMIT,
1212 | ) -> list[list[EntityNode]]:
1213 | if len(nodes) == 0:
1214 | return []
1215 |
1216 | group_id = nodes[0].group_id
1217 | query_nodes = [
1218 | {
1219 | 'uuid': node.uuid,
1220 | 'name': node.name,
1221 | 'name_embedding': node.name_embedding,
1222 | 'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
1223 | }
1224 | for node in nodes
1225 | ]
1226 |
1227 | filter_queries, filter_params = node_search_filter_query_constructor(
1228 | search_filter, driver.provider
1229 | )
1230 |
1231 | filter_query = ''
1232 | if filter_queries:
1233 | filter_query = 'WHERE ' + (' AND '.join(filter_queries))
1234 |
1235 | if driver.provider == GraphProvider.KUZU:
1236 | embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
1237 | if embedding_size == 0:
1238 | return []
1239 |
1240 | # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
1241 | query = (
1242 | """
1243 | UNWIND $nodes AS node
1244 | MATCH (n:Entity {group_id: $group_id})
1245 | """
1246 | + filter_query
1247 | + """
1248 | WITH node, n, """
1249 | + get_vector_cosine_func_query(
1250 | 'n.name_embedding',
1251 | f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
1252 | driver.provider,
1253 | )
1254 | + """ AS score
1255 | WHERE score > $min_score
1256 | WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1257 | """
1258 | + get_nodes_query(
1259 | 'node_name_and_summary',
1260 | 'node.fulltext_query',
1261 | limit=limit,
1262 | provider=driver.provider,
1263 | )
1264 | + """
1265 | WITH node AS m
1266 | WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
1267 | WITH node, top_vector_nodes, collect(m) AS fulltext_nodes
1268 |
1269 | WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes
1270 |
1271 | UNWIND combined_nodes AS x
1272 | WITH node, collect(DISTINCT {
1273 | uuid: x.uuid,
1274 | name: x.name,
1275 | name_embedding: x.name_embedding,
1276 | group_id: x.group_id,
1277 | created_at: x.created_at,
1278 | summary: x.summary,
1279 | labels: x.labels,
1280 | attributes: x.attributes
1281 | }) AS matches
1282 |
1283 | RETURN
1284 | node.uuid AS search_node_uuid, matches
1285 | """
1286 | )
1287 | else:
1288 | query = (
1289 | """
1290 | UNWIND $nodes AS node
1291 | MATCH (n:Entity {group_id: $group_id})
1292 | """
1293 | + filter_query
1294 | + """
1295 | WITH node, n, """
1296 | + get_vector_cosine_func_query(
1297 | 'n.name_embedding', 'node.name_embedding', driver.provider
1298 | )
1299 | + """ AS score
1300 | WHERE score > $min_score
1301 | WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
1302 | """
1303 | + get_nodes_query(
1304 | 'node_name_and_summary',
1305 | 'node.fulltext_query',
1306 | limit=limit,
1307 | provider=driver.provider,
1308 | )
1309 | + """
1310 | YIELD node AS m
1311 | WHERE m.group_id = $group_id
1312 | WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes
1313 |
1314 | WITH node,
1315 | top_vector_nodes,
1316 | [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes
1317 |
1318 | WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes
1319 |
1320 | UNWIND combined_nodes AS combined_node
1321 | WITH node, collect(DISTINCT combined_node) AS deduped_nodes
1322 |
1323 | RETURN
1324 | node.uuid AS search_node_uuid,
1325 | [x IN deduped_nodes | {
1326 | uuid: x.uuid,
1327 | name: x.name,
1328 | name_embedding: x.name_embedding,
1329 | group_id: x.group_id,
1330 | created_at: x.created_at,
1331 | summary: x.summary,
1332 | labels: labels(x),
1333 | attributes: properties(x)
1334 | }] AS matches
1335 | """
1336 | )
1337 |
1338 | results, _, _ = await driver.execute_query(
1339 | query,
1340 | nodes=query_nodes,
1341 | group_id=group_id,
1342 | limit=limit,
1343 | min_score=min_score,
1344 | routing_='r',
1345 | **filter_params,
1346 | )
1347 |
1348 | relevant_nodes_dict: dict[str, list[EntityNode]] = {
1349 | result['search_node_uuid']: [
1350 | get_entity_node_from_record(record, driver.provider) for record in result['matches']
1351 | ]
1352 | for result in results
1353 | }
1354 |
1355 | relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]
1356 |
1357 | return relevant_nodes
1358 |
1359 |
1360 | async def get_relevant_edges(
1361 | driver: GraphDriver,
1362 | edges: list[EntityEdge],
1363 | search_filter: SearchFilters,
1364 | min_score: float = DEFAULT_MIN_SCORE,
1365 | limit: int = RELEVANT_SCHEMA_LIMIT,
1366 | ) -> list[list[EntityEdge]]:
1367 | if len(edges) == 0:
1368 | return []
1369 |
1370 | filter_queries, filter_params = edge_search_filter_query_constructor(
1371 | search_filter, driver.provider
1372 | )
1373 |
1374 | filter_query = ''
1375 | if filter_queries:
1376 | filter_query = ' WHERE ' + (' AND '.join(filter_queries))
1377 |
1378 | if driver.provider == GraphProvider.NEPTUNE:
1379 | query = (
1380 | """
1381 | UNWIND $edges AS edge
1382 | MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1383 | """
1384 | + filter_query
1385 | + """
1386 | WITH e, edge
1387 | RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
1388 | edge.fact_embedding as target_embedding
1389 | """
1390 | )
1391 | resp, _, _ = await driver.execute_query(
1392 | query,
1393 | edges=[edge.model_dump() for edge in edges],
1394 | limit=limit,
1395 | min_score=min_score,
1396 | routing_='r',
1397 | **filter_params,
1398 | )
1399 |
1400 | # Calculate Cosine similarity then return the edge ids
1401 | input_ids = []
1402 | for r in resp:
1403 | score = calculate_cosine_similarity(
1404 | list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1405 | )
1406 | if score > min_score:
1407 | input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1408 |
1409 | # Match the edge ides and return the values
1410 | query = """
1411 | UNWIND $ids AS edge
1412 | MATCH ()-[e]->()
1413 | WHERE id(e) = edge.id
1414 | WITH edge, e
1415 | ORDER BY edge.score DESC
1416 | RETURN edge.uuid AS search_edge_uuid,
1417 | collect({
1418 | uuid: e.uuid,
1419 | source_node_uuid: startNode(e).uuid,
1420 | target_node_uuid: endNode(e).uuid,
1421 | created_at: e.created_at,
1422 | name: e.name,
1423 | group_id: e.group_id,
1424 | fact: e.fact,
1425 | fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1426 | episodes: split(e.episodes, ","),
1427 | expired_at: e.expired_at,
1428 | valid_at: e.valid_at,
1429 | invalid_at: e.invalid_at,
1430 | attributes: properties(e)
1431 | })[..$limit] AS matches
1432 | """
1433 |
1434 | results, _, _ = await driver.execute_query(
1435 | query,
1436 | ids=input_ids,
1437 | edges=[edge.model_dump() for edge in edges],
1438 | limit=limit,
1439 | min_score=min_score,
1440 | routing_='r',
1441 | **filter_params,
1442 | )
1443 | else:
1444 | if driver.provider == GraphProvider.KUZU:
1445 | embedding_size = (
1446 | len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1447 | )
1448 | if embedding_size == 0:
1449 | return []
1450 |
1451 | query = (
1452 | """
1453 | UNWIND $edges AS edge
1454 | MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
1455 | """
1456 | + filter_query
1457 | + """
1458 | WITH e, edge, n, m, """
1459 | + get_vector_cosine_func_query(
1460 | 'e.fact_embedding',
1461 | f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1462 | driver.provider,
1463 | )
1464 | + """ AS score
1465 | WHERE score > $min_score
1466 | WITH e, edge, n, m, score
1467 | ORDER BY score DESC
1468 | LIMIT $limit
1469 | RETURN
1470 | edge.uuid AS search_edge_uuid,
1471 | collect({
1472 | uuid: e.uuid,
1473 | source_node_uuid: n.uuid,
1474 | target_node_uuid: m.uuid,
1475 | created_at: e.created_at,
1476 | name: e.name,
1477 | group_id: e.group_id,
1478 | fact: e.fact,
1479 | fact_embedding: e.fact_embedding,
1480 | episodes: e.episodes,
1481 | expired_at: e.expired_at,
1482 | valid_at: e.valid_at,
1483 | invalid_at: e.invalid_at,
1484 | attributes: e.attributes
1485 | }) AS matches
1486 | """
1487 | )
1488 | else:
1489 | query = (
1490 | """
1491 | UNWIND $edges AS edge
1492 | MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
1493 | """
1494 | + filter_query
1495 | + """
1496 | WITH e, edge, """
1497 | + get_vector_cosine_func_query(
1498 | 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1499 | )
1500 | + """ AS score
1501 | WHERE score > $min_score
1502 | WITH edge, e, score
1503 | ORDER BY score DESC
1504 | RETURN
1505 | edge.uuid AS search_edge_uuid,
1506 | collect({
1507 | uuid: e.uuid,
1508 | source_node_uuid: startNode(e).uuid,
1509 | target_node_uuid: endNode(e).uuid,
1510 | created_at: e.created_at,
1511 | name: e.name,
1512 | group_id: e.group_id,
1513 | fact: e.fact,
1514 | fact_embedding: e.fact_embedding,
1515 | episodes: e.episodes,
1516 | expired_at: e.expired_at,
1517 | valid_at: e.valid_at,
1518 | invalid_at: e.invalid_at,
1519 | attributes: properties(e)
1520 | })[..$limit] AS matches
1521 | """
1522 | )
1523 |
1524 | results, _, _ = await driver.execute_query(
1525 | query,
1526 | edges=[edge.model_dump() for edge in edges],
1527 | limit=limit,
1528 | min_score=min_score,
1529 | routing_='r',
1530 | **filter_params,
1531 | )
1532 |
1533 | relevant_edges_dict: dict[str, list[EntityEdge]] = {
1534 | result['search_edge_uuid']: [
1535 | get_entity_edge_from_record(record, driver.provider) for record in result['matches']
1536 | ]
1537 | for result in results
1538 | }
1539 |
1540 | relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]
1541 |
1542 | return relevant_edges
1543 |
1544 |
1545 | async def get_edge_invalidation_candidates(
1546 | driver: GraphDriver,
1547 | edges: list[EntityEdge],
1548 | search_filter: SearchFilters,
1549 | min_score: float = DEFAULT_MIN_SCORE,
1550 | limit: int = RELEVANT_SCHEMA_LIMIT,
1551 | ) -> list[list[EntityEdge]]:
1552 | if len(edges) == 0:
1553 | return []
1554 |
1555 | filter_queries, filter_params = edge_search_filter_query_constructor(
1556 | search_filter, driver.provider
1557 | )
1558 |
1559 | filter_query = ''
1560 | if filter_queries:
1561 | filter_query = ' AND ' + (' AND '.join(filter_queries))
1562 |
1563 | if driver.provider == GraphProvider.NEPTUNE:
1564 | query = (
1565 | """
1566 | UNWIND $edges AS edge
1567 | MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1568 | WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1569 | """
1570 | + filter_query
1571 | + """
1572 | WITH e, edge
1573 | RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
1574 | edge.fact_embedding as target_embedding,
1575 | edge.uuid as search_edge_uuid
1576 | """
1577 | )
1578 | resp, _, _ = await driver.execute_query(
1579 | query,
1580 | edges=[edge.model_dump() for edge in edges],
1581 | limit=limit,
1582 | min_score=min_score,
1583 | routing_='r',
1584 | **filter_params,
1585 | )
1586 |
1587 | # Calculate Cosine similarity then return the edge ids
1588 | input_ids = []
1589 | for r in resp:
1590 | score = calculate_cosine_similarity(
1591 | list(map(float, r['source_embedding'].split(','))), r['target_embedding']
1592 | )
1593 | if score > min_score:
1594 | input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})
1595 |
1596 | # Match the edge ides and return the values
1597 | query = """
1598 | UNWIND $ids AS edge
1599 | MATCH ()-[e]->()
1600 | WHERE id(e) = edge.id
1601 | WITH edge, e
1602 | ORDER BY edge.score DESC
1603 | RETURN edge.uuid AS search_edge_uuid,
1604 | collect({
1605 | uuid: e.uuid,
1606 | source_node_uuid: startNode(e).uuid,
1607 | target_node_uuid: endNode(e).uuid,
1608 | created_at: e.created_at,
1609 | name: e.name,
1610 | group_id: e.group_id,
1611 | fact: e.fact,
1612 | fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
1613 | episodes: split(e.episodes, ","),
1614 | expired_at: e.expired_at,
1615 | valid_at: e.valid_at,
1616 | invalid_at: e.invalid_at,
1617 | attributes: properties(e)
1618 | })[..$limit] AS matches
1619 | """
1620 | results, _, _ = await driver.execute_query(
1621 | query,
1622 | ids=input_ids,
1623 | edges=[edge.model_dump() for edge in edges],
1624 | limit=limit,
1625 | min_score=min_score,
1626 | routing_='r',
1627 | **filter_params,
1628 | )
1629 | else:
1630 | if driver.provider == GraphProvider.KUZU:
1631 | embedding_size = (
1632 | len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
1633 | )
1634 | if embedding_size == 0:
1635 | return []
1636 |
1637 | query = (
1638 | """
1639 | UNWIND $edges AS edge
1640 | MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
1641 | WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
1642 | """
1643 | + filter_query
1644 | + """
1645 | WITH edge, e, n, m, """
1646 | + get_vector_cosine_func_query(
1647 | 'e.fact_embedding',
1648 | f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
1649 | driver.provider,
1650 | )
1651 | + """ AS score
1652 | WHERE score > $min_score
1653 | WITH edge, e, n, m, score
1654 | ORDER BY score DESC
1655 | LIMIT $limit
1656 | RETURN
1657 | edge.uuid AS search_edge_uuid,
1658 | collect({
1659 | uuid: e.uuid,
1660 | source_node_uuid: n.uuid,
1661 | target_node_uuid: m.uuid,
1662 | created_at: e.created_at,
1663 | name: e.name,
1664 | group_id: e.group_id,
1665 | fact: e.fact,
1666 | fact_embedding: e.fact_embedding,
1667 | episodes: e.episodes,
1668 | expired_at: e.expired_at,
1669 | valid_at: e.valid_at,
1670 | invalid_at: e.invalid_at,
1671 | attributes: e.attributes
1672 | }) AS matches
1673 | """
1674 | )
1675 | else:
1676 | query = (
1677 | """
1678 | UNWIND $edges AS edge
1679 | MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
1680 | WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
1681 | """
1682 | + filter_query
1683 | + """
1684 | WITH edge, e, """
1685 | + get_vector_cosine_func_query(
1686 | 'e.fact_embedding', 'edge.fact_embedding', driver.provider
1687 | )
1688 | + """ AS score
1689 | WHERE score > $min_score
1690 | WITH edge, e, score
1691 | ORDER BY score DESC
1692 | RETURN
1693 | edge.uuid AS search_edge_uuid,
1694 | collect({
1695 | uuid: e.uuid,
1696 | source_node_uuid: startNode(e).uuid,
1697 | target_node_uuid: endNode(e).uuid,
1698 | created_at: e.created_at,
1699 | name: e.name,
1700 | group_id: e.group_id,
1701 | fact: e.fact,
1702 | fact_embedding: e.fact_embedding,
1703 | episodes: e.episodes,
1704 | expired_at: e.expired_at,
1705 | valid_at: e.valid_at,
1706 | invalid_at: e.invalid_at,
1707 | attributes: properties(e)
1708 | })[..$limit] AS matches
1709 | """
1710 | )
1711 |
1712 | results, _, _ = await driver.execute_query(
1713 | query,
1714 | edges=[edge.model_dump() for edge in edges],
1715 | limit=limit,
1716 | min_score=min_score,
1717 | routing_='r',
1718 | **filter_params,
1719 | )
1720 | invalidation_edges_dict: dict[str, list[EntityEdge]] = {
1721 | result['search_edge_uuid']: [
1722 | get_entity_edge_from_record(record, driver.provider) for record in result['matches']
1723 | ]
1724 | for result in results
1725 | }
1726 |
1727 | invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]
1728 |
1729 | return invalidation_edges
1730 |
1731 |
1732 | # takes in a list of rankings of uuids
1733 | def rrf(
1734 | results: list[list[str]], rank_const=1, min_score: float = 0
1735 | ) -> tuple[list[str], list[float]]:
1736 | scores: dict[str, float] = defaultdict(float)
1737 | for result in results:
1738 | for i, uuid in enumerate(result):
1739 | scores[uuid] += 1 / (i + rank_const)
1740 |
1741 | scored_uuids = [term for term in scores.items()]
1742 | scored_uuids.sort(reverse=True, key=lambda term: term[1])
1743 |
1744 | sorted_uuids = [term[0] for term in scored_uuids]
1745 |
1746 | return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1747 | scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1748 | ]
1749 |
1750 |
1751 | async def node_distance_reranker(
1752 | driver: GraphDriver,
1753 | node_uuids: list[str],
1754 | center_node_uuid: str,
1755 | min_score: float = 0,
1756 | ) -> tuple[list[str], list[float]]:
1757 | # filter out node_uuid center node node uuid
1758 | filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
1759 | scores: dict[str, float] = {center_node_uuid: 0.0}
1760 |
1761 | query = """
1762 | UNWIND $node_uuids AS node_uuid
1763 | MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
1764 | RETURN 1 AS score, node_uuid AS uuid
1765 | """
1766 | if driver.provider == GraphProvider.KUZU:
1767 | query = """
1768 | UNWIND $node_uuids AS node_uuid
1769 | MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
1770 | RETURN 1 AS score, node_uuid AS uuid
1771 | """
1772 |
1773 | # Find the shortest path to center node
1774 | results, header, _ = await driver.execute_query(
1775 | query,
1776 | node_uuids=filtered_uuids,
1777 | center_uuid=center_node_uuid,
1778 | routing_='r',
1779 | )
1780 | if driver.provider == GraphProvider.FALKORDB:
1781 | results = [dict(zip(header, row, strict=True)) for row in results]
1782 |
1783 | for result in results:
1784 | uuid = result['uuid']
1785 | score = result['score']
1786 | scores[uuid] = score
1787 |
1788 | for uuid in filtered_uuids:
1789 | if uuid not in scores:
1790 | scores[uuid] = float('inf')
1791 |
1792 | # rerank on shortest distance
1793 | filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
1794 |
1795 | # add back in filtered center uuid if it was filtered out
1796 | if center_node_uuid in node_uuids:
1797 | scores[center_node_uuid] = 0.1
1798 | filtered_uuids = [center_node_uuid] + filtered_uuids
1799 |
1800 | return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
1801 | 1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
1802 | ]
1803 |
1804 |
1805 | async def episode_mentions_reranker(
1806 | driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
1807 | ) -> tuple[list[str], list[float]]:
1808 | # use rrf as a preliminary ranker
1809 | sorted_uuids, _ = rrf(node_uuids)
1810 | scores: dict[str, float] = {}
1811 |
1812 | # Find the shortest path to center node
1813 | results, _, _ = await driver.execute_query(
1814 | """
1815 | UNWIND $node_uuids AS node_uuid
1816 | MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
1817 | RETURN count(*) AS score, n.uuid AS uuid
1818 | """,
1819 | node_uuids=sorted_uuids,
1820 | routing_='r',
1821 | )
1822 |
1823 | for result in results:
1824 | scores[result['uuid']] = result['score']
1825 |
1826 | for uuid in sorted_uuids:
1827 | if uuid not in scores:
1828 | scores[uuid] = float('inf')
1829 |
1830 | # rerank on shortest distance
1831 | sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])
1832 |
1833 | return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
1834 | scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
1835 | ]
1836 |
1837 |
1838 | def maximal_marginal_relevance(
1839 | query_vector: list[float],
1840 | candidates: dict[str, list[float]],
1841 | mmr_lambda: float = DEFAULT_MMR_LAMBDA,
1842 | min_score: float = -2.0,
1843 | ) -> tuple[list[str], list[float]]:
1844 | start = time()
1845 | query_array = np.array(query_vector)
1846 | candidate_arrays: dict[str, NDArray] = {}
1847 | for uuid, embedding in candidates.items():
1848 | candidate_arrays[uuid] = normalize_l2(embedding)
1849 |
1850 | uuids: list[str] = list(candidate_arrays.keys())
1851 |
1852 | similarity_matrix = np.zeros((len(uuids), len(uuids)))
1853 |
1854 | for i, uuid_1 in enumerate(uuids):
1855 | for j, uuid_2 in enumerate(uuids[:i]):
1856 | u = candidate_arrays[uuid_1]
1857 | v = candidate_arrays[uuid_2]
1858 | similarity = np.dot(u, v)
1859 |
1860 | similarity_matrix[i, j] = similarity
1861 | similarity_matrix[j, i] = similarity
1862 |
1863 | mmr_scores: dict[str, float] = {}
1864 | for i, uuid in enumerate(uuids):
1865 | max_sim = np.max(similarity_matrix[i, :])
1866 | mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
1867 | mmr_scores[uuid] = mmr
1868 |
1869 | uuids.sort(reverse=True, key=lambda c: mmr_scores[c])
1870 |
1871 | end = time()
1872 | logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')
1873 |
1874 | return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
1875 | mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
1876 | ]
1877 |
1878 |
1879 | async def get_embeddings_for_nodes(
1880 | driver: GraphDriver, nodes: list[EntityNode]
1881 | ) -> dict[str, list[float]]:
1882 | if driver.graph_operations_interface:
1883 | return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
1884 | elif driver.provider == GraphProvider.NEPTUNE:
1885 | query = """
1886 | MATCH (n:Entity)
1887 | WHERE n.uuid IN $node_uuids
1888 | RETURN DISTINCT
1889 | n.uuid AS uuid,
1890 | split(n.name_embedding, ",") AS name_embedding
1891 | """
1892 | else:
1893 | query = """
1894 | MATCH (n:Entity)
1895 | WHERE n.uuid IN $node_uuids
1896 | RETURN DISTINCT
1897 | n.uuid AS uuid,
1898 | n.name_embedding AS name_embedding
1899 | """
1900 | results, _, _ = await driver.execute_query(
1901 | query,
1902 | node_uuids=[node.uuid for node in nodes],
1903 | routing_='r',
1904 | )
1905 |
1906 | embeddings_dict: dict[str, list[float]] = {}
1907 | for result in results:
1908 | uuid: str = result.get('uuid')
1909 | embedding: list[float] = result.get('name_embedding')
1910 | if uuid is not None and embedding is not None:
1911 | embeddings_dict[uuid] = embedding
1912 |
1913 | return embeddings_dict
1914 |
1915 |
1916 | async def get_embeddings_for_communities(
1917 | driver: GraphDriver, communities: list[CommunityNode]
1918 | ) -> dict[str, list[float]]:
1919 | if driver.provider == GraphProvider.NEPTUNE:
1920 | query = """
1921 | MATCH (c:Community)
1922 | WHERE c.uuid IN $community_uuids
1923 | RETURN DISTINCT
1924 | c.uuid AS uuid,
1925 | split(c.name_embedding, ",") AS name_embedding
1926 | """
1927 | else:
1928 | query = """
1929 | MATCH (c:Community)
1930 | WHERE c.uuid IN $community_uuids
1931 | RETURN DISTINCT
1932 | c.uuid AS uuid,
1933 | c.name_embedding AS name_embedding
1934 | """
1935 | results, _, _ = await driver.execute_query(
1936 | query,
1937 | community_uuids=[community.uuid for community in communities],
1938 | routing_='r',
1939 | )
1940 |
1941 | embeddings_dict: dict[str, list[float]] = {}
1942 | for result in results:
1943 | uuid: str = result.get('uuid')
1944 | embedding: list[float] = result.get('name_embedding')
1945 | if uuid is not None and embedding is not None:
1946 | embeddings_dict[uuid] = embedding
1947 |
1948 | return embeddings_dict
1949 |
1950 |
1951 | async def get_embeddings_for_edges(
1952 | driver: GraphDriver, edges: list[EntityEdge]
1953 | ) -> dict[str, list[float]]:
1954 | if driver.graph_operations_interface:
1955 | return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
1956 | elif driver.provider == GraphProvider.NEPTUNE:
1957 | query = """
1958 | MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1959 | WHERE e.uuid IN $edge_uuids
1960 | RETURN DISTINCT
1961 | e.uuid AS uuid,
1962 | split(e.fact_embedding, ",") AS fact_embedding
1963 | """
1964 | else:
1965 | match_query = """
1966 | MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
1967 | """
1968 | if driver.provider == GraphProvider.KUZU:
1969 | match_query = """
1970 | MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
1971 | """
1972 |
1973 | query = (
1974 | match_query
1975 | + """
1976 | WHERE e.uuid IN $edge_uuids
1977 | RETURN DISTINCT
1978 | e.uuid AS uuid,
1979 | e.fact_embedding AS fact_embedding
1980 | """
1981 | )
1982 | results, _, _ = await driver.execute_query(
1983 | query,
1984 | edge_uuids=[edge.uuid for edge in edges],
1985 | routing_='r',
1986 | )
1987 |
1988 | embeddings_dict: dict[str, list[float]] = {}
1989 | for result in results:
1990 | uuid: str = result.get('uuid')
1991 | embedding: list[float] = result.get('fact_embedding')
1992 | if uuid is not None and embedding is not None:
1993 | embeddings_dict[uuid] = embedding
1994 |
1995 | return embeddings_dict
1996 |
```