This is page 3 of 9. Use http://codebase.md/getzep/graphiti?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/.github/workflows/release-mcp-server.yml:
--------------------------------------------------------------------------------
```yaml
name: Release MCP Server
on:
push:
tags: ["mcp-v*.*.*"]
workflow_dispatch:
inputs:
tag:
description: 'Existing tag to release (e.g., mcp-v1.0.0) - tag must exist in repo'
required: true
type: string
env:
REGISTRY: docker.io
IMAGE_NAME: zepai/knowledge-graph-mcp
jobs:
release:
runs-on: depot-ubuntu-24.04-small
permissions:
contents: write
id-token: write
environment:
name: release
strategy:
matrix:
variant:
- name: standalone
dockerfile: docker/Dockerfile.standalone
image_suffix: "-standalone"
tag_latest: "standalone"
title: "Graphiti MCP Server (Standalone)"
description: "Standalone Graphiti MCP server for external Neo4j or FalkorDB"
- name: combined
dockerfile: docker/Dockerfile
image_suffix: ""
tag_latest: "latest"
title: "FalkorDB + Graphiti MCP Server"
description: "Combined FalkorDB graph database with Graphiti MCP server"
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
ref: ${{ inputs.tag || github.ref }}
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Extract and validate version
id: version
run: |
# Extract tag from either push event or manual workflow_dispatch input
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
TAG_FULL="${{ inputs.tag }}"
TAG_VERSION=${TAG_FULL#mcp-v}
else
TAG_VERSION=${GITHUB_REF#refs/tags/mcp-v}
fi
# Validate semantic versioning format
if ! [[ $TAG_VERSION =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
echo "Error: Tag must follow semantic versioning: mcp-vX.Y.Z (e.g., mcp-v1.0.0)"
echo "Received: mcp-v$TAG_VERSION"
exit 1
fi
# Validate against pyproject.toml version
PROJECT_VERSION=$(python -c "import tomllib; print(tomllib.load(open('mcp_server/pyproject.toml', 'rb'))['project']['version'])")
if [ "$TAG_VERSION" != "$PROJECT_VERSION" ]; then
echo "Error: Tag version mcp-v$TAG_VERSION does not match mcp_server/pyproject.toml version $PROJECT_VERSION"
exit 1
fi
echo "version=$PROJECT_VERSION" >> $GITHUB_OUTPUT
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Depot CLI
uses: depot/setup-action@v1
- name: Get latest graphiti-core version from PyPI
id: graphiti
run: |
# Query PyPI for the latest graphiti-core version with error handling
set -eo pipefail
if ! GRAPHITI_VERSION=$(curl -sf https://pypi.org/pypi/graphiti-core/json | python -c "import sys, json; data=json.load(sys.stdin); print(data['info']['version'])"); then
echo "Error: Failed to fetch graphiti-core version from PyPI"
exit 1
fi
if [ -z "$GRAPHITI_VERSION" ]; then
echo "Error: Empty version returned from PyPI"
exit 1
fi
echo "graphiti_version=${GRAPHITI_VERSION}" >> $GITHUB_OUTPUT
echo "Latest Graphiti Core version from PyPI: ${GRAPHITI_VERSION}"
- name: Extract metadata
id: meta
run: |
# Get build date
echo "build_date=$(date -u +%Y-%m-%dT%H:%M:%SZ)" >> $GITHUB_OUTPUT
- name: Generate Docker metadata
id: docker_meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=raw,value=${{ steps.version.outputs.version }}${{ matrix.variant.image_suffix }}
type=raw,value=${{ steps.version.outputs.version }}-graphiti-${{ steps.graphiti.outputs.graphiti_version }}${{ matrix.variant.image_suffix }}
type=raw,value=${{ matrix.variant.tag_latest }}
labels: |
org.opencontainers.image.title=${{ matrix.variant.title }}
org.opencontainers.image.description=${{ matrix.variant.description }}
org.opencontainers.image.version=${{ steps.version.outputs.version }}
org.opencontainers.image.vendor=Zep AI
graphiti.core.version=${{ steps.graphiti.outputs.graphiti_version }}
- name: Build and push Docker image (${{ matrix.variant.name }})
uses: depot/build-push-action@v1
with:
project: v9jv1mlpwc
context: ./mcp_server
file: ./mcp_server/${{ matrix.variant.dockerfile }}
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.docker_meta.outputs.tags }}
labels: ${{ steps.docker_meta.outputs.labels }}
build-args: |
MCP_SERVER_VERSION=${{ steps.version.outputs.version }}
GRAPHITI_CORE_VERSION=${{ steps.graphiti.outputs.graphiti_version }}
BUILD_DATE=${{ steps.meta.outputs.build_date }}
VCS_REF=${{ steps.version.outputs.version }}
- name: Create release summary
run: |
{
echo "## MCP Server Release Summary - ${{ matrix.variant.title }}"
echo ""
echo "**MCP Server Version:** ${{ steps.version.outputs.version }}"
echo "**Graphiti Core Version:** ${{ steps.graphiti.outputs.graphiti_version }}"
echo "**Build Date:** ${{ steps.meta.outputs.build_date }}"
echo ""
echo "### Docker Image Tags"
echo "${{ steps.docker_meta.outputs.tags }}" | tr ',' '\n' | sed 's/^/- /'
echo ""
} >> $GITHUB_STEP_SUMMARY
```
--------------------------------------------------------------------------------
/.github/workflows/release-server-container.yml:
--------------------------------------------------------------------------------
```yaml
name: Release Server Container
on:
workflow_run:
workflows: ["Release to PyPI"]
types: [completed]
branches: [main]
workflow_dispatch:
inputs:
version:
description: 'Graphiti core version to build (e.g., 0.22.1)'
required: false
env:
REGISTRY: docker.io
IMAGE_NAME: zepai/graphiti
jobs:
build-and-push:
runs-on: depot-ubuntu-24.04-small
if: ${{ github.event.workflow_run.conclusion == 'success' || github.event_name == 'workflow_dispatch' }}
permissions:
contents: write
id-token: write
environment:
name: release
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.workflow_run.head_sha || github.ref }}
- name: Set up Python 3.11
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"
- name: Extract version
id: version
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ] && [ -n "${{ github.event.inputs.version }}" ]; then
VERSION="${{ github.event.inputs.version }}"
echo "Using manual input version: $VERSION"
else
# When triggered by workflow_run, get the tag that triggered the PyPI release
# The PyPI workflow is triggered by tags matching v*.*.*
VERSION=$(git tag --points-at HEAD | grep '^v[0-9]' | head -1 | sed 's/^v//')
if [ -z "$VERSION" ]; then
# Fallback: check pyproject.toml version
VERSION=$(uv run python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])")
echo "Version from pyproject.toml: $VERSION"
else
echo "Version from git tag: $VERSION"
fi
if [ -z "$VERSION" ]; then
echo "Could not determine version"
exit 1
fi
fi
# Validate it's a stable release - catch all Python pre-release patterns
# Matches: pre, rc, alpha, beta, a1, b2, dev0, etc.
if [[ $VERSION =~ (pre|rc|alpha|beta|a[0-9]+|b[0-9]+|\.dev[0-9]*) ]]; then
echo "Skipping pre-release version: $VERSION"
echo "skip=true" >> $GITHUB_OUTPUT
exit 0
fi
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "skip=false" >> $GITHUB_OUTPUT
- name: Wait for PyPI availability
if: steps.version.outputs.skip != 'true'
run: |
VERSION="${{ steps.version.outputs.version }}"
echo "Checking PyPI for graphiti-core version $VERSION..."
MAX_ATTEMPTS=10
SLEEP_TIME=30
for i in $(seq 1 $MAX_ATTEMPTS); do
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" "https://pypi.org/pypi/graphiti-core/$VERSION/json")
if [ "$HTTP_CODE" == "200" ]; then
echo "✓ graphiti-core $VERSION is available on PyPI"
exit 0
fi
echo "Attempt $i/$MAX_ATTEMPTS: graphiti-core $VERSION not yet available (HTTP $HTTP_CODE)"
if [ $i -lt $MAX_ATTEMPTS ]; then
echo "Waiting ${SLEEP_TIME}s before retry..."
sleep $SLEEP_TIME
fi
done
echo "ERROR: graphiti-core $VERSION not available on PyPI after $MAX_ATTEMPTS attempts"
exit 1
- name: Log in to Docker Hub
if: steps.version.outputs.skip != 'true'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Depot CLI
if: steps.version.outputs.skip != 'true'
uses: depot/setup-action@v1
- name: Extract metadata
if: steps.version.outputs.skip != 'true'
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=raw,value=${{ steps.version.outputs.version }}
type=raw,value=latest
labels: |
org.opencontainers.image.title=Graphiti FastAPI Server
org.opencontainers.image.description=FastAPI server for Graphiti temporal knowledge graphs
org.opencontainers.image.version=${{ steps.version.outputs.version }}
io.graphiti.core.version=${{ steps.version.outputs.version }}
- name: Build and push Docker image
if: steps.version.outputs.skip != 'true'
uses: depot/build-push-action@v1
with:
project: v9jv1mlpwc
context: .
file: ./Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
build-args: |
GRAPHITI_VERSION=${{ steps.version.outputs.version }}
BUILD_DATE=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.created'] }}
VCS_REF=${{ github.sha }}
- name: Summary
if: steps.version.outputs.skip != 'true'
run: |
echo "## 🚀 Server Container Released" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "- **Version**: ${{ steps.version.outputs.version }}" >> $GITHUB_STEP_SUMMARY
echo "- **Image**: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}" >> $GITHUB_STEP_SUMMARY
echo "- **Tags**: ${{ steps.version.outputs.version }}, latest" >> $GITHUB_STEP_SUMMARY
echo "- **Platforms**: linux/amd64, linux/arm64" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Pull the image:" >> $GITHUB_STEP_SUMMARY
echo '```bash' >> $GITHUB_STEP_SUMMARY
echo "docker pull ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.version.outputs.version }}" >> $GITHUB_STEP_SUMMARY
echo '```' >> $GITHUB_STEP_SUMMARY
```
--------------------------------------------------------------------------------
/graphiti_core/tracer.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager, suppress
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from opentelemetry.trace import Span, StatusCode
try:
from opentelemetry.trace import Span, StatusCode
OTEL_AVAILABLE = True
except ImportError:
OTEL_AVAILABLE = False
class TracerSpan(ABC):
"""Abstract base class for tracer spans."""
@abstractmethod
def add_attributes(self, attributes: dict[str, Any]) -> None:
"""Add attributes to the span."""
pass
@abstractmethod
def set_status(self, status: str, description: str | None = None) -> None:
"""Set the status of the span."""
pass
@abstractmethod
def record_exception(self, exception: Exception) -> None:
"""Record an exception in the span."""
pass
class Tracer(ABC):
"""Abstract base class for tracers."""
@abstractmethod
def start_span(self, name: str) -> AbstractContextManager[TracerSpan]:
"""Start a new span with the given name."""
pass
class NoOpSpan(TracerSpan):
"""No-op span implementation that does nothing."""
def add_attributes(self, attributes: dict[str, Any]) -> None:
pass
def set_status(self, status: str, description: str | None = None) -> None:
pass
def record_exception(self, exception: Exception) -> None:
pass
class NoOpTracer(Tracer):
"""No-op tracer implementation that does nothing."""
@contextmanager
def start_span(self, name: str) -> Generator[NoOpSpan, None, None]:
"""Return a no-op span."""
yield NoOpSpan()
class OpenTelemetrySpan(TracerSpan):
"""Wrapper for OpenTelemetry span."""
def __init__(self, span: 'Span'):
self._span = span
def add_attributes(self, attributes: dict[str, Any]) -> None:
"""Add attributes to the OpenTelemetry span."""
try:
# Filter out None values and convert all values to appropriate types
filtered_attrs = {}
for key, value in attributes.items():
if value is not None:
# Convert to string if not a primitive type
if isinstance(value, str | int | float | bool):
filtered_attrs[key] = value
else:
filtered_attrs[key] = str(value)
if filtered_attrs:
self._span.set_attributes(filtered_attrs)
except Exception:
# Silently ignore tracing errors
pass
def set_status(self, status: str, description: str | None = None) -> None:
"""Set the status of the OpenTelemetry span."""
try:
if OTEL_AVAILABLE:
if status == 'error':
self._span.set_status(StatusCode.ERROR, description)
elif status == 'ok':
self._span.set_status(StatusCode.OK, description)
except Exception:
# Silently ignore tracing errors
pass
def record_exception(self, exception: Exception) -> None:
"""Record an exception in the OpenTelemetry span."""
with suppress(Exception):
self._span.record_exception(exception)
class OpenTelemetryTracer(Tracer):
"""Wrapper for OpenTelemetry tracer with configurable span name prefix."""
def __init__(self, tracer: Any, span_prefix: str = 'graphiti'):
"""
Initialize the OpenTelemetry tracer wrapper.
Parameters
----------
tracer : opentelemetry.trace.Tracer
The OpenTelemetry tracer instance.
span_prefix : str, optional
Prefix to prepend to all span names. Defaults to 'graphiti'.
"""
if not OTEL_AVAILABLE:
raise ImportError(
'OpenTelemetry is not installed. Install it with: pip install opentelemetry-api'
)
self._tracer = tracer
self._span_prefix = span_prefix.rstrip('.')
@contextmanager
def start_span(self, name: str) -> Generator[OpenTelemetrySpan | NoOpSpan, None, None]:
"""Start a new OpenTelemetry span with the configured prefix."""
try:
full_name = f'{self._span_prefix}.{name}'
with self._tracer.start_as_current_span(full_name) as span:
yield OpenTelemetrySpan(span)
except Exception:
# If tracing fails, yield a no-op span to prevent breaking the operation
yield NoOpSpan()
def create_tracer(otel_tracer: Any | None = None, span_prefix: str = 'graphiti') -> Tracer:
"""
Create a tracer instance.
Parameters
----------
otel_tracer : opentelemetry.trace.Tracer | None, optional
An OpenTelemetry tracer instance. If None, a no-op tracer is returned.
span_prefix : str, optional
Prefix to prepend to all span names. Defaults to 'graphiti'.
Returns
-------
Tracer
A tracer instance (either OpenTelemetryTracer or NoOpTracer).
Examples
--------
Using with OpenTelemetry:
>>> from opentelemetry import trace
>>> otel_tracer = trace.get_tracer(__name__)
>>> tracer = create_tracer(otel_tracer, span_prefix='myapp.graphiti')
Using no-op tracer:
>>> tracer = create_tracer() # Returns NoOpTracer
"""
if otel_tracer is None:
return NoOpTracer()
if not OTEL_AVAILABLE:
return NoOpTracer()
return OpenTelemetryTracer(otel_tracer, span_prefix)
```
--------------------------------------------------------------------------------
/graphiti_core/cross_encoder/gemini_reranker_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import re
from typing import TYPE_CHECKING
from ..helpers import semaphore_gather
from ..llm_client import LLMConfig, RateLimitError
from .client import CrossEncoderClient
if TYPE_CHECKING:
from google import genai
from google.genai import types
else:
try:
from google import genai
from google.genai import types
except ImportError:
raise ImportError(
'google-genai is required for GeminiRerankerClient. '
'Install it with: pip install graphiti-core[google-genai]'
) from None
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gemini-2.5-flash-lite'
class GeminiRerankerClient(CrossEncoderClient):
"""
Google Gemini Reranker Client
"""
def __init__(
self,
config: LLMConfig | None = None,
client: 'genai.Client | None' = None,
):
"""
Initialize the GeminiRerankerClient with the provided configuration and client.
The Gemini Developer API does not yet support logprobs. Unlike the OpenAI reranker,
this reranker uses the Gemini API to perform direct relevance scoring of passages.
Each passage is scored individually on a 0-100 scale.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
"""
if config is None:
config = LLMConfig()
self.config = config
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]:
"""
Rank passages based on their relevance to the query using direct scoring.
Each passage is scored individually on a 0-100 scale, then normalized to [0,1].
"""
if len(passages) <= 1:
return [(passage, 1.0) for passage in passages]
# Generate scoring prompts for each passage
scoring_prompts = []
for passage in passages:
prompt = f"""Rate how well this passage answers or relates to the query. Use a scale from 0 to 100.
Query: {query}
Passage: {passage}
Provide only a number between 0 and 100 (no explanation, just the number):"""
scoring_prompts.append(
[
types.Content(
role='user',
parts=[types.Part.from_text(text=prompt)],
),
]
)
try:
# Execute all scoring requests concurrently - O(n) API calls
responses = await semaphore_gather(
*[
self.client.aio.models.generate_content(
model=self.config.model or DEFAULT_MODEL,
contents=prompt_messages, # type: ignore
config=types.GenerateContentConfig(
system_instruction='You are an expert at rating passage relevance. Respond with only a number from 0-100.',
temperature=0.0,
max_output_tokens=3,
),
)
for prompt_messages in scoring_prompts
]
)
# Extract scores and create results
results = []
for passage, response in zip(passages, responses, strict=True):
try:
if hasattr(response, 'text') and response.text:
# Extract numeric score from response
score_text = response.text.strip()
# Handle cases where model might return non-numeric text
score_match = re.search(r'\b(\d{1,3})\b', score_text)
if score_match:
score = float(score_match.group(1))
# Normalize to [0, 1] range and clamp to valid range
normalized_score = max(0.0, min(1.0, score / 100.0))
results.append((passage, normalized_score))
else:
logger.warning(
f'Could not extract numeric score from response: {score_text}'
)
results.append((passage, 0.0))
else:
logger.warning('Empty response from Gemini for passage scoring')
results.append((passage, 0.0))
except (ValueError, AttributeError) as e:
logger.warning(f'Error parsing score from Gemini response: {e}')
results.append((passage, 0.0))
# Sort by score in descending order (highest relevance first)
results.sort(reverse=True, key=lambda x: x[1])
return results
except Exception as e:
# Check if it's a rate limit error based on Gemini API error codes
error_message = str(e).lower()
if (
'rate limit' in error_message
or 'quota' in error_message
or 'resource_exhausted' in error_message
or '429' in str(e)
):
raise RateLimitError from e
logger.error(f'Error in generating LLM response: {e}')
raise
```
--------------------------------------------------------------------------------
/graphiti_core/prompts/dedupe_edges.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Protocol, TypedDict
from pydantic import BaseModel, Field
from .models import Message, PromptFunction, PromptVersion
from .prompt_helpers import to_prompt_json
class EdgeDuplicate(BaseModel):
duplicate_facts: list[int] = Field(
...,
description='List of idx values of any duplicate facts. If no duplicate facts are found, default to empty list.',
)
contradicted_facts: list[int] = Field(
...,
description='List of idx values of facts that should be invalidated. If no facts should be invalidated, the list should be empty.',
)
fact_type: str = Field(..., description='One of the provided fact types or DEFAULT')
class UniqueFact(BaseModel):
uuid: str = Field(..., description='unique identifier of the fact')
fact: str = Field(..., description='fact of a unique edge')
class UniqueFacts(BaseModel):
unique_facts: list[UniqueFact]
class Prompt(Protocol):
edge: PromptVersion
edge_list: PromptVersion
resolve_edge: PromptVersion
class Versions(TypedDict):
edge: PromptFunction
edge_list: PromptFunction
resolve_edge: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
<EXISTING EDGES>
{to_prompt_json(context['related_edges'])}
</EXISTING EDGES>
<NEW EDGE>
{to_prompt_json(context['extracted_edges'])}
</NEW EDGE>
Task:
If the New Edges represents the same factual information as any edge in Existing Edges, return the id of the duplicate fact
as part of the list of duplicate_facts.
If the NEW EDGE is not a duplicate of any of the EXISTING EDGES, return an empty list.
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
""",
),
]
def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, find all of the duplicates in a list of facts:
Facts:
{to_prompt_json(context['edges'])}
Task:
If any facts in Facts is a duplicate of another fact, return a new fact with one of their uuid's.
Guidelines:
1. identical or near identical facts are duplicates
2. Facts are also duplicates if they are represented by similar sentences
3. Facts will often discuss the same or similar relation between identical entities
4. The final list should have only unique facts. If 3 facts are all duplicates of each other, only one of their
facts should be in the response
""",
),
]
def resolve_edge(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates facts from fact lists and determines which existing '
'facts are contradicted by the new fact.',
),
Message(
role='user',
content=f"""
Task:
You will receive TWO separate lists of facts. Each list uses 'idx' as its index field, starting from 0.
1. DUPLICATE DETECTION:
- If the NEW FACT represents identical factual information as any fact in EXISTING FACTS, return those idx values in duplicate_facts.
- Facts with similar information that contain key differences should NOT be marked as duplicates.
- Return idx values from EXISTING FACTS.
- If no duplicates, return an empty list for duplicate_facts.
2. FACT TYPE CLASSIFICATION:
- Given the predefined FACT TYPES, determine if the NEW FACT should be classified as one of these types.
- Return the fact type as fact_type or DEFAULT if NEW FACT is not one of the FACT TYPES.
3. CONTRADICTION DETECTION:
- Based on FACT INVALIDATION CANDIDATES and NEW FACT, determine which facts the new fact contradicts.
- Return idx values from FACT INVALIDATION CANDIDATES.
- If no contradictions, return an empty list for contradicted_facts.
IMPORTANT:
- duplicate_facts: Use ONLY 'idx' values from EXISTING FACTS
- contradicted_facts: Use ONLY 'idx' values from FACT INVALIDATION CANDIDATES
- These are two separate lists with independent idx ranges starting from 0
Guidelines:
1. Some facts may be very similar but will have key differences, particularly around numeric values in the facts.
Do not mark these facts as duplicates.
<FACT TYPES>
{context['edge_types']}
</FACT TYPES>
<EXISTING FACTS>
{context['existing_edges']}
</EXISTING FACTS>
<FACT INVALIDATION CANDIDATES>
{context['edge_invalidation_candidates']}
</FACT INVALIDATION CANDIDATES>
<NEW FACT>
{context['new_edge']}
</NEW FACT>
""",
),
]
versions: Versions = {'edge': edge, 'edge_list': edge_list, 'resolve_edge': resolve_edge}
```
--------------------------------------------------------------------------------
/.github/workflows/issue-triage.yml:
--------------------------------------------------------------------------------
```yaml
name: Issue Triage and Deduplication
on:
issues:
types: [opened]
jobs:
triage:
runs-on: ubuntu-latest
timeout-minutes: 10
permissions:
contents: read
issues: write
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Run Claude Code for Issue Triage
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
allowed_non_write_users: "*"
github_token: ${{ secrets.GITHUB_TOKEN }}
prompt: |
You're an issue triage assistant for GitHub issues. Your task is to analyze the issue and select appropriate labels from the provided list.
IMPORTANT: Don't post any comments or messages to the issue. Your only action should be to apply labels. DO NOT check for duplicates - that's handled by a separate job.
Issue Information:
- REPO: ${{ github.repository }}
- ISSUE_NUMBER: ${{ github.event.issue.number }}
TASK OVERVIEW:
1. First, fetch the list of labels available in this repository by running: `gh label list`. Run exactly this command with nothing else.
2. Next, use gh commands to get context about the issue:
- Use `gh issue view ${{ github.event.issue.number }}` to retrieve the current issue's details
- Use `gh search issues` to find similar issues that might provide context for proper categorization
- You have access to these Bash commands:
- Bash(gh label list:*) - to get available labels
- Bash(gh issue view:*) - to view issue details
- Bash(gh issue edit:*) - to apply labels to the issue
- Bash(gh search:*) - to search for similar issues
3. Analyze the issue content, considering:
- The issue title and description
- The type of issue (bug report, feature request, question, etc.)
- Technical areas mentioned
- Database mentions (neo4j, falkordb, neptune, etc.)
- LLM providers mentioned (openai, anthropic, gemini, groq, etc.)
- Components affected (embeddings, search, prompts, server, mcp, etc.)
4. Select appropriate labels from the available labels list:
- Choose labels that accurately reflect the issue's nature
- Be specific but comprehensive
- Add database-specific labels if mentioned: neo4j, falkordb, neptune
- Add component labels if applicable
- DO NOT add priority labels (P1, P2, P3)
- DO NOT add duplicate label - that's handled by the deduplication job
5. Apply the selected labels:
- Use `gh issue edit ${{ github.event.issue.number }} --add-label "label1,label2,label3"` to apply your selected labels
- DO NOT post any comments explaining your decision
- DO NOT communicate directly with users
- If no labels are clearly applicable, do not apply any labels
IMPORTANT GUIDELINES:
- Be thorough in your analysis
- Only select labels from the provided list
- DO NOT post any comments to the issue
- Your ONLY action should be to apply labels using gh issue edit
- It's okay to not add any labels if none are clearly applicable
- DO NOT check for duplicates
claude_args: |
--allowedTools "Bash(gh label list:*),Bash(gh issue view:*),Bash(gh issue edit:*),Bash(gh search:*)"
--model claude-sonnet-4-5-20250929
deduplicate:
runs-on: ubuntu-latest
timeout-minutes: 10
needs: triage
permissions:
contents: read
issues: write
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
fetch-depth: 1
- name: Check for duplicate issues
uses: anthropics/claude-code-action@v1
with:
allowed_non_write_users: "*"
prompt: |
Analyze this new issue and check if it's a duplicate of existing issues in the repository.
Issue: #${{ github.event.issue.number }}
Repository: ${{ github.repository }}
Your task:
1. Use mcp__github__get_issue to get details of the current issue (#${{ github.event.issue.number }})
2. Search for similar existing OPEN issues using mcp__github__search_issues with relevant keywords from the issue title and body
3. Compare the new issue with existing ones to identify potential duplicates
Criteria for duplicates:
- Same bug or error being reported
- Same feature request (even if worded differently)
- Same question being asked
- Issues describing the same root problem
If you find duplicates:
- Add a comment on the new issue linking to the original issue(s)
- Apply the "duplicate" label to the new issue
- Be polite and explain why it's a duplicate
- Suggest the user follow the original issue for updates
If it's NOT a duplicate:
- Don't add any comments
- Don't modify labels
Use these tools:
- mcp__github__get_issue: Get issue details
- mcp__github__search_issues: Search for similar issues (use state:open)
- mcp__github__list_issues: List recent issues if needed
- mcp__github__create_issue_comment: Add a comment if duplicate found
- mcp__github__update_issue: Add "duplicate" label
Be thorough but efficient. Focus on finding true duplicates, not just similar issues.
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--allowedTools "mcp__github__get_issue,mcp__github__search_issues,mcp__github__list_issues,mcp__github__create_issue_comment,mcp__github__update_issue,mcp__github__get_issue_comments"
--model claude-sonnet-4-5-20250929
```
--------------------------------------------------------------------------------
/tests/utils/search/search_utils_test.py:
--------------------------------------------------------------------------------
```python
from unittest.mock import AsyncMock, patch
import pytest
from graphiti_core.nodes import EntityNode
from graphiti_core.search.search_filters import SearchFilters
from graphiti_core.search.search_utils import hybrid_node_search
@pytest.mark.asyncio
async def test_hybrid_node_search_deduplication():
# Mock the database driver
mock_driver = AsyncMock()
# Mock the node_fulltext_search and entity_similarity_search functions
with (
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
):
# Set up mock return values
mock_fulltext_search.side_effect = [
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
[EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1')],
]
mock_similarity_search.side_effect = [
[EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')],
[EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1')],
]
# Call the function with test data
queries = ['Alice', 'Bob']
embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
# Assertions
assert len(results) == 3
assert set(node.uuid for node in results) == {'1', '2', '3'}
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
# Verify that the mock functions were called correctly
assert mock_fulltext_search.call_count == 2
assert mock_similarity_search.call_count == 2
@pytest.mark.asyncio
async def test_hybrid_node_search_empty_results():
mock_driver = AsyncMock()
with (
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
):
mock_fulltext_search.return_value = []
mock_similarity_search.return_value = []
queries = ['NonExistent']
embeddings = [[0.1, 0.2, 0.3]]
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
assert len(results) == 0
@pytest.mark.asyncio
async def test_hybrid_node_search_only_fulltext():
mock_driver = AsyncMock()
with (
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
):
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1')
]
mock_similarity_search.return_value = []
queries = ['Alice']
embeddings = []
results = await hybrid_node_search(queries, embeddings, mock_driver, SearchFilters())
assert len(results) == 1
assert results[0].name == 'Alice'
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 0
@pytest.mark.asyncio
async def test_hybrid_node_search_with_limit():
mock_driver = AsyncMock()
with (
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
):
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
]
mock_similarity_search.return_value = [
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
EntityNode(
uuid='4',
name='David',
labels=['Entity'],
group_id='1',
),
]
queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]]
limit = 1
results = await hybrid_node_search(
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
)
# We expect 4 results because the limit is applied per search method
# before deduplication, and we're not actually limiting the results
# in the hybrid_node_search function itself
assert len(results) == 4
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
# Verify that the limit was passed to the search functions
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 2)
mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 2
)
@pytest.mark.asyncio
async def test_hybrid_node_search_with_limit_and_duplicates():
mock_driver = AsyncMock()
with (
patch('graphiti_core.search.search_utils.node_fulltext_search') as mock_fulltext_search,
patch('graphiti_core.search.search_utils.node_similarity_search') as mock_similarity_search,
):
mock_fulltext_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'),
EntityNode(uuid='2', name='Bob', labels=['Entity'], group_id='1'),
]
mock_similarity_search.return_value = [
EntityNode(uuid='1', name='Alice', labels=['Entity'], group_id='1'), # Duplicate
EntityNode(uuid='3', name='Charlie', labels=['Entity'], group_id='1'),
]
queries = ['Test']
embeddings = [[0.1, 0.2, 0.3]]
limit = 2
results = await hybrid_node_search(
queries, embeddings, mock_driver, SearchFilters(), ['1'], limit
)
# We expect 3 results because:
# 1. The limit of 2 is applied to each search method
# 2. We get 2 results from fulltext and 2 from similarity
# 3. One result is a duplicate (Alice), so it's only included once
assert len(results) == 3
assert set(node.name for node in results) == {'Alice', 'Bob', 'Charlie'}
assert mock_fulltext_search.call_count == 1
assert mock_similarity_search.call_count == 1
mock_fulltext_search.assert_called_with(mock_driver, 'Test', SearchFilters(), ['1'], 4)
mock_similarity_search.assert_called_with(
mock_driver, [0.1, 0.2, 0.3], SearchFilters(), ['1'], 4
)
```
--------------------------------------------------------------------------------
/tests/evals/eval_e2e_graph_building.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
from datetime import datetime, timezone
import pandas as pd
from graphiti_core import Graphiti
from graphiti_core.graphiti import AddEpisodeResults
from graphiti_core.helpers import semaphore_gather
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EpisodeType
from graphiti_core.prompts import prompt_library
from graphiti_core.prompts.eval import EvalAddEpisodeResults
from tests.test_graphiti_int import NEO4J_URI, NEO4j_PASSWORD, NEO4j_USER
async def build_subgraph(
graphiti: Graphiti,
user_id: str,
multi_session,
multi_session_dates,
session_length: int,
group_id_suffix: str,
) -> tuple[str, list[AddEpisodeResults], list[str]]:
add_episode_results: list[AddEpisodeResults] = []
add_episode_context: list[str] = []
message_count = 0
for session_idx, session in enumerate(multi_session):
for _, msg in enumerate(session):
if message_count >= session_length:
continue
message_count += 1
date = multi_session_dates[session_idx] + ' UTC'
date_format = '%Y/%m/%d (%a) %H:%M UTC'
date_string = datetime.strptime(date, date_format).replace(tzinfo=timezone.utc)
episode_body = f'{msg["role"]}: {msg["content"]}'
results = await graphiti.add_episode(
name='',
episode_body=episode_body,
reference_time=date_string,
source=EpisodeType.message,
source_description='',
group_id=user_id + '_' + group_id_suffix,
)
for node in results.nodes:
node.name_embedding = None
for edge in results.edges:
edge.fact_embedding = None
add_episode_results.append(results)
add_episode_context.append(msg['content'])
return user_id, add_episode_results, add_episode_context
async def build_graph(
group_id_suffix: str, multi_session_count: int, session_length: int, graphiti: Graphiti
) -> tuple[dict[str, list[AddEpisodeResults]], dict[str, list[str]]]:
# Get longmemeval dataset
lme_dataset_option = (
'data/longmemeval_data/longmemeval_oracle.json' # Can be _oracle, _s, or _m
)
lme_dataset_df = pd.read_json(lme_dataset_option)
add_episode_results: dict[str, list[AddEpisodeResults]] = {}
add_episode_context: dict[str, list[str]] = {}
subgraph_results: list[tuple[str, list[AddEpisodeResults], list[str]]] = await semaphore_gather(
*[
build_subgraph(
graphiti,
user_id='lme_oracle_experiment_user_' + str(multi_session_idx),
multi_session=lme_dataset_df['haystack_sessions'].iloc[multi_session_idx],
multi_session_dates=lme_dataset_df['haystack_dates'].iloc[multi_session_idx],
session_length=session_length,
group_id_suffix=group_id_suffix,
)
for multi_session_idx in range(multi_session_count)
]
)
for user_id, episode_results, episode_context in subgraph_results:
add_episode_results[user_id] = episode_results
add_episode_context[user_id] = episode_context
return add_episode_results, add_episode_context
async def build_baseline_graph(multi_session_count: int, session_length: int):
# Use gpt-4.1-mini for graph building baseline
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
add_episode_results, _ = await build_graph(
'baseline', multi_session_count, session_length, graphiti
)
filename = 'baseline_graph_results.json'
serializable_baseline_graph_results = {
key: [item.model_dump(mode='json') for item in value]
for key, value in add_episode_results.items()
}
with open(filename, 'w') as file:
json.dump(serializable_baseline_graph_results, file, indent=4, default=str)
async def eval_graph(multi_session_count: int, session_length: int, llm_client=None) -> float:
if llm_client is None:
llm_client = OpenAIClient(config=LLMConfig(model='gpt-4.1-mini'))
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, llm_client=llm_client)
with open('baseline_graph_results.json') as file:
baseline_results_raw = json.load(file)
baseline_results: dict[str, list[AddEpisodeResults]] = {
key: [AddEpisodeResults(**item) for item in value]
for key, value in baseline_results_raw.items()
}
add_episode_results, add_episode_context = await build_graph(
'candidate', multi_session_count, session_length, graphiti
)
filename = 'candidate_graph_results.json'
candidate_baseline_graph_results = {
key: [item.model_dump(mode='json') for item in value]
for key, value in add_episode_results.items()
}
with open(filename, 'w') as file:
json.dump(candidate_baseline_graph_results, file, indent=4, default=str)
raw_score = 0
user_count = 0
for user_id in add_episode_results:
user_count += 1
user_raw_score = 0
for baseline_result, add_episode_result, episodes in zip(
baseline_results[user_id],
add_episode_results[user_id],
add_episode_context[user_id],
strict=False,
):
context = {
'baseline': baseline_result,
'candidate': add_episode_result,
'message': episodes[0],
'previous_messages': episodes[1:],
}
llm_response = await llm_client.generate_response(
prompt_library.eval.eval_add_episode_results(context),
response_model=EvalAddEpisodeResults,
)
candidate_is_worse = llm_response.get('candidate_is_worse', False)
user_raw_score += 0 if candidate_is_worse else 1
print('llm_response:', llm_response)
user_score = user_raw_score / len(add_episode_results[user_id])
raw_score += user_score
score = raw_score / user_count
return score
```
--------------------------------------------------------------------------------
/graphiti_core/prompts/extract_edges.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Protocol, TypedDict
from pydantic import BaseModel, Field
from .models import Message, PromptFunction, PromptVersion
from .prompt_helpers import to_prompt_json
class Edge(BaseModel):
relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
source_entity_id: int = Field(
..., description='The id of the source entity from the ENTITIES list'
)
target_entity_id: int = Field(
..., description='The id of the target entity from the ENTITIES list'
)
fact: str = Field(
...,
description='A natural language description of the relationship between the entities, paraphrased from the source text',
)
valid_at: str | None = Field(
None,
description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
)
invalid_at: str | None = Field(
None,
description='The date and time when the relationship described by the edge fact stopped being true or ended. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
)
class ExtractedEdges(BaseModel):
edges: list[Edge]
class MissingFacts(BaseModel):
missing_facts: list[str] = Field(..., description="facts that weren't extracted")
class Prompt(Protocol):
edge: PromptVersion
reflexion: PromptVersion
extract_attributes: PromptVersion
class Versions(TypedDict):
edge: PromptFunction
reflexion: PromptFunction
extract_attributes: PromptFunction
def edge(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an expert fact extractor that extracts fact triples from text. '
'1. Extracted fact triples should also be extracted with relevant date information.'
'2. Treat the CURRENT TIME as the time the CURRENT MESSAGE was sent. All temporal information should be extracted relative to this time.',
),
Message(
role='user',
content=f"""
<FACT TYPES>
{context['edge_types']}
</FACT TYPES>
<PREVIOUS_MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']])}
</PREVIOUS_MESSAGES>
<CURRENT_MESSAGE>
{context['episode_content']}
</CURRENT_MESSAGE>
<ENTITIES>
{to_prompt_json(context['nodes'])}
</ENTITIES>
<REFERENCE_TIME>
{context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
</REFERENCE_TIME>
# TASK
Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
Only extract facts that:
- involve two DISTINCT ENTITIES from the ENTITIES list,
- are clearly stated or unambiguously implied in the CURRENT MESSAGE,
and can be represented as edges in a knowledge graph.
- Facts should include entity names rather than pronouns whenever possible.
- The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
- The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
of the FACT TYPES
- The FACT TYPES each contain their fact_type_signature which represents the source and target entity types.
You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
{context['custom_prompt']}
# EXTRACTION RULES
1. **Entity ID Validation**: `source_entity_id` and `target_entity_id` must use only the `id` values from the ENTITIES list provided above.
- **CRITICAL**: Using IDs not in the list will cause the edge to be rejected
2. Each fact must involve two **distinct** entities.
3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
4. Do not emit duplicate or semantically redundant facts.
5. The `fact` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text.
6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
7. Do **not** hallucinate or infer temporal bounds from unrelated events.
# DATETIME RULES
- Use ISO 8601 with “Z” suffix (UTC) (e.g., 2025-04-30T00:00:00Z).
- If the fact is ongoing (present tense), set `valid_at` to REFERENCE_TIME.
- If a change/termination is expressed, set `invalid_at` to the relevant timestamp.
- Leave both fields `null` if no explicit or resolvable time is stated.
- If only a date is mentioned (no time), assume 00:00:00.
- If only a year is mentioned, use January 1st at 00:00:00.
""",
),
]
def reflexion(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that determines which facts have not been extracted from the given context"""
user_prompt = f"""
<PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']])}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
<EXTRACTED ENTITIES>
{context['nodes']}
</EXTRACTED ENTITIES>
<EXTRACTED FACTS>
{context['extracted_facts']}
</EXTRACTED FACTS>
Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS;
determine if any facts haven't been extracted.
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]
def extract_attributes(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that extracts fact properties from the provided text.',
),
Message(
role='user',
content=f"""
<MESSAGE>
{to_prompt_json(context['episode_content'])}
</MESSAGE>
<REFERENCE TIME>
{context['reference_time']}
</REFERENCE TIME>
Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
Guidelines:
1. Do not hallucinate entity property values if they cannot be found in the current context.
2. Only use the provided MESSAGES and FACT to set attribute values.
<FACT>
{context['fact']}
</FACT>
""",
),
]
versions: Versions = {
'edge': edge,
'reflexion': reflexion,
'extract_attributes': extract_attributes,
}
```
--------------------------------------------------------------------------------
/graphiti_core/embedder/gemini.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
from collections.abc import Iterable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from google import genai
from google.genai import types
else:
try:
from google import genai
from google.genai import types
except ImportError:
raise ImportError(
'google-genai is required for GeminiEmbedder. '
'Install it with: pip install graphiti-core[google-genai]'
) from None
from pydantic import Field
from .client import EmbedderClient, EmbedderConfig
logger = logging.getLogger(__name__)
DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
DEFAULT_BATCH_SIZE = 100
class GeminiEmbedderConfig(EmbedderConfig):
embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
api_key: str | None = None
class GeminiEmbedder(EmbedderClient):
"""
Google Gemini Embedder Client
"""
def __init__(
self,
config: GeminiEmbedderConfig | None = None,
client: 'genai.Client | None' = None,
batch_size: int | None = None,
):
"""
Initialize the GeminiEmbedder with the provided configuration and client.
Args:
config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
batch_size (int | None): An optional batch size to use. If not provided, the default batch size will be used.
"""
if config is None:
config = GeminiEmbedderConfig()
self.config = config
if client is None:
self.client = genai.Client(api_key=config.api_key)
else:
self.client = client
if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
# Gemini API has a limit on the number of instances per request
# https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
self.batch_size = 1
elif batch_size is None:
self.batch_size = DEFAULT_BATCH_SIZE
else:
self.batch_size = batch_size
async def create(
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
) -> list[float]:
"""
Create embeddings for the given input data using Google's Gemini embedding model.
Args:
input_data: The input data to create embeddings for. Can be a string, list of strings,
or an iterable of integers or iterables of integers.
Returns:
A list of floats representing the embedding vector.
"""
# Generate embeddings
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=[input_data], # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
)
if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
raise ValueError('No embeddings returned from Gemini API in create()')
return result.embeddings[0].values
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
"""
Create embeddings for a batch of input data using Google's Gemini embedding model.
This method handles batching to respect the Gemini API's limits on the number
of instances that can be processed in a single request.
Args:
input_data_list: A list of strings to create embeddings for.
Returns:
A list of embedding vectors (each vector is a list of floats).
"""
if not input_data_list:
return []
batch_size = self.batch_size
all_embeddings = []
# Process inputs in batches
for i in range(0, len(input_data_list), batch_size):
batch = input_data_list[i : i + batch_size]
try:
# Generate embeddings for this batch
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
)
if not result.embeddings or len(result.embeddings) == 0:
raise Exception('No embeddings returned')
# Process embeddings from this batch
for embedding in result.embeddings:
if not embedding.values:
raise ValueError('Empty embedding values returned')
all_embeddings.append(embedding.values)
except Exception as e:
# If batch processing fails, fall back to individual processing
logger.warning(
f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
)
for item in batch:
try:
# Process each item individually
result = await self.client.aio.models.embed_content(
model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
config=types.EmbedContentConfig(
output_dimensionality=self.config.embedding_dim
),
)
if not result.embeddings or len(result.embeddings) == 0:
raise ValueError('No embeddings returned from Gemini API')
if not result.embeddings[0].values:
raise ValueError('Empty embedding values returned')
all_embeddings.append(result.embeddings[0].values)
except Exception as individual_error:
logger.error(f'Failed to embed individual item: {individual_error}')
raise individual_error
return all_embeddings
```
--------------------------------------------------------------------------------
/mcp_server/docker/README-falkordb-combined.md:
--------------------------------------------------------------------------------
```markdown
# FalkorDB + Graphiti MCP Server Combined Image
This Docker setup bundles FalkorDB (graph database) and the Graphiti MCP Server into a single container image for simplified deployment.
## Overview
The combined image extends the official FalkorDB Docker image to include:
- **FalkorDB**: Redis-based graph database running on port 6379
- **FalkorDB Web UI**: Graph visualization interface on port 3000
- **Graphiti MCP Server**: Knowledge graph API on port 8000
Both services are managed by a startup script that launches FalkorDB as a daemon and the MCP server in the foreground.
## Quick Start
### Using Docker Compose (Recommended)
1. Create a `.env` file in the `mcp_server` directory:
```bash
# Required
OPENAI_API_KEY=your_openai_api_key
# Optional
GRAPHITI_GROUP_ID=main
SEMAPHORE_LIMIT=10
FALKORDB_PASSWORD=
```
2. Start the combined service:
```bash
cd mcp_server
docker compose -f docker/docker-compose-falkordb-combined.yml up
```
3. Access the services:
- MCP Server: http://localhost:8000/mcp/
- FalkorDB Web UI: http://localhost:3000
- FalkorDB (Redis): localhost:6379
### Using Docker Run
```bash
docker run -d \
-p 6379:6379 \
-p 3000:3000 \
-p 8000:8000 \
-e OPENAI_API_KEY=your_key \
-e GRAPHITI_GROUP_ID=main \
-v falkordb_data:/var/lib/falkordb/data \
zepai/graphiti-falkordb:latest
```
## Building the Image
### Build with Default Version
```bash
docker compose -f docker/docker-compose-falkordb-combined.yml build
```
### Build with Specific Graphiti Version
```bash
GRAPHITI_CORE_VERSION=0.22.0 docker compose -f docker/docker-compose-falkordb-combined.yml build
```
### Build Arguments
- `GRAPHITI_CORE_VERSION`: Version of graphiti-core package (default: 0.22.0)
- `MCP_SERVER_VERSION`: MCP server version tag (default: 1.0.0rc0)
- `BUILD_DATE`: Build timestamp
- `VCS_REF`: Git commit hash
## Configuration
### Environment Variables
All environment variables from the standard MCP server are supported:
**Required:**
- `OPENAI_API_KEY`: OpenAI API key for LLM operations
**Optional:**
- `BROWSER`: Enable FalkorDB Browser web UI on port 3000 (default: "1", set to "0" to disable)
- `GRAPHITI_GROUP_ID`: Namespace for graph data (default: "main")
- `SEMAPHORE_LIMIT`: Concurrency limit for episode processing (default: 10)
- `FALKORDB_PASSWORD`: Password for FalkorDB (optional)
- `FALKORDB_DATABASE`: FalkorDB database name (default: "default_db")
**Other LLM Providers:**
- `ANTHROPIC_API_KEY`: For Claude models
- `GOOGLE_API_KEY`: For Gemini models
- `GROQ_API_KEY`: For Groq models
### Volumes
- `/var/lib/falkordb/data`: Persistent storage for graph data
- `/var/log/graphiti`: MCP server and FalkorDB Browser logs
## Service Management
### View Logs
```bash
# All logs (both services stdout/stderr)
docker compose -f docker/docker-compose-falkordb-combined.yml logs -f
# Only container logs
docker compose -f docker/docker-compose-falkordb-combined.yml logs -f graphiti-falkordb
```
### Restart Services
```bash
# Restart entire container (both services)
docker compose -f docker/docker-compose-falkordb-combined.yml restart
# Check FalkorDB status
docker compose -f docker/docker-compose-falkordb-combined.yml exec graphiti-falkordb redis-cli ping
# Check MCP server status
curl http://localhost:8000/health
```
### Disabling the FalkorDB Browser
To disable the FalkorDB Browser web UI (port 3000), set the `BROWSER` environment variable to `0`:
```bash
# Using docker run
docker run -d \
-p 6379:6379 \
-p 3000:3000 \
-p 8000:8000 \
-e BROWSER=0 \
-e OPENAI_API_KEY=your_key \
zepai/graphiti-falkordb:latest
# Using docker-compose
# Add to your .env file:
BROWSER=0
```
When disabled, only FalkorDB (port 6379) and the MCP server (port 8000) will run.
## Health Checks
The container includes a health check that verifies:
1. FalkorDB is responding to ping
2. MCP server health endpoint is accessible
Check health status:
```bash
docker compose -f docker/docker-compose-falkordb-combined.yml ps
```
## Architecture
### Process Structure
```
start-services.sh (PID 1)
├── redis-server (FalkorDB daemon)
├── node server.js (FalkorDB Browser - background, if BROWSER=1)
└── uv run main.py (MCP server - foreground)
```
The startup script launches FalkorDB as a background daemon, waits for it to be ready, optionally starts the FalkorDB Browser (if `BROWSER=1`), then starts the MCP server in the foreground. When the MCP server stops, the container exits.
### Directory Structure
```
/app/mcp/ # MCP server application
├── main.py
├── src/
├── config/
│ └── config.yaml # FalkorDB-specific configuration
└── .graphiti-core-version # Installed version info
/var/lib/falkordb/data/ # Persistent graph storage
/var/lib/falkordb/browser/ # FalkorDB Browser web UI
/var/log/graphiti/ # MCP server and Browser logs
/start-services.sh # Startup script
```
## Benefits of Combined Image
1. **Simplified Deployment**: Single container to manage
2. **Reduced Network Latency**: Localhost communication between services
3. **Easier Development**: One command to start entire stack
4. **Unified Logging**: All logs available via docker logs
5. **Resource Efficiency**: Shared base image and dependencies
## Troubleshooting
### FalkorDB Not Starting
Check container logs:
```bash
docker compose -f docker/docker-compose-falkordb-combined.yml logs graphiti-falkordb
```
### MCP Server Connection Issues
1. Verify FalkorDB is running:
```bash
docker compose -f docker/docker-compose-falkordb-combined.yml exec graphiti-falkordb redis-cli ping
```
2. Check MCP server health:
```bash
curl http://localhost:8000/health
```
3. View all container logs:
```bash
docker compose -f docker/docker-compose-falkordb-combined.yml logs -f
```
### Port Conflicts
If ports 6379, 3000, or 8000 are already in use, modify the port mappings in `docker-compose-falkordb-combined.yml`:
```yaml
ports:
- "16379:6379" # Use different external port
- "13000:3000"
- "18000:8000"
```
## Production Considerations
1. **Resource Limits**: Add resource constraints in docker-compose:
```yaml
deploy:
resources:
limits:
cpus: '2'
memory: 4G
```
2. **Persistent Volumes**: Use named volumes or bind mounts for production data
3. **Monitoring**: Export logs to external monitoring system
4. **Backups**: Regular backups of `/var/lib/falkordb/data` volume
5. **Security**: Set `FALKORDB_PASSWORD` in production environments
## Comparison with Separate Containers
| Aspect | Combined Image | Separate Containers |
|--------|---------------|---------------------|
| Setup Complexity | Simple (one container) | Moderate (service dependencies) |
| Network Latency | Lower (localhost) | Higher (container network) |
| Resource Usage | Lower (shared base) | Higher (separate images) |
| Scalability | Limited | Better (scale independently) |
| Debugging | Harder (multiple processes) | Easier (isolated services) |
| Production Use | Development/Single-node | Recommended |
## See Also
- [Main MCP Server README](../README.md)
- [FalkorDB Documentation](https://docs.falkordb.com/)
- [Docker Compose Documentation](https://docs.docker.com/compose/)
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_falkordb_integration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
FalkorDB integration test for the Graphiti MCP Server.
Tests MCP server functionality with FalkorDB as the graph database backend.
"""
import asyncio
import json
import time
from typing import Any
from mcp import StdioServerParameters
from mcp.client.stdio import stdio_client
class GraphitiFalkorDBIntegrationTest:
"""Integration test client for Graphiti MCP Server using FalkorDB backend."""
def __init__(self):
self.test_group_id = f'falkor_test_group_{int(time.time())}'
self.session = None
async def __aenter__(self):
"""Start the MCP client session with FalkorDB configuration."""
# Configure server parameters to run with FalkorDB backend
server_params = StdioServerParameters(
command='uv',
args=['run', 'main.py', '--transport', 'stdio', '--database-provider', 'falkordb'],
env={
'FALKORDB_URI': 'redis://localhost:6379',
'FALKORDB_PASSWORD': '', # No password for test instance
'FALKORDB_DATABASE': 'default_db',
'OPENAI_API_KEY': 'dummy_key_for_testing',
'GRAPHITI_GROUP_ID': self.test_group_id,
},
)
# Start the stdio client
self.session = await stdio_client(server_params).__aenter__()
print(' 📡 Started MCP client session with FalkorDB backend')
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Clean up the MCP client session."""
if self.session:
await self.session.close()
print(' 🔌 Closed MCP client session')
async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
"""Call an MCP tool via the stdio client."""
try:
result = await self.session.call_tool(tool_name, arguments)
if hasattr(result, 'content') and result.content:
# Handle different content types
if hasattr(result.content[0], 'text'):
content = result.content[0].text
try:
return json.loads(content)
except json.JSONDecodeError:
return {'raw_response': content}
else:
return {'content': str(result.content[0])}
return {'result': 'success', 'content': None}
except Exception as e:
return {'error': str(e), 'tool': tool_name, 'arguments': arguments}
async def test_server_status(self) -> bool:
"""Test the get_status tool to verify FalkorDB connectivity."""
print(' 🏥 Testing server status with FalkorDB...')
result = await self.call_mcp_tool('get_status', {})
if 'error' in result:
print(f' ❌ Status check failed: {result["error"]}')
return False
# Check if status indicates FalkorDB is working
status_text = result.get('raw_response', result.get('content', ''))
if 'running' in str(status_text).lower() or 'ready' in str(status_text).lower():
print(' ✅ Server status OK with FalkorDB')
return True
else:
print(f' ⚠️ Status unclear: {status_text}')
return True # Don't fail on unclear status
async def test_add_episode(self) -> bool:
"""Test adding an episode to FalkorDB."""
print(' 📝 Testing episode addition to FalkorDB...')
episode_data = {
'name': 'FalkorDB Test Episode',
'episode_body': 'This is a test episode to verify FalkorDB integration works correctly.',
'source': 'text',
'source_description': 'Integration test for FalkorDB backend',
}
result = await self.call_mcp_tool('add_episode', episode_data)
if 'error' in result:
print(f' ❌ Add episode failed: {result["error"]}')
return False
print(' ✅ Episode added successfully to FalkorDB')
return True
async def test_search_functionality(self) -> bool:
"""Test search functionality with FalkorDB."""
print(' 🔍 Testing search functionality with FalkorDB...')
# Give some time for episode processing
await asyncio.sleep(2)
# Test node search
search_result = await self.call_mcp_tool(
'search_nodes', {'query': 'FalkorDB test episode', 'limit': 5}
)
if 'error' in search_result:
print(f' ⚠️ Search returned error (may be expected): {search_result["error"]}')
return True # Don't fail on search errors in integration test
print(' ✅ Search functionality working with FalkorDB')
return True
async def test_clear_graph(self) -> bool:
"""Test clearing the graph in FalkorDB."""
print(' 🧹 Testing graph clearing in FalkorDB...')
result = await self.call_mcp_tool('clear_graph', {})
if 'error' in result:
print(f' ❌ Clear graph failed: {result["error"]}')
return False
print(' ✅ Graph cleared successfully in FalkorDB')
return True
async def run_falkordb_integration_test() -> bool:
"""Run the complete FalkorDB integration test suite."""
print('🧪 Starting FalkorDB Integration Test Suite')
print('=' * 55)
test_results = []
try:
async with GraphitiFalkorDBIntegrationTest() as test_client:
print(f' 🎯 Using test group: {test_client.test_group_id}')
# Run test suite
tests = [
('Server Status', test_client.test_server_status),
('Add Episode', test_client.test_add_episode),
('Search Functionality', test_client.test_search_functionality),
('Clear Graph', test_client.test_clear_graph),
]
for test_name, test_func in tests:
print(f'\n🔬 Running {test_name} Test...')
try:
result = await test_func()
test_results.append((test_name, result))
if result:
print(f' ✅ {test_name}: PASSED')
else:
print(f' ❌ {test_name}: FAILED')
except Exception as e:
print(f' 💥 {test_name}: ERROR - {e}')
test_results.append((test_name, False))
except Exception as e:
print(f'💥 Test setup failed: {e}')
return False
# Summary
print('\n' + '=' * 55)
print('📊 FalkorDB Integration Test Results:')
print('-' * 30)
passed = sum(1 for _, result in test_results if result)
total = len(test_results)
for test_name, result in test_results:
status = '✅ PASS' if result else '❌ FAIL'
print(f' {test_name}: {status}')
print(f'\n🎯 Overall: {passed}/{total} tests passed')
if passed == total:
print('🎉 All FalkorDB integration tests PASSED!')
return True
else:
print('⚠️ Some FalkorDB integration tests failed')
return passed >= (total * 0.7) # Pass if 70% of tests pass
if __name__ == '__main__':
success = asyncio.run(run_falkordb_integration_test())
exit(0 if success else 1)
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_configuration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""Test script for configuration loading and factory patterns."""
import asyncio
import os
import sys
from pathlib import Path
# Add the current directory to the path
sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
from config.schema import GraphitiConfig
from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
def test_config_loading():
"""Test loading configuration from YAML and environment variables."""
print('Testing configuration loading...')
# Test with default config.yaml
config = GraphitiConfig()
print('✓ Loaded configuration successfully')
print(f' - Server transport: {config.server.transport}')
print(f' - LLM provider: {config.llm.provider}')
print(f' - LLM model: {config.llm.model}')
print(f' - Embedder provider: {config.embedder.provider}')
print(f' - Database provider: {config.database.provider}')
print(f' - Group ID: {config.graphiti.group_id}')
# Test environment variable override
os.environ['LLM__PROVIDER'] = 'anthropic'
os.environ['LLM__MODEL'] = 'claude-3-opus'
config2 = GraphitiConfig()
print('\n✓ Environment variable overrides work')
print(f' - LLM provider (overridden): {config2.llm.provider}')
print(f' - LLM model (overridden): {config2.llm.model}')
# Clean up env vars
del os.environ['LLM__PROVIDER']
del os.environ['LLM__MODEL']
assert config is not None
assert config2 is not None
# Return the first config for subsequent tests
return config
def test_llm_factory(config: GraphitiConfig):
"""Test LLM client factory creation."""
print('\nTesting LLM client factory...')
# Test OpenAI client creation (if API key is set)
if (
config.llm.provider == 'openai'
and config.llm.providers.openai
and config.llm.providers.openai.api_key
):
try:
client = LLMClientFactory.create(config.llm)
print(f'✓ Created {config.llm.provider} LLM client successfully')
print(f' - Model: {client.model}')
print(f' - Temperature: {client.temperature}')
except Exception as e:
print(f'✗ Failed to create LLM client: {e}')
else:
print(f'⚠ Skipping LLM factory test (no API key configured for {config.llm.provider})')
# Test switching providers
test_config = config.llm.model_copy()
test_config.provider = 'gemini'
if not test_config.providers.gemini:
from config.schema import GeminiProviderConfig
test_config.providers.gemini = GeminiProviderConfig(api_key='dummy_value_for_testing')
else:
test_config.providers.gemini.api_key = 'dummy_value_for_testing'
try:
client = LLMClientFactory.create(test_config)
print('✓ Factory supports provider switching (tested with Gemini)')
except Exception as e:
print(f'✗ Factory provider switching failed: {e}')
def test_embedder_factory(config: GraphitiConfig):
"""Test Embedder client factory creation."""
print('\nTesting Embedder client factory...')
# Test OpenAI embedder creation (if API key is set)
if (
config.embedder.provider == 'openai'
and config.embedder.providers.openai
and config.embedder.providers.openai.api_key
):
try:
_ = EmbedderFactory.create(config.embedder)
print(f'✓ Created {config.embedder.provider} Embedder client successfully')
# The embedder client may not expose model/dimensions as attributes
print(f' - Configured model: {config.embedder.model}')
print(f' - Configured dimensions: {config.embedder.dimensions}')
except Exception as e:
print(f'✗ Failed to create Embedder client: {e}')
else:
print(
f'⚠ Skipping Embedder factory test (no API key configured for {config.embedder.provider})'
)
async def test_database_factory(config: GraphitiConfig):
"""Test Database driver factory creation."""
print('\nTesting Database driver factory...')
# Test Neo4j config creation
if config.database.provider == 'neo4j' and config.database.providers.neo4j:
try:
db_config = DatabaseDriverFactory.create_config(config.database)
print(f'✓ Created {config.database.provider} configuration successfully')
print(f' - URI: {db_config["uri"]}')
print(f' - User: {db_config["user"]}')
print(
f' - Password: {"*" * len(db_config["password"]) if db_config["password"] else "None"}'
)
# Test actual connection would require initializing Graphiti
from graphiti_core import Graphiti
try:
# This will fail if Neo4j is not running, but tests the config
graphiti = Graphiti(
uri=db_config['uri'],
user=db_config['user'],
password=db_config['password'],
)
await graphiti.driver.client.verify_connectivity()
print(' ✓ Successfully connected to Neo4j')
await graphiti.driver.client.close()
except Exception as e:
print(f' ⚠ Could not connect to Neo4j (is it running?): {type(e).__name__}')
except Exception as e:
print(f'✗ Failed to create Database configuration: {e}')
else:
print(f'⚠ Skipping Database factory test (no configuration for {config.database.provider})')
def test_cli_override():
"""Test CLI argument override functionality."""
print('\nTesting CLI argument override...')
# Simulate argparse Namespace
class Args:
config = Path('config.yaml')
transport = 'stdio'
llm_provider = 'anthropic'
model = 'claude-3-sonnet'
temperature = 0.5
embedder_provider = 'voyage'
embedder_model = 'voyage-3'
database_provider = 'falkordb'
group_id = 'test-group'
user_id = 'test-user'
config = GraphitiConfig()
config.apply_cli_overrides(Args())
print('✓ CLI overrides applied successfully')
print(f' - Transport: {config.server.transport}')
print(f' - LLM provider: {config.llm.provider}')
print(f' - LLM model: {config.llm.model}')
print(f' - Temperature: {config.llm.temperature}')
print(f' - Embedder provider: {config.embedder.provider}')
print(f' - Database provider: {config.database.provider}')
print(f' - Group ID: {config.graphiti.group_id}')
print(f' - User ID: {config.graphiti.user_id}')
async def main():
"""Run all tests."""
print('=' * 60)
print('Configuration and Factory Pattern Test Suite')
print('=' * 60)
try:
# Test configuration loading
config = test_config_loading()
# Test factories
test_llm_factory(config)
test_embedder_factory(config)
await test_database_factory(config)
# Test CLI overrides
test_cli_override()
print('\n' + '=' * 60)
print('✓ All tests completed successfully!')
print('=' * 60)
except Exception as e:
print(f'\n✗ Test suite failed: {e}')
sys.exit(1)
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/graphiti_core/search/search_config_recipes.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from graphiti_core.search.search_config import (
CommunityReranker,
CommunitySearchConfig,
CommunitySearchMethod,
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
EpisodeReranker,
EpisodeSearchConfig,
EpisodeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
)
# Performs a hybrid search with rrf reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.rrf,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.rrf,
),
episode_config=EpisodeSearchConfig(
search_methods=[
EpisodeSearchMethod.bm25,
],
reranker=EpisodeReranker.rrf,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.rrf,
),
)
# Performs a hybrid search with mmr reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
mmr_lambda=1,
),
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
mmr_lambda=1,
),
episode_config=EpisodeSearchConfig(
search_methods=[
EpisodeSearchMethod.bm25,
],
reranker=EpisodeReranker.rrf,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
mmr_lambda=1,
),
)
# Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[
EdgeSearchMethod.bm25,
EdgeSearchMethod.cosine_similarity,
EdgeSearchMethod.bfs,
],
reranker=EdgeReranker.cross_encoder,
),
node_config=NodeSearchConfig(
search_methods=[
NodeSearchMethod.bm25,
NodeSearchMethod.cosine_similarity,
NodeSearchMethod.bfs,
],
reranker=NodeReranker.cross_encoder,
),
episode_config=EpisodeSearchConfig(
search_methods=[
EpisodeSearchMethod.bm25,
],
reranker=EpisodeReranker.cross_encoder,
),
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.cross_encoder,
),
)
# performs a hybrid search over edges with rrf reranking
EDGE_HYBRID_SEARCH_RRF = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.rrf,
)
)
# performs a hybrid search over edges with mmr reranking
EDGE_HYBRID_SEARCH_MMR = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.mmr,
)
)
# performs a hybrid search over edges with node distance reranking
EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.node_distance,
),
)
# performs a hybrid search over edges with episode mention reranking
EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
reranker=EdgeReranker.episode_mentions,
)
)
# performs a hybrid search over edges with cross encoder reranking
EDGE_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
edge_config=EdgeSearchConfig(
search_methods=[
EdgeSearchMethod.bm25,
EdgeSearchMethod.cosine_similarity,
EdgeSearchMethod.bfs,
],
reranker=EdgeReranker.cross_encoder,
),
limit=10,
)
# performs a hybrid search over nodes with rrf reranking
NODE_HYBRID_SEARCH_RRF = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.rrf,
)
)
# performs a hybrid search over nodes with mmr reranking
NODE_HYBRID_SEARCH_MMR = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.mmr,
)
)
# performs a hybrid search over nodes with node distance reranking
NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.node_distance,
)
)
# performs a hybrid search over nodes with episode mentions reranking
NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
reranker=NodeReranker.episode_mentions,
)
)
# performs a hybrid search over nodes with episode mentions reranking
NODE_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
node_config=NodeSearchConfig(
search_methods=[
NodeSearchMethod.bm25,
NodeSearchMethod.cosine_similarity,
NodeSearchMethod.bfs,
],
reranker=NodeReranker.cross_encoder,
),
limit=10,
)
# performs a hybrid search over communities with rrf reranking
COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.rrf,
)
)
# performs a hybrid search over communities with mmr reranking
COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.mmr,
)
)
# performs a hybrid search over communities with mmr reranking
COMMUNITY_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
community_config=CommunitySearchConfig(
search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
reranker=CommunityReranker.cross_encoder,
),
limit=3,
)
```
--------------------------------------------------------------------------------
/tests/test_node_int.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime, timedelta
from uuid import uuid4
import pytest
from graphiti_core.nodes import (
CommunityNode,
EntityNode,
EpisodeType,
EpisodicNode,
)
from tests.helpers_test import (
assert_community_node_equals,
assert_entity_node_equals,
assert_episodic_node_equals,
get_node_count,
group_id,
)
created_at = datetime.now()
deleted_at = created_at + timedelta(days=3)
valid_at = created_at + timedelta(days=1)
invalid_at = created_at + timedelta(days=2)
@pytest.fixture
def sample_entity_node():
return EntityNode(
uuid=str(uuid4()),
name='Test Entity',
group_id=group_id,
labels=['Entity', 'Person'],
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Entity Summary',
attributes={
'age': 30,
'location': 'New York',
},
)
@pytest.fixture
def sample_episodic_node():
return EpisodicNode(
uuid=str(uuid4()),
name='Episode 1',
group_id=group_id,
created_at=created_at,
source=EpisodeType.text,
source_description='Test source',
content='Some content here',
valid_at=valid_at,
entity_edges=[],
)
@pytest.fixture
def sample_community_node():
return CommunityNode(
uuid=str(uuid4()),
name='Community A',
group_id=group_id,
created_at=created_at,
name_embedding=[0.5] * 1024,
summary='Community summary',
)
@pytest.mark.asyncio
async def test_entity_node(sample_entity_node, graph_driver):
uuid = sample_entity_node.uuid
# Create node
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
# Get node by uuid
retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
# Get node by uuids
retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Get node by group ids
retrieved = await EntityNode.get_by_group_ids(
graph_driver, [group_id], limit=2, with_embeddings=True
)
assert len(retrieved) == 1
await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
# Delete node by uuid
await sample_entity_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_entity_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_entity_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_community_node(sample_community_node, graph_driver):
uuid = sample_community_node.uuid
# Create node
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
# Get node by uuid
retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
# Get node by uuids
retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
# Get node by group ids
retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
# Delete node by uuid
await sample_community_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_community_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_community_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await graph_driver.close()
@pytest.mark.asyncio
async def test_episodic_node(sample_episodic_node, graph_driver):
uuid = sample_episodic_node.uuid
# Create node
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
# Get node by uuid
retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
await assert_episodic_node_equals(retrieved, sample_episodic_node)
# Get node by uuids
retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
# Get node by group ids
retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
assert len(retrieved) == 1
await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
# Delete node by uuid
await sample_episodic_node.delete(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by uuids
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
# Delete node by group id
await sample_episodic_node.save(graph_driver)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 1
await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
node_count = await get_node_count(graph_driver, [uuid])
assert node_count == 0
await graph_driver.close()
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_temporal_operations_int.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from datetime import timedelta
import pytest
from dotenv import load_dotenv
from graphiti_core.edges import EntityEdge
from graphiti_core.llm_client import LLMConfig, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.temporal_operations import (
get_edge_contradictions,
)
load_dotenv()
def setup_llm_client():
return OpenAIClient(
LLMConfig(
api_key=os.getenv('TEST_OPENAI_API_KEY'),
model=os.getenv('TEST_OPENAI_MODEL'),
base_url='https://api.openai.com/v1',
)
)
def create_test_data():
now = utc_now()
# Create edges
existing_edge = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='LIKES',
fact='Alice likes Bob',
created_at=now - timedelta(days=1),
group_id='1',
)
new_edge = EntityEdge(
uuid='e2',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
created_at=now,
group_id='1',
)
# Create current episode
current_episode = EpisodicNode(
name='Current Episode',
content='Alice now dislikes Bob',
created_at=now,
valid_at=now,
source=EpisodeType.message,
source_description='Test episode for unit testing',
group_id='1',
)
# Create previous episodes
previous_episodes = [
EpisodicNode(
name='Previous Episode',
content='Alice liked Bob',
created_at=now - timedelta(days=1),
valid_at=now - timedelta(days=1),
source=EpisodeType.message,
source_description='Test previous episode for unit testing',
group_id='1',
)
]
return existing_edge, new_edge, current_episode, previous_episodes
@pytest.mark.asyncio
@pytest.mark.integration
async def test_get_edge_contradictions():
existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge])
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge.uuid
@pytest.mark.asyncio
@pytest.mark.integration
async def test_get_edge_contradictions_no_contradictions():
_, new_edge, current_episode, previous_episodes = create_test_data()
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [])
assert len(invalidated_edges) == 0
@pytest.mark.skip(reason='Flaky LLM-based test with non-deterministic results')
@pytest.mark.asyncio
@pytest.mark.integration
async def test_get_edge_contradictions_multiple_existing():
existing_edge1, new_edge, _, _ = create_test_data()
existing_edge2, _, _, _ = create_test_data()
existing_edge2.uuid = 'e3'
existing_edge2.name = 'KNOWS'
existing_edge2.fact = 'Alice knows Bob'
invalidated_edges = await get_edge_contradictions(
setup_llm_client(), new_edge, [existing_edge1, existing_edge2]
)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == existing_edge1.uuid
# Helper function to create more complex test data
def create_complex_test_data():
now = utc_now()
# Create nodes
node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
node4 = EntityNode(
uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
)
# Create edges
existing_edge1 = EntityEdge(
uuid='e1',
source_node_uuid='1',
target_node_uuid='2',
name='LIKES',
fact='Alice likes Bob',
group_id='1',
created_at=now - timedelta(days=5),
)
existing_edge2 = EntityEdge(
uuid='e2',
source_node_uuid='1',
target_node_uuid='3',
name='FRIENDS_WITH',
fact='Alice is friends with Charlie',
group_id='1',
created_at=now - timedelta(days=3),
)
existing_edge3 = EntityEdge(
uuid='e3',
source_node_uuid='2',
target_node_uuid='4',
name='WORKS_FOR',
fact='Bob works for Company XYZ',
group_id='1',
created_at=now - timedelta(days=2),
)
return [existing_edge1, existing_edge2, existing_edge3], [
node1,
node2,
node3,
node4,
]
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_complex():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that contradicts an existing one
new_edge = EntityEdge(
uuid='e4',
source_node_uuid='1',
target_node_uuid='2',
name='DISLIKES',
fact='Alice dislikes Bob',
group_id='1',
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == 'e1'
@pytest.mark.asyncio
@pytest.mark.integration
async def test_get_edge_contradictions_temporal_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that updates an existing one with new information
new_edge = EntityEdge(
uuid='e5',
source_node_uuid='2',
target_node_uuid='4',
name='LEFT_JOB',
fact='Bob no longer works at at Company XYZ',
group_id='1',
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 1
assert invalidated_edges[0].uuid == 'e3'
@pytest.mark.asyncio
@pytest.mark.integration
async def test_get_edge_contradictions_no_effect():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that doesn't invalidate any existing edges
new_edge = EntityEdge(
uuid='e8',
source_node_uuid='3',
target_node_uuid='4',
name='APPLIED_TO',
fact='Charlie applied to Company XYZ',
group_id='1',
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 0
@pytest.mark.skip(reason='Flaky LLM-based test with non-deterministic results')
@pytest.mark.asyncio
@pytest.mark.integration
async def test_invalidate_edges_partial_update():
existing_edges, nodes = create_complex_test_data()
# Create a new edge that partially updates an existing one
new_edge = EntityEdge(
uuid='e9',
source_node_uuid='2',
target_node_uuid='4',
name='CHANGED_POSITION',
fact='Bob changed his position at Company XYZ',
group_id='1',
created_at=utc_now(),
)
invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
# Run the tests
if __name__ == '__main__':
pytest.main([__file__])
```
--------------------------------------------------------------------------------
/graphiti_core/graph_queries.py:
--------------------------------------------------------------------------------
```python
"""
Database query utilities for different graph database backends.
This module provides database-agnostic query generation for Neo4j and FalkorDB,
supporting index creation, fulltext search, and bulk operations.
"""
from typing_extensions import LiteralString
from graphiti_core.driver.driver import GraphProvider
# Mapping from Neo4j fulltext index names to FalkorDB node labels
NEO4J_TO_FALKORDB_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RELATES_TO',
}
# Mapping from fulltext index names to Kuzu node labels
INDEX_TO_LABEL_KUZU_MAPPING = {
'node_name_and_summary': 'Entity',
'community_name': 'Community',
'episode_content': 'Episodic',
'edge_name_and_fact': 'RelatesToNode_',
}
def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.FALKORDB:
return [
# Entity node
'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
# Episodic node
'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
# Community node
'CREATE INDEX FOR (n:Community) ON (n.uuid)',
# RELATES_TO edge
'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
# MENTIONS edge
'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
# HAS_MEMBER edge
'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
]
if provider == GraphProvider.KUZU:
return []
return [
'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
'CREATE INDEX community_group_id IF NOT EXISTS FOR (n:Community) ON (n.group_id)',
'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
]
def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
if provider == GraphProvider.FALKORDB:
from typing import cast
from graphiti_core.driver.falkordb_driver import STOPWORDS
# Convert to string representation for embedding in queries
stopwords_str = str(STOPWORDS)
# Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
return cast(
list[LiteralString],
[
f"""CALL db.idx.fulltext.createNodeIndex(
{{
label: 'Episodic',
stopwords: {stopwords_str}
}},
'content', 'source', 'source_description', 'group_id'
)""",
f"""CALL db.idx.fulltext.createNodeIndex(
{{
label: 'Entity',
stopwords: {stopwords_str}
}},
'name', 'summary', 'group_id'
)""",
f"""CALL db.idx.fulltext.createNodeIndex(
{{
label: 'Community',
stopwords: {stopwords_str}
}},
'name', 'group_id'
)""",
"""CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
],
)
if provider == GraphProvider.KUZU:
return [
"CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
"CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
"CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
"CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
]
return [
"""CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
"""CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
"""CREATE FULLTEXT INDEX community_name IF NOT EXISTS
FOR (n:Community) ON EACH [n.name, n.group_id]""",
"""CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
]
def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
# FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
if provider == GraphProvider.KUZU:
return f'array_cosine_similarity({vec1}, {vec2})'
return f'vector.similarity.cosine({vec1}, {vec2})'
def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
if provider == GraphProvider.FALKORDB:
label = NEO4J_TO_FALKORDB_MAPPING[name]
return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
if provider == GraphProvider.KUZU:
label = INDEX_TO_LABEL_KUZU_MAPPING[name]
return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
```
--------------------------------------------------------------------------------
/examples/azure-openai/azure_openai_neo4j.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO
from dotenv import load_dotenv
from openai import AsyncOpenAI
from graphiti_core import Graphiti
from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.nodes import EpisodeType
#################################################
# CONFIGURATION
#################################################
# Set up logging and environment variables for
# connecting to Neo4j database and Azure OpenAI
#################################################
# Configure logging
logging.basicConfig(
level=INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
load_dotenv()
# Neo4j connection parameters
# Make sure Neo4j Desktop is running with a local DBMS started
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
# Azure OpenAI connection parameters
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', 'gpt-4.1')
azure_embedding_deployment = os.environ.get(
'AZURE_OPENAI_EMBEDDING_DEPLOYMENT', 'text-embedding-3-small'
)
if not azure_endpoint or not azure_api_key:
raise ValueError('AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set')
async def main():
#################################################
# INITIALIZATION
#################################################
# Connect to Neo4j and Azure OpenAI, then set up
# Graphiti indices. This is required before using
# other Graphiti functionality
#################################################
# Initialize Azure OpenAI client
azure_client = AsyncOpenAI(
base_url=f'{azure_endpoint}/openai/v1/',
api_key=azure_api_key,
)
# Create LLM and Embedder clients
llm_client = AzureOpenAILLMClient(
azure_client=azure_client,
config=LLMConfig(model=azure_deployment, small_model=azure_deployment),
)
embedder_client = AzureOpenAIEmbedderClient(
azure_client=azure_client, model=azure_embedding_deployment
)
# Initialize Graphiti with Neo4j connection and Azure OpenAI clients
graphiti = Graphiti(
neo4j_uri,
neo4j_user,
neo4j_password,
llm_client=llm_client,
embedder=embedder_client,
)
try:
#################################################
# ADDING EPISODES
#################################################
# Episodes are the primary units of information
# in Graphiti. They can be text or structured JSON
# and are automatically processed to extract entities
# and relationships.
#################################################
# Example: Add Episodes
# Episodes list containing both text and JSON episodes
episodes = [
{
'content': 'Kamala Harris is the Attorney General of California. She was previously '
'the district attorney for San Francisco.',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'state': 'California',
'previous_role': 'Lieutenant Governor',
'previous_location': 'San Francisco',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
]
# Add episodes to the graph
for i, episode in enumerate(episodes):
await graphiti.add_episode(
name=f'California Politics {i}',
episode_body=(
episode['content']
if isinstance(episode['content'], str)
else json.dumps(episode['content'])
),
source=episode['type'],
source_description=episode['description'],
reference_time=datetime.now(timezone.utc),
)
print(f'Added episode: California Politics {i} ({episode["type"].value})')
#################################################
# BASIC SEARCH
#################################################
# The simplest way to retrieve relationships (edges)
# from Graphiti is using the search method, which
# performs a hybrid search combining semantic
# similarity and BM25 text retrieval.
#################################################
# Perform a hybrid search combining semantic similarity and BM25 retrieval
print("\nSearching for: 'Who was the California Attorney General?'")
results = await graphiti.search('Who was the California Attorney General?')
# Print search results
print('\nSearch Results:')
for result in results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
#################################################
# CENTER NODE SEARCH
#################################################
# For more contextually relevant results, you can
# use a center node to rerank search results based
# on their graph distance to a specific node
#################################################
# Use the top search result's UUID as the center node for reranking
if results and len(results) > 0:
# Get the source node UUID from the top result
center_node_uuid = results[0].source_node_uuid
print('\nReranking search results based on graph distance:')
print(f'Using center node UUID: {center_node_uuid}')
reranked_results = await graphiti.search(
'Who was the California Attorney General?',
center_node_uuid=center_node_uuid,
)
# Print reranked search results
print('\nReranked Search Results:')
for result in reranked_results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
else:
print('No results found in the initial search to use as center node.')
finally:
#################################################
# CLEANUP
#################################################
# Always close the connection to Neo4j when
# finished to properly release resources
#################################################
# Close the connection
await graphiti.close()
print('\nConnection closed')
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/openai_generic_client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import typing
from typing import Any, ClassVar
import openai
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam
from pydantic import BaseModel
from ..prompts.models import Message
from .client import LLMClient, get_extraction_language_instruction
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError, RefusalError
logger = logging.getLogger(__name__)
DEFAULT_MODEL = 'gpt-4.1-mini'
class OpenAIGenericClient(LLMClient):
"""
OpenAIClient is a client class for interacting with OpenAI's language models.
This class extends the LLMClient and provides methods to initialize the client,
get an embedder, and generate responses from the language model.
Attributes:
client (AsyncOpenAI): The OpenAI client used to interact with the API.
model (str): The model name to use for generating responses.
temperature (float): The temperature to use for generating responses.
max_tokens (int): The maximum number of tokens to generate in a response.
Methods:
__init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
Initializes the OpenAIClient with the provided configuration, cache setting, and client.
_generate_response(messages: list[Message]) -> dict[str, typing.Any]:
Generates a response from the language model based on the provided messages.
"""
# Class-level constants
MAX_RETRIES: ClassVar[int] = 2
def __init__(
self,
config: LLMConfig | None = None,
cache: bool = False,
client: typing.Any = None,
max_tokens: int = 16384,
):
"""
Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
Args:
config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
cache (bool): Whether to use caching for responses. Defaults to False.
client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
"""
# removed caching to simplify the `generate_response` override
if cache:
raise NotImplementedError('Caching is not implemented for OpenAI')
if config is None:
config = LLMConfig()
super().__init__(config, cache)
# Override max_tokens to support higher limits for local models
self.max_tokens = max_tokens
if client is None:
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
else:
self.client = client
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
m.content = self._clean_input(m.content)
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
try:
# Prepare response format
response_format: dict[str, Any] = {'type': 'json_object'}
if response_model is not None:
schema_name = getattr(response_model, '__name__', 'structured_response')
json_schema = response_model.model_json_schema()
response_format = {
'type': 'json_schema',
'json_schema': {
'name': schema_name,
'schema': json_schema,
},
}
response = await self.client.chat.completions.create(
model=self.model or DEFAULT_MODEL,
messages=openai_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
response_format=response_format, # type: ignore[arg-type]
)
result = response.choices[0].message.content or ''
return json.loads(result)
except openai.RateLimitError as e:
raise RateLimitError from e
except Exception as e:
logger.error(f'Error in generating LLM response: {e}')
raise
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
prompt_name: str | None = None,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction(group_id)
# Wrap entire operation in tracing span
with self.tracer.start_span('llm.generate') as span:
attributes = {
'llm.provider': 'openai',
'model.size': model_size.value,
'max_tokens': max_tokens,
}
if prompt_name:
attributes['prompt.name'] = prompt_name
span.add_attributes(attributes)
retry_count = 0
last_error = None
while retry_count <= self.MAX_RETRIES:
try:
response = await self._generate_response(
messages, response_model, max_tokens=max_tokens, model_size=model_size
)
return response
except (RateLimitError, RefusalError):
# These errors should not trigger retries
span.set_status('error', str(last_error))
raise
except (
openai.APITimeoutError,
openai.APIConnectionError,
openai.InternalServerError,
):
# Let OpenAI's client handle these retries
span.set_status('error', str(last_error))
raise
except Exception as e:
last_error = e
# Don't retry if we've hit the max retries
if retry_count >= self.MAX_RETRIES:
logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
span.set_status('error', str(e))
span.record_exception(e)
raise
retry_count += 1
# Construct a detailed error message for the LLM
error_context = (
f'The previous response attempt was invalid. '
f'Error type: {e.__class__.__name__}. '
f'Error details: {str(e)}. '
f'Please try again with a valid response, ensuring the output matches '
f'the expected format and constraints.'
)
error_message = Message(role='user', content=error_context)
messages.append(error_message)
logger.warning(
f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
)
# If we somehow get here, raise the last error
span.set_status('error', str(last_error))
raise last_error or Exception('Max retries exceeded with no specific error')
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/client.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import hashlib
import json
import logging
import typing
from abc import ABC, abstractmethod
import httpx
from diskcache import Cache
from pydantic import BaseModel
from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
from ..prompts.models import Message
from ..tracer import NoOpTracer, Tracer
from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
from .errors import RateLimitError
DEFAULT_TEMPERATURE = 0
DEFAULT_CACHE_DIR = './llm_cache'
def get_extraction_language_instruction(group_id: str | None = None) -> str:
"""Returns instruction for language extraction behavior.
Override this function to customize language extraction:
- Return empty string to disable multilingual instructions
- Return custom instructions for specific language requirements
- Use group_id to provide different instructions per group/partition
Args:
group_id: Optional partition identifier for the graph
Returns:
str: Language instruction to append to system messages
"""
return '\n\nAny extracted information should be returned in the same language as it was written in.'
logger = logging.getLogger(__name__)
def is_server_or_retry_error(exception):
if isinstance(exception, RateLimitError | json.decoder.JSONDecodeError):
return True
return (
isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
)
class LLMClient(ABC):
def __init__(self, config: LLMConfig | None, cache: bool = False):
if config is None:
config = LLMConfig()
self.config = config
self.model = config.model
self.small_model = config.small_model
self.temperature = config.temperature
self.max_tokens = config.max_tokens
self.cache_enabled = cache
self.cache_dir = None
self.tracer: Tracer = NoOpTracer()
# Only create the cache directory if caching is enabled
if self.cache_enabled:
self.cache_dir = Cache(DEFAULT_CACHE_DIR)
def set_tracer(self, tracer: Tracer) -> None:
"""Set the tracer for this LLM client."""
self.tracer = tracer
def _clean_input(self, input: str) -> str:
"""Clean input string of invalid unicode and control characters.
Args:
input: Raw input string to be cleaned
Returns:
Cleaned string safe for LLM processing
"""
# Clean any invalid Unicode
cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
# Remove zero-width characters and other invisible unicode
zero_width = '\u200b\u200c\u200d\ufeff\u2060'
for char in zero_width:
cleaned = cleaned.replace(char, '')
# Remove control characters except newlines, returns, and tabs
cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
return cleaned
@retry(
stop=stop_after_attempt(4),
wait=wait_random_exponential(multiplier=10, min=5, max=120),
retry=retry_if_exception(is_server_or_retry_error),
after=lambda retry_state: logger.warning(
f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
)
if retry_state.attempt_number > 1
else None,
reraise=True,
)
async def _generate_response_with_retry(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
try:
return await self._generate_response(messages, response_model, max_tokens, model_size)
except (httpx.HTTPStatusError, RateLimitError) as e:
raise e
@abstractmethod
async def _generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int = DEFAULT_MAX_TOKENS,
model_size: ModelSize = ModelSize.medium,
) -> dict[str, typing.Any]:
pass
def _get_cache_key(self, messages: list[Message]) -> str:
# Create a unique cache key based on the messages and model
message_str = json.dumps([m.model_dump() for m in messages], sort_keys=True)
key_str = f'{self.model}:{message_str}'
return hashlib.md5(key_str.encode()).hexdigest()
async def generate_response(
self,
messages: list[Message],
response_model: type[BaseModel] | None = None,
max_tokens: int | None = None,
model_size: ModelSize = ModelSize.medium,
group_id: str | None = None,
prompt_name: str | None = None,
) -> dict[str, typing.Any]:
if max_tokens is None:
max_tokens = self.max_tokens
if response_model is not None:
serialized_model = json.dumps(response_model.model_json_schema())
messages[
-1
].content += (
f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
)
# Add multilingual extraction instructions
messages[0].content += get_extraction_language_instruction(group_id)
for message in messages:
message.content = self._clean_input(message.content)
# Wrap entire operation in tracing span
with self.tracer.start_span('llm.generate') as span:
attributes = {
'llm.provider': self._get_provider_type(),
'model.size': model_size.value,
'max_tokens': max_tokens,
'cache.enabled': self.cache_enabled,
}
if prompt_name:
attributes['prompt.name'] = prompt_name
span.add_attributes(attributes)
# Check cache first
if self.cache_enabled and self.cache_dir is not None:
cache_key = self._get_cache_key(messages)
cached_response = self.cache_dir.get(cache_key)
if cached_response is not None:
logger.debug(f'Cache hit for {cache_key}')
span.add_attributes({'cache.hit': True})
return cached_response
span.add_attributes({'cache.hit': False})
# Execute LLM call
try:
response = await self._generate_response_with_retry(
messages, response_model, max_tokens, model_size
)
except Exception as e:
span.set_status('error', str(e))
span.record_exception(e)
raise
# Cache response if enabled
if self.cache_enabled and self.cache_dir is not None:
cache_key = self._get_cache_key(messages)
self.cache_dir.set(cache_key, response)
return response
def _get_provider_type(self) -> str:
"""Get provider type from class name."""
class_name = self.__class__.__name__.lower()
if 'openai' in class_name:
return 'openai'
elif 'anthropic' in class_name:
return 'anthropic'
elif 'gemini' in class_name:
return 'gemini'
elif 'groq' in class_name:
return 'groq'
else:
return 'unknown'
def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
"""
Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
"""
log = ''
log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
if output is not None:
if len(output) > 4000:
log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
else:
log += f'Raw output: {output}\n'
else:
log += 'No raw output available'
return log
```
--------------------------------------------------------------------------------
/graphiti_core/prompts/dedupe_nodes.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Protocol, TypedDict
from pydantic import BaseModel, Field
from .models import Message, PromptFunction, PromptVersion
from .prompt_helpers import to_prompt_json
class NodeDuplicate(BaseModel):
id: int = Field(..., description='integer id of the entity')
duplicate_idx: int = Field(
...,
description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
)
name: str = Field(
...,
description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
)
duplicates: list[int] = Field(
...,
description='idx of all entities that are a duplicate of the entity with the above id.',
)
class NodeResolutions(BaseModel):
entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
class Prompt(Protocol):
node: PromptVersion
node_list: PromptVersion
nodes: PromptVersion
class Versions(TypedDict):
node: PromptFunction
node_list: PromptFunction
nodes: PromptFunction
def node(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
),
Message(
role='user',
content=f"""
<PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']])}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
<NEW ENTITY>
{to_prompt_json(context['extracted_node'])}
</NEW ENTITY>
<ENTITY TYPE DESCRIPTION>
{to_prompt_json(context['entity_type_description'])}
</ENTITY TYPE DESCRIPTION>
<EXISTING ENTITIES>
{to_prompt_json(context['existing_nodes'])}
</EXISTING ENTITIES>
Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
is a duplicate entity of one of the EXISTING ENTITIES.
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
Semantic Equivalence: if a descriptive label in existing_entities clearly refers to a named entity in context, treat them as duplicates.
Do NOT mark entities as duplicates if:
- They are related but distinct.
- They have similar names or purposes but refer to separate instances or concepts.
TASK:
1. Compare `new_entity` against each item in `existing_entities`.
2. If it refers to the same real-world object or concept, collect its index.
3. Let `duplicate_idx` = the smallest collected index, or -1 if none.
4. Let `duplicates` = the sorted list of all collected indices (empty list if none).
Respond with a JSON object containing an "entity_resolutions" array with a single entry:
{{
"entity_resolutions": [
{{
"id": integer id from NEW ENTITY,
"name": the best full name for the entity,
"duplicate_idx": integer index of the best duplicate in EXISTING ENTITIES, or -1 if none,
"duplicates": sorted list of all duplicate indices you collected (deduplicate the list, use [] when none)
}}
]
}}
Only reference indices that appear in EXISTING ENTITIES, and return [] / -1 when unsure.
""",
),
]
def nodes(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
' of existing entities.',
),
Message(
role='user',
content=f"""
<PREVIOUS MESSAGES>
{to_prompt_json([ep for ep in context['previous_episodes']])}
</PREVIOUS MESSAGES>
<CURRENT MESSAGE>
{context['episode_content']}
</CURRENT MESSAGE>
Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
Each entity in ENTITIES is represented as a JSON object with the following structure:
{{
id: integer id of the entity,
name: "name of the entity",
entity_type: ["Entity", "<optional additional label>", ...],
entity_type_description: "Description of what the entity type represents"
}}
<ENTITIES>
{to_prompt_json(context['extracted_nodes'])}
</ENTITIES>
<EXISTING ENTITIES>
{to_prompt_json(context['existing_nodes'])}
</EXISTING ENTITIES>
Each entry in EXISTING ENTITIES is an object with the following structure:
{{
idx: integer index of the candidate entity (use this when referencing a duplicate),
name: "name of the candidate entity",
entity_types: ["Entity", "<optional additional label>", ...],
...<additional attributes such as summaries or metadata>
}}
For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
Do NOT mark entities as duplicates if:
- They are related but distinct.
- They have similar names or purposes but refer to separate instances or concepts.
Task:
ENTITIES contains {len(context['extracted_nodes'])} entities with IDs 0 through {len(context['extracted_nodes']) - 1}.
Your response MUST include EXACTLY {len(context['extracted_nodes'])} resolutions with IDs 0 through {len(context['extracted_nodes']) - 1}. Do not skip or add IDs.
For every entity, return an object with the following keys:
{{
"id": integer id from ENTITIES,
"name": the best full name for the entity (preserve the original name unless a duplicate has a more complete name),
"duplicate_idx": the idx of the EXISTING ENTITY that is the best duplicate match, or -1 if there is no duplicate,
"duplicates": a sorted list of all idx values from EXISTING ENTITIES that refer to duplicates (deduplicate the list, use [] when none or unsure)
}}
- Only use idx values that appear in EXISTING ENTITIES.
- Set duplicate_idx to the smallest idx you collected for that entity, or -1 if duplicates is empty.
- Never fabricate entities or indices.
""",
),
]
def node_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates nodes from node lists.',
),
Message(
role='user',
content=f"""
Given the following context, deduplicate a list of nodes:
Nodes:
{to_prompt_json(context['nodes'])}
Task:
1. Group nodes together such that all duplicate nodes are in the same list of uuids
2. All duplicate uuids should be grouped together in the same list
3. Also return a new summary that synthesizes the summary into a new short summary
Guidelines:
1. Each uuid from the list of nodes should appear EXACTLY once in your response
2. If a node has no duplicates, it should appear in the response in a list of only one uuid
Respond with a JSON object in the following format:
{{
"nodes": [
{{
"uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
"summary": "Brief summary of the node summaries that appear in the list of names."
}}
]
}}
""",
),
]
versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
```
--------------------------------------------------------------------------------
/graphiti_core/search/search_filters.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from datetime import datetime
from enum import Enum
from typing import Any
from pydantic import BaseModel, Field
from graphiti_core.driver.driver import GraphProvider
class ComparisonOperator(Enum):
equals = '='
not_equals = '<>'
greater_than = '>'
less_than = '<'
greater_than_equal = '>='
less_than_equal = '<='
is_null = 'IS NULL'
is_not_null = 'IS NOT NULL'
class DateFilter(BaseModel):
date: datetime | None = Field(description='A datetime to filter on')
comparison_operator: ComparisonOperator = Field(
description='Comparison operator for date filter'
)
class SearchFilters(BaseModel):
node_labels: list[str] | None = Field(
default=None, description='List of node labels to filter on'
)
edge_types: list[str] | None = Field(
default=None, description='List of edge types to filter on'
)
valid_at: list[list[DateFilter]] | None = Field(default=None)
invalid_at: list[list[DateFilter]] | None = Field(default=None)
created_at: list[list[DateFilter]] | None = Field(default=None)
expired_at: list[list[DateFilter]] | None = Field(default=None)
edge_uuids: list[str] | None = Field(default=None)
def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
mapping = {
ComparisonOperator.greater_than: 'gt',
ComparisonOperator.less_than: 'lt',
ComparisonOperator.greater_than_equal: 'gte',
ComparisonOperator.less_than_equal: 'lte',
}
return mapping.get(op, op.value)
def node_search_filter_query_constructor(
filters: SearchFilters,
provider: GraphProvider,
) -> tuple[list[str], dict[str, Any]]:
filter_queries: list[str] = []
filter_params: dict[str, Any] = {}
if filters.node_labels is not None:
if provider == GraphProvider.KUZU:
node_label_filter = 'list_has_all(n.labels, $labels)'
filter_params['labels'] = filters.node_labels
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels
filter_queries.append(node_label_filter)
return filter_queries, filter_params
def date_filter_query_constructor(
value_name: str, param_name: str, operator: ComparisonOperator
) -> str:
query = '(' + value_name + ' '
if operator == ComparisonOperator.is_null or operator == ComparisonOperator.is_not_null:
query += operator.value + ')'
else:
query += operator.value + ' ' + param_name + ')'
return query
def edge_search_filter_query_constructor(
filters: SearchFilters,
provider: GraphProvider,
) -> tuple[list[str], dict[str, Any]]:
filter_queries: list[str] = []
filter_params: dict[str, Any] = {}
if filters.edge_types is not None:
edge_types = filters.edge_types
filter_queries.append('e.name in $edge_types')
filter_params['edge_types'] = edge_types
if filters.edge_uuids is not None:
filter_queries.append('e.uuid in $edge_uuids')
filter_params['edge_uuids'] = filters.edge_uuids
if filters.node_labels is not None:
if provider == GraphProvider.KUZU:
node_label_filter = (
'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
)
filter_params['labels'] = filters.node_labels
else:
node_labels = '|'.join(filters.node_labels)
node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
filter_queries.append(node_label_filter)
if filters.valid_at is not None:
valid_at_filter = '('
for i, or_list in enumerate(filters.valid_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['valid_at_' + str(j)] = date_filter.date
and_filters = [
date_filter_query_constructor(
'e.valid_at', f'$valid_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filters) - 1:
and_filter_query += ' AND '
valid_at_filter += and_filter_query
if i == len(filters.valid_at) - 1:
valid_at_filter += ')'
else:
valid_at_filter += ' OR '
filter_queries.append(valid_at_filter)
if filters.invalid_at is not None:
invalid_at_filter = '('
for i, or_list in enumerate(filters.invalid_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['invalid_at_' + str(j)] = date_filter.date
and_filters = [
date_filter_query_constructor(
'e.invalid_at', f'$invalid_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filters) - 1:
and_filter_query += ' AND '
invalid_at_filter += and_filter_query
if i == len(filters.invalid_at) - 1:
invalid_at_filter += ')'
else:
invalid_at_filter += ' OR '
filter_queries.append(invalid_at_filter)
if filters.created_at is not None:
created_at_filter = '('
for i, or_list in enumerate(filters.created_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['created_at_' + str(j)] = date_filter.date
and_filters = [
date_filter_query_constructor(
'e.created_at', f'$created_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filters) - 1:
and_filter_query += ' AND '
created_at_filter += and_filter_query
if i == len(filters.created_at) - 1:
created_at_filter += ')'
else:
created_at_filter += ' OR '
filter_queries.append(created_at_filter)
if filters.expired_at is not None:
expired_at_filter = '('
for i, or_list in enumerate(filters.expired_at):
for j, date_filter in enumerate(or_list):
if date_filter.comparison_operator not in [
ComparisonOperator.is_null,
ComparisonOperator.is_not_null,
]:
filter_params['expired_at_' + str(j)] = date_filter.date
and_filters = [
date_filter_query_constructor(
'e.expired_at', f'$expired_at_{j}', date_filter.comparison_operator
)
for j, date_filter in enumerate(or_list)
]
and_filter_query = ''
for j, and_filter in enumerate(and_filters):
and_filter_query += and_filter
if j != len(and_filters) - 1:
and_filter_query += ' AND '
expired_at_filter += and_filter_query
if i == len(filters.expired_at) - 1:
expired_at_filter += ')'
else:
expired_at_filter += ' OR '
filter_queries.append(expired_at_filter)
return filter_queries, filter_params
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_http_integration.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Integration test for MCP server using HTTP streaming transport.
This avoids the stdio subprocess timing issues.
"""
import asyncio
import json
import sys
import time
from mcp.client.session import ClientSession
async def test_http_transport(base_url: str = 'http://localhost:8000'):
"""Test MCP server with HTTP streaming transport."""
# Import the streamable http client
try:
from mcp.client.streamable_http import streamablehttp_client as http_client
except ImportError:
print('❌ Streamable HTTP client not available in MCP SDK')
return False
test_group_id = f'test_http_{int(time.time())}'
print('🚀 Testing MCP Server with HTTP streaming transport')
print(f' Server URL: {base_url}')
print(f' Test Group: {test_group_id}')
print('=' * 60)
try:
# Connect to the server via HTTP
print('\n🔌 Connecting to server...')
async with http_client(base_url) as (read_stream, write_stream):
session = ClientSession(read_stream, write_stream)
await session.initialize()
print('✅ Connected successfully')
# Test 1: List tools
print('\n📋 Test 1: Listing tools...')
try:
result = await session.list_tools()
tools = [tool.name for tool in result.tools]
expected = [
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'clear_graph',
]
found = [t for t in expected if t in tools]
print(f' ✅ Found {len(tools)} tools ({len(found)}/{len(expected)} expected)')
for tool in tools[:5]:
print(f' - {tool}')
except Exception as e:
print(f' ❌ Failed: {e}')
return False
# Test 2: Add memory
print('\n📝 Test 2: Adding memory...')
try:
result = await session.call_tool(
'add_memory',
{
'name': 'Integration Test Episode',
'episode_body': 'This is a test episode created via HTTP transport integration test.',
'group_id': test_group_id,
'source': 'text',
'source_description': 'HTTP Integration Test',
},
)
if result.content and result.content[0].text:
response = result.content[0].text
if 'success' in response.lower() or 'queued' in response.lower():
print(' ✅ Memory added successfully')
else:
print(f' ❌ Unexpected response: {response[:100]}')
else:
print(' ❌ No content in response')
except Exception as e:
print(f' ❌ Failed: {e}')
# Test 3: Search nodes (with delay for processing)
print('\n🔍 Test 3: Searching nodes...')
await asyncio.sleep(2) # Wait for async processing
try:
result = await session.call_tool(
'search_memory_nodes',
{'query': 'integration test episode', 'group_ids': [test_group_id], 'limit': 5},
)
if result.content and result.content[0].text:
response = result.content[0].text
try:
data = json.loads(response)
nodes = data.get('nodes', [])
print(f' ✅ Search returned {len(nodes)} nodes')
except Exception: # noqa: E722
print(f' ✅ Search completed: {response[:100]}')
else:
print(' ⚠️ No results (may be processing)')
except Exception as e:
print(f' ❌ Failed: {e}')
# Test 4: Get episodes
print('\n📚 Test 4: Getting episodes...')
try:
result = await session.call_tool(
'get_episodes', {'group_ids': [test_group_id], 'limit': 10}
)
if result.content and result.content[0].text:
response = result.content[0].text
try:
data = json.loads(response)
episodes = data.get('episodes', [])
print(f' ✅ Found {len(episodes)} episodes')
except Exception: # noqa: E722
print(f' ✅ Episodes retrieved: {response[:100]}')
else:
print(' ⚠️ No episodes found')
except Exception as e:
print(f' ❌ Failed: {e}')
# Test 5: Clear graph
print('\n🧹 Test 5: Clearing graph...')
try:
result = await session.call_tool('clear_graph', {'group_id': test_group_id})
if result.content and result.content[0].text:
response = result.content[0].text
if 'success' in response.lower() or 'cleared' in response.lower():
print(' ✅ Graph cleared successfully')
else:
print(f' ✅ Clear completed: {response[:100]}')
else:
print(' ❌ No response')
except Exception as e:
print(f' ❌ Failed: {e}')
print('\n' + '=' * 60)
print('✅ All integration tests completed!')
return True
except Exception as e:
print(f'\n❌ Connection failed: {e}')
return False
async def test_sse_transport(base_url: str = 'http://localhost:8000'):
"""Test MCP server with SSE transport."""
# Import the SSE client
try:
from mcp.client.sse import sse_client
except ImportError:
print('❌ SSE client not available in MCP SDK')
return False
test_group_id = f'test_sse_{int(time.time())}'
print('🚀 Testing MCP Server with SSE transport')
print(f' Server URL: {base_url}/sse')
print(f' Test Group: {test_group_id}')
print('=' * 60)
try:
# Connect to the server via SSE
print('\n🔌 Connecting to server...')
async with sse_client(f'{base_url}/sse') as (read_stream, write_stream):
session = ClientSession(read_stream, write_stream)
await session.initialize()
print('✅ Connected successfully')
# Run same tests as HTTP
print('\n📋 Test 1: Listing tools...')
try:
result = await session.list_tools()
tools = [tool.name for tool in result.tools]
print(f' ✅ Found {len(tools)} tools')
for tool in tools[:3]:
print(f' - {tool}')
except Exception as e:
print(f' ❌ Failed: {e}')
return False
print('\n' + '=' * 60)
print('✅ SSE transport test completed!')
return True
except Exception as e:
print(f'\n❌ SSE connection failed: {e}')
return False
async def main():
"""Run integration tests."""
# Check command line arguments
if len(sys.argv) < 2:
print('Usage: python test_http_integration.py <transport> [host] [port]')
print(' transport: http or sse')
print(' host: server host (default: localhost)')
print(' port: server port (default: 8000)')
sys.exit(1)
transport = sys.argv[1].lower()
host = sys.argv[2] if len(sys.argv) > 2 else 'localhost'
port = sys.argv[3] if len(sys.argv) > 3 else '8000'
base_url = f'http://{host}:{port}'
# Check if server is running
import httpx
try:
async with httpx.AsyncClient() as client:
# Try to connect to the server
await client.get(base_url, timeout=2.0)
except Exception: # noqa: E722
print(f'⚠️ Server not responding at {base_url}')
print('Please start the server with one of these commands:')
print(f' uv run main.py --transport http --port {port}')
print(f' uv run main.py --transport sse --port {port}')
sys.exit(1)
# Run the appropriate test
if transport == 'http':
success = await test_http_transport(base_url)
elif transport == 'sse':
success = await test_sse_transport(base_url)
else:
print(f'❌ Unknown transport: {transport}')
sys.exit(1)
sys.exit(0 if success else 1)
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/dedup_helpers.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2024, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
import math
import re
from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass, field
from functools import lru_cache
from hashlib import blake2b
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from graphiti_core.nodes import EntityNode
_NAME_ENTROPY_THRESHOLD = 1.5
_MIN_NAME_LENGTH = 6
_MIN_TOKEN_COUNT = 2
_FUZZY_JACCARD_THRESHOLD = 0.9
_MINHASH_PERMUTATIONS = 32
_MINHASH_BAND_SIZE = 4
def _normalize_string_exact(name: str) -> str:
"""Lowercase text and collapse whitespace so equal names map to the same key."""
normalized = re.sub(r'[\s]+', ' ', name.lower())
return normalized.strip()
def _normalize_name_for_fuzzy(name: str) -> str:
"""Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
normalized = normalized.strip()
return re.sub(r'[\s]+', ' ', normalized)
def _name_entropy(normalized_name: str) -> float:
"""Approximate text specificity using Shannon entropy over characters.
We strip spaces, count how often each character appears, and sum
probability * -log2(probability). Short or repetitive names yield low
entropy, which signals we should defer resolution to the LLM instead of
trusting fuzzy similarity.
"""
if not normalized_name:
return 0.0
counts: dict[str, int] = {}
for char in normalized_name.replace(' ', ''):
counts[char] = counts.get(char, 0) + 1
total = sum(counts.values())
if total == 0:
return 0.0
entropy = 0.0
for count in counts.values():
probability = count / total
entropy -= probability * math.log2(probability)
return entropy
def _has_high_entropy(normalized_name: str) -> bool:
"""Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
token_count = len(normalized_name.split())
if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
return False
return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
def _shingles(normalized_name: str) -> set[str]:
"""Create 3-gram shingles from the normalized name for MinHash calculations."""
cleaned = normalized_name.replace(' ', '')
if len(cleaned) < 2:
return {cleaned} if cleaned else set()
return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
def _hash_shingle(shingle: str, seed: int) -> int:
"""Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
return int.from_bytes(digest.digest(), 'big')
def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
"""Compute the MinHash signature for the shingle set across predefined permutations."""
if not shingles:
return tuple()
seeds = range(_MINHASH_PERMUTATIONS)
signature: list[int] = []
for seed in seeds:
min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
signature.append(min_hash)
return tuple(signature)
def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
"""Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
signature_list = list(signature)
if not signature_list:
return []
bands: list[tuple[int, ...]] = []
for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
if len(band) == _MINHASH_BAND_SIZE:
bands.append(band)
return bands
def _jaccard_similarity(a: set[str], b: set[str]) -> float:
"""Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
if not a and not b:
return 1.0
if not a or not b:
return 0.0
intersection = len(a.intersection(b))
union = len(a.union(b))
return intersection / union if union else 0.0
@lru_cache(maxsize=512)
def _cached_shingles(name: str) -> set[str]:
"""Cache shingle sets per normalized name to avoid recomputation within a worker."""
return _shingles(name)
@dataclass
class DedupCandidateIndexes:
"""Precomputed lookup structures that drive entity deduplication heuristics."""
existing_nodes: list[EntityNode]
nodes_by_uuid: dict[str, EntityNode]
normalized_existing: defaultdict[str, list[EntityNode]]
shingles_by_candidate: dict[str, set[str]]
lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
@dataclass
class DedupResolutionState:
"""Mutable resolution bookkeeping shared across deterministic and LLM passes."""
resolved_nodes: list[EntityNode | None]
uuid_map: dict[str, str]
unresolved_indices: list[int]
duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
"""Precompute exact and fuzzy lookup structures once per dedupe run."""
normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
nodes_by_uuid: dict[str, EntityNode] = {}
shingles_by_candidate: dict[str, set[str]] = {}
lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
for candidate in existing_nodes:
normalized = _normalize_string_exact(candidate.name)
normalized_existing[normalized].append(candidate)
nodes_by_uuid[candidate.uuid] = candidate
shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
shingles_by_candidate[candidate.uuid] = shingles
signature = _minhash_signature(shingles)
for band_index, band in enumerate(_lsh_bands(signature)):
lsh_buckets[(band_index, band)].append(candidate.uuid)
return DedupCandidateIndexes(
existing_nodes=existing_nodes,
nodes_by_uuid=nodes_by_uuid,
normalized_existing=normalized_existing,
shingles_by_candidate=shingles_by_candidate,
lsh_buckets=lsh_buckets,
)
def _resolve_with_similarity(
extracted_nodes: list[EntityNode],
indexes: DedupCandidateIndexes,
state: DedupResolutionState,
) -> None:
"""Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
for idx, node in enumerate(extracted_nodes):
normalized_exact = _normalize_string_exact(node.name)
normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
if not _has_high_entropy(normalized_fuzzy):
state.unresolved_indices.append(idx)
continue
existing_matches = indexes.normalized_existing.get(normalized_exact, [])
if len(existing_matches) == 1:
match = existing_matches[0]
state.resolved_nodes[idx] = match
state.uuid_map[node.uuid] = match.uuid
if match.uuid != node.uuid:
state.duplicate_pairs.append((node, match))
continue
if len(existing_matches) > 1:
state.unresolved_indices.append(idx)
continue
shingles = _cached_shingles(normalized_fuzzy)
signature = _minhash_signature(shingles)
candidate_ids: set[str] = set()
for band_index, band in enumerate(_lsh_bands(signature)):
candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
best_candidate: EntityNode | None = None
best_score = 0.0
for candidate_id in candidate_ids:
candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
score = _jaccard_similarity(shingles, candidate_shingles)
if score > best_score:
best_score = score
best_candidate = indexes.nodes_by_uuid.get(candidate_id)
if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
state.resolved_nodes[idx] = best_candidate
state.uuid_map[node.uuid] = best_candidate.uuid
if best_candidate.uuid != node.uuid:
state.duplicate_pairs.append((node, best_candidate))
continue
state.unresolved_indices.append(idx)
__all__ = [
'DedupCandidateIndexes',
'DedupResolutionState',
'_normalize_string_exact',
'_normalize_name_for_fuzzy',
'_has_high_entropy',
'_minhash_signature',
'_lsh_bands',
'_jaccard_similarity',
'_cached_shingles',
'_FUZZY_JACCARD_THRESHOLD',
'_build_candidate_indexes',
'_resolve_with_similarity',
]
```
--------------------------------------------------------------------------------
/examples/quickstart/quickstart_neo4j.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.nodes import EpisodeType
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
#################################################
# CONFIGURATION
#################################################
# Set up logging and environment variables for
# connecting to Neo4j database
#################################################
# Configure logging
logging.basicConfig(
level=INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
load_dotenv()
# Neo4j connection parameters
# Make sure Neo4j Desktop is running with a local DBMS started
neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
if not neo4j_uri or not neo4j_user or not neo4j_password:
raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
async def main():
#################################################
# INITIALIZATION
#################################################
# Connect to Neo4j and set up Graphiti indices
# This is required before using other Graphiti
# functionality
#################################################
# Initialize Graphiti with Neo4j connection
graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
try:
#################################################
# ADDING EPISODES
#################################################
# Episodes are the primary units of information
# in Graphiti. They can be text or structured JSON
# and are automatically processed to extract entities
# and relationships.
#################################################
# Example: Add Episodes
# Episodes list containing both text and JSON episodes
episodes = [
{
'content': 'Kamala Harris is the Attorney General of California. She was previously '
'the district attorney for San Francisco.',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'state': 'California',
'previous_role': 'Lieutenant Governor',
'previous_location': 'San Francisco',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'term_start': 'January 7, 2019',
'term_end': 'Present',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
]
# Add episodes to the graph
for i, episode in enumerate(episodes):
await graphiti.add_episode(
name=f'Freakonomics Radio {i}',
episode_body=episode['content']
if isinstance(episode['content'], str)
else json.dumps(episode['content']),
source=episode['type'],
source_description=episode['description'],
reference_time=datetime.now(timezone.utc),
)
print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
#################################################
# BASIC SEARCH
#################################################
# The simplest way to retrieve relationships (edges)
# from Graphiti is using the search method, which
# performs a hybrid search combining semantic
# similarity and BM25 text retrieval.
#################################################
# Perform a hybrid search combining semantic similarity and BM25 retrieval
print("\nSearching for: 'Who was the California Attorney General?'")
results = await graphiti.search('Who was the California Attorney General?')
# Print search results
print('\nSearch Results:')
for result in results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
#################################################
# CENTER NODE SEARCH
#################################################
# For more contextually relevant results, you can
# use a center node to rerank search results based
# on their graph distance to a specific node
#################################################
# Use the top search result's UUID as the center node for reranking
if results and len(results) > 0:
# Get the source node UUID from the top result
center_node_uuid = results[0].source_node_uuid
print('\nReranking search results based on graph distance:')
print(f'Using center node UUID: {center_node_uuid}')
reranked_results = await graphiti.search(
'Who was the California Attorney General?', center_node_uuid=center_node_uuid
)
# Print reranked search results
print('\nReranked Search Results:')
for result in reranked_results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
else:
print('No results found in the initial search to use as center node.')
#################################################
# NODE SEARCH USING SEARCH RECIPES
#################################################
# Graphiti provides predefined search recipes
# optimized for different search scenarios.
# Here we use NODE_HYBRID_SEARCH_RRF for retrieving
# nodes directly instead of edges.
#################################################
# Example: Perform a node search using _search method with standard recipes
print(
'\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
)
# Use a predefined search configuration recipe and modify its limit
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
node_search_config.limit = 5 # Limit to 5 results
# Execute the node search
node_search_results = await graphiti._search(
query='California Governor',
config=node_search_config,
)
# Print node search results
print('\nNode Search Results:')
for node in node_search_results.nodes:
print(f'Node UUID: {node.uuid}')
print(f'Node Name: {node.name}')
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
print(f'Content Summary: {node_summary}')
print(f'Node Labels: {", ".join(node.labels)}')
print(f'Created At: {node.created_at}')
if hasattr(node, 'attributes') and node.attributes:
print('Attributes:')
for key, value in node.attributes.items():
print(f' {key}: {value}')
print('---')
finally:
#################################################
# CLEANUP
#################################################
# Always close the connection to Neo4j when
# finished to properly release resources
#################################################
# Close the connection
await graphiti.close()
print('\nConnection closed')
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_transports.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Test MCP server with different transport modes using the MCP SDK.
Tests both SSE and streaming HTTP transports.
"""
import asyncio
import json
import sys
import time
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
class MCPTransportTester:
"""Test MCP server with different transport modes."""
def __init__(self, transport: str = 'sse', host: str = 'localhost', port: int = 8000):
self.transport = transport
self.host = host
self.port = port
self.base_url = f'http://{host}:{port}'
self.test_group_id = f'test_{transport}_{int(time.time())}'
self.session = None
async def connect_sse(self) -> ClientSession:
"""Connect using SSE transport."""
print(f'🔌 Connecting to MCP server via SSE at {self.base_url}/sse')
# Use the sse_client to connect
async with sse_client(self.base_url + '/sse') as (read_stream, write_stream):
self.session = ClientSession(read_stream, write_stream)
await self.session.initialize()
return self.session
async def connect_http(self) -> ClientSession:
"""Connect using streaming HTTP transport."""
from mcp.client.http import http_client
print(f'🔌 Connecting to MCP server via HTTP at {self.base_url}')
# Use the http_client to connect
async with http_client(self.base_url) as (read_stream, write_stream):
self.session = ClientSession(read_stream, write_stream)
await self.session.initialize()
return self.session
async def test_list_tools(self) -> bool:
"""Test listing available tools."""
print('\n📋 Testing list_tools...')
try:
result = await self.session.list_tools()
tools = [tool.name for tool in result.tools]
expected_tools = [
'add_memory',
'search_memory_nodes',
'search_memory_facts',
'get_episodes',
'delete_episode',
'get_entity_edge',
'delete_entity_edge',
'clear_graph',
]
print(f' ✅ Found {len(tools)} tools')
for tool in tools[:5]: # Show first 5 tools
print(f' - {tool}')
# Check if we have most expected tools
found_tools = [t for t in expected_tools if t in tools]
success = len(found_tools) >= len(expected_tools) * 0.8
if success:
print(
f' ✅ Tool discovery successful ({len(found_tools)}/{len(expected_tools)} expected tools)'
)
else:
print(f' ❌ Missing too many tools ({len(found_tools)}/{len(expected_tools)})')
return success
except Exception as e:
print(f' ❌ Failed to list tools: {e}')
return False
async def test_add_memory(self) -> bool:
"""Test adding a memory."""
print('\n📝 Testing add_memory...')
try:
result = await self.session.call_tool(
'add_memory',
{
'name': 'Test Episode',
'episode_body': 'This is a test episode created by the MCP transport test suite.',
'group_id': self.test_group_id,
'source': 'text',
'source_description': 'Integration test',
},
)
# Check the result
if result.content:
content = result.content[0]
if hasattr(content, 'text'):
response = (
json.loads(content.text)
if content.text.startswith('{')
else {'message': content.text}
)
if 'success' in str(response).lower() or 'queued' in str(response).lower():
print(f' ✅ Memory added successfully: {response.get("message", "OK")}')
return True
else:
print(f' ❌ Unexpected response: {response}')
return False
print(' ❌ No content in response')
return False
except Exception as e:
print(f' ❌ Failed to add memory: {e}')
return False
async def test_search_nodes(self) -> bool:
"""Test searching for nodes."""
print('\n🔍 Testing search_memory_nodes...')
# Wait a bit for the memory to be processed
await asyncio.sleep(2)
try:
result = await self.session.call_tool(
'search_memory_nodes',
{'query': 'test episode', 'group_ids': [self.test_group_id], 'limit': 5},
)
if result.content:
content = result.content[0]
if hasattr(content, 'text'):
response = (
json.loads(content.text) if content.text.startswith('{') else {'nodes': []}
)
nodes = response.get('nodes', [])
print(f' ✅ Search returned {len(nodes)} nodes')
return True
print(' ⚠️ No nodes found (this may be expected if processing is async)')
return True # Don't fail on empty results
except Exception as e:
print(f' ❌ Failed to search nodes: {e}')
return False
async def test_get_episodes(self) -> bool:
"""Test getting episodes."""
print('\n📚 Testing get_episodes...')
try:
result = await self.session.call_tool(
'get_episodes', {'group_ids': [self.test_group_id], 'limit': 10}
)
if result.content:
content = result.content[0]
if hasattr(content, 'text'):
response = (
json.loads(content.text)
if content.text.startswith('{')
else {'episodes': []}
)
episodes = response.get('episodes', [])
print(f' ✅ Found {len(episodes)} episodes')
return True
print(' ⚠️ No episodes found')
return True
except Exception as e:
print(f' ❌ Failed to get episodes: {e}')
return False
async def test_clear_graph(self) -> bool:
"""Test clearing the graph."""
print('\n🧹 Testing clear_graph...')
try:
result = await self.session.call_tool('clear_graph', {'group_id': self.test_group_id})
if result.content:
content = result.content[0]
if hasattr(content, 'text'):
response = content.text
if 'success' in response.lower() or 'cleared' in response.lower():
print(' ✅ Graph cleared successfully')
return True
print(' ❌ Failed to clear graph')
return False
except Exception as e:
print(f' ❌ Failed to clear graph: {e}')
return False
async def run_tests(self) -> bool:
"""Run all tests for the configured transport."""
print(f'\n{"=" * 60}')
print(f'🚀 Testing MCP Server with {self.transport.upper()} transport')
print(f' Server: {self.base_url}')
print(f' Test Group: {self.test_group_id}')
print('=' * 60)
try:
# Connect based on transport type
if self.transport == 'sse':
await self.connect_sse()
elif self.transport == 'http':
await self.connect_http()
else:
print(f'❌ Unknown transport: {self.transport}')
return False
print(f'✅ Connected via {self.transport.upper()}')
# Run tests
results = []
results.append(await self.test_list_tools())
results.append(await self.test_add_memory())
results.append(await self.test_search_nodes())
results.append(await self.test_get_episodes())
results.append(await self.test_clear_graph())
# Summary
passed = sum(results)
total = len(results)
success = passed == total
print(f'\n{"=" * 60}')
print(f'📊 Results for {self.transport.upper()} transport:')
print(f' Passed: {passed}/{total}')
print(f' Status: {"✅ ALL TESTS PASSED" if success else "❌ SOME TESTS FAILED"}')
print('=' * 60)
return success
except Exception as e:
print(f'❌ Test suite failed: {e}')
return False
finally:
if self.session:
await self.session.close()
async def main():
"""Run tests for both transports."""
# Parse command line arguments
transport = sys.argv[1] if len(sys.argv) > 1 else 'sse'
host = sys.argv[2] if len(sys.argv) > 2 else 'localhost'
port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
# Create tester
tester = MCPTransportTester(transport, host, port)
# Run tests
success = await tester.run_tests()
# Exit with appropriate code
exit(0 if success else 1)
if __name__ == '__main__':
asyncio.run(main())
```
--------------------------------------------------------------------------------
/examples/quickstart/quickstart_neptune.py:
--------------------------------------------------------------------------------
```python
"""
Copyright 2025, Zep Software, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import asyncio
import json
import logging
import os
from datetime import datetime, timezone
from logging import INFO
from dotenv import load_dotenv
from graphiti_core import Graphiti
from graphiti_core.driver.neptune_driver import NeptuneDriver
from graphiti_core.nodes import EpisodeType
from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
#################################################
# CONFIGURATION
#################################################
# Set up logging and environment variables for
# connecting to Neptune database
#################################################
# Configure logging
logging.basicConfig(
level=INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
logger = logging.getLogger(__name__)
load_dotenv()
# Neptune and OpenSearch connection parameters
neptune_uri = os.environ.get('NEPTUNE_HOST')
neptune_port = int(os.environ.get('NEPTUNE_PORT', 8182))
aoss_host = os.environ.get('AOSS_HOST')
if not neptune_uri:
raise ValueError('NEPTUNE_HOST must be set')
if not aoss_host:
raise ValueError('AOSS_HOST must be set')
async def main():
#################################################
# INITIALIZATION
#################################################
# Connect to Neptune and set up Graphiti indices
# This is required before using other Graphiti
# functionality
#################################################
# Initialize Graphiti with Neptune connection
driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
graphiti = Graphiti(graph_driver=driver)
try:
# Initialize the graph database with graphiti's indices. This only needs to be done once.
await driver.delete_aoss_indices()
await driver._delete_all_data()
await graphiti.build_indices_and_constraints()
#################################################
# ADDING EPISODES
#################################################
# Episodes are the primary units of information
# in Graphiti. They can be text or structured JSON
# and are automatically processed to extract entities
# and relationships.
#################################################
# Example: Add Episodes
# Episodes list containing both text and JSON episodes
episodes = [
{
'content': 'Kamala Harris is the Attorney General of California. She was previously '
'the district attorney for San Francisco.',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
'type': EpisodeType.text,
'description': 'podcast transcript',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'state': 'California',
'previous_role': 'Lieutenant Governor',
'previous_location': 'San Francisco',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
{
'content': {
'name': 'Gavin Newsom',
'position': 'Governor',
'term_start': 'January 7, 2019',
'term_end': 'Present',
},
'type': EpisodeType.json,
'description': 'podcast metadata',
},
]
# Add episodes to the graph
for i, episode in enumerate(episodes):
await graphiti.add_episode(
name=f'Freakonomics Radio {i}',
episode_body=episode['content']
if isinstance(episode['content'], str)
else json.dumps(episode['content']),
source=episode['type'],
source_description=episode['description'],
reference_time=datetime.now(timezone.utc),
)
print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
await graphiti.build_communities()
#################################################
# BASIC SEARCH
#################################################
# The simplest way to retrieve relationships (edges)
# from Graphiti is using the search method, which
# performs a hybrid search combining semantic
# similarity and BM25 text retrieval.
#################################################
# Perform a hybrid search combining semantic similarity and BM25 retrieval
print("\nSearching for: 'Who was the California Attorney General?'")
results = await graphiti.search('Who was the California Attorney General?')
# Print search results
print('\nSearch Results:')
for result in results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
#################################################
# CENTER NODE SEARCH
#################################################
# For more contextually relevant results, you can
# use a center node to rerank search results based
# on their graph distance to a specific node
#################################################
# Use the top search result's UUID as the center node for reranking
if results and len(results) > 0:
# Get the source node UUID from the top result
center_node_uuid = results[0].source_node_uuid
print('\nReranking search results based on graph distance:')
print(f'Using center node UUID: {center_node_uuid}')
reranked_results = await graphiti.search(
'Who was the California Attorney General?', center_node_uuid=center_node_uuid
)
# Print reranked search results
print('\nReranked Search Results:')
for result in reranked_results:
print(f'UUID: {result.uuid}')
print(f'Fact: {result.fact}')
if hasattr(result, 'valid_at') and result.valid_at:
print(f'Valid from: {result.valid_at}')
if hasattr(result, 'invalid_at') and result.invalid_at:
print(f'Valid until: {result.invalid_at}')
print('---')
else:
print('No results found in the initial search to use as center node.')
#################################################
# NODE SEARCH USING SEARCH RECIPES
#################################################
# Graphiti provides predefined search recipes
# optimized for different search scenarios.
# Here we use NODE_HYBRID_SEARCH_RRF for retrieving
# nodes directly instead of edges.
#################################################
# Example: Perform a node search using _search method with standard recipes
print(
'\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
)
# Use a predefined search configuration recipe and modify its limit
node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
node_search_config.limit = 5 # Limit to 5 results
# Execute the node search
node_search_results = await graphiti._search(
query='California Governor',
config=node_search_config,
)
# Print node search results
print('\nNode Search Results:')
for node in node_search_results.nodes:
print(f'Node UUID: {node.uuid}')
print(f'Node Name: {node.name}')
node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
print(f'Content Summary: {node_summary}')
print(f'Node Labels: {", ".join(node.labels)}')
print(f'Created At: {node.created_at}')
if hasattr(node, 'attributes') and node.attributes:
print('Attributes:')
for key, value in node.attributes.items():
print(f' {key}: {value}')
print('---')
finally:
#################################################
# CLEANUP
#################################################
# Always close the connection to Neptune when
# finished to properly release resources
#################################################
# Close the connection
await graphiti.close()
print('\nConnection closed')
if __name__ == '__main__':
asyncio.run(main())
```