#
tokens: 41774/50000 3/236 files (page 8/9)
lines: off (toggle) GitHub
raw markdown copy
This is page 8 of 9. Use http://codebase.md/getzep/graphiti?page={x} to view the full context.

# Directory Structure

```
├── .env.example
├── .github
│   ├── dependabot.yml
│   ├── ISSUE_TEMPLATE
│   │   └── bug_report.md
│   ├── pull_request_template.md
│   ├── secret_scanning.yml
│   └── workflows
│       ├── ai-moderator.yml
│       ├── cla.yml
│       ├── claude-code-review-manual.yml
│       ├── claude-code-review.yml
│       ├── claude.yml
│       ├── codeql.yml
│       ├── lint.yml
│       ├── release-graphiti-core.yml
│       ├── release-mcp-server.yml
│       ├── release-server-container.yml
│       ├── typecheck.yml
│       └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│   ├── azure-openai
│   │   ├── .env.example
│   │   ├── azure_openai_neo4j.py
│   │   └── README.md
│   ├── data
│   │   └── manybirds_products.json
│   ├── ecommerce
│   │   ├── runner.ipynb
│   │   └── runner.py
│   ├── langgraph-agent
│   │   ├── agent.ipynb
│   │   └── tinybirds-jess.png
│   ├── opentelemetry
│   │   ├── .env.example
│   │   ├── otel_stdout_example.py
│   │   ├── pyproject.toml
│   │   ├── README.md
│   │   └── uv.lock
│   ├── podcast
│   │   ├── podcast_runner.py
│   │   ├── podcast_transcript.txt
│   │   └── transcript_parser.py
│   ├── quickstart
│   │   ├── dense_vs_normal_ingestion.py
│   │   ├── quickstart_falkordb.py
│   │   ├── quickstart_neo4j.py
│   │   ├── quickstart_neptune.py
│   │   ├── README.md
│   │   └── requirements.txt
│   └── wizard_of_oz
│       ├── parser.py
│       ├── runner.py
│       └── woo.txt
├── graphiti_core
│   ├── __init__.py
│   ├── cross_encoder
│   │   ├── __init__.py
│   │   ├── bge_reranker_client.py
│   │   ├── client.py
│   │   ├── gemini_reranker_client.py
│   │   └── openai_reranker_client.py
│   ├── decorators.py
│   ├── driver
│   │   ├── __init__.py
│   │   ├── driver.py
│   │   ├── falkordb_driver.py
│   │   ├── graph_operations
│   │   │   └── graph_operations.py
│   │   ├── kuzu_driver.py
│   │   ├── neo4j_driver.py
│   │   ├── neptune_driver.py
│   │   └── search_interface
│   │       └── search_interface.py
│   ├── edges.py
│   ├── embedder
│   │   ├── __init__.py
│   │   ├── azure_openai.py
│   │   ├── client.py
│   │   ├── gemini.py
│   │   ├── openai.py
│   │   └── voyage.py
│   ├── errors.py
│   ├── graph_queries.py
│   ├── graphiti_types.py
│   ├── graphiti.py
│   ├── helpers.py
│   ├── llm_client
│   │   ├── __init__.py
│   │   ├── anthropic_client.py
│   │   ├── azure_openai_client.py
│   │   ├── client.py
│   │   ├── config.py
│   │   ├── errors.py
│   │   ├── gemini_client.py
│   │   ├── groq_client.py
│   │   ├── openai_base_client.py
│   │   ├── openai_client.py
│   │   ├── openai_generic_client.py
│   │   └── utils.py
│   ├── migrations
│   │   └── __init__.py
│   ├── models
│   │   ├── __init__.py
│   │   ├── edges
│   │   │   ├── __init__.py
│   │   │   └── edge_db_queries.py
│   │   └── nodes
│   │       ├── __init__.py
│   │       └── node_db_queries.py
│   ├── nodes.py
│   ├── prompts
│   │   ├── __init__.py
│   │   ├── dedupe_edges.py
│   │   ├── dedupe_nodes.py
│   │   ├── eval.py
│   │   ├── extract_edge_dates.py
│   │   ├── extract_edges.py
│   │   ├── extract_nodes.py
│   │   ├── invalidate_edges.py
│   │   ├── lib.py
│   │   ├── models.py
│   │   ├── prompt_helpers.py
│   │   ├── snippets.py
│   │   └── summarize_nodes.py
│   ├── py.typed
│   ├── search
│   │   ├── __init__.py
│   │   ├── search_config_recipes.py
│   │   ├── search_config.py
│   │   ├── search_filters.py
│   │   ├── search_helpers.py
│   │   ├── search_utils.py
│   │   └── search.py
│   ├── telemetry
│   │   ├── __init__.py
│   │   └── telemetry.py
│   ├── tracer.py
│   └── utils
│       ├── __init__.py
│       ├── bulk_utils.py
│       ├── content_chunking.py
│       ├── datetime_utils.py
│       ├── maintenance
│       │   ├── __init__.py
│       │   ├── community_operations.py
│       │   ├── dedup_helpers.py
│       │   ├── edge_operations.py
│       │   ├── graph_data_operations.py
│       │   ├── node_operations.py
│       │   └── temporal_operations.py
│       ├── ontology_utils
│       │   └── entity_types_utils.py
│       └── text_utils.py
├── images
│   ├── arxiv-screenshot.png
│   ├── graphiti-graph-intro.gif
│   ├── graphiti-intro-slides-stock-2.gif
│   └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│   ├── .env.example
│   ├── .python-version
│   ├── config
│   │   ├── config-docker-falkordb-combined.yaml
│   │   ├── config-docker-falkordb.yaml
│   │   ├── config-docker-neo4j.yaml
│   │   ├── config.yaml
│   │   └── mcp_config_stdio_example.json
│   ├── docker
│   │   ├── build-standalone.sh
│   │   ├── build-with-version.sh
│   │   ├── docker-compose-falkordb.yml
│   │   ├── docker-compose-neo4j.yml
│   │   ├── docker-compose.yml
│   │   ├── Dockerfile
│   │   ├── Dockerfile.standalone
│   │   ├── github-actions-example.yml
│   │   ├── README-falkordb-combined.md
│   │   └── README.md
│   ├── docs
│   │   └── cursor_rules.md
│   ├── main.py
│   ├── pyproject.toml
│   ├── pytest.ini
│   ├── README.md
│   ├── src
│   │   ├── __init__.py
│   │   ├── config
│   │   │   ├── __init__.py
│   │   │   └── schema.py
│   │   ├── graphiti_mcp_server.py
│   │   ├── models
│   │   │   ├── __init__.py
│   │   │   ├── entity_types.py
│   │   │   └── response_types.py
│   │   ├── services
│   │   │   ├── __init__.py
│   │   │   ├── factories.py
│   │   │   └── queue_service.py
│   │   └── utils
│   │       ├── __init__.py
│   │       ├── formatting.py
│   │       └── utils.py
│   ├── tests
│   │   ├── __init__.py
│   │   ├── conftest.py
│   │   ├── pytest.ini
│   │   ├── README.md
│   │   ├── run_tests.py
│   │   ├── test_async_operations.py
│   │   ├── test_comprehensive_integration.py
│   │   ├── test_configuration.py
│   │   ├── test_falkordb_integration.py
│   │   ├── test_fixtures.py
│   │   ├── test_http_integration.py
│   │   ├── test_integration.py
│   │   ├── test_mcp_integration.py
│   │   ├── test_mcp_transports.py
│   │   ├── test_stdio_simple.py
│   │   └── test_stress_load.py
│   └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│   ├── .env.example
│   ├── graph_service
│   │   ├── __init__.py
│   │   ├── config.py
│   │   ├── dto
│   │   │   ├── __init__.py
│   │   │   ├── common.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   ├── main.py
│   │   ├── routers
│   │   │   ├── __init__.py
│   │   │   ├── ingest.py
│   │   │   └── retrieve.py
│   │   └── zep_graphiti.py
│   ├── Makefile
│   ├── pyproject.toml
│   ├── README.md
│   └── uv.lock
├── signatures
│   └── version1
│       └── cla.json
├── tests
│   ├── cross_encoder
│   │   ├── test_bge_reranker_client_int.py
│   │   └── test_gemini_reranker_client.py
│   ├── driver
│   │   ├── __init__.py
│   │   └── test_falkordb_driver.py
│   ├── embedder
│   │   ├── embedder_fixtures.py
│   │   ├── test_gemini.py
│   │   ├── test_openai.py
│   │   └── test_voyage.py
│   ├── evals
│   │   ├── data
│   │   │   └── longmemeval_data
│   │   │       ├── longmemeval_oracle.json
│   │   │       └── README.md
│   │   ├── eval_cli.py
│   │   ├── eval_e2e_graph_building.py
│   │   ├── pytest.ini
│   │   └── utils.py
│   ├── helpers_test.py
│   ├── llm_client
│   │   ├── test_anthropic_client_int.py
│   │   ├── test_anthropic_client.py
│   │   ├── test_azure_openai_client.py
│   │   ├── test_client.py
│   │   ├── test_errors.py
│   │   └── test_gemini_client.py
│   ├── test_edge_int.py
│   ├── test_entity_exclusion_int.py
│   ├── test_graphiti_int.py
│   ├── test_graphiti_mock.py
│   ├── test_node_int.py
│   ├── test_text_utils.py
│   └── utils
│       ├── maintenance
│       │   ├── test_bulk_utils.py
│       │   ├── test_edge_operations.py
│       │   ├── test_entity_extraction.py
│       │   ├── test_node_operations.py
│       │   └── test_temporal_operations_int.py
│       ├── search
│       │   └── search_utils_test.py
│       └── test_content_chunking.py
├── uv.lock
└── Zep-CLA.md
```

# Files

--------------------------------------------------------------------------------
/graphiti_core/graphiti.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from datetime import datetime
from time import time

from dotenv import load_dotenv
from pydantic import BaseModel
from typing_extensions import LiteralString

from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.decorators import handle_multiple_group_ids
from graphiti_core.driver.driver import GraphDriver
from graphiti_core.driver.neo4j_driver import Neo4jDriver
from graphiti_core.edges import (
    CommunityEdge,
    Edge,
    EntityEdge,
    EpisodicEdge,
    create_entity_edge_embeddings,
)
from graphiti_core.embedder import EmbedderClient, OpenAIEmbedder
from graphiti_core.errors import NodeNotFoundError
from graphiti_core.graphiti_types import GraphitiClients
from graphiti_core.helpers import (
    get_default_group_id,
    semaphore_gather,
    validate_excluded_entity_types,
    validate_group_id,
)
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import (
    CommunityNode,
    EntityNode,
    EpisodeType,
    EpisodicNode,
    Node,
    create_entity_node_embeddings,
)
from graphiti_core.search.search import SearchConfig, search
from graphiti_core.search.search_config import DEFAULT_SEARCH_LIMIT, SearchResults
from graphiti_core.search.search_config_recipes import (
    COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
    EDGE_HYBRID_SEARCH_NODE_DISTANCE,
    EDGE_HYBRID_SEARCH_RRF,
)
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import (
    RELEVANT_SCHEMA_LIMIT,
    get_mentioned_nodes,
)
from graphiti_core.telemetry import capture_event
from graphiti_core.tracer import Tracer, create_tracer
from graphiti_core.utils.bulk_utils import (
    RawEpisode,
    add_nodes_and_edges_bulk,
    dedupe_edges_bulk,
    dedupe_nodes_bulk,
    extract_nodes_and_edges_bulk,
    resolve_edge_pointers,
    retrieve_previous_episodes_bulk,
)
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.community_operations import (
    build_communities,
    remove_communities,
    update_community,
)
from graphiti_core.utils.maintenance.edge_operations import (
    build_episodic_edges,
    extract_edges,
    resolve_extracted_edge,
    resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
    EPISODE_WINDOW_LEN,
    retrieve_episodes,
)
from graphiti_core.utils.maintenance.node_operations import (
    extract_attributes_from_nodes,
    extract_nodes,
    resolve_extracted_nodes,
)
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types

logger = logging.getLogger(__name__)

load_dotenv()


class AddEpisodeResults(BaseModel):
    episode: EpisodicNode
    episodic_edges: list[EpisodicEdge]
    nodes: list[EntityNode]
    edges: list[EntityEdge]
    communities: list[CommunityNode]
    community_edges: list[CommunityEdge]


class AddBulkEpisodeResults(BaseModel):
    episodes: list[EpisodicNode]
    episodic_edges: list[EpisodicEdge]
    nodes: list[EntityNode]
    edges: list[EntityEdge]
    communities: list[CommunityNode]
    community_edges: list[CommunityEdge]


class AddTripletResults(BaseModel):
    nodes: list[EntityNode]
    edges: list[EntityEdge]


