This is page 1 of 3. Use http://codebase.md/allenday/solr-mcp?page={x} to view the full context.
# Directory Structure
```
├── .flake8
├── .gitignore
├── CHANGELOG.md
├── CLAUDE.md
├── CONTRIBUTING.md
├── data
│ ├── bitcoin-whitepaper.json
│ ├── bitcoin-whitepaper.md
│ └── README.md
├── docker-compose.yml
├── LICENSE
├── poetry.lock
├── pyproject.toml
├── QUICKSTART.md
├── README.md
├── scripts
│ ├── check_solr.py
│ ├── create_test_collection.py
│ ├── create_unified_collection.py
│ ├── demo_hybrid_search.py
│ ├── demo_search.py
│ ├── diagnose_search.py
│ ├── direct_mcp_test.py
│ ├── format.py
│ ├── index_documents.py
│ ├── lint.py
│ ├── prepare_data.py
│ ├── process_markdown.py
│ ├── README.md
│ ├── setup.sh
│ ├── simple_index.py
│ ├── simple_mcp_test.py
│ ├── simple_search.py
│ ├── unified_index.py
│ ├── unified_search.py
│ ├── vector_index_simple.py
│ ├── vector_index.py
│ └── vector_search.py
├── solr_config
│ └── unified
│ └── conf
│ ├── schema.xml
│ ├── solrconfig.xml
│ ├── stopwords.txt
│ └── synonyms.txt
├── solr_mcp
│ ├── __init__.py
│ ├── server.py
│ ├── solr
│ │ ├── __init__.py
│ │ ├── client.py
│ │ ├── collections.py
│ │ ├── config.py
│ │ ├── constants.py
│ │ ├── exceptions.py
│ │ ├── interfaces.py
│ │ ├── query
│ │ │ ├── __init__.py
│ │ │ ├── builder.py
│ │ │ ├── executor.py
│ │ │ ├── parser.py
│ │ │ └── validator.py
│ │ ├── response.py
│ │ ├── schema
│ │ │ ├── __init__.py
│ │ │ ├── cache.py
│ │ │ └── fields.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ └── formatting.py
│ │ ├── vector
│ │ │ ├── __init__.py
│ │ │ ├── manager.py
│ │ │ └── results.py
│ │ └── zookeeper.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── solr_default_vectorizer.py
│ │ ├── solr_list_collections.py
│ │ ├── solr_list_fields.py
│ │ ├── solr_select.py
│ │ ├── solr_semantic_select.py
│ │ ├── solr_vector_select.py
│ │ └── tool_decorator.py
│ ├── utils.py
│ └── vector_provider
│ ├── __init__.py
│ ├── clients
│ │ ├── __init__.py
│ │ └── ollama.py
│ ├── constants.py
│ ├── exceptions.py
│ └── interfaces.py
├── solr.Dockerfile
└── tests
├── __init__.py
├── integration
│ ├── __init__.py
│ └── test_direct_solr.py
└── unit
├── __init__.py
├── conftest.py
├── fixtures
│ ├── __init__.py
│ ├── common.py
│ ├── config_fixtures.py
│ ├── http_fixtures.py
│ ├── server_fixtures.py
│ ├── solr_fixtures.py
│ ├── time_fixtures.py
│ ├── vector_fixtures.py
│ └── zookeeper_fixtures.py
├── solr
│ ├── schema
│ │ └── test_fields.py
│ ├── test_client.py
│ ├── test_config.py
│ ├── utils
│ │ └── test_formatting.py
│ └── vector
│ └── test_results.py
├── test_cache.py
├── test_client.py
├── test_config.py
├── test_formatting.py
├── test_interfaces.py
├── test_parser.py
├── test_query.py
├── test_schema.py
├── test_utils.py
├── test_validator.py
├── test_vector.py
├── test_zookeeper.py
├── tools
│ ├── test_base.py
│ ├── test_init.py
│ ├── test_solr_default_vectorizer.py
│ ├── test_solr_list_collections.py
│ ├── test_solr_list_fields.py
│ ├── test_tool_decorator.py
│ └── test_tools.py
└── vector_provider
├── test_constants.py
├── test_exceptions.py
├── test_interfaces.py
└── test_ollama.py
```
# Files
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
```
[flake8]
max-line-length = 88
extend-ignore = E203
exclude = .venv,.git,__pycache__,build,dist
```
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
*.bak
*.un~
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc
```
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
```markdown
# Utility Scripts for Solr MCP
This directory contains utility scripts for working with the Solr MCP server.
## Scripts
### demo_search.py
Demonstrates how to use the MCP client to search for information using both text search and vector search.
**Usage:**
```bash
# Text search
python demo_search.py "bitcoin mining" --collection vectors
# Vector (semantic) search
python demo_search.py "How does Bitcoin prevent double-spending?" --vector --collection vectors
# Specify number of results
python demo_search.py "blockchain" --results 10
```
The script shows how to connect to the MCP server, perform different types of searches, and display the results.
### process_markdown.py
Splits markdown files into sections based on headings and converts them to JSON documents ready for Solr indexing.
**Usage:**
```bash
# Process a markdown file and output to stdout
python process_markdown.py data/document.md
# Process a markdown file and save to a JSON file
python process_markdown.py data/document.md --output data/processed/document_sections.json
```
The script supports markdown files with YAML frontmatter. The frontmatter metadata will be added to each section document.
### index_documents.py
Indexes documents from a JSON file into Solr with vector embeddings generated using Ollama's nomic-embed-text model.
**Usage:**
```bash
# Index documents into the default collection
python index_documents.py data/processed/document_sections.json
# Index documents into a specific collection
python index_documents.py data/processed/document_sections.json --collection my_collection
# Index documents without committing (useful for batch indexing)
python index_documents.py data/processed/document_sections.json --no-commit
```
## Workflow Example
1. Process a markdown file:
```bash
python process_markdown.py data/document.md --output data/processed/document_sections.json
```
2. Start the Docker containers (if not already running):
```bash
docker-compose up -d
```
3. Index the documents with vector embeddings:
```bash
python index_documents.py data/processed/document_sections.json --collection vectors
```
4. Use the MCP server to search the documents:
```bash
# Configure Claude Desktop to use the MCP server
# Then ask questions about the document
```
```
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
```markdown
# Data Examples for Solr MCP
This directory contains example data for testing and demonstrating the Solr MCP server.
## Bitcoin Whitepaper Example
The Bitcoin whitepaper by Satoshi Nakamoto is included as an example document for testing semantic search capabilities.
### Files
- `bitcoin-whitepaper.md`: The original Bitcoin whitepaper in markdown format
- `processed/bitcoin_sections.json`: The whitepaper split into sections, ready for indexing
- `processed/bitcoin_metadata.md`: Example with proper YAML frontmatter metadata
- `processed/bitcoin_metadata.json`: Processed version with metadata included
### Using the Bitcoin Whitepaper Example
1. **Process the whitepaper into sections** (already done):
```bash
python scripts/process_markdown.py data/bitcoin-whitepaper.md --output data/processed/bitcoin_sections.json
```
2. **Start the Docker containers**:
```bash
docker-compose up -d
```
3. **Index the sections with vector embeddings**:
```bash
python scripts/index_documents.py data/processed/bitcoin_sections.json --collection vectors
```
4. **Search using Claude Desktop**:
Configure Claude Desktop to use your MCP server, then ask questions like:
- "How does Bitcoin solve the double-spending problem?"
- "Explain Bitcoin's proof-of-work system"
- "What is the incentive for nodes to support the network?"
The MCP server will find the most semantically relevant sections from the whitepaper and return them to Claude.
## Adding Your Own Documents
You can add your own documents to this directory and process them using the same workflow:
1. Add markdown documents to the `data/` directory
2. Process them into sections:
```bash
python scripts/process_markdown.py data/your-document.md --output data/processed/your-document_sections.json
```
3. Index them into Solr:
```bash
python scripts/index_documents.py data/processed/your-document_sections.json --collection vectors
```
### YAML Frontmatter
For better document organization, add YAML frontmatter to your markdown files:
```markdown
---
title: "Document Title"
author: "Author Name"
date: "2023-01-01"
tags:
- tag1
- tag2
categories:
- category1
- category2
---
# Your Document Content
...
```
This metadata will be included in the indexed documents and can be used for filtering searches.
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
# Solr MCP
A Python package for accessing Apache Solr indexes via Model Context Protocol (MCP). This integration allows AI assistants like Claude to perform powerful search queries against your Solr indexes, combining both keyword and vector search capabilities.
## Features
- **MCP Server**: Implements the Model Context Protocol for integration with AI assistants
- **Hybrid Search**: Combines keyword search precision with vector search semantic understanding
- **Vector Embeddings**: Generates embeddings for documents using Ollama with nomic-embed-text
- **Unified Collections**: Store both document content and vector embeddings in the same collection
- **Docker Integration**: Easy setup with Docker and docker-compose
- **Optimized Vector Search**: Efficiently handles combined vector and SQL queries by pushing down SQL filters to the vector search stage, ensuring optimal performance even with large result sets and pagination
## Architecture
### Vector Search Optimization
The system employs an important optimization for combined vector and SQL queries. When executing a query that includes both vector similarity search and SQL filters:
1. SQL filters (WHERE clauses) are pushed down to the vector search stage
2. This ensures that vector similarity calculations are only performed on documents that will match the final SQL criteria
3. Significantly improves performance for queries with:
- Selective WHERE clauses
- Pagination (LIMIT/OFFSET)
- Large result sets
This optimization reduces computational overhead and network transfer by minimizing the number of vector similarity calculations needed.
## Quick Start
1. Clone this repository
2. Start SolrCloud with Docker:
```bash
docker-compose up -d
```
3. Install dependencies:
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
pip install poetry
poetry install
```
4. Process and index the sample document:
```bash
python scripts/process_markdown.py data/bitcoin-whitepaper.md --output data/processed/bitcoin_sections.json
python scripts/create_unified_collection.py unified
python scripts/unified_index.py data/processed/bitcoin_sections.json --collection unified
```
5. Run the MCP server:
```bash
poetry run python -m solr_mcp.server
```
For more detailed setup and usage instructions, see the [QUICKSTART.md](QUICKSTART.md) guide.
## Requirements
- Python 3.10 or higher
- Docker and Docker Compose
- SolrCloud 9.x
- Ollama (for embedding generation)
## License
This project is licensed under the MIT License - see the LICENSE file for details.
## Contributing
Contributions are welcome! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines.
```
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
```markdown
# Contributing to Solr MCP
Thank you for your interest in contributing to the Solr MCP project! This document provides guidelines and instructions for contributing.
## Getting Started
1. Fork the repository on GitHub
2. Clone your fork locally
3. Set up the development environment as described in the README
4. Create a new branch for your changes
## Development Workflow
1. Make your changes in your branch
2. Write or update tests for your changes
3. Ensure all tests pass
4. Format your code using Black and isort
5. Submit a pull request
## Code Style Guidelines
- Follow PEP 8 style guide with 88-char line length (Black formatter)
- Use type hints consistently (Python 3.9+ typing)
- Group imports: stdlib → third-party → local
- Document functions, classes, and tools with docstrings
## Testing
Run the test suite with:
```bash
poetry run pytest
```
For test coverage:
```bash
poetry run pytest --cov=solr_mcp
```
## Submitting Pull Requests
1. Update the README.md with details of changes if appropriate
2. Update the CHANGELOG.md following the Keep a Changelog format
3. The version will be updated according to Semantic Versioning by the maintainers
4. Once you have the sign-off of a maintainer, your PR will be merged
## License
By contributing to this project, you agree that your contributions will be licensed under the project's MIT License.
```
--------------------------------------------------------------------------------
/CLAUDE.md:
--------------------------------------------------------------------------------
```markdown
# CLAUDE.md - Solr MCP Server Guide (Python)
## IMPORTANT NOTE
Before using the search tools, make sure the Bitcoin whitepaper content is properly indexed in the unified collection!
If search queries like "double spend" return no results, you may need to reindex the content:
```bash
python scripts/process_markdown.py data/bitcoin-whitepaper.md --output data/processed/bitcoin_sections.json
python scripts/unified_index.py data/processed/bitcoin_sections.json --collection unified
```
## Project Structure
- Python-based MCP server integrating with SolrCloud
- Uses MCP 1.4.1 framework for protocol implementation
- Provides document search and knowledge retrieval for AI systems
- Supports SolrCloud collections and distributed search
- Vector search/KNN capabilities for semantic search
## Environment Setup
- Python 3.10: `python3.10 -m venv venv`
- Activate: `source venv/bin/activate` (Unix) or `venv\Scripts\activate` (Windows)
- Install Poetry: `pip install poetry`
## Build Commands
- Install all deps: `poetry install`
- Run server: `poetry run python -m solr_mcp.server`
- Debug mode: `poetry run python -m solr_mcp.server --debug`
- Package: `poetry build`
## Test Commands
- Run tests: `poetry run pytest`
- Single test: `poetry run pytest tests/test_file.py::test_function`
- Coverage: `poetry run pytest --cov=solr_mcp`
- Lint: `poetry run flake8 solr_mcp tests`
- Type check: `poetry run mypy solr_mcp tests`
- Format code: `poetry run black solr_mcp tests`
- Sort imports: `poetry run isort solr_mcp tests`
- Run all formatting: `poetry run black solr_mcp tests && poetry run isort solr_mcp tests`
## Docker Commands
- Start SolrCloud: `docker-compose up -d`
- Check logs: `docker-compose logs -f`
- Solr UI: http://localhost:8983/solr/
- Stop SolrCloud: `docker-compose down`
- Cleanup volumes: `docker-compose down -v`
## SolrCloud Integration
- Connection via pysolr with ZooKeeper ensemble
- Support for collection management and configuration
- Handle distributed search with configurable shards and replicas
- Vector search using dense_vector fields and KNN
- Hybrid search combining keyword and vector search capabilities
- Embedding generation via Ollama using nomic-embed-text (768D vectors)
- Unified collections storing both text content and vector embeddings
- Implement retry and fallback logic for resilience
## Code Style Guidelines
- Follow PEP 8 style guide with 88-char line length (Black formatter)
- Use type hints consistently (Python 3.9+ typing)
- Group imports: stdlib → third-party → local
- Document functions, classes and tools with docstrings
- Handle Solr connection errors with appropriate retries
- Log operations with structured logging (JSON format)
- Return well-formatted errors following JSON-RPC 2.0 spec
## Technical Details
Key implementation details:
- Uses MCP 1.4.1 framework for protocol implementation
```
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
```python
"""Tests package."""
```
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
```python
"""Unit tests package."""
```
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
```python
"""Integration tests package."""
```
--------------------------------------------------------------------------------
/solr_mcp/__init__.py:
--------------------------------------------------------------------------------
```python
"""Solr MCP Server - Model Context Protocol server for SolrCloud integration."""
__version__ = "0.1.0"
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/__init__.py:
--------------------------------------------------------------------------------
```python
"""Fixture package for unit tests.
This package contains various fixtures categorized by functionality.
"""
```
--------------------------------------------------------------------------------
/solr_mcp/vector_provider/clients/__init__.py:
--------------------------------------------------------------------------------
```python
"""Vector provider client implementations."""
from .ollama import OllamaVectorProvider
__all__ = ["OllamaVectorProvider"]
```
--------------------------------------------------------------------------------
/solr_mcp/solr/query/__init__.py:
--------------------------------------------------------------------------------
```python
"""Query building and validation package for SolrCloud client."""
from solr_mcp.solr.query.builder import QueryBuilder
__all__ = ["QueryBuilder"]
```
--------------------------------------------------------------------------------
/solr_mcp/vector_provider/__init__.py:
--------------------------------------------------------------------------------
```python
"""Vector provider implementations."""
from solr_mcp.vector_provider.clients.ollama import OllamaVectorProvider
__all__ = ["OllamaVectorProvider"]
```
--------------------------------------------------------------------------------
/solr_config/unified/conf/stopwords.txt:
--------------------------------------------------------------------------------
```
# Standard stop words
a
an
and
are
as
at
be
but
by
for
if
in
into
is
it
no
not
of
on
or
such
that
the
their
then
there
these
they
this
to
was
will
with
```
--------------------------------------------------------------------------------
/solr_mcp/solr/schema/__init__.py:
--------------------------------------------------------------------------------
```python
"""Schema management package for SolrCloud client."""
from solr_mcp.solr.schema.cache import FieldCache
from solr_mcp.solr.schema.fields import FieldManager
__all__ = ["FieldManager", "FieldCache"]
```
--------------------------------------------------------------------------------
/solr_mcp/solr/utils/__init__.py:
--------------------------------------------------------------------------------
```python
"""Solr utilities package."""
from solr_mcp.solr.utils.formatting import (
format_error_response,
format_search_results,
format_sql_response,
)
__all__ = ["format_search_results", "format_sql_response", "format_error_response"]
```
--------------------------------------------------------------------------------
/solr_mcp/solr/vector/__init__.py:
--------------------------------------------------------------------------------
```python
"""Vector search functionality."""
from solr_mcp.solr.vector.manager import VectorManager
from solr_mcp.solr.vector.results import VectorSearchResult, VectorSearchResults
__all__ = ["VectorManager", "VectorSearchResult", "VectorSearchResults"]
```
--------------------------------------------------------------------------------
/solr_config/unified/conf/synonyms.txt:
--------------------------------------------------------------------------------
```
# Synonym mappings
bitcoin, btc
blockchain, distributed ledger
cryptocurrency, crypto
double spend, double spending, double-spend, double-spending, doublespend, doublespending
consensus, agreement
transaction, tx
block, blocks
mining, miner, miners
peer to peer, peer-to-peer, p2p
cryptographic, cryptography, crypto
distributed ledger, blockchain
proof of work, pow
hash, hashing
```
--------------------------------------------------------------------------------
/solr_mcp/tools/solr_list_collections.py:
--------------------------------------------------------------------------------
```python
"""Tool for listing Solr collections."""
from typing import Dict, List
from solr_mcp.tools.tool_decorator import tool
@tool()
async def execute_list_collections(mcp) -> List[str]:
"""List all available Solr collections.
Lists all collections available in the Solr cluster.
Args:
mcp: SolrMCPServer instance
Returns:
List of collection names
"""
solr_client = mcp.solr_client
return await solr_client.list_collections()
```
--------------------------------------------------------------------------------
/solr_mcp/vector_provider/exceptions.py:
--------------------------------------------------------------------------------
```python
"""Exceptions for vector provider module."""
class VectorError(Exception):
"""Base exception for vector-related errors."""
pass
class VectorGenerationError(VectorError):
"""Raised when vector generation fails."""
pass
class VectorConfigError(VectorError):
"""Raised when there is an error in vector provider configuration."""
pass
class VectorConnectionError(VectorError):
"""Raised when connection to vector service fails."""
pass
```
--------------------------------------------------------------------------------
/solr_mcp/solr/__init__.py:
--------------------------------------------------------------------------------
```python
"""SolrCloud client package."""
from solr_mcp.solr.client import SolrClient
from solr_mcp.solr.config import SolrConfig
from solr_mcp.solr.constants import FIELD_TYPE_MAPPING, SYNTHETIC_SORT_FIELDS
from solr_mcp.solr.exceptions import (
ConfigurationError,
ConnectionError,
QueryError,
SchemaError,
SolrError,
)
__all__ = [
"SolrConfig",
"SolrClient",
"SolrError",
"ConfigurationError",
"ConnectionError",
"QueryError",
"SchemaError",
"FIELD_TYPE_MAPPING",
"SYNTHETIC_SORT_FIELDS",
]
```
--------------------------------------------------------------------------------
/solr_mcp/vector_provider/constants.py:
--------------------------------------------------------------------------------
```python
"""Constants for vector module."""
from typing import Any, Dict
# Default configuration for vector providers
DEFAULT_OLLAMA_CONFIG: Dict[str, Any] = {
"base_url": "http://localhost:11434",
"model": "nomic-embed-text",
"timeout": 30, # seconds
"retries": 3,
}
# Environment variable names
ENV_OLLAMA_BASE_URL = "OLLAMA_BASE_URL"
ENV_OLLAMA_MODEL = "OLLAMA_MODEL"
# HTTP endpoints
OLLAMA_EMBEDDINGS_PATH = "/api/embeddings"
# Model-specific constants
MODEL_DIMENSIONS = {"nomic-embed-text": 768} # 768-dimensional vectors
```
--------------------------------------------------------------------------------
/tests/unit/conftest.py:
--------------------------------------------------------------------------------
```python
"""Test configuration and fixtures.
This module imports and re-exports all fixtures from the fixtures directory,
making them available to all tests without explicit imports.
"""
# Import and expose all fixtures
from unittest.mock import mock_open
from .fixtures.common import *
from .fixtures.config_fixtures import *
from .fixtures.http_fixtures import *
from .fixtures.server_fixtures import *
from .fixtures.solr_fixtures import *
from .fixtures.time_fixtures import *
from .fixtures.vector_fixtures import *
from .fixtures.zookeeper_fixtures import *
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/time_fixtures.py:
--------------------------------------------------------------------------------
```python
"""Time-related fixtures for unit tests."""
import time
from unittest.mock import patch
import pytest
@pytest.fixture
def mock_time(request):
"""Parameterized time.time mock.
Args:
request: Pytest request object that can contain parameters:
- minutes_offset: Minutes to add to current time
"""
# Get parameters or use defaults
minutes_offset = getattr(request, "param", {}).get("minutes_offset", 0)
seconds_offset = minutes_offset * 60
with patch("time.time", return_value=time.time() + seconds_offset) as mock:
yield mock
```
--------------------------------------------------------------------------------
/solr_mcp/solr/constants.py:
--------------------------------------------------------------------------------
```python
"""Constants for SolrCloud client."""
# Field type mapping for sorting
FIELD_TYPE_MAPPING = {
"string": "string",
"text_general": "text",
"text_en": "text",
"int": "numeric",
"long": "numeric",
"float": "numeric",
"double": "numeric",
"date": "date",
"boolean": "boolean",
}
# Synthetic fields that can be used for sorting
SYNTHETIC_SORT_FIELDS = {
"score": {
"type": "numeric",
"directions": ["asc", "desc"],
"default_direction": "desc",
"searchable": True,
},
"_docid_": {
"type": "numeric",
"directions": ["asc", "desc"],
"default_direction": "asc",
"searchable": False,
"warning": "Internal Lucene document ID. Not stable across restarts or reindexing.",
},
}
```
--------------------------------------------------------------------------------
/solr_mcp/tools/__init__.py:
--------------------------------------------------------------------------------
```python
"""Tool definitions for Solr MCP server."""
import inspect
import sys
from .solr_default_vectorizer import get_default_text_vectorizer
from .solr_list_collections import execute_list_collections
from .solr_list_fields import execute_list_fields
from .solr_select import execute_select_query
from .solr_semantic_select import execute_semantic_select_query
from .solr_vector_select import execute_vector_select_query
from .tool_decorator import get_schema, tool
__all__ = [
"execute_list_collections",
"execute_list_fields",
"execute_select_query",
"execute_vector_select_query",
"execute_semantic_select_query",
"get_default_text_vectorizer",
]
TOOLS_DEFINITION = [
obj
for name, obj in inspect.getmembers(sys.modules[__name__])
if inspect.isfunction(obj) and hasattr(obj, "_is_tool") and obj._is_tool
]
```
--------------------------------------------------------------------------------
/scripts/format.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Format script to run all code formatters on the project.
"""
import subprocess
import sys
from typing import List
def run_command(command: List[str]) -> bool:
"""Run a command and return True if successful, False otherwise."""
print(f"Running: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
print(f"Command failed with exit code {result.returncode}")
print(result.stdout)
print(result.stderr)
return False
print(result.stdout)
return True
def main() -> int:
"""Run all code formatters."""
print("Running code formatters...")
success = True
# Run black
if not run_command(["black", "solr_mcp", "tests"]):
success = False
# Run isort
if not run_command(["isort", "solr_mcp", "tests"]):
success = False
if success:
print("All formatting completed successfully!")
return 0
else:
print("Some formatting commands failed.")
return 1
if __name__ == "__main__":
sys.exit(main())
```
--------------------------------------------------------------------------------
/scripts/lint.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Lint script to run all linting tools on the project.
"""
import subprocess
import sys
from typing import List
def run_command(command: List[str]) -> bool:
"""Run a command and return True if successful, False otherwise."""
print(f"Running: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
print(f"Command failed with exit code {result.returncode}")
print(result.stdout)
print(result.stderr)
return False
print(result.stdout)
return True
def main() -> int:
"""Run all linting tools."""
print("Running full linting checks...")
success = True
# Run flake8 with all checks
if not run_command(["flake8", "solr_mcp", "tests"]):
success = False
# Run mypy type checking
if not run_command(["mypy", "solr_mcp", "tests"]):
success = False
if success:
print("All linting checks passed!")
return 0
else:
print("Some linting checks failed.")
return 1
if __name__ == "__main__":
sys.exit(main())
```
--------------------------------------------------------------------------------
/solr_mcp/tools/base.py:
--------------------------------------------------------------------------------
```python
"""Base tool definitions and decorators."""
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Union
def tool(
name: Optional[str] = None,
description: Optional[str] = None,
parameters: Optional[Dict[str, Any]] = None,
) -> Callable:
"""Decorator to mark a function as an MCP tool.
Args:
name: Tool name. Defaults to function name if not provided.
description: Tool description. Defaults to function docstring if not provided.
parameters: Tool parameters. Defaults to function parameters if not provided.
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> List[Dict[str, str]]:
result = func(*args, **kwargs)
if not isinstance(result, list):
result = [{"type": "text", "text": str(result)}]
return result
# Mark as tool
wrapper._is_tool = True
# Set tool metadata
wrapper._tool_name = name or func.__name__
wrapper._tool_description = description or func.__doc__ or ""
wrapper._tool_parameters = parameters or {}
return wrapper
return decorator
```
--------------------------------------------------------------------------------
/tests/unit/tools/test_init.py:
--------------------------------------------------------------------------------
```python
"""Test tools initialization."""
import pytest
from solr_mcp.tools import (
TOOLS_DEFINITION,
execute_list_collections,
execute_list_fields,
execute_select_query,
execute_semantic_select_query,
execute_vector_select_query,
get_default_text_vectorizer,
)
def test_tools_definition():
"""Test that TOOLS_DEFINITION contains all expected tools."""
# All tools should be in TOOLS_DEFINITION
tools = {
"solr_list_collections": execute_list_collections,
"solr_list_fields": execute_list_fields,
"solr_select": execute_select_query,
"solr_vector_select": execute_vector_select_query,
"solr_semantic_select": execute_semantic_select_query,
"get_default_text_vectorizer": get_default_text_vectorizer,
}
assert len(TOOLS_DEFINITION) == len(tools)
for tool_name, tool_func in tools.items():
assert tool_func in TOOLS_DEFINITION
def test_tools_exports():
"""Test that __all__ exports all tools."""
from solr_mcp.tools import __all__
expected = {
"execute_list_collections",
"execute_list_fields",
"execute_select_query",
"execute_vector_select_query",
"execute_semantic_select_query",
"get_default_text_vectorizer",
}
assert set(__all__) == expected
```
--------------------------------------------------------------------------------
/tests/unit/vector_provider/test_constants.py:
--------------------------------------------------------------------------------
```python
"""Tests for vector provider constants."""
from solr_mcp.vector_provider.constants import (
DEFAULT_OLLAMA_CONFIG,
ENV_OLLAMA_BASE_URL,
ENV_OLLAMA_MODEL,
MODEL_DIMENSIONS,
OLLAMA_EMBEDDINGS_PATH,
)
def test_default_ollama_config():
"""Test default Ollama configuration values."""
assert isinstance(DEFAULT_OLLAMA_CONFIG, dict)
assert "base_url" in DEFAULT_OLLAMA_CONFIG
assert "model" in DEFAULT_OLLAMA_CONFIG
assert "timeout" in DEFAULT_OLLAMA_CONFIG
assert "retries" in DEFAULT_OLLAMA_CONFIG
assert DEFAULT_OLLAMA_CONFIG["base_url"] == "http://localhost:11434"
assert DEFAULT_OLLAMA_CONFIG["model"] == "nomic-embed-text"
assert DEFAULT_OLLAMA_CONFIG["timeout"] == 30
assert DEFAULT_OLLAMA_CONFIG["retries"] == 3
def test_environment_variables():
"""Test environment variable names."""
assert ENV_OLLAMA_BASE_URL == "OLLAMA_BASE_URL"
assert ENV_OLLAMA_MODEL == "OLLAMA_MODEL"
def test_api_endpoints():
"""Test API endpoint paths."""
assert OLLAMA_EMBEDDINGS_PATH == "/api/embeddings"
def test_model_dimensions():
"""Test model dimension mappings."""
assert isinstance(MODEL_DIMENSIONS, dict)
assert "nomic-embed-text" in MODEL_DIMENSIONS
assert MODEL_DIMENSIONS["nomic-embed-text"] == 768 # 768-dimensional embeddings
```
--------------------------------------------------------------------------------
/solr.Dockerfile:
--------------------------------------------------------------------------------
```dockerfile
FROM solr:9.5
USER root
# Install SQL dependencies
RUN apt-get update && \
apt-get install -y wget unzip && \
mkdir -p /opt/solr/contrib/sql && \
cd /opt/solr/contrib/sql && \
wget https://repo1.maven.org/maven2/org/apache/solr/solr-sql/9.5.0/solr-sql-9.5.0.jar && \
wget https://repo1.maven.org/maven2/org/apache/calcite/calcite-core/1.35.0/calcite-core-1.35.0.jar && \
wget https://repo1.maven.org/maven2/org/apache/calcite/calcite-linq4j/1.35.0/calcite-linq4j-1.35.0.jar && \
wget https://repo1.maven.org/maven2/org/apache/calcite/avatica/avatica-core/1.23.0/avatica-core-1.23.0.jar && \
wget https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/3.21.7/protobuf-java-3.21.7.jar && \
wget https://repo1.maven.org/maven2/org/apache/calcite/avatica/avatica-metrics/1.23.0/avatica-metrics-1.23.0.jar && \
wget https://repo1.maven.org/maven2/org/locationtech/jts/jts-core/1.19.0/jts-core-1.19.0.jar && \
wget https://repo1.maven.org/maven2/org/codehaus/janino/janino/3.1.9/janino-3.1.9.jar && \
wget https://repo1.maven.org/maven2/org/codehaus/janino/commons-compiler/3.1.9/commons-compiler-3.1.9.jar && \
cp *.jar /opt/solr/server/solr-webapp/webapp/WEB-INF/lib/ && \
chown -R solr:solr /opt/solr/contrib/sql /opt/solr/server/solr-webapp/webapp/WEB-INF/lib/*.jar
USER solr
```
--------------------------------------------------------------------------------
/tests/unit/vector_provider/test_exceptions.py:
--------------------------------------------------------------------------------
```python
"""Tests for vector provider exceptions."""
from solr_mcp.vector_provider.exceptions import (
VectorConfigError,
VectorConnectionError,
VectorError,
VectorGenerationError,
)
def test_vector_error():
"""Test base VectorError exception."""
error = VectorError("Test error")
assert str(error) == "Test error"
assert isinstance(error, Exception)
def test_vector_generation_error():
"""Test VectorGenerationError exception."""
error = VectorGenerationError("Generation failed")
assert str(error) == "Generation failed"
assert isinstance(error, VectorError)
assert isinstance(error, Exception)
def test_vector_config_error():
"""Test VectorConfigError exception."""
error = VectorConfigError("Invalid config")
assert str(error) == "Invalid config"
assert isinstance(error, VectorError)
assert isinstance(error, Exception)
def test_vector_connection_error():
"""Test VectorConnectionError exception."""
error = VectorConnectionError("Connection failed")
assert str(error) == "Connection failed"
assert isinstance(error, VectorError)
assert isinstance(error, Exception)
def test_error_inheritance():
"""Test exception inheritance hierarchy."""
assert issubclass(VectorGenerationError, VectorError)
assert issubclass(VectorConfigError, VectorError)
assert issubclass(VectorConnectionError, VectorError)
assert issubclass(VectorError, Exception)
```
--------------------------------------------------------------------------------
/CHANGELOG.md:
--------------------------------------------------------------------------------
```markdown
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- Initial project structure
- MCP server implementation
- Solr client with search, vector search, and hybrid search capabilities
- Embedding generation via Ollama using nomic-embed-text
- Docker configuration for SolrCloud and ZooKeeper
- Demo scripts and utilities for testing
- Bitcoin whitepaper as sample document
- Documentation (README, QUICKSTART, CONTRIBUTING)
### Fixed
- Improved search query transformation for better results
- Fixed phrase proximity searches with `~5` operator
- Proper field naming for Solr compatibility
- Enhanced text analysis for hyphenated terms like "double-spending"
- Improved synonym handling in Solr configuration
- Fixed vector search configuration to use built-in capabilities
- Improved error handling in Ollama embedding client with retries
- Added proper timeout and fallback mechanisms for embedding generation
- Fixed Solr schema URL paths in client implementation
- Enhanced Docker healthcheck for Ollama service
### Changed
- Migrated from FastMCP to MCP 1.4.1
## [0.1.0] - 2024-03-17
### Added
- Initial release
- MCP server implementation
- Integration with SolrCloud
- Support for basic search operations
- Vector search capabilities
- Hybrid search functionality
- Embedding generation and indexing
```
--------------------------------------------------------------------------------
/solr_mcp/tools/solr_select.py:
--------------------------------------------------------------------------------
```python
"""Tool for executing SQL SELECT queries against Solr."""
from typing import Dict
from solr_mcp.tools.tool_decorator import tool
@tool()
async def execute_select_query(mcp, query: str) -> Dict:
"""Execute SQL queries against Solr collections.
Executes SQL queries against Solr collections with the following Solr-specific behaviors:
Collection/Field Rules:
- Collections are used as table names (case-insensitive)
- Field names are case-sensitive and must exist in Solr schema
- SELECT * only allowed with LIMIT clause
- Unlimited queries restricted to docValues-enabled fields
- Reserved words must be backtick-escaped
WHERE Clause Differences:
- Field must be on one side of predicate
- No comparing two constants or two fields
- No subqueries
- Solr syntax in values:
- '[0 TO 100]' for ranges
- '(term1 term2)' for non-phrase OR search
- String literals use single-quotes
Supported Features:
- Operators: =, <>, >, >=, <, <=, IN, LIKE (wildcards), BETWEEN, IS [NOT] NULL
- Functions: COUNT(*), COUNT(DISTINCT), MIN, MAX, SUM, AVG
- GROUP BY: Uses faceting (fast) for low cardinality, map_reduce (slow) for high cardinality
- ORDER BY: Requires docValues-enabled fields
- LIMIT/OFFSET: Use 'OFFSET x FETCH NEXT y ROWS ONLY' syntax
- Performance of OFFSET degrades beyond 10k docs per shard
Args:
mcp: SolrMCPServer instance
query: SQL query to execute
Returns:
Query results
"""
solr_client = mcp.solr_client
return await solr_client.execute_select_query(query)
```
--------------------------------------------------------------------------------
/solr_mcp/tools/solr_vector_select.py:
--------------------------------------------------------------------------------
```python
"""Tool for executing vector search queries against Solr collections."""
from typing import Dict, List, Optional
from solr_mcp.tools.tool_decorator import tool
@tool()
async def execute_vector_select_query(
mcp, query: str, vector: List[float], field: Optional[str] = None
) -> Dict:
"""Execute vector search queries against Solr collections.
Extends solr_select tool with vector search capabilities.
Additional Parameters:
- vector: Used to match against the collection's vector field, intended for vector search.
- field: Name of the vector field to search against (optional, will auto-detect if not specified)
The query results will be ranked based on distance to the provided vector. Therefore, ORDER BY is not allowed.
Collection/Field Rules:
- Vector field must be a dense_vector or knn_vector field type
- The specified field must exist in the collection schema
- The input vector dimensionality must match the field's vector dimensionality
Supported Features:
- All standard SELECT query features except ORDER BY
- Results are ordered by vector distance
- Hybrid search combining keyword (SQL WHERE clauses) and vector distance (vector parameter)
Args:
mcp: SolrMCPServer instance
query: SQL query to execute
vector: Query vector for similarity search
field: Name of the vector field to search against (optional, auto-detected if not specified)
Returns:
Query results
"""
solr_client = mcp.solr_client
return await solr_client.execute_vector_select_query(query, vector, field)
```
--------------------------------------------------------------------------------
/solr_mcp/solr/response.py:
--------------------------------------------------------------------------------
```python
"""Response formatters for Solr results."""
import logging
from typing import Any, Dict, List, Optional, Union
import pysolr
from loguru import logger
from solr_mcp.solr.utils.formatting import format_search_results, format_sql_response
logger = logging.getLogger(__name__)
class ResponseFormatter:
"""Formats Solr responses for client consumption."""
@staticmethod
def format_search_results(
results: pysolr.Results, start: int = 0
) -> Dict[str, Any]:
"""Format Solr search results for client consumption.
Args:
results: Solr search results
start: Starting index of results
Returns:
Formatted search results
"""
return format_search_results(results, start)
@staticmethod
def format_sql_response(response: Dict[str, Any]) -> Dict[str, Any]:
"""Format Solr SQL response for client consumption.
Args:
response: Solr SQL response
Returns:
Formatted SQL response
"""
return format_sql_response(response)
@staticmethod
def format_vector_search_results(
results: Dict[str, Any], top_k: int
) -> Dict[str, Any]:
"""Format vector search results.
Args:
results: Vector search results
top_k: Number of top results
Returns:
Formatted vector search results
"""
from solr_mcp.solr.vector import VectorSearchResults
vector_results = VectorSearchResults.from_solr_response(
response=results, top_k=top_k
)
return vector_results.to_dict()
```
--------------------------------------------------------------------------------
/scripts/simple_mcp_test.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Simple MCP client test script.
"""
import sys
import os
import json
import asyncio
import httpx
# Add the project root to your path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.solr.client import SolrClient
async def direct_solr_test():
"""Test direct Solr connection."""
client = SolrClient()
# Test standard search with different query formats
print("\n=== Testing direct Solr client search with different query formats ===")
results1 = await client.search("double spend", collection="unified")
print(f"Simple search results: {results1}")
results2 = await client.search("content:double content:spend", collection="unified")
print(f"Field-specific search results: {results2}")
results3 = await client.search("content:\"double spend\"~5", collection="unified")
print(f"Phrase search results: {results3}")
# Test with HTTP client
print("\n=== Testing direct HTTP search ===")
async with httpx.AsyncClient() as http_client:
response = await http_client.get(
'http://localhost:8983/solr/unified/select',
params={
'q': 'content:"double spend"~5',
'wt': 'json'
}
)
print(f"HTTP search results: {response.text}")
# Check solr config details
print("\n=== Solr client configuration ===")
print(f"Default collection: {client.config.default_collection}")
print(f"Collections available: {client.list_collections()}")
async def main():
await direct_solr_test()
if __name__ == "__main__":
asyncio.run(main())
```
--------------------------------------------------------------------------------
/solr_mcp/tools/solr_list_fields.py:
--------------------------------------------------------------------------------
```python
"""Tool for listing fields in a Solr collection."""
from typing import Any, Dict
from solr_mcp.tools.tool_decorator import tool
@tool()
async def execute_list_fields(mcp: str, collection: str) -> Dict[str, Any]:
"""List all fields in a Solr collection.
This tool provides detailed information about each field in a Solr collection,
including how fields are related through copyField directives. Pay special
attention to fields that have 'copies_from' data - these are aggregate fields
that combine content from multiple source fields.
For example, the '_text_' field is typically an aggregate field that combines
content from many text fields to provide a unified search experience. When you
see a field with 'copies_from' data, it means that field contains a copy of
the content from all the listed source fields.
Args:
mcp: MCP instance name
collection: Name of the collection to get fields from
Returns:
Dictionary containing:
- fields: List of field definitions with their properties including:
- name: Field name
- type: Field type (text_general, string, etc)
- indexed: Whether the field is indexed for searching
- stored: Whether the field values are stored
- docValues: Whether the field can be used for sorting/faceting
- multiValued: Whether the field can contain multiple values
- copies_from: List of source fields that copy their content to this field
- collection: Name of the collection queried
"""
fields = await mcp.solr_client.list_fields(collection)
return {"fields": fields, "collection": collection}
```
--------------------------------------------------------------------------------
/solr_mcp/vector_provider/interfaces.py:
--------------------------------------------------------------------------------
```python
"""Interfaces for vector providers."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class VectorProvider(ABC):
"""Interface for generating vectors for semantic search."""
@abstractmethod
async def get_vector(self, text: str, model: Optional[str] = None) -> List[float]:
"""Get vector for a single text.
Args:
text: Text to generate vector for
model: Optional model name to use (overrides default)
Returns:
List of floats representing the vector
Raises:
VectorGenerationError: If vector generation fails
VectorConnectionError: If connection to service fails
"""
pass
@abstractmethod
async def get_vectors(
self, texts: List[str], model: Optional[str] = None
) -> List[List[float]]:
"""Get vectors for multiple texts.
Args:
texts: List of texts to generate vectors for
model: Optional model name to use (overrides default)
Returns:
List of vectors (list of floats)
Raises:
VectorGenerationError: If vector generation fails
VectorConnectionError: If connection to service fails
"""
pass
@property
@abstractmethod
def vector_dimension(self) -> int:
"""Get the dimension of vectors produced by this provider.
Returns:
Integer dimension of the vectors
Raises:
VectorConfigError: If unable to determine vector dimension
"""
pass
@property
@abstractmethod
def model_name(self) -> str:
"""Get the name of the model used by this provider.
Returns:
String name of the model
"""
pass
```
--------------------------------------------------------------------------------
/scripts/simple_index.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Simple indexing script to demonstrate adding documents to Solr without embeddings.
"""
import argparse
import json
import os
import sys
import time
import pysolr
from typing import Dict, List, Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
def index_documents(json_file: str, collection: str = "documents", commit: bool = True):
"""
Index documents from a JSON file into Solr without vector embeddings.
Args:
json_file: Path to the JSON file containing documents
collection: Solr collection name
commit: Whether to commit after indexing
"""
# Load documents
with open(json_file, 'r', encoding='utf-8') as f:
documents = json.load(f)
# Initialize Solr client directly
solr_url = f"http://localhost:8983/solr/{collection}"
solr = pysolr.Solr(solr_url, always_commit=commit)
print(f"Indexing {len(documents)} documents to {collection} collection...")
try:
# Add documents to Solr
solr.add(documents)
print(f"Successfully indexed {len(documents)} documents in collection '{collection}'")
except Exception as e:
print(f"Error indexing documents: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Index documents in Solr without vector embeddings")
parser.add_argument("json_file", help="Path to the JSON file containing documents")
parser.add_argument("--collection", "-c", default="documents", help="Solr collection name")
parser.add_argument("--no-commit", dest="commit", action="store_false", help="Don't commit after indexing")
args = parser.parse_args()
index_documents(args.json_file, args.collection, args.commit)
```
--------------------------------------------------------------------------------
/solr_mcp/solr/interfaces.py:
--------------------------------------------------------------------------------
```python
"""Interfaces for Solr client components."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class CollectionProvider(ABC):
"""Interface for providing collection information."""
@abstractmethod
async def list_collections(self) -> List[str]:
"""List all available collections.
Returns:
List of collection names
Raises:
ConnectionError: If unable to retrieve collections
"""
pass
@abstractmethod
async def collection_exists(self, collection: str) -> bool:
"""Check if a collection exists.
Args:
collection: Name of the collection to check
Returns:
True if the collection exists, False otherwise
Raises:
ConnectionError: If unable to check collection existence
"""
pass
class VectorSearchProvider(ABC):
"""Interface for vector search operations."""
@abstractmethod
def execute_vector_search(
self, client: Any, vector: List[float], field: str, top_k: Optional[int] = None
) -> Dict[str, Any]:
"""Execute a vector similarity search.
Args:
client: Solr client instance
vector: Dense vector for similarity search
field: DenseVector field to search against
top_k: Number of top results to return
Returns:
Search results as a dictionary
Raises:
SolrError: If vector search fails
"""
pass
@abstractmethod
async def get_vector(self, text: str) -> List[float]:
"""Get vector for text.
Args:
text: Text to convert to vector
Returns:
Vector as list of floats
Raises:
SolrError: If vector generation fails
"""
pass
```
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
```toml
[tool.poetry]
name = "solr-mcp"
version = "0.1.0"
description = "A Python package for accessing Solr indexes via Model Context Protocol (MCP)"
authors = ["Allen Day <[email protected]>"]
readme = "README.md"
license = "MIT"
repository = "https://github.com/allenday/solr-mcp"
packages = [{include = "solr_mcp"}]
[tool.poetry.scripts]
solr-mcp = "solr_mcp.server:main"
lint = "scripts.lint:main"
format = "scripts.format:main"
[tool.poetry.dependencies]
python = "^3.10"
pysolr = "^3.9.0"
mcp = "^1.4.1"
httpx = "^0.27.0"
pydantic = "^2.6.1"
numpy = "^1.26.3"
markdown = "^3.5.2"
fastapi = "^0.109.2"
uvicorn = "^0.27.1"
python-frontmatter = "^1.1.0"
loguru = "^0.7.3"
kazoo = "^2.10.0"
sqlglot = "^26.11.1"
pytest-mock = "^3.14.0"
[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
mypy = "^1.8.0"
flake8 = "^7.0.0"
black = "^24.2.0"
isort = "^5.13.2"
pytest-cov = "^6.0.0"
pytest-asyncio = "^0.25.3"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
markers = [
"integration: marks tests that require external services (deselect with '-m \"not integration\"')"
]
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_no_return = true
warn_unreachable = true
[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
disallow_incomplete_defs = false
[tool.black]
line-length = 88
target-version = ['py310']
include = '\.pyi?$'
[tool.isort]
profile = "black"
line_length = 88
multi_line_output = 3
[tool.flake8]
max-line-length = 88
extend-ignore = ["E203"]
exclude = [".venv", ".git", "__pycache__", "build", "dist"]
```
--------------------------------------------------------------------------------
/tests/unit/tools/test_solr_list_collections.py:
--------------------------------------------------------------------------------
```python
"""Tests for Solr list collections tool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from mcp.server.fastmcp.exceptions import ToolError
from solr_mcp.server import SolrMCPServer
from solr_mcp.tools.solr_list_collections import execute_list_collections
@pytest.mark.asyncio
class TestListCollectionsTool:
"""Test list collections tool."""
async def test_execute_list_collections_requires_server_instance(self):
"""Test that execute_list_collections requires a proper server instance."""
# Test with string instead of server instance
with pytest.raises(
AttributeError, match="'str' object has no attribute 'solr_client'"
):
await execute_list_collections("server")
async def test_execute_list_collections_success(self):
"""Test successful list collections execution."""
# Create mock server instance with solr_client
mock_server = MagicMock(spec=SolrMCPServer)
mock_solr_client = AsyncMock()
mock_solr_client.list_collections.return_value = ["unified", "collection2"]
mock_server.solr_client = mock_solr_client
# Execute tool
result = await execute_list_collections(mock_server)
# Verify result
assert isinstance(result, list)
assert "unified" in result
assert len(result) == 2
mock_solr_client.list_collections.assert_called_once()
async def test_execute_list_collections_error(self):
"""Test list collections error handling."""
# Create mock server instance with failing solr_client
mock_server = MagicMock(spec=SolrMCPServer)
mock_solr_client = AsyncMock()
mock_solr_client.list_collections.side_effect = Exception(
"Failed to list collections"
)
mock_server.solr_client = mock_solr_client
# Execute tool and verify error is propagated
with pytest.raises(Exception, match="Failed to list collections"):
await execute_list_collections(mock_server)
```
--------------------------------------------------------------------------------
/solr_mcp/tools/solr_default_vectorizer.py:
--------------------------------------------------------------------------------
```python
"""Tool for getting information about the default vector provider."""
import re
from typing import Any, Dict
from urllib.parse import urlparse
from solr_mcp.tools.tool_decorator import tool
from solr_mcp.vector_provider.constants import DEFAULT_OLLAMA_CONFIG, MODEL_DIMENSIONS
@tool()
async def get_default_text_vectorizer(mcp) -> Dict[str, Any]:
"""Get information about the default vector provider used for semantic search.
Returns information about the default vector provider configuration used for semantic search,
including the model name, vector dimensionality, host, and port.
This information is useful for ensuring that your vector fields in Solr have
the correct dimensionality to match the vector provider model.
Returns:
Dictionary containing:
- vector_provider_model: The name of the default vector provider model
- vector_provider_dimension: The dimensionality of vectors produced by this model
- vector_provider_host: The host of the vector provider service
- vector_provider_port: The port of the vector provider service
- vector_provider_url: The full URL of the vector provider service
"""
if hasattr(mcp, "solr_client") and hasattr(mcp.solr_client, "vector_manager"):
vector_manager = mcp.solr_client.vector_manager
model_name = vector_manager.client.model
dimension = MODEL_DIMENSIONS.get(model_name, 768)
base_url = vector_manager.client.base_url
else:
# Fall back to defaults
model_name = DEFAULT_OLLAMA_CONFIG["model"]
dimension = MODEL_DIMENSIONS.get(model_name, 768)
base_url = DEFAULT_OLLAMA_CONFIG["base_url"]
# Parse URL to extract host and port
parsed_url = urlparse(base_url)
host = parsed_url.hostname or "localhost"
port = parsed_url.port or 11434 # Default Ollama port
# Format as "model@host:port" for easy use with vector_provider parameter
formatted_spec = f"{model_name}@{host}:{port}"
return {
"vector_provider_model": model_name,
"vector_provider_dimension": dimension,
"vector_provider_host": host,
"vector_provider_port": port,
"vector_provider_url": base_url,
"vector_provider_spec": formatted_spec,
}
```
--------------------------------------------------------------------------------
/scripts/index_documents.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Script to index documents in Solr with vector embeddings
generated using Ollama's nomic-embed-text model.
"""
import argparse
import asyncio
import json
import os
import sys
from typing import Dict, List
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.embeddings.client import OllamaClient
from solr_mcp.solr.client import SolrClient
async def index_documents(json_file: str, collection: str = "vectors", commit: bool = True):
"""
Index documents from a JSON file into Solr with vector embeddings.
Args:
json_file: Path to the JSON file containing documents
collection: Solr collection name
commit: Whether to commit after indexing
"""
# Load documents
with open(json_file, 'r', encoding='utf-8') as f:
documents = json.load(f)
# Initialize clients
solr_client = SolrClient()
# Check if collection exists
collections = solr_client.list_collections()
if collection not in collections:
print(f"Warning: Collection '{collection}' not found in Solr. Available collections: {collections}")
response = input("Do you want to continue with the default collection? (y/N): ")
if response.lower() != 'y':
print("Aborting.")
return
collection = solr_client.config.default_collection
# Index documents with embeddings
print(f"Indexing {len(documents)} documents with embeddings...")
try:
success = await solr_client.batch_index_with_generated_embeddings(
documents=documents,
collection=collection,
commit=commit
)
if success:
print(f"Successfully indexed {len(documents)} documents in collection '{collection}'")
else:
print("Indexing failed")
except Exception as e:
print(f"Error indexing documents: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Index documents in Solr with vector embeddings")
parser.add_argument("json_file", help="Path to the JSON file containing documents")
parser.add_argument("--collection", "-c", default="vectors", help="Solr collection name")
parser.add_argument("--no-commit", dest="commit", action="store_false", help="Don't commit after indexing")
args = parser.parse_args()
asyncio.run(index_documents(args.json_file, args.collection, args.commit))
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/server_fixtures.py:
--------------------------------------------------------------------------------
```python
"""Server fixtures for unit tests."""
import json
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from solr_mcp.server import SolrMCPServer
from .common import MOCK_RESPONSES
@pytest.fixture
def mock_server(mock_solr_client, mock_config):
"""Create a mock SolrMCPServer for testing."""
server = SolrMCPServer(
solr_base_url=mock_config.solr_base_url,
zookeeper_hosts=mock_config.zookeeper_hosts,
connection_timeout=mock_config.connection_timeout,
)
server.solr_client = mock_solr_client
return server
@pytest.fixture
def mock_server_instance():
"""Create a mock FastMCP server instance for testing."""
mock_server = MagicMock()
# Mock list collections response
async def mock_list_collections(*args, **kwargs):
return [{"type": "text", "text": json.dumps(MOCK_RESPONSES["collections"])}]
mock_server.list_collections = AsyncMock(side_effect=mock_list_collections)
# Mock select query response
async def mock_select(*args, **kwargs):
return [
{
"type": "text",
"text": json.dumps({"rows": [{"id": "1", "title": "Test Doc"}]}),
}
]
mock_server.select = AsyncMock(side_effect=mock_select)
# Mock vector select response
async def mock_vector_select(*args, **kwargs):
return [
{
"type": "text",
"text": json.dumps({"rows": [{"id": "1", "title": "Test Doc"}]}),
}
]
mock_server.vector_select = AsyncMock(side_effect=mock_vector_select)
# Mock semantic select response
async def mock_semantic_select(*args, **kwargs):
return [
{
"type": "text",
"text": json.dumps({"rows": [{"id": "1", "title": "Test Doc"}]}),
}
]
mock_server.semantic_select = AsyncMock(side_effect=mock_semantic_select)
return mock_server
@pytest.fixture
def mock_singleton_server():
"""Mock SolrMCPServer for singleton pattern testing."""
# Create a mock class to avoid affecting real singleton
MockServer = Mock(spec=SolrMCPServer)
MockServer._instance = None
# Create a proper classmethod mock
def get_instance():
return MockServer._instance
MockServer.get_instance = classmethod(get_instance)
# Create two different instances
server1 = Mock(spec=SolrMCPServer)
server2 = Mock(spec=SolrMCPServer)
with patch("solr_mcp.server.SolrMCPServer", MockServer):
yield {"MockServer": MockServer, "server1": server1, "server2": server2}
```
--------------------------------------------------------------------------------
/tests/unit/test_config.py:
--------------------------------------------------------------------------------
```python
"""Tests for Solr configuration."""
import json
from unittest.mock import mock_open, patch
import pytest
from solr_mcp.solr.config import SolrConfig
from solr_mcp.solr.exceptions import ConfigurationError
def test_config_defaults():
"""Test default configuration values."""
config = SolrConfig(
solr_base_url="http://test:8983/solr", zookeeper_hosts=["test:2181"]
)
assert config.solr_base_url == "http://test:8983/solr"
assert config.zookeeper_hosts == ["test:2181"]
assert config.connection_timeout == 10
def test_config_custom_values():
"""Test custom configuration values."""
config = SolrConfig(
solr_base_url="http://custom:8983/solr",
zookeeper_hosts=["custom:2181"],
connection_timeout=20,
)
assert config.solr_base_url == "http://custom:8983/solr"
assert config.zookeeper_hosts == ["custom:2181"]
assert config.connection_timeout == 20
def test_config_validation():
"""Test configuration validation."""
with pytest.raises(ConfigurationError, match="solr_base_url is required"):
SolrConfig(zookeeper_hosts=["test:2181"])
with pytest.raises(ConfigurationError, match="zookeeper_hosts is required"):
SolrConfig(solr_base_url="http://test:8983/solr")
with pytest.raises(ConfigurationError, match="connection_timeout must be positive"):
SolrConfig(
solr_base_url="http://test:8983/solr",
zookeeper_hosts=["test:2181"],
connection_timeout=0,
)
def test_load_from_file():
"""Test loading configuration from file."""
config_data = {
"solr_base_url": "http://test:8983/solr",
"zookeeper_hosts": ["test:2181"],
"connection_timeout": 20,
}
with patch("builtins.open", mock_open(read_data=json.dumps(config_data))):
config = SolrConfig.load("config.json")
assert config.solr_base_url == "http://test:8983/solr"
assert config.zookeeper_hosts == ["test:2181"]
assert config.connection_timeout == 20
def test_load_invalid_json():
"""Test loading invalid JSON."""
with patch("builtins.open", mock_open(read_data="invalid json")):
with pytest.raises(
ConfigurationError, match="Invalid JSON in configuration file"
):
SolrConfig.load("config.json")
def test_load_missing_required_field():
"""Test loading config with missing required field."""
config_data = {
"solr_base_url": "http://test:8983/solr"
# Missing zookeeper_hosts
}
with patch("builtins.open", mock_open(read_data=json.dumps(config_data))):
with pytest.raises(ConfigurationError, match="zookeeper_hosts is required"):
SolrConfig.load("config.json")
```
--------------------------------------------------------------------------------
/scripts/direct_mcp_test.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Direct MCP server test script.
Tests the raw JSON-RPC interface that Claude uses to communicate with MCP servers.
"""
import sys
import os
import json
import subprocess
import time
from threading import Thread
import tempfile
# Add the project root to your path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
# First clean up any existing MCP servers
os.system("pkill -f 'python -m solr_mcp.server'")
time.sleep(1) # Let them shut down
def write_to_stdin(process, data):
"""Write data to the stdin of a process and flush."""
process.stdin.write(data)
process.stdin.flush()
def read_from_stdout(process):
"""Read a JSON-RPC message from stdout of a process."""
line = process.stdout.readline().strip()
if not line:
return None
try:
return json.loads(line)
except json.JSONDecodeError:
print(f"Error decoding JSON: {line}")
return None
# Start a new MCP server process
cmd = ["python", "-m", "solr_mcp.server"]
server_process = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1, # Line buffered
)
print("MCP server started.")
time.sleep(2) # Give it time to initialize
# Test search methods
def test_search(query):
print(f"\n\nTesting search for: '{query}'")
# Try a standard search
request = {
"jsonrpc": "2.0",
"id": "1",
"method": "execute_tool",
"params": {
"name": "solr_search",
"arguments": {
"query": query
}
}
}
print("\nSending search request:", json.dumps(request, indent=2))
write_to_stdin(server_process, json.dumps(request) + "\n")
response = read_from_stdout(server_process)
print("\nGot response:", json.dumps(response, indent=2) if response else "No response")
# Try a hybrid search
request = {
"jsonrpc": "2.0",
"id": "2",
"method": "execute_tool",
"params": {
"name": "solr_hybrid_search",
"arguments": {
"query": query,
"blend_factor": 0.5
}
}
}
print("\nSending hybrid search request:", json.dumps(request, indent=2))
write_to_stdin(server_process, json.dumps(request) + "\n")
response = read_from_stdout(server_process)
print("\nGot hybrid response:", json.dumps(response, indent=2) if response else "No response")
# Test with a query we know exists
test_search("double spend")
# Test with another query
test_search("blockchain")
# Clean up
print("\nCleaning up...")
server_process.terminate()
server_process.wait()
print("Done!")
```
--------------------------------------------------------------------------------
/tests/unit/tools/test_base.py:
--------------------------------------------------------------------------------
```python
"""Tests for base tool decorator."""
from typing import Dict, List
import pytest
from solr_mcp.tools.base import tool
def test_tool_decorator_default_values():
"""Test tool decorator with default values."""
@tool()
def sample_tool() -> str:
"""Sample tool docstring."""
return "test"
assert hasattr(sample_tool, "_is_tool")
assert sample_tool._is_tool is True
assert sample_tool._tool_name == "sample_tool"
assert "Sample tool docstring" in sample_tool._tool_description
assert sample_tool._tool_parameters == {}
def test_tool_decorator_custom_values():
"""Test tool decorator with custom values."""
@tool(
name="custom_name",
description="Custom description",
parameters={"param": "description"},
)
def sample_tool() -> str:
return "test"
assert sample_tool._is_tool is True
assert sample_tool._tool_name == "custom_name"
assert sample_tool._tool_description == "Custom description"
assert sample_tool._tool_parameters == {"param": "description"}
def test_tool_decorator_result_wrapping():
"""Test that tool decorator properly wraps results."""
@tool()
def string_tool() -> str:
return "test"
@tool()
def dict_tool() -> Dict[str, str]:
return {"key": "value"}
@tool()
def list_tool() -> List[Dict[str, str]]:
return [{"type": "text", "text": "test"}]
# String result should be wrapped
result = string_tool()
assert isinstance(result, list)
assert len(result) == 1
assert result[0]["type"] == "text"
assert result[0]["text"] == "test"
# Dict result should be wrapped
result = dict_tool()
assert isinstance(result, list)
assert len(result) == 1
assert result[0]["type"] == "text"
assert result[0]["text"] == "{'key': 'value'}"
# List result should be returned as is
result = list_tool()
assert isinstance(result, list)
assert len(result) == 1
assert result[0]["type"] == "text"
assert result[0]["text"] == "test"
def test_tool_decorator_preserves_function_metadata():
"""Test that tool decorator preserves function metadata."""
@tool()
def sample_tool(param1: str, param2: int = 0) -> str:
"""Sample tool docstring."""
return f"{param1} {param2}"
assert sample_tool.__name__ == "sample_tool"
assert "Sample tool docstring" in sample_tool.__doc__
# Check that the function signature is preserved
import inspect
sig = inspect.signature(sample_tool)
assert list(sig.parameters.keys()) == ["param1", "param2"]
assert sig.parameters["param1"].annotation == str
assert sig.parameters["param2"].annotation == int
assert sig.parameters["param2"].default == 0
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/vector_fixtures.py:
--------------------------------------------------------------------------------
```python
"""Vector search fixtures for unit tests."""
import json
from unittest.mock import MagicMock, Mock, patch
import pytest
import requests
from solr_mcp.solr.interfaces import VectorSearchProvider
from solr_mcp.solr.vector.manager import VectorManager
from solr_mcp.vector_provider.clients.ollama import OllamaVectorProvider
@pytest.fixture
def mock_ollama(request):
"""Parameterized mock for Ollama client.
Args:
request: Pytest request object that can contain parameters:
- vector_dim: Dimension of returned vectors
- error: Whether to simulate an error
"""
# Get parameters or use defaults
vector_dim = getattr(request, "param", {}).get("vector_dim", 3)
error = getattr(request, "param", {}).get("error", False)
provider = Mock(spec=OllamaVectorProvider)
if error:
provider.get_vector.side_effect = Exception("Ollama API error")
else:
provider.get_vector.return_value = [0.1] * vector_dim
return provider
@pytest.fixture
def mock_vector_provider(request):
"""Parameterized mock for vector provider.
Args:
request: Pytest request object that can contain parameters:
- vector_dim: Dimension of returned vectors
- error: Whether to simulate an error
"""
# Get parameters or use defaults
vector_dim = getattr(request, "param", {}).get("vector_dim", 768)
error = getattr(request, "param", {}).get("error", False)
provider = Mock(spec=VectorSearchProvider)
if error:
provider.get_vector.side_effect = Exception("Vector API error")
else:
provider.get_vector.return_value = [0.1] * vector_dim
return provider
@pytest.fixture
def mock_vector_manager(request):
"""Parameterized mock VectorManager.
Args:
request: Pytest request object that can contain parameters:
- vector_dim: Dimension of returned vectors
- error: Whether to simulate an error
"""
# Get parameters or use defaults
vector_dim = getattr(request, "param", {}).get("vector_dim", 3)
error = getattr(request, "param", {}).get("error", False)
manager = Mock(spec=VectorManager)
if error:
manager.get_vector.side_effect = Exception("Vector generation error")
else:
manager.get_vector.return_value = [0.1] * vector_dim
return manager
@pytest.fixture
def mock_ollama_response(request):
"""Parameterized mock Ollama API response.
Args:
request: Pytest request object that can contain parameters:
- vector_dim: Dimension of returned vectors
- model: Model name to include in response
"""
# Get parameters or use defaults
vector_dim = getattr(request, "param", {}).get("vector_dim", 5)
model = getattr(request, "param", {}).get("model", "nomic-embed-text")
return {"embedding": [0.1] * vector_dim, "model": model}
```
--------------------------------------------------------------------------------
/tests/unit/vector_provider/test_interfaces.py:
--------------------------------------------------------------------------------
```python
"""Tests for vector provider interfaces."""
from typing import List
import pytest
from solr_mcp.vector_provider.exceptions import (
VectorConfigError,
VectorConnectionError,
VectorGenerationError,
)
from solr_mcp.vector_provider.interfaces import VectorProvider
class MockVectorProvider(VectorProvider):
"""Mock implementation of VectorProvider for testing."""
def __init__(self, dimension: int = 768):
self._dimension = dimension
self._model = "mock-model"
async def get_vector(self, text: str) -> List[float]:
if text == "error":
raise VectorGenerationError("Test error")
return [0.1] * self._dimension
async def get_vectors(self, texts: List[str]) -> List[List[float]]:
if any(t == "error" for t in texts):
raise VectorGenerationError("Test error")
return [[0.1] * self._dimension for _ in texts]
@property
def vector_dimension(self) -> int:
return self._dimension
@property
def model_name(self) -> str:
return self._model
def test_vector_provider_is_abstract():
"""Test that VectorProvider cannot be instantiated directly."""
with pytest.raises(TypeError):
VectorProvider()
def test_vector_provider_requires_methods():
"""Test that implementing class must define all abstract methods."""
class IncompleteProvider(VectorProvider):
pass
with pytest.raises(TypeError):
IncompleteProvider()
@pytest.mark.asyncio
async def test_mock_provider_get_vector():
"""Test get_vector implementation."""
provider = MockVectorProvider()
result = await provider.get_vector("test")
assert len(result) == 768
assert all(x == 0.1 for x in result)
@pytest.mark.asyncio
async def test_mock_provider_get_vector_error():
"""Test get_vector error handling."""
provider = MockVectorProvider()
with pytest.raises(VectorGenerationError):
await provider.get_vector("error")
@pytest.mark.asyncio
async def test_mock_provider_get_vectors():
"""Test get_vectors implementation."""
provider = MockVectorProvider()
texts = ["test1", "test2"]
result = await provider.get_vectors(texts)
assert len(result) == 2
assert all(len(v) == 768 for v in result)
assert all(all(x == 0.1 for x in v) for v in result)
@pytest.mark.asyncio
async def test_mock_provider_get_vectors_error():
"""Test get_vectors error handling."""
provider = MockVectorProvider()
with pytest.raises(VectorGenerationError):
await provider.get_vectors(["test", "error"])
def test_mock_provider_vector_dimension():
"""Test vector_dimension property."""
provider = MockVectorProvider(dimension=512)
assert provider.vector_dimension == 512
def test_mock_provider_model_name():
"""Test model_name property."""
provider = MockVectorProvider()
assert provider.model_name == "mock-model"
```
--------------------------------------------------------------------------------
/tests/unit/tools/test_solr_list_fields.py:
--------------------------------------------------------------------------------
```python
"""Tests for the list fields tool."""
import pytest
from solr_mcp.solr.exceptions import SolrError
from solr_mcp.tools.solr_list_fields import execute_list_fields
# Sample field data for testing
FIELD_DATA = {
"fields": [
{"name": "id", "type": "string", "indexed": True, "stored": True},
{
"name": "_text_",
"type": "text_general",
"indexed": True,
"stored": False,
"copies_from": ["title", "content"],
},
]
}
@pytest.mark.asyncio
@pytest.mark.parametrize(
"collection,custom_fields",
[
("test_collection", None),
(
"custom_collection",
[
{
"name": "custom_id",
"type": "string",
"indexed": True,
"stored": True,
},
{
"name": "custom_text",
"type": "text_general",
"indexed": True,
"stored": False,
},
],
),
],
)
async def test_execute_list_fields_success(mock_server, collection, custom_fields):
"""Test successful execution of list_fields tool with different collections and field sets."""
# Use default fields or custom fields based on parameter
fields = custom_fields or FIELD_DATA["fields"]
mock_server.solr_client.list_fields.return_value = fields
# Execute the tool
result = await execute_list_fields(mock_server, collection)
# Verify the result
assert result["collection"] == collection
assert len(result["fields"]) == len(fields)
assert result["fields"][0]["name"] == fields[0]["name"]
# Check for copies_from in the default test case
if custom_fields is None and "copies_from" in fields[1]:
assert "copies_from" in result["fields"][1]
# Verify the correct collection was used
mock_server.solr_client.list_fields.assert_called_once_with(collection)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"error_message",
["Failed to list fields", "Collection not found", "Connection error"],
)
async def test_execute_list_fields_error(mock_solr_client, mock_config, error_message):
"""Test error handling in list_fields tool with different error messages."""
# Create a server with a parameterized error client
error_client = mock_solr_client(param={"error": True})
from solr_mcp.server import SolrMCPServer
server = SolrMCPServer(
solr_base_url=mock_config.solr_base_url,
zookeeper_hosts=mock_config.zookeeper_hosts,
connection_timeout=mock_config.connection_timeout,
)
server.solr_client = error_client
# Override the exception message
error_client.list_fields.side_effect = SolrError(error_message)
# Verify the exception is raised with the correct message
with pytest.raises(SolrError, match=error_message):
await execute_list_fields(server, "test_collection")
```
--------------------------------------------------------------------------------
/solr_mcp/solr/schema/cache.py:
--------------------------------------------------------------------------------
```python
"""Field caching for SolrCloud client."""
import logging
import time
from typing import Any, Dict, List, Optional
from loguru import logger
from solr_mcp.solr.constants import SYNTHETIC_SORT_FIELDS
logger = logging.getLogger(__name__)
class FieldCache:
"""Caches field information for Solr collections."""
def __init__(self):
"""Initialize the FieldCache."""
self._cache: Dict[str, Dict[str, Any]] = {}
def get(self, collection: str) -> Optional[Dict[str, Any]]:
"""Get cached field information for a collection.
Args:
collection: Collection name
Returns:
Dict containing field information or None if not cached
"""
if collection in self._cache:
return self._cache[collection]
return None
def set(self, collection: str, field_info: Dict[str, Any]) -> None:
"""Cache field information for a collection.
Args:
collection: Collection name
field_info: Field information to cache
"""
self._cache[collection] = {**field_info, "last_updated": time.time()}
def is_stale(self, collection: str, max_age: float = 300.0) -> bool:
"""Check if cached field information is stale.
Args:
collection: Collection name
max_age: Maximum age in seconds before cache is considered stale
Returns:
True if cache is stale or missing, False otherwise
"""
if collection not in self._cache:
return True
last_updated = self._cache[collection].get("last_updated", 0)
return (time.time() - last_updated) > max_age
def get_or_default(self, collection: str) -> Dict[str, Any]:
"""Get cached field information or return defaults.
Args:
collection: Collection name
Returns:
Dict containing field information (cached or default)
"""
if collection in self._cache:
return self._cache[collection]
# Return safe defaults
return {
"searchable_fields": ["_text_"],
"sortable_fields": {"score": SYNTHETIC_SORT_FIELDS["score"]},
"last_updated": time.time(),
}
def clear(self, collection: Optional[str] = None) -> None:
"""Clear cached field information.
Args:
collection: Collection name to clear, or None to clear all
"""
if collection:
self._cache.pop(collection, None)
else:
self._cache.clear()
def update(self, collection: str, field_info: Dict[str, Any]) -> None:
"""Update cached field information.
Args:
collection: Collection name
field_info: Field information to update
"""
if collection in self._cache:
self._cache[collection].update(field_info)
self._cache[collection]["last_updated"] = time.time()
else:
self.set(collection, field_info)
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/config_fixtures.py:
--------------------------------------------------------------------------------
```python
"""Configuration fixtures for unit tests."""
from unittest.mock import Mock, mock_open, patch
import pytest
from solr_mcp.solr.config import SolrConfig
@pytest.fixture
def mock_config(request):
"""Parameterized SolrConfig mock.
Args:
request: Pytest request object that can contain parameters:
- base_url: Custom Solr base URL
- zk_hosts: Custom ZooKeeper hosts
- timeout: Custom connection timeout
"""
# Get parameters or use defaults
base_url = getattr(request, "param", {}).get(
"base_url", "http://localhost:8983/solr"
)
zk_hosts = getattr(request, "param", {}).get("zk_hosts", ["localhost:2181"])
timeout = getattr(request, "param", {}).get("timeout", 10)
config = Mock(spec=SolrConfig)
config.solr_base_url = base_url
config.zookeeper_hosts = zk_hosts
config.connection_timeout = timeout
return config
@pytest.fixture(
params=[
# Format: (fixture_name, content, side_effect)
(
"valid",
"""
{
"solr_base_url": "http://solr:8983/solr",
"zookeeper_hosts": ["zk1:2181", "zk2:2181"],
"connection_timeout": 30
}
""",
None,
),
("invalid_json", "invalid json content", None),
(
"minimal",
"""
{
"zookeeper_hosts": ["zk1:2181"]
}
""",
None,
),
("missing", None, FileNotFoundError()),
]
)
def mock_config_file(request):
"""Parameterized fixture for different config file scenarios."""
fixture_name, content, side_effect = request.param
if side_effect:
with patch("builtins.open", side_effect=side_effect):
yield fixture_name
else:
with patch("builtins.open", mock_open(read_data=content)):
yield fixture_name
@pytest.fixture
def mock_field_manager_methods():
"""Mock FieldManager methods for testing."""
mock_fields = {
"searchable_fields": ["title", "content"],
"sortable_fields": {
"id": {"directions": ["asc", "desc"], "default_direction": "asc"},
"score": {
"directions": ["asc", "desc"],
"default_direction": "desc",
"type": "numeric",
"searchable": True,
},
},
}
def patch_get_collection_fields(field_manager):
"""Create a context manager for patching _get_collection_fields."""
return patch.object(
field_manager, "_get_collection_fields", return_value=mock_fields
)
def patch_get_searchable_fields(field_manager):
"""Create a context manager for patching _get_searchable_fields."""
return patch.object(
field_manager, "_get_searchable_fields", side_effect=Exception("API error")
)
return {
"mock_fields": mock_fields,
"patch_get_collection_fields": patch_get_collection_fields,
"patch_get_searchable_fields": patch_get_searchable_fields,
}
```
--------------------------------------------------------------------------------
/scripts/setup.sh:
--------------------------------------------------------------------------------
```bash
#!/bin/bash
# Setup script for Solr MCP Server
set -e # Exit immediately if a command exits with a non-zero status
echo "=== Setting up Solr MCP Server ==="
# Check if Docker is installed
if ! command -v docker &> /dev/null; then
echo "Docker is not installed. Please install Docker and Docker Compose first."
exit 1
fi
# Check if Docker Compose is installed
if ! command -v docker-compose &> /dev/null; then
echo "Docker Compose is not installed. Please install Docker Compose first."
exit 1
fi
# Create Python virtual environment
echo "Creating Python virtual environment..."
python3 -m venv venv
source venv/bin/activate
# Install dependencies
echo "Installing dependencies..."
pip install poetry
poetry install
# Start Docker containers
echo "Starting SolrCloud, ZooKeeper, and Ollama containers..."
docker-compose up -d
# Wait for Solr to be ready
echo "Waiting for SolrCloud to be ready..."
sleep 10
attempts=0
max_attempts=30
while ! curl -s http://localhost:8983/solr/ > /dev/null; do
attempts=$((attempts+1))
if [ $attempts -ge $max_attempts ]; then
echo "Error: SolrCloud did not start in time. Please check docker-compose logs."
exit 1
fi
echo "Waiting for SolrCloud to start... (attempt $attempts/$max_attempts)"
sleep 5
done
# Create unified collection
echo "Creating unified collection..."
python scripts/create_unified_collection.py
# Process demo data (Bitcoin whitepaper)
echo "Processing demo data..."
python scripts/process_markdown.py data/bitcoin-whitepaper.md --output data/processed/bitcoin_sections.json
# Index demo data to unified collection
echo "Indexing demo data to unified collection..."
python scripts/unified_index.py data/processed/bitcoin_sections.json --collection unified
# Test search to ensure content is indexed properly
echo "Testing search functionality..."
python -c "
import httpx
import asyncio
async def test_search():
async with httpx.AsyncClient() as client:
response = await client.get(
'http://localhost:8983/solr/unified/select',
params={
'q': 'content:\"double spend\"~5',
'wt': 'json'
}
)
results = response.json()
if results.get('response', {}).get('numFound', 0) > 0:
print('✅ Search test successful! Found documents matching \"double spend\"')
else:
print('❌ Warning: No documents found for \"double spend\". Search may not work properly.')
print(' Try running: python scripts/diagnose_search.py --collection unified --term \"double spend\"')
asyncio.run(test_search())
"
echo ""
echo "=== Setup Complete! ==="
echo ""
echo "You can now use the Solr MCP server with the following commands:"
echo ""
echo "1. Start the MCP server:"
echo " python -m solr_mcp.server"
echo ""
echo "2. Try hybrid search on the demo data:"
echo " python scripts/demo_hybrid_search.py \"blockchain\" --mode compare"
echo ""
echo "3. Use the Claude Desktop integration by configuring the MCP server"
echo " in Claude's configuration file (see README.md for details)."
echo ""
echo "For more information, please refer to the documentation in README.md."
```
--------------------------------------------------------------------------------
/solr_mcp/tools/solr_semantic_select.py:
--------------------------------------------------------------------------------
```python
"""Tool for executing semantic search queries against Solr collections."""
from typing import Dict, List, Optional
from solr_mcp.tools.tool_decorator import tool
from solr_mcp.vector_provider.constants import DEFAULT_OLLAMA_CONFIG
@tool()
async def execute_semantic_select_query(
mcp, query: str, text: str, field: Optional[str] = None, vector_provider: str = ""
) -> Dict:
"""Execute semantic search queries against Solr collections.
Extends solr_select tool with semantic search capabilities.
Additional Parameters:
- text: Natural language text that is converted to vector, which will be used to match against other vector fields
- field: Name of the vector field to search against (optional, will auto-detect if not specified)
- vector_provider: Vector provider specification in format "model@host:port" (e.g., "nomic-embed-text@localhost:11434")
If not specified, the default vector provider will be used
The query results will be ranked based on semantic similarity to the provided text. Therefore, ORDER BY is not allowed.
Collection/Field Rules:
- Vector field must be a dense_vector or knn_vector field type
- The specified field must exist in the collection schema
- The vector provider's dimensionality must match the dimensionality of the vector field
Supported Features:
- All standard SELECT query features except ORDER BY
- Results are ordered by semantic similarity
- Hybrid search combining keyword (SQL WHERE clauses) and vector distance (text parameter)
Args:
mcp: SolrMCPServer instance
query: SQL query to execute
text: Search text to convert to vector
field: Name of the vector field to search against (optional, auto-detected if not specified)
vector_provider: Optional vector provider specification "model@host:port"
Returns:
Query results
"""
solr_client = mcp.solr_client
# Configure vector provider from parameter string
vector_provider_config = {}
if vector_provider:
# Parse "model@host:port" format
model_part = vector_provider
host_port_part = None
if "@" in vector_provider:
parts = vector_provider.split("@", 1)
model_part = parts[0]
host_port_part = parts[1]
# Set model if specified
if model_part:
vector_provider_config["model"] = model_part
# Set host:port if specified
if host_port_part:
if ":" in host_port_part:
host, port_str = host_port_part.split(":", 1)
try:
port = int(port_str)
vector_provider_config["base_url"] = f"http://{host}:{port}"
except ValueError:
# If port is not a valid integer, use the host with default port
vector_provider_config["base_url"] = f"http://{host_port_part}"
else:
# Only host specified, use default port
vector_provider_config["base_url"] = f"http://{host_port_part}:11434"
return await solr_client.execute_semantic_select_query(
query, text, field, vector_provider_config
)
```
--------------------------------------------------------------------------------
/QUICKSTART.md:
--------------------------------------------------------------------------------
```markdown
# Solr MCP Quick Start Guide
This guide will help you get up and running with the Solr MCP server quickly.
## Prerequisites
- Python 3.10 or higher
- Docker and Docker Compose
- Git
## Step 1: Clone the Repository
```bash
git clone https://github.com/allenday/solr-mcp.git
cd solr-mcp
```
## Step 2: Start SolrCloud with Docker
```bash
docker-compose up -d
```
This will start a SolrCloud instance with ZooKeeper and Ollama for embedding generation.
Verify that Solr is running by visiting: http://localhost:8983/solr/
## Step 3: Set Up Python Environment
```bash
# Create a virtual environment
python -m venv venv
# Activate it
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install Poetry
pip install poetry
# Install dependencies
poetry install
```
## Step 4: Process and Index Sample Documents
The repository includes the Bitcoin whitepaper as a sample document. Let's process and index it:
```bash
# Process the Markdown file into sections
python scripts/process_markdown.py data/bitcoin-whitepaper.md --output data/processed/bitcoin_sections.json
# Create a unified collection
python scripts/create_unified_collection.py unified
# Index the sections with embeddings
python scripts/unified_index.py data/processed/bitcoin_sections.json --collection unified
```
## Step 5: Run the MCP Server
```bash
poetry run python -m solr_mcp.server
```
By default, the server will run on http://localhost:8000
## Step 6: Test the Search Functionality
You can test the different search capabilities using the demo scripts:
```bash
# Test keyword search
python scripts/simple_search.py "double spend" --collection unified
# Test vector search
python scripts/vector_search.py "how does bitcoin prevent fraud" --collection unified
# Test hybrid search (combining keyword and vector)
python scripts/simple_mcp_test.py
```
## Using with Claude Desktop
To use the MCP server with Claude Desktop:
1. Make sure the MCP server is running
2. In Claude Desktop, go to Settings > Tools
3. Add a new tool with:
- Name: Solr Search
- URL: http://localhost:8000
- Working Directory: /path/to/solr-mcp
Now you can ask Claude queries like:
- "Search for information about double spending in the Bitcoin whitepaper"
- "Find sections related to consensus mechanisms"
- "What does the whitepaper say about transaction verification?"
## Troubleshooting
If you encounter issues:
1. Check that Solr is running: http://localhost:8983/solr/
2. Verify the collection exists: http://localhost:8983/solr/#/~collections
3. Run the diagnostic script: `python scripts/diagnose_search.py`
4. Check the server logs for errors
## Setup linting and formatting
We use several tools to maintain code quality:
```bash
# Run code formatters (black and isort)
poetry run python scripts/format.py
# Or use the poetry script
poetry run format
# Run linters (flake8 and mypy)
poetry run python scripts/lint.py
# Or use the poetry script
poetry run lint
```
You can also run individual tools:
```bash
# Format code with Black
poetry run black solr_mcp tests
# Sort imports with isort
poetry run isort solr_mcp tests
# Run flake8 linter
poetry run flake8 solr_mcp tests
# Run mypy type checker
poetry run mypy solr_mcp tests
```
```
--------------------------------------------------------------------------------
/tests/unit/test_client.py:
--------------------------------------------------------------------------------
```python
"""Unit tests for SolrClient."""
from unittest.mock import Mock, patch
import pytest
from solr_mcp.solr.client import SolrClient
from solr_mcp.solr.interfaces import CollectionProvider, VectorSearchProvider
from .conftest import MOCK_RESPONSES, MockCollectionProvider, MockVectorProvider
class TestSolrClient:
"""Test cases for SolrClient."""
def test_init_with_defaults(self, mock_config, mock_field_manager, mock_ollama):
"""Test initialization with default dependencies."""
client = SolrClient(
config=mock_config,
field_manager=mock_field_manager,
vector_provider=mock_ollama,
)
assert client.config == mock_config
assert isinstance(client.collection_provider, CollectionProvider)
assert client.field_manager == mock_field_manager
assert client.vector_provider == mock_ollama
def test_init_with_custom_providers(self, mock_config, mock_field_manager):
"""Test initialization with custom providers."""
mock_collection_provider = MockCollectionProvider()
mock_vector_provider = MockVectorProvider()
mock_solr = Mock() # Create a simple mock
client = SolrClient(
config=mock_config,
collection_provider=mock_collection_provider,
solr_client=mock_solr,
field_manager=mock_field_manager,
vector_provider=mock_vector_provider,
)
assert client.config == mock_config
assert client.collection_provider == mock_collection_provider
assert client.field_manager == mock_field_manager
assert client.vector_provider == mock_vector_provider
@pytest.mark.asyncio
@pytest.mark.parametrize("collection", ["collection1", "test_collection"])
async def test_execute_select_query_success(
self, mock_config, mock_field_manager, collection
):
"""Test successful SQL query execution with different collections."""
# Create a mock for the query builder
mock_query_builder = Mock()
mock_query_builder.parser = Mock()
mock_query_builder.parser.preprocess_query = Mock(
return_value=f"SELECT * FROM {collection}"
)
mock_query_builder.parse_and_validate_select = Mock(
return_value=(
Mock(), # AST
collection, # Collection name
["id", "title"], # Fields
)
)
# Create a mock response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"result-set": {"docs": [{"id": "1", "field": "value"}], "numFound": 1}
}
# Create client with dependencies and patch requests.post
with patch("requests.post", return_value=mock_response):
client = SolrClient(
config=mock_config,
field_manager=mock_field_manager,
query_builder=mock_query_builder,
)
# Execute query
result = await client.execute_select_query(f"SELECT * FROM {collection}")
# Verify result structure
assert "result-set" in result
assert "docs" in result["result-set"]
assert result["result-set"]["docs"][0]["id"] == "1"
```
--------------------------------------------------------------------------------
/solr_mcp/solr/zookeeper.py:
--------------------------------------------------------------------------------
```python
"""ZooKeeper-based collection provider."""
from typing import List
import anyio
from kazoo.client import KazooClient
from kazoo.exceptions import ConnectionLoss, NoNodeError
from solr_mcp.solr.exceptions import ConnectionError
from solr_mcp.solr.interfaces import CollectionProvider
class ZooKeeperCollectionProvider(CollectionProvider):
"""Collection provider that uses ZooKeeper to discover collections."""
def __init__(self, hosts: List[str]):
"""Initialize with ZooKeeper hosts.
Args:
hosts: List of ZooKeeper hosts in format host:port
"""
self.hosts = hosts
self.zk = None
self.connect()
def connect(self):
"""Connect to ZooKeeper and verify /collections path exists."""
try:
self.zk = KazooClient(hosts=",".join(self.hosts))
self.zk.start()
# Check if /collections path exists
if not self.zk.exists("/collections"):
raise ConnectionError("ZooKeeper /collections path does not exist")
except ConnectionLoss as e:
raise ConnectionError(f"Failed to connect to ZooKeeper: {str(e)}")
except Exception as e:
raise ConnectionError(f"Error connecting to ZooKeeper: {str(e)}")
def cleanup(self):
"""Clean up ZooKeeper connection."""
if self.zk:
try:
self.zk.stop()
self.zk.close()
except Exception:
pass # Ignore cleanup errors
finally:
self.zk = None
async def list_collections(self) -> List[str]:
"""List available collections from ZooKeeper.
Returns:
List of collection names
Raises:
ConnectionError: If there is an error communicating with ZooKeeper
"""
try:
if not self.zk:
raise ConnectionError("Not connected to ZooKeeper")
collections = await anyio.to_thread.run_sync(
self.zk.get_children, "/collections"
)
return collections
except NoNodeError:
return [] # No collections exist yet
except ConnectionLoss as e:
raise ConnectionError(f"Lost connection to ZooKeeper: {str(e)}")
except Exception as e:
raise ConnectionError(f"Error listing collections: {str(e)}")
async def collection_exists(self, collection: str) -> bool:
"""Check if a collection exists in ZooKeeper.
Args:
collection: Name of the collection to check
Returns:
True if the collection exists, False otherwise
Raises:
ConnectionError: If there is an error communicating with ZooKeeper
"""
try:
if not self.zk:
raise ConnectionError("Not connected to ZooKeeper")
# Check for collection in ZooKeeper
collection_path = f"/collections/{collection}"
exists = await anyio.to_thread.run_sync(self.zk.exists, collection_path)
return exists is not None
except ConnectionLoss as e:
raise ConnectionError(f"Lost connection to ZooKeeper: {str(e)}")
except Exception as e:
raise ConnectionError(f"Error checking collection existence: {str(e)}")
```
--------------------------------------------------------------------------------
/solr_mcp/solr/exceptions.py:
--------------------------------------------------------------------------------
```python
"""Solr client exceptions."""
from typing import Any, Dict, Optional
class SolrError(Exception):
"""Base exception for Solr-related errors."""
pass
class ConfigurationError(SolrError):
"""Configuration-related errors."""
pass
class ConnectionError(SolrError):
"""Exception raised for connection-related errors."""
pass
class QueryError(SolrError):
"""Base exception for query-related errors."""
def __init__(
self,
message: str,
error_type: Optional[str] = None,
response_time: Optional[int] = None,
):
self.message = message
self.error_type = error_type
self.response_time = response_time
super().__init__(self.message)
def to_dict(self) -> Dict[str, Any]:
"""Convert the error to a dictionary format."""
return {
"error_type": self.error_type,
"message": self.message,
"response_time": self.response_time,
}
class DocValuesError(QueryError):
"""Exception raised when a query requires DocValues but fields don't have them enabled."""
def __init__(self, message: str, response_time: Optional[int] = None):
super().__init__(
message, error_type="MISSING_DOCVALUES", response_time=response_time
)
class SQLParseError(QueryError):
"""Exception raised when SQL query parsing fails."""
def __init__(self, message: str, response_time: Optional[int] = None):
super().__init__(message, error_type="PARSE_ERROR", response_time=response_time)
class SQLExecutionError(QueryError):
"""Exception raised for other SQL execution errors."""
def __init__(self, message: str, response_time: Optional[int] = None):
super().__init__(
message, error_type="SOLR_SQL_ERROR", response_time=response_time
)
class SchemaError(SolrError):
"""Base exception for schema-related errors."""
def __init__(
self,
message: str,
error_type: str = "schema_error",
collection: str = "unknown",
):
"""Initialize SchemaError.
Args:
message: Error message
error_type: Type of schema error
collection: Collection name
"""
self.error_type = error_type
self.collection = collection
super().__init__(message)
def to_dict(self) -> Dict[str, Any]:
"""Convert the error to a dictionary format."""
return {
"error_type": self.error_type,
"message": self.message,
"collection": self.collection,
}
class CollectionNotFoundError(SchemaError):
"""Exception raised when a collection does not exist."""
def __init__(self, collection: str):
super().__init__(
message=f"Collection '{collection}' not found",
error_type="COLLECTION_NOT_FOUND",
collection=collection,
)
class SchemaNotFoundError(SchemaError):
"""Exception raised when a collection's schema cannot be retrieved."""
def __init__(self, collection: str, details: str = None):
message = f"Schema for collection '{collection}' could not be retrieved"
if details:
message += f": {details}"
super().__init__(
message=message, error_type="SCHEMA_NOT_FOUND", collection=collection
)
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/zookeeper_fixtures.py:
--------------------------------------------------------------------------------
```python
"""ZooKeeper fixtures for unit tests."""
from unittest.mock import MagicMock, patch
import pytest
from kazoo.client import KazooClient
from kazoo.exceptions import ConnectionLoss, NoNodeError
@pytest.fixture(
params=["success", "no_collections", "empty", "error", "connection_error"]
)
def mock_kazoo_client(request):
"""Parameterized KazooClient mock with different behavior scenarios."""
mock = MagicMock(spec=KazooClient)
scenario = request.param
if scenario == "success":
mock.get_children.return_value = ["collection1", "collection2"]
mock.exists.return_value = True
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "no_collections":
mock.exists.return_value = False
mock.get_children.side_effect = NoNodeError
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "empty":
mock.exists.return_value = True
mock.get_children.return_value = []
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "error":
mock.exists.return_value = True
mock.get_children.side_effect = ConnectionLoss("ZooKeeper error")
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "connection_error":
mock.start.side_effect = ConnectionLoss("ZooKeeper connection error")
mock.stop.return_value = None
yield mock, scenario
@pytest.fixture
def mock_kazoo_client_factory(request):
"""Factory for creating KazooClient mocks with specific behavior."""
def _create_client(scenario="success"):
mock = MagicMock(spec=KazooClient)
if scenario == "success":
mock.get_children.return_value = ["collection1", "collection2"]
mock.exists.return_value = True
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "no_collections":
mock.exists.return_value = False
mock.get_children.side_effect = NoNodeError
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "empty":
mock.exists.return_value = True
mock.get_children.return_value = []
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "error":
mock.exists.return_value = True
mock.get_children.side_effect = ConnectionLoss("ZooKeeper error")
mock.start.return_value = None
mock.stop.return_value = None
elif scenario == "connection_error":
mock.start.side_effect = ConnectionLoss("ZooKeeper connection error")
mock.stop.return_value = None
return mock
scenario = getattr(request, "param", "success")
mock_client = _create_client(scenario)
with patch("solr_mcp.solr.zookeeper.KazooClient", return_value=mock_client):
yield _create_client
@pytest.fixture
def provider(mock_kazoo_client_factory):
"""Create ZooKeeperCollectionProvider instance with mocked dependencies."""
from solr_mcp.solr.zookeeper import ZooKeeperCollectionProvider
provider = ZooKeeperCollectionProvider(hosts=["localhost:2181"])
# The KazooClient is already mocked via the factory fixture
return provider
```
--------------------------------------------------------------------------------
/scripts/check_solr.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Script to check Solr configuration and status.
"""
import asyncio
import httpx
import json
import sys
async def check_solr_collections():
"""Check Solr collections and their configuration."""
try:
async with httpx.AsyncClient() as client:
# Get list of collections
response = await client.get(
"http://localhost:8983/solr/admin/collections",
params={"action": "LIST", "wt": "json"},
timeout=10.0
)
if response.status_code != 200:
print(f"Error getting collections: {response.status_code} - {response.text}")
return
collections_data = response.json()
if 'collections' in collections_data:
collections = collections_data['collections']
print(f"Found {len(collections)} collections: {', '.join(collections)}")
# Check each collection
for collection in collections:
# Get schema information
schema_response = await client.get(
f"http://localhost:8983/solr/{collection}/schema",
params={"wt": "json"},
timeout=10.0
)
if schema_response.status_code != 200:
print(f"Error getting schema for {collection}: {schema_response.status_code}")
continue
schema_data = schema_response.json()
# Check for vector field type
field_types = schema_data.get('schema', {}).get('fieldTypes', [])
vector_type = None
for ft in field_types:
if ft.get('class') == 'solr.DenseVectorField':
vector_type = ft
break
if vector_type:
print(f"\nCollection '{collection}' has vector field type:")
print(f" Name: {vector_type.get('name')}")
print(f" Class: {vector_type.get('class')}")
print(f" Vector Dimension: {vector_type.get('vectorDimension')}")
print(f" Similarity Function: {vector_type.get('similarityFunction')}")
else:
print(f"\nCollection '{collection}' does not have a vector field type")
# Check for vector fields
fields = schema_data.get('schema', {}).get('fields', [])
vector_fields = [f for f in fields if f.get('type') == 'knn_vector']
if vector_fields:
print(f"\n Vector fields in '{collection}':")
for field in vector_fields:
print(f" - {field.get('name')} (indexed: {field.get('indexed')}, stored: {field.get('stored')})")
else:
print(f"\n No vector fields found in '{collection}'")
else:
print("No collections found or invalid response format")
except Exception as e:
print(f"Error checking Solr: {e}")
if __name__ == "__main__":
asyncio.run(check_solr_collections())
```
--------------------------------------------------------------------------------
/tests/unit/test_parser.py:
--------------------------------------------------------------------------------
```python
"""Unit tests for QueryParser."""
import pytest
from solr_mcp.solr.exceptions import QueryError
from solr_mcp.solr.query.parser import QueryParser
@pytest.fixture
def query_parser():
"""Create QueryParser instance for testing."""
return QueryParser()
class TestQueryParser:
"""Test cases for QueryParser."""
def test_init(self, query_parser):
"""Test QueryParser initialization."""
assert isinstance(query_parser, QueryParser)
def test_preprocess_query_basic(self, query_parser):
"""Test preprocessing basic field:value syntax."""
query = "SELECT * FROM collection1 WHERE field:value"
result = query_parser.preprocess_query(query)
assert "field = 'value'" in result
def test_preprocess_query_multiple(self, query_parser):
"""Test preprocessing multiple field:value pairs."""
query = "SELECT * FROM collection1 WHERE field1:value1 AND field2:value2"
result = query_parser.preprocess_query(query)
assert "field1 = 'value1'" in result
assert "field2 = 'value2'" in result
def test_parse_select_valid(self, query_parser):
"""Test parsing valid SELECT query."""
query = "SELECT id, title FROM collection1"
ast, collection, fields = query_parser.parse_select(query)
assert ast is not None
assert collection == "collection1"
assert fields == ["id", "title"]
def test_parse_select_no_select(self, query_parser):
"""Test parsing non-SELECT query."""
query = "INSERT INTO collection1 (id) VALUES (1)"
with pytest.raises(QueryError) as exc_info:
query_parser.parse_select(query)
assert exc_info.type == QueryError
def test_parse_select_no_from(self, query_parser):
"""Test parsing query without FROM clause."""
query = "SELECT id, title"
with pytest.raises(QueryError) as exc_info:
query_parser.parse_select(query)
assert exc_info.type == QueryError
def test_parse_select_with_alias(self, query_parser):
"""Test parsing query with aliased fields."""
query = "SELECT id as doc_id, title as doc_title FROM collection1"
ast, collection, fields = query_parser.parse_select(query)
assert ast is not None
assert collection == "collection1"
assert "doc_id" in fields
assert "doc_title" in fields
def test_parse_select_with_star(self, query_parser):
"""Test parsing query with * selector."""
query = "SELECT * FROM collection1"
ast, collection, fields = query_parser.parse_select(query)
assert ast is not None
assert collection == "collection1"
assert "*" in fields
def test_parse_select_invalid_syntax(self, query_parser):
"""Test parsing query with invalid syntax."""
query = "INVALID SQL"
with pytest.raises(QueryError) as exc_info:
query_parser.parse_select(query)
assert exc_info.type == QueryError
def test_extract_sort_fields_single(self, query_parser):
"""Test extracting fields from single sort specification."""
sort_spec = "title desc"
fields = query_parser.extract_sort_fields(sort_spec)
assert fields == ["title"]
def test_extract_sort_fields_multiple(self, query_parser):
"""Test extracting fields from multiple sort specifications."""
sort_spec = "title desc, id asc"
fields = query_parser.extract_sort_fields(sort_spec)
assert fields == ["title", "id"]
```
--------------------------------------------------------------------------------
/scripts/prepare_data.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Script to prepare data for indexing in Solr with dynamic field naming conventions.
"""
import argparse
import json
import sys
import os
from datetime import datetime
def prepare_data_for_solr(input_file, output_file):
"""
Modify field names to use Solr dynamic field naming conventions.
Args:
input_file: Path to the input JSON file
output_file: Path to the output JSON file
"""
# Load the input data
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# Transform the data
transformed_data = []
for doc in data:
transformed_doc = {}
# Map fields to appropriate dynamic field suffixes
for key, value in doc.items():
if key == 'id' or key == 'title' or key == 'text' or key == 'source':
# Keep standard fields as they are
transformed_doc[key] = value
elif key == 'section_number':
# Integer fields get _i suffix
transformed_doc['section_number_i'] = value
elif key == 'date_indexed':
# Date fields get _dt suffix and need proper Solr format
# Convert to Solr format YYYY-MM-DDThh:mm:ssZ
# If already a string, ensure it's in the right format
if isinstance(value, str):
# Truncate microseconds if present
if '.' in value:
parts = value.split('.')
value = parts[0] + 'Z'
elif not value.endswith('Z'):
value = value + 'Z'
transformed_doc[f'{key}_dt'] = value
elif key == 'date':
# Ensure date has proper format
if isinstance(value, str):
# If just a date (YYYY-MM-DD), add time
if len(value) == 10 and value.count('-') == 2:
value = value + 'T00:00:00Z'
# If it has time but no Z, add Z
elif 'T' in value and not value.endswith('Z'):
value = value + 'Z'
transformed_doc[f'{key}_dt'] = value
elif key == 'tags' or key == 'category':
# Multi-valued string fields get _ss suffix
transformed_doc[f'{key}_ss'] = value
elif key == 'author':
# String fields get _s suffix
transformed_doc[f'{key}_s'] = value
else:
# Default: keep as is
transformed_doc[key] = value
transformed_data.append(transformed_doc)
# Write the transformed data to output file
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(transformed_data, f, indent=2)
print(f"Prepared {len(transformed_data)} documents for Solr indexing")
print(f"Output saved to {output_file}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Prepare data for Solr indexing")
parser.add_argument("input_file", help="Path to the input JSON file")
parser.add_argument("--output", "-o", default=None, help="Path to the output JSON file")
args = parser.parse_args()
# Generate output filename if not provided
if args.output is None:
input_name = os.path.basename(args.input_file)
name, ext = os.path.splitext(input_name)
args.output = f"data/processed/{name}_solr{ext}"
prepare_data_for_solr(args.input_file, args.output)
```
--------------------------------------------------------------------------------
/tests/unit/tools/test_solr_default_vectorizer.py:
--------------------------------------------------------------------------------
```python
"""Tests for solr_default_vectorizer tool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from solr_mcp.tools.solr_default_vectorizer import get_default_text_vectorizer
from solr_mcp.vector_provider.constants import DEFAULT_OLLAMA_CONFIG, MODEL_DIMENSIONS
class TestDefaultVectorizerTool:
"""Test cases for default_text_vectorizer tool."""
@pytest.mark.asyncio
async def test_get_default_text_vectorizer_with_server(self):
"""Test getting default vectorizer with a server instance."""
# Create mock server
mock_vector_manager = MagicMock()
mock_vector_manager.client.model = "nomic-embed-text"
mock_vector_manager.client.base_url = "http://test-host:8888"
mock_solr_client = MagicMock()
mock_solr_client.vector_manager = mock_vector_manager
mock_server = MagicMock()
mock_server.solr_client = mock_solr_client
# Execute tool
result = await get_default_text_vectorizer(mock_server)
# Verify result
assert result["vector_provider_model"] == "nomic-embed-text"
assert result["vector_provider_dimension"] == 768
assert result["vector_provider_host"] == "test-host"
assert result["vector_provider_port"] == 8888
assert result["vector_provider_url"] == "http://test-host:8888"
assert result["vector_provider_spec"] == "nomic-embed-text@test-host:8888"
@pytest.mark.asyncio
async def test_get_default_text_vectorizer_without_server(self):
"""Test getting default vectorizer without a server instance."""
# Create a server without vector_manager
mock_server = MagicMock(spec=["no_solr_client"])
# Use patch to avoid trying to parse MagicMock as URL
with patch(
"solr_mcp.vector_provider.constants.DEFAULT_OLLAMA_CONFIG",
{
"model": "nomic-embed-text",
"base_url": "http://localhost:11434",
"timeout": 30,
"retries": 3,
},
):
# Execute tool
result = await get_default_text_vectorizer(mock_server)
# Verify result uses defaults
assert result["vector_provider_model"] == DEFAULT_OLLAMA_CONFIG["model"]
assert (
result["vector_provider_dimension"]
== MODEL_DIMENSIONS[DEFAULT_OLLAMA_CONFIG["model"]]
)
assert result["vector_provider_host"] == "localhost"
assert result["vector_provider_port"] == 11434
assert result["vector_provider_url"] == DEFAULT_OLLAMA_CONFIG["base_url"]
assert (
result["vector_provider_spec"]
== f"{DEFAULT_OLLAMA_CONFIG['model']}@localhost:11434"
)
@pytest.mark.asyncio
async def test_get_default_text_vectorizer_unknown_model(self):
"""Test getting default vectorizer with unknown model."""
# Create mock server
mock_vector_manager = MagicMock()
mock_vector_manager.client.model = "unknown-model"
mock_vector_manager.client.base_url = "http://test-host:8888"
mock_solr_client = MagicMock()
mock_solr_client.vector_manager = mock_vector_manager
mock_server = MagicMock()
mock_server.solr_client = mock_solr_client
# Execute tool
result = await get_default_text_vectorizer(mock_server)
# Verify result with default dimension for unknown model
assert result["vector_provider_model"] == "unknown-model"
assert result["vector_provider_dimension"] == 768 # Default dimension
assert result["vector_provider_spec"] == "unknown-model@test-host:8888"
```
--------------------------------------------------------------------------------
/solr_mcp/solr/query/validator.py:
--------------------------------------------------------------------------------
```python
"""Query validation for SolrCloud client."""
import logging
from typing import Any, Dict, List, Optional
from loguru import logger
from sqlglot.expressions import Select
from solr_mcp.solr.exceptions import QueryError
logger = logging.getLogger(__name__)
class QueryValidator:
"""Validates SQL queries for Solr."""
def __init__(self, field_manager):
"""Initialize the QueryValidator.
Args:
field_manager: FieldManager instance for field validation
"""
self.field_manager = field_manager
def validate_fields(self, collection: str, fields: List[str]) -> None:
"""Validate that fields exist in the collection.
Args:
collection: Collection name
fields: List of field names to validate
Raises:
QueryError: If fields are invalid
"""
try:
# Get available fields for collection
available_fields = self.field_manager.get_field_types(collection)
# Check each field exists
for field in fields:
if field not in available_fields:
raise QueryError(
f"Invalid field '{field}' - field does not exist in collection '{collection}'"
)
except QueryError:
raise
except Exception as e:
raise QueryError(f"Field validation error: {str(e)}")
def validate_sort_fields(self, collection: str, fields: List[str]) -> None:
"""Validate that fields are sortable in the collection.
Args:
collection: Collection name
fields: List of field names to validate
Raises:
QueryError: If fields are not sortable
"""
try:
self.field_manager.validate_sort_fields(collection, fields)
except Exception as e:
raise QueryError(f"Sort field validation error: {str(e)}")
def validate_sort(self, sort: Optional[str], collection: str) -> Optional[str]:
"""Validate and normalize sort parameter.
Args:
sort: Sort string in format "field direction" or just "field"
collection: Collection name
Returns:
Validated sort string or None if sort is None
Raises:
QueryError: If sort specification is invalid
"""
if not sort:
return None
parts = sort.strip().split()
if len(parts) == 1:
field = parts[0]
direction = None
elif len(parts) == 2:
field, direction = parts
else:
raise QueryError(f"Invalid sort format: {sort}")
try:
# Get sortable fields for the collection
field_info = self.field_manager.get_field_info(collection)
sortable_fields = field_info["sortable_fields"]
# Check if field is sortable
if field not in sortable_fields:
raise QueryError(f"Field '{field}' is not sortable")
# Validate direction if provided
if direction:
valid_directions = sortable_fields[field]["directions"]
if direction.lower() not in [d.lower() for d in valid_directions]:
raise QueryError(
f"Invalid sort direction '{direction}' for field '{field}'"
)
else:
# Use default direction for field
direction = sortable_fields[field]["default_direction"]
return f"{field} {direction}"
except QueryError:
raise
except Exception as e:
raise QueryError(f"Sort field validation error: {str(e)}")
```
--------------------------------------------------------------------------------
/solr_mcp/solr/utils/formatting.py:
--------------------------------------------------------------------------------
```python
"""Utilities for formatting Solr search results."""
import json
import logging
from typing import Any, Dict, List, Optional, Union
import pysolr
from solr_mcp.solr.exceptions import QueryError, SolrError
logger = logging.getLogger(__name__)
def format_search_results(
results: pysolr.Results,
start: int = 0,
include_score: bool = True,
include_facets: bool = True,
include_highlighting: bool = True,
) -> str:
"""Format Solr search results for consumption.
Args:
results: pysolr Results object
start: Start offset used in the search
include_score: Whether to include score information
include_facets: Whether to include facet information
include_highlighting: Whether to include highlighting information
Returns:
Formatted results as JSON string
"""
try:
formatted = {
"result-set": {
"numFound": results.hits,
"start": start,
"docs": list(results.docs) if hasattr(results, "docs") else [],
}
}
# Include score information if requested and available
if include_score and hasattr(results, "max_score"):
formatted["result-set"]["maxScore"] = results.max_score
# Include facets if requested and available
if include_facets and hasattr(results, "facets") and results.facets:
formatted["result-set"]["facets"] = results.facets
# Include highlighting if requested and available
if (
include_highlighting
and hasattr(results, "highlighting")
and results.highlighting
):
formatted["result-set"]["highlighting"] = results.highlighting
try:
return json.dumps(formatted, default=str)
except TypeError as e:
logger.error(f"JSON serialization error: {e}")
# Fall back to basic result format
return json.dumps(
{
"result-set": {
"numFound": results.hits,
"start": start,
"docs": (
[str(doc) for doc in results.docs]
if hasattr(results, "docs")
else []
),
}
}
)
except Exception as e:
logger.error(f"Error formatting search results: {e}")
return json.dumps({"error": str(e)})
def format_sql_response(raw_response: Dict[str, Any]) -> Dict[str, Any]:
"""Format SQL query response to a standardized structure."""
try:
# Check for error response
if "result-set" in raw_response and "docs" in raw_response["result-set"]:
docs = raw_response["result-set"]["docs"]
if len(docs) == 1 and "EXCEPTION" in docs[0]:
raise QueryError(docs[0]["EXCEPTION"])
# Return standardized response format
return {
"result-set": {
"docs": raw_response.get("result-set", {}).get("docs", []),
"numFound": len(raw_response.get("result-set", {}).get("docs", [])),
"start": 0,
}
}
except QueryError as e:
raise e
except Exception as e:
raise QueryError(f"Error formatting SQL response: {str(e)}")
def format_error_response(error: Exception) -> str:
"""Format error response as JSON string.
Args:
error: Exception object
Returns:
Error message as JSON string
"""
error_code = "INTERNAL_ERROR"
if isinstance(error, QueryError):
error_code = "QUERY_ERROR"
elif isinstance(error, SolrError):
error_code = "SOLR_ERROR"
return json.dumps({"error": {"code": error_code, "message": str(error)}})
```
--------------------------------------------------------------------------------
/tests/unit/fixtures/common.py:
--------------------------------------------------------------------------------
```python
"""Common fixtures and mock data for unit tests."""
from typing import List, Optional
from unittest.mock import Mock
import pytest
from solr_mcp.solr.interfaces import CollectionProvider, VectorSearchProvider
# Mock response data with various levels of detail
MOCK_RESPONSES = {
"collections": ["collection1", "collection2"],
"select": {"result-set": {"docs": [{"id": "1", "field": "value"}], "numFound": 1}},
"vector": {
"result-set": {
"docs": [{"id": "1", "field": "value", "score": 0.95}],
"numFound": 1,
}
},
"semantic": {
"result-set": {
"docs": [{"id": "1", "field": "value", "score": 0.85}],
"numFound": 1,
}
},
"schema": {
"schema": {
"fields": [
{
"name": "id",
"type": "string",
"multiValued": False,
"required": True,
},
{"name": "title", "type": "text_general", "multiValued": False},
{"name": "content", "type": "text_general", "multiValued": False},
{"name": "vector", "type": "knn_vector", "multiValued": False},
],
"fieldTypes": [
{"name": "string", "class": "solr.StrField", "sortMissingLast": True},
{
"name": "text_general",
"class": "solr.TextField",
"positionIncrementGap": "100",
},
{
"name": "knn_vector",
"class": "solr.DenseVectorField",
"vectorDimension": 768,
},
],
}
},
"field_list": {
"fields": [
{
"name": "id",
"type": "string",
"indexed": True,
"stored": True,
"docValues": True,
"multiValued": False,
},
{
"name": "_text_",
"type": "text_general",
"indexed": True,
"stored": False,
"docValues": False,
"multiValued": True,
"copies_from": ["title", "content"],
},
]
},
}
class MockCollectionProvider(CollectionProvider):
"""Mock implementation of CollectionProvider."""
def __init__(self, collections=None):
"""Initialize with optional list of collections."""
self.collections = (
collections if collections is not None else MOCK_RESPONSES["collections"]
)
async def list_collections(self) -> List[str]:
"""Return mock list of collections."""
return self.collections
async def collection_exists(self, collection: str) -> bool:
"""Check if collection exists in mock list."""
return collection in self.collections
class MockVectorProvider(VectorSearchProvider):
"""Mock vector provider for testing."""
async def execute_vector_search(self, client, vector, top_k=10):
"""Mock vector search execution."""
return {
"response": {
"docs": [
{"_docid_": "1", "score": 0.9, "_vector_distance_": 0.1},
{"_docid_": "2", "score": 0.8, "_vector_distance_": 0.2},
{"_docid_": "3", "score": 0.7, "_vector_distance_": 0.3},
],
"numFound": 3,
"start": 0,
}
}
async def get_vector(self, text: str, model: Optional[str] = None) -> List[float]:
"""Mock text to vector conversion."""
return [0.1, 0.2, 0.3]
@pytest.fixture
def valid_config_dict():
"""Valid configuration dictionary."""
return {
"solr_base_url": "http://localhost:8983/solr",
"zookeeper_hosts": ["localhost:2181"],
"connection_timeout": 10,
}
```
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
```yaml
version: '3.8'
services:
zookeeper:
image: zookeeper:3.9
container_name: zookeeper
ports:
- "2181:2181"
environment:
ZOO_MY_ID: 1
ZOO_SERVERS: server.1=zookeeper:2888:3888;2181
volumes:
- zookeeper_data:/data
- zookeeper_logs:/datalog
networks:
- solr-net
solr1:
build:
context: .
dockerfile: solr.Dockerfile
container_name: solr1
ports:
- "8983:8983"
environment:
- ZK_HOST=zookeeper:2181
- SOLR_JAVA_MEM=-Xms512m -Xmx512m
volumes:
- solr1_data:/var/solr
depends_on:
- zookeeper
networks:
- solr-net
command:
- solr-foreground
- -c # Run in cloud mode
healthcheck:
test: ["CMD", "wget", "-q", "--spider", "http://localhost:8983/solr/"]
interval: 5s
timeout: 10s
retries: 5
solr2:
build:
context: .
dockerfile: solr.Dockerfile
container_name: solr2
ports:
- "8984:8983"
environment:
- ZK_HOST=zookeeper:2181
- SOLR_JAVA_MEM=-Xms512m -Xmx512m
volumes:
- solr2_data:/var/solr
depends_on:
- zookeeper
- solr1
networks:
- solr-net
command:
- solr-foreground
- -c # Run in cloud mode
# Initializer service to set up Solr collections (runs once and exits)
solr-init:
image: solr:9.5
container_name: solr-init
depends_on:
solr1:
condition: service_healthy
solr2:
condition: service_started
networks:
- solr-net
environment:
- ZK_HOST=zookeeper:2181
- SOLR_HOST=solr1
- SOLR_PORT=8983
volumes:
- ./solr_config:/config
command: >
bash -c "
# Wait for Solr to be available
echo 'Waiting for Solr to be available...'
until wget -q --spider http://solr1:8983/solr; do
sleep 2
done
echo 'Solr is up!'
# Delete existing collection if it exists
echo 'Deleting existing unified collection if it exists...'
curl -s 'http://solr1:8983/solr/admin/collections?action=DELETE&name=unified' || true
# Upload the updated config
echo 'Uploading updated configuration...'
solr zk upconfig -n unified_config -d /config/unified -z zookeeper:2181
# Create the unified collection with the updated config
echo 'Creating unified collection...'
curl -s 'http://solr1:8983/solr/admin/collections?action=CREATE&name=unified&numShards=1&replicationFactor=1&collection.configName=unified_config' || echo 'Collection creation failed - check solr logs'
echo 'Initialization complete!'
"
# Ollama for embeddings
ollama:
image: ollama/ollama:latest
container_name: ollama
ports:
- "11434:11434"
volumes:
- ollama_data:/root/.ollama
networks:
- solr-net
# Run Ollama and pull the model
entrypoint: ["/bin/bash", "-c"]
command: >
"ollama serve &
sleep 15 &&
echo 'Pulling nomic-embed-text model...' &&
ollama pull nomic-embed-text &&
echo 'Model pulled successfully' &&
tail -f /dev/null"
healthcheck:
test: ["CMD-SHELL", "bash -c 'cat < /dev/null > /dev/tcp/localhost/11434'"]
interval: 15s
timeout: 5s
retries: 5
start_period: 90s
# MCP Server
mcp-server:
build:
context: .
dockerfile: Dockerfile
container_name: mcp-server
ports:
- "8000:8000"
environment:
- SOLR_MCP_ZK_HOSTS=zookeeper:2181
- SOLR_MCP_SOLR_URL=http://solr1:8983/solr
- SOLR_MCP_DEFAULT_COLLECTION=unified
- OLLAMA_BASE_URL=http://ollama:11434
depends_on:
solr-init:
condition: service_completed_successfully
ollama:
condition: service_healthy
networks:
- solr-net
volumes:
- ./:/app
networks:
solr-net:
driver: bridge
volumes:
zookeeper_data:
zookeeper_logs:
solr1_data:
solr2_data:
ollama_data:
```
--------------------------------------------------------------------------------
/scripts/process_markdown.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Script to process markdown files, splitting them by section headings
and preparing them for indexing in Solr with vector embeddings.
"""
import argparse
import json
import os
import re
import sys
from datetime import datetime
from typing import Dict, List, Tuple
import frontmatter
def extract_sections(markdown_content: str) -> List[Tuple[str, str]]:
"""
Extract sections from a markdown document based on headings.
Args:
markdown_content: The content of the markdown file
Returns:
List of tuples (section_title, section_content)
"""
# Split by headers (# Header)
header_pattern = r'^(#{1,6})\s+(.+?)$'
lines = markdown_content.split('\n')
sections = []
current_title = "Introduction"
current_content = []
for line in lines:
header_match = re.match(header_pattern, line, re.MULTILINE)
if header_match:
# Save previous section
if current_content:
sections.append((current_title, '\n'.join(current_content).strip()))
current_content = []
# Start new section
current_title = header_match.group(2).strip()
else:
current_content.append(line)
# Add the last section
if current_content:
sections.append((current_title, '\n'.join(current_content).strip()))
return sections
def convert_to_solr_docs(sections: List[Tuple[str, str]], filename: str, metadata: Dict) -> List[Dict]:
"""
Convert markdown sections to Solr documents.
Args:
sections: List of (title, content) tuples
filename: Original filename
metadata: Metadata from frontmatter
Returns:
List of documents ready for Solr indexing
"""
documents = []
for i, (title, content) in enumerate(sections):
# Skip empty sections
if not content.strip():
continue
doc = {
"id": f"{os.path.basename(filename)}_section_{i}",
"title": title,
"text": content,
"source": filename,
"section_number": i,
"date_indexed": datetime.now().isoformat(),
"tags": metadata.get("tags", []),
"category": metadata.get("categories", [])
}
# Add any additional metadata
for key, value in metadata.items():
if key not in ["tags", "categories"] and key not in doc:
doc[key] = value
documents.append(doc)
return documents
def process_markdown_file(file_path: str, output_file: str = None):
"""
Process a markdown file, splitting it into sections and converting to Solr documents.
Args:
file_path: Path to the markdown file
output_file: Path to save the JSON output (if None, prints to stdout)
"""
# Read and parse markdown with frontmatter
with open(file_path, 'r', encoding='utf-8') as f:
post = frontmatter.load(f)
# Extract frontmatter metadata and content
metadata = dict(post.metadata)
content = post.content
# Extract sections
sections = extract_sections(content)
# Convert to Solr documents
documents = convert_to_solr_docs(sections, file_path, metadata)
# Output
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(documents, f, indent=2)
print(f"Processed {file_path} into {len(documents)} sections and saved to {output_file}")
else:
print(json.dumps(documents, indent=2))
print(f"Processed {file_path} into {len(documents)} sections", file=sys.stderr)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process markdown files for Solr indexing")
parser.add_argument("file", help="Path to the markdown file")
parser.add_argument("--output", "-o", help="Output JSON file path")
args = parser.parse_args()
process_markdown_file(args.file, args.output)
```
--------------------------------------------------------------------------------
/tests/unit/test_formatting.py:
--------------------------------------------------------------------------------
```python
"""Unit tests for formatting utilities."""
import json
from unittest.mock import Mock
from solr_mcp.solr.exceptions import QueryError, SolrError
from solr_mcp.solr.utils.formatting import (
format_error_response,
format_search_results,
format_sql_response,
)
class TestFormatting:
"""Test cases for formatting utilities."""
def test_format_search_results(self):
"""Test formatting Solr search results."""
# Create mock pysolr Results
mock_results = Mock()
mock_results.docs = [
{"id": "1", "title": "Test 1"},
{"id": "2", "title": "Test 2"},
]
mock_results.hits = 2
mock_results.raw_response = {
"response": {
"docs": mock_results.docs,
"numFound": mock_results.hits,
"start": 0,
}
}
formatted = format_search_results(mock_results, start=0)
result_dict = json.loads(formatted)
assert "result-set" in result_dict
assert result_dict["result-set"]["docs"] == mock_results.docs
assert result_dict["result-set"]["numFound"] == mock_results.hits
assert result_dict["result-set"]["start"] == 0
def test_format_search_results_empty(self):
"""Test formatting empty search results."""
mock_results = Mock()
mock_results.docs = []
mock_results.hits = 0
mock_results.raw_response = {
"response": {"docs": [], "numFound": 0, "start": 0}
}
formatted = format_search_results(mock_results, start=0)
result_dict = json.loads(formatted)
assert "result-set" in result_dict
assert result_dict["result-set"]["docs"] == []
assert result_dict["result-set"]["numFound"] == 0
assert result_dict["result-set"]["start"] == 0
def test_format_sql_response(self):
"""Test formatting SQL query response."""
response = {
"result-set": {
"docs": [
{"id": "1", "title": "Test 1"},
{"id": "2", "title": "Test 2"},
],
"numFound": 2,
"start": 0,
}
}
formatted = format_sql_response(response)
assert formatted == response
assert "result-set" in formatted
assert formatted["result-set"]["numFound"] == 2
assert len(formatted["result-set"]["docs"]) == 2
def test_format_sql_response_empty(self):
"""Test formatting empty SQL query response."""
response = {"result-set": {"docs": [], "numFound": 0, "start": 0}}
formatted = format_sql_response(response)
assert formatted == response
assert "result-set" in formatted
assert formatted["result-set"]["numFound"] == 0
assert formatted["result-set"]["docs"] == []
def test_format_error_response_query_error(self):
"""Test formatting QueryError response."""
error = QueryError("Invalid SQL syntax")
formatted = format_error_response(error)
error_dict = json.loads(formatted)
assert "error" in error_dict
assert error_dict["error"]["code"] == "QUERY_ERROR"
assert error_dict["error"]["message"] == "Invalid SQL syntax"
def test_format_error_response_solr_error(self):
"""Test formatting SolrError response."""
error = SolrError("Connection failed")
formatted = format_error_response(error)
error_dict = json.loads(formatted)
assert "error" in error_dict
assert error_dict["error"]["code"] == "SOLR_ERROR"
assert error_dict["error"]["message"] == "Connection failed"
def test_format_error_response_generic_error(self):
"""Test formatting generic error response."""
error = Exception("Unknown error")
formatted = format_error_response(error)
error_dict = json.loads(formatted)
assert "error" in error_dict
assert error_dict["error"]["code"] == "INTERNAL_ERROR"
assert "Unknown error" in error_dict["error"]["message"]
```
--------------------------------------------------------------------------------
/solr_mcp/solr/vector/results.py:
--------------------------------------------------------------------------------
```python
"""Vector search results handling."""
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class VectorSearchResult(BaseModel):
"""Individual vector search result."""
docid: str = Field(description="Internal Solr document ID (_docid_)")
score: float = Field(description="Search score")
distance: Optional[float] = Field(None, description="Vector distance if available")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Additional metadata"
)
def __getitem__(self, key):
"""Make result subscriptable."""
if key == "docid":
return self.docid
elif key == "score":
return self.score
elif key == "distance":
return self.distance
elif key == "metadata":
return self.metadata
raise KeyError(f"Invalid key: {key}")
class VectorSearchResults(BaseModel):
"""Container for vector search results."""
results: List[VectorSearchResult] = Field(
default_factory=list, description="List of search results"
)
total_found: int = Field(0, description="Total number of results found")
top_k: int = Field(..., description="Number of results requested")
query_time_ms: Optional[int] = Field(
None, description="Query execution time in milliseconds"
)
@property
def docs(self) -> List[VectorSearchResult]:
"""Get list of search results."""
return self.results
@classmethod
def from_solr_response(
cls, response: Dict[str, Any], top_k: int = 10
) -> "VectorSearchResults":
"""Create VectorSearchResults from Solr response.
Args:
response: Raw Solr response dictionary
top_k: Number of results requested
Returns:
VectorSearchResults instance
"""
# Extract response header
header = response.get("responseHeader", {})
query_time = header.get("QTime")
# Extract main response section
resp = response.get("response", {})
docs = resp.get("docs", [])
# Create results list
results = []
for doc in docs:
# Handle both string and numeric _docid_
docid = doc.get("_docid_")
if docid is None:
# Try alternate field names
docid = doc.get("[docid]") or doc.get("docid") or "0"
docid = str(docid) # Ensure string type
result = VectorSearchResult(
docid=docid,
score=doc.get("score", 0.0),
distance=doc.get("_vector_distance_"),
metadata={
k: v
for k, v in doc.items()
if k
not in ["_docid_", "[docid]", "docid", "score", "_vector_distance_"]
},
)
results.append(result)
# Create VectorSearchResults
return cls(
results=results,
total_found=resp.get("numFound", 0),
top_k=top_k,
query_time_ms=query_time,
)
def to_dict(self) -> Dict[str, Any]:
"""Convert results to dictionary format.
Returns:
Dictionary representation of results
"""
return {
"results": [result.model_dump() for result in self.results],
"metadata": {
"total_found": self.total_found,
"top_k": self.top_k,
"query_time_ms": self.query_time_ms,
},
}
def get_doc_ids(self) -> List[str]:
"""Get list of document IDs from results.
Returns:
List of document IDs
"""
return [result.docid for result in self.results]
def get_scores(self) -> List[float]:
"""Get list of scores from results.
Returns:
List of scores
"""
return [result.score for result in self.results]
def get_distances(self) -> List[Optional[float]]:
"""Get list of vector distances from results.
Returns:
List of distances (None if not available)
"""
return [result.distance for result in self.results]
```
--------------------------------------------------------------------------------
/scripts/vector_index_simple.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Simple script for indexing documents with vector embeddings.
"""
import argparse
import asyncio
import json
import os
import sys
import numpy as np
import httpx
from typing import Dict, List, Any
# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.embeddings.client import OllamaClient
async def generate_embeddings(texts: List[str]) -> List[List[float]]:
"""Generate embeddings for a list of texts using Ollama.
Args:
texts: List of text strings to generate embeddings for
Returns:
List of embedding vectors
"""
client = OllamaClient()
embeddings = []
print(f"Generating embeddings for {len(texts)} documents...")
# Process in smaller batches to avoid overwhelming Ollama
batch_size = 5
for i in range(0, len(texts), batch_size):
batch = texts[i:i+batch_size]
print(f"Processing batch {i//batch_size + 1}/{(len(texts) + batch_size - 1)//batch_size}...")
batch_embeddings = await client.get_embeddings(batch)
embeddings.extend(batch_embeddings)
return embeddings
async def index_documents(json_file: str, collection: str = "testvectors", commit: bool = True):
"""Index documents with vector embeddings.
Args:
json_file: Path to the JSON file containing documents
collection: Solr collection name
commit: Whether to commit after indexing
"""
# Load documents
with open(json_file, 'r', encoding='utf-8') as f:
documents = json.load(f)
# Extract text for embedding generation
texts = []
for doc in documents:
if 'text' in doc:
texts.append(doc['text'])
elif 'content' in doc:
texts.append(doc['content'])
else:
texts.append(doc.get('title', ''))
# Generate embeddings
embeddings = await generate_embeddings(texts)
# Prepare documents for indexing
solr_docs = []
for i, doc in enumerate(documents):
solr_doc = {
'id': doc['id'],
'title': doc['title'],
'text': doc.get('text', doc.get('content', '')),
'source': doc.get('source', 'unknown'),
'vector_model': 'nomic-embed-text',
'embedding': embeddings[i]
}
solr_docs.append(solr_doc)
# Index each document separately (a workaround for vector field issues)
print(f"Indexing {len(solr_docs)} documents to collection '{collection}'...")
async with httpx.AsyncClient() as client:
for i, doc in enumerate(solr_docs):
solr_url = f"http://localhost:8983/solr/{collection}/update/json/docs"
params = {"commit": "true"} if (commit and i == len(solr_docs) - 1) else {}
try:
response = await client.post(
solr_url,
json=doc,
params=params,
timeout=30.0
)
if response.status_code != 200:
print(f"Error indexing document {doc['id']}: {response.status_code} - {response.text}")
return False
print(f"Indexed document {i+1}/{len(solr_docs)}: {doc['id']}")
except Exception as e:
print(f"Error indexing document {doc['id']}: {e}")
return False
print(f"Successfully indexed {len(solr_docs)} documents to collection '{collection}'")
return True
async def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Index documents with vector embeddings")
parser.add_argument("json_file", help="Path to the JSON file containing documents")
parser.add_argument("--collection", "-c", default="testvectors", help="Solr collection name")
parser.add_argument("--no-commit", dest="commit", action="store_false", help="Don't commit after indexing")
args = parser.parse_args()
result = await index_documents(args.json_file, args.collection, args.commit)
sys.exit(0 if result else 1)
if __name__ == "__main__":
asyncio.run(main())
```
--------------------------------------------------------------------------------
/solr_config/unified/conf/solrconfig.xml:
--------------------------------------------------------------------------------
```
<?xml version="1.0" encoding="UTF-8" ?>
<config>
<luceneMatchVersion>9.5.0</luceneMatchVersion>
<!-- Data Directory -->
<dataDir>${solr.data.dir:}</dataDir>
<!-- Directory for storing index files -->
<directoryFactory name="DirectoryFactory" class="solr.NRTCachingDirectoryFactory"/>
<!-- Request Handler for Search (with vector search) -->
<requestHandler name="/select" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">_text_</str>
<str name="rows">10</str>
</lst>
</requestHandler>
<!-- Use built-in vector search capabilities -->
<requestHandler name="/knn" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
</lst>
</requestHandler>
<!-- Handler for vector search -->
<requestHandler name="/vector" class="solr.SearchHandler">
<lst name="defaults">
<str name="echoParams">explicit</str>
<str name="wt">json</str>
<str name="indent">true</str>
<str name="df">_text_</str>
<str name="rows">10</str>
</lst>
</requestHandler>
<!-- SpellCheck -->
<searchComponent name="spellcheck" class="solr.SpellCheckComponent">
<str name="queryAnalyzerFieldType">text_general</str>
<lst name="spellchecker">
<str name="name">default</str>
<str name="field">_text_</str>
<str name="classname">solr.DirectSolrSpellChecker</str>
<str name="distanceMeasure">internal</str>
<float name="accuracy">0.5</float>
<int name="maxEdits">2</int>
<int name="minPrefix">1</int>
<int name="maxInspections">5</int>
<int name="minQueryLength">3</int>
<float name="maxQueryFrequency">0.01</float>
</lst>
</searchComponent>
<!-- Suggester for auto-complete -->
<searchComponent name="suggest" class="solr.SuggestComponent">
<lst name="suggester">
<str name="name">default</str>
<str name="lookupImpl">BlendedInfixLookupFactory</str>
<str name="dictionaryImpl">DocumentDictionaryFactory</str>
<str name="field">_text_</str>
<str name="weightField">popularity</str>
<str name="suggestAnalyzerFieldType">text_general</str>
<str name="buildOnStartup">false</str>
<str name="buildOnCommit">false</str>
</lst>
</searchComponent>
<!-- Request handler for suggestions -->
<requestHandler name="/suggest" class="solr.SearchHandler">
<lst name="defaults">
<str name="suggest">true</str>
<str name="suggest.dictionary">default</str>
<str name="suggest.count">10</str>
</lst>
<arr name="components">
<str>suggest</str>
</arr>
</requestHandler>
<!-- Update request handlers -->
<requestHandler name="/update" class="solr.UpdateRequestHandler"/>
<requestHandler name="/update/json" class="solr.UpdateRequestHandler">
<lst name="defaults">
<str name="stream.contentType">application/json</str>
<str name="stream.body">{}</str>
</lst>
</requestHandler>
<!-- Standard components -->
<updateHandler class="solr.DirectUpdateHandler2"/>
<requestDispatcher>
<requestParsers enableRemoteStreaming="true" multipartUploadLimitInKB="2048000" formdataUploadLimitInKB="2048"/>
<httpCaching never304="true"/>
</requestDispatcher>
<!-- Manage cache sizes -->
<query>
<filterCache class="solr.FastLRUCache" size="512" initialSize="512" autowarmCount="0"/>
<queryResultCache class="solr.LRUCache" size="512" initialSize="512" autowarmCount="0"/>
<documentCache class="solr.LRUCache" size="512" initialSize="512" autowarmCount="0"/>
<cache name="knnCache" class="solr.search.LRUCache" size="512" initialSize="512" autowarmCount="0"/>
</query>
<!-- Response Writers -->
<queryResponseWriter name="json" class="solr.JSONResponseWriter">
<str name="content-type">text/plain; charset=UTF-8</str>
</queryResponseWriter>
<!-- SQL request handler -->
<requestHandler name="/sql" class="org.apache.solr.handler.sql.SQLHandler">
<lst name="defaults">
<str name="wt">json</str>
<str name="indent">true</str>
</lst>
</requestHandler>
</config>
```
--------------------------------------------------------------------------------
/solr_mcp/vector_provider/clients/ollama.py:
--------------------------------------------------------------------------------
```python
"""Ollama vector provider implementation."""
from typing import Any, Dict, List, Optional
import requests
from loguru import logger
from solr_mcp.solr.interfaces import VectorSearchProvider
from solr_mcp.vector_provider.constants import MODEL_DIMENSIONS, OLLAMA_EMBEDDINGS_PATH
class OllamaVectorProvider(VectorSearchProvider):
"""Vector provider that uses Ollama to vectorize text."""
def __init__(
self,
model: str = "nomic-embed-text",
base_url: str = "http://localhost:11434",
timeout: int = 30,
retries: int = 3,
):
"""Initialize the Ollama vector provider.
Args:
model: Name of the Ollama model to use
base_url: Base URL of the Ollama server
timeout: Request timeout in seconds
retries: Number of retries for failed requests
"""
self.model = model
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.retries = retries
logger.info(
f"Initialized Ollama vector provider with model={model} at {base_url} (timeout={timeout}s, retries={retries})"
)
async def get_vector(self, text: str, model: Optional[str] = None) -> List[float]:
"""Get vector for a single text.
Args:
text: Text to get vector for
model: Optional model to use for vectorization (overrides default)
Returns:
List of floats representing the text vector
Raises:
Exception: If there is an error getting vector
"""
url = f"{self.base_url}{OLLAMA_EMBEDDINGS_PATH}"
data = {"model": model or self.model, "prompt": text}
actual_model = data["model"]
for attempt in range(self.retries + 1):
try:
response = requests.post(url, json=data, timeout=self.timeout)
response.raise_for_status()
return response.json()["embedding"]
except Exception as e:
if attempt == self.retries:
raise Exception(
f"Failed to get vector with model {actual_model} after {self.retries} retries: {str(e)}"
)
logger.warning(
f"Failed to get vector with model {actual_model} (attempt {attempt + 1}/{self.retries + 1}): {str(e)}"
)
continue
async def get_vectors(
self, texts: List[str], model: Optional[str] = None
) -> List[List[float]]:
"""Get vector for multiple texts.
Args:
texts: List of texts to get vector for
model: Optional model to use for vectorization (overrides default)
Returns:
List of vectors (list of floats)
Raises:
Exception: If there is an error getting vector
"""
results = []
for text in texts:
vector = await self.get_vector(text, model)
results.append(vector)
return results
async def execute_vector_search(
self, client: Any, vector: List[float], top_k: int = 10
) -> Dict[str, Any]:
"""Execute vector similarity search.
Args:
client: Solr client instance
vector: Query vector
top_k: Number of results to return
Returns:
Dictionary containing search results
Raises:
Exception: If there is an error executing the search
"""
try:
# Build KNN query
knn_query = {
"q": "*:*",
"knn": f"{{!knn f=vector topK={top_k}}}[{','.join(str(x) for x in vector)}]",
}
# Execute search
results = client.search(**knn_query)
return results
except Exception as e:
raise Exception(f"Vector search failed: {str(e)}")
@property
def vector_dimension(self) -> int:
"""Get the dimension of vectors produced by this provider.
Returns:
Integer dimension of the vectors
"""
return MODEL_DIMENSIONS.get(
self.model, 768
) # Default to 768 if model not found
@property
def model_name(self) -> str:
"""Get the name of the model used by this provider.
Returns:
String name of the model
"""
return self.model
```
--------------------------------------------------------------------------------
/tests/unit/test_interfaces.py:
--------------------------------------------------------------------------------
```python
"""Unit tests for Solr client interfaces."""
from abc import ABC
from typing import Any, Dict, List, Optional
import pytest
from solr_mcp.solr.interfaces import CollectionProvider, VectorSearchProvider
def test_collection_provider_is_abstract():
"""Test that CollectionProvider is an abstract base class."""
assert issubclass(CollectionProvider, ABC)
assert CollectionProvider.__abstractmethods__ == {
"list_collections",
"collection_exists",
}
def test_collection_provider_cannot_instantiate():
"""Test that CollectionProvider cannot be instantiated directly."""
with pytest.raises(TypeError) as exc_info:
CollectionProvider()
assert "abstract methods collection_exists, list_collections" in str(exc_info.value)
def test_collection_provider_requires_methods():
"""Test that implementations must provide required methods."""
class IncompleteProvider(CollectionProvider):
pass
with pytest.raises(TypeError) as exc_info:
IncompleteProvider()
assert "abstract methods collection_exists, list_collections" in str(exc_info.value)
@pytest.mark.asyncio
async def test_collection_provider_implementation():
"""Test that a complete implementation can be instantiated."""
class ValidProvider(CollectionProvider):
async def list_collections(self) -> List[str]:
return ["collection1"]
async def collection_exists(self, collection: str) -> bool:
return collection in ["collection1"]
provider = ValidProvider()
assert isinstance(provider, CollectionProvider)
result = await provider.list_collections()
assert result == ["collection1"]
exists = await provider.collection_exists("collection1")
assert exists is True
def test_vector_search_provider_is_abstract():
"""Test that VectorSearchProvider is an abstract base class."""
assert issubclass(VectorSearchProvider, ABC)
assert VectorSearchProvider.__abstractmethods__ == {
"execute_vector_search",
"get_vector",
}
def test_vector_search_provider_cannot_instantiate():
"""Test that VectorSearchProvider cannot be instantiated directly."""
with pytest.raises(TypeError) as exc_info:
VectorSearchProvider()
assert "abstract methods" in str(exc_info.value)
assert "execute_vector_search" in str(exc_info.value)
assert "get_vector" in str(exc_info.value)
def test_vector_search_provider_requires_all_methods():
"""Test that implementations must provide all required methods."""
class IncompleteProvider(VectorSearchProvider):
def execute_vector_search(
self,
client: Any,
vector: List[float],
field: str,
top_k: Optional[int] = None,
) -> Dict[str, Any]:
return {"response": {"docs": []}}
with pytest.raises(TypeError) as exc_info:
IncompleteProvider()
assert (
"Can't instantiate abstract class IncompleteProvider with abstract method get_vector"
== str(exc_info.value)
)
def test_vector_search_provider_implementation():
"""Test that a complete implementation can be instantiated."""
class ValidProvider(VectorSearchProvider):
def execute_vector_search(
self,
client: Any,
vector: List[float],
field: str,
top_k: Optional[int] = None,
) -> Dict[str, Any]:
return {"response": {"docs": []}}
async def get_vector(self, text: str) -> List[float]:
return [0.1, 0.2, 0.3]
provider = ValidProvider()
assert isinstance(provider, VectorSearchProvider)
assert provider.execute_vector_search(None, [0.1], "vector_field") == {
"response": {"docs": []}
}
@pytest.mark.asyncio
async def test_vector_search_provider_async_method():
"""Test that async get_vector method works correctly."""
class ValidProvider(VectorSearchProvider):
def execute_vector_search(
self,
client: Any,
vector: List[float],
field: str,
top_k: Optional[int] = None,
) -> Dict[str, Any]:
return {"response": {"docs": []}}
async def get_vector(self, text: str) -> List[float]:
return [0.1, 0.2, 0.3]
provider = ValidProvider()
result = await provider.get_vector("test")
assert result == [0.1, 0.2, 0.3]
```
--------------------------------------------------------------------------------
/solr_mcp/solr/query/parser.py:
--------------------------------------------------------------------------------
```python
"""Query parser for Solr."""
import logging
from typing import List, Optional, Tuple
from loguru import logger
from sqlglot import ParseError, exp, parse_one
from sqlglot.expressions import (
Alias,
Binary,
Column,
From,
Identifier,
Ordered,
Select,
Star,
Table,
Where,
)
from solr_mcp.solr.exceptions import QueryError
logger = logging.getLogger(__name__)
class QueryParser:
"""Parses SQL queries for Solr."""
def preprocess_query(self, query: str) -> str:
"""Preprocess query to handle field:value syntax.
Args:
query: SQL query to preprocess
Returns:
Preprocessed query
"""
# Convert field:value to field = 'value'
parts = query.split()
for i, part in enumerate(parts):
if ":" in part and not part.startswith('"') and not part.endswith('"'):
field, value = part.split(":")
parts[i] = f"{field} = '{value}'"
return " ".join(parts)
def parse_select(self, query: str) -> Tuple[Select, str, List[str]]:
"""Parse a SELECT query.
Args:
query: SQL query to parse
Returns:
Tuple of (AST, collection name, selected fields)
Raises:
QueryError: If query is invalid
"""
try:
# Validate and parse query
preprocessed = self.preprocess_query(query)
try:
ast = parse_one(preprocessed)
except ParseError as e:
raise QueryError(f"Invalid SQL syntax: {str(e)}")
if not isinstance(ast, Select):
raise QueryError("Query must be a SELECT statement")
# Validate selected fields
if not ast.expressions:
raise QueryError("SELECT clause must specify at least one field")
# Get collection from FROM clause
from_expr = ast.args.get("from")
if not from_expr:
raise QueryError("FROM clause is required")
# Extract collection name
collection = None
if isinstance(from_expr, Table):
collection = from_expr.name
elif isinstance(from_expr, From):
if isinstance(from_expr.this, Table):
collection = from_expr.this.name
elif isinstance(from_expr.this, Identifier):
collection = from_expr.this.name
elif hasattr(from_expr.this, "this") and isinstance(
from_expr.this.this, (Table, Identifier)
):
collection = from_expr.this.this.name
if not collection:
raise QueryError("FROM clause must specify a collection")
# Get selected fields
fields = []
logger.debug(f"AST: {repr(ast)}")
for expr in ast.expressions:
logger.debug(f"Expression: {repr(expr)}")
logger.debug(f"Expression type: {type(expr)}")
logger.debug(f"Expression args: {expr.args}")
if isinstance(expr, Star):
fields.append("*")
elif isinstance(expr, Column):
fields.append(expr.args["this"].name)
elif isinstance(expr, Alias):
fields.append(expr.args["alias"].this)
elif isinstance(expr, Identifier):
fields.append(expr.name)
return ast, collection, fields
except QueryError as e:
raise e
except Exception as e:
raise QueryError(f"Error parsing query: {str(e)}")
def get_sort_fields(self, ast: Select) -> List[Tuple[str, str]]:
"""Get sort fields from AST.
Args:
ast: Query AST
Returns:
List of (field, direction) tuples
"""
sort_fields = []
if ast.args.get("order"):
for expr in ast.args["order"]:
if isinstance(expr, Ordered):
field = (
expr.this.name
if isinstance(expr.this, Identifier)
else expr.this.args["this"].name
)
direction = expr.args["desc"] and "DESC" or "ASC"
sort_fields.append((field, direction))
return sort_fields
def extract_sort_fields(self, sort_spec: str) -> List[str]:
"""Extract field names from a sort specification.
Args:
sort_spec: Sort specification string
Returns:
List of field names
"""
fields = []
parts = sort_spec.split(",")
for part in parts:
field = part.strip().split()[0]
fields.append(field)
return fields
```
--------------------------------------------------------------------------------
/tests/unit/test_validator.py:
--------------------------------------------------------------------------------
```python
"""Unit tests for QueryValidator."""
from unittest.mock import Mock
import pytest
from solr_mcp.solr.exceptions import QueryError
from solr_mcp.solr.query.validator import QueryValidator
@pytest.fixture
def mock_field_manager():
"""Mock FieldManager for testing."""
mock = Mock()
mock.get_field_types.return_value = {
"id": "string",
"title": "text_general",
"content": "text_general",
"vector": "knn_vector",
}
mock.get_field_info.return_value = {
"sortable_fields": {
"id": {"directions": ["asc", "desc"], "default_direction": "asc"},
"title": {"directions": ["asc", "desc"], "default_direction": "asc"},
}
}
return mock
@pytest.fixture
def query_validator(mock_field_manager):
"""Create QueryValidator instance with mocked dependencies."""
return QueryValidator(field_manager=mock_field_manager)
class TestQueryValidator:
"""Test cases for QueryValidator."""
def test_init(self, query_validator, mock_field_manager):
"""Test QueryValidator initialization."""
assert query_validator.field_manager == mock_field_manager
def test_validate_fields_valid(self, query_validator):
"""Test validating valid fields."""
fields = ["id", "title", "content"]
# Should not raise any exceptions
query_validator.validate_fields("collection1", fields)
def test_validate_fields_invalid(self, query_validator):
"""Test validating invalid fields."""
fields = ["id", "nonexistent_field"]
with pytest.raises(QueryError) as exc_info:
query_validator.validate_fields("collection1", fields)
assert "Invalid field 'nonexistent_field'" in str(exc_info.value)
def test_validate_fields_error_handling(self, query_validator, mock_field_manager):
"""Test error handling in validate_fields."""
mock_field_manager.get_field_types.side_effect = Exception("Test error")
with pytest.raises(QueryError) as exc_info:
query_validator.validate_fields("collection1", ["id"])
assert "Field validation error" in str(exc_info.value)
def test_validate_sort_fields_valid(self, query_validator, mock_field_manager):
"""Test validating valid sort fields."""
fields = ["id", "title"]
# Should not raise any exceptions
query_validator.validate_sort_fields("collection1", fields)
def test_validate_sort_fields_invalid(self, query_validator, mock_field_manager):
"""Test validating invalid sort fields."""
mock_field_manager.validate_sort_fields.side_effect = Exception(
"Invalid sort field"
)
with pytest.raises(QueryError) as exc_info:
query_validator.validate_sort_fields("collection1", ["nonexistent_field"])
assert "Sort field validation error" in str(exc_info.value)
def test_validate_sort_none(self, query_validator):
"""Test validating None sort parameter."""
result = query_validator.validate_sort(None, "collection1")
assert result is None
def test_validate_sort_field_only(self, query_validator):
"""Test validating sort with field only."""
result = query_validator.validate_sort("id", "collection1")
assert result == "id asc" # Uses default direction
def test_validate_sort_field_and_direction(self, query_validator):
"""Test validating sort with field and direction."""
result = query_validator.validate_sort("id desc", "collection1")
assert result == "id desc"
def test_validate_sort_invalid_format(self, query_validator):
"""Test validating sort with invalid format."""
with pytest.raises(QueryError) as exc_info:
query_validator.validate_sort("id desc asc", "collection1")
assert "Invalid sort format" in str(exc_info.value)
def test_validate_sort_non_sortable_field(self, query_validator):
"""Test validating sort with non-sortable field."""
with pytest.raises(QueryError) as exc_info:
query_validator.validate_sort("content desc", "collection1")
assert "Field 'content' is not sortable" in str(exc_info.value)
def test_validate_sort_invalid_direction(self, query_validator):
"""Test validating sort with invalid direction."""
with pytest.raises(QueryError) as exc_info:
query_validator.validate_sort("id invalid", "collection1")
assert "Invalid sort direction 'invalid'" in str(exc_info.value)
def test_validate_sort_error_handling(self, query_validator, mock_field_manager):
"""Test error handling in validate_sort."""
mock_field_manager.get_field_info.side_effect = Exception("Test error")
with pytest.raises(QueryError) as exc_info:
query_validator.validate_sort("id desc", "collection1")
assert "Sort field validation error" in str(exc_info.value)
```
--------------------------------------------------------------------------------
/solr_mcp/solr/collections.py:
--------------------------------------------------------------------------------
```python
"""Collection providers for SolrCloud."""
import logging
from typing import List, Optional
import anyio
import requests
from kazoo.client import KazooClient
from kazoo.exceptions import ConnectionLoss, NoNodeError
from solr_mcp.solr.exceptions import ConnectionError, SolrError
from solr_mcp.solr.interfaces import CollectionProvider
logger = logging.getLogger(__name__)
class HttpCollectionProvider(CollectionProvider):
"""Collection provider that uses Solr HTTP API to discover collections."""
def __init__(self, base_url: str):
"""Initialize with Solr base URL.
Args:
base_url: Base URL for Solr instance (e.g., http://localhost:8983/solr)
"""
self.base_url = base_url.rstrip("/")
async def list_collections(self) -> List[str]:
"""List all available collections using Solr HTTP API.
Returns:
List of collection names
Raises:
SolrError: If unable to retrieve collections
"""
try:
response = requests.get(f"{self.base_url}/admin/collections?action=LIST")
if response.status_code != 200:
raise SolrError(f"Failed to list collections: {response.text}")
collections = response.json().get("collections", [])
return collections
except Exception as e:
raise SolrError(f"Failed to list collections: {str(e)}")
async def collection_exists(self, collection: str) -> bool:
"""Check if a collection exists.
Args:
collection: Name of the collection to check
Returns:
True if the collection exists, False otherwise
Raises:
SolrError: If unable to check collection existence
"""
try:
collections = await self.list_collections()
return collection in collections
except Exception as e:
raise SolrError(f"Failed to check if collection exists: {str(e)}")
class ZooKeeperCollectionProvider(CollectionProvider):
"""Collection provider that uses ZooKeeper to discover collections."""
def __init__(self, hosts: List[str]):
"""Initialize with ZooKeeper hosts.
Args:
hosts: List of ZooKeeper hosts in format host:port
"""
self.hosts = hosts
self.zk = None
self.connect()
def connect(self):
"""Connect to ZooKeeper and verify /collections path exists."""
try:
self.zk = KazooClient(hosts=",".join(self.hosts))
self.zk.start()
# Check if /collections path exists
if not self.zk.exists("/collections"):
raise ConnectionError("ZooKeeper /collections path does not exist")
except ConnectionLoss as e:
raise ConnectionError(f"Failed to connect to ZooKeeper: {str(e)}")
except Exception as e:
raise ConnectionError(f"Error connecting to ZooKeeper: {str(e)}")
def cleanup(self):
"""Clean up ZooKeeper connection."""
if self.zk:
try:
self.zk.stop()
self.zk.close()
except Exception:
pass # Ignore cleanup errors
finally:
self.zk = None
async def list_collections(self) -> List[str]:
"""List available collections from ZooKeeper.
Returns:
List of collection names
Raises:
ConnectionError: If there is an error communicating with ZooKeeper
"""
try:
if not self.zk:
raise ConnectionError("Not connected to ZooKeeper")
collections = await anyio.to_thread.run_sync(
self.zk.get_children, "/collections"
)
return collections
except NoNodeError:
return [] # No collections exist yet
except ConnectionLoss as e:
raise ConnectionError(f"Lost connection to ZooKeeper: {str(e)}")
except Exception as e:
raise ConnectionError(f"Error listing collections: {str(e)}")
async def collection_exists(self, collection: str) -> bool:
"""Check if a collection exists in ZooKeeper.
Args:
collection: Name of the collection to check
Returns:
True if the collection exists, False otherwise
Raises:
ConnectionError: If there is an error communicating with ZooKeeper
"""
try:
if not self.zk:
raise ConnectionError("Not connected to ZooKeeper")
# Check for collection in ZooKeeper
collection_path = f"/collections/{collection}"
exists = await anyio.to_thread.run_sync(self.zk.exists, collection_path)
return exists is not None
except ConnectionLoss as e:
raise ConnectionError(f"Lost connection to ZooKeeper: {str(e)}")
except Exception as e:
raise ConnectionError(f"Error checking collection existence: {str(e)}")
```
--------------------------------------------------------------------------------
/scripts/simple_search.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Simple search script to demonstrate searching in Solr without MCP.
"""
import argparse
import asyncio
import json
import os
import sys
from typing import Dict, List, Optional
# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.solr.client import SolrClient
from solr_mcp.embeddings.client import OllamaClient
async def search_by_text(query: str, collection: Optional[str] = None, rows: int = 5):
"""
Perform a text search using the SolrClient directly.
Args:
query: Search query
collection: Collection name (optional)
rows: Number of results to return
"""
# Set up Solr client
solr_client = SolrClient()
try:
# Perform the search
print(f"Searching for: '{query}'")
result = await solr_client.search(
query=query,
collection=collection,
rows=rows
)
# Display results
print(f"\n=== Results for text search: '{query}' ===\n")
display_results(result)
except Exception as e:
print(f"Error during search: {e}")
async def search_by_vector(query: str, collection: Optional[str] = None, k: int = 5):
"""
Perform a vector similarity search using the SolrClient directly.
Args:
query: Text to generate embedding from
collection: Collection name (optional)
k: Number of nearest neighbors to return
"""
# Set up clients
solr_client = SolrClient()
ollama_client = OllamaClient()
try:
# Generate embedding for the query
print(f"Generating embedding for: '{query}'")
embedding = await ollama_client.get_embedding(query)
# Perform the vector search
print(f"Performing vector search")
result = await solr_client.vector_search(
vector=embedding,
collection=collection,
k=k
)
# Display results
print(f"\n=== Results for vector search: '{query}' ===\n")
display_results(result)
except Exception as e:
print(f"Error during vector search: {e}")
def display_results(result_json: str):
"""
Display search results in a readable format.
Args:
result_json: JSON string with search results
"""
try:
data = json.loads(result_json)
if "docs" in data and isinstance(data["docs"], list):
docs = data["docs"]
if not docs:
print("No results found.")
return
for i, doc in enumerate(docs, 1):
print(f"Result {i}:")
# Handle title which could be a string or list
title = doc.get('title', 'No title')
if isinstance(title, list):
title = title[0]
print(f" Title: {title}")
print(f" ID: {doc.get('id', 'No ID')}")
if "score" in doc:
print(f" Score: {doc['score']}")
# Show a preview of the content (first 150 chars)
content = doc.get("content", "")
if content:
preview = content[:150] + "..." if len(content) > 150 else content
print(f" Preview: {preview}")
if "category" in doc:
categories = doc["category"] if isinstance(doc["category"], list) else [doc["category"]]
print(f" Categories: {', '.join(categories)}")
if "tags" in doc:
tags = doc["tags"] if isinstance(doc["tags"], list) else [doc["tags"]]
print(f" Tags: {', '.join(tags)}")
print()
print(f"Total results: {data.get('numFound', len(docs))}")
else:
print("Unexpected result format:")
print(result_json)
except json.JSONDecodeError:
print("Could not parse JSON response:")
print(result_json)
except Exception as e:
print(f"Error displaying results: {e}")
async def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Simple search script for Solr")
parser.add_argument("query", help="Search query")
parser.add_argument("--vector", "-v", action="store_true", help="Use vector search instead of text search")
parser.add_argument("--collection", "-c", default="documents", help="Collection name")
parser.add_argument("--results", "-n", type=int, default=5, help="Number of results to return")
args = parser.parse_args()
if args.vector:
await search_by_vector(args.query, args.collection, args.results)
else:
await search_by_text(args.query, args.collection, args.results)
if __name__ == "__main__":
asyncio.run(main())
```
--------------------------------------------------------------------------------
/tests/unit/solr/test_client.py:
--------------------------------------------------------------------------------
```python
"""Tests for SolrClient."""
import asyncio
from unittest.mock import AsyncMock, Mock, patch
import aiohttp
import pysolr
import pytest
import requests
from aiohttp import test_utils
from solr_mcp.solr.client import SolrClient
from solr_mcp.solr.exceptions import (
ConnectionError,
DocValuesError,
QueryError,
SolrError,
SQLExecutionError,
SQLParseError,
)
@pytest.mark.asyncio
async def test_init_with_defaults(mock_config):
"""Test initialization with only config."""
client = SolrClient(config=mock_config)
assert client.config == mock_config
@pytest.mark.asyncio
async def test_init_with_custom_providers(
client,
mock_config,
mock_collection_provider,
mock_field_manager,
mock_vector_provider,
mock_query_builder,
):
"""Test initialization with custom providers."""
assert client.config == mock_config
assert client.collection_provider == mock_collection_provider
assert client.field_manager == mock_field_manager
assert client.vector_provider == mock_vector_provider
assert client.query_builder == mock_query_builder
@pytest.mark.asyncio
async def test_get_or_create_client_with_collection(client):
"""Test getting Solr client with specific collection."""
solr_client = await client._get_or_create_client("test_collection")
assert solr_client is not None
@pytest.mark.asyncio
async def test_get_or_create_client_with_different_collection(client):
"""Test getting Solr client with a different collection."""
solr_client = await client._get_or_create_client("another_collection")
assert solr_client is not None
@pytest.mark.asyncio
async def test_get_or_create_client_no_collection(mock_config):
"""Test error when no collection specified."""
client = SolrClient(config=mock_config)
with pytest.raises(SolrError):
await client._get_or_create_client(None)
@pytest.mark.asyncio
async def test_list_collections_success(client):
"""Test successful collection listing."""
# Mock the collection provider's list_collections method
expected_collections = ["test_collection"]
client.collection_provider.list_collections = AsyncMock(
return_value=expected_collections
)
# Test the method
result = await client.list_collections()
assert result == expected_collections
# Verify the collection provider was called
client.collection_provider.list_collections.assert_called_once()
@pytest.mark.asyncio
async def test_list_fields_schema_error(client):
"""Test schema error handling in list_fields."""
# Mock field_manager.list_fields to raise an error
client.field_manager.list_fields = AsyncMock(side_effect=SolrError("Schema error"))
# Test that the error is propagated
with pytest.raises(SolrError):
await client.list_fields("test_collection")
@pytest.mark.asyncio
async def test_execute_select_query_success(client):
"""Test successful SQL query execution."""
# Mock parser.preprocess_query
client.query_builder.parser.preprocess_query = Mock(
return_value="SELECT * FROM test_collection"
)
# Mock the parse_and_validate_select
client.query_builder.parse_and_validate_select = Mock(
return_value=(Mock(), "test_collection", None)
)
# Mock the query executor
expected_result = {
"result-set": {"docs": [{"id": "1", "title": "Test"}], "numFound": 1}
}
client.query_executor.execute_select_query = AsyncMock(return_value=expected_result)
# Execute the query
result = await client.execute_select_query("SELECT * FROM test_collection")
# Verify the result
assert result == expected_result
client.query_executor.execute_select_query.assert_called_once_with(
query="SELECT * FROM test_collection", collection="test_collection"
)
@pytest.mark.asyncio
async def test_execute_select_query_docvalues_error(client):
"""Test SQL query with DocValues error."""
# Mock parser.preprocess_query
client.query_builder.parser.preprocess_query = Mock(
return_value="SELECT * FROM test_collection"
)
# Mock the parse_and_validate_select
client.query_builder.parse_and_validate_select = Mock(
return_value=(Mock(), "test_collection", None)
)
# Mock the query executor to raise a DocValuesError
client.query_executor.execute_select_query = AsyncMock(
side_effect=DocValuesError("must have DocValues to use this feature", 10)
)
# Execute the query and verify the error
with pytest.raises(DocValuesError):
await client.execute_select_query("SELECT * FROM test_collection")
@pytest.mark.asyncio
async def test_execute_select_query_parse_error(client):
"""Test SQL query with parse error."""
# Mock parser.preprocess_query
client.query_builder.parser.preprocess_query = Mock(return_value="INVALID SQL")
# Mock the parse_and_validate_select
client.query_builder.parse_and_validate_select = Mock(
return_value=(Mock(), "test_collection", None)
)
# Mock the query executor to raise a SQLParseError
client.query_executor.execute_select_query = AsyncMock(
side_effect=SQLParseError("parse failed: syntax error", 10)
)
# Execute the query and verify the error
with pytest.raises(SQLParseError):
await client.execute_select_query("INVALID SQL")
```
--------------------------------------------------------------------------------
/solr_mcp/solr/config.py:
--------------------------------------------------------------------------------
```python
"""Configuration for Solr client."""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import pydantic
from pydantic import BaseModel, Field, field_validator, model_validator
from solr_mcp.solr.exceptions import ConfigurationError
logger = logging.getLogger(__name__)
class SolrConfig(BaseModel):
"""Configuration for Solr client."""
solr_base_url: str = Field(description="Base URL for Solr instance")
zookeeper_hosts: List[str] = Field(description="List of ZooKeeper hosts")
connection_timeout: int = Field(
default=10, gt=0, description="Connection timeout in seconds"
)
def __init__(self, **data):
"""Initialize SolrConfig with validation error handling."""
try:
super().__init__(**data)
except pydantic.ValidationError as e:
# Convert Pydantic validation errors to our custom ConfigurationError
for error in e.errors():
if error["type"] == "missing":
field = error["loc"][0]
raise ConfigurationError(f"{field} is required")
elif error["type"] == "greater_than":
field = error["loc"][0]
if field == "connection_timeout":
raise ConfigurationError("connection_timeout must be positive")
# If we get here, it's some other validation error
raise ConfigurationError(str(e))
@field_validator("solr_base_url")
def validate_solr_url(cls, v: str) -> str:
"""Validate Solr base URL."""
if not v:
raise ConfigurationError("solr_base_url is required")
if not v.startswith(("http://", "https://")):
raise ConfigurationError(
"Solr base URL must start with http:// or https://"
)
return v
@field_validator("zookeeper_hosts")
def validate_zookeeper_hosts(cls, v: List[str]) -> List[str]:
"""Validate ZooKeeper hosts."""
if not v:
raise ConfigurationError("zookeeper_hosts is required")
if not all(isinstance(host, str) for host in v):
raise ConfigurationError("ZooKeeper hosts must be strings")
return v
@model_validator(mode="after")
def validate_config(self) -> "SolrConfig":
"""Validate the complete configuration."""
# Validate solr_base_url
if not self.solr_base_url:
raise ConfigurationError("solr_base_url is required")
# Validate zookeeper_hosts
if not self.zookeeper_hosts:
raise ConfigurationError("zookeeper_hosts is required")
# Validate numeric fields
if self.connection_timeout <= 0:
raise ConfigurationError("connection_timeout must be positive")
return self
@classmethod
def load(cls, config_path: str) -> "SolrConfig":
"""Load configuration from JSON file.
Args:
config_path: Path to JSON config file
Returns:
SolrConfig instance
Raises:
ConfigurationError: If file cannot be loaded or is invalid
"""
try:
with open(config_path) as f:
config_dict = json.load(f)
try:
return cls(**config_dict)
except pydantic.ValidationError as e:
# Convert Pydantic validation errors to our custom ConfigurationError
for error in e.errors():
if error["type"] == "missing":
field = error["loc"][0]
raise ConfigurationError(f"{field} is required")
elif error["type"] == "greater_than":
field = error["loc"][0]
if field == "connection_timeout":
raise ConfigurationError(
"connection_timeout must be positive"
)
# If we get here, it's some other validation error
raise ConfigurationError(str(e))
except FileNotFoundError:
raise ConfigurationError(f"Configuration file not found: {config_path}")
except json.JSONDecodeError:
raise ConfigurationError(
f"Invalid JSON in configuration file: {config_path}"
)
except Exception as e:
if isinstance(e, ConfigurationError):
raise
raise ConfigurationError(f"Failed to load config: {str(e)}")
def to_dict(self) -> Dict[str, Any]:
"""Convert config to dictionary."""
return self.model_dump()
def model_validate(cls, *args, **kwargs):
"""Override model_validate to handle validation errors."""
try:
return super().model_validate(*args, **kwargs)
except pydantic.ValidationError as e:
# Convert Pydantic validation errors to our custom ConfigurationError
for error in e.errors():
if error["type"] == "missing":
field = error["loc"][0]
raise ConfigurationError(f"{field} is required")
elif error["type"] == "greater_than":
field = error["loc"][0]
if field == "connection_timeout":
raise ConfigurationError("connection_timeout must be positive")
# If we get here, it's some other validation error
raise ConfigurationError(str(e))
```
--------------------------------------------------------------------------------
/tests/unit/solr/vector/test_results.py:
--------------------------------------------------------------------------------
```python
"""Tests for solr_mcp.solr.vector.results module."""
from typing import Any, Dict
import pytest
from solr_mcp.solr.vector.results import VectorSearchResult, VectorSearchResults
@pytest.fixture
def sample_result_data() -> Dict[str, Any]:
"""Create sample result data."""
return {
"docid": "123",
"score": 0.95,
"distance": 0.05,
"metadata": {"title": "Test Document", "author": "Test Author"},
}
@pytest.fixture
def sample_solr_response() -> Dict[str, Any]:
"""Create sample Solr response."""
return {
"responseHeader": {"QTime": 50},
"response": {
"numFound": 2,
"docs": [
{
"_docid_": "123",
"score": 0.95,
"_vector_distance_": 0.05,
"title": "Test Document 1",
"author": "Test Author 1",
},
{
"_docid_": "456",
"score": 0.85,
"_vector_distance_": 0.15,
"title": "Test Document 2",
"author": "Test Author 2",
},
],
},
}
def test_vector_search_result_creation(sample_result_data):
"""Test VectorSearchResult creation and properties."""
result = VectorSearchResult(**sample_result_data)
assert result.docid == "123"
assert result.score == 0.95
assert result.distance == 0.05
assert result.metadata == {"title": "Test Document", "author": "Test Author"}
def test_vector_search_result_subscript(sample_result_data):
"""Test VectorSearchResult subscript access."""
result = VectorSearchResult(**sample_result_data)
assert result["docid"] == "123"
assert result["score"] == 0.95
assert result["distance"] == 0.05
assert result["metadata"] == {"title": "Test Document", "author": "Test Author"}
def test_vector_search_result_invalid_key(sample_result_data):
"""Test VectorSearchResult invalid key access."""
result = VectorSearchResult(**sample_result_data)
with pytest.raises(KeyError, match="Invalid key: invalid_key"):
_ = result["invalid_key"]
def test_vector_search_results_creation(sample_solr_response):
"""Test VectorSearchResults creation from Solr response."""
results = VectorSearchResults.from_solr_response(sample_solr_response, top_k=10)
assert len(results.results) == 2
assert results.total_found == 2
assert results.top_k == 10
assert results.query_time_ms == 50
def test_vector_search_results_docs_property(sample_solr_response):
"""Test VectorSearchResults docs property."""
results = VectorSearchResults.from_solr_response(sample_solr_response, top_k=10)
docs = results.docs
assert len(docs) == 2
assert isinstance(docs[0], VectorSearchResult)
assert docs[0].docid == "123"
assert docs[1].docid == "456"
def test_vector_search_results_alternate_docid_fields():
"""Test VectorSearchResults with alternate docid field names."""
response = {
"response": {"numFound": 1, "docs": [{"[docid]": "123", "score": 0.95}]}
}
results = VectorSearchResults.from_solr_response(response, top_k=10)
assert results.results[0].docid == "123"
response["response"]["docs"][0] = {"docid": "456", "score": 0.85}
results = VectorSearchResults.from_solr_response(response, top_k=10)
assert results.results[0].docid == "456"
def test_vector_search_results_missing_docid():
"""Test VectorSearchResults with missing docid field."""
response = {"response": {"numFound": 1, "docs": [{"score": 0.95}]}}
results = VectorSearchResults.from_solr_response(response, top_k=10)
assert results.results[0].docid == "0"
def test_vector_search_results_to_dict(sample_solr_response):
"""Test VectorSearchResults to_dict method."""
results = VectorSearchResults.from_solr_response(sample_solr_response, top_k=10)
result_dict = results.to_dict()
assert "results" in result_dict
assert "metadata" in result_dict
assert len(result_dict["results"]) == 2
assert result_dict["metadata"]["total_found"] == 2
assert result_dict["metadata"]["top_k"] == 10
assert result_dict["metadata"]["query_time_ms"] == 50
def test_vector_search_results_get_methods(sample_solr_response):
"""Test VectorSearchResults getter methods."""
results = VectorSearchResults.from_solr_response(sample_solr_response, top_k=10)
assert results.get_doc_ids() == ["123", "456"]
assert results.get_scores() == [0.95, 0.85]
assert results.get_distances() == [0.05, 0.15]
def test_vector_search_results_empty_response():
"""Test VectorSearchResults with empty response."""
empty_response = {"responseHeader": {}, "response": {"numFound": 0, "docs": []}}
results = VectorSearchResults.from_solr_response(empty_response, top_k=10)
assert len(results.results) == 0
assert results.total_found == 0
assert results.get_doc_ids() == []
assert results.get_scores() == []
assert results.get_distances() == []
def test_vector_search_results_minimal_response():
"""Test VectorSearchResults with minimal response."""
minimal_response = {"response": {"docs": [{"_docid_": "123"}]}}
results = VectorSearchResults.from_solr_response(minimal_response, top_k=10)
assert len(results.results) == 1
assert results.total_found == 0 # Default when numFound is missing
assert results.query_time_ms is None # Default when QTime is missing
```
--------------------------------------------------------------------------------
/solr_config/unified/conf/schema.xml:
--------------------------------------------------------------------------------
```
<?xml version="1.0" encoding="UTF-8" ?>
<schema name="unified" version="1.6">
<!-- Field Types -->
<fieldType name="string" class="solr.StrField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="boolean" class="solr.BoolField" sortMissingLast="true" omitNorms="true"/>
<fieldType name="int" class="solr.IntPointField" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="float" class="solr.FloatPointField" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="long" class="solr.LongPointField" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="double" class="solr.DoublePointField" omitNorms="true" positionIncrementGap="0"/>
<fieldType name="date" class="solr.DatePointField" omitNorms="true" positionIncrementGap="0"/>
<!-- Improved text field with stemming and better tokenization -->
<fieldType name="text_general" class="solr.TextField" positionIncrementGap="100">
<analyzer type="index">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt"/>
<filter class="solr.WordDelimiterGraphFilterFactory"
generateWordParts="1"
generateNumberParts="1"
catenateWords="1"
catenateNumbers="1"
catenateAll="0"
splitOnCaseChange="1"/>
<filter class="solr.LowerCaseFilterFactory"/>
<filter class="solr.EnglishMinimalStemFilterFactory"/>
<filter class="solr.RemoveDuplicatesTokenFilterFactory"/>
</analyzer>
<analyzer type="query">
<tokenizer class="solr.StandardTokenizerFactory"/>
<filter class="solr.StopFilterFactory" ignoreCase="true" words="stopwords.txt"/>
<filter class="solr.WordDelimiterGraphFilterFactory"
generateWordParts="1"
generateNumberParts="1"
catenateWords="0"
catenateNumbers="0"
catenateAll="0"
splitOnCaseChange="1"/>
<filter class="solr.SynonymGraphFilterFactory" synonyms="synonyms.txt" ignoreCase="true" expand="true"/>
<filter class="solr.LowerCaseFilterFactory"/>
<filter class="solr.EnglishMinimalStemFilterFactory"/>
<filter class="solr.RemoveDuplicatesTokenFilterFactory"/>
</analyzer>
</fieldType>
<!-- Vector field type for embeddings -->
<fieldType name="knn_vector" class="solr.DenseVectorField"
vectorDimension="768" similarityFunction="cosine">
<vectorEncoding>FLOAT32</vectorEncoding>
</fieldType>
<!-- Fields for document -->
<!-- Unique identifier for each document -->
<field name="id" type="string" indexed="true" stored="true" required="true" multiValued="false"/>
<!-- Text fields - for full text search, use _text_ field which includes content from all these fields -->
<field name="title" type="text_general" indexed="true" stored="true"/>
<field name="content" type="text_general" indexed="true" stored="true"/>
<field name="text" type="text_general" indexed="true" stored="true"/>
<!-- Metadata fields - good for faceting and filtering -->
<field name="section" type="string" indexed="true" stored="true"/>
<field name="section_number" type="int" indexed="true" stored="true"/>
<field name="source" type="string" indexed="true" stored="true"/>
<field name="url" type="string" indexed="true" stored="true"/>
<field name="published_date" type="date" indexed="true" stored="true"/>
<field name="author" type="string" indexed="true" stored="true" multiValued="true"/>
<field name="tags" type="string" indexed="true" stored="true" multiValued="true"/>
<!-- Vector embedding field for similarity search -->
<field name="embedding" type="knn_vector" indexed="true" stored="false"/>
<!-- Dynamic field patterns -->
<dynamicField name="*_i" type="int" indexed="true" stored="true"/>
<dynamicField name="*_s" type="string" indexed="true" stored="true"/>
<dynamicField name="*_l" type="long" indexed="true" stored="true"/>
<dynamicField name="*_t" type="text_general" indexed="true" stored="true"/>
<dynamicField name="*_b" type="boolean" indexed="true" stored="true"/>
<dynamicField name="*_f" type="float" indexed="true" stored="true"/>
<dynamicField name="*_d" type="double" indexed="true" stored="true"/>
<dynamicField name="*_dt" type="date" indexed="true" stored="true"/>
<dynamicField name="*_ss" type="string" indexed="true" stored="true" multiValued="true"/>
<dynamicField name="*_vector" type="knn_vector" indexed="true" stored="true"/>
<!-- Required fields -->
<uniqueKey>id</uniqueKey>
<!-- Special fields -->
<!-- _text_ is the main field for full text search - it combines content from all text fields -->
<!-- Use this field for general text search queries like: WHERE _text_:'your search terms' -->
<field name="_text_" type="text_general" indexed="true" stored="false" multiValued="true"/>
<field name="_version_" type="long" indexed="true" stored="true"/>
<!-- Copy all text fields into _text_ for unified full-text search -->
<!-- This is why you should use _text_ for searching instead of individual fields -->
<copyField source="title" dest="_text_"/>
<copyField source="content" dest="_text_"/>
<copyField source="text" dest="_text_"/>
<copyField source="section" dest="_text_"/>
<copyField source="source" dest="_text_"/>
<copyField source="author" dest="_text_"/>
<copyField source="tags" dest="_text_"/>
<copyField source="*_t" dest="_text_"/>
<copyField source="*_s" dest="_text_"/>
</schema>
```
--------------------------------------------------------------------------------
/scripts/vector_search.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Test script for vector search in Solr.
"""
import argparse
import asyncio
import json
import os
import sys
from typing import Dict, List, Any
import httpx
# Add the project root to the path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from solr_mcp.embeddings.client import OllamaClient
async def generate_query_embedding(query_text: str) -> List[float]:
"""Generate embedding for a query using Ollama.
Args:
query_text: Query text to generate embedding for
Returns:
Embedding vector for the query
"""
client = OllamaClient()
print(f"Generating embedding for query: '{query_text}'")
embedding = await client.get_embedding(query_text)
return embedding
async def vector_search(
query: str,
collection: str = "testvectors",
vector_field: str = "embedding",
k: int = 5,
filter_query: str = None
):
"""
Perform a vector search in Solr using the generated embedding.
Args:
query: Search query text
collection: Solr collection name
vector_field: Name of the vector field
k: Number of results to return
filter_query: Optional filter query
"""
# Generate embedding for the query
query_embedding = await generate_query_embedding(query)
# Format the vector as a string that Solr expects for KNN search
vector_str = "[" + ",".join(str(v) for v in query_embedding) + "]"
# Prepare Solr KNN query
solr_url = f"http://localhost:8983/solr/{collection}/select"
# Build query parameters
params = {
"q": f"{{!knn f={vector_field} topK={k}}}{vector_str}",
"fl": "id,title,text,score,vector_model",
"wt": "json"
}
if filter_query:
params["fq"] = filter_query
print(f"Executing vector search in collection '{collection}'")
try:
# Split implementation - try POST first (to handle long vectors), fall back to GET
async with httpx.AsyncClient() as client:
try:
# First try with POST to handle large vectors
response = await client.post(
solr_url,
data={"q": params["q"]},
params={
"fl": params["fl"],
"wt": params["wt"]
},
timeout=30.0
)
except Exception as post_error:
print(f"POST request failed, trying GET: {post_error}")
# Fall back to GET with a shorter vector representation
# Truncate the vector string if needed
if len(vector_str) > 800:
short_vector = ",".join(str(round(v, 4)) for v in query_embedding[:100])
params["q"] = f"{{!knn f={vector_field} topK={k}}}{short_vector}"
response = await client.get(solr_url, params=params, timeout=30.0)
if response.status_code == 200:
result = response.json()
return result
else:
print(f"Error in vector search: {response.status_code} - {response.text}")
return None
except Exception as e:
print(f"Error during vector search: {e}")
return None
def display_results(results: Dict[str, Any]):
"""Display search results in a readable format.
Args:
results: Search results from Solr
"""
if not results or 'response' not in results:
print("No valid results received")
return
print("\n=== Vector Search Results ===\n")
docs = results['response']['docs']
num_found = results['response']['numFound']
if not docs:
print("No matching documents found.")
return
print(f"Found {num_found} matching document(s):\n")
for i, doc in enumerate(docs, 1):
print(f"Result {i}:")
print(f" ID: {doc.get('id', 'N/A')}")
# Handle title which could be a string or list
title = doc.get('title', 'N/A')
if isinstance(title, list) and title:
title = title[0]
print(f" Title: {title}")
if 'score' in doc:
print(f" Score: {doc['score']}")
# Handle text which could be string or list
text = doc.get('text', '')
if isinstance(text, list) and text:
text = text[0]
if text:
preview = text[:150] + "..." if len(text) > 150 else text
print(f" Preview: {preview}")
# Print model info if available
if 'vector_model' in doc:
print(f" Model: {doc.get('vector_model')}")
print()
async def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Test vector search in Solr")
parser.add_argument("query", help="Search query")
parser.add_argument("--collection", "-c", default="vectors", help="Solr collection name")
parser.add_argument("--field", "-f", default="embedding", help="Vector field name")
parser.add_argument("--results", "-k", type=int, default=5, help="Number of results to return")
parser.add_argument("--filter", "-fq", help="Optional filter query")
args = parser.parse_args()
results = await vector_search(
args.query,
args.collection,
args.field,
args.results,
args.filter
)
if results:
display_results(results)
if __name__ == "__main__":
asyncio.run(main())
```