This is page 4 of 12. Use http://codebase.md/getzep/graphiti?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .github
│ ├── dependabot.yml
│ ├── ISSUE_TEMPLATE
│ │ └── bug_report.md
│ ├── pull_request_template.md
│ ├── secret_scanning.yml
│ └── workflows
│ ├── ai-moderator.yml
│ ├── cla.yml
│ ├── claude-code-review-manual.yml
│ ├── claude-code-review.yml
│ ├── claude.yml
│ ├── codeql.yml
│ ├── daily_issue_maintenance.yml
│ ├── issue-triage.yml
│ ├── lint.yml
│ ├── release-graphiti-core.yml
│ ├── release-mcp-server.yml
│ ├── release-server-container.yml
│ ├── typecheck.yml
│ └── unit_tests.yml
├── .gitignore
├── AGENTS.md
├── CLAUDE.md
├── CODE_OF_CONDUCT.md
├── conftest.py
├── CONTRIBUTING.md
├── depot.json
├── docker-compose.test.yml
├── docker-compose.yml
├── Dockerfile
├── ellipsis.yaml
├── examples
│ ├── azure-openai
│ │ ├── .env.example
│ │ ├── azure_openai_neo4j.py
│ │ └── README.md
│ ├── data
│ │ └── manybirds_products.json
│ ├── ecommerce
│ │ ├── runner.ipynb
│ │ └── runner.py
│ ├── langgraph-agent
│ │ ├── agent.ipynb
│ │ └── tinybirds-jess.png
│ ├── opentelemetry
│ │ ├── .env.example
│ │ ├── otel_stdout_example.py
│ │ ├── pyproject.toml
│ │ ├── README.md
│ │ └── uv.lock
│ ├── podcast
│ │ ├── podcast_runner.py
│ │ ├── podcast_transcript.txt
│ │ └── transcript_parser.py
│ ├── quickstart
│ │ ├── quickstart_falkordb.py
│ │ ├── quickstart_neo4j.py
│ │ ├── quickstart_neptune.py
│ │ ├── README.md
│ │ └── requirements.txt
│ └── wizard_of_oz
│ ├── parser.py
│ ├── runner.py
│ └── woo.txt
├── graphiti_core
│ ├── __init__.py
│ ├── cross_encoder
│ │ ├── __init__.py
│ │ ├── bge_reranker_client.py
│ │ ├── client.py
│ │ ├── gemini_reranker_client.py
│ │ └── openai_reranker_client.py
│ ├── decorators.py
│ ├── driver
│ │ ├── __init__.py
│ │ ├── driver.py
│ │ ├── falkordb_driver.py
│ │ ├── graph_operations
│ │ │ └── graph_operations.py
│ │ ├── kuzu_driver.py
│ │ ├── neo4j_driver.py
│ │ ├── neptune_driver.py
│ │ └── search_interface
│ │ └── search_interface.py
│ ├── edges.py
│ ├── embedder
│ │ ├── __init__.py
│ │ ├── azure_openai.py
│ │ ├── client.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ └── voyage.py
│ ├── errors.py
│ ├── graph_queries.py
│ ├── graphiti_types.py
│ ├── graphiti.py
│ ├── helpers.py
│ ├── llm_client
│ │ ├── __init__.py
│ │ ├── anthropic_client.py
│ │ ├── azure_openai_client.py
│ │ ├── client.py
│ │ ├── config.py
│ │ ├── errors.py
│ │ ├── gemini_client.py
│ │ ├── groq_client.py
│ │ ├── openai_base_client.py
│ │ ├── openai_client.py
│ │ ├── openai_generic_client.py
│ │ └── utils.py
│ ├── migrations
│ │ └── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── edges
│ │ │ ├── __init__.py
│ │ │ └── edge_db_queries.py
│ │ └── nodes
│ │ ├── __init__.py
│ │ └── node_db_queries.py
│ ├── nodes.py
│ ├── prompts
│ │ ├── __init__.py
│ │ ├── dedupe_edges.py
│ │ ├── dedupe_nodes.py
│ │ ├── eval.py
│ │ ├── extract_edge_dates.py
│ │ ├── extract_edges.py
│ │ ├── extract_nodes.py
│ │ ├── invalidate_edges.py
│ │ ├── lib.py
│ │ ├── models.py
│ │ ├── prompt_helpers.py
│ │ ├── snippets.py
│ │ └── summarize_nodes.py
│ ├── py.typed
│ ├── search
│ │ ├── __init__.py
│ │ ├── search_config_recipes.py
│ │ ├── search_config.py
│ │ ├── search_filters.py
│ │ ├── search_helpers.py
│ │ ├── search_utils.py
│ │ └── search.py
│ ├── telemetry
│ │ ├── __init__.py
│ │ └── telemetry.py
│ ├── tracer.py
│ └── utils
│ ├── __init__.py
│ ├── bulk_utils.py
│ ├── datetime_utils.py
│ ├── maintenance
│ │ ├── __init__.py
│ │ ├── community_operations.py
│ │ ├── dedup_helpers.py
│ │ ├── edge_operations.py
│ │ ├── graph_data_operations.py
│ │ ├── node_operations.py
│ │ └── temporal_operations.py
│ ├── ontology_utils
│ │ └── entity_types_utils.py
│ └── text_utils.py
├── images
│ ├── arxiv-screenshot.png
│ ├── graphiti-graph-intro.gif
│ ├── graphiti-intro-slides-stock-2.gif
│ └── simple_graph.svg
├── LICENSE
├── Makefile
├── mcp_server
│ ├── .env.example
│ ├── .python-version
│ ├── config
│ │ ├── config-docker-falkordb-combined.yaml
│ │ ├── config-docker-falkordb.yaml
│ │ ├── config-docker-neo4j.yaml
│ │ ├── config.yaml
│ │ └── mcp_config_stdio_example.json
│ ├── docker
│ │ ├── build-standalone.sh
│ │ ├── build-with-version.sh
│ │ ├── docker-compose-falkordb.yml
│ │ ├── docker-compose-neo4j.yml
│ │ ├── docker-compose.yml
│ │ ├── Dockerfile
│ │ ├── Dockerfile.standalone
│ │ ├── github-actions-example.yml
│ │ ├── README-falkordb-combined.md
│ │ └── README.md
│ ├── docs
│ │ └── cursor_rules.md
│ ├── main.py
│ ├── pyproject.toml
│ ├── pytest.ini
│ ├── README.md
│ ├── src
│ │ ├── __init__.py
│ │ ├── config
│ │ │ ├── __init__.py
│ │ │ └── schema.py
│ │ ├── graphiti_mcp_server.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── entity_types.py
│ │ │ └── response_types.py
│ │ ├── services
│ │ │ ├── __init__.py
│ │ │ ├── factories.py
│ │ │ └── queue_service.py
│ │ └── utils
│ │ ├── __init__.py
│ │ ├── formatting.py
│ │ └── utils.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── conftest.py
│ │ ├── pytest.ini
│ │ ├── README.md
│ │ ├── run_tests.py
│ │ ├── test_async_operations.py
│ │ ├── test_comprehensive_integration.py
│ │ ├── test_configuration.py
│ │ ├── test_falkordb_integration.py
│ │ ├── test_fixtures.py
│ │ ├── test_http_integration.py
│ │ ├── test_integration.py
│ │ ├── test_mcp_integration.py
│ │ ├── test_mcp_transports.py
│ │ ├── test_stdio_simple.py
│ │ └── test_stress_load.py
│ └── uv.lock
├── OTEL_TRACING.md
├── py.typed
├── pyproject.toml
├── pytest.ini
├── README.md
├── SECURITY.md
├── server
│ ├── .env.example
│ ├── graph_service
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── dto
│ │ │ ├── __init__.py
│ │ │ ├── common.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ ├── main.py
│ │ ├── routers
│ │ │ ├── __init__.py
│ │ │ ├── ingest.py
│ │ │ └── retrieve.py
│ │ └── zep_graphiti.py
│ ├── Makefile
│ ├── pyproject.toml
│ ├── README.md
│ └── uv.lock
├── signatures
│ └── version1
│ └── cla.json
├── tests
│ ├── cross_encoder
│ │ ├── test_bge_reranker_client_int.py
│ │ └── test_gemini_reranker_client.py
│ ├── driver
│ │ ├── __init__.py
│ │ └── test_falkordb_driver.py
│ ├── embedder
│ │ ├── embedder_fixtures.py
│ │ ├── test_gemini.py
│ │ ├── test_openai.py
│ │ └── test_voyage.py
│ ├── evals
│ │ ├── data
│ │ │ └── longmemeval_data
│ │ │ ├── longmemeval_oracle.json
│ │ │ └── README.md
│ │ ├── eval_cli.py
│ │ ├── eval_e2e_graph_building.py
│ │ ├── pytest.ini
│ │ └── utils.py
│ ├── helpers_test.py
│ ├── llm_client
│ │ ├── test_anthropic_client_int.py
│ │ ├── test_anthropic_client.py
│ │ ├── test_azure_openai_client.py
│ │ ├── test_client.py
│ │ ├── test_errors.py
│ │ └── test_gemini_client.py
│ ├── test_edge_int.py
│ ├── test_entity_exclusion_int.py
│ ├── test_graphiti_int.py
│ ├── test_graphiti_mock.py
│ ├── test_node_int.py
│ ├── test_text_utils.py
│ └── utils
│ ├── maintenance
│ │ ├── test_bulk_utils.py
│ │ ├── test_edge_operations.py
│ │ ├── test_node_operations.py
│ │ └── test_temporal_operations_int.py
│ └── search
│ └── search_utils_test.py
├── uv.lock
└── Zep-CLA.md
```
# Files
--------------------------------------------------------------------------------
/graphiti_core/prompts/extract_edges.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import Any, Protocol, TypedDict
18 |
19 | from pydantic import BaseModel, Field
20 |
21 | from .models import Message, PromptFunction, PromptVersion
22 | from .prompt_helpers import to_prompt_json
23 |
24 |
25 | class Edge(BaseModel):
26 | relation_type: str = Field(..., description='FACT_PREDICATE_IN_SCREAMING_SNAKE_CASE')
27 | source_entity_id: int = Field(
28 | ..., description='The id of the source entity from the ENTITIES list'
29 | )
30 | target_entity_id: int = Field(
31 | ..., description='The id of the target entity from the ENTITIES list'
32 | )
33 | fact: str = Field(
34 | ...,
35 | description='A natural language description of the relationship between the entities, paraphrased from the source text',
36 | )
37 | valid_at: str | None = Field(
38 | None,
39 | description='The date and time when the relationship described by the edge fact became true or was established. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
40 | )
41 | invalid_at: str | None = Field(
42 | None,
43 | description='The date and time when the relationship described by the edge fact stopped being true or ended. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SS.SSSSSSZ)',
44 | )
45 |
46 |
47 | class ExtractedEdges(BaseModel):
48 | edges: list[Edge]
49 |
50 |
51 | class MissingFacts(BaseModel):
52 | missing_facts: list[str] = Field(..., description="facts that weren't extracted")
53 |
54 |
55 | class Prompt(Protocol):
56 | edge: PromptVersion
57 | reflexion: PromptVersion
58 | extract_attributes: PromptVersion
59 |
60 |
61 | class Versions(TypedDict):
62 | edge: PromptFunction
63 | reflexion: PromptFunction
64 | extract_attributes: PromptFunction
65 |
66 |
67 | def edge(context: dict[str, Any]) -> list[Message]:
68 | return [
69 | Message(
70 | role='system',
71 | content='You are an expert fact extractor that extracts fact triples from text. '
72 | '1. Extracted fact triples should also be extracted with relevant date information.'
73 | '2. Treat the CURRENT TIME as the time the CURRENT MESSAGE was sent. All temporal information should be extracted relative to this time.',
74 | ),
75 | Message(
76 | role='user',
77 | content=f"""
78 | <FACT TYPES>
79 | {context['edge_types']}
80 | </FACT TYPES>
81 |
82 | <PREVIOUS_MESSAGES>
83 | {to_prompt_json([ep for ep in context['previous_episodes']])}
84 | </PREVIOUS_MESSAGES>
85 |
86 | <CURRENT_MESSAGE>
87 | {context['episode_content']}
88 | </CURRENT_MESSAGE>
89 |
90 | <ENTITIES>
91 | {to_prompt_json(context['nodes'])}
92 | </ENTITIES>
93 |
94 | <REFERENCE_TIME>
95 | {context['reference_time']} # ISO 8601 (UTC); used to resolve relative time mentions
96 | </REFERENCE_TIME>
97 |
98 | # TASK
99 | Extract all factual relationships between the given ENTITIES based on the CURRENT MESSAGE.
100 | Only extract facts that:
101 | - involve two DISTINCT ENTITIES from the ENTITIES list,
102 | - are clearly stated or unambiguously implied in the CURRENT MESSAGE,
103 | and can be represented as edges in a knowledge graph.
104 | - Facts should include entity names rather than pronouns whenever possible.
105 | - The FACT TYPES provide a list of the most important types of facts, make sure to extract facts of these types
106 | - The FACT TYPES are not an exhaustive list, extract all facts from the message even if they do not fit into one
107 | of the FACT TYPES
108 | - The FACT TYPES each contain their fact_type_signature which represents the source and target entity types.
109 |
110 | You may use information from the PREVIOUS MESSAGES only to disambiguate references or support continuity.
111 |
112 |
113 | {context['custom_prompt']}
114 |
115 | # EXTRACTION RULES
116 |
117 | 1. **Entity ID Validation**: `source_entity_id` and `target_entity_id` must use only the `id` values from the ENTITIES list provided above.
118 | - **CRITICAL**: Using IDs not in the list will cause the edge to be rejected
119 | 2. Each fact must involve two **distinct** entities.
120 | 3. Use a SCREAMING_SNAKE_CASE string as the `relation_type` (e.g., FOUNDED, WORKS_AT).
121 | 4. Do not emit duplicate or semantically redundant facts.
122 | 5. The `fact` should closely paraphrase the original source sentence(s). Do not verbatim quote the original text.
123 | 6. Use `REFERENCE_TIME` to resolve vague or relative temporal expressions (e.g., "last week").
124 | 7. Do **not** hallucinate or infer temporal bounds from unrelated events.
125 |
126 | # DATETIME RULES
127 |
128 | - Use ISO 8601 with “Z” suffix (UTC) (e.g., 2025-04-30T00:00:00Z).
129 | - If the fact is ongoing (present tense), set `valid_at` to REFERENCE_TIME.
130 | - If a change/termination is expressed, set `invalid_at` to the relevant timestamp.
131 | - Leave both fields `null` if no explicit or resolvable time is stated.
132 | - If only a date is mentioned (no time), assume 00:00:00.
133 | - If only a year is mentioned, use January 1st at 00:00:00.
134 | """,
135 | ),
136 | ]
137 |
138 |
139 | def reflexion(context: dict[str, Any]) -> list[Message]:
140 | sys_prompt = """You are an AI assistant that determines which facts have not been extracted from the given context"""
141 |
142 | user_prompt = f"""
143 | <PREVIOUS MESSAGES>
144 | {to_prompt_json([ep for ep in context['previous_episodes']])}
145 | </PREVIOUS MESSAGES>
146 | <CURRENT MESSAGE>
147 | {context['episode_content']}
148 | </CURRENT MESSAGE>
149 |
150 | <EXTRACTED ENTITIES>
151 | {context['nodes']}
152 | </EXTRACTED ENTITIES>
153 |
154 | <EXTRACTED FACTS>
155 | {context['extracted_facts']}
156 | </EXTRACTED FACTS>
157 |
158 | Given the above MESSAGES, list of EXTRACTED ENTITIES entities, and list of EXTRACTED FACTS;
159 | determine if any facts haven't been extracted.
160 | """
161 | return [
162 | Message(role='system', content=sys_prompt),
163 | Message(role='user', content=user_prompt),
164 | ]
165 |
166 |
167 | def extract_attributes(context: dict[str, Any]) -> list[Message]:
168 | return [
169 | Message(
170 | role='system',
171 | content='You are a helpful assistant that extracts fact properties from the provided text.',
172 | ),
173 | Message(
174 | role='user',
175 | content=f"""
176 |
177 | <MESSAGE>
178 | {to_prompt_json(context['episode_content'])}
179 | </MESSAGE>
180 | <REFERENCE TIME>
181 | {context['reference_time']}
182 | </REFERENCE TIME>
183 |
184 | Given the above MESSAGE, its REFERENCE TIME, and the following FACT, update any of its attributes based on the information provided
185 | in MESSAGE. Use the provided attribute descriptions to better understand how each attribute should be determined.
186 |
187 | Guidelines:
188 | 1. Do not hallucinate entity property values if they cannot be found in the current context.
189 | 2. Only use the provided MESSAGES and FACT to set attribute values.
190 |
191 | <FACT>
192 | {context['fact']}
193 | </FACT>
194 | """,
195 | ),
196 | ]
197 |
198 |
199 | versions: Versions = {
200 | 'edge': edge,
201 | 'reflexion': reflexion,
202 | 'extract_attributes': extract_attributes,
203 | }
204 |
```
--------------------------------------------------------------------------------
/graphiti_core/embedder/gemini.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import logging
18 | from collections.abc import Iterable
19 | from typing import TYPE_CHECKING
20 |
21 | if TYPE_CHECKING:
22 | from google import genai
23 | from google.genai import types
24 | else:
25 | try:
26 | from google import genai
27 | from google.genai import types
28 | except ImportError:
29 | raise ImportError(
30 | 'google-genai is required for GeminiEmbedder. '
31 | 'Install it with: pip install graphiti-core[google-genai]'
32 | ) from None
33 |
34 | from pydantic import Field
35 |
36 | from .client import EmbedderClient, EmbedderConfig
37 |
38 | logger = logging.getLogger(__name__)
39 |
40 | DEFAULT_EMBEDDING_MODEL = 'text-embedding-001' # gemini-embedding-001 or text-embedding-005
41 |
42 | DEFAULT_BATCH_SIZE = 100
43 |
44 |
45 | class GeminiEmbedderConfig(EmbedderConfig):
46 | embedding_model: str = Field(default=DEFAULT_EMBEDDING_MODEL)
47 | api_key: str | None = None
48 |
49 |
50 | class GeminiEmbedder(EmbedderClient):
51 | """
52 | Google Gemini Embedder Client
53 | """
54 |
55 | def __init__(
56 | self,
57 | config: GeminiEmbedderConfig | None = None,
58 | client: 'genai.Client | None' = None,
59 | batch_size: int | None = None,
60 | ):
61 | """
62 | Initialize the GeminiEmbedder with the provided configuration and client.
63 |
64 | Args:
65 | config (GeminiEmbedderConfig | None): The configuration for the GeminiEmbedder, including API key, model, base URL, temperature, and max tokens.
66 | client (genai.Client | None): An optional async client instance to use. If not provided, a new genai.Client is created.
67 | batch_size (int | None): An optional batch size to use. If not provided, the default batch size will be used.
68 | """
69 | if config is None:
70 | config = GeminiEmbedderConfig()
71 |
72 | self.config = config
73 |
74 | if client is None:
75 | self.client = genai.Client(api_key=config.api_key)
76 | else:
77 | self.client = client
78 |
79 | if batch_size is None and self.config.embedding_model == 'gemini-embedding-001':
80 | # Gemini API has a limit on the number of instances per request
81 | # https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api
82 | self.batch_size = 1
83 | elif batch_size is None:
84 | self.batch_size = DEFAULT_BATCH_SIZE
85 | else:
86 | self.batch_size = batch_size
87 |
88 | async def create(
89 | self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
90 | ) -> list[float]:
91 | """
92 | Create embeddings for the given input data using Google's Gemini embedding model.
93 |
94 | Args:
95 | input_data: The input data to create embeddings for. Can be a string, list of strings,
96 | or an iterable of integers or iterables of integers.
97 |
98 | Returns:
99 | A list of floats representing the embedding vector.
100 | """
101 | # Generate embeddings
102 | result = await self.client.aio.models.embed_content(
103 | model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
104 | contents=[input_data], # type: ignore[arg-type] # mypy fails on broad union type
105 | config=types.EmbedContentConfig(output_dimensionality=self.config.embedding_dim),
106 | )
107 |
108 | if not result.embeddings or len(result.embeddings) == 0 or not result.embeddings[0].values:
109 | raise ValueError('No embeddings returned from Gemini API in create()')
110 |
111 | return result.embeddings[0].values
112 |
113 | async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
114 | """
115 | Create embeddings for a batch of input data using Google's Gemini embedding model.
116 |
117 | This method handles batching to respect the Gemini API's limits on the number
118 | of instances that can be processed in a single request.
119 |
120 | Args:
121 | input_data_list: A list of strings to create embeddings for.
122 |
123 | Returns:
124 | A list of embedding vectors (each vector is a list of floats).
125 | """
126 | if not input_data_list:
127 | return []
128 |
129 | batch_size = self.batch_size
130 | all_embeddings = []
131 |
132 | # Process inputs in batches
133 | for i in range(0, len(input_data_list), batch_size):
134 | batch = input_data_list[i : i + batch_size]
135 |
136 | try:
137 | # Generate embeddings for this batch
138 | result = await self.client.aio.models.embed_content(
139 | model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
140 | contents=batch, # type: ignore[arg-type] # mypy fails on broad union type
141 | config=types.EmbedContentConfig(
142 | output_dimensionality=self.config.embedding_dim
143 | ),
144 | )
145 |
146 | if not result.embeddings or len(result.embeddings) == 0:
147 | raise Exception('No embeddings returned')
148 |
149 | # Process embeddings from this batch
150 | for embedding in result.embeddings:
151 | if not embedding.values:
152 | raise ValueError('Empty embedding values returned')
153 | all_embeddings.append(embedding.values)
154 |
155 | except Exception as e:
156 | # If batch processing fails, fall back to individual processing
157 | logger.warning(
158 | f'Batch embedding failed for batch {i // batch_size + 1}, falling back to individual processing: {e}'
159 | )
160 |
161 | for item in batch:
162 | try:
163 | # Process each item individually
164 | result = await self.client.aio.models.embed_content(
165 | model=self.config.embedding_model or DEFAULT_EMBEDDING_MODEL,
166 | contents=[item], # type: ignore[arg-type] # mypy fails on broad union type
167 | config=types.EmbedContentConfig(
168 | output_dimensionality=self.config.embedding_dim
169 | ),
170 | )
171 |
172 | if not result.embeddings or len(result.embeddings) == 0:
173 | raise ValueError('No embeddings returned from Gemini API')
174 | if not result.embeddings[0].values:
175 | raise ValueError('Empty embedding values returned')
176 |
177 | all_embeddings.append(result.embeddings[0].values)
178 |
179 | except Exception as individual_error:
180 | logger.error(f'Failed to embed individual item: {individual_error}')
181 | raise individual_error
182 |
183 | return all_embeddings
184 |
```
--------------------------------------------------------------------------------
/mcp_server/docker/README-falkordb-combined.md:
--------------------------------------------------------------------------------
```markdown
1 | # FalkorDB + Graphiti MCP Server Combined Image
2 |
3 | This Docker setup bundles FalkorDB (graph database) and the Graphiti MCP Server into a single container image for simplified deployment.
4 |
5 | ## Overview
6 |
7 | The combined image extends the official FalkorDB Docker image to include:
8 | - **FalkorDB**: Redis-based graph database running on port 6379
9 | - **FalkorDB Web UI**: Graph visualization interface on port 3000
10 | - **Graphiti MCP Server**: Knowledge graph API on port 8000
11 |
12 | Both services are managed by a startup script that launches FalkorDB as a daemon and the MCP server in the foreground.
13 |
14 | ## Quick Start
15 |
16 | ### Using Docker Compose (Recommended)
17 |
18 | 1. Create a `.env` file in the `mcp_server` directory:
19 |
20 | ```bash
21 | # Required
22 | OPENAI_API_KEY=your_openai_api_key
23 |
24 | # Optional
25 | GRAPHITI_GROUP_ID=main
26 | SEMAPHORE_LIMIT=10
27 | FALKORDB_PASSWORD=
28 | ```
29 |
30 | 2. Start the combined service:
31 |
32 | ```bash
33 | cd mcp_server
34 | docker compose -f docker/docker-compose-falkordb-combined.yml up
35 | ```
36 |
37 | 3. Access the services:
38 | - MCP Server: http://localhost:8000/mcp/
39 | - FalkorDB Web UI: http://localhost:3000
40 | - FalkorDB (Redis): localhost:6379
41 |
42 | ### Using Docker Run
43 |
44 | ```bash
45 | docker run -d \
46 | -p 6379:6379 \
47 | -p 3000:3000 \
48 | -p 8000:8000 \
49 | -e OPENAI_API_KEY=your_key \
50 | -e GRAPHITI_GROUP_ID=main \
51 | -v falkordb_data:/var/lib/falkordb/data \
52 | zepai/graphiti-falkordb:latest
53 | ```
54 |
55 | ## Building the Image
56 |
57 | ### Build with Default Version
58 |
59 | ```bash
60 | docker compose -f docker/docker-compose-falkordb-combined.yml build
61 | ```
62 |
63 | ### Build with Specific Graphiti Version
64 |
65 | ```bash
66 | GRAPHITI_CORE_VERSION=0.22.0 docker compose -f docker/docker-compose-falkordb-combined.yml build
67 | ```
68 |
69 | ### Build Arguments
70 |
71 | - `GRAPHITI_CORE_VERSION`: Version of graphiti-core package (default: 0.22.0)
72 | - `MCP_SERVER_VERSION`: MCP server version tag (default: 1.0.0rc0)
73 | - `BUILD_DATE`: Build timestamp
74 | - `VCS_REF`: Git commit hash
75 |
76 | ## Configuration
77 |
78 | ### Environment Variables
79 |
80 | All environment variables from the standard MCP server are supported:
81 |
82 | **Required:**
83 | - `OPENAI_API_KEY`: OpenAI API key for LLM operations
84 |
85 | **Optional:**
86 | - `BROWSER`: Enable FalkorDB Browser web UI on port 3000 (default: "1", set to "0" to disable)
87 | - `GRAPHITI_GROUP_ID`: Namespace for graph data (default: "main")
88 | - `SEMAPHORE_LIMIT`: Concurrency limit for episode processing (default: 10)
89 | - `FALKORDB_PASSWORD`: Password for FalkorDB (optional)
90 | - `FALKORDB_DATABASE`: FalkorDB database name (default: "default_db")
91 |
92 | **Other LLM Providers:**
93 | - `ANTHROPIC_API_KEY`: For Claude models
94 | - `GOOGLE_API_KEY`: For Gemini models
95 | - `GROQ_API_KEY`: For Groq models
96 |
97 | ### Volumes
98 |
99 | - `/var/lib/falkordb/data`: Persistent storage for graph data
100 | - `/var/log/graphiti`: MCP server and FalkorDB Browser logs
101 |
102 | ## Service Management
103 |
104 | ### View Logs
105 |
106 | ```bash
107 | # All logs (both services stdout/stderr)
108 | docker compose -f docker/docker-compose-falkordb-combined.yml logs -f
109 |
110 | # Only container logs
111 | docker compose -f docker/docker-compose-falkordb-combined.yml logs -f graphiti-falkordb
112 | ```
113 |
114 | ### Restart Services
115 |
116 | ```bash
117 | # Restart entire container (both services)
118 | docker compose -f docker/docker-compose-falkordb-combined.yml restart
119 |
120 | # Check FalkorDB status
121 | docker compose -f docker/docker-compose-falkordb-combined.yml exec graphiti-falkordb redis-cli ping
122 |
123 | # Check MCP server status
124 | curl http://localhost:8000/health
125 | ```
126 |
127 | ### Disabling the FalkorDB Browser
128 |
129 | To disable the FalkorDB Browser web UI (port 3000), set the `BROWSER` environment variable to `0`:
130 |
131 | ```bash
132 | # Using docker run
133 | docker run -d \
134 | -p 6379:6379 \
135 | -p 3000:3000 \
136 | -p 8000:8000 \
137 | -e BROWSER=0 \
138 | -e OPENAI_API_KEY=your_key \
139 | zepai/graphiti-falkordb:latest
140 |
141 | # Using docker-compose
142 | # Add to your .env file:
143 | BROWSER=0
144 | ```
145 |
146 | When disabled, only FalkorDB (port 6379) and the MCP server (port 8000) will run.
147 |
148 | ## Health Checks
149 |
150 | The container includes a health check that verifies:
151 | 1. FalkorDB is responding to ping
152 | 2. MCP server health endpoint is accessible
153 |
154 | Check health status:
155 | ```bash
156 | docker compose -f docker/docker-compose-falkordb-combined.yml ps
157 | ```
158 |
159 | ## Architecture
160 |
161 | ### Process Structure
162 | ```
163 | start-services.sh (PID 1)
164 | ├── redis-server (FalkorDB daemon)
165 | ├── node server.js (FalkorDB Browser - background, if BROWSER=1)
166 | └── uv run main.py (MCP server - foreground)
167 | ```
168 |
169 | The startup script launches FalkorDB as a background daemon, waits for it to be ready, optionally starts the FalkorDB Browser (if `BROWSER=1`), then starts the MCP server in the foreground. When the MCP server stops, the container exits.
170 |
171 | ### Directory Structure
172 | ```
173 | /app/mcp/ # MCP server application
174 | ├── main.py
175 | ├── src/
176 | ├── config/
177 | │ └── config.yaml # FalkorDB-specific configuration
178 | └── .graphiti-core-version # Installed version info
179 |
180 | /var/lib/falkordb/data/ # Persistent graph storage
181 | /var/lib/falkordb/browser/ # FalkorDB Browser web UI
182 | /var/log/graphiti/ # MCP server and Browser logs
183 | /start-services.sh # Startup script
184 | ```
185 |
186 | ## Benefits of Combined Image
187 |
188 | 1. **Simplified Deployment**: Single container to manage
189 | 2. **Reduced Network Latency**: Localhost communication between services
190 | 3. **Easier Development**: One command to start entire stack
191 | 4. **Unified Logging**: All logs available via docker logs
192 | 5. **Resource Efficiency**: Shared base image and dependencies
193 |
194 | ## Troubleshooting
195 |
196 | ### FalkorDB Not Starting
197 |
198 | Check container logs:
199 | ```bash
200 | docker compose -f docker/docker-compose-falkordb-combined.yml logs graphiti-falkordb
201 | ```
202 |
203 | ### MCP Server Connection Issues
204 |
205 | 1. Verify FalkorDB is running:
206 | ```bash
207 | docker compose -f docker/docker-compose-falkordb-combined.yml exec graphiti-falkordb redis-cli ping
208 | ```
209 |
210 | 2. Check MCP server health:
211 | ```bash
212 | curl http://localhost:8000/health
213 | ```
214 |
215 | 3. View all container logs:
216 | ```bash
217 | docker compose -f docker/docker-compose-falkordb-combined.yml logs -f
218 | ```
219 |
220 | ### Port Conflicts
221 |
222 | If ports 6379, 3000, or 8000 are already in use, modify the port mappings in `docker-compose-falkordb-combined.yml`:
223 |
224 | ```yaml
225 | ports:
226 | - "16379:6379" # Use different external port
227 | - "13000:3000"
228 | - "18000:8000"
229 | ```
230 |
231 | ## Production Considerations
232 |
233 | 1. **Resource Limits**: Add resource constraints in docker-compose:
234 | ```yaml
235 | deploy:
236 | resources:
237 | limits:
238 | cpus: '2'
239 | memory: 4G
240 | ```
241 |
242 | 2. **Persistent Volumes**: Use named volumes or bind mounts for production data
243 | 3. **Monitoring**: Export logs to external monitoring system
244 | 4. **Backups**: Regular backups of `/var/lib/falkordb/data` volume
245 | 5. **Security**: Set `FALKORDB_PASSWORD` in production environments
246 |
247 | ## Comparison with Separate Containers
248 |
249 | | Aspect | Combined Image | Separate Containers |
250 | |--------|---------------|---------------------|
251 | | Setup Complexity | Simple (one container) | Moderate (service dependencies) |
252 | | Network Latency | Lower (localhost) | Higher (container network) |
253 | | Resource Usage | Lower (shared base) | Higher (separate images) |
254 | | Scalability | Limited | Better (scale independently) |
255 | | Debugging | Harder (multiple processes) | Easier (isolated services) |
256 | | Production Use | Development/Single-node | Recommended |
257 |
258 | ## See Also
259 |
260 | - [Main MCP Server README](../README.md)
261 | - [FalkorDB Documentation](https://docs.falkordb.com/)
262 | - [Docker Compose Documentation](https://docs.docker.com/compose/)
263 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_falkordb_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | FalkorDB integration test for the Graphiti MCP Server.
4 | Tests MCP server functionality with FalkorDB as the graph database backend.
5 | """
6 |
7 | import asyncio
8 | import json
9 | import time
10 | from typing import Any
11 |
12 | from mcp import StdioServerParameters
13 | from mcp.client.stdio import stdio_client
14 |
15 |
16 | class GraphitiFalkorDBIntegrationTest:
17 | """Integration test client for Graphiti MCP Server using FalkorDB backend."""
18 |
19 | def __init__(self):
20 | self.test_group_id = f'falkor_test_group_{int(time.time())}'
21 | self.session = None
22 |
23 | async def __aenter__(self):
24 | """Start the MCP client session with FalkorDB configuration."""
25 | # Configure server parameters to run with FalkorDB backend
26 | server_params = StdioServerParameters(
27 | command='uv',
28 | args=['run', 'main.py', '--transport', 'stdio', '--database-provider', 'falkordb'],
29 | env={
30 | 'FALKORDB_URI': 'redis://localhost:6379',
31 | 'FALKORDB_PASSWORD': '', # No password for test instance
32 | 'FALKORDB_DATABASE': 'default_db',
33 | 'OPENAI_API_KEY': 'dummy_key_for_testing',
34 | 'GRAPHITI_GROUP_ID': self.test_group_id,
35 | },
36 | )
37 |
38 | # Start the stdio client
39 | self.session = await stdio_client(server_params).__aenter__()
40 | print(' 📡 Started MCP client session with FalkorDB backend')
41 | return self
42 |
43 | async def __aexit__(self, exc_type, exc_val, exc_tb):
44 | """Clean up the MCP client session."""
45 | if self.session:
46 | await self.session.close()
47 | print(' 🔌 Closed MCP client session')
48 |
49 | async def call_mcp_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
50 | """Call an MCP tool via the stdio client."""
51 | try:
52 | result = await self.session.call_tool(tool_name, arguments)
53 | if hasattr(result, 'content') and result.content:
54 | # Handle different content types
55 | if hasattr(result.content[0], 'text'):
56 | content = result.content[0].text
57 | try:
58 | return json.loads(content)
59 | except json.JSONDecodeError:
60 | return {'raw_response': content}
61 | else:
62 | return {'content': str(result.content[0])}
63 | return {'result': 'success', 'content': None}
64 | except Exception as e:
65 | return {'error': str(e), 'tool': tool_name, 'arguments': arguments}
66 |
67 | async def test_server_status(self) -> bool:
68 | """Test the get_status tool to verify FalkorDB connectivity."""
69 | print(' 🏥 Testing server status with FalkorDB...')
70 | result = await self.call_mcp_tool('get_status', {})
71 |
72 | if 'error' in result:
73 | print(f' ❌ Status check failed: {result["error"]}')
74 | return False
75 |
76 | # Check if status indicates FalkorDB is working
77 | status_text = result.get('raw_response', result.get('content', ''))
78 | if 'running' in str(status_text).lower() or 'ready' in str(status_text).lower():
79 | print(' ✅ Server status OK with FalkorDB')
80 | return True
81 | else:
82 | print(f' ⚠️ Status unclear: {status_text}')
83 | return True # Don't fail on unclear status
84 |
85 | async def test_add_episode(self) -> bool:
86 | """Test adding an episode to FalkorDB."""
87 | print(' 📝 Testing episode addition to FalkorDB...')
88 |
89 | episode_data = {
90 | 'name': 'FalkorDB Test Episode',
91 | 'episode_body': 'This is a test episode to verify FalkorDB integration works correctly.',
92 | 'source': 'text',
93 | 'source_description': 'Integration test for FalkorDB backend',
94 | }
95 |
96 | result = await self.call_mcp_tool('add_episode', episode_data)
97 |
98 | if 'error' in result:
99 | print(f' ❌ Add episode failed: {result["error"]}')
100 | return False
101 |
102 | print(' ✅ Episode added successfully to FalkorDB')
103 | return True
104 |
105 | async def test_search_functionality(self) -> bool:
106 | """Test search functionality with FalkorDB."""
107 | print(' 🔍 Testing search functionality with FalkorDB...')
108 |
109 | # Give some time for episode processing
110 | await asyncio.sleep(2)
111 |
112 | # Test node search
113 | search_result = await self.call_mcp_tool(
114 | 'search_nodes', {'query': 'FalkorDB test episode', 'limit': 5}
115 | )
116 |
117 | if 'error' in search_result:
118 | print(f' ⚠️ Search returned error (may be expected): {search_result["error"]}')
119 | return True # Don't fail on search errors in integration test
120 |
121 | print(' ✅ Search functionality working with FalkorDB')
122 | return True
123 |
124 | async def test_clear_graph(self) -> bool:
125 | """Test clearing the graph in FalkorDB."""
126 | print(' 🧹 Testing graph clearing in FalkorDB...')
127 |
128 | result = await self.call_mcp_tool('clear_graph', {})
129 |
130 | if 'error' in result:
131 | print(f' ❌ Clear graph failed: {result["error"]}')
132 | return False
133 |
134 | print(' ✅ Graph cleared successfully in FalkorDB')
135 | return True
136 |
137 |
138 | async def run_falkordb_integration_test() -> bool:
139 | """Run the complete FalkorDB integration test suite."""
140 | print('🧪 Starting FalkorDB Integration Test Suite')
141 | print('=' * 55)
142 |
143 | test_results = []
144 |
145 | try:
146 | async with GraphitiFalkorDBIntegrationTest() as test_client:
147 | print(f' 🎯 Using test group: {test_client.test_group_id}')
148 |
149 | # Run test suite
150 | tests = [
151 | ('Server Status', test_client.test_server_status),
152 | ('Add Episode', test_client.test_add_episode),
153 | ('Search Functionality', test_client.test_search_functionality),
154 | ('Clear Graph', test_client.test_clear_graph),
155 | ]
156 |
157 | for test_name, test_func in tests:
158 | print(f'\n🔬 Running {test_name} Test...')
159 | try:
160 | result = await test_func()
161 | test_results.append((test_name, result))
162 | if result:
163 | print(f' ✅ {test_name}: PASSED')
164 | else:
165 | print(f' ❌ {test_name}: FAILED')
166 | except Exception as e:
167 | print(f' 💥 {test_name}: ERROR - {e}')
168 | test_results.append((test_name, False))
169 |
170 | except Exception as e:
171 | print(f'💥 Test setup failed: {e}')
172 | return False
173 |
174 | # Summary
175 | print('\n' + '=' * 55)
176 | print('📊 FalkorDB Integration Test Results:')
177 | print('-' * 30)
178 |
179 | passed = sum(1 for _, result in test_results if result)
180 | total = len(test_results)
181 |
182 | for test_name, result in test_results:
183 | status = '✅ PASS' if result else '❌ FAIL'
184 | print(f' {test_name}: {status}')
185 |
186 | print(f'\n🎯 Overall: {passed}/{total} tests passed')
187 |
188 | if passed == total:
189 | print('🎉 All FalkorDB integration tests PASSED!')
190 | return True
191 | else:
192 | print('⚠️ Some FalkorDB integration tests failed')
193 | return passed >= (total * 0.7) # Pass if 70% of tests pass
194 |
195 |
196 | if __name__ == '__main__':
197 | success = asyncio.run(run_falkordb_integration_test())
198 | exit(0 if success else 1)
199 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_configuration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """Test script for configuration loading and factory patterns."""
3 |
4 | import asyncio
5 | import os
6 | import sys
7 | from pathlib import Path
8 |
9 | # Add the current directory to the path
10 | sys.path.insert(0, str(Path(__file__).parent.parent / 'src'))
11 |
12 | from config.schema import GraphitiConfig
13 | from services.factories import DatabaseDriverFactory, EmbedderFactory, LLMClientFactory
14 |
15 |
16 | def test_config_loading():
17 | """Test loading configuration from YAML and environment variables."""
18 | print('Testing configuration loading...')
19 |
20 | # Test with default config.yaml
21 | config = GraphitiConfig()
22 |
23 | print('✓ Loaded configuration successfully')
24 | print(f' - Server transport: {config.server.transport}')
25 | print(f' - LLM provider: {config.llm.provider}')
26 | print(f' - LLM model: {config.llm.model}')
27 | print(f' - Embedder provider: {config.embedder.provider}')
28 | print(f' - Database provider: {config.database.provider}')
29 | print(f' - Group ID: {config.graphiti.group_id}')
30 |
31 | # Test environment variable override
32 | os.environ['LLM__PROVIDER'] = 'anthropic'
33 | os.environ['LLM__MODEL'] = 'claude-3-opus'
34 | config2 = GraphitiConfig()
35 |
36 | print('\n✓ Environment variable overrides work')
37 | print(f' - LLM provider (overridden): {config2.llm.provider}')
38 | print(f' - LLM model (overridden): {config2.llm.model}')
39 |
40 | # Clean up env vars
41 | del os.environ['LLM__PROVIDER']
42 | del os.environ['LLM__MODEL']
43 |
44 | assert config is not None
45 | assert config2 is not None
46 |
47 | # Return the first config for subsequent tests
48 | return config
49 |
50 |
51 | def test_llm_factory(config: GraphitiConfig):
52 | """Test LLM client factory creation."""
53 | print('\nTesting LLM client factory...')
54 |
55 | # Test OpenAI client creation (if API key is set)
56 | if (
57 | config.llm.provider == 'openai'
58 | and config.llm.providers.openai
59 | and config.llm.providers.openai.api_key
60 | ):
61 | try:
62 | client = LLMClientFactory.create(config.llm)
63 | print(f'✓ Created {config.llm.provider} LLM client successfully')
64 | print(f' - Model: {client.model}')
65 | print(f' - Temperature: {client.temperature}')
66 | except Exception as e:
67 | print(f'✗ Failed to create LLM client: {e}')
68 | else:
69 | print(f'⚠ Skipping LLM factory test (no API key configured for {config.llm.provider})')
70 |
71 | # Test switching providers
72 | test_config = config.llm.model_copy()
73 | test_config.provider = 'gemini'
74 | if not test_config.providers.gemini:
75 | from config.schema import GeminiProviderConfig
76 |
77 | test_config.providers.gemini = GeminiProviderConfig(api_key='dummy_value_for_testing')
78 | else:
79 | test_config.providers.gemini.api_key = 'dummy_value_for_testing'
80 |
81 | try:
82 | client = LLMClientFactory.create(test_config)
83 | print('✓ Factory supports provider switching (tested with Gemini)')
84 | except Exception as e:
85 | print(f'✗ Factory provider switching failed: {e}')
86 |
87 |
88 | def test_embedder_factory(config: GraphitiConfig):
89 | """Test Embedder client factory creation."""
90 | print('\nTesting Embedder client factory...')
91 |
92 | # Test OpenAI embedder creation (if API key is set)
93 | if (
94 | config.embedder.provider == 'openai'
95 | and config.embedder.providers.openai
96 | and config.embedder.providers.openai.api_key
97 | ):
98 | try:
99 | _ = EmbedderFactory.create(config.embedder)
100 | print(f'✓ Created {config.embedder.provider} Embedder client successfully')
101 | # The embedder client may not expose model/dimensions as attributes
102 | print(f' - Configured model: {config.embedder.model}')
103 | print(f' - Configured dimensions: {config.embedder.dimensions}')
104 | except Exception as e:
105 | print(f'✗ Failed to create Embedder client: {e}')
106 | else:
107 | print(
108 | f'⚠ Skipping Embedder factory test (no API key configured for {config.embedder.provider})'
109 | )
110 |
111 |
112 | async def test_database_factory(config: GraphitiConfig):
113 | """Test Database driver factory creation."""
114 | print('\nTesting Database driver factory...')
115 |
116 | # Test Neo4j config creation
117 | if config.database.provider == 'neo4j' and config.database.providers.neo4j:
118 | try:
119 | db_config = DatabaseDriverFactory.create_config(config.database)
120 | print(f'✓ Created {config.database.provider} configuration successfully')
121 | print(f' - URI: {db_config["uri"]}')
122 | print(f' - User: {db_config["user"]}')
123 | print(
124 | f' - Password: {"*" * len(db_config["password"]) if db_config["password"] else "None"}'
125 | )
126 |
127 | # Test actual connection would require initializing Graphiti
128 | from graphiti_core import Graphiti
129 |
130 | try:
131 | # This will fail if Neo4j is not running, but tests the config
132 | graphiti = Graphiti(
133 | uri=db_config['uri'],
134 | user=db_config['user'],
135 | password=db_config['password'],
136 | )
137 | await graphiti.driver.client.verify_connectivity()
138 | print(' ✓ Successfully connected to Neo4j')
139 | await graphiti.driver.client.close()
140 | except Exception as e:
141 | print(f' ⚠ Could not connect to Neo4j (is it running?): {type(e).__name__}')
142 | except Exception as e:
143 | print(f'✗ Failed to create Database configuration: {e}')
144 | else:
145 | print(f'⚠ Skipping Database factory test (no configuration for {config.database.provider})')
146 |
147 |
148 | def test_cli_override():
149 | """Test CLI argument override functionality."""
150 | print('\nTesting CLI argument override...')
151 |
152 | # Simulate argparse Namespace
153 | class Args:
154 | config = Path('config.yaml')
155 | transport = 'stdio'
156 | llm_provider = 'anthropic'
157 | model = 'claude-3-sonnet'
158 | temperature = 0.5
159 | embedder_provider = 'voyage'
160 | embedder_model = 'voyage-3'
161 | database_provider = 'falkordb'
162 | group_id = 'test-group'
163 | user_id = 'test-user'
164 |
165 | config = GraphitiConfig()
166 | config.apply_cli_overrides(Args())
167 |
168 | print('✓ CLI overrides applied successfully')
169 | print(f' - Transport: {config.server.transport}')
170 | print(f' - LLM provider: {config.llm.provider}')
171 | print(f' - LLM model: {config.llm.model}')
172 | print(f' - Temperature: {config.llm.temperature}')
173 | print(f' - Embedder provider: {config.embedder.provider}')
174 | print(f' - Database provider: {config.database.provider}')
175 | print(f' - Group ID: {config.graphiti.group_id}')
176 | print(f' - User ID: {config.graphiti.user_id}')
177 |
178 |
179 | async def main():
180 | """Run all tests."""
181 | print('=' * 60)
182 | print('Configuration and Factory Pattern Test Suite')
183 | print('=' * 60)
184 |
185 | try:
186 | # Test configuration loading
187 | config = test_config_loading()
188 |
189 | # Test factories
190 | test_llm_factory(config)
191 | test_embedder_factory(config)
192 | await test_database_factory(config)
193 |
194 | # Test CLI overrides
195 | test_cli_override()
196 |
197 | print('\n' + '=' * 60)
198 | print('✓ All tests completed successfully!')
199 | print('=' * 60)
200 |
201 | except Exception as e:
202 | print(f'\n✗ Test suite failed: {e}')
203 | sys.exit(1)
204 |
205 |
206 | if __name__ == '__main__':
207 | asyncio.run(main())
208 |
```
--------------------------------------------------------------------------------
/graphiti_core/search/search_config_recipes.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from graphiti_core.search.search_config import (
18 | CommunityReranker,
19 | CommunitySearchConfig,
20 | CommunitySearchMethod,
21 | EdgeReranker,
22 | EdgeSearchConfig,
23 | EdgeSearchMethod,
24 | EpisodeReranker,
25 | EpisodeSearchConfig,
26 | EpisodeSearchMethod,
27 | NodeReranker,
28 | NodeSearchConfig,
29 | NodeSearchMethod,
30 | SearchConfig,
31 | )
32 |
33 | # Performs a hybrid search with rrf reranking over edges, nodes, and communities
34 | COMBINED_HYBRID_SEARCH_RRF = SearchConfig(
35 | edge_config=EdgeSearchConfig(
36 | search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
37 | reranker=EdgeReranker.rrf,
38 | ),
39 | node_config=NodeSearchConfig(
40 | search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
41 | reranker=NodeReranker.rrf,
42 | ),
43 | episode_config=EpisodeSearchConfig(
44 | search_methods=[
45 | EpisodeSearchMethod.bm25,
46 | ],
47 | reranker=EpisodeReranker.rrf,
48 | ),
49 | community_config=CommunitySearchConfig(
50 | search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
51 | reranker=CommunityReranker.rrf,
52 | ),
53 | )
54 |
55 | # Performs a hybrid search with mmr reranking over edges, nodes, and communities
56 | COMBINED_HYBRID_SEARCH_MMR = SearchConfig(
57 | edge_config=EdgeSearchConfig(
58 | search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
59 | reranker=EdgeReranker.mmr,
60 | mmr_lambda=1,
61 | ),
62 | node_config=NodeSearchConfig(
63 | search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
64 | reranker=NodeReranker.mmr,
65 | mmr_lambda=1,
66 | ),
67 | episode_config=EpisodeSearchConfig(
68 | search_methods=[
69 | EpisodeSearchMethod.bm25,
70 | ],
71 | reranker=EpisodeReranker.rrf,
72 | ),
73 | community_config=CommunitySearchConfig(
74 | search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
75 | reranker=CommunityReranker.mmr,
76 | mmr_lambda=1,
77 | ),
78 | )
79 |
80 | # Performs a full-text search, similarity search, and bfs with cross_encoder reranking over edges, nodes, and communities
81 | COMBINED_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
82 | edge_config=EdgeSearchConfig(
83 | search_methods=[
84 | EdgeSearchMethod.bm25,
85 | EdgeSearchMethod.cosine_similarity,
86 | EdgeSearchMethod.bfs,
87 | ],
88 | reranker=EdgeReranker.cross_encoder,
89 | ),
90 | node_config=NodeSearchConfig(
91 | search_methods=[
92 | NodeSearchMethod.bm25,
93 | NodeSearchMethod.cosine_similarity,
94 | NodeSearchMethod.bfs,
95 | ],
96 | reranker=NodeReranker.cross_encoder,
97 | ),
98 | episode_config=EpisodeSearchConfig(
99 | search_methods=[
100 | EpisodeSearchMethod.bm25,
101 | ],
102 | reranker=EpisodeReranker.cross_encoder,
103 | ),
104 | community_config=CommunitySearchConfig(
105 | search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
106 | reranker=CommunityReranker.cross_encoder,
107 | ),
108 | )
109 |
110 | # performs a hybrid search over edges with rrf reranking
111 | EDGE_HYBRID_SEARCH_RRF = SearchConfig(
112 | edge_config=EdgeSearchConfig(
113 | search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
114 | reranker=EdgeReranker.rrf,
115 | )
116 | )
117 |
118 | # performs a hybrid search over edges with mmr reranking
119 | EDGE_HYBRID_SEARCH_MMR = SearchConfig(
120 | edge_config=EdgeSearchConfig(
121 | search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
122 | reranker=EdgeReranker.mmr,
123 | )
124 | )
125 |
126 | # performs a hybrid search over edges with node distance reranking
127 | EDGE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
128 | edge_config=EdgeSearchConfig(
129 | search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
130 | reranker=EdgeReranker.node_distance,
131 | ),
132 | )
133 |
134 | # performs a hybrid search over edges with episode mention reranking
135 | EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
136 | edge_config=EdgeSearchConfig(
137 | search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity],
138 | reranker=EdgeReranker.episode_mentions,
139 | )
140 | )
141 |
142 | # performs a hybrid search over edges with cross encoder reranking
143 | EDGE_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
144 | edge_config=EdgeSearchConfig(
145 | search_methods=[
146 | EdgeSearchMethod.bm25,
147 | EdgeSearchMethod.cosine_similarity,
148 | EdgeSearchMethod.bfs,
149 | ],
150 | reranker=EdgeReranker.cross_encoder,
151 | ),
152 | limit=10,
153 | )
154 |
155 | # performs a hybrid search over nodes with rrf reranking
156 | NODE_HYBRID_SEARCH_RRF = SearchConfig(
157 | node_config=NodeSearchConfig(
158 | search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
159 | reranker=NodeReranker.rrf,
160 | )
161 | )
162 |
163 | # performs a hybrid search over nodes with mmr reranking
164 | NODE_HYBRID_SEARCH_MMR = SearchConfig(
165 | node_config=NodeSearchConfig(
166 | search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
167 | reranker=NodeReranker.mmr,
168 | )
169 | )
170 |
171 | # performs a hybrid search over nodes with node distance reranking
172 | NODE_HYBRID_SEARCH_NODE_DISTANCE = SearchConfig(
173 | node_config=NodeSearchConfig(
174 | search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
175 | reranker=NodeReranker.node_distance,
176 | )
177 | )
178 |
179 | # performs a hybrid search over nodes with episode mentions reranking
180 | NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig(
181 | node_config=NodeSearchConfig(
182 | search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity],
183 | reranker=NodeReranker.episode_mentions,
184 | )
185 | )
186 |
187 | # performs a hybrid search over nodes with episode mentions reranking
188 | NODE_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
189 | node_config=NodeSearchConfig(
190 | search_methods=[
191 | NodeSearchMethod.bm25,
192 | NodeSearchMethod.cosine_similarity,
193 | NodeSearchMethod.bfs,
194 | ],
195 | reranker=NodeReranker.cross_encoder,
196 | ),
197 | limit=10,
198 | )
199 |
200 | # performs a hybrid search over communities with rrf reranking
201 | COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig(
202 | community_config=CommunitySearchConfig(
203 | search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
204 | reranker=CommunityReranker.rrf,
205 | )
206 | )
207 |
208 | # performs a hybrid search over communities with mmr reranking
209 | COMMUNITY_HYBRID_SEARCH_MMR = SearchConfig(
210 | community_config=CommunitySearchConfig(
211 | search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
212 | reranker=CommunityReranker.mmr,
213 | )
214 | )
215 |
216 | # performs a hybrid search over communities with mmr reranking
217 | COMMUNITY_HYBRID_SEARCH_CROSS_ENCODER = SearchConfig(
218 | community_config=CommunitySearchConfig(
219 | search_methods=[CommunitySearchMethod.bm25, CommunitySearchMethod.cosine_similarity],
220 | reranker=CommunityReranker.cross_encoder,
221 | ),
222 | limit=3,
223 | )
224 |
```
--------------------------------------------------------------------------------
/tests/test_node_int.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from datetime import datetime, timedelta
18 | from uuid import uuid4
19 |
20 | import pytest
21 |
22 | from graphiti_core.nodes import (
23 | CommunityNode,
24 | EntityNode,
25 | EpisodeType,
26 | EpisodicNode,
27 | )
28 | from tests.helpers_test import (
29 | assert_community_node_equals,
30 | assert_entity_node_equals,
31 | assert_episodic_node_equals,
32 | get_node_count,
33 | group_id,
34 | )
35 |
36 | created_at = datetime.now()
37 | deleted_at = created_at + timedelta(days=3)
38 | valid_at = created_at + timedelta(days=1)
39 | invalid_at = created_at + timedelta(days=2)
40 |
41 |
42 | @pytest.fixture
43 | def sample_entity_node():
44 | return EntityNode(
45 | uuid=str(uuid4()),
46 | name='Test Entity',
47 | group_id=group_id,
48 | labels=['Entity', 'Person'],
49 | created_at=created_at,
50 | name_embedding=[0.5] * 1024,
51 | summary='Entity Summary',
52 | attributes={
53 | 'age': 30,
54 | 'location': 'New York',
55 | },
56 | )
57 |
58 |
59 | @pytest.fixture
60 | def sample_episodic_node():
61 | return EpisodicNode(
62 | uuid=str(uuid4()),
63 | name='Episode 1',
64 | group_id=group_id,
65 | created_at=created_at,
66 | source=EpisodeType.text,
67 | source_description='Test source',
68 | content='Some content here',
69 | valid_at=valid_at,
70 | entity_edges=[],
71 | )
72 |
73 |
74 | @pytest.fixture
75 | def sample_community_node():
76 | return CommunityNode(
77 | uuid=str(uuid4()),
78 | name='Community A',
79 | group_id=group_id,
80 | created_at=created_at,
81 | name_embedding=[0.5] * 1024,
82 | summary='Community summary',
83 | )
84 |
85 |
86 | @pytest.mark.asyncio
87 | async def test_entity_node(sample_entity_node, graph_driver):
88 | uuid = sample_entity_node.uuid
89 |
90 | # Create node
91 | node_count = await get_node_count(graph_driver, [uuid])
92 | assert node_count == 0
93 | await sample_entity_node.save(graph_driver)
94 | node_count = await get_node_count(graph_driver, [uuid])
95 | assert node_count == 1
96 |
97 | # Get node by uuid
98 | retrieved = await EntityNode.get_by_uuid(graph_driver, sample_entity_node.uuid)
99 | await assert_entity_node_equals(graph_driver, retrieved, sample_entity_node)
100 |
101 | # Get node by uuids
102 | retrieved = await EntityNode.get_by_uuids(graph_driver, [sample_entity_node.uuid])
103 | await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
104 |
105 | # Get node by group ids
106 | retrieved = await EntityNode.get_by_group_ids(
107 | graph_driver, [group_id], limit=2, with_embeddings=True
108 | )
109 | assert len(retrieved) == 1
110 | await assert_entity_node_equals(graph_driver, retrieved[0], sample_entity_node)
111 |
112 | # Delete node by uuid
113 | await sample_entity_node.delete(graph_driver)
114 | node_count = await get_node_count(graph_driver, [uuid])
115 | assert node_count == 0
116 |
117 | # Delete node by uuids
118 | await sample_entity_node.save(graph_driver)
119 | node_count = await get_node_count(graph_driver, [uuid])
120 | assert node_count == 1
121 | await sample_entity_node.delete_by_uuids(graph_driver, [uuid])
122 | node_count = await get_node_count(graph_driver, [uuid])
123 | assert node_count == 0
124 |
125 | # Delete node by group id
126 | await sample_entity_node.save(graph_driver)
127 | node_count = await get_node_count(graph_driver, [uuid])
128 | assert node_count == 1
129 | await sample_entity_node.delete_by_group_id(graph_driver, group_id)
130 | node_count = await get_node_count(graph_driver, [uuid])
131 | assert node_count == 0
132 |
133 | await graph_driver.close()
134 |
135 |
136 | @pytest.mark.asyncio
137 | async def test_community_node(sample_community_node, graph_driver):
138 | uuid = sample_community_node.uuid
139 |
140 | # Create node
141 | node_count = await get_node_count(graph_driver, [uuid])
142 | assert node_count == 0
143 | await sample_community_node.save(graph_driver)
144 | node_count = await get_node_count(graph_driver, [uuid])
145 | assert node_count == 1
146 |
147 | # Get node by uuid
148 | retrieved = await CommunityNode.get_by_uuid(graph_driver, sample_community_node.uuid)
149 | await assert_community_node_equals(graph_driver, retrieved, sample_community_node)
150 |
151 | # Get node by uuids
152 | retrieved = await CommunityNode.get_by_uuids(graph_driver, [sample_community_node.uuid])
153 | await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
154 |
155 | # Get node by group ids
156 | retrieved = await CommunityNode.get_by_group_ids(graph_driver, [group_id], limit=2)
157 | assert len(retrieved) == 1
158 | await assert_community_node_equals(graph_driver, retrieved[0], sample_community_node)
159 |
160 | # Delete node by uuid
161 | await sample_community_node.delete(graph_driver)
162 | node_count = await get_node_count(graph_driver, [uuid])
163 | assert node_count == 0
164 |
165 | # Delete node by uuids
166 | await sample_community_node.save(graph_driver)
167 | node_count = await get_node_count(graph_driver, [uuid])
168 | assert node_count == 1
169 | await sample_community_node.delete_by_uuids(graph_driver, [uuid])
170 | node_count = await get_node_count(graph_driver, [uuid])
171 | assert node_count == 0
172 |
173 | # Delete node by group id
174 | await sample_community_node.save(graph_driver)
175 | node_count = await get_node_count(graph_driver, [uuid])
176 | assert node_count == 1
177 | await sample_community_node.delete_by_group_id(graph_driver, group_id)
178 | node_count = await get_node_count(graph_driver, [uuid])
179 | assert node_count == 0
180 |
181 | await graph_driver.close()
182 |
183 |
184 | @pytest.mark.asyncio
185 | async def test_episodic_node(sample_episodic_node, graph_driver):
186 | uuid = sample_episodic_node.uuid
187 |
188 | # Create node
189 | node_count = await get_node_count(graph_driver, [uuid])
190 | assert node_count == 0
191 | await sample_episodic_node.save(graph_driver)
192 | node_count = await get_node_count(graph_driver, [uuid])
193 | assert node_count == 1
194 |
195 | # Get node by uuid
196 | retrieved = await EpisodicNode.get_by_uuid(graph_driver, sample_episodic_node.uuid)
197 | await assert_episodic_node_equals(retrieved, sample_episodic_node)
198 |
199 | # Get node by uuids
200 | retrieved = await EpisodicNode.get_by_uuids(graph_driver, [sample_episodic_node.uuid])
201 | await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
202 |
203 | # Get node by group ids
204 | retrieved = await EpisodicNode.get_by_group_ids(graph_driver, [group_id], limit=2)
205 | assert len(retrieved) == 1
206 | await assert_episodic_node_equals(retrieved[0], sample_episodic_node)
207 |
208 | # Delete node by uuid
209 | await sample_episodic_node.delete(graph_driver)
210 | node_count = await get_node_count(graph_driver, [uuid])
211 | assert node_count == 0
212 |
213 | # Delete node by uuids
214 | await sample_episodic_node.save(graph_driver)
215 | node_count = await get_node_count(graph_driver, [uuid])
216 | assert node_count == 1
217 | await sample_episodic_node.delete_by_uuids(graph_driver, [uuid])
218 | node_count = await get_node_count(graph_driver, [uuid])
219 | assert node_count == 0
220 |
221 | # Delete node by group id
222 | await sample_episodic_node.save(graph_driver)
223 | node_count = await get_node_count(graph_driver, [uuid])
224 | assert node_count == 1
225 | await sample_episodic_node.delete_by_group_id(graph_driver, group_id)
226 | node_count = await get_node_count(graph_driver, [uuid])
227 | assert node_count == 0
228 |
229 | await graph_driver.close()
230 |
```
--------------------------------------------------------------------------------
/tests/utils/maintenance/test_temporal_operations_int.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import os
18 | from datetime import timedelta
19 |
20 | import pytest
21 | from dotenv import load_dotenv
22 |
23 | from graphiti_core.edges import EntityEdge
24 | from graphiti_core.llm_client import LLMConfig, OpenAIClient
25 | from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
26 | from graphiti_core.utils.datetime_utils import utc_now
27 | from graphiti_core.utils.maintenance.temporal_operations import (
28 | get_edge_contradictions,
29 | )
30 |
31 | load_dotenv()
32 |
33 |
34 | def setup_llm_client():
35 | return OpenAIClient(
36 | LLMConfig(
37 | api_key=os.getenv('TEST_OPENAI_API_KEY'),
38 | model=os.getenv('TEST_OPENAI_MODEL'),
39 | base_url='https://api.openai.com/v1',
40 | )
41 | )
42 |
43 |
44 | def create_test_data():
45 | now = utc_now()
46 |
47 | # Create edges
48 | existing_edge = EntityEdge(
49 | uuid='e1',
50 | source_node_uuid='1',
51 | target_node_uuid='2',
52 | name='LIKES',
53 | fact='Alice likes Bob',
54 | created_at=now - timedelta(days=1),
55 | group_id='1',
56 | )
57 | new_edge = EntityEdge(
58 | uuid='e2',
59 | source_node_uuid='1',
60 | target_node_uuid='2',
61 | name='DISLIKES',
62 | fact='Alice dislikes Bob',
63 | created_at=now,
64 | group_id='1',
65 | )
66 |
67 | # Create current episode
68 | current_episode = EpisodicNode(
69 | name='Current Episode',
70 | content='Alice now dislikes Bob',
71 | created_at=now,
72 | valid_at=now,
73 | source=EpisodeType.message,
74 | source_description='Test episode for unit testing',
75 | group_id='1',
76 | )
77 |
78 | # Create previous episodes
79 | previous_episodes = [
80 | EpisodicNode(
81 | name='Previous Episode',
82 | content='Alice liked Bob',
83 | created_at=now - timedelta(days=1),
84 | valid_at=now - timedelta(days=1),
85 | source=EpisodeType.message,
86 | source_description='Test previous episode for unit testing',
87 | group_id='1',
88 | )
89 | ]
90 |
91 | return existing_edge, new_edge, current_episode, previous_episodes
92 |
93 |
94 | @pytest.mark.asyncio
95 | @pytest.mark.integration
96 | async def test_get_edge_contradictions():
97 | existing_edge, new_edge, current_episode, previous_episodes = create_test_data()
98 |
99 | invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [existing_edge])
100 |
101 | assert len(invalidated_edges) == 1
102 | assert invalidated_edges[0].uuid == existing_edge.uuid
103 |
104 |
105 | @pytest.mark.asyncio
106 | @pytest.mark.integration
107 | async def test_get_edge_contradictions_no_contradictions():
108 | _, new_edge, current_episode, previous_episodes = create_test_data()
109 |
110 | invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, [])
111 |
112 | assert len(invalidated_edges) == 0
113 |
114 |
115 | @pytest.mark.skip(reason='Flaky LLM-based test with non-deterministic results')
116 | @pytest.mark.asyncio
117 | @pytest.mark.integration
118 | async def test_get_edge_contradictions_multiple_existing():
119 | existing_edge1, new_edge, _, _ = create_test_data()
120 | existing_edge2, _, _, _ = create_test_data()
121 | existing_edge2.uuid = 'e3'
122 | existing_edge2.name = 'KNOWS'
123 | existing_edge2.fact = 'Alice knows Bob'
124 |
125 | invalidated_edges = await get_edge_contradictions(
126 | setup_llm_client(), new_edge, [existing_edge1, existing_edge2]
127 | )
128 |
129 | assert len(invalidated_edges) == 1
130 | assert invalidated_edges[0].uuid == existing_edge1.uuid
131 |
132 |
133 | # Helper function to create more complex test data
134 | def create_complex_test_data():
135 | now = utc_now()
136 |
137 | # Create nodes
138 | node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now, group_id='1')
139 | node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now, group_id='1')
140 | node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now, group_id='1')
141 | node4 = EntityNode(
142 | uuid='4', name='Company XYZ', labels=['Organization'], created_at=now, group_id='1'
143 | )
144 |
145 | # Create edges
146 | existing_edge1 = EntityEdge(
147 | uuid='e1',
148 | source_node_uuid='1',
149 | target_node_uuid='2',
150 | name='LIKES',
151 | fact='Alice likes Bob',
152 | group_id='1',
153 | created_at=now - timedelta(days=5),
154 | )
155 | existing_edge2 = EntityEdge(
156 | uuid='e2',
157 | source_node_uuid='1',
158 | target_node_uuid='3',
159 | name='FRIENDS_WITH',
160 | fact='Alice is friends with Charlie',
161 | group_id='1',
162 | created_at=now - timedelta(days=3),
163 | )
164 | existing_edge3 = EntityEdge(
165 | uuid='e3',
166 | source_node_uuid='2',
167 | target_node_uuid='4',
168 | name='WORKS_FOR',
169 | fact='Bob works for Company XYZ',
170 | group_id='1',
171 | created_at=now - timedelta(days=2),
172 | )
173 |
174 | return [existing_edge1, existing_edge2, existing_edge3], [
175 | node1,
176 | node2,
177 | node3,
178 | node4,
179 | ]
180 |
181 |
182 | @pytest.mark.asyncio
183 | @pytest.mark.integration
184 | async def test_invalidate_edges_complex():
185 | existing_edges, nodes = create_complex_test_data()
186 |
187 | # Create a new edge that contradicts an existing one
188 | new_edge = EntityEdge(
189 | uuid='e4',
190 | source_node_uuid='1',
191 | target_node_uuid='2',
192 | name='DISLIKES',
193 | fact='Alice dislikes Bob',
194 | group_id='1',
195 | created_at=utc_now(),
196 | )
197 |
198 | invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
199 |
200 | assert len(invalidated_edges) == 1
201 | assert invalidated_edges[0].uuid == 'e1'
202 |
203 |
204 | @pytest.mark.asyncio
205 | @pytest.mark.integration
206 | async def test_get_edge_contradictions_temporal_update():
207 | existing_edges, nodes = create_complex_test_data()
208 |
209 | # Create a new edge that updates an existing one with new information
210 | new_edge = EntityEdge(
211 | uuid='e5',
212 | source_node_uuid='2',
213 | target_node_uuid='4',
214 | name='LEFT_JOB',
215 | fact='Bob no longer works at at Company XYZ',
216 | group_id='1',
217 | created_at=utc_now(),
218 | )
219 |
220 | invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
221 |
222 | assert len(invalidated_edges) == 1
223 | assert invalidated_edges[0].uuid == 'e3'
224 |
225 |
226 | @pytest.mark.asyncio
227 | @pytest.mark.integration
228 | async def test_get_edge_contradictions_no_effect():
229 | existing_edges, nodes = create_complex_test_data()
230 |
231 | # Create a new edge that doesn't invalidate any existing edges
232 | new_edge = EntityEdge(
233 | uuid='e8',
234 | source_node_uuid='3',
235 | target_node_uuid='4',
236 | name='APPLIED_TO',
237 | fact='Charlie applied to Company XYZ',
238 | group_id='1',
239 | created_at=utc_now(),
240 | )
241 |
242 | invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
243 |
244 | assert len(invalidated_edges) == 0
245 |
246 |
247 | @pytest.mark.skip(reason='Flaky LLM-based test with non-deterministic results')
248 | @pytest.mark.asyncio
249 | @pytest.mark.integration
250 | async def test_invalidate_edges_partial_update():
251 | existing_edges, nodes = create_complex_test_data()
252 |
253 | # Create a new edge that partially updates an existing one
254 | new_edge = EntityEdge(
255 | uuid='e9',
256 | source_node_uuid='2',
257 | target_node_uuid='4',
258 | name='CHANGED_POSITION',
259 | fact='Bob changed his position at Company XYZ',
260 | group_id='1',
261 | created_at=utc_now(),
262 | )
263 |
264 | invalidated_edges = await get_edge_contradictions(setup_llm_client(), new_edge, existing_edges)
265 |
266 | assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated
267 |
268 |
269 | # Run the tests
270 | if __name__ == '__main__':
271 | pytest.main([__file__])
272 |
```
--------------------------------------------------------------------------------
/graphiti_core/graph_queries.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Database query utilities for different graph database backends.
3 |
4 | This module provides database-agnostic query generation for Neo4j and FalkorDB,
5 | supporting index creation, fulltext search, and bulk operations.
6 | """
7 |
8 | from typing_extensions import LiteralString
9 |
10 | from graphiti_core.driver.driver import GraphProvider
11 |
12 | # Mapping from Neo4j fulltext index names to FalkorDB node labels
13 | NEO4J_TO_FALKORDB_MAPPING = {
14 | 'node_name_and_summary': 'Entity',
15 | 'community_name': 'Community',
16 | 'episode_content': 'Episodic',
17 | 'edge_name_and_fact': 'RELATES_TO',
18 | }
19 | # Mapping from fulltext index names to Kuzu node labels
20 | INDEX_TO_LABEL_KUZU_MAPPING = {
21 | 'node_name_and_summary': 'Entity',
22 | 'community_name': 'Community',
23 | 'episode_content': 'Episodic',
24 | 'edge_name_and_fact': 'RelatesToNode_',
25 | }
26 |
27 |
28 | def get_range_indices(provider: GraphProvider) -> list[LiteralString]:
29 | if provider == GraphProvider.FALKORDB:
30 | return [
31 | # Entity node
32 | 'CREATE INDEX FOR (n:Entity) ON (n.uuid, n.group_id, n.name, n.created_at)',
33 | # Episodic node
34 | 'CREATE INDEX FOR (n:Episodic) ON (n.uuid, n.group_id, n.created_at, n.valid_at)',
35 | # Community node
36 | 'CREATE INDEX FOR (n:Community) ON (n.uuid)',
37 | # RELATES_TO edge
38 | 'CREATE INDEX FOR ()-[e:RELATES_TO]-() ON (e.uuid, e.group_id, e.name, e.created_at, e.expired_at, e.valid_at, e.invalid_at)',
39 | # MENTIONS edge
40 | 'CREATE INDEX FOR ()-[e:MENTIONS]-() ON (e.uuid, e.group_id)',
41 | # HAS_MEMBER edge
42 | 'CREATE INDEX FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
43 | ]
44 |
45 | if provider == GraphProvider.KUZU:
46 | return []
47 |
48 | return [
49 | 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)',
50 | 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)',
51 | 'CREATE INDEX community_uuid IF NOT EXISTS FOR (n:Community) ON (n.uuid)',
52 | 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)',
53 | 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)',
54 | 'CREATE INDEX has_member_uuid IF NOT EXISTS FOR ()-[e:HAS_MEMBER]-() ON (e.uuid)',
55 | 'CREATE INDEX entity_group_id IF NOT EXISTS FOR (n:Entity) ON (n.group_id)',
56 | 'CREATE INDEX episode_group_id IF NOT EXISTS FOR (n:Episodic) ON (n.group_id)',
57 | 'CREATE INDEX community_group_id IF NOT EXISTS FOR (n:Community) ON (n.group_id)',
58 | 'CREATE INDEX relation_group_id IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.group_id)',
59 | 'CREATE INDEX mention_group_id IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.group_id)',
60 | 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)',
61 | 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)',
62 | 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)',
63 | 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)',
64 | 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)',
65 | 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)',
66 | 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)',
67 | 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)',
68 | 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)',
69 | ]
70 |
71 |
72 | def get_fulltext_indices(provider: GraphProvider) -> list[LiteralString]:
73 | if provider == GraphProvider.FALKORDB:
74 | from typing import cast
75 |
76 | from graphiti_core.driver.falkordb_driver import STOPWORDS
77 |
78 | # Convert to string representation for embedding in queries
79 | stopwords_str = str(STOPWORDS)
80 |
81 | # Use type: ignore to satisfy LiteralString requirement while maintaining single source of truth
82 | return cast(
83 | list[LiteralString],
84 | [
85 | f"""CALL db.idx.fulltext.createNodeIndex(
86 | {{
87 | label: 'Episodic',
88 | stopwords: {stopwords_str}
89 | }},
90 | 'content', 'source', 'source_description', 'group_id'
91 | )""",
92 | f"""CALL db.idx.fulltext.createNodeIndex(
93 | {{
94 | label: 'Entity',
95 | stopwords: {stopwords_str}
96 | }},
97 | 'name', 'summary', 'group_id'
98 | )""",
99 | f"""CALL db.idx.fulltext.createNodeIndex(
100 | {{
101 | label: 'Community',
102 | stopwords: {stopwords_str}
103 | }},
104 | 'name', 'group_id'
105 | )""",
106 | """CREATE FULLTEXT INDEX FOR ()-[e:RELATES_TO]-() ON (e.name, e.fact, e.group_id)""",
107 | ],
108 | )
109 |
110 | if provider == GraphProvider.KUZU:
111 | return [
112 | "CALL CREATE_FTS_INDEX('Episodic', 'episode_content', ['content', 'source', 'source_description']);",
113 | "CALL CREATE_FTS_INDEX('Entity', 'node_name_and_summary', ['name', 'summary']);",
114 | "CALL CREATE_FTS_INDEX('Community', 'community_name', ['name']);",
115 | "CALL CREATE_FTS_INDEX('RelatesToNode_', 'edge_name_and_fact', ['name', 'fact']);",
116 | ]
117 |
118 | return [
119 | """CREATE FULLTEXT INDEX episode_content IF NOT EXISTS
120 | FOR (e:Episodic) ON EACH [e.content, e.source, e.source_description, e.group_id]""",
121 | """CREATE FULLTEXT INDEX node_name_and_summary IF NOT EXISTS
122 | FOR (n:Entity) ON EACH [n.name, n.summary, n.group_id]""",
123 | """CREATE FULLTEXT INDEX community_name IF NOT EXISTS
124 | FOR (n:Community) ON EACH [n.name, n.group_id]""",
125 | """CREATE FULLTEXT INDEX edge_name_and_fact IF NOT EXISTS
126 | FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact, e.group_id]""",
127 | ]
128 |
129 |
130 | def get_nodes_query(name: str, query: str, limit: int, provider: GraphProvider) -> str:
131 | if provider == GraphProvider.FALKORDB:
132 | label = NEO4J_TO_FALKORDB_MAPPING[name]
133 | return f"CALL db.idx.fulltext.queryNodes('{label}', {query})"
134 |
135 | if provider == GraphProvider.KUZU:
136 | label = INDEX_TO_LABEL_KUZU_MAPPING[name]
137 | return f"CALL QUERY_FTS_INDEX('{label}', '{name}', {query}, TOP := $limit)"
138 |
139 | return f'CALL db.index.fulltext.queryNodes("{name}", {query}, {{limit: $limit}})'
140 |
141 |
142 | def get_vector_cosine_func_query(vec1, vec2, provider: GraphProvider) -> str:
143 | if provider == GraphProvider.FALKORDB:
144 | # FalkorDB uses a different syntax for regular cosine similarity and Neo4j uses normalized cosine similarity
145 | return f'(2 - vec.cosineDistance({vec1}, vecf32({vec2})))/2'
146 |
147 | if provider == GraphProvider.KUZU:
148 | return f'array_cosine_similarity({vec1}, {vec2})'
149 |
150 | return f'vector.similarity.cosine({vec1}, {vec2})'
151 |
152 |
153 | def get_relationships_query(name: str, limit: int, provider: GraphProvider) -> str:
154 | if provider == GraphProvider.FALKORDB:
155 | label = NEO4J_TO_FALKORDB_MAPPING[name]
156 | return f"CALL db.idx.fulltext.queryRelationships('{label}', $query)"
157 |
158 | if provider == GraphProvider.KUZU:
159 | label = INDEX_TO_LABEL_KUZU_MAPPING[name]
160 | return f"CALL QUERY_FTS_INDEX('{label}', '{name}', cast($query AS STRING), TOP := $limit)"
161 |
162 | return f'CALL db.index.fulltext.queryRelationships("{name}", $query, {{limit: $limit}})'
163 |
```
--------------------------------------------------------------------------------
/examples/azure-openai/azure_openai_neo4j.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2025, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import asyncio
18 | import json
19 | import logging
20 | import os
21 | from datetime import datetime, timezone
22 | from logging import INFO
23 |
24 | from dotenv import load_dotenv
25 | from openai import AsyncOpenAI
26 |
27 | from graphiti_core import Graphiti
28 | from graphiti_core.embedder.azure_openai import AzureOpenAIEmbedderClient
29 | from graphiti_core.llm_client.azure_openai_client import AzureOpenAILLMClient
30 | from graphiti_core.llm_client.config import LLMConfig
31 | from graphiti_core.nodes import EpisodeType
32 |
33 | #################################################
34 | # CONFIGURATION
35 | #################################################
36 | # Set up logging and environment variables for
37 | # connecting to Neo4j database and Azure OpenAI
38 | #################################################
39 |
40 | # Configure logging
41 | logging.basicConfig(
42 | level=INFO,
43 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
44 | datefmt='%Y-%m-%d %H:%M:%S',
45 | )
46 | logger = logging.getLogger(__name__)
47 |
48 | load_dotenv()
49 |
50 | # Neo4j connection parameters
51 | # Make sure Neo4j Desktop is running with a local DBMS started
52 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
53 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
54 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
55 |
56 | # Azure OpenAI connection parameters
57 | azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
58 | azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
59 | azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', 'gpt-4.1')
60 | azure_embedding_deployment = os.environ.get(
61 | 'AZURE_OPENAI_EMBEDDING_DEPLOYMENT', 'text-embedding-3-small'
62 | )
63 |
64 | if not azure_endpoint or not azure_api_key:
65 | raise ValueError('AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_API_KEY must be set')
66 |
67 |
68 | async def main():
69 | #################################################
70 | # INITIALIZATION
71 | #################################################
72 | # Connect to Neo4j and Azure OpenAI, then set up
73 | # Graphiti indices. This is required before using
74 | # other Graphiti functionality
75 | #################################################
76 |
77 | # Initialize Azure OpenAI client
78 | azure_client = AsyncOpenAI(
79 | base_url=f'{azure_endpoint}/openai/v1/',
80 | api_key=azure_api_key,
81 | )
82 |
83 | # Create LLM and Embedder clients
84 | llm_client = AzureOpenAILLMClient(
85 | azure_client=azure_client,
86 | config=LLMConfig(model=azure_deployment, small_model=azure_deployment),
87 | )
88 | embedder_client = AzureOpenAIEmbedderClient(
89 | azure_client=azure_client, model=azure_embedding_deployment
90 | )
91 |
92 | # Initialize Graphiti with Neo4j connection and Azure OpenAI clients
93 | graphiti = Graphiti(
94 | neo4j_uri,
95 | neo4j_user,
96 | neo4j_password,
97 | llm_client=llm_client,
98 | embedder=embedder_client,
99 | )
100 |
101 | try:
102 | #################################################
103 | # ADDING EPISODES
104 | #################################################
105 | # Episodes are the primary units of information
106 | # in Graphiti. They can be text or structured JSON
107 | # and are automatically processed to extract entities
108 | # and relationships.
109 | #################################################
110 |
111 | # Example: Add Episodes
112 | # Episodes list containing both text and JSON episodes
113 | episodes = [
114 | {
115 | 'content': 'Kamala Harris is the Attorney General of California. She was previously '
116 | 'the district attorney for San Francisco.',
117 | 'type': EpisodeType.text,
118 | 'description': 'podcast transcript',
119 | },
120 | {
121 | 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
122 | 'type': EpisodeType.text,
123 | 'description': 'podcast transcript',
124 | },
125 | {
126 | 'content': {
127 | 'name': 'Gavin Newsom',
128 | 'position': 'Governor',
129 | 'state': 'California',
130 | 'previous_role': 'Lieutenant Governor',
131 | 'previous_location': 'San Francisco',
132 | },
133 | 'type': EpisodeType.json,
134 | 'description': 'podcast metadata',
135 | },
136 | ]
137 |
138 | # Add episodes to the graph
139 | for i, episode in enumerate(episodes):
140 | await graphiti.add_episode(
141 | name=f'California Politics {i}',
142 | episode_body=(
143 | episode['content']
144 | if isinstance(episode['content'], str)
145 | else json.dumps(episode['content'])
146 | ),
147 | source=episode['type'],
148 | source_description=episode['description'],
149 | reference_time=datetime.now(timezone.utc),
150 | )
151 | print(f'Added episode: California Politics {i} ({episode["type"].value})')
152 |
153 | #################################################
154 | # BASIC SEARCH
155 | #################################################
156 | # The simplest way to retrieve relationships (edges)
157 | # from Graphiti is using the search method, which
158 | # performs a hybrid search combining semantic
159 | # similarity and BM25 text retrieval.
160 | #################################################
161 |
162 | # Perform a hybrid search combining semantic similarity and BM25 retrieval
163 | print("\nSearching for: 'Who was the California Attorney General?'")
164 | results = await graphiti.search('Who was the California Attorney General?')
165 |
166 | # Print search results
167 | print('\nSearch Results:')
168 | for result in results:
169 | print(f'UUID: {result.uuid}')
170 | print(f'Fact: {result.fact}')
171 | if hasattr(result, 'valid_at') and result.valid_at:
172 | print(f'Valid from: {result.valid_at}')
173 | if hasattr(result, 'invalid_at') and result.invalid_at:
174 | print(f'Valid until: {result.invalid_at}')
175 | print('---')
176 |
177 | #################################################
178 | # CENTER NODE SEARCH
179 | #################################################
180 | # For more contextually relevant results, you can
181 | # use a center node to rerank search results based
182 | # on their graph distance to a specific node
183 | #################################################
184 |
185 | # Use the top search result's UUID as the center node for reranking
186 | if results and len(results) > 0:
187 | # Get the source node UUID from the top result
188 | center_node_uuid = results[0].source_node_uuid
189 |
190 | print('\nReranking search results based on graph distance:')
191 | print(f'Using center node UUID: {center_node_uuid}')
192 |
193 | reranked_results = await graphiti.search(
194 | 'Who was the California Attorney General?',
195 | center_node_uuid=center_node_uuid,
196 | )
197 |
198 | # Print reranked search results
199 | print('\nReranked Search Results:')
200 | for result in reranked_results:
201 | print(f'UUID: {result.uuid}')
202 | print(f'Fact: {result.fact}')
203 | if hasattr(result, 'valid_at') and result.valid_at:
204 | print(f'Valid from: {result.valid_at}')
205 | if hasattr(result, 'invalid_at') and result.invalid_at:
206 | print(f'Valid until: {result.invalid_at}')
207 | print('---')
208 | else:
209 | print('No results found in the initial search to use as center node.')
210 |
211 | finally:
212 | #################################################
213 | # CLEANUP
214 | #################################################
215 | # Always close the connection to Neo4j when
216 | # finished to properly release resources
217 | #################################################
218 |
219 | # Close the connection
220 | await graphiti.close()
221 | print('\nConnection closed')
222 |
223 |
224 | if __name__ == '__main__':
225 | asyncio.run(main())
226 |
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/openai_generic_client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import json
18 | import logging
19 | import typing
20 | from typing import Any, ClassVar
21 |
22 | import openai
23 | from openai import AsyncOpenAI
24 | from openai.types.chat import ChatCompletionMessageParam
25 | from pydantic import BaseModel
26 |
27 | from ..prompts.models import Message
28 | from .client import LLMClient, get_extraction_language_instruction
29 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
30 | from .errors import RateLimitError, RefusalError
31 |
32 | logger = logging.getLogger(__name__)
33 |
34 | DEFAULT_MODEL = 'gpt-4.1-mini'
35 |
36 |
37 | class OpenAIGenericClient(LLMClient):
38 | """
39 | OpenAIClient is a client class for interacting with OpenAI's language models.
40 |
41 | This class extends the LLMClient and provides methods to initialize the client,
42 | get an embedder, and generate responses from the language model.
43 |
44 | Attributes:
45 | client (AsyncOpenAI): The OpenAI client used to interact with the API.
46 | model (str): The model name to use for generating responses.
47 | temperature (float): The temperature to use for generating responses.
48 | max_tokens (int): The maximum number of tokens to generate in a response.
49 |
50 | Methods:
51 | __init__(config: LLMConfig | None = None, cache: bool = False, client: typing.Any = None):
52 | Initializes the OpenAIClient with the provided configuration, cache setting, and client.
53 |
54 | _generate_response(messages: list[Message]) -> dict[str, typing.Any]:
55 | Generates a response from the language model based on the provided messages.
56 | """
57 |
58 | # Class-level constants
59 | MAX_RETRIES: ClassVar[int] = 2
60 |
61 | def __init__(
62 | self,
63 | config: LLMConfig | None = None,
64 | cache: bool = False,
65 | client: typing.Any = None,
66 | max_tokens: int = 16384,
67 | ):
68 | """
69 | Initialize the OpenAIGenericClient with the provided configuration, cache setting, and client.
70 |
71 | Args:
72 | config (LLMConfig | None): The configuration for the LLM client, including API key, model, base URL, temperature, and max tokens.
73 | cache (bool): Whether to use caching for responses. Defaults to False.
74 | client (Any | None): An optional async client instance to use. If not provided, a new AsyncOpenAI client is created.
75 | max_tokens (int): The maximum number of tokens to generate. Defaults to 16384 (16K) for better compatibility with local models.
76 |
77 | """
78 | # removed caching to simplify the `generate_response` override
79 | if cache:
80 | raise NotImplementedError('Caching is not implemented for OpenAI')
81 |
82 | if config is None:
83 | config = LLMConfig()
84 |
85 | super().__init__(config, cache)
86 |
87 | # Override max_tokens to support higher limits for local models
88 | self.max_tokens = max_tokens
89 |
90 | if client is None:
91 | self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
92 | else:
93 | self.client = client
94 |
95 | async def _generate_response(
96 | self,
97 | messages: list[Message],
98 | response_model: type[BaseModel] | None = None,
99 | max_tokens: int = DEFAULT_MAX_TOKENS,
100 | model_size: ModelSize = ModelSize.medium,
101 | ) -> dict[str, typing.Any]:
102 | openai_messages: list[ChatCompletionMessageParam] = []
103 | for m in messages:
104 | m.content = self._clean_input(m.content)
105 | if m.role == 'user':
106 | openai_messages.append({'role': 'user', 'content': m.content})
107 | elif m.role == 'system':
108 | openai_messages.append({'role': 'system', 'content': m.content})
109 | try:
110 | # Prepare response format
111 | response_format: dict[str, Any] = {'type': 'json_object'}
112 | if response_model is not None:
113 | schema_name = getattr(response_model, '__name__', 'structured_response')
114 | json_schema = response_model.model_json_schema()
115 | response_format = {
116 | 'type': 'json_schema',
117 | 'json_schema': {
118 | 'name': schema_name,
119 | 'schema': json_schema,
120 | },
121 | }
122 |
123 | response = await self.client.chat.completions.create(
124 | model=self.model or DEFAULT_MODEL,
125 | messages=openai_messages,
126 | temperature=self.temperature,
127 | max_tokens=self.max_tokens,
128 | response_format=response_format, # type: ignore[arg-type]
129 | )
130 | result = response.choices[0].message.content or ''
131 | return json.loads(result)
132 | except openai.RateLimitError as e:
133 | raise RateLimitError from e
134 | except Exception as e:
135 | logger.error(f'Error in generating LLM response: {e}')
136 | raise
137 |
138 | async def generate_response(
139 | self,
140 | messages: list[Message],
141 | response_model: type[BaseModel] | None = None,
142 | max_tokens: int | None = None,
143 | model_size: ModelSize = ModelSize.medium,
144 | group_id: str | None = None,
145 | prompt_name: str | None = None,
146 | ) -> dict[str, typing.Any]:
147 | if max_tokens is None:
148 | max_tokens = self.max_tokens
149 |
150 | # Add multilingual extraction instructions
151 | messages[0].content += get_extraction_language_instruction(group_id)
152 |
153 | # Wrap entire operation in tracing span
154 | with self.tracer.start_span('llm.generate') as span:
155 | attributes = {
156 | 'llm.provider': 'openai',
157 | 'model.size': model_size.value,
158 | 'max_tokens': max_tokens,
159 | }
160 | if prompt_name:
161 | attributes['prompt.name'] = prompt_name
162 | span.add_attributes(attributes)
163 |
164 | retry_count = 0
165 | last_error = None
166 |
167 | while retry_count <= self.MAX_RETRIES:
168 | try:
169 | response = await self._generate_response(
170 | messages, response_model, max_tokens=max_tokens, model_size=model_size
171 | )
172 | return response
173 | except (RateLimitError, RefusalError):
174 | # These errors should not trigger retries
175 | span.set_status('error', str(last_error))
176 | raise
177 | except (
178 | openai.APITimeoutError,
179 | openai.APIConnectionError,
180 | openai.InternalServerError,
181 | ):
182 | # Let OpenAI's client handle these retries
183 | span.set_status('error', str(last_error))
184 | raise
185 | except Exception as e:
186 | last_error = e
187 |
188 | # Don't retry if we've hit the max retries
189 | if retry_count >= self.MAX_RETRIES:
190 | logger.error(f'Max retries ({self.MAX_RETRIES}) exceeded. Last error: {e}')
191 | span.set_status('error', str(e))
192 | span.record_exception(e)
193 | raise
194 |
195 | retry_count += 1
196 |
197 | # Construct a detailed error message for the LLM
198 | error_context = (
199 | f'The previous response attempt was invalid. '
200 | f'Error type: {e.__class__.__name__}. '
201 | f'Error details: {str(e)}. '
202 | f'Please try again with a valid response, ensuring the output matches '
203 | f'the expected format and constraints.'
204 | )
205 |
206 | error_message = Message(role='user', content=error_context)
207 | messages.append(error_message)
208 | logger.warning(
209 | f'Retrying after application error (attempt {retry_count}/{self.MAX_RETRIES}): {e}'
210 | )
211 |
212 | # If we somehow get here, raise the last error
213 | span.set_status('error', str(last_error))
214 | raise last_error or Exception('Max retries exceeded with no specific error')
215 |
```
--------------------------------------------------------------------------------
/graphiti_core/llm_client/client.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import hashlib
18 | import json
19 | import logging
20 | import typing
21 | from abc import ABC, abstractmethod
22 |
23 | import httpx
24 | from diskcache import Cache
25 | from pydantic import BaseModel
26 | from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random_exponential
27 |
28 | from ..prompts.models import Message
29 | from ..tracer import NoOpTracer, Tracer
30 | from .config import DEFAULT_MAX_TOKENS, LLMConfig, ModelSize
31 | from .errors import RateLimitError
32 |
33 | DEFAULT_TEMPERATURE = 0
34 | DEFAULT_CACHE_DIR = './llm_cache'
35 |
36 |
37 | def get_extraction_language_instruction(group_id: str | None = None) -> str:
38 | """Returns instruction for language extraction behavior.
39 |
40 | Override this function to customize language extraction:
41 | - Return empty string to disable multilingual instructions
42 | - Return custom instructions for specific language requirements
43 | - Use group_id to provide different instructions per group/partition
44 |
45 | Args:
46 | group_id: Optional partition identifier for the graph
47 |
48 | Returns:
49 | str: Language instruction to append to system messages
50 | """
51 | return '\n\nAny extracted information should be returned in the same language as it was written in.'
52 |
53 |
54 | logger = logging.getLogger(__name__)
55 |
56 |
57 | def is_server_or_retry_error(exception):
58 | if isinstance(exception, RateLimitError | json.decoder.JSONDecodeError):
59 | return True
60 |
61 | return (
62 | isinstance(exception, httpx.HTTPStatusError) and 500 <= exception.response.status_code < 600
63 | )
64 |
65 |
66 | class LLMClient(ABC):
67 | def __init__(self, config: LLMConfig | None, cache: bool = False):
68 | if config is None:
69 | config = LLMConfig()
70 |
71 | self.config = config
72 | self.model = config.model
73 | self.small_model = config.small_model
74 | self.temperature = config.temperature
75 | self.max_tokens = config.max_tokens
76 | self.cache_enabled = cache
77 | self.cache_dir = None
78 | self.tracer: Tracer = NoOpTracer()
79 |
80 | # Only create the cache directory if caching is enabled
81 | if self.cache_enabled:
82 | self.cache_dir = Cache(DEFAULT_CACHE_DIR)
83 |
84 | def set_tracer(self, tracer: Tracer) -> None:
85 | """Set the tracer for this LLM client."""
86 | self.tracer = tracer
87 |
88 | def _clean_input(self, input: str) -> str:
89 | """Clean input string of invalid unicode and control characters.
90 |
91 | Args:
92 | input: Raw input string to be cleaned
93 |
94 | Returns:
95 | Cleaned string safe for LLM processing
96 | """
97 | # Clean any invalid Unicode
98 | cleaned = input.encode('utf-8', errors='ignore').decode('utf-8')
99 |
100 | # Remove zero-width characters and other invisible unicode
101 | zero_width = '\u200b\u200c\u200d\ufeff\u2060'
102 | for char in zero_width:
103 | cleaned = cleaned.replace(char, '')
104 |
105 | # Remove control characters except newlines, returns, and tabs
106 | cleaned = ''.join(char for char in cleaned if ord(char) >= 32 or char in '\n\r\t')
107 |
108 | return cleaned
109 |
110 | @retry(
111 | stop=stop_after_attempt(4),
112 | wait=wait_random_exponential(multiplier=10, min=5, max=120),
113 | retry=retry_if_exception(is_server_or_retry_error),
114 | after=lambda retry_state: logger.warning(
115 | f'Retrying {retry_state.fn.__name__ if retry_state.fn else "function"} after {retry_state.attempt_number} attempts...'
116 | )
117 | if retry_state.attempt_number > 1
118 | else None,
119 | reraise=True,
120 | )
121 | async def _generate_response_with_retry(
122 | self,
123 | messages: list[Message],
124 | response_model: type[BaseModel] | None = None,
125 | max_tokens: int = DEFAULT_MAX_TOKENS,
126 | model_size: ModelSize = ModelSize.medium,
127 | ) -> dict[str, typing.Any]:
128 | try:
129 | return await self._generate_response(messages, response_model, max_tokens, model_size)
130 | except (httpx.HTTPStatusError, RateLimitError) as e:
131 | raise e
132 |
133 | @abstractmethod
134 | async def _generate_response(
135 | self,
136 | messages: list[Message],
137 | response_model: type[BaseModel] | None = None,
138 | max_tokens: int = DEFAULT_MAX_TOKENS,
139 | model_size: ModelSize = ModelSize.medium,
140 | ) -> dict[str, typing.Any]:
141 | pass
142 |
143 | def _get_cache_key(self, messages: list[Message]) -> str:
144 | # Create a unique cache key based on the messages and model
145 | message_str = json.dumps([m.model_dump() for m in messages], sort_keys=True)
146 | key_str = f'{self.model}:{message_str}'
147 | return hashlib.md5(key_str.encode()).hexdigest()
148 |
149 | async def generate_response(
150 | self,
151 | messages: list[Message],
152 | response_model: type[BaseModel] | None = None,
153 | max_tokens: int | None = None,
154 | model_size: ModelSize = ModelSize.medium,
155 | group_id: str | None = None,
156 | prompt_name: str | None = None,
157 | ) -> dict[str, typing.Any]:
158 | if max_tokens is None:
159 | max_tokens = self.max_tokens
160 |
161 | if response_model is not None:
162 | serialized_model = json.dumps(response_model.model_json_schema())
163 | messages[
164 | -1
165 | ].content += (
166 | f'\n\nRespond with a JSON object in the following format:\n\n{serialized_model}'
167 | )
168 |
169 | # Add multilingual extraction instructions
170 | messages[0].content += get_extraction_language_instruction(group_id)
171 |
172 | for message in messages:
173 | message.content = self._clean_input(message.content)
174 |
175 | # Wrap entire operation in tracing span
176 | with self.tracer.start_span('llm.generate') as span:
177 | attributes = {
178 | 'llm.provider': self._get_provider_type(),
179 | 'model.size': model_size.value,
180 | 'max_tokens': max_tokens,
181 | 'cache.enabled': self.cache_enabled,
182 | }
183 | if prompt_name:
184 | attributes['prompt.name'] = prompt_name
185 | span.add_attributes(attributes)
186 |
187 | # Check cache first
188 | if self.cache_enabled and self.cache_dir is not None:
189 | cache_key = self._get_cache_key(messages)
190 | cached_response = self.cache_dir.get(cache_key)
191 | if cached_response is not None:
192 | logger.debug(f'Cache hit for {cache_key}')
193 | span.add_attributes({'cache.hit': True})
194 | return cached_response
195 |
196 | span.add_attributes({'cache.hit': False})
197 |
198 | # Execute LLM call
199 | try:
200 | response = await self._generate_response_with_retry(
201 | messages, response_model, max_tokens, model_size
202 | )
203 | except Exception as e:
204 | span.set_status('error', str(e))
205 | span.record_exception(e)
206 | raise
207 |
208 | # Cache response if enabled
209 | if self.cache_enabled and self.cache_dir is not None:
210 | cache_key = self._get_cache_key(messages)
211 | self.cache_dir.set(cache_key, response)
212 |
213 | return response
214 |
215 | def _get_provider_type(self) -> str:
216 | """Get provider type from class name."""
217 | class_name = self.__class__.__name__.lower()
218 | if 'openai' in class_name:
219 | return 'openai'
220 | elif 'anthropic' in class_name:
221 | return 'anthropic'
222 | elif 'gemini' in class_name:
223 | return 'gemini'
224 | elif 'groq' in class_name:
225 | return 'groq'
226 | else:
227 | return 'unknown'
228 |
229 | def _get_failed_generation_log(self, messages: list[Message], output: str | None) -> str:
230 | """
231 | Log the full input messages, the raw output (if any), and the exception for debugging failed generations.
232 | """
233 | log = ''
234 | log += f'Input messages: {json.dumps([m.model_dump() for m in messages], indent=2)}\n'
235 | if output is not None:
236 | if len(output) > 4000:
237 | log += f'Raw output: {output[:2000]}... (truncated) ...{output[-2000:]}\n'
238 | else:
239 | log += f'Raw output: {output}\n'
240 | else:
241 | log += 'No raw output available'
242 | return log
243 |
```
--------------------------------------------------------------------------------
/graphiti_core/prompts/dedupe_nodes.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from typing import Any, Protocol, TypedDict
18 |
19 | from pydantic import BaseModel, Field
20 |
21 | from .models import Message, PromptFunction, PromptVersion
22 | from .prompt_helpers import to_prompt_json
23 |
24 |
25 | class NodeDuplicate(BaseModel):
26 | id: int = Field(..., description='integer id of the entity')
27 | duplicate_idx: int = Field(
28 | ...,
29 | description='idx of the duplicate entity. If no duplicate entities are found, default to -1.',
30 | )
31 | name: str = Field(
32 | ...,
33 | description='Name of the entity. Should be the most complete and descriptive name of the entity. Do not include any JSON formatting in the Entity name such as {}.',
34 | )
35 | duplicates: list[int] = Field(
36 | ...,
37 | description='idx of all entities that are a duplicate of the entity with the above id.',
38 | )
39 |
40 |
41 | class NodeResolutions(BaseModel):
42 | entity_resolutions: list[NodeDuplicate] = Field(..., description='List of resolved nodes')
43 |
44 |
45 | class Prompt(Protocol):
46 | node: PromptVersion
47 | node_list: PromptVersion
48 | nodes: PromptVersion
49 |
50 |
51 | class Versions(TypedDict):
52 | node: PromptFunction
53 | node_list: PromptFunction
54 | nodes: PromptFunction
55 |
56 |
57 | def node(context: dict[str, Any]) -> list[Message]:
58 | return [
59 | Message(
60 | role='system',
61 | content='You are a helpful assistant that determines whether or not a NEW ENTITY is a duplicate of any EXISTING ENTITIES.',
62 | ),
63 | Message(
64 | role='user',
65 | content=f"""
66 | <PREVIOUS MESSAGES>
67 | {to_prompt_json([ep for ep in context['previous_episodes']])}
68 | </PREVIOUS MESSAGES>
69 | <CURRENT MESSAGE>
70 | {context['episode_content']}
71 | </CURRENT MESSAGE>
72 | <NEW ENTITY>
73 | {to_prompt_json(context['extracted_node'])}
74 | </NEW ENTITY>
75 | <ENTITY TYPE DESCRIPTION>
76 | {to_prompt_json(context['entity_type_description'])}
77 | </ENTITY TYPE DESCRIPTION>
78 |
79 | <EXISTING ENTITIES>
80 | {to_prompt_json(context['existing_nodes'])}
81 | </EXISTING ENTITIES>
82 |
83 | Given the above EXISTING ENTITIES and their attributes, MESSAGE, and PREVIOUS MESSAGES; Determine if the NEW ENTITY extracted from the conversation
84 | is a duplicate entity of one of the EXISTING ENTITIES.
85 |
86 | Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
87 | Semantic Equivalence: if a descriptive label in existing_entities clearly refers to a named entity in context, treat them as duplicates.
88 |
89 | Do NOT mark entities as duplicates if:
90 | - They are related but distinct.
91 | - They have similar names or purposes but refer to separate instances or concepts.
92 |
93 | TASK:
94 | 1. Compare `new_entity` against each item in `existing_entities`.
95 | 2. If it refers to the same real-world object or concept, collect its index.
96 | 3. Let `duplicate_idx` = the smallest collected index, or -1 if none.
97 | 4. Let `duplicates` = the sorted list of all collected indices (empty list if none).
98 |
99 | Respond with a JSON object containing an "entity_resolutions" array with a single entry:
100 | {{
101 | "entity_resolutions": [
102 | {{
103 | "id": integer id from NEW ENTITY,
104 | "name": the best full name for the entity,
105 | "duplicate_idx": integer index of the best duplicate in EXISTING ENTITIES, or -1 if none,
106 | "duplicates": sorted list of all duplicate indices you collected (deduplicate the list, use [] when none)
107 | }}
108 | ]
109 | }}
110 |
111 | Only reference indices that appear in EXISTING ENTITIES, and return [] / -1 when unsure.
112 | """,
113 | ),
114 | ]
115 |
116 |
117 | def nodes(context: dict[str, Any]) -> list[Message]:
118 | return [
119 | Message(
120 | role='system',
121 | content='You are a helpful assistant that determines whether or not ENTITIES extracted from a conversation are duplicates'
122 | ' of existing entities.',
123 | ),
124 | Message(
125 | role='user',
126 | content=f"""
127 | <PREVIOUS MESSAGES>
128 | {to_prompt_json([ep for ep in context['previous_episodes']])}
129 | </PREVIOUS MESSAGES>
130 | <CURRENT MESSAGE>
131 | {context['episode_content']}
132 | </CURRENT MESSAGE>
133 |
134 |
135 | Each of the following ENTITIES were extracted from the CURRENT MESSAGE.
136 | Each entity in ENTITIES is represented as a JSON object with the following structure:
137 | {{
138 | id: integer id of the entity,
139 | name: "name of the entity",
140 | entity_type: ["Entity", "<optional additional label>", ...],
141 | entity_type_description: "Description of what the entity type represents"
142 | }}
143 |
144 | <ENTITIES>
145 | {to_prompt_json(context['extracted_nodes'])}
146 | </ENTITIES>
147 |
148 | <EXISTING ENTITIES>
149 | {to_prompt_json(context['existing_nodes'])}
150 | </EXISTING ENTITIES>
151 |
152 | Each entry in EXISTING ENTITIES is an object with the following structure:
153 | {{
154 | idx: integer index of the candidate entity (use this when referencing a duplicate),
155 | name: "name of the candidate entity",
156 | entity_types: ["Entity", "<optional additional label>", ...],
157 | ...<additional attributes such as summaries or metadata>
158 | }}
159 |
160 | For each of the above ENTITIES, determine if the entity is a duplicate of any of the EXISTING ENTITIES.
161 |
162 | Entities should only be considered duplicates if they refer to the *same real-world object or concept*.
163 |
164 | Do NOT mark entities as duplicates if:
165 | - They are related but distinct.
166 | - They have similar names or purposes but refer to separate instances or concepts.
167 |
168 | Task:
169 | ENTITIES contains {len(context['extracted_nodes'])} entities with IDs 0 through {len(context['extracted_nodes']) - 1}.
170 | Your response MUST include EXACTLY {len(context['extracted_nodes'])} resolutions with IDs 0 through {len(context['extracted_nodes']) - 1}. Do not skip or add IDs.
171 |
172 | For every entity, return an object with the following keys:
173 | {{
174 | "id": integer id from ENTITIES,
175 | "name": the best full name for the entity (preserve the original name unless a duplicate has a more complete name),
176 | "duplicate_idx": the idx of the EXISTING ENTITY that is the best duplicate match, or -1 if there is no duplicate,
177 | "duplicates": a sorted list of all idx values from EXISTING ENTITIES that refer to duplicates (deduplicate the list, use [] when none or unsure)
178 | }}
179 |
180 | - Only use idx values that appear in EXISTING ENTITIES.
181 | - Set duplicate_idx to the smallest idx you collected for that entity, or -1 if duplicates is empty.
182 | - Never fabricate entities or indices.
183 | """,
184 | ),
185 | ]
186 |
187 |
188 | def node_list(context: dict[str, Any]) -> list[Message]:
189 | return [
190 | Message(
191 | role='system',
192 | content='You are a helpful assistant that de-duplicates nodes from node lists.',
193 | ),
194 | Message(
195 | role='user',
196 | content=f"""
197 | Given the following context, deduplicate a list of nodes:
198 |
199 | Nodes:
200 | {to_prompt_json(context['nodes'])}
201 |
202 | Task:
203 | 1. Group nodes together such that all duplicate nodes are in the same list of uuids
204 | 2. All duplicate uuids should be grouped together in the same list
205 | 3. Also return a new summary that synthesizes the summary into a new short summary
206 |
207 | Guidelines:
208 | 1. Each uuid from the list of nodes should appear EXACTLY once in your response
209 | 2. If a node has no duplicates, it should appear in the response in a list of only one uuid
210 |
211 | Respond with a JSON object in the following format:
212 | {{
213 | "nodes": [
214 | {{
215 | "uuids": ["5d643020624c42fa9de13f97b1b3fa39", "node that is a duplicate of 5d643020624c42fa9de13f97b1b3fa39"],
216 | "summary": "Brief summary of the node summaries that appear in the list of names."
217 | }}
218 | ]
219 | }}
220 | """,
221 | ),
222 | ]
223 |
224 |
225 | versions: Versions = {'node': node, 'node_list': node_list, 'nodes': nodes}
226 |
```
--------------------------------------------------------------------------------
/graphiti_core/search/search_filters.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from datetime import datetime
18 | from enum import Enum
19 | from typing import Any
20 |
21 | from pydantic import BaseModel, Field
22 |
23 | from graphiti_core.driver.driver import GraphProvider
24 |
25 |
26 | class ComparisonOperator(Enum):
27 | equals = '='
28 | not_equals = '<>'
29 | greater_than = '>'
30 | less_than = '<'
31 | greater_than_equal = '>='
32 | less_than_equal = '<='
33 | is_null = 'IS NULL'
34 | is_not_null = 'IS NOT NULL'
35 |
36 |
37 | class DateFilter(BaseModel):
38 | date: datetime | None = Field(description='A datetime to filter on')
39 | comparison_operator: ComparisonOperator = Field(
40 | description='Comparison operator for date filter'
41 | )
42 |
43 |
44 | class SearchFilters(BaseModel):
45 | node_labels: list[str] | None = Field(
46 | default=None, description='List of node labels to filter on'
47 | )
48 | edge_types: list[str] | None = Field(
49 | default=None, description='List of edge types to filter on'
50 | )
51 | valid_at: list[list[DateFilter]] | None = Field(default=None)
52 | invalid_at: list[list[DateFilter]] | None = Field(default=None)
53 | created_at: list[list[DateFilter]] | None = Field(default=None)
54 | expired_at: list[list[DateFilter]] | None = Field(default=None)
55 | edge_uuids: list[str] | None = Field(default=None)
56 |
57 |
58 | def cypher_to_opensearch_operator(op: ComparisonOperator) -> str:
59 | mapping = {
60 | ComparisonOperator.greater_than: 'gt',
61 | ComparisonOperator.less_than: 'lt',
62 | ComparisonOperator.greater_than_equal: 'gte',
63 | ComparisonOperator.less_than_equal: 'lte',
64 | }
65 | return mapping.get(op, op.value)
66 |
67 |
68 | def node_search_filter_query_constructor(
69 | filters: SearchFilters,
70 | provider: GraphProvider,
71 | ) -> tuple[list[str], dict[str, Any]]:
72 | filter_queries: list[str] = []
73 | filter_params: dict[str, Any] = {}
74 |
75 | if filters.node_labels is not None:
76 | if provider == GraphProvider.KUZU:
77 | node_label_filter = 'list_has_all(n.labels, $labels)'
78 | filter_params['labels'] = filters.node_labels
79 | else:
80 | node_labels = '|'.join(filters.node_labels)
81 | node_label_filter = 'n:' + node_labels
82 | filter_queries.append(node_label_filter)
83 |
84 | return filter_queries, filter_params
85 |
86 |
87 | def date_filter_query_constructor(
88 | value_name: str, param_name: str, operator: ComparisonOperator
89 | ) -> str:
90 | query = '(' + value_name + ' '
91 |
92 | if operator == ComparisonOperator.is_null or operator == ComparisonOperator.is_not_null:
93 | query += operator.value + ')'
94 | else:
95 | query += operator.value + ' ' + param_name + ')'
96 |
97 | return query
98 |
99 |
100 | def edge_search_filter_query_constructor(
101 | filters: SearchFilters,
102 | provider: GraphProvider,
103 | ) -> tuple[list[str], dict[str, Any]]:
104 | filter_queries: list[str] = []
105 | filter_params: dict[str, Any] = {}
106 |
107 | if filters.edge_types is not None:
108 | edge_types = filters.edge_types
109 | filter_queries.append('e.name in $edge_types')
110 | filter_params['edge_types'] = edge_types
111 |
112 | if filters.edge_uuids is not None:
113 | filter_queries.append('e.uuid in $edge_uuids')
114 | filter_params['edge_uuids'] = filters.edge_uuids
115 |
116 | if filters.node_labels is not None:
117 | if provider == GraphProvider.KUZU:
118 | node_label_filter = (
119 | 'list_has_all(n.labels, $labels) AND list_has_all(m.labels, $labels)'
120 | )
121 | filter_params['labels'] = filters.node_labels
122 | else:
123 | node_labels = '|'.join(filters.node_labels)
124 | node_label_filter = 'n:' + node_labels + ' AND m:' + node_labels
125 | filter_queries.append(node_label_filter)
126 |
127 | if filters.valid_at is not None:
128 | valid_at_filter = '('
129 | for i, or_list in enumerate(filters.valid_at):
130 | for j, date_filter in enumerate(or_list):
131 | if date_filter.comparison_operator not in [
132 | ComparisonOperator.is_null,
133 | ComparisonOperator.is_not_null,
134 | ]:
135 | filter_params['valid_at_' + str(j)] = date_filter.date
136 |
137 | and_filters = [
138 | date_filter_query_constructor(
139 | 'e.valid_at', f'$valid_at_{j}', date_filter.comparison_operator
140 | )
141 | for j, date_filter in enumerate(or_list)
142 | ]
143 | and_filter_query = ''
144 | for j, and_filter in enumerate(and_filters):
145 | and_filter_query += and_filter
146 | if j != len(and_filters) - 1:
147 | and_filter_query += ' AND '
148 |
149 | valid_at_filter += and_filter_query
150 |
151 | if i == len(filters.valid_at) - 1:
152 | valid_at_filter += ')'
153 | else:
154 | valid_at_filter += ' OR '
155 |
156 | filter_queries.append(valid_at_filter)
157 |
158 | if filters.invalid_at is not None:
159 | invalid_at_filter = '('
160 | for i, or_list in enumerate(filters.invalid_at):
161 | for j, date_filter in enumerate(or_list):
162 | if date_filter.comparison_operator not in [
163 | ComparisonOperator.is_null,
164 | ComparisonOperator.is_not_null,
165 | ]:
166 | filter_params['invalid_at_' + str(j)] = date_filter.date
167 |
168 | and_filters = [
169 | date_filter_query_constructor(
170 | 'e.invalid_at', f'$invalid_at_{j}', date_filter.comparison_operator
171 | )
172 | for j, date_filter in enumerate(or_list)
173 | ]
174 | and_filter_query = ''
175 | for j, and_filter in enumerate(and_filters):
176 | and_filter_query += and_filter
177 | if j != len(and_filters) - 1:
178 | and_filter_query += ' AND '
179 |
180 | invalid_at_filter += and_filter_query
181 |
182 | if i == len(filters.invalid_at) - 1:
183 | invalid_at_filter += ')'
184 | else:
185 | invalid_at_filter += ' OR '
186 |
187 | filter_queries.append(invalid_at_filter)
188 |
189 | if filters.created_at is not None:
190 | created_at_filter = '('
191 | for i, or_list in enumerate(filters.created_at):
192 | for j, date_filter in enumerate(or_list):
193 | if date_filter.comparison_operator not in [
194 | ComparisonOperator.is_null,
195 | ComparisonOperator.is_not_null,
196 | ]:
197 | filter_params['created_at_' + str(j)] = date_filter.date
198 |
199 | and_filters = [
200 | date_filter_query_constructor(
201 | 'e.created_at', f'$created_at_{j}', date_filter.comparison_operator
202 | )
203 | for j, date_filter in enumerate(or_list)
204 | ]
205 | and_filter_query = ''
206 | for j, and_filter in enumerate(and_filters):
207 | and_filter_query += and_filter
208 | if j != len(and_filters) - 1:
209 | and_filter_query += ' AND '
210 |
211 | created_at_filter += and_filter_query
212 |
213 | if i == len(filters.created_at) - 1:
214 | created_at_filter += ')'
215 | else:
216 | created_at_filter += ' OR '
217 |
218 | filter_queries.append(created_at_filter)
219 |
220 | if filters.expired_at is not None:
221 | expired_at_filter = '('
222 | for i, or_list in enumerate(filters.expired_at):
223 | for j, date_filter in enumerate(or_list):
224 | if date_filter.comparison_operator not in [
225 | ComparisonOperator.is_null,
226 | ComparisonOperator.is_not_null,
227 | ]:
228 | filter_params['expired_at_' + str(j)] = date_filter.date
229 |
230 | and_filters = [
231 | date_filter_query_constructor(
232 | 'e.expired_at', f'$expired_at_{j}', date_filter.comparison_operator
233 | )
234 | for j, date_filter in enumerate(or_list)
235 | ]
236 | and_filter_query = ''
237 | for j, and_filter in enumerate(and_filters):
238 | and_filter_query += and_filter
239 | if j != len(and_filters) - 1:
240 | and_filter_query += ' AND '
241 |
242 | expired_at_filter += and_filter_query
243 |
244 | if i == len(filters.expired_at) - 1:
245 | expired_at_filter += ')'
246 | else:
247 | expired_at_filter += ' OR '
248 |
249 | filter_queries.append(expired_at_filter)
250 |
251 | return filter_queries, filter_params
252 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_http_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Integration test for MCP server using HTTP streaming transport.
4 | This avoids the stdio subprocess timing issues.
5 | """
6 |
7 | import asyncio
8 | import json
9 | import sys
10 | import time
11 |
12 | from mcp.client.session import ClientSession
13 |
14 |
15 | async def test_http_transport(base_url: str = 'http://localhost:8000'):
16 | """Test MCP server with HTTP streaming transport."""
17 |
18 | # Import the streamable http client
19 | try:
20 | from mcp.client.streamable_http import streamablehttp_client as http_client
21 | except ImportError:
22 | print('❌ Streamable HTTP client not available in MCP SDK')
23 | return False
24 |
25 | test_group_id = f'test_http_{int(time.time())}'
26 |
27 | print('🚀 Testing MCP Server with HTTP streaming transport')
28 | print(f' Server URL: {base_url}')
29 | print(f' Test Group: {test_group_id}')
30 | print('=' * 60)
31 |
32 | try:
33 | # Connect to the server via HTTP
34 | print('\n🔌 Connecting to server...')
35 | async with http_client(base_url) as (read_stream, write_stream):
36 | session = ClientSession(read_stream, write_stream)
37 | await session.initialize()
38 | print('✅ Connected successfully')
39 |
40 | # Test 1: List tools
41 | print('\n📋 Test 1: Listing tools...')
42 | try:
43 | result = await session.list_tools()
44 | tools = [tool.name for tool in result.tools]
45 |
46 | expected = [
47 | 'add_memory',
48 | 'search_memory_nodes',
49 | 'search_memory_facts',
50 | 'get_episodes',
51 | 'delete_episode',
52 | 'clear_graph',
53 | ]
54 |
55 | found = [t for t in expected if t in tools]
56 | print(f' ✅ Found {len(tools)} tools ({len(found)}/{len(expected)} expected)')
57 | for tool in tools[:5]:
58 | print(f' - {tool}')
59 |
60 | except Exception as e:
61 | print(f' ❌ Failed: {e}')
62 | return False
63 |
64 | # Test 2: Add memory
65 | print('\n📝 Test 2: Adding memory...')
66 | try:
67 | result = await session.call_tool(
68 | 'add_memory',
69 | {
70 | 'name': 'Integration Test Episode',
71 | 'episode_body': 'This is a test episode created via HTTP transport integration test.',
72 | 'group_id': test_group_id,
73 | 'source': 'text',
74 | 'source_description': 'HTTP Integration Test',
75 | },
76 | )
77 |
78 | if result.content and result.content[0].text:
79 | response = result.content[0].text
80 | if 'success' in response.lower() or 'queued' in response.lower():
81 | print(' ✅ Memory added successfully')
82 | else:
83 | print(f' ❌ Unexpected response: {response[:100]}')
84 | else:
85 | print(' ❌ No content in response')
86 |
87 | except Exception as e:
88 | print(f' ❌ Failed: {e}')
89 |
90 | # Test 3: Search nodes (with delay for processing)
91 | print('\n🔍 Test 3: Searching nodes...')
92 | await asyncio.sleep(2) # Wait for async processing
93 |
94 | try:
95 | result = await session.call_tool(
96 | 'search_memory_nodes',
97 | {'query': 'integration test episode', 'group_ids': [test_group_id], 'limit': 5},
98 | )
99 |
100 | if result.content and result.content[0].text:
101 | response = result.content[0].text
102 | try:
103 | data = json.loads(response)
104 | nodes = data.get('nodes', [])
105 | print(f' ✅ Search returned {len(nodes)} nodes')
106 | except Exception: # noqa: E722
107 | print(f' ✅ Search completed: {response[:100]}')
108 | else:
109 | print(' ⚠️ No results (may be processing)')
110 |
111 | except Exception as e:
112 | print(f' ❌ Failed: {e}')
113 |
114 | # Test 4: Get episodes
115 | print('\n📚 Test 4: Getting episodes...')
116 | try:
117 | result = await session.call_tool(
118 | 'get_episodes', {'group_ids': [test_group_id], 'limit': 10}
119 | )
120 |
121 | if result.content and result.content[0].text:
122 | response = result.content[0].text
123 | try:
124 | data = json.loads(response)
125 | episodes = data.get('episodes', [])
126 | print(f' ✅ Found {len(episodes)} episodes')
127 | except Exception: # noqa: E722
128 | print(f' ✅ Episodes retrieved: {response[:100]}')
129 | else:
130 | print(' ⚠️ No episodes found')
131 |
132 | except Exception as e:
133 | print(f' ❌ Failed: {e}')
134 |
135 | # Test 5: Clear graph
136 | print('\n🧹 Test 5: Clearing graph...')
137 | try:
138 | result = await session.call_tool('clear_graph', {'group_id': test_group_id})
139 |
140 | if result.content and result.content[0].text:
141 | response = result.content[0].text
142 | if 'success' in response.lower() or 'cleared' in response.lower():
143 | print(' ✅ Graph cleared successfully')
144 | else:
145 | print(f' ✅ Clear completed: {response[:100]}')
146 | else:
147 | print(' ❌ No response')
148 |
149 | except Exception as e:
150 | print(f' ❌ Failed: {e}')
151 |
152 | print('\n' + '=' * 60)
153 | print('✅ All integration tests completed!')
154 | return True
155 |
156 | except Exception as e:
157 | print(f'\n❌ Connection failed: {e}')
158 | return False
159 |
160 |
161 | async def test_sse_transport(base_url: str = 'http://localhost:8000'):
162 | """Test MCP server with SSE transport."""
163 |
164 | # Import the SSE client
165 | try:
166 | from mcp.client.sse import sse_client
167 | except ImportError:
168 | print('❌ SSE client not available in MCP SDK')
169 | return False
170 |
171 | test_group_id = f'test_sse_{int(time.time())}'
172 |
173 | print('🚀 Testing MCP Server with SSE transport')
174 | print(f' Server URL: {base_url}/sse')
175 | print(f' Test Group: {test_group_id}')
176 | print('=' * 60)
177 |
178 | try:
179 | # Connect to the server via SSE
180 | print('\n🔌 Connecting to server...')
181 | async with sse_client(f'{base_url}/sse') as (read_stream, write_stream):
182 | session = ClientSession(read_stream, write_stream)
183 | await session.initialize()
184 | print('✅ Connected successfully')
185 |
186 | # Run same tests as HTTP
187 | print('\n📋 Test 1: Listing tools...')
188 | try:
189 | result = await session.list_tools()
190 | tools = [tool.name for tool in result.tools]
191 | print(f' ✅ Found {len(tools)} tools')
192 | for tool in tools[:3]:
193 | print(f' - {tool}')
194 | except Exception as e:
195 | print(f' ❌ Failed: {e}')
196 | return False
197 |
198 | print('\n' + '=' * 60)
199 | print('✅ SSE transport test completed!')
200 | return True
201 |
202 | except Exception as e:
203 | print(f'\n❌ SSE connection failed: {e}')
204 | return False
205 |
206 |
207 | async def main():
208 | """Run integration tests."""
209 |
210 | # Check command line arguments
211 | if len(sys.argv) < 2:
212 | print('Usage: python test_http_integration.py <transport> [host] [port]')
213 | print(' transport: http or sse')
214 | print(' host: server host (default: localhost)')
215 | print(' port: server port (default: 8000)')
216 | sys.exit(1)
217 |
218 | transport = sys.argv[1].lower()
219 | host = sys.argv[2] if len(sys.argv) > 2 else 'localhost'
220 | port = sys.argv[3] if len(sys.argv) > 3 else '8000'
221 | base_url = f'http://{host}:{port}'
222 |
223 | # Check if server is running
224 | import httpx
225 |
226 | try:
227 | async with httpx.AsyncClient() as client:
228 | # Try to connect to the server
229 | await client.get(base_url, timeout=2.0)
230 | except Exception: # noqa: E722
231 | print(f'⚠️ Server not responding at {base_url}')
232 | print('Please start the server with one of these commands:')
233 | print(f' uv run main.py --transport http --port {port}')
234 | print(f' uv run main.py --transport sse --port {port}')
235 | sys.exit(1)
236 |
237 | # Run the appropriate test
238 | if transport == 'http':
239 | success = await test_http_transport(base_url)
240 | elif transport == 'sse':
241 | success = await test_sse_transport(base_url)
242 | else:
243 | print(f'❌ Unknown transport: {transport}')
244 | sys.exit(1)
245 |
246 | sys.exit(0 if success else 1)
247 |
248 |
249 | if __name__ == '__main__':
250 | asyncio.run(main())
251 |
```
--------------------------------------------------------------------------------
/graphiti_core/utils/maintenance/dedup_helpers.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2024, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | from __future__ import annotations
18 |
19 | import math
20 | import re
21 | from collections import defaultdict
22 | from collections.abc import Iterable
23 | from dataclasses import dataclass, field
24 | from functools import lru_cache
25 | from hashlib import blake2b
26 | from typing import TYPE_CHECKING
27 |
28 | if TYPE_CHECKING:
29 | from graphiti_core.nodes import EntityNode
30 |
31 | _NAME_ENTROPY_THRESHOLD = 1.5
32 | _MIN_NAME_LENGTH = 6
33 | _MIN_TOKEN_COUNT = 2
34 | _FUZZY_JACCARD_THRESHOLD = 0.9
35 | _MINHASH_PERMUTATIONS = 32
36 | _MINHASH_BAND_SIZE = 4
37 |
38 |
39 | def _normalize_string_exact(name: str) -> str:
40 | """Lowercase text and collapse whitespace so equal names map to the same key."""
41 | normalized = re.sub(r'[\s]+', ' ', name.lower())
42 | return normalized.strip()
43 |
44 |
45 | def _normalize_name_for_fuzzy(name: str) -> str:
46 | """Produce a fuzzier form that keeps alphanumerics and apostrophes for n-gram shingles."""
47 | normalized = re.sub(r"[^a-z0-9' ]", ' ', _normalize_string_exact(name))
48 | normalized = normalized.strip()
49 | return re.sub(r'[\s]+', ' ', normalized)
50 |
51 |
52 | def _name_entropy(normalized_name: str) -> float:
53 | """Approximate text specificity using Shannon entropy over characters.
54 |
55 | We strip spaces, count how often each character appears, and sum
56 | probability * -log2(probability). Short or repetitive names yield low
57 | entropy, which signals we should defer resolution to the LLM instead of
58 | trusting fuzzy similarity.
59 | """
60 | if not normalized_name:
61 | return 0.0
62 |
63 | counts: dict[str, int] = {}
64 | for char in normalized_name.replace(' ', ''):
65 | counts[char] = counts.get(char, 0) + 1
66 |
67 | total = sum(counts.values())
68 | if total == 0:
69 | return 0.0
70 |
71 | entropy = 0.0
72 | for count in counts.values():
73 | probability = count / total
74 | entropy -= probability * math.log2(probability)
75 |
76 | return entropy
77 |
78 |
79 | def _has_high_entropy(normalized_name: str) -> bool:
80 | """Filter out very short or low-entropy names that are unreliable for fuzzy matching."""
81 | token_count = len(normalized_name.split())
82 | if len(normalized_name) < _MIN_NAME_LENGTH and token_count < _MIN_TOKEN_COUNT:
83 | return False
84 |
85 | return _name_entropy(normalized_name) >= _NAME_ENTROPY_THRESHOLD
86 |
87 |
88 | def _shingles(normalized_name: str) -> set[str]:
89 | """Create 3-gram shingles from the normalized name for MinHash calculations."""
90 | cleaned = normalized_name.replace(' ', '')
91 | if len(cleaned) < 2:
92 | return {cleaned} if cleaned else set()
93 |
94 | return {cleaned[i : i + 3] for i in range(len(cleaned) - 2)}
95 |
96 |
97 | def _hash_shingle(shingle: str, seed: int) -> int:
98 | """Generate a deterministic 64-bit hash for a shingle given the permutation seed."""
99 | digest = blake2b(f'{seed}:{shingle}'.encode(), digest_size=8)
100 | return int.from_bytes(digest.digest(), 'big')
101 |
102 |
103 | def _minhash_signature(shingles: Iterable[str]) -> tuple[int, ...]:
104 | """Compute the MinHash signature for the shingle set across predefined permutations."""
105 | if not shingles:
106 | return tuple()
107 |
108 | seeds = range(_MINHASH_PERMUTATIONS)
109 | signature: list[int] = []
110 | for seed in seeds:
111 | min_hash = min(_hash_shingle(shingle, seed) for shingle in shingles)
112 | signature.append(min_hash)
113 |
114 | return tuple(signature)
115 |
116 |
117 | def _lsh_bands(signature: Iterable[int]) -> list[tuple[int, ...]]:
118 | """Split the MinHash signature into fixed-size bands for locality-sensitive hashing."""
119 | signature_list = list(signature)
120 | if not signature_list:
121 | return []
122 |
123 | bands: list[tuple[int, ...]] = []
124 | for start in range(0, len(signature_list), _MINHASH_BAND_SIZE):
125 | band = tuple(signature_list[start : start + _MINHASH_BAND_SIZE])
126 | if len(band) == _MINHASH_BAND_SIZE:
127 | bands.append(band)
128 | return bands
129 |
130 |
131 | def _jaccard_similarity(a: set[str], b: set[str]) -> float:
132 | """Return the Jaccard similarity between two shingle sets, handling empty edge cases."""
133 | if not a and not b:
134 | return 1.0
135 | if not a or not b:
136 | return 0.0
137 |
138 | intersection = len(a.intersection(b))
139 | union = len(a.union(b))
140 | return intersection / union if union else 0.0
141 |
142 |
143 | @lru_cache(maxsize=512)
144 | def _cached_shingles(name: str) -> set[str]:
145 | """Cache shingle sets per normalized name to avoid recomputation within a worker."""
146 | return _shingles(name)
147 |
148 |
149 | @dataclass
150 | class DedupCandidateIndexes:
151 | """Precomputed lookup structures that drive entity deduplication heuristics."""
152 |
153 | existing_nodes: list[EntityNode]
154 | nodes_by_uuid: dict[str, EntityNode]
155 | normalized_existing: defaultdict[str, list[EntityNode]]
156 | shingles_by_candidate: dict[str, set[str]]
157 | lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]]
158 |
159 |
160 | @dataclass
161 | class DedupResolutionState:
162 | """Mutable resolution bookkeeping shared across deterministic and LLM passes."""
163 |
164 | resolved_nodes: list[EntityNode | None]
165 | uuid_map: dict[str, str]
166 | unresolved_indices: list[int]
167 | duplicate_pairs: list[tuple[EntityNode, EntityNode]] = field(default_factory=list)
168 |
169 |
170 | def _build_candidate_indexes(existing_nodes: list[EntityNode]) -> DedupCandidateIndexes:
171 | """Precompute exact and fuzzy lookup structures once per dedupe run."""
172 | normalized_existing: defaultdict[str, list[EntityNode]] = defaultdict(list)
173 | nodes_by_uuid: dict[str, EntityNode] = {}
174 | shingles_by_candidate: dict[str, set[str]] = {}
175 | lsh_buckets: defaultdict[tuple[int, tuple[int, ...]], list[str]] = defaultdict(list)
176 |
177 | for candidate in existing_nodes:
178 | normalized = _normalize_string_exact(candidate.name)
179 | normalized_existing[normalized].append(candidate)
180 | nodes_by_uuid[candidate.uuid] = candidate
181 |
182 | shingles = _cached_shingles(_normalize_name_for_fuzzy(candidate.name))
183 | shingles_by_candidate[candidate.uuid] = shingles
184 |
185 | signature = _minhash_signature(shingles)
186 | for band_index, band in enumerate(_lsh_bands(signature)):
187 | lsh_buckets[(band_index, band)].append(candidate.uuid)
188 |
189 | return DedupCandidateIndexes(
190 | existing_nodes=existing_nodes,
191 | nodes_by_uuid=nodes_by_uuid,
192 | normalized_existing=normalized_existing,
193 | shingles_by_candidate=shingles_by_candidate,
194 | lsh_buckets=lsh_buckets,
195 | )
196 |
197 |
198 | def _resolve_with_similarity(
199 | extracted_nodes: list[EntityNode],
200 | indexes: DedupCandidateIndexes,
201 | state: DedupResolutionState,
202 | ) -> None:
203 | """Attempt deterministic resolution using exact name hits and fuzzy MinHash comparisons."""
204 | for idx, node in enumerate(extracted_nodes):
205 | normalized_exact = _normalize_string_exact(node.name)
206 | normalized_fuzzy = _normalize_name_for_fuzzy(node.name)
207 |
208 | if not _has_high_entropy(normalized_fuzzy):
209 | state.unresolved_indices.append(idx)
210 | continue
211 |
212 | existing_matches = indexes.normalized_existing.get(normalized_exact, [])
213 | if len(existing_matches) == 1:
214 | match = existing_matches[0]
215 | state.resolved_nodes[idx] = match
216 | state.uuid_map[node.uuid] = match.uuid
217 | if match.uuid != node.uuid:
218 | state.duplicate_pairs.append((node, match))
219 | continue
220 | if len(existing_matches) > 1:
221 | state.unresolved_indices.append(idx)
222 | continue
223 |
224 | shingles = _cached_shingles(normalized_fuzzy)
225 | signature = _minhash_signature(shingles)
226 | candidate_ids: set[str] = set()
227 | for band_index, band in enumerate(_lsh_bands(signature)):
228 | candidate_ids.update(indexes.lsh_buckets.get((band_index, band), []))
229 |
230 | best_candidate: EntityNode | None = None
231 | best_score = 0.0
232 | for candidate_id in candidate_ids:
233 | candidate_shingles = indexes.shingles_by_candidate.get(candidate_id, set())
234 | score = _jaccard_similarity(shingles, candidate_shingles)
235 | if score > best_score:
236 | best_score = score
237 | best_candidate = indexes.nodes_by_uuid.get(candidate_id)
238 |
239 | if best_candidate is not None and best_score >= _FUZZY_JACCARD_THRESHOLD:
240 | state.resolved_nodes[idx] = best_candidate
241 | state.uuid_map[node.uuid] = best_candidate.uuid
242 | if best_candidate.uuid != node.uuid:
243 | state.duplicate_pairs.append((node, best_candidate))
244 | continue
245 |
246 | state.unresolved_indices.append(idx)
247 |
248 |
249 | __all__ = [
250 | 'DedupCandidateIndexes',
251 | 'DedupResolutionState',
252 | '_normalize_string_exact',
253 | '_normalize_name_for_fuzzy',
254 | '_has_high_entropy',
255 | '_minhash_signature',
256 | '_lsh_bands',
257 | '_jaccard_similarity',
258 | '_cached_shingles',
259 | '_FUZZY_JACCARD_THRESHOLD',
260 | '_build_candidate_indexes',
261 | '_resolve_with_similarity',
262 | ]
263 |
```
--------------------------------------------------------------------------------
/examples/quickstart/quickstart_neo4j.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2025, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import asyncio
18 | import json
19 | import logging
20 | import os
21 | from datetime import datetime, timezone
22 | from logging import INFO
23 |
24 | from dotenv import load_dotenv
25 |
26 | from graphiti_core import Graphiti
27 | from graphiti_core.nodes import EpisodeType
28 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
29 |
30 | #################################################
31 | # CONFIGURATION
32 | #################################################
33 | # Set up logging and environment variables for
34 | # connecting to Neo4j database
35 | #################################################
36 |
37 | # Configure logging
38 | logging.basicConfig(
39 | level=INFO,
40 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
41 | datefmt='%Y-%m-%d %H:%M:%S',
42 | )
43 | logger = logging.getLogger(__name__)
44 |
45 | load_dotenv()
46 |
47 | # Neo4j connection parameters
48 | # Make sure Neo4j Desktop is running with a local DBMS started
49 | neo4j_uri = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
50 | neo4j_user = os.environ.get('NEO4J_USER', 'neo4j')
51 | neo4j_password = os.environ.get('NEO4J_PASSWORD', 'password')
52 |
53 | if not neo4j_uri or not neo4j_user or not neo4j_password:
54 | raise ValueError('NEO4J_URI, NEO4J_USER, and NEO4J_PASSWORD must be set')
55 |
56 |
57 | async def main():
58 | #################################################
59 | # INITIALIZATION
60 | #################################################
61 | # Connect to Neo4j and set up Graphiti indices
62 | # This is required before using other Graphiti
63 | # functionality
64 | #################################################
65 |
66 | # Initialize Graphiti with Neo4j connection
67 | graphiti = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
68 |
69 | try:
70 | #################################################
71 | # ADDING EPISODES
72 | #################################################
73 | # Episodes are the primary units of information
74 | # in Graphiti. They can be text or structured JSON
75 | # and are automatically processed to extract entities
76 | # and relationships.
77 | #################################################
78 |
79 | # Example: Add Episodes
80 | # Episodes list containing both text and JSON episodes
81 | episodes = [
82 | {
83 | 'content': 'Kamala Harris is the Attorney General of California. She was previously '
84 | 'the district attorney for San Francisco.',
85 | 'type': EpisodeType.text,
86 | 'description': 'podcast transcript',
87 | },
88 | {
89 | 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
90 | 'type': EpisodeType.text,
91 | 'description': 'podcast transcript',
92 | },
93 | {
94 | 'content': {
95 | 'name': 'Gavin Newsom',
96 | 'position': 'Governor',
97 | 'state': 'California',
98 | 'previous_role': 'Lieutenant Governor',
99 | 'previous_location': 'San Francisco',
100 | },
101 | 'type': EpisodeType.json,
102 | 'description': 'podcast metadata',
103 | },
104 | {
105 | 'content': {
106 | 'name': 'Gavin Newsom',
107 | 'position': 'Governor',
108 | 'term_start': 'January 7, 2019',
109 | 'term_end': 'Present',
110 | },
111 | 'type': EpisodeType.json,
112 | 'description': 'podcast metadata',
113 | },
114 | ]
115 |
116 | # Add episodes to the graph
117 | for i, episode in enumerate(episodes):
118 | await graphiti.add_episode(
119 | name=f'Freakonomics Radio {i}',
120 | episode_body=episode['content']
121 | if isinstance(episode['content'], str)
122 | else json.dumps(episode['content']),
123 | source=episode['type'],
124 | source_description=episode['description'],
125 | reference_time=datetime.now(timezone.utc),
126 | )
127 | print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
128 |
129 | #################################################
130 | # BASIC SEARCH
131 | #################################################
132 | # The simplest way to retrieve relationships (edges)
133 | # from Graphiti is using the search method, which
134 | # performs a hybrid search combining semantic
135 | # similarity and BM25 text retrieval.
136 | #################################################
137 |
138 | # Perform a hybrid search combining semantic similarity and BM25 retrieval
139 | print("\nSearching for: 'Who was the California Attorney General?'")
140 | results = await graphiti.search('Who was the California Attorney General?')
141 |
142 | # Print search results
143 | print('\nSearch Results:')
144 | for result in results:
145 | print(f'UUID: {result.uuid}')
146 | print(f'Fact: {result.fact}')
147 | if hasattr(result, 'valid_at') and result.valid_at:
148 | print(f'Valid from: {result.valid_at}')
149 | if hasattr(result, 'invalid_at') and result.invalid_at:
150 | print(f'Valid until: {result.invalid_at}')
151 | print('---')
152 |
153 | #################################################
154 | # CENTER NODE SEARCH
155 | #################################################
156 | # For more contextually relevant results, you can
157 | # use a center node to rerank search results based
158 | # on their graph distance to a specific node
159 | #################################################
160 |
161 | # Use the top search result's UUID as the center node for reranking
162 | if results and len(results) > 0:
163 | # Get the source node UUID from the top result
164 | center_node_uuid = results[0].source_node_uuid
165 |
166 | print('\nReranking search results based on graph distance:')
167 | print(f'Using center node UUID: {center_node_uuid}')
168 |
169 | reranked_results = await graphiti.search(
170 | 'Who was the California Attorney General?', center_node_uuid=center_node_uuid
171 | )
172 |
173 | # Print reranked search results
174 | print('\nReranked Search Results:')
175 | for result in reranked_results:
176 | print(f'UUID: {result.uuid}')
177 | print(f'Fact: {result.fact}')
178 | if hasattr(result, 'valid_at') and result.valid_at:
179 | print(f'Valid from: {result.valid_at}')
180 | if hasattr(result, 'invalid_at') and result.invalid_at:
181 | print(f'Valid until: {result.invalid_at}')
182 | print('---')
183 | else:
184 | print('No results found in the initial search to use as center node.')
185 |
186 | #################################################
187 | # NODE SEARCH USING SEARCH RECIPES
188 | #################################################
189 | # Graphiti provides predefined search recipes
190 | # optimized for different search scenarios.
191 | # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
192 | # nodes directly instead of edges.
193 | #################################################
194 |
195 | # Example: Perform a node search using _search method with standard recipes
196 | print(
197 | '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
198 | )
199 |
200 | # Use a predefined search configuration recipe and modify its limit
201 | node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
202 | node_search_config.limit = 5 # Limit to 5 results
203 |
204 | # Execute the node search
205 | node_search_results = await graphiti._search(
206 | query='California Governor',
207 | config=node_search_config,
208 | )
209 |
210 | # Print node search results
211 | print('\nNode Search Results:')
212 | for node in node_search_results.nodes:
213 | print(f'Node UUID: {node.uuid}')
214 | print(f'Node Name: {node.name}')
215 | node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
216 | print(f'Content Summary: {node_summary}')
217 | print(f'Node Labels: {", ".join(node.labels)}')
218 | print(f'Created At: {node.created_at}')
219 | if hasattr(node, 'attributes') and node.attributes:
220 | print('Attributes:')
221 | for key, value in node.attributes.items():
222 | print(f' {key}: {value}')
223 | print('---')
224 |
225 | finally:
226 | #################################################
227 | # CLEANUP
228 | #################################################
229 | # Always close the connection to Neo4j when
230 | # finished to properly release resources
231 | #################################################
232 |
233 | # Close the connection
234 | await graphiti.close()
235 | print('\nConnection closed')
236 |
237 |
238 | if __name__ == '__main__':
239 | asyncio.run(main())
240 |
```
--------------------------------------------------------------------------------
/mcp_server/tests/test_mcp_transports.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Test MCP server with different transport modes using the MCP SDK.
4 | Tests both SSE and streaming HTTP transports.
5 | """
6 |
7 | import asyncio
8 | import json
9 | import sys
10 | import time
11 |
12 | from mcp.client.session import ClientSession
13 | from mcp.client.sse import sse_client
14 |
15 |
16 | class MCPTransportTester:
17 | """Test MCP server with different transport modes."""
18 |
19 | def __init__(self, transport: str = 'sse', host: str = 'localhost', port: int = 8000):
20 | self.transport = transport
21 | self.host = host
22 | self.port = port
23 | self.base_url = f'http://{host}:{port}'
24 | self.test_group_id = f'test_{transport}_{int(time.time())}'
25 | self.session = None
26 |
27 | async def connect_sse(self) -> ClientSession:
28 | """Connect using SSE transport."""
29 | print(f'🔌 Connecting to MCP server via SSE at {self.base_url}/sse')
30 |
31 | # Use the sse_client to connect
32 | async with sse_client(self.base_url + '/sse') as (read_stream, write_stream):
33 | self.session = ClientSession(read_stream, write_stream)
34 | await self.session.initialize()
35 | return self.session
36 |
37 | async def connect_http(self) -> ClientSession:
38 | """Connect using streaming HTTP transport."""
39 | from mcp.client.http import http_client
40 |
41 | print(f'🔌 Connecting to MCP server via HTTP at {self.base_url}')
42 |
43 | # Use the http_client to connect
44 | async with http_client(self.base_url) as (read_stream, write_stream):
45 | self.session = ClientSession(read_stream, write_stream)
46 | await self.session.initialize()
47 | return self.session
48 |
49 | async def test_list_tools(self) -> bool:
50 | """Test listing available tools."""
51 | print('\n📋 Testing list_tools...')
52 |
53 | try:
54 | result = await self.session.list_tools()
55 | tools = [tool.name for tool in result.tools]
56 |
57 | expected_tools = [
58 | 'add_memory',
59 | 'search_memory_nodes',
60 | 'search_memory_facts',
61 | 'get_episodes',
62 | 'delete_episode',
63 | 'get_entity_edge',
64 | 'delete_entity_edge',
65 | 'clear_graph',
66 | ]
67 |
68 | print(f' ✅ Found {len(tools)} tools')
69 | for tool in tools[:5]: # Show first 5 tools
70 | print(f' - {tool}')
71 |
72 | # Check if we have most expected tools
73 | found_tools = [t for t in expected_tools if t in tools]
74 | success = len(found_tools) >= len(expected_tools) * 0.8
75 |
76 | if success:
77 | print(
78 | f' ✅ Tool discovery successful ({len(found_tools)}/{len(expected_tools)} expected tools)'
79 | )
80 | else:
81 | print(f' ❌ Missing too many tools ({len(found_tools)}/{len(expected_tools)})')
82 |
83 | return success
84 | except Exception as e:
85 | print(f' ❌ Failed to list tools: {e}')
86 | return False
87 |
88 | async def test_add_memory(self) -> bool:
89 | """Test adding a memory."""
90 | print('\n📝 Testing add_memory...')
91 |
92 | try:
93 | result = await self.session.call_tool(
94 | 'add_memory',
95 | {
96 | 'name': 'Test Episode',
97 | 'episode_body': 'This is a test episode created by the MCP transport test suite.',
98 | 'group_id': self.test_group_id,
99 | 'source': 'text',
100 | 'source_description': 'Integration test',
101 | },
102 | )
103 |
104 | # Check the result
105 | if result.content:
106 | content = result.content[0]
107 | if hasattr(content, 'text'):
108 | response = (
109 | json.loads(content.text)
110 | if content.text.startswith('{')
111 | else {'message': content.text}
112 | )
113 | if 'success' in str(response).lower() or 'queued' in str(response).lower():
114 | print(f' ✅ Memory added successfully: {response.get("message", "OK")}')
115 | return True
116 | else:
117 | print(f' ❌ Unexpected response: {response}')
118 | return False
119 |
120 | print(' ❌ No content in response')
121 | return False
122 |
123 | except Exception as e:
124 | print(f' ❌ Failed to add memory: {e}')
125 | return False
126 |
127 | async def test_search_nodes(self) -> bool:
128 | """Test searching for nodes."""
129 | print('\n🔍 Testing search_memory_nodes...')
130 |
131 | # Wait a bit for the memory to be processed
132 | await asyncio.sleep(2)
133 |
134 | try:
135 | result = await self.session.call_tool(
136 | 'search_memory_nodes',
137 | {'query': 'test episode', 'group_ids': [self.test_group_id], 'limit': 5},
138 | )
139 |
140 | if result.content:
141 | content = result.content[0]
142 | if hasattr(content, 'text'):
143 | response = (
144 | json.loads(content.text) if content.text.startswith('{') else {'nodes': []}
145 | )
146 | nodes = response.get('nodes', [])
147 | print(f' ✅ Search returned {len(nodes)} nodes')
148 | return True
149 |
150 | print(' ⚠️ No nodes found (this may be expected if processing is async)')
151 | return True # Don't fail on empty results
152 |
153 | except Exception as e:
154 | print(f' ❌ Failed to search nodes: {e}')
155 | return False
156 |
157 | async def test_get_episodes(self) -> bool:
158 | """Test getting episodes."""
159 | print('\n📚 Testing get_episodes...')
160 |
161 | try:
162 | result = await self.session.call_tool(
163 | 'get_episodes', {'group_ids': [self.test_group_id], 'limit': 10}
164 | )
165 |
166 | if result.content:
167 | content = result.content[0]
168 | if hasattr(content, 'text'):
169 | response = (
170 | json.loads(content.text)
171 | if content.text.startswith('{')
172 | else {'episodes': []}
173 | )
174 | episodes = response.get('episodes', [])
175 | print(f' ✅ Found {len(episodes)} episodes')
176 | return True
177 |
178 | print(' ⚠️ No episodes found')
179 | return True
180 |
181 | except Exception as e:
182 | print(f' ❌ Failed to get episodes: {e}')
183 | return False
184 |
185 | async def test_clear_graph(self) -> bool:
186 | """Test clearing the graph."""
187 | print('\n🧹 Testing clear_graph...')
188 |
189 | try:
190 | result = await self.session.call_tool('clear_graph', {'group_id': self.test_group_id})
191 |
192 | if result.content:
193 | content = result.content[0]
194 | if hasattr(content, 'text'):
195 | response = content.text
196 | if 'success' in response.lower() or 'cleared' in response.lower():
197 | print(' ✅ Graph cleared successfully')
198 | return True
199 |
200 | print(' ❌ Failed to clear graph')
201 | return False
202 |
203 | except Exception as e:
204 | print(f' ❌ Failed to clear graph: {e}')
205 | return False
206 |
207 | async def run_tests(self) -> bool:
208 | """Run all tests for the configured transport."""
209 | print(f'\n{"=" * 60}')
210 | print(f'🚀 Testing MCP Server with {self.transport.upper()} transport')
211 | print(f' Server: {self.base_url}')
212 | print(f' Test Group: {self.test_group_id}')
213 | print('=' * 60)
214 |
215 | try:
216 | # Connect based on transport type
217 | if self.transport == 'sse':
218 | await self.connect_sse()
219 | elif self.transport == 'http':
220 | await self.connect_http()
221 | else:
222 | print(f'❌ Unknown transport: {self.transport}')
223 | return False
224 |
225 | print(f'✅ Connected via {self.transport.upper()}')
226 |
227 | # Run tests
228 | results = []
229 | results.append(await self.test_list_tools())
230 | results.append(await self.test_add_memory())
231 | results.append(await self.test_search_nodes())
232 | results.append(await self.test_get_episodes())
233 | results.append(await self.test_clear_graph())
234 |
235 | # Summary
236 | passed = sum(results)
237 | total = len(results)
238 | success = passed == total
239 |
240 | print(f'\n{"=" * 60}')
241 | print(f'📊 Results for {self.transport.upper()} transport:')
242 | print(f' Passed: {passed}/{total}')
243 | print(f' Status: {"✅ ALL TESTS PASSED" if success else "❌ SOME TESTS FAILED"}')
244 | print('=' * 60)
245 |
246 | return success
247 |
248 | except Exception as e:
249 | print(f'❌ Test suite failed: {e}')
250 | return False
251 | finally:
252 | if self.session:
253 | await self.session.close()
254 |
255 |
256 | async def main():
257 | """Run tests for both transports."""
258 | # Parse command line arguments
259 | transport = sys.argv[1] if len(sys.argv) > 1 else 'sse'
260 | host = sys.argv[2] if len(sys.argv) > 2 else 'localhost'
261 | port = int(sys.argv[3]) if len(sys.argv) > 3 else 8000
262 |
263 | # Create tester
264 | tester = MCPTransportTester(transport, host, port)
265 |
266 | # Run tests
267 | success = await tester.run_tests()
268 |
269 | # Exit with appropriate code
270 | exit(0 if success else 1)
271 |
272 |
273 | if __name__ == '__main__':
274 | asyncio.run(main())
275 |
```
--------------------------------------------------------------------------------
/examples/quickstart/quickstart_neptune.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Copyright 2025, Zep Software, Inc.
3 |
4 | Licensed under the Apache License, Version 2.0 (the "License");
5 | you may not use this file except in compliance with the License.
6 | You may obtain a copy of the License at
7 |
8 | http://www.apache.org/licenses/LICENSE-2.0
9 |
10 | Unless required by applicable law or agreed to in writing, software
11 | distributed under the License is distributed on an "AS IS" BASIS,
12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | See the License for the specific language governing permissions and
14 | limitations under the License.
15 | """
16 |
17 | import asyncio
18 | import json
19 | import logging
20 | import os
21 | from datetime import datetime, timezone
22 | from logging import INFO
23 |
24 | from dotenv import load_dotenv
25 |
26 | from graphiti_core import Graphiti
27 | from graphiti_core.driver.neptune_driver import NeptuneDriver
28 | from graphiti_core.nodes import EpisodeType
29 | from graphiti_core.search.search_config_recipes import NODE_HYBRID_SEARCH_RRF
30 |
31 | #################################################
32 | # CONFIGURATION
33 | #################################################
34 | # Set up logging and environment variables for
35 | # connecting to Neptune database
36 | #################################################
37 |
38 | # Configure logging
39 | logging.basicConfig(
40 | level=INFO,
41 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
42 | datefmt='%Y-%m-%d %H:%M:%S',
43 | )
44 | logger = logging.getLogger(__name__)
45 |
46 | load_dotenv()
47 |
48 | # Neptune and OpenSearch connection parameters
49 | neptune_uri = os.environ.get('NEPTUNE_HOST')
50 | neptune_port = int(os.environ.get('NEPTUNE_PORT', 8182))
51 | aoss_host = os.environ.get('AOSS_HOST')
52 |
53 | if not neptune_uri:
54 | raise ValueError('NEPTUNE_HOST must be set')
55 |
56 |
57 | if not aoss_host:
58 | raise ValueError('AOSS_HOST must be set')
59 |
60 |
61 | async def main():
62 | #################################################
63 | # INITIALIZATION
64 | #################################################
65 | # Connect to Neptune and set up Graphiti indices
66 | # This is required before using other Graphiti
67 | # functionality
68 | #################################################
69 |
70 | # Initialize Graphiti with Neptune connection
71 | driver = NeptuneDriver(host=neptune_uri, aoss_host=aoss_host, port=neptune_port)
72 |
73 | graphiti = Graphiti(graph_driver=driver)
74 |
75 | try:
76 | # Initialize the graph database with graphiti's indices. This only needs to be done once.
77 | await driver.delete_aoss_indices()
78 | await driver._delete_all_data()
79 | await graphiti.build_indices_and_constraints()
80 |
81 | #################################################
82 | # ADDING EPISODES
83 | #################################################
84 | # Episodes are the primary units of information
85 | # in Graphiti. They can be text or structured JSON
86 | # and are automatically processed to extract entities
87 | # and relationships.
88 | #################################################
89 |
90 | # Example: Add Episodes
91 | # Episodes list containing both text and JSON episodes
92 | episodes = [
93 | {
94 | 'content': 'Kamala Harris is the Attorney General of California. She was previously '
95 | 'the district attorney for San Francisco.',
96 | 'type': EpisodeType.text,
97 | 'description': 'podcast transcript',
98 | },
99 | {
100 | 'content': 'As AG, Harris was in office from January 3, 2011 – January 3, 2017',
101 | 'type': EpisodeType.text,
102 | 'description': 'podcast transcript',
103 | },
104 | {
105 | 'content': {
106 | 'name': 'Gavin Newsom',
107 | 'position': 'Governor',
108 | 'state': 'California',
109 | 'previous_role': 'Lieutenant Governor',
110 | 'previous_location': 'San Francisco',
111 | },
112 | 'type': EpisodeType.json,
113 | 'description': 'podcast metadata',
114 | },
115 | {
116 | 'content': {
117 | 'name': 'Gavin Newsom',
118 | 'position': 'Governor',
119 | 'term_start': 'January 7, 2019',
120 | 'term_end': 'Present',
121 | },
122 | 'type': EpisodeType.json,
123 | 'description': 'podcast metadata',
124 | },
125 | ]
126 |
127 | # Add episodes to the graph
128 | for i, episode in enumerate(episodes):
129 | await graphiti.add_episode(
130 | name=f'Freakonomics Radio {i}',
131 | episode_body=episode['content']
132 | if isinstance(episode['content'], str)
133 | else json.dumps(episode['content']),
134 | source=episode['type'],
135 | source_description=episode['description'],
136 | reference_time=datetime.now(timezone.utc),
137 | )
138 | print(f'Added episode: Freakonomics Radio {i} ({episode["type"].value})')
139 |
140 | await graphiti.build_communities()
141 |
142 | #################################################
143 | # BASIC SEARCH
144 | #################################################
145 | # The simplest way to retrieve relationships (edges)
146 | # from Graphiti is using the search method, which
147 | # performs a hybrid search combining semantic
148 | # similarity and BM25 text retrieval.
149 | #################################################
150 |
151 | # Perform a hybrid search combining semantic similarity and BM25 retrieval
152 | print("\nSearching for: 'Who was the California Attorney General?'")
153 | results = await graphiti.search('Who was the California Attorney General?')
154 |
155 | # Print search results
156 | print('\nSearch Results:')
157 | for result in results:
158 | print(f'UUID: {result.uuid}')
159 | print(f'Fact: {result.fact}')
160 | if hasattr(result, 'valid_at') and result.valid_at:
161 | print(f'Valid from: {result.valid_at}')
162 | if hasattr(result, 'invalid_at') and result.invalid_at:
163 | print(f'Valid until: {result.invalid_at}')
164 | print('---')
165 |
166 | #################################################
167 | # CENTER NODE SEARCH
168 | #################################################
169 | # For more contextually relevant results, you can
170 | # use a center node to rerank search results based
171 | # on their graph distance to a specific node
172 | #################################################
173 |
174 | # Use the top search result's UUID as the center node for reranking
175 | if results and len(results) > 0:
176 | # Get the source node UUID from the top result
177 | center_node_uuid = results[0].source_node_uuid
178 |
179 | print('\nReranking search results based on graph distance:')
180 | print(f'Using center node UUID: {center_node_uuid}')
181 |
182 | reranked_results = await graphiti.search(
183 | 'Who was the California Attorney General?', center_node_uuid=center_node_uuid
184 | )
185 |
186 | # Print reranked search results
187 | print('\nReranked Search Results:')
188 | for result in reranked_results:
189 | print(f'UUID: {result.uuid}')
190 | print(f'Fact: {result.fact}')
191 | if hasattr(result, 'valid_at') and result.valid_at:
192 | print(f'Valid from: {result.valid_at}')
193 | if hasattr(result, 'invalid_at') and result.invalid_at:
194 | print(f'Valid until: {result.invalid_at}')
195 | print('---')
196 | else:
197 | print('No results found in the initial search to use as center node.')
198 |
199 | #################################################
200 | # NODE SEARCH USING SEARCH RECIPES
201 | #################################################
202 | # Graphiti provides predefined search recipes
203 | # optimized for different search scenarios.
204 | # Here we use NODE_HYBRID_SEARCH_RRF for retrieving
205 | # nodes directly instead of edges.
206 | #################################################
207 |
208 | # Example: Perform a node search using _search method with standard recipes
209 | print(
210 | '\nPerforming node search using _search method with standard recipe NODE_HYBRID_SEARCH_RRF:'
211 | )
212 |
213 | # Use a predefined search configuration recipe and modify its limit
214 | node_search_config = NODE_HYBRID_SEARCH_RRF.model_copy(deep=True)
215 | node_search_config.limit = 5 # Limit to 5 results
216 |
217 | # Execute the node search
218 | node_search_results = await graphiti._search(
219 | query='California Governor',
220 | config=node_search_config,
221 | )
222 |
223 | # Print node search results
224 | print('\nNode Search Results:')
225 | for node in node_search_results.nodes:
226 | print(f'Node UUID: {node.uuid}')
227 | print(f'Node Name: {node.name}')
228 | node_summary = node.summary[:100] + '...' if len(node.summary) > 100 else node.summary
229 | print(f'Content Summary: {node_summary}')
230 | print(f'Node Labels: {", ".join(node.labels)}')
231 | print(f'Created At: {node.created_at}')
232 | if hasattr(node, 'attributes') and node.attributes:
233 | print('Attributes:')
234 | for key, value in node.attributes.items():
235 | print(f' {key}: {value}')
236 | print('---')
237 |
238 | finally:
239 | #################################################
240 | # CLEANUP
241 | #################################################
242 | # Always close the connection to Neptune when
243 | # finished to properly release resources
244 | #################################################
245 |
246 | # Close the connection
247 | await graphiti.close()
248 | print('\nConnection closed')
249 |
250 |
251 | if __name__ == '__main__':
252 | asyncio.run(main())
253 |
```