class Graphiti:
    def __init__(
        self,
        uri: str | None = None,
        user: str | None = None,
        password: str | None = None,
        llm_client: LLMClient | None = None,
        embedder: EmbedderClient | None = None,
        cross_encoder: CrossEncoderClient | None = None,
        store_raw_episode_content: bool = True,
        graph_driver: GraphDriver | None = None,
        max_coroutines: int | None = None,
        tracer: Tracer | None = None,
        trace_span_prefix: str = 'graphiti',
    ):
        """
        Initialize a Graphiti instance.

        This constructor sets up a connection to a graph database and initializes
        the LLM client for natural language processing tasks.

        Parameters
        ----------
        uri : str
            The URI of the Neo4j database.
        user : str
            The username for authenticating with the Neo4j database.
        password : str
            The password for authenticating with the Neo4j database.
        llm_client : LLMClient | None, optional
            An instance of LLMClient for natural language processing tasks.
            If not provided, a default OpenAIClient will be initialized.
        embedder : EmbedderClient | None, optional
            An instance of EmbedderClient for embedding tasks.
            If not provided, a default OpenAIEmbedder will be initialized.
        cross_encoder : CrossEncoderClient | None, optional
            An instance of CrossEncoderClient for reranking tasks.
            If not provided, a default OpenAIRerankerClient will be initialized.
        store_raw_episode_content : bool, optional
            Whether to store the raw content of episodes. Defaults to True.
        graph_driver : GraphDriver | None, optional
            An instance of GraphDriver for database operations.
            If not provided, a default Neo4jDriver will be initialized.
        max_coroutines : int | None, optional
            The maximum number of concurrent operations allowed. Overrides SEMAPHORE_LIMIT set in the environment.
            If not set, the Graphiti default is used.
        tracer : Tracer | None, optional
            An OpenTelemetry tracer instance for distributed tracing. If not provided, tracing is disabled (no-op).
        trace_span_prefix : str, optional
            Prefix to prepend to all span names. Defaults to 'graphiti'.

        Returns
        -------
        None

        Notes
        -----
        This method establishes a connection to a graph database (Neo4j by default) using the provided
        credentials. It also sets up the LLM client, either using the provided client
        or by creating a default OpenAIClient.

        The default database name is defined during the driver’s construction. If a different database name
        is required, it should be specified in the URI or set separately after
        initialization.

        The OpenAI API key is expected to be set in the environment variables.
        Make sure to set the OPENAI_API_KEY environment variable before initializing
        Graphiti if you're using the default OpenAIClient.
        """

        if graph_driver:
            self.driver = graph_driver
        else:
            if uri is None:
                raise ValueError('uri must be provided when graph_driver is None')
            self.driver = Neo4jDriver(uri, user, password)

        self.store_raw_episode_content = store_raw_episode_content
        self.max_coroutines = max_coroutines
        if llm_client:
            self.llm_client = llm_client
        else:
            self.llm_client = OpenAIClient()
        if embedder:
            self.embedder = embedder
        else:
            self.embedder = OpenAIEmbedder()
        if cross_encoder:
            self.cross_encoder = cross_encoder
        else:
            self.cross_encoder = OpenAIRerankerClient()

        # Initialize tracer
        self.tracer = create_tracer(tracer, trace_span_prefix)

        # Set tracer on clients
        self.llm_client.set_tracer(self.tracer)

        self.clients = GraphitiClients(
            driver=self.driver,
            llm_client=self.llm_client,
            embedder=self.embedder,
            cross_encoder=self.cross_encoder,
            tracer=self.tracer,
        )

        # Capture telemetry event
        self._capture_initialization_telemetry()

    def _capture_initialization_telemetry(self):
        """Capture telemetry event for Graphiti initialization."""
        try:
            # Detect provider types from class names
            llm_provider = self._get_provider_type(self.llm_client)
            embedder_provider = self._get_provider_type(self.embedder)
            reranker_provider = self._get_provider_type(self.cross_encoder)
            database_provider = self._get_provider_type(self.driver)

            properties = {
                'llm_provider': llm_provider,
                'embedder_provider': embedder_provider,
                'reranker_provider': reranker_provider,
                'database_provider': database_provider,
            }

            capture_event('graphiti_initialized', properties)
        except Exception:
            # Silently handle telemetry errors
            pass

    def _get_provider_type(self, client) -> str:
        """Get provider type from client class name."""
        if client is None:
            return 'none'

        class_name = client.__class__.__name__.lower()

        # LLM providers
        if 'openai' in class_name:
            return 'openai'
        elif 'azure' in class_name:
            return 'azure'
        elif 'anthropic' in class_name:
            return 'anthropic'
        elif 'crossencoder' in class_name:
            return 'crossencoder'
        elif 'gemini' in class_name:
            return 'gemini'
        elif 'groq' in class_name:
            return 'groq'
        # Database providers
        elif 'neo4j' in class_name:
            return 'neo4j'
        elif 'falkor' in class_name:
            return 'falkordb'
        # Embedder providers
        elif 'voyage' in class_name:
            return 'voyage'
        else:
            return 'unknown'

    async def close(self):
        """
        Close the connection to the Neo4j database.

        This method safely closes the driver connection to the Neo4j database.
        It should be called when the Graphiti instance is no longer needed or
        when the application is shutting down.

        Parameters
        ----------
        self

        Returns
        -------
        None

        Notes
        -----
        It's important to close the driver connection to release system resources
        and ensure that all pending transactions are completed or rolled back.
        This method should be called as part of a cleanup process, potentially
        in a context manager or a shutdown hook.

        Example:
            graphiti = Graphiti(uri, user, password)
            try:
                # Use graphiti...
            finally:
                graphiti.close()
        """
        await self.driver.close()

    async def build_indices_and_constraints(self, delete_existing: bool = False):
        """
        Build indices and constraints in the Neo4j database.

        This method sets up the necessary indices and constraints in the Neo4j database
        to optimize query performance and ensure data integrity for the knowledge graph.

        Parameters
        ----------
        self
        delete_existing : bool, optional
            Whether to clear existing indices before creating new ones.


        Returns
        -------
        None

        Notes
        -----
        This method should typically be called once during the initial setup of the
        knowledge graph or when updating the database schema. It uses the
        driver's `build_indices_and_constraints` method to perform
        the actual database operations.

        The specific indices and constraints created depend on the implementation
        of the driver's `build_indices_and_constraints` method. Refer to the specific
        driver documentation for details on the exact database schema modifications.

        Caution: Running this method on a large existing database may take some time
        and could impact database performance during execution.
        """
        await self.driver.build_indices_and_constraints(delete_existing)

    async def _extract_and_resolve_nodes(
        self,
        episode: EpisodicNode,
        previous_episodes: list[EpisodicNode],
        entity_types: dict[str, type[BaseModel]] | None,
        excluded_entity_types: list[str] | None,
    ) -> tuple[list[EntityNode], dict[str, str], list[tuple[EntityNode, EntityNode]]]:
        """Extract nodes from episode and resolve against existing graph."""
        extracted_nodes = await extract_nodes(
            self.clients, episode, previous_episodes, entity_types, excluded_entity_types
        )

        nodes, uuid_map, duplicates = await resolve_extracted_nodes(
            self.clients,
            extracted_nodes,
            episode,
            previous_episodes,
            entity_types,
        )

        return nodes, uuid_map, duplicates

    async def _extract_and_resolve_edges(
        self,
        episode: EpisodicNode,
        extracted_nodes: list[EntityNode],
        previous_episodes: list[EpisodicNode],
        edge_type_map: dict[tuple[str, str], list[str]],
        group_id: str,
        edge_types: dict[str, type[BaseModel]] | None,
        nodes: list[EntityNode],
        uuid_map: dict[str, str],
        custom_extraction_instructions: str | None = None,
    ) -> tuple[list[EntityEdge], list[EntityEdge]]:
        """Extract edges from episode and resolve against existing graph."""
        extracted_edges = await extract_edges(
            self.clients,
            episode,
            extracted_nodes,
            previous_episodes,
            edge_type_map,
            group_id,
            edge_types,
            custom_extraction_instructions,
        )

        edges = resolve_edge_pointers(extracted_edges, uuid_map)

        resolved_edges, invalidated_edges = await resolve_extracted_edges(
            self.clients,
            edges,
            episode,
            nodes,
            edge_types or {},
            edge_type_map,
        )

        return resolved_edges, invalidated_edges

    async def _process_episode_data(
        self,
        episode: EpisodicNode,
        nodes: list[EntityNode],
        entity_edges: list[EntityEdge],
        now: datetime,
    ) -> tuple[list[EpisodicEdge], EpisodicNode]:
        """Process and save episode data to the graph."""
        episodic_edges = build_episodic_edges(nodes, episode.uuid, now)
        episode.entity_edges = [edge.uuid for edge in entity_edges]

        if not self.store_raw_episode_content:
            episode.content = ''

        await add_nodes_and_edges_bulk(
            self.driver,
            [episode],
            episodic_edges,
            nodes,
            entity_edges,
            self.embedder,
        )

        return episodic_edges, episode

    async def _extract_and_dedupe_nodes_bulk(
        self,
        episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
        edge_type_map: dict[tuple[str, str], list[str]],
        edge_types: dict[str, type[BaseModel]] | None,
        entity_types: dict[str, type[BaseModel]] | None,
        excluded_entity_types: list[str] | None,
    ) -> tuple[
        dict[str, list[EntityNode]],
        dict[str, str],
        list[list[EntityEdge]],
    ]:
        """Extract nodes and edges from all episodes and deduplicate."""
        # Extract all nodes and edges for each episode
        extracted_nodes_bulk, extracted_edges_bulk = await extract_nodes_and_edges_bulk(
            self.clients,
            episode_context,
            edge_type_map=edge_type_map,
            edge_types=edge_types,
            entity_types=entity_types,
            excluded_entity_types=excluded_entity_types,
        )

        # Dedupe extracted nodes in memory
        nodes_by_episode, uuid_map = await dedupe_nodes_bulk(
            self.clients, extracted_nodes_bulk, episode_context, entity_types
        )

        return nodes_by_episode, uuid_map, extracted_edges_bulk

    async def _resolve_nodes_and_edges_bulk(
        self,
        nodes_by_episode: dict[str, list[EntityNode]],
        edges_by_episode: dict[str, list[EntityEdge]],
        episode_context: list[tuple[EpisodicNode, list[EpisodicNode]]],
        entity_types: dict[str, type[BaseModel]] | None,
        edge_types: dict[str, type[BaseModel]] | None,
        edge_type_map: dict[tuple[str, str], list[str]],
        episodes: list[EpisodicNode],
    ) -> tuple[list[EntityNode], list[EntityEdge], list[EntityEdge], dict[str, str]]:
        """Resolve nodes and edges against the existing graph."""
        nodes_by_uuid: dict[str, EntityNode] = {
            node.uuid: node for nodes in nodes_by_episode.values() for node in nodes
        }

        # Get unique nodes per episode
        nodes_by_episode_unique: dict[str, list[EntityNode]] = {}
        nodes_uuid_set: set[str] = set()
        for episode, _ in episode_context:
            nodes_by_episode_unique[episode.uuid] = []
            nodes = [nodes_by_uuid[node.uuid] for node in nodes_by_episode[episode.uuid]]
            for node in nodes:
                if node.uuid not in nodes_uuid_set:
                    nodes_by_episode_unique[episode.uuid].append(node)
                    nodes_uuid_set.add(node.uuid)

        # Resolve nodes
        node_results = await semaphore_gather(
            *[
                resolve_extracted_nodes(
                    self.clients,
                    nodes_by_episode_unique[episode.uuid],
                    episode,
                    previous_episodes,
                    entity_types,
                )
                for episode, previous_episodes in episode_context
            ]
        )

        resolved_nodes: list[EntityNode] = []
        uuid_map: dict[str, str] = {}
        for result in node_results:
            resolved_nodes.extend(result[0])
            uuid_map.update(result[1])

        # Update nodes_by_uuid with resolved nodes
        for resolved_node in resolved_nodes:
            nodes_by_uuid[resolved_node.uuid] = resolved_node

        # Update nodes_by_episode_unique with resolved pointers
        for episode_uuid, nodes in nodes_by_episode_unique.items():
            updated_nodes: list[EntityNode] = []
            for node in nodes:
                updated_node_uuid = uuid_map.get(node.uuid, node.uuid)
                updated_node = nodes_by_uuid[updated_node_uuid]
                updated_nodes.append(updated_node)
            nodes_by_episode_unique[episode_uuid] = updated_nodes

        # Extract attributes for resolved nodes
        hydrated_nodes_results: list[list[EntityNode]] = await semaphore_gather(
            *[
                extract_attributes_from_nodes(
                    self.clients,
                    nodes_by_episode_unique[episode.uuid],
                    episode,
                    previous_episodes,
                    entity_types,
                )
                for episode, previous_episodes in episode_context
            ]
        )

        final_hydrated_nodes = [node for nodes in hydrated_nodes_results for node in nodes]

        # Resolve edges with updated pointers
        edges_by_episode_unique: dict[str, list[EntityEdge]] = {}
        edges_uuid_set: set[str] = set()
        for episode_uuid, edges in edges_by_episode.items():
            edges_with_updated_pointers = resolve_edge_pointers(edges, uuid_map)
            edges_by_episode_unique[episode_uuid] = []

            for edge in edges_with_updated_pointers:
                if edge.uuid not in edges_uuid_set:
                    edges_by_episode_unique[episode_uuid].append(edge)
                    edges_uuid_set.add(edge.uuid)

        edge_results = await semaphore_gather(
            *[
                resolve_extracted_edges(
                    self.clients,
                    edges_by_episode_unique[episode.uuid],
                    episode,
                    final_hydrated_nodes,
                    edge_types or {},
                    edge_type_map,
                )
                for episode in episodes
            ]
        )

        resolved_edges: list[EntityEdge] = []
        invalidated_edges: list[EntityEdge] = []
        for result in edge_results:
            resolved_edges.extend(result[0])
            invalidated_edges.extend(result[1])

        return final_hydrated_nodes, resolved_edges, invalidated_edges, uuid_map

    @handle_multiple_group_ids
    async def retrieve_episodes(
        self,
        reference_time: datetime,
        last_n: int = EPISODE_WINDOW_LEN,
        group_ids: list[str] | None = None,
        source: EpisodeType | None = None,
        driver: GraphDriver | None = None,
    ) -> list[EpisodicNode]:
        """
        Retrieve the last n episodic nodes from the graph.

        This method fetches a specified number of the most recent episodic nodes
        from the graph, relative to the given reference time.

        Parameters
        ----------
        reference_time : datetime
            The reference time to retrieve episodes before.
        last_n : int, optional
            The number of episodes to retrieve. Defaults to EPISODE_WINDOW_LEN.
        group_ids : list[str | None], optional
            The group ids to return data from.

        Returns
        -------
        list[EpisodicNode]
            A list of the most recent EpisodicNode objects.

        Notes
        -----
        The actual retrieval is performed by the `retrieve_episodes` function
        from the `graphiti_core.utils` module.
        """
        if driver is None:
            driver = self.clients.driver

        return await retrieve_episodes(driver, reference_time, last_n, group_ids, source)

    async def add_episode(
        self,
        name: str,
        episode_body: str,
        source_description: str,
        reference_time: datetime,
        source: EpisodeType = EpisodeType.message,
        group_id: str | None = None,
        uuid: str | None = None,
        update_communities: bool = False,
        entity_types: dict[str, type[BaseModel]] | None = None,
        excluded_entity_types: list[str] | None = None,
        previous_episode_uuids: list[str] | None = None,
        edge_types: dict[str, type[BaseModel]] | None = None,
        edge_type_map: dict[tuple[str, str], list[str]] | None = None,
        custom_extraction_instructions: str | None = None,
    ) -> AddEpisodeResults:
        """
        Process an episode and update the graph.

        This method extracts information from the episode, creates nodes and edges,
        and updates the graph database accordingly.

        Parameters
        ----------
        name : str
            The name of the episode.
        episode_body : str
            The content of the episode.
        source_description : str
            A description of the episode's source.
        reference_time : datetime
            The reference time for the episode.
        source : EpisodeType, optional
            The type of the episode. Defaults to EpisodeType.message.
        group_id : str | None
            An id for the graph partition the episode is a part of.
        uuid : str | None
            Optional uuid of the episode.
        update_communities : bool
            Optional. Whether to update communities with new node information
        entity_types : dict[str, BaseModel] | None
            Optional. Dictionary mapping entity type names to their Pydantic model definitions.
        excluded_entity_types : list[str] | None
            Optional. List of entity type names to exclude from the graph. Entities classified
            into these types will not be added to the graph. Can include 'Entity' to exclude
            the default entity type.
        previous_episode_uuids : list[str] | None
            Optional.  list of episode uuids to use as the previous episodes. If this is not provided,
            the most recent episodes by created_at date will be used.
        custom_extraction_instructions : str | None
            Optional. Custom extraction instructions string to be included in the extract entities and extract edges prompts.
            This allows for additional instructions or context to guide the extraction process.

        Returns
        -------
        None

        Notes
        -----
        This method performs several steps including node extraction, edge extraction,
        deduplication, and database updates. It also handles embedding generation
        and edge invalidation.

        It is recommended to run this method as a background process, such as in a queue.
        It's important that each episode is added sequentially and awaited before adding
        the next one. For web applications, consider using FastAPI's background tasks
        or a dedicated task queue like Celery for this purpose.

        Example using FastAPI background tasks:
            @app.post("/add_episode")
            async def add_episode_endpoint(episode_data: EpisodeData):
                background_tasks.add_task(graphiti.add_episode, **episode_data.dict())
                return {"message": "Episode processing started"}
        """
        start = time()
        now = utc_now()

        validate_entity_types(entity_types)
        validate_excluded_entity_types(excluded_entity_types, entity_types)

        if group_id is None:
            # if group_id is None, use the default group id by the provider
            # and the preset database name will be used
            group_id = get_default_group_id(self.driver.provider)
        else:
            validate_group_id(group_id)
            if group_id != self.driver._database:
                # if group_id is provided, use it as the database name
                self.driver = self.driver.clone(database=group_id)
                self.clients.driver = self.driver

        with self.tracer.start_span('add_episode') as span:
            try:
                # Retrieve previous episodes for context
                previous_episodes = (
                    await self.retrieve_episodes(
                        reference_time,
                        last_n=RELEVANT_SCHEMA_LIMIT,
                        group_ids=[group_id],
                        source=source,
                    )
                    if previous_episode_uuids is None
                    else await EpisodicNode.get_by_uuids(self.driver, previous_episode_uuids)
                )

                # Get or create episode
                episode = (
                    await EpisodicNode.get_by_uuid(self.driver, uuid)
                    if uuid is not None
                    else EpisodicNode(
                        name=name,
                        group_id=group_id,
                        labels=[],
                        source=source,
                        content=episode_body,
                        source_description=source_description,
                        created_at=now,
                        valid_at=reference_time,
                    )
                )

                # Create default edge type map
                edge_type_map_default = (
                    {('Entity', 'Entity'): list(edge_types.keys())}
                    if edge_types is not None
                    else {('Entity', 'Entity'): []}
                )

                # Extract and resolve nodes
                extracted_nodes = await extract_nodes(
                    self.clients,
                    episode,
                    previous_episodes,
                    entity_types,
                    excluded_entity_types,
                    custom_extraction_instructions,
                )

                nodes, uuid_map, _ = await resolve_extracted_nodes(
                    self.clients,
                    extracted_nodes,
                    episode,
                    previous_episodes,
                    entity_types,
                )

                # Extract and resolve edges in parallel with attribute extraction
                resolved_edges, invalidated_edges = await self._extract_and_resolve_edges(
                    episode,
                    extracted_nodes,
                    previous_episodes,
                    edge_type_map or edge_type_map_default,
                    group_id,
                    edge_types,
                    nodes,
                    uuid_map,
                    custom_extraction_instructions,
                )

                # Extract node attributes
                hydrated_nodes = await extract_attributes_from_nodes(
                    self.clients, nodes, episode, previous_episodes, entity_types
                )

                entity_edges = resolved_edges + invalidated_edges

                # Process and save episode data
                episodic_edges, episode = await self._process_episode_data(
                    episode, hydrated_nodes, entity_edges, now
                )

                # Update communities if requested
                communities = []
                community_edges = []
                if update_communities:
                    communities, community_edges = await semaphore_gather(
                        *[
                            update_community(self.driver, self.llm_client, self.embedder, node)
                            for node in nodes
                        ],
                        max_coroutines=self.max_coroutines,
                    )

                end = time()

                # Add span attributes
                span.add_attributes(
                    {
                        'episode.uuid': episode.uuid,
                        'episode.source': source.value,
                        'episode.reference_time': reference_time.isoformat(),
                        'group_id': group_id,
                        'node.count': len(hydrated_nodes),
                        'edge.count': len(entity_edges),
                        'edge.invalidated_count': len(invalidated_edges),
                        'previous_episodes.count': len(previous_episodes),
                        'entity_types.count': len(entity_types) if entity_types else 0,
                        'edge_types.count': len(edge_types) if edge_types else 0,
                        'update_communities': update_communities,
                        'communities.count': len(communities) if update_communities else 0,
                        'duration_ms': (end - start) * 1000,
                    }
                )

                logger.info(f'Completed add_episode in {(end - start) * 1000} ms')

                return AddEpisodeResults(
                    episode=episode,
                    episodic_edges=episodic_edges,
                    nodes=hydrated_nodes,
                    edges=entity_edges,
                    communities=communities,
                    community_edges=community_edges,
                )

            except Exception as e:
                span.set_status('error', str(e))
                span.record_exception(e)
                raise e

    async def add_episode_bulk(
        self,
        bulk_episodes: list[RawEpisode],
        group_id: str | None = None,
        entity_types: dict[str, type[BaseModel]] | None = None,
        excluded_entity_types: list[str] | None = None,
        edge_types: dict[str, type[BaseModel]] | None = None,
        edge_type_map: dict[tuple[str, str], list[str]] | None = None,
    ) -> AddBulkEpisodeResults:
        """
        Process multiple episodes in bulk and update the graph.

        This method extracts information from multiple episodes, creates nodes and edges,
        and updates the graph database accordingly, all in a single batch operation.

        Parameters
        ----------
        bulk_episodes : list[RawEpisode]
            A list of RawEpisode objects to be processed and added to the graph.
        group_id : str | None
            An id for the graph partition the episode is a part of.

        Returns
        -------
        AddBulkEpisodeResults

        Notes
        -----
        This method performs several steps including:
        - Saving all episodes to the database
        - Retrieving previous episode context for each new episode
        - Extracting nodes and edges from all episodes
        - Generating embeddings for nodes and edges
        - Deduplicating nodes and edges
        - Saving nodes, episodic edges, and entity edges to the knowledge graph

        This bulk operation is designed for efficiency when processing multiple episodes
        at once. However, it's important to ensure that the bulk operation doesn't
        overwhelm system resources. Consider implementing rate limiting or chunking for
        very large batches of episodes.

        Important: This method does not perform edge invalidation or date extraction steps.
        If these operations are required, use the `add_episode` method instead for each
        individual episode.
        """
        with self.tracer.start_span('add_episode_bulk') as bulk_span:
            bulk_span.add_attributes({'episode.count': len(bulk_episodes)})

            try:
                start = time()
                now = utc_now()

                # if group_id is None, use the default group id by the provider
                if group_id is None:
                    group_id = get_default_group_id(self.driver.provider)
                else:
                    validate_group_id(group_id)
                    if group_id != self.driver._database:
                        # if group_id is provided, use it as the database name
                        self.driver = self.driver.clone(database=group_id)
                        self.clients.driver = self.driver

                # Create default edge type map
                edge_type_map_default = (
                    {('Entity', 'Entity'): list(edge_types.keys())}
                    if edge_types is not None
                    else {('Entity', 'Entity'): []}
                )

                episodes = [
                    await EpisodicNode.get_by_uuid(self.driver, episode.uuid)
                    if episode.uuid is not None
                    else EpisodicNode(
                        name=episode.name,
                        labels=[],
                        source=episode.source,
                        content=episode.content,
                        source_description=episode.source_description,
                        group_id=group_id,
                        created_at=now,
                        valid_at=episode.reference_time,
                    )
                    for episode in bulk_episodes
                ]

                # Save all episodes
                await add_nodes_and_edges_bulk(
                    driver=self.driver,
                    episodic_nodes=episodes,
                    episodic_edges=[],
                    entity_nodes=[],
                    entity_edges=[],
                    embedder=self.embedder,
                )

                # Get previous episode context for each episode
                episode_context = await retrieve_previous_episodes_bulk(self.driver, episodes)

                # Extract and dedupe nodes and edges
                (
                    nodes_by_episode,
                    uuid_map,
                    extracted_edges_bulk,
                ) = await self._extract_and_dedupe_nodes_bulk(
                    episode_context,
                    edge_type_map or edge_type_map_default,
                    edge_types,
                    entity_types,
                    excluded_entity_types,
                )

                # Create Episodic Edges
                episodic_edges: list[EpisodicEdge] = []
                for episode_uuid, nodes in nodes_by_episode.items():
                    episodic_edges.extend(build_episodic_edges(nodes, episode_uuid, now))

                # Re-map edge pointers and dedupe edges
                extracted_edges_bulk_updated: list[list[EntityEdge]] = [
                    resolve_edge_pointers(edges, uuid_map) for edges in extracted_edges_bulk
                ]

                edges_by_episode = await dedupe_edges_bulk(
                    self.clients,
                    extracted_edges_bulk_updated,
                    episode_context,
                    [],
                    edge_types or {},
                    edge_type_map or edge_type_map_default,
                )

                # Resolve nodes and edges against the existing graph
                (
                    final_hydrated_nodes,
                    resolved_edges,
                    invalidated_edges,
                    final_uuid_map,
                ) = await self._resolve_nodes_and_edges_bulk(
                    nodes_by_episode,
                    edges_by_episode,
                    episode_context,
                    entity_types,
                    edge_types,
                    edge_type_map or edge_type_map_default,
                    episodes,
                )

                # Resolved pointers for episodic edges
                resolved_episodic_edges = resolve_edge_pointers(episodic_edges, final_uuid_map)

                # save data to KG
                await add_nodes_and_edges_bulk(
                    self.driver,
                    episodes,
                    resolved_episodic_edges,
                    final_hydrated_nodes,
                    resolved_edges + invalidated_edges,
                    self.embedder,
                )

                end = time()

                # Add span attributes
                bulk_span.add_attributes(
                    {
                        'group_id': group_id,
                        'node.count': len(final_hydrated_nodes),
                        'edge.count': len(resolved_edges + invalidated_edges),
                        'duration_ms': (end - start) * 1000,
                    }
                )

                logger.info(f'Completed add_episode_bulk in {(end - start) * 1000} ms')

                return AddBulkEpisodeResults(
                    episodes=episodes,
                    episodic_edges=resolved_episodic_edges,
                    nodes=final_hydrated_nodes,
                    edges=resolved_edges + invalidated_edges,
                    communities=[],
                    community_edges=[],
                )

            except Exception as e:
                bulk_span.set_status('error', str(e))
                bulk_span.record_exception(e)
                raise e

    @handle_multiple_group_ids
    async def build_communities(
        self, group_ids: list[str] | None = None, driver: GraphDriver | None = None
    ) -> tuple[list[CommunityNode], list[CommunityEdge]]:
        """
        Use a community clustering algorithm to find communities of nodes. Create community nodes summarising
        the content of these communities.
        ----------
        group_ids : list[str] | None
            Optional. Create communities only for the listed group_ids. If blank the entire graph will be used.
        """
        if driver is None:
            driver = self.clients.driver

        # Clear existing communities
        await remove_communities(driver)

        community_nodes, community_edges = await build_communities(
            driver, self.llm_client, group_ids
        )

        await semaphore_gather(
            *[node.generate_name_embedding(self.embedder) for node in community_nodes],
            max_coroutines=self.max_coroutines,
        )

        await semaphore_gather(
            *[node.save(driver) for node in community_nodes],
            max_coroutines=self.max_coroutines,
        )
        await semaphore_gather(
            *[edge.save(driver) for edge in community_edges],
            max_coroutines=self.max_coroutines,
        )

        return community_nodes, community_edges

    @handle_multiple_group_ids
    async def search(
        self,
        query: str,
        center_node_uuid: str | None = None,
        group_ids: list[str] | None = None,
        num_results=DEFAULT_SEARCH_LIMIT,
        search_filter: SearchFilters | None = None,
        driver: GraphDriver | None = None,
    ) -> list[EntityEdge]:
        """
        Perform a hybrid search on the knowledge graph.

        This method executes a search query on the graph, combining vector and
        text-based search techniques to retrieve relevant facts, returning the edges as a string.

        This is our basic out-of-the-box search, for more robust results we recommend using our more advanced
        search method graphiti.search_().

        Parameters
        ----------
        query : str
            The search query string.
        center_node_uuid: str, optional
            Facts will be reranked based on proximity to this node
        group_ids : list[str | None] | None, optional
            The graph partitions to return data from.
        num_results : int, optional
            The maximum number of results to return. Defaults to 10.

        Returns
        -------
        list
            A list of EntityEdge objects that are relevant to the search query.

        Notes
        -----
        This method uses a SearchConfig with num_episodes set to 0 and
        num_results set to the provided num_results parameter.

        The search is performed using the current date and time as the reference
        point for temporal relevance.
        """
        search_config = (
            EDGE_HYBRID_SEARCH_RRF if center_node_uuid is None else EDGE_HYBRID_SEARCH_NODE_DISTANCE
        )
        search_config.limit = num_results

        edges = (
            await search(
                self.clients,
                query,
                group_ids,
                search_config,
                search_filter if search_filter is not None else SearchFilters(),
                driver=driver,
                center_node_uuid=center_node_uuid,
            )
        ).edges

        return edges

    async def _search(
        self,
        query: str,
        config: SearchConfig,
        group_ids: list[str] | None = None,
        center_node_uuid: str | None = None,
        bfs_origin_node_uuids: list[str] | None = None,
        search_filter: SearchFilters | None = None,
    ) -> SearchResults:
        """DEPRECATED"""
        return await self.search_(
            query, config, group_ids, center_node_uuid, bfs_origin_node_uuids, search_filter
        )

    @handle_multiple_group_ids
    async def search_(
        self,
        query: str,
        config: SearchConfig = COMBINED_HYBRID_SEARCH_CROSS_ENCODER,
        group_ids: list[str] | None = None,
        center_node_uuid: str | None = None,
        bfs_origin_node_uuids: list[str] | None = None,
        search_filter: SearchFilters | None = None,
        driver: GraphDriver | None = None,
    ) -> SearchResults:
        """search_ (replaces _search) is our advanced search method that returns Graph objects (nodes and edges) rather
        than a list of facts. This endpoint allows the end user to utilize more advanced features such as filters and
        different search and reranker methodologies across different layers in the graph.

        For different config recipes refer to search/search_config_recipes.
        """

        return await search(
            self.clients,
            query,
            group_ids,
            config,
            search_filter if search_filter is not None else SearchFilters(),
            center_node_uuid,
            bfs_origin_node_uuids,
            driver=driver,
        )

    async def get_nodes_and_edges_by_episode(self, episode_uuids: list[str]) -> SearchResults:
        episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids)

        edges_list = await semaphore_gather(
            *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes],
            max_coroutines=self.max_coroutines,
        )

        edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst]

        nodes = await get_mentioned_nodes(self.driver, episodes)

        return SearchResults(edges=edges, nodes=nodes)

    async def add_triplet(
        self, source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
    ) -> AddTripletResults:
        if source_node.name_embedding is None:
            await source_node.generate_name_embedding(self.embedder)
        if target_node.name_embedding is None:
            await target_node.generate_name_embedding(self.embedder)
        if edge.fact_embedding is None:
            await edge.generate_embedding(self.embedder)

        try:
            resolved_source = await EntityNode.get_by_uuid(self.driver, source_node.uuid)
        except NodeNotFoundError:
            resolved_source_nodes, _, _ = await resolve_extracted_nodes(
                self.clients,
                [source_node],
            )
            resolved_source = resolved_source_nodes[0]

        try:
            resolved_target = await EntityNode.get_by_uuid(self.driver, target_node.uuid)
        except NodeNotFoundError:
            resolved_target_nodes, _, _ = await resolve_extracted_nodes(
                self.clients,
                [target_node],
            )
            resolved_target = resolved_target_nodes[0]

        nodes = [resolved_source, resolved_target]

        # Merge user-provided properties from original nodes into resolved nodes (excluding uuid)
        # Update attributes dictionary (merge rather than replace)
        if source_node.attributes:
            resolved_source.attributes.update(source_node.attributes)
        if target_node.attributes:
            resolved_target.attributes.update(target_node.attributes)

        # Update summary if provided by user (non-empty string)
        if source_node.summary:
            resolved_source.summary = source_node.summary
        if target_node.summary:
            resolved_target.summary = target_node.summary

        # Update labels (merge with existing)
        if source_node.labels:
            resolved_source.labels = list(set(resolved_source.labels) | set(source_node.labels))
        if target_node.labels:
            resolved_target.labels = list(set(resolved_target.labels) | set(target_node.labels))

        edge.source_node_uuid = resolved_source.uuid
        edge.target_node_uuid = resolved_target.uuid

        valid_edges = await EntityEdge.get_between_nodes(
            self.driver, edge.source_node_uuid, edge.target_node_uuid
        )

        related_edges = (
            await search(
                self.clients,
                edge.fact,
                group_ids=[edge.group_id],
                config=EDGE_HYBRID_SEARCH_RRF,
                search_filter=SearchFilters(edge_uuids=[edge.uuid for edge in valid_edges]),
            )
        ).edges
        existing_edges = (
            await search(
                self.clients,
                edge.fact,
                group_ids=[edge.group_id],
                config=EDGE_HYBRID_SEARCH_RRF,
                search_filter=SearchFilters(),
            )
        ).edges

        resolved_edge, invalidated_edges, _ = await resolve_extracted_edge(
            self.llm_client,
            edge,
            related_edges,
            existing_edges,
            EpisodicNode(
                name='',
                source=EpisodeType.text,
                source_description='',
                content='',
                valid_at=edge.valid_at or utc_now(),
                entity_edges=[],
                group_id=edge.group_id,
            ),
            None,
            None,
        )

        edges: list[EntityEdge] = [resolved_edge] + invalidated_edges

        await create_entity_edge_embeddings(self.embedder, edges)
        await create_entity_node_embeddings(self.embedder, nodes)

        await add_nodes_and_edges_bulk(self.driver, [], [], nodes, edges, self.embedder)
        return AddTripletResults(edges=edges, nodes=nodes)

    async def remove_episode(self, episode_uuid: str):
        # Find the episode to be deleted
        episode = await EpisodicNode.get_by_uuid(self.driver, episode_uuid)

        # Find edges mentioned by the episode
        edges = await EntityEdge.get_by_uuids(self.driver, episode.entity_edges)

        # We should only delete edges created by the episode
        edges_to_delete: list[EntityEdge] = []
        for edge in edges:
            if edge.episodes and edge.episodes[0] == episode.uuid:
                edges_to_delete.append(edge)

        # Find nodes mentioned by the episode
        nodes = await get_mentioned_nodes(self.driver, [episode])
        # We should delete all nodes that are only mentioned in the deleted episode
        nodes_to_delete: list[EntityNode] = []
        for node in nodes:
            query: LiteralString = 'MATCH (e:Episodic)-[:MENTIONS]->(n:Entity {uuid: $uuid}) RETURN count(*) AS episode_count'
            records, _, _ = await self.driver.execute_query(query, uuid=node.uuid, routing_='r')

            for record in records:
                if record['episode_count'] == 1:
                    nodes_to_delete.append(node)

        await Edge.delete_by_uuids(self.driver, [edge.uuid for edge in edges_to_delete])
        await Node.delete_by_uuids(self.driver, [node.uuid for node in nodes_to_delete])

        await episode.delete(self.driver)

```

--------------------------------------------------------------------------------
/tests/test_graphiti_mock.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

from datetime import datetime, timedelta
from unittest.mock import Mock

import numpy as np
import pytest

from graphiti_core.cross_encoder.client import CrossEncoderClient
from graphiti_core.edges import CommunityEdge, EntityEdge, EpisodicEdge
from graphiti_core.graphiti import Graphiti
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import CommunityNode, EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search_filters import ComparisonOperator, DateFilter, SearchFilters
from graphiti_core.search.search_utils import (
    community_fulltext_search,
    community_similarity_search,
    edge_bfs_search,
    edge_fulltext_search,
    edge_similarity_search,
    episode_fulltext_search,
    episode_mentions_reranker,
    get_communities_by_nodes,
    get_edge_invalidation_candidates,
    get_embeddings_for_communities,
    get_embeddings_for_edges,
    get_embeddings_for_nodes,
    get_mentioned_nodes,
    get_relevant_edges,
    get_relevant_nodes,
    node_bfs_search,
    node_distance_reranker,
    node_fulltext_search,
    node_similarity_search,
)
from graphiti_core.utils.bulk_utils import add_nodes_and_edges_bulk
from graphiti_core.utils.maintenance.community_operations import (
    determine_entity_community,
    get_community_clusters,
    remove_communities,
)
from graphiti_core.utils.maintenance.edge_operations import filter_existing_duplicate_of_edges
from tests.helpers_test import (
    GraphProvider,
    assert_entity_edge_equals,
    assert_entity_node_equals,
    assert_episodic_edge_equals,
    assert_episodic_node_equals,
    get_edge_count,
    get_node_count,
    group_id,
    group_id_2,
)

pytest_plugins = ('pytest_asyncio',)


@pytest.fixture
def mock_llm_client():
    """Create a mock LLM"""
    mock_llm = Mock(spec=LLMClient)
    mock_llm.config = Mock()
    mock_llm.model = 'test-model'
    mock_llm.small_model = 'test-small-model'
    mock_llm.temperature = 0.0
    mock_llm.max_tokens = 1000
    mock_llm.cache_enabled = False
    mock_llm.cache_dir = None

    # Mock the public method that's actually called
    mock_llm.generate_response = Mock()
    mock_llm.generate_response.return_value = {
        'tool_calls': [
            {
                'name': 'extract_entities',
                'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
            }
        ]
    }

    return mock_llm


@pytest.fixture
def mock_cross_encoder_client():
    """Create a mock LLM"""
    mock_llm = Mock(spec=CrossEncoderClient)
    mock_llm.config = Mock()

    # Mock the public method that's actually called
    mock_llm.rerank = Mock()
    mock_llm.rerank.return_value = {
        'tool_calls': [
            {
                'name': 'extract_entities',
                'arguments': {'entities': [{'entity': 'test_entity', 'entity_type': 'test_type'}]},
            }
        ]
    }

    return mock_llm


@pytest.mark.asyncio
async def test_add_bulk(graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as test fails on FalkorDB')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )

    await graphiti.build_indices_and_constraints()

    now = datetime.now()

    # Create episodic nodes
    episode_node_1 = EpisodicNode(
        name='test_episode',
        group_id=group_id,
        labels=[],
        created_at=now,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Alice likes Bob',
        valid_at=now,
        entity_edges=[],  # Filled in later
    )
    episode_node_2 = EpisodicNode(
        name='test_episode_2',
        group_id=group_id,
        labels=[],
        created_at=now,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Bob adores Alice',
        valid_at=now,
        entity_edges=[],  # Filled in later
    )

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        group_id=group_id,
        labels=['Entity', 'Person'],
        created_at=now,
        summary='test_entity_1 summary',
        attributes={'age': 30, 'location': 'New York'},
    )
    await entity_node_1.generate_name_embedding(mock_embedder)

    entity_node_2 = EntityNode(
        name='test_entity_2',
        group_id=group_id,
        labels=['Entity', 'Person2'],
        created_at=now,
        summary='test_entity_2 summary',
        attributes={'age': 25, 'location': 'Los Angeles'},
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    entity_node_3 = EntityNode(
        name='test_entity_3',
        group_id=group_id,
        labels=['Entity', 'City', 'Location'],
        created_at=now,
        summary='test_entity_3 summary',
        attributes={'age': 25, 'location': 'Los Angeles'},
    )
    await entity_node_3.generate_name_embedding(mock_embedder)

    entity_node_4 = EntityNode(
        name='test_entity_4',
        group_id=group_id,
        labels=['Entity'],
        created_at=now,
        summary='test_entity_4 summary',
        attributes={'age': 25, 'location': 'Los Angeles'},
    )
    await entity_node_4.generate_name_embedding(mock_embedder)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        created_at=now,
        name='likes',
        fact='test_entity_1 relates to test_entity_2',
        episodes=[],
        expired_at=now,
        valid_at=now,
        invalid_at=now,
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)

    entity_edge_2 = EntityEdge(
        source_node_uuid=entity_node_3.uuid,
        target_node_uuid=entity_node_4.uuid,
        created_at=now,
        name='relates_to',
        fact='test_entity_3 relates to test_entity_4',
        episodes=[],
        expired_at=now,
        valid_at=now,
        invalid_at=now,
        group_id=group_id,
    )
    await entity_edge_2.generate_embedding(mock_embedder)

    # Create episodic to entity edges
    episodic_edge_1 = EpisodicEdge(
        source_node_uuid=episode_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=now,
        group_id=group_id,
    )
    episodic_edge_2 = EpisodicEdge(
        source_node_uuid=episode_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        created_at=now,
        group_id=group_id,
    )
    episodic_edge_3 = EpisodicEdge(
        source_node_uuid=episode_node_2.uuid,
        target_node_uuid=entity_node_3.uuid,
        created_at=now,
        group_id=group_id,
    )
    episodic_edge_4 = EpisodicEdge(
        source_node_uuid=episode_node_2.uuid,
        target_node_uuid=entity_node_4.uuid,
        created_at=now,
        group_id=group_id,
    )

    # Cross reference the ids
    episode_node_1.entity_edges = [entity_edge_1.uuid]
    episode_node_2.entity_edges = [entity_edge_2.uuid]
    entity_edge_1.episodes = [episode_node_1.uuid, episode_node_2.uuid]
    entity_edge_2.episodes = [episode_node_2.uuid]

    # Test add bulk
    await add_nodes_and_edges_bulk(
        graph_driver,
        [episode_node_1, episode_node_2],
        [episodic_edge_1, episodic_edge_2, episodic_edge_3, episodic_edge_4],
        [entity_node_1, entity_node_2, entity_node_3, entity_node_4],
        [entity_edge_1, entity_edge_2],
        mock_embedder,
    )

    node_ids = [
        episode_node_1.uuid,
        episode_node_2.uuid,
        entity_node_1.uuid,
        entity_node_2.uuid,
        entity_node_3.uuid,
        entity_node_4.uuid,
    ]
    edge_ids = [
        episodic_edge_1.uuid,
        episodic_edge_2.uuid,
        episodic_edge_3.uuid,
        episodic_edge_4.uuid,
        entity_edge_1.uuid,
        entity_edge_2.uuid,
    ]
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == len(node_ids)
    edge_count = await get_edge_count(graph_driver, edge_ids)
    assert edge_count == len(edge_ids)

    # Test episodic nodes
    retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_1.uuid)
    await assert_episodic_node_equals(retrieved_episode, episode_node_1)

    retrieved_episode = await EpisodicNode.get_by_uuid(graph_driver, episode_node_2.uuid)
    await assert_episodic_node_equals(retrieved_episode, episode_node_2)

    # Test entity nodes
    retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_1.uuid)
    await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_1)

    retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_2.uuid)
    await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_2)

    retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_3.uuid)
    await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_3)

    retrieved_entity_node = await EntityNode.get_by_uuid(graph_driver, entity_node_4.uuid)
    await assert_entity_node_equals(graph_driver, retrieved_entity_node, entity_node_4)

    # Test episodic edges
    retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_1.uuid)
    await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_1)

    retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_2.uuid)
    await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_2)

    retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_3.uuid)
    await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_3)

    retrieved_episode_edge = await EpisodicEdge.get_by_uuid(graph_driver, episodic_edge_4.uuid)
    await assert_episodic_edge_equals(retrieved_episode_edge, episodic_edge_4)

    # Test entity edges
    retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_1.uuid)
    await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_1)

    retrieved_entity_edge = await EntityEdge.get_by_uuid(graph_driver, entity_edge_2.uuid)
    await assert_entity_edge_equals(graph_driver, retrieved_entity_edge, entity_edge_2)


@pytest.mark.asyncio
async def test_remove_episode(
    graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )

    await graphiti.build_indices_and_constraints()

    now = datetime.now()

    # Create episodic nodes
    episode_node = EpisodicNode(
        name='test_episode',
        group_id=group_id,
        labels=[],
        created_at=now,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Alice likes Bob',
        valid_at=now,
        entity_edges=[],  # Filled in later
    )

    # Create entity nodes
    alice_node = EntityNode(
        name='Alice',
        group_id=group_id,
        labels=['Entity', 'Person'],
        created_at=now,
        summary='Alice summary',
        attributes={'age': 30, 'location': 'New York'},
    )
    await alice_node.generate_name_embedding(mock_embedder)

    bob_node = EntityNode(
        name='Bob',
        group_id=group_id,
        labels=['Entity', 'Person2'],
        created_at=now,
        summary='Bob summary',
        attributes={'age': 25, 'location': 'Los Angeles'},
    )
    await bob_node.generate_name_embedding(mock_embedder)

    # Create entity to entity edge
    entity_edge = EntityEdge(
        source_node_uuid=alice_node.uuid,
        target_node_uuid=bob_node.uuid,
        created_at=now,
        name='likes',
        fact='Alice likes Bob',
        episodes=[],
        expired_at=now,
        valid_at=now,
        invalid_at=now,
        group_id=group_id,
    )
    await entity_edge.generate_embedding(mock_embedder)

    # Create episodic to entity edges
    episodic_alice_edge = EpisodicEdge(
        source_node_uuid=episode_node.uuid,
        target_node_uuid=alice_node.uuid,
        created_at=now,
        group_id=group_id,
    )
    episodic_bob_edge = EpisodicEdge(
        source_node_uuid=episode_node.uuid,
        target_node_uuid=bob_node.uuid,
        created_at=now,
        group_id=group_id,
    )

    # Cross reference the ids
    episode_node.entity_edges = [entity_edge.uuid]
    entity_edge.episodes = [episode_node.uuid]

    # Test add bulk
    await add_nodes_and_edges_bulk(
        graph_driver,
        [episode_node],
        [episodic_alice_edge, episodic_bob_edge],
        [alice_node, bob_node],
        [entity_edge],
        mock_embedder,
    )

    node_ids = [episode_node.uuid, alice_node.uuid, bob_node.uuid]
    edge_ids = [episodic_alice_edge.uuid, episodic_bob_edge.uuid, entity_edge.uuid]
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 3
    edge_count = await get_edge_count(graph_driver, edge_ids)
    assert edge_count == 3

    # Test remove episode
    await graphiti.remove_episode(episode_node.uuid)
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 0
    edge_count = await get_edge_count(graph_driver, edge_ids)
    assert edge_count == 0

    # Test add bulk again
    await add_nodes_and_edges_bulk(
        graph_driver,
        [episode_node],
        [episodic_alice_edge, episodic_bob_edge],
        [alice_node, bob_node],
        [entity_edge],
        mock_embedder,
    )
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 3
    edge_count = await get_edge_count(graph_driver, edge_ids)
    assert edge_count == 3


@pytest.mark.asyncio
async def test_graphiti_retrieve_episodes(
    graph_driver, mock_llm_client, mock_embedder, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as test fails on FalkorDB')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )

    await graphiti.build_indices_and_constraints()

    now = datetime.now()
    valid_at_1 = now - timedelta(days=2)
    valid_at_2 = now - timedelta(days=4)
    valid_at_3 = now - timedelta(days=6)

    # Create episodic nodes
    episode_node_1 = EpisodicNode(
        name='test_episode_1',
        labels=[],
        created_at=now,
        valid_at=valid_at_1,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Test message 1',
        entity_edges=[],
        group_id=group_id,
    )
    episode_node_2 = EpisodicNode(
        name='test_episode_2',
        labels=[],
        created_at=now,
        valid_at=valid_at_2,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Test message 2',
        entity_edges=[],
        group_id=group_id,
    )
    episode_node_3 = EpisodicNode(
        name='test_episode_3',
        labels=[],
        created_at=now,
        valid_at=valid_at_3,
        source=EpisodeType.message,
        source_description='conversation message',
        content='Test message 3',
        entity_edges=[],
        group_id=group_id,
    )

    # Save the nodes
    await episode_node_1.save(graph_driver)
    await episode_node_2.save(graph_driver)
    await episode_node_3.save(graph_driver)

    node_ids = [episode_node_1.uuid, episode_node_2.uuid, episode_node_3.uuid]
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 3

    # Retrieve episodes
    query_time = now - timedelta(days=3)
    episodes = await graphiti.retrieve_episodes(
        query_time, last_n=5, group_ids=[group_id], source=EpisodeType.message
    )
    assert len(episodes) == 2
    assert episodes[0].name == episode_node_3.name
    assert episodes[1].name == episode_node_2.name


@pytest.mark.asyncio
async def test_filter_existing_duplicate_of_edges(graph_driver, mock_embedder):
    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)
    entity_node_4 = EntityNode(
        name='test_entity_4',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_4.generate_name_embedding(mock_embedder)

    # Save the nodes
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_node_4.save(graph_driver)

    node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 4

    # Create duplicate entity edge
    entity_edge = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='IS_DUPLICATE_OF',
        fact='test_entity_1 is a duplicate of test_entity_2',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge.generate_embedding(mock_embedder)
    await entity_edge.save(graph_driver)

    # Filter duplicate entity edges
    duplicate_node_tuples = [
        (entity_node_1, entity_node_2),
        (entity_node_3, entity_node_4),
    ]
    node_tuples = await filter_existing_duplicate_of_edges(graph_driver, duplicate_node_tuples)
    assert len(node_tuples) == 1
    assert [node.name for node in node_tuples[0]] == [entity_node_3.name, entity_node_4.name]


@pytest.mark.asyncio
async def test_determine_entity_community(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as test fails on FalkorDB')

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)
    entity_node_4 = EntityNode(
        name='test_entity_4',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_4.generate_name_embedding(mock_embedder)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_4.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_4',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)
    entity_edge_2 = EntityEdge(
        source_node_uuid=entity_node_2.uuid,
        target_node_uuid=entity_node_4.uuid,
        name='RELATES_TO',
        fact='test_entity_2 relates to test_entity_4',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_2.generate_embedding(mock_embedder)
    entity_edge_3 = EntityEdge(
        source_node_uuid=entity_node_3.uuid,
        target_node_uuid=entity_node_4.uuid,
        name='RELATES_TO',
        fact='test_entity_3 relates to test_entity_4',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_3.generate_embedding(mock_embedder)

    # Create community nodes
    community_node_1 = CommunityNode(
        name='test_community_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_1.generate_name_embedding(mock_embedder)
    community_node_2 = CommunityNode(
        name='test_community_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_2.generate_name_embedding(mock_embedder)

    # Create community to entity edges
    community_edge_1 = CommunityEdge(
        source_node_uuid=community_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )
    community_edge_2 = CommunityEdge(
        source_node_uuid=community_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )
    community_edge_3 = CommunityEdge(
        source_node_uuid=community_node_2.uuid,
        target_node_uuid=entity_node_3.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_node_4.save(graph_driver)
    await community_node_1.save(graph_driver)
    await community_node_2.save(graph_driver)

    await entity_edge_1.save(graph_driver)
    await entity_edge_2.save(graph_driver)
    await entity_edge_3.save(graph_driver)
    await community_edge_1.save(graph_driver)
    await community_edge_2.save(graph_driver)
    await community_edge_3.save(graph_driver)

    node_ids = [
        entity_node_1.uuid,
        entity_node_2.uuid,
        entity_node_3.uuid,
        entity_node_4.uuid,
        community_node_1.uuid,
        community_node_2.uuid,
    ]
    edge_ids = [
        entity_edge_1.uuid,
        entity_edge_2.uuid,
        entity_edge_3.uuid,
        community_edge_1.uuid,
        community_edge_2.uuid,
        community_edge_3.uuid,
    ]
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 6
    edge_count = await get_edge_count(graph_driver, edge_ids)
    assert edge_count == 6

    # Determine entity community
    community, is_new = await determine_entity_community(graph_driver, entity_node_4)
    assert community.name == community_node_1.name
    assert is_new

    # Add entity to community edge
    community_edge_4 = CommunityEdge(
        source_node_uuid=community_node_1.uuid,
        target_node_uuid=entity_node_4.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_edge_4.save(graph_driver)

    # Determine entity community again
    community, is_new = await determine_entity_community(graph_driver, entity_node_4)
    assert community.name == community_node_1.name
    assert not is_new

    await remove_communities(graph_driver)
    node_count = await get_node_count(graph_driver, [community_node_1.uuid, community_node_2.uuid])
    assert node_count == 0


@pytest.mark.asyncio
async def test_get_community_clusters(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as test fails on FalkorDB')

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id_2,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)
    entity_node_4 = EntityNode(
        name='test_entity_4',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id_2,
    )
    await entity_node_4.generate_name_embedding(mock_embedder)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)
    entity_edge_2 = EntityEdge(
        source_node_uuid=entity_node_3.uuid,
        target_node_uuid=entity_node_4.uuid,
        name='RELATES_TO',
        fact='test_entity_3 relates to test_entity_4',
        created_at=datetime.now(),
        group_id=group_id_2,
    )
    await entity_edge_2.generate_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_node_4.save(graph_driver)
    await entity_edge_1.save(graph_driver)
    await entity_edge_2.save(graph_driver)

    node_ids = [entity_node_1.uuid, entity_node_2.uuid, entity_node_3.uuid, entity_node_4.uuid]
    edge_ids = [entity_edge_1.uuid, entity_edge_2.uuid]
    node_count = await get_node_count(graph_driver, node_ids)
    assert node_count == 4
    edge_count = await get_edge_count(graph_driver, edge_ids)
    assert edge_count == 2

    # Get community clusters
    clusters = await get_community_clusters(graph_driver, group_ids=None)
    assert len(clusters) == 2
    assert len(clusters[0]) == 2
    assert len(clusters[1]) == 2
    entities_1 = set([node.name for node in clusters[0]])
    entities_2 = set([node.name for node in clusters[1]])
    assert entities_1 == set(['test_entity_1', 'test_entity_2']) or entities_2 == set(
        ['test_entity_1', 'test_entity_2']
    )
    assert entities_1 == set(['test_entity_3', 'test_entity_4']) or entities_2 == set(
        ['test_entity_3', 'test_entity_4']
    )


@pytest.mark.asyncio
async def test_get_mentioned_nodes(graph_driver, mock_embedder):
    # Create episodic nodes
    episodic_node_1 = EpisodicNode(
        name='test_episodic_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
        source=EpisodeType.message,
        source_description='test_source_description',
        content='test_content',
        valid_at=datetime.now(),
    )
    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)

    # Create episodic to entity edges
    episodic_edge_1 = EpisodicEdge(
        source_node_uuid=episodic_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )

    # Save the graph
    await episodic_node_1.save(graph_driver)
    await entity_node_1.save(graph_driver)
    await episodic_edge_1.save(graph_driver)

    # Get mentioned nodes
    mentioned_nodes = await get_mentioned_nodes(graph_driver, [episodic_node_1])
    assert len(mentioned_nodes) == 1
    assert mentioned_nodes[0].name == entity_node_1.name


@pytest.mark.asyncio
async def test_get_communities_by_nodes(graph_driver, mock_embedder):
    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)

    # Create community nodes
    community_node_1 = CommunityNode(
        name='test_community_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_1.generate_name_embedding(mock_embedder)

    # Create community to entity edges
    community_edge_1 = CommunityEdge(
        source_node_uuid=community_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )

    # Save the graph
    await entity_node_1.save(graph_driver)
    await community_node_1.save(graph_driver)
    await community_edge_1.save(graph_driver)

    # Get communities by nodes
    communities = await get_communities_by_nodes(graph_driver, [entity_node_1])
    assert len(communities) == 1
    assert communities[0].name == community_node_1.name


@pytest.mark.asyncio
async def test_edge_fulltext_search(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.KUZU:
        pytest.skip('Skipping as fulltext indexing not supported for Kuzu')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    now = datetime.now()
    created_at = now
    expired_at = now + timedelta(days=6)
    valid_at = now + timedelta(days=2)
    invalid_at = now + timedelta(days=4)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=created_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        expired_at=expired_at,
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_edge_1.save(graph_driver)

    # Search for entity edges
    search_filters = SearchFilters(
        node_labels=['Entity'],
        edge_types=['RELATES_TO'],
        created_at=[
            [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
        ],
        expired_at=[
            [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
        ],
        valid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=1),
                    comparison_operator=ComparisonOperator.greater_than_equal,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.less_than_equal,
                )
            ],
        ],
        invalid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.greater_than,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
                )
            ],
        ],
    )
    edges = await edge_fulltext_search(
        graph_driver, 'test_entity_1 relates to test_entity_2', search_filters, group_ids=[group_id]
    )
    assert len(edges) == 1
    assert edges[0].name == entity_edge_1.name


@pytest.mark.asyncio
async def test_edge_similarity_search(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    now = datetime.now()
    created_at = now
    expired_at = now + timedelta(days=6)
    valid_at = now + timedelta(days=2)
    invalid_at = now + timedelta(days=4)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=created_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        expired_at=expired_at,
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_edge_1.save(graph_driver)

    # Search for entity edges
    search_filters = SearchFilters(
        node_labels=['Entity'],
        edge_types=['RELATES_TO'],
        created_at=[
            [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
        ],
        expired_at=[
            [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
        ],
        valid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=1),
                    comparison_operator=ComparisonOperator.greater_than_equal,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.less_than_equal,
                )
            ],
        ],
        invalid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.greater_than,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
                )
            ],
        ],
    )
    edges = await edge_similarity_search(
        graph_driver,
        entity_edge_1.fact_embedding,
        entity_node_1.uuid,
        entity_node_2.uuid,
        search_filters,
        group_ids=[group_id],
    )
    assert len(edges) == 1
    assert edges[0].name == entity_edge_1.name


@pytest.mark.asyncio
async def test_edge_bfs_search(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    # Create episodic nodes
    episodic_node_1 = EpisodicNode(
        name='test_episodic_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
        source=EpisodeType.message,
        source_description='test_source_description',
        content='test_content',
        valid_at=datetime.now(),
    )

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)

    now = datetime.now()
    created_at = now
    expired_at = now + timedelta(days=6)
    valid_at = now + timedelta(days=2)
    invalid_at = now + timedelta(days=4)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=created_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        expired_at=expired_at,
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)
    entity_edge_2 = EntityEdge(
        source_node_uuid=entity_node_2.uuid,
        target_node_uuid=entity_node_3.uuid,
        name='RELATES_TO',
        fact='test_entity_2 relates to test_entity_3',
        created_at=created_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        expired_at=expired_at,
        group_id=group_id,
    )
    await entity_edge_2.generate_embedding(mock_embedder)

    # Create episodic to entity edges
    episodic_edge_1 = EpisodicEdge(
        source_node_uuid=episodic_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )

    # Save the graph
    await episodic_node_1.save(graph_driver)
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_edge_1.save(graph_driver)
    await entity_edge_2.save(graph_driver)
    await episodic_edge_1.save(graph_driver)

    # Search for entity edges
    search_filters = SearchFilters(
        node_labels=['Entity'],
        edge_types=['RELATES_TO'],
        created_at=[
            [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
        ],
        expired_at=[
            [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
        ],
        valid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=1),
                    comparison_operator=ComparisonOperator.greater_than_equal,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.less_than_equal,
                )
            ],
        ],
        invalid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.greater_than,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
                )
            ],
        ],
    )

    # Test bfs from episodic node

    edges = await edge_bfs_search(
        graph_driver,
        [episodic_node_1.uuid],
        1,
        search_filters,
        group_ids=[group_id],
    )
    assert len(edges) == 0

    edges = await edge_bfs_search(
        graph_driver,
        [episodic_node_1.uuid],
        2,
        search_filters,
        group_ids=[group_id],
    )
    edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
    assert len(edges_deduplicated) == 1
    assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}

    edges = await edge_bfs_search(
        graph_driver,
        [episodic_node_1.uuid],
        3,
        search_filters,
        group_ids=[group_id],
    )
    edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
    assert len(edges_deduplicated) == 2
    assert edges_deduplicated == {
        'test_entity_1 relates to test_entity_2',
        'test_entity_2 relates to test_entity_3',
    }

    # Test bfs from entity node

    edges = await edge_bfs_search(
        graph_driver,
        [entity_node_1.uuid],
        1,
        search_filters,
        group_ids=[group_id],
    )
    edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
    assert len(edges_deduplicated) == 1
    assert edges_deduplicated == {'test_entity_1 relates to test_entity_2'}

    edges = await edge_bfs_search(
        graph_driver,
        [entity_node_1.uuid],
        2,
        search_filters,
        group_ids=[group_id],
    )
    edges_deduplicated = set({edge.uuid: edge.fact for edge in edges}.values())
    assert len(edges_deduplicated) == 2
    assert edges_deduplicated == {
        'test_entity_1 relates to test_entity_2',
        'test_entity_2 relates to test_entity_3',
    }


@pytest.mark.asyncio
async def test_node_fulltext_search(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.KUZU:
        pytest.skip('Skipping as fulltext indexing not supported for Kuzu')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        summary='Summary about Alice',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        summary='Summary about Bob',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)

    # Search for entity edges
    search_filters = SearchFilters(node_labels=['Entity'])
    nodes = await node_fulltext_search(
        graph_driver,
        'Alice',
        search_filters,
        group_ids=[group_id],
    )
    assert len(nodes) == 1
    assert nodes[0].name == entity_node_1.name


@pytest.mark.asyncio
async def test_node_similarity_search(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_alice',
        summary='Summary about Alice',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_bob',
        summary='Summary about Bob',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)

    # Search for entity edges
    search_filters = SearchFilters(node_labels=['Entity'])
    nodes = await node_similarity_search(
        graph_driver,
        entity_node_1.name_embedding,
        search_filters,
        group_ids=[group_id],
        min_score=0.9,
    )
    assert len(nodes) == 1
    assert nodes[0].name == entity_node_1.name


@pytest.mark.asyncio
async def test_node_bfs_search(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    # Create episodic nodes
    episodic_node_1 = EpisodicNode(
        name='test_episodic_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
        source=EpisodeType.message,
        source_description='test_source_description',
        content='test_content',
        valid_at=datetime.now(),
    )

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)
    entity_edge_2 = EntityEdge(
        source_node_uuid=entity_node_2.uuid,
        target_node_uuid=entity_node_3.uuid,
        name='RELATES_TO',
        fact='test_entity_2 relates to test_entity_3',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_2.generate_embedding(mock_embedder)

    # Create episodic to entity edges
    episodic_edge_1 = EpisodicEdge(
        source_node_uuid=episodic_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )

    # Save the graph
    await episodic_node_1.save(graph_driver)
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_edge_1.save(graph_driver)
    await entity_edge_2.save(graph_driver)
    await episodic_edge_1.save(graph_driver)

    # Search for entity nodes
    search_filters = SearchFilters(
        node_labels=['Entity'],
    )

    # Test bfs from episodic node

    nodes = await node_bfs_search(
        graph_driver,
        [episodic_node_1.uuid],
        search_filters,
        1,
        group_ids=[group_id],
    )
    nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
    assert len(nodes_deduplicated) == 1
    assert nodes_deduplicated == {'test_entity_1'}

    nodes = await node_bfs_search(
        graph_driver,
        [episodic_node_1.uuid],
        search_filters,
        2,
        group_ids=[group_id],
    )
    nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
    assert len(nodes_deduplicated) == 2
    assert nodes_deduplicated == {'test_entity_1', 'test_entity_2'}

    # Test bfs from entity node

    nodes = await node_bfs_search(
        graph_driver,
        [entity_node_1.uuid],
        search_filters,
        1,
        group_ids=[group_id],
    )
    nodes_deduplicated = set({node.uuid: node.name for node in nodes}.values())
    assert len(nodes_deduplicated) == 1
    assert nodes_deduplicated == {'test_entity_2'}


@pytest.mark.asyncio
async def test_episode_fulltext_search(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.KUZU:
        pytest.skip('Skipping as fulltext indexing not supported for Kuzu')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create episodic nodes
    episodic_node_1 = EpisodicNode(
        name='test_episodic_1',
        content='test_content',
        created_at=datetime.now(),
        valid_at=datetime.now(),
        group_id=group_id,
        source=EpisodeType.message,
        source_description='Description about Alice',
    )
    episodic_node_2 = EpisodicNode(
        name='test_episodic_2',
        content='test_content_2',
        created_at=datetime.now(),
        valid_at=datetime.now(),
        group_id=group_id,
        source=EpisodeType.message,
        source_description='Description about Bob',
    )

    # Save the graph
    await episodic_node_1.save(graph_driver)
    await episodic_node_2.save(graph_driver)

    # Search for episodic nodes
    search_filters = SearchFilters(node_labels=['Episodic'])
    nodes = await episode_fulltext_search(
        graph_driver,
        'Alice',
        search_filters,
        group_ids=[group_id],
    )
    assert len(nodes) == 1
    assert nodes[0].name == episodic_node_1.name


@pytest.mark.asyncio
async def test_community_fulltext_search(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.KUZU:
        pytest.skip('Skipping as fulltext indexing not supported for Kuzu')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create community nodes
    community_node_1 = CommunityNode(
        name='Alice',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_1.generate_name_embedding(mock_embedder)
    community_node_2 = CommunityNode(
        name='Bob',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_2.generate_name_embedding(mock_embedder)

    # Save the graph
    await community_node_1.save(graph_driver)
    await community_node_2.save(graph_driver)

    # Search for community nodes
    nodes = await community_fulltext_search(
        graph_driver,
        'Alice',
        group_ids=[group_id],
    )
    assert len(nodes) == 1
    assert nodes[0].name == community_node_1.name


@pytest.mark.asyncio
async def test_community_similarity_search(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create community nodes
    community_node_1 = CommunityNode(
        name='Alice',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_1.generate_name_embedding(mock_embedder)
    community_node_2 = CommunityNode(
        name='Bob',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_2.generate_name_embedding(mock_embedder)

    # Save the graph
    await community_node_1.save(graph_driver)
    await community_node_2.save(graph_driver)

    # Search for community nodes
    nodes = await community_similarity_search(
        graph_driver,
        community_node_1.name_embedding,
        group_ids=[group_id],
        min_score=0.9,
    )
    assert len(nodes) == 1
    assert nodes[0].name == community_node_1.name


@pytest.mark.asyncio
async def test_get_relevant_nodes(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    if graph_driver.provider == GraphProvider.KUZU:
        pytest.skip('Skipping as tests fail on Kuzu')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='Alice',
        summary='Alice',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='Bob',
        summary='Bob',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='Alice Smith',
        summary='Alice Smith',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)

    # Search for entity nodes
    search_filters = SearchFilters(node_labels=['Entity'])
    nodes = (
        await get_relevant_nodes(
            graph_driver,
            [entity_node_1],
            search_filters,
            min_score=0.9,
        )
    )[0]
    assert len(nodes) == 2
    assert set({node.name for node in nodes}) == {entity_node_1.name, entity_node_3.name}


@pytest.mark.asyncio
async def test_get_relevant_edges_and_invalidation_candidates(
    graph_driver, mock_embedder, mock_llm_client, mock_cross_encoder_client
):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    graphiti = Graphiti(
        graph_driver=graph_driver,
        llm_client=mock_llm_client,
        embedder=mock_embedder,
        cross_encoder=mock_cross_encoder_client,
    )
    await graphiti.build_indices_and_constraints()

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        summary='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        summary='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        summary='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)

    now = datetime.now()
    created_at = now
    expired_at = now + timedelta(days=6)
    valid_at = now + timedelta(days=2)
    invalid_at = now + timedelta(days=4)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='Alice',
        created_at=created_at,
        expired_at=expired_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)
    entity_edge_2 = EntityEdge(
        source_node_uuid=entity_node_2.uuid,
        target_node_uuid=entity_node_3.uuid,
        name='RELATES_TO',
        fact='Bob',
        created_at=created_at,
        expired_at=expired_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        group_id=group_id,
    )
    await entity_edge_2.generate_embedding(mock_embedder)
    entity_edge_3 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_3.uuid,
        name='RELATES_TO',
        fact='Alice',
        created_at=created_at,
        expired_at=expired_at,
        valid_at=valid_at,
        invalid_at=invalid_at,
        group_id=group_id,
    )
    await entity_edge_3.generate_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_edge_1.save(graph_driver)
    await entity_edge_2.save(graph_driver)
    await entity_edge_3.save(graph_driver)

    # Search for entity nodes
    search_filters = SearchFilters(
        node_labels=['Entity'],
        edge_types=['RELATES_TO'],
        created_at=[
            [DateFilter(date=created_at, comparison_operator=ComparisonOperator.equals)],
        ],
        expired_at=[
            [DateFilter(date=now, comparison_operator=ComparisonOperator.not_equals)],
        ],
        valid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=1),
                    comparison_operator=ComparisonOperator.greater_than_equal,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.less_than_equal,
                )
            ],
        ],
        invalid_at=[
            [
                DateFilter(
                    date=now + timedelta(days=3),
                    comparison_operator=ComparisonOperator.greater_than,
                )
            ],
            [
                DateFilter(
                    date=now + timedelta(days=5), comparison_operator=ComparisonOperator.less_than
                )
            ],
        ],
    )
    edges = (
        await get_relevant_edges(
            graph_driver,
            [entity_edge_1],
            search_filters,
            min_score=0.9,
        )
    )[0]
    assert len(edges) == 1
    assert set({edge.name for edge in edges}) == {entity_edge_1.name}

    edges = (
        await get_edge_invalidation_candidates(
            graph_driver,
            [entity_edge_1],
            search_filters,
            min_score=0.9,
        )
    )[0]
    assert len(edges) == 2
    assert set({edge.name for edge in edges}) == {entity_edge_1.name, entity_edge_3.name}


@pytest.mark.asyncio
async def test_node_distance_reranker(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)
    entity_node_3 = EntityNode(
        name='test_entity_3',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_3.generate_name_embedding(mock_embedder)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_node_3.save(graph_driver)
    await entity_edge_1.save(graph_driver)

    # Test reranker
    reranked_uuids, reranked_scores = await node_distance_reranker(
        graph_driver,
        [entity_node_2.uuid, entity_node_3.uuid],
        entity_node_1.uuid,
    )
    uuid_to_name = {
        entity_node_1.uuid: entity_node_1.name,
        entity_node_2.uuid: entity_node_2.name,
        entity_node_3.uuid: entity_node_3.name,
    }
    names = [uuid_to_name[uuid] for uuid in reranked_uuids]
    assert names == [entity_node_2.name, entity_node_3.name]
    assert np.allclose(reranked_scores, [1.0, 0.0])


@pytest.mark.asyncio
async def test_episode_mentions_reranker(graph_driver, mock_embedder):
    if graph_driver.provider == GraphProvider.FALKORDB:
        pytest.skip('Skipping as tests fail on Falkordb')

    # Create episodic nodes
    episodic_node_1 = EpisodicNode(
        name='test_episodic_1',
        content='test_content',
        created_at=datetime.now(),
        valid_at=datetime.now(),
        group_id=group_id,
        source=EpisodeType.message,
        source_description='Description about Alice',
    )

    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    # Create entity edges
    episodic_edge_1 = EpisodicEdge(
        source_node_uuid=episodic_node_1.uuid,
        target_node_uuid=entity_node_1.uuid,
        created_at=datetime.now(),
        group_id=group_id,
    )

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await episodic_node_1.save(graph_driver)
    await episodic_edge_1.save(graph_driver)

    # Test reranker
    reranked_uuids, reranked_scores = await episode_mentions_reranker(
        graph_driver,
        [[entity_node_1.uuid, entity_node_2.uuid]],
    )
    uuid_to_name = {entity_node_1.uuid: entity_node_1.name, entity_node_2.uuid: entity_node_2.name}
    names = [uuid_to_name[uuid] for uuid in reranked_uuids]
    assert names == [entity_node_1.name, entity_node_2.name]
    assert np.allclose(reranked_scores, [1.0, float('inf')])


@pytest.mark.asyncio
async def test_get_embeddings_for_edges(graph_driver, mock_embedder):
    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)
    entity_node_2 = EntityNode(
        name='test_entity_2',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_2.generate_name_embedding(mock_embedder)

    # Create entity edges
    entity_edge_1 = EntityEdge(
        source_node_uuid=entity_node_1.uuid,
        target_node_uuid=entity_node_2.uuid,
        name='RELATES_TO',
        fact='test_entity_1 relates to test_entity_2',
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_edge_1.generate_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)
    await entity_node_2.save(graph_driver)
    await entity_edge_1.save(graph_driver)

    # Get embeddings for edges
    embeddings = await get_embeddings_for_edges(graph_driver, [entity_edge_1])
    assert len(embeddings) == 1
    assert entity_edge_1.uuid in embeddings
    assert np.allclose(embeddings[entity_edge_1.uuid], entity_edge_1.fact_embedding)


@pytest.mark.asyncio
async def test_get_embeddings_for_nodes(graph_driver, mock_embedder):
    # Create entity nodes
    entity_node_1 = EntityNode(
        name='test_entity_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await entity_node_1.generate_name_embedding(mock_embedder)

    # Save the graph
    await entity_node_1.save(graph_driver)

    # Get embeddings for edges
    embeddings = await get_embeddings_for_nodes(graph_driver, [entity_node_1])
    assert len(embeddings) == 1
    assert entity_node_1.uuid in embeddings
    assert np.allclose(embeddings[entity_node_1.uuid], entity_node_1.name_embedding)


@pytest.mark.asyncio
async def test_get_embeddings_for_communities(graph_driver, mock_embedder):
    # Create community nodes
    community_node_1 = CommunityNode(
        name='test_community_1',
        labels=[],
        created_at=datetime.now(),
        group_id=group_id,
    )
    await community_node_1.generate_name_embedding(mock_embedder)

    # Save the graph
    await community_node_1.save(graph_driver)

    # Get embeddings for communities
    embeddings = await get_embeddings_for_communities(graph_driver, [community_node_1])
    assert len(embeddings) == 1
    assert community_node_1.uuid in embeddings
    assert np.allclose(embeddings[community_node_1.uuid], community_node_1.name_embedding)

```

--------------------------------------------------------------------------------
/graphiti_core/search/search_utils.py:
--------------------------------------------------------------------------------

```python
"""
Copyright 2024, Zep Software, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
from collections import defaultdict
from time import time
from typing import Any

import numpy as np
from numpy._typing import NDArray
from typing_extensions import LiteralString

from graphiti_core.driver.driver import (
    GraphDriver,
    GraphProvider,
)
from graphiti_core.edges import EntityEdge, get_entity_edge_from_record
from graphiti_core.graph_queries import (
    get_nodes_query,
    get_relationships_query,
    get_vector_cosine_func_query,
)
from graphiti_core.helpers import (
    lucene_sanitize,
    normalize_l2,
    semaphore_gather,
)
from graphiti_core.models.edges.edge_db_queries import get_entity_edge_return_query
from graphiti_core.models.nodes.node_db_queries import (
    COMMUNITY_NODE_RETURN,
    EPISODIC_NODE_RETURN,
    get_entity_node_return_query,
)
from graphiti_core.nodes import (
    CommunityNode,
    EntityNode,
    EpisodicNode,
    get_community_node_from_record,
    get_entity_node_from_record,
    get_episodic_node_from_record,
)
from graphiti_core.search.search_filters import (
    SearchFilters,
    edge_search_filter_query_constructor,
    node_search_filter_query_constructor,
)

logger = logging.getLogger(__name__)

RELEVANT_SCHEMA_LIMIT = 10
DEFAULT_MIN_SCORE = 0.6
DEFAULT_MMR_LAMBDA = 0.5
MAX_SEARCH_DEPTH = 3
MAX_QUERY_LENGTH = 128


def calculate_cosine_similarity(vector1: list[float], vector2: list[float]) -> float:
    """
    Calculates the cosine similarity between two vectors using NumPy.
    """
    dot_product = np.dot(vector1, vector2)
    norm_vector1 = np.linalg.norm(vector1)
    norm_vector2 = np.linalg.norm(vector2)

    if norm_vector1 == 0 or norm_vector2 == 0:
        return 0  # Handle cases where one or both vectors are zero vectors

    return dot_product / (norm_vector1 * norm_vector2)


def fulltext_query(query: str, group_ids: list[str] | None, driver: GraphDriver):
    if driver.provider == GraphProvider.KUZU:
        # Kuzu only supports simple queries.
        if len(query.split(' ')) > MAX_QUERY_LENGTH:
            return ''
        return query
    elif driver.provider == GraphProvider.FALKORDB:
        return driver.build_fulltext_query(query, group_ids, MAX_QUERY_LENGTH)
    group_ids_filter_list = (
        [driver.fulltext_syntax + f'group_id:"{g}"' for g in group_ids]
        if group_ids is not None
        else []
    )
    group_ids_filter = ''
    for f in group_ids_filter_list:
        group_ids_filter += f if not group_ids_filter else f' OR {f}'

    group_ids_filter += ' AND ' if group_ids_filter else ''

    lucene_query = lucene_sanitize(query)
    # If the lucene query is too long return no query
    if len(lucene_query.split(' ')) + len(group_ids or '') >= MAX_QUERY_LENGTH:
        return ''

    full_query = group_ids_filter + '(' + lucene_query + ')'

    return full_query


async def get_episodes_by_mentions(
    driver: GraphDriver,
    nodes: list[EntityNode],
    edges: list[EntityEdge],
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
    episode_uuids: list[str] = []
    for edge in edges:
        episode_uuids.extend(edge.episodes)

    episodes = await EpisodicNode.get_by_uuids(driver, episode_uuids[:limit])

    return episodes


async def get_mentioned_nodes(
    driver: GraphDriver, episodes: list[EpisodicNode]
) -> list[EntityNode]:
    episode_uuids = [episode.uuid for episode in episodes]

    records, _, _ = await driver.execute_query(
        """
        MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity)
        WHERE episode.uuid IN $uuids
        RETURN DISTINCT
        """
        + get_entity_node_return_query(driver.provider),
        uuids=episode_uuids,
        routing_='r',
    )

    nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

    return nodes


async def get_communities_by_nodes(
    driver: GraphDriver, nodes: list[EntityNode]
) -> list[CommunityNode]:
    node_uuids = [node.uuid for node in nodes]

    records, _, _ = await driver.execute_query(
        """
        MATCH (c:Community)-[:HAS_MEMBER]->(m:Entity)
        WHERE m.uuid IN $uuids
        RETURN DISTINCT
        """
        + COMMUNITY_NODE_RETURN,
        uuids=node_uuids,
        routing_='r',
    )

    communities = [get_community_node_from_record(record) for record in records]

    return communities


async def edge_fulltext_search(
    driver: GraphDriver,
    query: str,
    search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
    if driver.search_interface:
        return await driver.search_interface.edge_fulltext_search(
            driver, query, search_filter, group_ids, limit
        )

    # fulltext search over facts
    fuzzy_query = fulltext_query(query, group_ids, driver)

    if fuzzy_query == '':
        return []

    match_query = """
    YIELD relationship AS rel, score
    MATCH (n:Entity)-[e:RELATES_TO {uuid: rel.uuid}]->(m:Entity)
    """
    if driver.provider == GraphProvider.KUZU:
        match_query = """
        YIELD node, score
        MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {uuid: node.uuid})-[:RELATES_TO]->(m:Entity)
        """

    filter_queries, filter_params = edge_search_filter_query_constructor(
        search_filter, driver.provider
    )

    if group_ids is not None:
        filter_queries.append('e.group_id IN $group_ids')
        filter_params['group_ids'] = group_ids

    filter_query = ''
    if filter_queries:
        filter_query = ' WHERE ' + (' AND '.join(filter_queries))

    if driver.provider == GraphProvider.NEPTUNE:
        res = driver.run_aoss_query('edge_name_and_fact', query)  # pyright: ignore reportAttributeAccessIssue
        if res['hits']['total']['value'] > 0:
            input_ids = []
            for r in res['hits']['hits']:
                input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})

            # Match the edge ids and return the values
            query = (
                """
                                UNWIND $ids as id
                                MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
                                WHERE e.group_id IN $group_ids 
                                AND id(e)=id 
                                """
                + filter_query
                + """
                AND id(e)=id
                WITH e, id.score as score, startNode(e) AS n, endNode(e) AS m
                RETURN
                    e.uuid AS uuid,
                    e.group_id AS group_id,
                    n.uuid AS source_node_uuid,
                    m.uuid AS target_node_uuid,
                    e.created_at AS created_at,
                    e.name AS name,
                    e.fact AS fact,
                    split(e.episodes, ",") AS episodes,
                    e.expired_at AS expired_at,
                    e.valid_at AS valid_at,
                    e.invalid_at AS invalid_at,
                    properties(e) AS attributes
                ORDER BY score DESC LIMIT $limit
                            """
            )

            records, _, _ = await driver.execute_query(
                query,
                query=fuzzy_query,
                ids=input_ids,
                limit=limit,
                routing_='r',
                **filter_params,
            )
        else:
            return []
    else:
        query = (
            get_relationships_query('edge_name_and_fact', limit=limit, provider=driver.provider)
            + match_query
            + filter_query
            + """
            WITH e, score, n, m
            RETURN
            """
            + get_entity_edge_return_query(driver.provider)
            + """
            ORDER BY score DESC
            LIMIT $limit
            """
        )

        records, _, _ = await driver.execute_query(
            query,
            query=fuzzy_query,
            limit=limit,
            routing_='r',
            **filter_params,
        )

    edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

    return edges


async def edge_similarity_search(
    driver: GraphDriver,
    search_vector: list[float],
    source_node_uuid: str | None,
    target_node_uuid: str | None,
    search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit: int = RELEVANT_SCHEMA_LIMIT,
    min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityEdge]:
    if driver.search_interface:
        return await driver.search_interface.edge_similarity_search(
            driver,
            search_vector,
            source_node_uuid,
            target_node_uuid,
            search_filter,
            group_ids,
            limit,
            min_score,
        )

    match_query = """
        MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
    """
    if driver.provider == GraphProvider.KUZU:
        match_query = """
            MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(m:Entity)
        """

    filter_queries, filter_params = edge_search_filter_query_constructor(
        search_filter, driver.provider
    )

    if group_ids is not None:
        filter_queries.append('e.group_id IN $group_ids')
        filter_params['group_ids'] = group_ids

        if source_node_uuid is not None:
            filter_params['source_uuid'] = source_node_uuid
            filter_queries.append('n.uuid = $source_uuid')

        if target_node_uuid is not None:
            filter_params['target_uuid'] = target_node_uuid
            filter_queries.append('m.uuid = $target_uuid')

    filter_query = ''
    if filter_queries:
        filter_query = ' WHERE ' + (' AND '.join(filter_queries))

    search_vector_var = '$search_vector'
    if driver.provider == GraphProvider.KUZU:
        search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'

    if driver.provider == GraphProvider.NEPTUNE:
        query = (
            """
                            MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity)
                            """
            + filter_query
            + """
            RETURN DISTINCT id(e) as id, e.fact_embedding as embedding
            """
        )
        resp, header, _ = await driver.execute_query(
            query,
            search_vector=search_vector,
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )

        if len(resp) > 0:
            # Calculate Cosine similarity then return the edge ids
            input_ids = []
            for r in resp:
                if r['embedding']:
                    score = calculate_cosine_similarity(
                        search_vector, list(map(float, r['embedding'].split(',')))
                    )
                    if score > min_score:
                        input_ids.append({'id': r['id'], 'score': score})

            # Match the edge ides and return the values
            query = """
                UNWIND $ids as i
                MATCH ()-[r]->()
                WHERE id(r) = i.id
                RETURN
                    r.uuid AS uuid,
                    r.group_id AS group_id,
                    startNode(r).uuid AS source_node_uuid,
                    endNode(r).uuid AS target_node_uuid,
                    r.created_at AS created_at,
                    r.name AS name,
                    r.fact AS fact,
                    split(r.episodes, ",") AS episodes,
                    r.expired_at AS expired_at,
                    r.valid_at AS valid_at,
                    r.invalid_at AS invalid_at,
                    properties(r) AS attributes
                ORDER BY i.score DESC
                LIMIT $limit
                    """
            records, _, _ = await driver.execute_query(
                query,
                ids=input_ids,
                search_vector=search_vector,
                limit=limit,
                min_score=min_score,
                routing_='r',
                **filter_params,
            )
        else:
            return []
    else:
        query = (
            match_query
            + filter_query
            + """
            WITH DISTINCT e, n, m, """
            + get_vector_cosine_func_query('e.fact_embedding', search_vector_var, driver.provider)
            + """ AS score
            WHERE score > $min_score
            RETURN
            """
            + get_entity_edge_return_query(driver.provider)
            + """
            ORDER BY score DESC
            LIMIT $limit
            """
        )

        records, _, _ = await driver.execute_query(
            query,
            search_vector=search_vector,
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )

    edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

    return edges


async def edge_bfs_search(
    driver: GraphDriver,
    bfs_origin_node_uuids: list[str] | None,
    bfs_max_depth: int,
    search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
    # vector similarity search over embedded facts
    if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0:
        return []

    filter_queries, filter_params = edge_search_filter_query_constructor(
        search_filter, driver.provider
    )

    if group_ids is not None:
        filter_queries.append('e.group_id IN $group_ids')
        filter_params['group_ids'] = group_ids

    filter_query = ''
    if filter_queries:
        filter_query = ' WHERE ' + (' AND '.join(filter_queries))

    if driver.provider == GraphProvider.KUZU:
        # Kuzu stores entity edges twice with an intermediate node, so we need to match them
        # separately for the correct BFS depth.
        depth = bfs_max_depth * 2 - 1
        match_queries = [
            f"""
            UNWIND $bfs_origin_node_uuids AS origin_uuid
            MATCH path = (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
            UNWIND nodes(path) AS relNode
            MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
            """,
        ]
        if bfs_max_depth > 1:
            depth = (bfs_max_depth - 1) * 2 - 1
            match_queries.append(f"""
                UNWIND $bfs_origin_node_uuids AS origin_uuid
                MATCH path = (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*1..{depth}]->(:RelatesToNode_)
                UNWIND nodes(path) AS relNode
                MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {{uuid: relNode.uuid}})-[:RELATES_TO]->(m:Entity)
            """)

        records = []
        for match_query in match_queries:
            sub_records, _, _ = await driver.execute_query(
                match_query
                + filter_query
                + """
                RETURN DISTINCT
                """
                + get_entity_edge_return_query(driver.provider)
                + """
                LIMIT $limit
                """,
                bfs_origin_node_uuids=bfs_origin_node_uuids,
                limit=limit,
                routing_='r',
                **filter_params,
            )
            records.extend(sub_records)
    else:
        if driver.provider == GraphProvider.NEPTUNE:
            query = (
                f"""
                UNWIND $bfs_origin_node_uuids AS origin_uuid
                MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS *1..{bfs_max_depth}]->(n:Entity)
                WHERE origin:Entity OR origin:Episodic
                UNWIND relationships(path) AS rel
                MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
                """
                + filter_query
                + """
                RETURN DISTINCT
                    e.uuid AS uuid,
                    e.group_id AS group_id,
                    startNode(e).uuid AS source_node_uuid,
                    endNode(e).uuid AS target_node_uuid,
                    e.created_at AS created_at,
                    e.name AS name,
                    e.fact AS fact,
                    split(e.episodes, ',') AS episodes,
                    e.expired_at AS expired_at,
                    e.valid_at AS valid_at,
                    e.invalid_at AS invalid_at,
                    properties(e) AS attributes
                LIMIT $limit
                """
            )
        else:
            query = (
                f"""
                UNWIND $bfs_origin_node_uuids AS origin_uuid
                MATCH path = (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(:Entity)
                UNWIND relationships(path) AS rel
                MATCH (n:Entity)-[e:RELATES_TO {{uuid: rel.uuid}}]-(m:Entity)
                """
                + filter_query
                + """
                RETURN DISTINCT
                """
                + get_entity_edge_return_query(driver.provider)
                + """
                LIMIT $limit
                """
            )

        records, _, _ = await driver.execute_query(
            query,
            bfs_origin_node_uuids=bfs_origin_node_uuids,
            depth=bfs_max_depth,
            limit=limit,
            routing_='r',
            **filter_params,
        )

    edges = [get_entity_edge_from_record(record, driver.provider) for record in records]

    return edges


async def node_fulltext_search(
    driver: GraphDriver,
    query: str,
    search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
    if driver.search_interface:
        return await driver.search_interface.node_fulltext_search(
            driver, query, search_filter, group_ids, limit
        )

    # BM25 search to get top nodes
    fuzzy_query = fulltext_query(query, group_ids, driver)
    if fuzzy_query == '':
        return []

    filter_queries, filter_params = node_search_filter_query_constructor(
        search_filter, driver.provider
    )

    if group_ids is not None:
        filter_queries.append('n.group_id IN $group_ids')
        filter_params['group_ids'] = group_ids

    filter_query = ''
    if filter_queries:
        filter_query = ' WHERE ' + (' AND '.join(filter_queries))

    yield_query = 'YIELD node AS n, score'
    if driver.provider == GraphProvider.KUZU:
        yield_query = 'WITH node AS n, score'

    if driver.provider == GraphProvider.NEPTUNE:
        res = driver.run_aoss_query('node_name_and_summary', query, limit=limit)  # pyright: ignore reportAttributeAccessIssue
        if res['hits']['total']['value'] > 0:
            input_ids = []
            for r in res['hits']['hits']:
                input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})

            # Match the edge ides and return the values
            query = (
                """
                                UNWIND $ids as i
                                MATCH (n:Entity)
                                WHERE n.uuid=i.id
                                RETURN
                                """
                + get_entity_node_return_query(driver.provider)
                + """
                ORDER BY i.score DESC
                LIMIT $limit
                            """
            )
            records, _, _ = await driver.execute_query(
                query,
                ids=input_ids,
                query=fuzzy_query,
                limit=limit,
                routing_='r',
                **filter_params,
            )
        else:
            return []
    else:
        query = (
            get_nodes_query(
                'node_name_and_summary', '$query', limit=limit, provider=driver.provider
            )
            + yield_query
            + filter_query
            + """
            WITH n, score
            ORDER BY score DESC
            LIMIT $limit
            RETURN
            """
            + get_entity_node_return_query(driver.provider)
        )

        records, _, _ = await driver.execute_query(
            query,
            query=fuzzy_query,
            limit=limit,
            routing_='r',
            **filter_params,
        )

    nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

    return nodes


async def node_similarity_search(
    driver: GraphDriver,
    search_vector: list[float],
    search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit=RELEVANT_SCHEMA_LIMIT,
    min_score: float = DEFAULT_MIN_SCORE,
) -> list[EntityNode]:
    if driver.search_interface:
        return await driver.search_interface.node_similarity_search(
            driver, search_vector, search_filter, group_ids, limit, min_score
        )

    filter_queries, filter_params = node_search_filter_query_constructor(
        search_filter, driver.provider
    )

    if group_ids is not None:
        filter_queries.append('n.group_id IN $group_ids')
        filter_params['group_ids'] = group_ids

    filter_query = ''
    if filter_queries:
        filter_query = ' WHERE ' + (' AND '.join(filter_queries))

    search_vector_var = '$search_vector'
    if driver.provider == GraphProvider.KUZU:
        search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'

    if driver.provider == GraphProvider.NEPTUNE:
        query = (
            """
                                                                                                                                    MATCH (n:Entity)
                                                                                                                                    """
            + filter_query
            + """
            RETURN DISTINCT id(n) as id, n.name_embedding as embedding
            """
        )
        resp, header, _ = await driver.execute_query(
            query,
            params=filter_params,
            search_vector=search_vector,
            limit=limit,
            min_score=min_score,
            routing_='r',
        )

        if len(resp) > 0:
            # Calculate Cosine similarity then return the edge ids
            input_ids = []
            for r in resp:
                if r['embedding']:
                    score = calculate_cosine_similarity(
                        search_vector, list(map(float, r['embedding'].split(',')))
                    )
                    if score > min_score:
                        input_ids.append({'id': r['id'], 'score': score})

            # Match the edge ides and return the values
            query = (
                """
                                                                                                                                                                UNWIND $ids as i
                                                                                                                                                                MATCH (n:Entity)
                                                                                                                                                                WHERE id(n)=i.id
                                                                                                                                                                RETURN 
                                                                                                                                                                """
                + get_entity_node_return_query(driver.provider)
                + """
                    ORDER BY i.score DESC
                    LIMIT $limit
                """
            )
            records, header, _ = await driver.execute_query(
                query,
                ids=input_ids,
                search_vector=search_vector,
                limit=limit,
                min_score=min_score,
                routing_='r',
                **filter_params,
            )
        else:
            return []
    else:
        query = (
            """
                                                                                                                                    MATCH (n:Entity)
                                                                                                                                    """
            + filter_query
            + """
            WITH n, """
            + get_vector_cosine_func_query('n.name_embedding', search_vector_var, driver.provider)
            + """ AS score
            WHERE score > $min_score
            RETURN
            """
            + get_entity_node_return_query(driver.provider)
            + """
            ORDER BY score DESC
            LIMIT $limit
            """
        )

        records, _, _ = await driver.execute_query(
            query,
            search_vector=search_vector,
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )

    nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

    return nodes


async def node_bfs_search(
    driver: GraphDriver,
    bfs_origin_node_uuids: list[str] | None,
    search_filter: SearchFilters,
    bfs_max_depth: int,
    group_ids: list[str] | None = None,
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
    if bfs_origin_node_uuids is None or len(bfs_origin_node_uuids) == 0 or bfs_max_depth < 1:
        return []

    filter_queries, filter_params = node_search_filter_query_constructor(
        search_filter, driver.provider
    )

    if group_ids is not None:
        filter_queries.append('n.group_id IN $group_ids')
        filter_queries.append('origin.group_id IN $group_ids')
        filter_params['group_ids'] = group_ids

    filter_query = ''
    if filter_queries:
        filter_query = ' AND ' + (' AND '.join(filter_queries))

    match_queries = [
        f"""
        UNWIND $bfs_origin_node_uuids AS origin_uuid
        MATCH (origin {{uuid: origin_uuid}})-[:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
        WHERE n.group_id = origin.group_id
        """
    ]

    if driver.provider == GraphProvider.NEPTUNE:
        match_queries = [
            f"""
            UNWIND $bfs_origin_node_uuids AS origin_uuid
            MATCH (origin {{uuid: origin_uuid}})-[e:RELATES_TO|MENTIONS*1..{bfs_max_depth}]->(n:Entity)
            WHERE origin:Entity OR origin.Episode
            AND n.group_id = origin.group_id
            """
        ]

    if driver.provider == GraphProvider.KUZU:
        depth = bfs_max_depth * 2
        match_queries = [
            """
            UNWIND $bfs_origin_node_uuids AS origin_uuid
            MATCH (origin:Episodic {uuid: origin_uuid})-[:MENTIONS]->(n:Entity)
            WHERE n.group_id = origin.group_id
            """,
            f"""
            UNWIND $bfs_origin_node_uuids AS origin_uuid
            MATCH (origin:Entity {{uuid: origin_uuid}})-[:RELATES_TO*2..{depth}]->(n:Entity)
            WHERE n.group_id = origin.group_id
            """,
        ]
        if bfs_max_depth > 1:
            depth = (bfs_max_depth - 1) * 2
            match_queries.append(f"""
                UNWIND $bfs_origin_node_uuids AS origin_uuid
                MATCH (origin:Episodic {{uuid: origin_uuid}})-[:MENTIONS]->(:Entity)-[:RELATES_TO*2..{depth}]->(n:Entity)
                WHERE n.group_id = origin.group_id
            """)

    records = []
    for match_query in match_queries:
        sub_records, _, _ = await driver.execute_query(
            match_query
            + filter_query
            + """
            RETURN
            """
            + get_entity_node_return_query(driver.provider)
            + """
            LIMIT $limit
            """,
            bfs_origin_node_uuids=bfs_origin_node_uuids,
            limit=limit,
            routing_='r',
            **filter_params,
        )
        records.extend(sub_records)

    nodes = [get_entity_node_from_record(record, driver.provider) for record in records]

    return nodes


async def episode_fulltext_search(
    driver: GraphDriver,
    query: str,
    _search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EpisodicNode]:
    if driver.search_interface:
        return await driver.search_interface.episode_fulltext_search(
            driver, query, _search_filter, group_ids, limit
        )

    # BM25 search to get top episodes
    fuzzy_query = fulltext_query(query, group_ids, driver)
    if fuzzy_query == '':
        return []

    filter_params: dict[str, Any] = {}
    group_filter_query: LiteralString = ''
    if group_ids is not None:
        group_filter_query += '\nAND e.group_id IN $group_ids'
        filter_params['group_ids'] = group_ids

    if driver.provider == GraphProvider.NEPTUNE:
        res = driver.run_aoss_query('episode_content', query, limit=limit)  # pyright: ignore reportAttributeAccessIssue
        if res['hits']['total']['value'] > 0:
            input_ids = []
            for r in res['hits']['hits']:
                input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})

            # Match the edge ides and return the values
            query = """
                UNWIND $ids as i
                MATCH (e:Episodic)
                WHERE e.uuid=i.uuid
            RETURN
                    e.content AS content,
                    e.created_at AS created_at,
                    e.valid_at AS valid_at,
                    e.uuid AS uuid,
                    e.name AS name,
                    e.group_id AS group_id,
                    e.source_description AS source_description,
                    e.source AS source,
                    e.entity_edges AS entity_edges
                ORDER BY i.score DESC
                LIMIT $limit
            """
            records, _, _ = await driver.execute_query(
                query,
                ids=input_ids,
                query=fuzzy_query,
                limit=limit,
                routing_='r',
                **filter_params,
            )
        else:
            return []
    else:
        query = (
            get_nodes_query('episode_content', '$query', limit=limit, provider=driver.provider)
            + """
            YIELD node AS episode, score
            MATCH (e:Episodic)
            WHERE e.uuid = episode.uuid
            """
            + group_filter_query
            + """
            RETURN
            """
            + EPISODIC_NODE_RETURN
            + """
            ORDER BY score DESC
            LIMIT $limit
            """
        )

        records, _, _ = await driver.execute_query(
            query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
        )

    episodes = [get_episodic_node_from_record(record) for record in records]

    return episodes


async def community_fulltext_search(
    driver: GraphDriver,
    query: str,
    group_ids: list[str] | None = None,
    limit=RELEVANT_SCHEMA_LIMIT,
) -> list[CommunityNode]:
    # BM25 search to get top communities
    fuzzy_query = fulltext_query(query, group_ids, driver)
    if fuzzy_query == '':
        return []

    filter_params: dict[str, Any] = {}
    group_filter_query: LiteralString = ''
    if group_ids is not None:
        group_filter_query = 'WHERE c.group_id IN $group_ids'
        filter_params['group_ids'] = group_ids

    yield_query = 'YIELD node AS c, score'
    if driver.provider == GraphProvider.KUZU:
        yield_query = 'WITH node AS c, score'

    if driver.provider == GraphProvider.NEPTUNE:
        res = driver.run_aoss_query('community_name', query, limit=limit)  # pyright: ignore reportAttributeAccessIssue
        if res['hits']['total']['value'] > 0:
            # Calculate Cosine similarity then return the edge ids
            input_ids = []
            for r in res['hits']['hits']:
                input_ids.append({'id': r['_source']['uuid'], 'score': r['_score']})

            # Match the edge ides and return the values
            query = """
                UNWIND $ids as i
                MATCH (comm:Community)
                WHERE comm.uuid=i.id
                RETURN
                    comm.uuid AS uuid,
                    comm.group_id AS group_id,
                    comm.name AS name,
                    comm.created_at AS created_at,
                    comm.summary AS summary,
                    [x IN split(comm.name_embedding, ",") | toFloat(x)]AS name_embedding
                ORDER BY i.score DESC
                LIMIT $limit
            """
            records, _, _ = await driver.execute_query(
                query,
                ids=input_ids,
                query=fuzzy_query,
                limit=limit,
                routing_='r',
                **filter_params,
            )
        else:
            return []
    else:
        query = (
            get_nodes_query('community_name', '$query', limit=limit, provider=driver.provider)
            + yield_query
            + """
            WITH c, score
            """
            + group_filter_query
            + """
            RETURN
            """
            + COMMUNITY_NODE_RETURN
            + """
            ORDER BY score DESC
            LIMIT $limit
            """
        )

        records, _, _ = await driver.execute_query(
            query, query=fuzzy_query, limit=limit, routing_='r', **filter_params
        )

    communities = [get_community_node_from_record(record) for record in records]

    return communities


async def community_similarity_search(
    driver: GraphDriver,
    search_vector: list[float],
    group_ids: list[str] | None = None,
    limit=RELEVANT_SCHEMA_LIMIT,
    min_score=DEFAULT_MIN_SCORE,
) -> list[CommunityNode]:
    # vector similarity search over entity names
    query_params: dict[str, Any] = {}

    group_filter_query: LiteralString = ''
    if group_ids is not None:
        group_filter_query += ' WHERE c.group_id IN $group_ids'
        query_params['group_ids'] = group_ids

    if driver.provider == GraphProvider.NEPTUNE:
        query = (
            """
                                                                                                                                    MATCH (n:Community)
                                                                                                                                    """
            + group_filter_query
            + """
            RETURN DISTINCT id(n) as id, n.name_embedding as embedding
            """
        )
        resp, header, _ = await driver.execute_query(
            query,
            search_vector=search_vector,
            limit=limit,
            min_score=min_score,
            routing_='r',
            **query_params,
        )

        if len(resp) > 0:
            # Calculate Cosine similarity then return the edge ids
            input_ids = []
            for r in resp:
                if r['embedding']:
                    score = calculate_cosine_similarity(
                        search_vector, list(map(float, r['embedding'].split(',')))
                    )
                    if score > min_score:
                        input_ids.append({'id': r['id'], 'score': score})

            # Match the edge ides and return the values
            query = """
                    UNWIND $ids as i
                    MATCH (comm:Community)
                    WHERE id(comm)=i.id
                    RETURN
                        comm.uuid As uuid,
                        comm.group_id AS group_id,
                        comm.name AS name,
                        comm.created_at AS created_at,
                        comm.summary AS summary,
                        comm.name_embedding AS name_embedding
                    ORDER BY i.score DESC
                    LIMIT $limit
                """
            records, header, _ = await driver.execute_query(
                query,
                ids=input_ids,
                search_vector=search_vector,
                limit=limit,
                min_score=min_score,
                routing_='r',
                **query_params,
            )
        else:
            return []
    else:
        search_vector_var = '$search_vector'
        if driver.provider == GraphProvider.KUZU:
            search_vector_var = f'CAST($search_vector AS FLOAT[{len(search_vector)}])'

        query = (
            """
                                                                                                                                    MATCH (c:Community)
                                                                                                                                    """
            + group_filter_query
            + """
            WITH c,
            """
            + get_vector_cosine_func_query('c.name_embedding', search_vector_var, driver.provider)
            + """ AS score
            WHERE score > $min_score
            RETURN
            """
            + COMMUNITY_NODE_RETURN
            + """
            ORDER BY score DESC
            LIMIT $limit
            """
        )

        records, _, _ = await driver.execute_query(
            query,
            search_vector=search_vector,
            limit=limit,
            min_score=min_score,
            routing_='r',
            **query_params,
        )

    communities = [get_community_node_from_record(record) for record in records]

    return communities


async def hybrid_node_search(
    queries: list[str],
    embeddings: list[list[float]],
    driver: GraphDriver,
    search_filter: SearchFilters,
    group_ids: list[str] | None = None,
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
    """
    Perform a hybrid search for nodes using both text queries and embeddings.

    This method combines fulltext search and vector similarity search to find
    relevant nodes in the graph database. It uses a rrf reranker.

    Parameters
    ----------
    queries : list[str]
        A list of text queries to search for.
    embeddings : list[list[float]]
        A list of embedding vectors corresponding to the queries. If empty only fulltext search is performed.
    driver : GraphDriver
        The Neo4j driver instance for database operations.
    group_ids : list[str] | None, optional
        The list of group ids to retrieve nodes from.
    limit : int | None, optional
        The maximum number of results to return per search method. If None, a default limit will be applied.

    Returns
    -------
    list[EntityNode]
        A list of unique EntityNode objects that match the search criteria.

    Notes
    -----
    This method performs the following steps:
    1. Executes fulltext searches for each query.
    2. Executes vector similarity searches for each embedding.
    3. Combines and deduplicates the results from both search types.
    4. Logs the performance metrics of the search operation.

    The search results are deduplicated based on the node UUIDs to ensure
    uniqueness in the returned list. The 'limit' parameter is applied to each
    individual search method before deduplication. If not specified, a default
    limit (defined in the individual search functions) will be used.
    """

    start = time()
    results: list[list[EntityNode]] = list(
        await semaphore_gather(
            *[
                node_fulltext_search(driver, q, search_filter, group_ids, 2 * limit)
                for q in queries
            ],
            *[
                node_similarity_search(driver, e, search_filter, group_ids, 2 * limit)
                for e in embeddings
            ],
        )
    )

    node_uuid_map: dict[str, EntityNode] = {
        node.uuid: node for result in results for node in result
    }
    result_uuids = [[node.uuid for node in result] for result in results]

    ranked_uuids, _ = rrf(result_uuids)

    relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]

    end = time()
    logger.debug(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
    return relevant_nodes


async def get_relevant_nodes(
    driver: GraphDriver,
    nodes: list[EntityNode],
    search_filter: SearchFilters,
    min_score: float = DEFAULT_MIN_SCORE,
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityNode]]:
    if len(nodes) == 0:
        return []

    group_id = nodes[0].group_id
    query_nodes = [
        {
            'uuid': node.uuid,
            'name': node.name,
            'name_embedding': node.name_embedding,
            'fulltext_query': fulltext_query(node.name, [node.group_id], driver),
        }
        for node in nodes
    ]

    filter_queries, filter_params = node_search_filter_query_constructor(
        search_filter, driver.provider
    )

    filter_query = ''
    if filter_queries:
        filter_query = 'WHERE ' + (' AND '.join(filter_queries))

    if driver.provider == GraphProvider.KUZU:
        embedding_size = len(nodes[0].name_embedding) if nodes[0].name_embedding is not None else 0
        if embedding_size == 0:
            return []

        # FIXME: Kuzu currently does not support using variables such as `node.fulltext_query` as an input to FTS, which means `get_relevant_nodes()` won't work with Kuzu as the graph driver.
        query = (
            """
                                                                                                                                    UNWIND $nodes AS node
                                                                                                                                    MATCH (n:Entity {group_id: $group_id})
                                                                                                                                    """
            + filter_query
            + """
            WITH node, n, """
            + get_vector_cosine_func_query(
                'n.name_embedding',
                f'CAST(node.name_embedding AS FLOAT[{embedding_size}])',
                driver.provider,
            )
            + """ AS score
            WHERE score > $min_score
            WITH node, collect(n)[:$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
            """
            + get_nodes_query(
                'node_name_and_summary',
                'node.fulltext_query',
                limit=limit,
                provider=driver.provider,
            )
            + """
            WITH node AS m
            WHERE m.group_id = $group_id AND NOT m.uuid IN vector_node_uuids
            WITH node, top_vector_nodes, collect(m) AS fulltext_nodes

            WITH node, list_concat(top_vector_nodes, fulltext_nodes) AS combined_nodes

            UNWIND combined_nodes AS x
            WITH node, collect(DISTINCT {
                uuid: x.uuid,
                name: x.name,
                name_embedding: x.name_embedding,
                group_id: x.group_id,
                created_at: x.created_at,
                summary: x.summary,
                labels: x.labels,
                attributes: x.attributes
            }) AS matches

            RETURN
            node.uuid AS search_node_uuid, matches
            """
        )
    else:
        query = (
            """
                                                                                                                                    UNWIND $nodes AS node
                                                                                                                                    MATCH (n:Entity {group_id: $group_id})
                                                                                                                                    """
            + filter_query
            + """
            WITH node, n, """
            + get_vector_cosine_func_query(
                'n.name_embedding', 'node.name_embedding', driver.provider
            )
            + """ AS score
            WHERE score > $min_score
            WITH node, collect(n)[..$limit] AS top_vector_nodes, collect(n.uuid) AS vector_node_uuids
            """
            + get_nodes_query(
                'node_name_and_summary',
                'node.fulltext_query',
                limit=limit,
                provider=driver.provider,
            )
            + """
            YIELD node AS m
            WHERE m.group_id = $group_id
            WITH node, top_vector_nodes, vector_node_uuids, collect(m) AS fulltext_nodes

            WITH node,
                top_vector_nodes,
                [m IN fulltext_nodes WHERE NOT m.uuid IN vector_node_uuids] AS filtered_fulltext_nodes

            WITH node, top_vector_nodes + filtered_fulltext_nodes AS combined_nodes

            UNWIND combined_nodes AS combined_node
            WITH node, collect(DISTINCT combined_node) AS deduped_nodes

            RETURN
            node.uuid AS search_node_uuid,
            [x IN deduped_nodes | {
                uuid: x.uuid,
                name: x.name,
                name_embedding: x.name_embedding,
                group_id: x.group_id,
                created_at: x.created_at,
                summary: x.summary,
                labels: labels(x),
                attributes: properties(x)
            }] AS matches
            """
        )

    results, _, _ = await driver.execute_query(
        query,
        nodes=query_nodes,
        group_id=group_id,
        limit=limit,
        min_score=min_score,
        routing_='r',
        **filter_params,
    )

    relevant_nodes_dict: dict[str, list[EntityNode]] = {
        result['search_node_uuid']: [
            get_entity_node_from_record(record, driver.provider) for record in result['matches']
        ]
        for result in results
    }

    relevant_nodes = [relevant_nodes_dict.get(node.uuid, []) for node in nodes]

    return relevant_nodes


async def get_relevant_edges(
    driver: GraphDriver,
    edges: list[EntityEdge],
    search_filter: SearchFilters,
    min_score: float = DEFAULT_MIN_SCORE,
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityEdge]]:
    if len(edges) == 0:
        return []

    filter_queries, filter_params = edge_search_filter_query_constructor(
        search_filter, driver.provider
    )

    filter_query = ''
    if filter_queries:
        filter_query = ' WHERE ' + (' AND '.join(filter_queries))

    if driver.provider == GraphProvider.NEPTUNE:
        query = (
            """
                                                                                                                                    UNWIND $edges AS edge
                                                                                                                                    MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
                                                                                                                                    """
            + filter_query
            + """
            WITH e, edge
            RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding, edge.uuid as search_edge_uuid,
            edge.fact_embedding as target_embedding
            """
        )
        resp, _, _ = await driver.execute_query(
            query,
            edges=[edge.model_dump() for edge in edges],
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )

        # Calculate Cosine similarity then return the edge ids
        input_ids = []
        for r in resp:
            score = calculate_cosine_similarity(
                list(map(float, r['source_embedding'].split(','))), r['target_embedding']
            )
            if score > min_score:
                input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})

        # Match the edge ides and return the values
        query = """
        UNWIND $ids AS edge
        MATCH ()-[e]->()
        WHERE id(e) = edge.id
        WITH edge, e
        ORDER BY edge.score DESC
        RETURN edge.uuid AS search_edge_uuid,
            collect({
                uuid: e.uuid,
                source_node_uuid: startNode(e).uuid,
                target_node_uuid: endNode(e).uuid,
                created_at: e.created_at,
                name: e.name,
                group_id: e.group_id,
                fact: e.fact,
                fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
                episodes: split(e.episodes, ","),
                expired_at: e.expired_at,
                valid_at: e.valid_at,
                invalid_at: e.invalid_at,
                attributes: properties(e)
            })[..$limit] AS matches
                """

        results, _, _ = await driver.execute_query(
            query,
            ids=input_ids,
            edges=[edge.model_dump() for edge in edges],
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )
    else:
        if driver.provider == GraphProvider.KUZU:
            embedding_size = (
                len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
            )
            if embedding_size == 0:
                return []

            query = (
                """
                                                                                                                                        UNWIND $edges AS edge
                                                                                                                                        MATCH (n:Entity {uuid: edge.source_node_uuid})-[:RELATES_TO]-(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]-(m:Entity {uuid: edge.target_node_uuid})
                                                                                                                                        """
                + filter_query
                + """
                WITH e, edge, n, m, """
                + get_vector_cosine_func_query(
                    'e.fact_embedding',
                    f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
                    driver.provider,
                )
                + """ AS score
                WHERE score > $min_score
                WITH e, edge, n, m, score
                ORDER BY score DESC
                LIMIT $limit
                RETURN
                    edge.uuid AS search_edge_uuid,
                    collect({
                        uuid: e.uuid,
                        source_node_uuid: n.uuid,
                        target_node_uuid: m.uuid,
                        created_at: e.created_at,
                        name: e.name,
                        group_id: e.group_id,
                        fact: e.fact,
                        fact_embedding: e.fact_embedding,
                        episodes: e.episodes,
                        expired_at: e.expired_at,
                        valid_at: e.valid_at,
                        invalid_at: e.invalid_at,
                        attributes: e.attributes
                    }) AS matches
                """
            )
        else:
            query = (
                """
                                                                                                                                        UNWIND $edges AS edge
                                                                                                                                        MATCH (n:Entity {uuid: edge.source_node_uuid})-[e:RELATES_TO {group_id: edge.group_id}]-(m:Entity {uuid: edge.target_node_uuid})
                                                                                                                                        """
                + filter_query
                + """
                WITH e, edge, """
                + get_vector_cosine_func_query(
                    'e.fact_embedding', 'edge.fact_embedding', driver.provider
                )
                + """ AS score
                WHERE score > $min_score
                WITH edge, e, score
                ORDER BY score DESC
                RETURN
                    edge.uuid AS search_edge_uuid,
                    collect({
                        uuid: e.uuid,
                        source_node_uuid: startNode(e).uuid,
                        target_node_uuid: endNode(e).uuid,
                        created_at: e.created_at,
                        name: e.name,
                        group_id: e.group_id,
                        fact: e.fact,
                        fact_embedding: e.fact_embedding,
                        episodes: e.episodes,
                        expired_at: e.expired_at,
                        valid_at: e.valid_at,
                        invalid_at: e.invalid_at,
                        attributes: properties(e)
                    })[..$limit] AS matches
                """
            )

        results, _, _ = await driver.execute_query(
            query,
            edges=[edge.model_dump() for edge in edges],
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )

    relevant_edges_dict: dict[str, list[EntityEdge]] = {
        result['search_edge_uuid']: [
            get_entity_edge_from_record(record, driver.provider) for record in result['matches']
        ]
        for result in results
    }

    relevant_edges = [relevant_edges_dict.get(edge.uuid, []) for edge in edges]

    return relevant_edges


async def get_edge_invalidation_candidates(
    driver: GraphDriver,
    edges: list[EntityEdge],
    search_filter: SearchFilters,
    min_score: float = DEFAULT_MIN_SCORE,
    limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[list[EntityEdge]]:
    if len(edges) == 0:
        return []

    filter_queries, filter_params = edge_search_filter_query_constructor(
        search_filter, driver.provider
    )

    filter_query = ''
    if filter_queries:
        filter_query = ' AND ' + (' AND '.join(filter_queries))

    if driver.provider == GraphProvider.NEPTUNE:
        query = (
            """
                                                                                                                                    UNWIND $edges AS edge
                                                                                                                                    MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
                                                                                                                                    WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
                                                                                                                                    """
            + filter_query
            + """
            WITH e, edge
            RETURN DISTINCT id(e) as id, e.fact_embedding as source_embedding,
            edge.fact_embedding as target_embedding,
            edge.uuid as search_edge_uuid
            """
        )
        resp, _, _ = await driver.execute_query(
            query,
            edges=[edge.model_dump() for edge in edges],
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )

        # Calculate Cosine similarity then return the edge ids
        input_ids = []
        for r in resp:
            score = calculate_cosine_similarity(
                list(map(float, r['source_embedding'].split(','))), r['target_embedding']
            )
            if score > min_score:
                input_ids.append({'id': r['id'], 'score': score, 'uuid': r['search_edge_uuid']})

        # Match the edge ides and return the values
        query = """
        UNWIND $ids AS edge
        MATCH ()-[e]->()
        WHERE id(e) = edge.id
        WITH edge, e
        ORDER BY edge.score DESC
        RETURN edge.uuid AS search_edge_uuid,
            collect({
                uuid: e.uuid,
                source_node_uuid: startNode(e).uuid,
                target_node_uuid: endNode(e).uuid,
                created_at: e.created_at,
                name: e.name,
                group_id: e.group_id,
                fact: e.fact,
                fact_embedding: [x IN split(e.fact_embedding, ",") | toFloat(x)],
                episodes: split(e.episodes, ","),
                expired_at: e.expired_at,
                valid_at: e.valid_at,
                invalid_at: e.invalid_at,
                attributes: properties(e)
            })[..$limit] AS matches
                """
        results, _, _ = await driver.execute_query(
            query,
            ids=input_ids,
            edges=[edge.model_dump() for edge in edges],
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )
    else:
        if driver.provider == GraphProvider.KUZU:
            embedding_size = (
                len(edges[0].fact_embedding) if edges[0].fact_embedding is not None else 0
            )
            if embedding_size == 0:
                return []

            query = (
                """
                                                                                                                                        UNWIND $edges AS edge
                                                                                                                                        MATCH (n:Entity)-[:RELATES_TO]->(e:RelatesToNode_ {group_id: edge.group_id})-[:RELATES_TO]->(m:Entity)
                                                                                                                                        WHERE (n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid])
                                                                                                                                        """
                + filter_query
                + """
                WITH edge, e, n, m, """
                + get_vector_cosine_func_query(
                    'e.fact_embedding',
                    f'CAST(edge.fact_embedding AS FLOAT[{embedding_size}])',
                    driver.provider,
                )
                + """ AS score
                WHERE score > $min_score
                WITH edge, e, n, m, score
                ORDER BY score DESC
                LIMIT $limit
                RETURN
                    edge.uuid AS search_edge_uuid,
                    collect({
                        uuid: e.uuid,
                        source_node_uuid: n.uuid,
                        target_node_uuid: m.uuid,
                        created_at: e.created_at,
                        name: e.name,
                        group_id: e.group_id,
                        fact: e.fact,
                        fact_embedding: e.fact_embedding,
                        episodes: e.episodes,
                        expired_at: e.expired_at,
                        valid_at: e.valid_at,
                        invalid_at: e.invalid_at,
                        attributes: e.attributes
                    }) AS matches
                """
            )
        else:
            query = (
                """
                                                                                                                                        UNWIND $edges AS edge
                                                                                                                                        MATCH (n:Entity)-[e:RELATES_TO {group_id: edge.group_id}]->(m:Entity)
                                                                                                                                        WHERE n.uuid IN [edge.source_node_uuid, edge.target_node_uuid] OR m.uuid IN [edge.target_node_uuid, edge.source_node_uuid]
                                                                                                                                        """
                + filter_query
                + """
                WITH edge, e, """
                + get_vector_cosine_func_query(
                    'e.fact_embedding', 'edge.fact_embedding', driver.provider
                )
                + """ AS score
                WHERE score > $min_score
                WITH edge, e, score
                ORDER BY score DESC
                RETURN
                    edge.uuid AS search_edge_uuid,
                    collect({
                        uuid: e.uuid,
                        source_node_uuid: startNode(e).uuid,
                        target_node_uuid: endNode(e).uuid,
                        created_at: e.created_at,
                        name: e.name,
                        group_id: e.group_id,
                        fact: e.fact,
                        fact_embedding: e.fact_embedding,
                        episodes: e.episodes,
                        expired_at: e.expired_at,
                        valid_at: e.valid_at,
                        invalid_at: e.invalid_at,
                        attributes: properties(e)
                    })[..$limit] AS matches
                """
            )

        results, _, _ = await driver.execute_query(
            query,
            edges=[edge.model_dump() for edge in edges],
            limit=limit,
            min_score=min_score,
            routing_='r',
            **filter_params,
        )
    invalidation_edges_dict: dict[str, list[EntityEdge]] = {
        result['search_edge_uuid']: [
            get_entity_edge_from_record(record, driver.provider) for record in result['matches']
        ]
        for result in results
    }

    invalidation_edges = [invalidation_edges_dict.get(edge.uuid, []) for edge in edges]

    return invalidation_edges


# takes in a list of rankings of uuids
def rrf(
    results: list[list[str]], rank_const=1, min_score: float = 0
) -> tuple[list[str], list[float]]:
    scores: dict[str, float] = defaultdict(float)
    for result in results:
        for i, uuid in enumerate(result):
            scores[uuid] += 1 / (i + rank_const)

    scored_uuids = [term for term in scores.items()]
    scored_uuids.sort(reverse=True, key=lambda term: term[1])

    sorted_uuids = [term[0] for term in scored_uuids]

    return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
        scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
    ]


async def node_distance_reranker(
    driver: GraphDriver,
    node_uuids: list[str],
    center_node_uuid: str,
    min_score: float = 0,
) -> tuple[list[str], list[float]]:
    # filter out node_uuid center node node uuid
    filtered_uuids = list(filter(lambda node_uuid: node_uuid != center_node_uuid, node_uuids))
    scores: dict[str, float] = {center_node_uuid: 0.0}

    query = """
    UNWIND $node_uuids AS node_uuid
    MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-(n:Entity {uuid: node_uuid})
    RETURN 1 AS score, node_uuid AS uuid
    """
    if driver.provider == GraphProvider.KUZU:
        query = """
        UNWIND $node_uuids AS node_uuid
        MATCH (center:Entity {uuid: $center_uuid})-[:RELATES_TO]->(e:RelatesToNode_)-[:RELATES_TO]->(n:Entity {uuid: node_uuid})
        RETURN 1 AS score, node_uuid AS uuid
        """

    # Find the shortest path to center node
    results, header, _ = await driver.execute_query(
        query,
        node_uuids=filtered_uuids,
        center_uuid=center_node_uuid,
        routing_='r',
    )
    if driver.provider == GraphProvider.FALKORDB:
        results = [dict(zip(header, row, strict=True)) for row in results]

    for result in results:
        uuid = result['uuid']
        score = result['score']
        scores[uuid] = score

    for uuid in filtered_uuids:
        if uuid not in scores:
            scores[uuid] = float('inf')

    # rerank on shortest distance
    filtered_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

    # add back in filtered center uuid if it was filtered out
    if center_node_uuid in node_uuids:
        scores[center_node_uuid] = 0.1
        filtered_uuids = [center_node_uuid] + filtered_uuids

    return [uuid for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score], [
        1 / scores[uuid] for uuid in filtered_uuids if (1 / scores[uuid]) >= min_score
    ]


async def episode_mentions_reranker(
    driver: GraphDriver, node_uuids: list[list[str]], min_score: float = 0
) -> tuple[list[str], list[float]]:
    # use rrf as a preliminary ranker
    sorted_uuids, _ = rrf(node_uuids)
    scores: dict[str, float] = {}

    # Find the shortest path to center node
    results, _, _ = await driver.execute_query(
        """
        UNWIND $node_uuids AS node_uuid
        MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: node_uuid})
        RETURN count(*) AS score, n.uuid AS uuid
        """,
        node_uuids=sorted_uuids,
        routing_='r',
    )

    for result in results:
        scores[result['uuid']] = result['score']

    for uuid in sorted_uuids:
        if uuid not in scores:
            scores[uuid] = float('inf')

    # rerank on shortest distance
    sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

    return [uuid for uuid in sorted_uuids if scores[uuid] >= min_score], [
        scores[uuid] for uuid in sorted_uuids if scores[uuid] >= min_score
    ]


def maximal_marginal_relevance(
    query_vector: list[float],
    candidates: dict[str, list[float]],
    mmr_lambda: float = DEFAULT_MMR_LAMBDA,
    min_score: float = -2.0,
) -> tuple[list[str], list[float]]:
    start = time()
    query_array = np.array(query_vector)
    candidate_arrays: dict[str, NDArray] = {}
    for uuid, embedding in candidates.items():
        candidate_arrays[uuid] = normalize_l2(embedding)

    uuids: list[str] = list(candidate_arrays.keys())

    similarity_matrix = np.zeros((len(uuids), len(uuids)))

    for i, uuid_1 in enumerate(uuids):
        for j, uuid_2 in enumerate(uuids[:i]):
            u = candidate_arrays[uuid_1]
            v = candidate_arrays[uuid_2]
            similarity = np.dot(u, v)

            similarity_matrix[i, j] = similarity
            similarity_matrix[j, i] = similarity

    mmr_scores: dict[str, float] = {}
    for i, uuid in enumerate(uuids):
        max_sim = np.max(similarity_matrix[i, :])
        mmr = mmr_lambda * np.dot(query_array, candidate_arrays[uuid]) + (mmr_lambda - 1) * max_sim
        mmr_scores[uuid] = mmr

    uuids.sort(reverse=True, key=lambda c: mmr_scores[c])

    end = time()
    logger.debug(f'Completed MMR reranking in {(end - start) * 1000} ms')

    return [uuid for uuid in uuids if mmr_scores[uuid] >= min_score], [
        mmr_scores[uuid] for uuid in uuids if mmr_scores[uuid] >= min_score
    ]


async def get_embeddings_for_nodes(
    driver: GraphDriver, nodes: list[EntityNode]
) -> dict[str, list[float]]:
    if driver.graph_operations_interface:
        return await driver.graph_operations_interface.node_load_embeddings_bulk(driver, nodes)
    elif driver.provider == GraphProvider.NEPTUNE:
        query = """
        MATCH (n:Entity)
        WHERE n.uuid IN $node_uuids
        RETURN DISTINCT
            n.uuid AS uuid,
            split(n.name_embedding, ",") AS name_embedding
        """
    else:
        query = """
        MATCH (n:Entity)
        WHERE n.uuid IN $node_uuids
        RETURN DISTINCT
            n.uuid AS uuid,
            n.name_embedding AS name_embedding
        """
    results, _, _ = await driver.execute_query(
        query,
        node_uuids=[node.uuid for node in nodes],
        routing_='r',
    )

    embeddings_dict: dict[str, list[float]] = {}
    for result in results:
        uuid: str = result.get('uuid')
        embedding: list[float] = result.get('name_embedding')
        if uuid is not None and embedding is not None:
            embeddings_dict[uuid] = embedding

    return embeddings_dict


async def get_embeddings_for_communities(
    driver: GraphDriver, communities: list[CommunityNode]
) -> dict[str, list[float]]:
    if driver.provider == GraphProvider.NEPTUNE:
        query = """
        MATCH (c:Community)
        WHERE c.uuid IN $community_uuids
        RETURN DISTINCT
            c.uuid AS uuid,
            split(c.name_embedding, ",") AS name_embedding
        """
    else:
        query = """
        MATCH (c:Community)
        WHERE c.uuid IN $community_uuids
        RETURN DISTINCT
            c.uuid AS uuid,
            c.name_embedding AS name_embedding
        """
    results, _, _ = await driver.execute_query(
        query,
        community_uuids=[community.uuid for community in communities],
        routing_='r',
    )

    embeddings_dict: dict[str, list[float]] = {}
    for result in results:
        uuid: str = result.get('uuid')
        embedding: list[float] = result.get('name_embedding')
        if uuid is not None and embedding is not None:
            embeddings_dict[uuid] = embedding

    return embeddings_dict


async def get_embeddings_for_edges(
    driver: GraphDriver, edges: list[EntityEdge]
) -> dict[str, list[float]]:
    if driver.graph_operations_interface:
        return await driver.graph_operations_interface.edge_load_embeddings_bulk(driver, edges)
    elif driver.provider == GraphProvider.NEPTUNE:
        query = """
        MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
        WHERE e.uuid IN $edge_uuids
        RETURN DISTINCT
            e.uuid AS uuid,
            split(e.fact_embedding, ",") AS fact_embedding
        """
    else:
        match_query = """
            MATCH (n:Entity)-[e:RELATES_TO]-(m:Entity)
        """
        if driver.provider == GraphProvider.KUZU:
            match_query = """
                MATCH (n:Entity)-[:RELATES_TO]-(e:RelatesToNode_)-[:RELATES_TO]-(m:Entity)
            """

        query = (
            match_query
            + """
        WHERE e.uuid IN $edge_uuids
        RETURN DISTINCT
            e.uuid AS uuid,
            e.fact_embedding AS fact_embedding
        """
        )
    results, _, _ = await driver.execute_query(
        query,
        edge_uuids=[edge.uuid for edge in edges],
        routing_='r',
    )

    embeddings_dict: dict[str, list[float]] = {}
    for result in results:
        uuid: str = result.get('uuid')
        embedding: list[float] = result.get('fact_embedding')
        if uuid is not None and embedding is not None:
            embeddings_dict[uuid] = embedding

    return embeddings_dict

```
Page 8/9FirstPrevNextLast