This is page 2 of 2. Use http://codebase.md/redis/mcp-redis?lines=false&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ └── workflows
│ ├── ci.yml
│ ├── release.yml
│ └── stale-issues.yml
├── .gitignore
├── Dockerfile
├── examples
│ └── redis_assistant.py
├── fly.toml
├── gemini-extension.json
├── GEMINI.md
├── LICENSE
├── pyproject.toml
├── README.md
├── server.json
├── src
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── connection.py
│ │ ├── entraid_auth.py
│ │ ├── logging_utils.py
│ │ └── server.py
│ ├── main.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── hash.py
│ │ ├── json.py
│ │ ├── list.py
│ │ ├── misc.py
│ │ ├── pub_sub.py
│ │ ├── redis_query_engine.py
│ │ ├── server_management.py
│ │ ├── set.py
│ │ ├── sorted_set.py
│ │ ├── stream.py
│ │ └── string.py
│ └── version.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_config.py
│ ├── test_connection.py
│ ├── test_entraid_auth.py
│ ├── test_integration.py
│ ├── test_logging_utils.py
│ ├── test_main.py
│ ├── test_server.py
│ └── tools
│ ├── __init__.py
│ ├── test_hash.py
│ ├── test_json.py
│ ├── test_list.py
│ ├── test_pub_sub.py
│ ├── test_redis_query_engine.py
│ ├── test_server_management.py
│ ├── test_set.py
│ ├── test_sorted_set.py
│ ├── test_stream.py
│ └── test_string.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_connection.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for src/common/connection.py
"""
from unittest.mock import Mock, patch
import pytest
from redis.exceptions import ConnectionError
from src.common.connection import RedisConnectionManager
class TestRedisConnectionManager:
"""Test cases for RedisConnectionManager class."""
def setup_method(self):
"""Set up test fixtures."""
# Reset singleton instance before each test
RedisConnectionManager._instance = None
def teardown_method(self):
"""Clean up after each test."""
# Reset singleton instance after each test
RedisConnectionManager._instance = None
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_get_connection_standalone_mode(self, mock_config, mock_redis_class):
"""Test getting connection in standalone mode."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "localhost",
"port": 6379,
"db": 0,
"username": None,
"password": "",
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
mock_redis_instance = Mock()
mock_redis_class.return_value = mock_redis_instance
connection = RedisConnectionManager.get_connection()
assert connection == mock_redis_instance
mock_redis_class.assert_called_once()
# Verify connection parameters
call_args = mock_redis_class.call_args[1]
assert call_args["host"] == "localhost"
assert call_args["port"] == 6379
assert call_args["db"] == 0
assert call_args["decode_responses"] is True
assert call_args["max_connections"] == 10
assert "lib_name" in call_args
@patch("src.common.connection.redis.cluster.RedisCluster")
@patch("src.common.connection.REDIS_CFG")
def test_get_connection_cluster_mode(self, mock_config, mock_cluster_class):
"""Test getting connection in cluster mode."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": True,
"host": "localhost",
"port": 6379,
"username": "testuser",
"password": "testpass",
"ssl": True,
"ssl_ca_path": "/path/to/ca.pem",
"ssl_keyfile": "/path/to/key.pem",
"ssl_certfile": "/path/to/cert.pem",
"ssl_cert_reqs": "required",
"ssl_ca_certs": "/path/to/ca-bundle.pem",
}[key]
mock_cluster_instance = Mock()
mock_cluster_class.return_value = mock_cluster_instance
connection = RedisConnectionManager.get_connection()
assert connection == mock_cluster_instance
mock_cluster_class.assert_called_once()
# Verify connection parameters
call_args = mock_cluster_class.call_args[1]
assert call_args["host"] == "localhost"
assert call_args["port"] == 6379
assert call_args["username"] == "testuser"
assert call_args["password"] == "testpass"
assert call_args["ssl"] is True
assert call_args["ssl_ca_path"] == "/path/to/ca.pem"
assert call_args["decode_responses"] is True
assert call_args["max_connections_per_node"] == 10
assert "lib_name" in call_args
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_get_connection_singleton_behavior(self, mock_config, mock_redis_class):
"""Test that get_connection returns the same instance (singleton behavior)."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "localhost",
"port": 6379,
"db": 0,
"username": None,
"password": "",
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
mock_redis_instance = Mock()
mock_redis_class.return_value = mock_redis_instance
# First call
connection1 = RedisConnectionManager.get_connection()
# Second call
connection2 = RedisConnectionManager.get_connection()
assert connection1 == connection2
assert connection1 == mock_redis_instance
# Redis class should only be called once
mock_redis_class.assert_called_once()
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_get_connection_with_decode_responses_false(
self, mock_config, mock_redis_class
):
"""Test getting connection with decode_responses=False."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "localhost",
"port": 6379,
"db": 0,
"username": None,
"password": "",
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
mock_redis_instance = Mock()
mock_redis_class.return_value = mock_redis_instance
connection = RedisConnectionManager.get_connection(decode_responses=False)
assert connection == mock_redis_instance
call_args = mock_redis_class.call_args[1]
assert call_args["decode_responses"] is False
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_get_connection_with_ssl_configuration(self, mock_config, mock_redis_class):
"""Test getting connection with SSL configuration."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "redis.example.com",
"port": 6380,
"db": 1,
"username": "ssluser",
"password": "sslpass",
"ssl": True,
"ssl_ca_path": "/path/to/ca.pem",
"ssl_keyfile": "/path/to/key.pem",
"ssl_certfile": "/path/to/cert.pem",
"ssl_cert_reqs": "optional",
"ssl_ca_certs": "/path/to/ca-bundle.pem",
}[key]
mock_redis_instance = Mock()
mock_redis_class.return_value = mock_redis_instance
connection = RedisConnectionManager.get_connection()
assert connection == mock_redis_instance
call_args = mock_redis_class.call_args[1]
assert call_args["ssl"] is True
assert call_args["ssl_ca_path"] == "/path/to/ca.pem"
assert call_args["ssl_keyfile"] == "/path/to/key.pem"
assert call_args["ssl_certfile"] == "/path/to/cert.pem"
assert call_args["ssl_cert_reqs"] == "optional"
assert call_args["ssl_ca_certs"] == "/path/to/ca-bundle.pem"
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_get_connection_includes_version_in_lib_name(
self, mock_config, mock_redis_class
):
"""Test that connection includes version information in lib_name."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "localhost",
"port": 6379,
"db": 0,
"username": None,
"password": "",
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
mock_redis_instance = Mock()
mock_redis_class.return_value = mock_redis_instance
with patch("src.common.connection.__version__", "1.0.0"):
connection = RedisConnectionManager.get_connection()
assert connection == mock_redis_instance
call_args = mock_redis_class.call_args[1]
assert "redis-py(mcp-server_v1.0.0)" in call_args["lib_name"]
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_connection_error_handling(self, mock_config, mock_redis_class):
"""Test connection error handling."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "localhost",
"port": 6379,
"db": 0,
"username": None,
"password": "",
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
# Mock Redis constructor to raise ConnectionError
mock_redis_class.side_effect = ConnectionError("Connection refused")
with pytest.raises(ConnectionError, match="Connection refused"):
RedisConnectionManager.get_connection()
@patch("src.common.connection.redis.cluster.RedisCluster")
@patch("src.common.connection.REDIS_CFG")
def test_cluster_connection_error_handling(self, mock_config, mock_cluster_class):
"""Test cluster connection error handling."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": True,
"host": "localhost",
"port": 6379,
"username": None,
"password": "",
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
# Mock RedisCluster constructor to raise ConnectionError
mock_cluster_class.side_effect = ConnectionError("Cluster connection failed")
with pytest.raises(ConnectionError, match="Cluster connection failed"):
RedisConnectionManager.get_connection()
def test_reset_instance(self):
"""Test that the singleton instance can be reset."""
# Set up a mock instance
mock_instance = Mock()
RedisConnectionManager._instance = mock_instance
# Verify instance is set
assert RedisConnectionManager._instance == mock_instance
# Reset instance
RedisConnectionManager._instance = None
# Verify instance is reset
assert RedisConnectionManager._instance is None
@patch("src.common.connection.redis.Redis")
@patch("src.common.connection.REDIS_CFG")
def test_connection_parameters_filtering(self, mock_config, mock_redis_class):
"""Test that None values are properly handled in connection parameters."""
mock_config.__getitem__.side_effect = lambda key: {
"cluster_mode": False,
"host": "localhost",
"port": 6379,
"db": 0,
"username": None, # This should be passed as None
"password": "", # This should be passed as empty string
"ssl": False,
"ssl_ca_path": None,
"ssl_keyfile": None,
"ssl_certfile": None,
"ssl_cert_reqs": "required",
"ssl_ca_certs": None,
}[key]
mock_redis_instance = Mock()
mock_redis_class.return_value = mock_redis_instance
connection = RedisConnectionManager.get_connection()
assert connection == mock_redis_instance
call_args = mock_redis_class.call_args[1]
assert call_args["username"] is None
assert call_args["password"] == ""
assert call_args["ssl_ca_path"] is None
```
--------------------------------------------------------------------------------
/tests/tools/test_pub_sub.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for src/tools/pub_sub.py
"""
from unittest.mock import Mock, patch
import pytest
from redis.exceptions import ConnectionError, RedisError
from src.tools.pub_sub import publish, subscribe, unsubscribe
class TestPubSubOperations:
"""Test cases for Redis pub/sub operations."""
@pytest.mark.asyncio
async def test_publish_success(self, mock_redis_connection_manager):
"""Test successful publish operation."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = (
2 # Number of subscribers that received the message
)
result = await publish("test_channel", "Hello World")
mock_redis.publish.assert_called_once_with("test_channel", "Hello World")
assert "Message published to channel 'test_channel'" in result
@pytest.mark.asyncio
async def test_publish_no_subscribers(self, mock_redis_connection_manager):
"""Test publish operation with no subscribers."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 0 # No subscribers
result = await publish("empty_channel", "Hello World")
mock_redis.publish.assert_called_once_with("empty_channel", "Hello World")
assert "Message published to channel 'empty_channel'" in result
@pytest.mark.asyncio
async def test_publish_redis_error(self, mock_redis_connection_manager):
"""Test publish operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.side_effect = RedisError("Connection failed")
result = await publish("test_channel", "Hello World")
assert (
"Error publishing message to channel 'test_channel': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_publish_connection_error(self, mock_redis_connection_manager):
"""Test publish operation with connection error."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.side_effect = ConnectionError("Redis server unavailable")
result = await publish("test_channel", "Hello World")
assert (
"Error publishing message to channel 'test_channel': Redis server unavailable"
in result
)
@pytest.mark.asyncio
async def test_publish_empty_message(self, mock_redis_connection_manager):
"""Test publish operation with empty message."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 1
result = await publish("test_channel", "")
mock_redis.publish.assert_called_once_with("test_channel", "")
assert "Message published to channel 'test_channel'" in result
@pytest.mark.asyncio
async def test_publish_numeric_message(self, mock_redis_connection_manager):
"""Test publish operation with numeric message."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 1
result = await publish("test_channel", 42)
mock_redis.publish.assert_called_once_with("test_channel", 42)
assert "Message published to channel 'test_channel'" in result
@pytest.mark.asyncio
async def test_publish_json_message(self, mock_redis_connection_manager):
"""Test publish operation with JSON-like message."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 3
json_message = (
'{"type": "notification", "data": {"user": "john", "action": "login"}}'
)
result = await publish("notifications", json_message)
mock_redis.publish.assert_called_once_with("notifications", json_message)
assert "Message published to channel 'notifications'" in result
@pytest.mark.asyncio
async def test_publish_unicode_message(self, mock_redis_connection_manager):
"""Test publish operation with unicode message."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 1
unicode_message = "Hello 世界 🌍"
result = await publish("test_channel", unicode_message)
mock_redis.publish.assert_called_once_with("test_channel", unicode_message)
assert "Message published to channel 'test_channel'" in result
@pytest.mark.asyncio
async def test_subscribe_success(self, mock_redis_connection_manager):
"""Test successful subscribe operation."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.subscribe.return_value = None
result = await subscribe("test_channel")
mock_redis.pubsub.assert_called_once()
mock_pubsub.subscribe.assert_called_once_with("test_channel")
assert "Subscribed to channel 'test_channel'" in result
@pytest.mark.asyncio
async def test_subscribe_redis_error(self, mock_redis_connection_manager):
"""Test subscribe operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.pubsub.side_effect = RedisError("Connection failed")
result = await subscribe("test_channel")
assert (
"Error subscribing to channel 'test_channel': Connection failed" in result
)
@pytest.mark.asyncio
async def test_subscribe_pubsub_error(self, mock_redis_connection_manager):
"""Test subscribe operation with pubsub creation error."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.subscribe.side_effect = RedisError("Subscribe failed")
result = await subscribe("test_channel")
assert "Error subscribing to channel 'test_channel': Subscribe failed" in result
@pytest.mark.asyncio
async def test_subscribe_multiple_channels_pattern(
self, mock_redis_connection_manager
):
"""Test subscribe operation with pattern-like channel name."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.subscribe.return_value = None
pattern_channel = "notifications:*"
result = await subscribe(pattern_channel)
mock_pubsub.subscribe.assert_called_once_with(pattern_channel)
assert f"Subscribed to channel '{pattern_channel}'" in result
@pytest.mark.asyncio
async def test_unsubscribe_success(self, mock_redis_connection_manager):
"""Test successful unsubscribe operation."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.unsubscribe.return_value = None
result = await unsubscribe("test_channel")
mock_redis.pubsub.assert_called_once()
mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
assert "Unsubscribed from channel 'test_channel'" in result
@pytest.mark.asyncio
async def test_unsubscribe_redis_error(self, mock_redis_connection_manager):
"""Test unsubscribe operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.pubsub.side_effect = RedisError("Connection failed")
result = await unsubscribe("test_channel")
assert (
"Error unsubscribing from channel 'test_channel': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_unsubscribe_pubsub_error(self, mock_redis_connection_manager):
"""Test unsubscribe operation with pubsub error."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.unsubscribe.side_effect = RedisError("Unsubscribe failed")
result = await unsubscribe("test_channel")
assert (
"Error unsubscribing from channel 'test_channel': Unsubscribe failed"
in result
)
@pytest.mark.asyncio
async def test_unsubscribe_from_all_channels(self, mock_redis_connection_manager):
"""Test unsubscribe operation without specifying channel (unsubscribe from all)."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.unsubscribe.return_value = None
# Test unsubscribing from specific channel
result = await unsubscribe("specific_channel")
mock_pubsub.unsubscribe.assert_called_once_with("specific_channel")
assert "Unsubscribed from channel 'specific_channel'" in result
@pytest.mark.asyncio
async def test_publish_to_pattern_channel(self, mock_redis_connection_manager):
"""Test publish operation to pattern-like channel."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 5
pattern_channel = "user:123:notifications"
result = await publish(pattern_channel, "User notification")
mock_redis.publish.assert_called_once_with(pattern_channel, "User notification")
assert f"Message published to channel '{pattern_channel}'" in result
@pytest.mark.asyncio
async def test_subscribe_with_special_characters(
self, mock_redis_connection_manager
):
"""Test subscribe operation with special characters in channel name."""
mock_redis = mock_redis_connection_manager
mock_pubsub = Mock()
mock_redis.pubsub.return_value = mock_pubsub
mock_pubsub.subscribe.return_value = None
special_channel = "channel:with:colons-and-dashes_and_underscores"
result = await subscribe(special_channel)
mock_pubsub.subscribe.assert_called_once_with(special_channel)
assert f"Subscribed to channel '{special_channel}'" in result
@pytest.mark.asyncio
async def test_connection_manager_called_correctly(self):
"""Test that RedisConnectionManager.get_connection is called correctly."""
with patch(
"src.tools.pub_sub.RedisConnectionManager.get_connection"
) as mock_get_conn:
mock_redis = Mock()
mock_redis.publish.return_value = 1
mock_get_conn.return_value = mock_redis
await publish("test_channel", "test_message")
mock_get_conn.assert_called_once()
@pytest.mark.asyncio
async def test_function_signatures(self):
"""Test that functions have correct signatures."""
import inspect
# Test publish function signature
publish_sig = inspect.signature(publish)
publish_params = list(publish_sig.parameters.keys())
assert publish_params == ["channel", "message"]
# Test subscribe function signature
subscribe_sig = inspect.signature(subscribe)
subscribe_params = list(subscribe_sig.parameters.keys())
assert subscribe_params == ["channel"]
# Test unsubscribe function signature
unsubscribe_sig = inspect.signature(unsubscribe)
unsubscribe_params = list(unsubscribe_sig.parameters.keys())
assert unsubscribe_params == ["channel"]
@pytest.mark.asyncio
async def test_publish_large_message(self, mock_redis_connection_manager):
"""Test publish operation with large message."""
mock_redis = mock_redis_connection_manager
mock_redis.publish.return_value = 1
large_message = "x" * 10000 # 10KB message
result = await publish("test_channel", large_message)
mock_redis.publish.assert_called_once_with("test_channel", large_message)
assert "Message published to channel 'test_channel'" in result
```
--------------------------------------------------------------------------------
/tests/tools/test_stream.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for src/tools/stream.py
"""
from unittest.mock import Mock, patch
import pytest
from redis.exceptions import RedisError
from src.tools.stream import xadd, xdel, xrange
class TestStreamOperations:
"""Test cases for Redis stream operations."""
@pytest.mark.asyncio
async def test_xadd_success(self, mock_redis_connection_manager):
"""Test successful stream add operation."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890123-0" # Stream entry ID
fields = {"field1": "value1", "field2": "value2"}
result = await xadd("test_stream", fields)
mock_redis.xadd.assert_called_once_with("test_stream", fields)
assert "Successfully added entry 1234567890123-0 to test_stream" in result
assert "1234567890123-0" in result
@pytest.mark.asyncio
async def test_xadd_with_expiration(self, mock_redis_connection_manager):
"""Test stream add operation with expiration."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890124-0"
mock_redis.expire.return_value = True
fields = {"message": "test message"}
result = await xadd("test_stream", fields, 60)
mock_redis.xadd.assert_called_once_with("test_stream", fields)
mock_redis.expire.assert_called_once_with("test_stream", 60)
assert "with expiration 60 seconds" in result
@pytest.mark.asyncio
async def test_xadd_single_field(self, mock_redis_connection_manager):
"""Test stream add operation with single field."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890125-0"
fields = {"message": "single field message"}
result = await xadd("test_stream", fields)
mock_redis.xadd.assert_called_once_with("test_stream", fields)
assert "Successfully added entry 1234567890125-0 to test_stream" in result
@pytest.mark.asyncio
async def test_xadd_redis_error(self, mock_redis_connection_manager):
"""Test stream add operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.side_effect = RedisError("Connection failed")
fields = {"field1": "value1"}
result = await xadd("test_stream", fields)
assert "Error adding to stream test_stream: Connection failed" in result
@pytest.mark.asyncio
async def test_xadd_with_numeric_values(self, mock_redis_connection_manager):
"""Test stream add operation with numeric field values."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890126-0"
fields = {"count": 42, "price": 19.99, "active": True}
result = await xadd("test_stream", fields)
mock_redis.xadd.assert_called_once_with("test_stream", fields)
assert "Successfully added entry 1234567890126-0 to test_stream" in result
@pytest.mark.asyncio
async def test_xrange_success(self, mock_redis_connection_manager):
"""Test successful stream range operation."""
mock_redis = mock_redis_connection_manager
mock_entries = [
("1234567890123-0", {"field1": "value1", "field2": "value2"}),
("1234567890124-0", {"field1": "value3", "field2": "value4"}),
]
mock_redis.xrange.return_value = mock_entries
result = await xrange("test_stream")
mock_redis.xrange.assert_called_once_with("test_stream", count=1)
assert result == str(mock_entries)
@pytest.mark.asyncio
async def test_xrange_with_custom_count(self, mock_redis_connection_manager):
"""Test stream range operation with custom count."""
mock_redis = mock_redis_connection_manager
mock_entries = [
("1234567890123-0", {"message": "entry1"}),
("1234567890124-0", {"message": "entry2"}),
("1234567890125-0", {"message": "entry3"}),
]
mock_redis.xrange.return_value = mock_entries
result = await xrange("test_stream", 3)
mock_redis.xrange.assert_called_once_with("test_stream", count=3)
assert result == str(mock_entries)
# Check the original mock_entries length
assert len(mock_entries) == 3
@pytest.mark.asyncio
async def test_xrange_empty_stream(self, mock_redis_connection_manager):
"""Test stream range operation on empty stream."""
mock_redis = mock_redis_connection_manager
mock_redis.xrange.return_value = []
result = await xrange("empty_stream")
assert "Stream empty_stream is empty or does not exist" in result
@pytest.mark.asyncio
async def test_xrange_redis_error(self, mock_redis_connection_manager):
"""Test stream range operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.xrange.side_effect = RedisError("Connection failed")
result = await xrange("test_stream")
assert "Error reading from stream test_stream: Connection failed" in result
@pytest.mark.asyncio
async def test_xdel_success(self, mock_redis_connection_manager):
"""Test successful stream delete operation."""
mock_redis = mock_redis_connection_manager
mock_redis.xdel.return_value = 1 # Number of entries deleted
result = await xdel("test_stream", "1234567890123-0")
mock_redis.xdel.assert_called_once_with("test_stream", "1234567890123-0")
assert "Successfully deleted entry 1234567890123-0 from test_stream" in result
@pytest.mark.asyncio
async def test_xdel_entry_not_found(self, mock_redis_connection_manager):
"""Test stream delete operation when entry doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.xdel.return_value = 0 # No entries deleted
result = await xdel("test_stream", "nonexistent-entry-id")
assert "Entry nonexistent-entry-id not found in test_stream" in result
@pytest.mark.asyncio
async def test_xdel_redis_error(self, mock_redis_connection_manager):
"""Test stream delete operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.xdel.side_effect = RedisError("Connection failed")
result = await xdel("test_stream", "1234567890123-0")
assert "Error deleting from stream test_stream: Connection failed" in result
@pytest.mark.asyncio
async def test_xadd_with_empty_fields(self, mock_redis_connection_manager):
"""Test stream add operation with empty fields dictionary."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890127-0"
fields = {}
result = await xadd("test_stream", fields)
mock_redis.xadd.assert_called_once_with("test_stream", fields)
assert "Successfully added entry 1234567890127-0 to test_stream" in result
@pytest.mark.asyncio
async def test_xadd_with_unicode_values(self, mock_redis_connection_manager):
"""Test stream add operation with unicode field values."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890128-0"
fields = {"message": "Hello 世界 🌍", "user": "测试用户"}
result = await xadd("test_stream", fields)
mock_redis.xadd.assert_called_once_with("test_stream", fields)
assert "Successfully added entry 1234567890128-0 to test_stream" in result
@pytest.mark.asyncio
async def test_xrange_large_count(self, mock_redis_connection_manager):
"""Test stream range operation with large count."""
mock_redis = mock_redis_connection_manager
mock_entries = [
(f"123456789012{i}-0", {"data": f"entry_{i}"}) for i in range(100)
]
mock_redis.xrange.return_value = mock_entries
result = await xrange("test_stream", 100)
mock_redis.xrange.assert_called_once_with("test_stream", count=100)
# The function returns a string representation
assert result == str(mock_entries)
# Check the original mock_entries length
assert len(mock_entries) == 100
@pytest.mark.asyncio
async def test_xdel_multiple_entries_behavior(self, mock_redis_connection_manager):
"""Test that xdel function handles single entry correctly."""
mock_redis = mock_redis_connection_manager
mock_redis.xdel.return_value = 1
result = await xdel("test_stream", "single-entry-id")
# Should call xdel with single entry ID, not multiple
mock_redis.xdel.assert_called_once_with("test_stream", "single-entry-id")
assert "Successfully deleted entry single-entry-id from test_stream" in result
@pytest.mark.asyncio
async def test_xadd_expiration_error(self, mock_redis_connection_manager):
"""Test stream add operation when expiration fails."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890129-0"
mock_redis.expire.side_effect = RedisError("Expire failed")
fields = {"message": "test"}
result = await xadd("test_stream", fields, 60)
assert "Error adding to stream test_stream: Expire failed" in result
@pytest.mark.asyncio
async def test_xrange_single_entry(self, mock_redis_connection_manager):
"""Test stream range operation returning single entry."""
mock_redis = mock_redis_connection_manager
mock_entries = [("1234567890123-0", {"single": "entry"})]
mock_redis.xrange.return_value = mock_entries
result = await xrange("test_stream", 1)
assert result == "[('1234567890123-0', {'single': 'entry'})]"
# Check the original mock_entries length
assert len(mock_entries) == 1
@pytest.mark.asyncio
async def test_connection_manager_called_correctly(self):
"""Test that RedisConnectionManager.get_connection is called correctly."""
with patch(
"src.tools.stream.RedisConnectionManager.get_connection"
) as mock_get_conn:
mock_redis = Mock()
mock_redis.xadd.return_value = "1234567890123-0"
mock_get_conn.return_value = mock_redis
await xadd("test_stream", {"field": "value"})
mock_get_conn.assert_called_once()
@pytest.mark.asyncio
async def test_function_signatures(self):
"""Test that functions have correct signatures."""
import inspect
# Test xadd function signature
xadd_sig = inspect.signature(xadd)
xadd_params = list(xadd_sig.parameters.keys())
assert xadd_params == ["key", "fields", "expiration"]
assert xadd_sig.parameters["expiration"].default is None
# Test xrange function signature
xrange_sig = inspect.signature(xrange)
xrange_params = list(xrange_sig.parameters.keys())
assert xrange_params == ["key", "count"]
assert xrange_sig.parameters["count"].default == 1
# Test xdel function signature
xdel_sig = inspect.signature(xdel)
xdel_params = list(xdel_sig.parameters.keys())
assert xdel_params == ["key", "entry_id"]
@pytest.mark.asyncio
async def test_xadd_with_complex_fields(self, mock_redis_connection_manager):
"""Test stream add operation with complex field structure."""
mock_redis = mock_redis_connection_manager
mock_redis.xadd.return_value = "1234567890130-0"
fields = {
"event_type": "user_action",
"user_id": "12345",
"timestamp": "2024-01-01T12:00:00Z",
"metadata": '{"browser": "chrome", "version": "120"}',
"score": 95.5,
"active": True,
}
result = await xadd("events_stream", fields)
mock_redis.xadd.assert_called_once_with("events_stream", fields)
assert "Successfully added entry 1234567890130-0 to events_stream" in result
```
--------------------------------------------------------------------------------
/tests/tools/test_json.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for src/tools/json.py
"""
import json
import pytest
from redis.exceptions import RedisError
from src.tools.json import json_del, json_get, json_set
class TestJSONOperations:
"""Test cases for Redis JSON operations."""
@pytest.mark.asyncio
async def test_json_set_success(
self, mock_redis_connection_manager, sample_json_data
):
"""Test successful JSON set operation."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
result = await json_set("test_doc", "$", sample_json_data)
mock_redis.json.return_value.set.assert_called_once_with(
"test_doc", "$", sample_json_data
)
assert "JSON value set at path '$' in 'test_doc'." in result
@pytest.mark.asyncio
async def test_json_set_with_expiration(
self, mock_redis_connection_manager, sample_json_data
):
"""Test JSON set operation with expiration."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
mock_redis.expire.return_value = True
result = await json_set("test_doc", "$.name", "John Updated", 60)
mock_redis.json.return_value.set.assert_called_once_with(
"test_doc", "$.name", "John Updated"
)
mock_redis.expire.assert_called_once_with("test_doc", 60)
assert "Expires in 60 seconds" in result
@pytest.mark.asyncio
async def test_json_set_nested_path(self, mock_redis_connection_manager):
"""Test JSON set operation with nested path."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
result = await json_set("test_doc", "$.user.profile.age", 25)
mock_redis.json.return_value.set.assert_called_once_with(
"test_doc", "$.user.profile.age", 25
)
assert "JSON value set at path '$.user.profile.age'" in result
@pytest.mark.asyncio
async def test_json_set_redis_error(self, mock_redis_connection_manager):
"""Test JSON set operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.side_effect = RedisError(
"JSON module not loaded"
)
result = await json_set("test_doc", "$", {"key": "value"})
assert (
"Error setting JSON value at path '$' in 'test_doc': JSON module not loaded"
in result
)
@pytest.mark.asyncio
async def test_json_get_success(
self, mock_redis_connection_manager, sample_json_data
):
"""Test successful JSON get operation."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.get.return_value = sample_json_data
result = await json_get("test_doc", "$")
mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$")
# json_get returns a JSON string representation
assert result == json.dumps(sample_json_data, ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_json_get_specific_field(self, mock_redis_connection_manager):
"""Test JSON get operation for specific field."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.get.return_value = ["John Doe"]
result = await json_get("test_doc", "$.name")
mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$.name")
# json_get returns a JSON string representation
assert result == json.dumps(["John Doe"], ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_json_get_default_path(
self, mock_redis_connection_manager, sample_json_data
):
"""Test JSON get operation with default path."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.get.return_value = sample_json_data
result = await json_get("test_doc")
mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$")
# json_get returns a JSON string representation
assert result == json.dumps(sample_json_data, ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_json_get_not_found(self, mock_redis_connection_manager):
"""Test JSON get operation when document doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.get.return_value = None
result = await json_get("nonexistent_doc", "$")
assert "No data found at path '$' in 'nonexistent_doc'" in result
@pytest.mark.asyncio
async def test_json_get_redis_error(self, mock_redis_connection_manager):
"""Test JSON get operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.get.side_effect = RedisError("Connection failed")
result = await json_get("test_doc", "$")
assert (
"Error retrieving JSON value at path '$' in 'test_doc': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_json_del_success(self, mock_redis_connection_manager):
"""Test successful JSON delete operation."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.delete.return_value = 1
result = await json_del("test_doc", "$.name")
mock_redis.json.return_value.delete.assert_called_once_with(
"test_doc", "$.name"
)
assert "Deleted JSON value at path '$.name' in 'test_doc'" in result
@pytest.mark.asyncio
async def test_json_del_default_path(self, mock_redis_connection_manager):
"""Test JSON delete operation with default path (entire document)."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.delete.return_value = 1
result = await json_del("test_doc")
mock_redis.json.return_value.delete.assert_called_once_with("test_doc", "$")
assert "Deleted JSON value at path '$' in 'test_doc'" in result
@pytest.mark.asyncio
async def test_json_del_not_found(self, mock_redis_connection_manager):
"""Test JSON delete operation when path doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.delete.return_value = 0
result = await json_del("test_doc", "$.nonexistent")
assert "No JSON value found at path '$.nonexistent' in 'test_doc'" in result
@pytest.mark.asyncio
async def test_json_del_redis_error(self, mock_redis_connection_manager):
"""Test JSON delete operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.delete.side_effect = RedisError(
"Connection failed"
)
result = await json_del("test_doc", "$.name")
assert (
"Error deleting JSON value at path '$.name' in 'test_doc': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_json_set_with_array(self, mock_redis_connection_manager):
"""Test JSON set operation with array value."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
array_data = ["item1", "item2", "item3"]
result = await json_set("test_doc", "$.items", array_data)
mock_redis.json.return_value.set.assert_called_once_with(
"test_doc", "$.items", array_data
)
assert "JSON value set at path '$.items'" in result
@pytest.mark.asyncio
async def test_json_set_with_nested_object(self, mock_redis_connection_manager):
"""Test JSON set operation with nested object."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
nested_data = {
"user": {
"profile": {
"name": "John",
"settings": {"theme": "dark", "notifications": True},
}
}
}
result = await json_set("test_doc", "$", nested_data)
mock_redis.json.return_value.set.assert_called_once_with(
"test_doc", "$", nested_data
)
assert "JSON value set at path '$'" in result
@pytest.mark.asyncio
async def test_json_get_array_element(self, mock_redis_connection_manager):
"""Test JSON get operation for array element."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.get.return_value = ["first_item"]
result = await json_get("test_doc", "$.items[0]")
mock_redis.json.return_value.get.assert_called_once_with(
"test_doc", "$.items[0]"
)
# json_get returns a JSON string representation
assert result == json.dumps(["first_item"], ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_json_operations_with_numeric_values(
self, mock_redis_connection_manager
):
"""Test JSON operations with numeric values."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
mock_redis.json.return_value.get.return_value = [42]
# Set numeric value
await json_set("test_doc", "$.count", 42)
mock_redis.json.return_value.set.assert_called_with("test_doc", "$.count", 42)
# Get numeric value
result = await json_get("test_doc", "$.count")
assert result == json.dumps([42], ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_json_operations_with_boolean_values(
self, mock_redis_connection_manager
):
"""Test JSON operations with boolean values."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
mock_redis.json.return_value.get.return_value = [True]
# Set boolean value
await json_set("test_doc", "$.active", True)
mock_redis.json.return_value.set.assert_called_with(
"test_doc", "$.active", True
)
# Get boolean value
result = await json_get("test_doc", "$.active")
assert result == json.dumps([True], ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_json_set_expiration_error(self, mock_redis_connection_manager):
"""Test JSON set operation when expiration fails."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
mock_redis.expire.side_effect = RedisError("Expire failed")
result = await json_set("test_doc", "$", {"key": "value"}, 60)
assert (
"Error setting JSON value at path '$' in 'test_doc': Expire failed"
in result
)
@pytest.mark.asyncio
async def test_json_del_multiple_matches(self, mock_redis_connection_manager):
"""Test JSON delete operation that matches multiple elements."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.delete.return_value = (
3 # Multiple elements deleted
)
result = await json_del("test_doc", "$..name")
mock_redis.json.return_value.delete.assert_called_once_with(
"test_doc", "$..name"
)
assert "Deleted JSON value at path '$..name'" in result
@pytest.mark.asyncio
async def test_json_operations_with_null_values(
self, mock_redis_connection_manager
):
"""Test JSON operations with null values."""
mock_redis = mock_redis_connection_manager
mock_redis.json.return_value.set.return_value = "OK"
mock_redis.json.return_value.get.return_value = [None]
# Set null value
await json_set("test_doc", "$.optional_field", None)
mock_redis.json.return_value.set.assert_called_with(
"test_doc", "$.optional_field", None
)
# Get null value
result = await json_get("test_doc", "$.optional_field")
assert result == json.dumps([None], ensure_ascii=False, indent=2)
```
--------------------------------------------------------------------------------
/tests/tools/test_redis_query_engine.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for src/tools/redis_query_engine.py
"""
import json
from unittest.mock import Mock, patch
import pytest
from redis.commands.search.field import VectorField
from redis.commands.search.index_definition import IndexDefinition
from redis.commands.search.query import Query
from redis.exceptions import RedisError
from src.tools.redis_query_engine import (
create_vector_index_hash,
get_index_info,
get_indexes,
vector_search_hash,
)
class TestRedisQueryEngineOperations:
"""Test cases for Redis query engine operations."""
@pytest.mark.asyncio
async def test_get_indexes_success(self, mock_redis_connection_manager):
"""Test successful get indexes operation."""
mock_redis = mock_redis_connection_manager
mock_indexes = ["index1", "index2", "vector_index"]
mock_redis.execute_command.return_value = mock_indexes
result = await get_indexes()
mock_redis.execute_command.assert_called_once_with("FT._LIST")
assert result == json.dumps(mock_indexes)
@pytest.mark.asyncio
async def test_get_indexes_empty(self, mock_redis_connection_manager):
"""Test get indexes operation with no indexes."""
mock_redis = mock_redis_connection_manager
mock_redis.execute_command.return_value = []
result = await get_indexes()
assert result == json.dumps([])
@pytest.mark.asyncio
async def test_get_indexes_redis_error(self, mock_redis_connection_manager):
"""Test get indexes operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.execute_command.side_effect = RedisError("Search module not loaded")
result = await get_indexes()
assert "Error retrieving indexes: Search module not loaded" in result
@pytest.mark.asyncio
async def test_create_vector_index_hash_success(
self, mock_redis_connection_manager
):
"""Test successful vector index creation."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.create_index.return_value = "OK"
result = await create_vector_index_hash()
mock_redis.ft.assert_called_once_with("vector_index")
mock_ft.create_index.assert_called_once()
# Verify the create_index call arguments
call_args = mock_ft.create_index.call_args
fields = call_args[0][0] # First positional argument (fields)
definition = call_args[1]["definition"] # Keyword argument
assert len(fields) == 1
assert isinstance(fields[0], VectorField)
assert fields[0].name == "vector"
assert isinstance(definition, IndexDefinition)
assert "Index 'vector_index' created successfully." in result
@pytest.mark.asyncio
async def test_create_vector_index_hash_custom_params(
self, mock_redis_connection_manager
):
"""Test vector index creation with custom parameters."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.create_index.return_value = "OK"
result = await create_vector_index_hash(
index_name="custom_index",
vector_field="embedding",
dim=512,
distance_metric="COSINE",
)
mock_redis.ft.assert_called_once_with("custom_index")
# Verify the field configuration
call_args = mock_ft.create_index.call_args
fields = call_args[0][0]
assert fields[0].name == "embedding"
assert "Index 'custom_index' created successfully." in result
@pytest.mark.asyncio
async def test_create_vector_index_hash_redis_error(
self, mock_redis_connection_manager
):
"""Test vector index creation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.create_index.side_effect = RedisError("Index already exists")
result = await create_vector_index_hash()
assert "Error creating index 'vector_index': Index already exists" in result
@pytest.mark.asyncio
async def test_vector_search_hash_success(
self, mock_redis_connection_manager, sample_vector
):
"""Test successful vector search operation."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
# Mock search results
mock_doc1 = Mock()
mock_doc1.__dict__ = {"id": "doc1", "vector": "binary_data", "score": "0.95"}
mock_doc2 = Mock()
mock_doc2.__dict__ = {"id": "doc2", "vector": "binary_data", "score": "0.87"}
mock_result = Mock()
mock_result.docs = [mock_doc1, mock_doc2]
mock_ft.search.return_value = mock_result
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value.tobytes.return_value = b"query_vector_bytes"
result = await vector_search_hash(sample_vector)
mock_redis.ft.assert_called_once_with("vector_index")
mock_ft.search.assert_called_once()
# Verify the search query
search_call_args = mock_ft.search.call_args[0][0]
assert isinstance(search_call_args, Query)
assert isinstance(result, list)
assert len(result) == 2
@pytest.mark.asyncio
async def test_vector_search_hash_custom_params(
self, mock_redis_connection_manager, sample_vector
):
"""Test vector search with custom parameters."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_result = Mock()
mock_result.docs = []
mock_ft.search.return_value = mock_result
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value.tobytes.return_value = b"query_vector_bytes"
result = await vector_search_hash(
query_vector=sample_vector,
index_name="custom_index",
vector_field="embedding",
k=10,
return_fields=["title", "content"],
)
mock_redis.ft.assert_called_once_with("custom_index")
assert isinstance(result, list)
@pytest.mark.asyncio
async def test_vector_search_hash_no_results(
self, mock_redis_connection_manager, sample_vector
):
"""Test vector search with no results."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_result = Mock()
mock_result.docs = []
mock_ft.search.return_value = mock_result
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value.tobytes.return_value = b"query_vector_bytes"
result = await vector_search_hash(sample_vector)
assert result == [] # Empty list when no results
@pytest.mark.asyncio
async def test_vector_search_hash_redis_error(
self, mock_redis_connection_manager, sample_vector
):
"""Test vector search with Redis error."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.search.side_effect = RedisError("Index not found")
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value.astype.return_value.tobytes.return_value = (
b"query_vector_bytes"
)
result = await vector_search_hash(sample_vector)
assert (
"Error performing vector search on index 'vector_index': Index not found"
in result
)
@pytest.mark.asyncio
async def test_get_index_info_success(self, mock_redis_connection_manager):
"""Test successful get index info operation."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_info = {
"index_name": "vector_index",
"index_options": [],
"index_definition": ["key_type", "HASH", "prefixes", ["doc:"]],
"attributes": [
["identifier", "vector", "attribute", "vector", "type", "VECTOR"]
],
"num_docs": "100",
"max_doc_id": "100",
"num_terms": "0",
"num_records": "100",
"inverted_sz_mb": "0.00",
"vector_index_sz_mb": "1.50",
"total_inverted_index_blocks": "0",
"offset_vectors_sz_mb": "0.00",
"doc_table_size_mb": "0.01",
"sortable_values_size_mb": "0.00",
"key_table_size_mb": "0.00",
}
mock_ft.info.return_value = mock_info
result = await get_index_info("vector_index")
mock_redis.ft.assert_called_once_with("vector_index")
mock_ft.info.assert_called_once()
# get_index_info returns a JSON string representation
assert result == json.dumps(mock_info, ensure_ascii=False, indent=2)
@pytest.mark.asyncio
async def test_get_index_info_default_index(self, mock_redis_connection_manager):
"""Test get index info with default index name."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.info.return_value = {"index_name": "vector_index"}
result = await get_index_info("vector_index")
mock_redis.ft.assert_called_once_with("vector_index")
# get_index_info returns a JSON string representation
assert result == json.dumps(
{"index_name": "vector_index"}, ensure_ascii=False, indent=2
)
@pytest.mark.asyncio
async def test_get_index_info_redis_error(self, mock_redis_connection_manager):
"""Test get index info with Redis error."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.info.side_effect = RedisError("Index not found")
result = await get_index_info("nonexistent_index")
assert "Error retrieving index info: Index not found" in result
@pytest.mark.asyncio
async def test_create_vector_index_different_metrics(
self, mock_redis_connection_manager
):
"""Test vector index creation with different distance metrics."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_ft.create_index.return_value = "OK"
# Test L2 metric
await create_vector_index_hash(distance_metric="L2")
mock_ft.create_index.assert_called()
# Test IP metric
mock_ft.reset_mock()
await create_vector_index_hash(distance_metric="IP")
mock_ft.create_index.assert_called()
@pytest.mark.asyncio
async def test_vector_search_with_large_k(
self, mock_redis_connection_manager, sample_vector
):
"""Test vector search with large k value."""
mock_redis = mock_redis_connection_manager
mock_ft = Mock()
mock_redis.ft.return_value = mock_ft
mock_result = Mock()
mock_result.docs = []
mock_ft.search.return_value = mock_result
with patch("numpy.array") as mock_np_array:
mock_np_array.return_value.astype.return_value.tobytes.return_value = (
b"query_vector_bytes"
)
result = await vector_search_hash(sample_vector, k=1000)
assert result == [] # Empty list when no results
# Should handle large k values
mock_ft.search.assert_called_once()
@pytest.mark.asyncio
async def test_connection_manager_called_correctly(self):
"""Test that RedisConnectionManager.get_connection is called correctly."""
with patch(
"src.tools.redis_query_engine.RedisConnectionManager.get_connection"
) as mock_get_conn:
mock_redis = Mock()
mock_redis.execute_command.return_value = []
mock_get_conn.return_value = mock_redis
await get_indexes()
mock_get_conn.assert_called_once()
```
--------------------------------------------------------------------------------
/tests/test_integration.py:
--------------------------------------------------------------------------------
```python
"""
Integration tests for Redis MCP Server.
These tests actually start the MCP server process and verify it can handle real requests.
"""
import json
import subprocess
import sys
import time
import os
from pathlib import Path
import pytest
def _redis_available():
"""Check if Redis is available for testing."""
try:
import redis
r = redis.Redis(host="localhost", port=6379, decode_responses=True)
r.ping()
return True
except Exception:
return False
def _create_server_process(project_root):
"""Create a server process with proper encoding for cross-platform compatibility."""
return subprocess.Popen(
[sys.executable, "-m", "src.main"],
cwd=project_root,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
encoding="utf-8",
errors="replace", # Replace invalid characters instead of failing
env={"REDIS_HOST": "localhost", "REDIS_PORT": "6379", **dict(os.environ)},
)
@pytest.mark.integration
class TestMCPServerIntegration:
"""Integration tests that start the actual MCP server."""
@pytest.fixture
def server_process(self):
"""Start the MCP server process for testing."""
# Get the project root directory
project_root = Path(__file__).parent.parent
# Start the server process with proper encoding for cross-platform compatibility
process = _create_server_process(project_root)
# Give the server a moment to start
time.sleep(1)
yield process
# Clean up
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait()
def test_server_starts_successfully(self, server_process):
"""Test that the MCP server starts without crashing."""
# Check if process is still running
assert server_process.poll() is None, "Server process should be running"
# Check for startup message in stderr
# Note: MCP servers typically output startup info to stderr
time.sleep(0.5) # Give time for startup message
# The server should still be running
assert server_process.poll() is None
def test_server_handles_unicode_on_windows(self, server_process):
"""Test that the server handles Unicode properly on Windows."""
# This test specifically addresses the Windows Unicode decode error
# Check if process is still running
assert server_process.poll() is None, "Server process should be running"
# Try to read any available output without blocking
# This should not cause a UnicodeDecodeError on Windows
try:
# Use a short timeout to avoid blocking
import select
import sys
if sys.platform == "win32":
# On Windows, we can't use select, so just check if process is alive
time.sleep(0.1)
assert server_process.poll() is None
else:
# On Unix-like systems, we can use select
ready, _, _ = select.select([server_process.stdout], [], [], 0.1)
# If there's output available, try to read it
if ready:
try:
server_process.stdout.read(1) # Read just one character
# If we get here, Unicode handling is working
assert True
except UnicodeDecodeError:
pytest.fail("Unicode decode error occurred")
except Exception:
# If any other error occurs, that's fine - we're just testing Unicode handling
pass
# Main assertion: process should still be running
assert server_process.poll() is None
def test_server_responds_to_initialize_request(self, server_process):
"""Test that the server responds to MCP initialize request."""
# MCP initialize request
initialize_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
}
# Send the request
request_json = json.dumps(initialize_request) + "\n"
server_process.stdin.write(request_json)
server_process.stdin.flush()
# Read the response
response_line = server_process.stdout.readline()
assert response_line.strip(), "Server should respond to initialize request"
# Parse the response
try:
response = json.loads(response_line)
assert response.get("jsonrpc") == "2.0"
assert response.get("id") == 1
assert "result" in response
except json.JSONDecodeError:
pytest.fail(f"Invalid JSON response: {response_line}")
def test_server_lists_tools(self, server_process):
"""Test that the server can list available tools."""
# First initialize
initialize_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
}
server_process.stdin.write(json.dumps(initialize_request) + "\n")
server_process.stdin.flush()
server_process.stdout.readline() # Read initialize response
# Send initialized notification
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized",
}
server_process.stdin.write(json.dumps(initialized_notification) + "\n")
server_process.stdin.flush()
# Request tools list
tools_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}
server_process.stdin.write(json.dumps(tools_request) + "\n")
server_process.stdin.flush()
# Read the response
response_line = server_process.stdout.readline()
response = json.loads(response_line)
assert response.get("jsonrpc") == "2.0"
assert response.get("id") == 2
assert "result" in response
assert "tools" in response["result"]
# Verify we have some Redis tools
tools = response["result"]["tools"]
tool_names = [tool["name"] for tool in tools]
# Should have basic Redis operations
expected_tools = [
"hset",
"hget",
"hdel",
"hgetall",
"hexists",
"set_vector_in_hash",
"get_vector_from_hash",
"json_set",
"json_get",
"json_del",
"lpush",
"rpush",
"lpop",
"rpop",
"lrange",
"llen",
"delete",
"type",
"expire",
"rename",
"scan_keys",
"scan_all_keys",
"publish",
"subscribe",
"unsubscribe",
"get_indexes",
"get_index_info",
"get_indexed_keys_number",
"create_vector_index_hash",
"vector_search_hash",
"dbsize",
"info",
"client_list",
"sadd",
"srem",
"smembers",
"zadd",
"zrange",
"zrem",
"xadd",
"xrange",
"xdel",
"set",
"get",
]
for tool in tool_names:
assert tool in expected_tools, (
f"Expected tool '{tool}' not found in {tool_names}"
)
def test_server_tool_count_and_names(self, server_process):
"""Test that the server registers the correct number of tools with expected names."""
# Initialize the server
self._initialize_server(server_process)
# Request tools list
tools_request = {"jsonrpc": "2.0", "id": 3, "method": "tools/list"}
server_process.stdin.write(json.dumps(tools_request) + "\n")
server_process.stdin.flush()
# Read the response
response_line = server_process.stdout.readline()
response = json.loads(response_line)
assert response.get("jsonrpc") == "2.0"
assert response.get("id") == 3
assert "result" in response
assert "tools" in response["result"]
tools = response["result"]["tools"]
tool_names = [tool["name"] for tool in tools]
# Expected tool count (based on @mcp.tool() decorators in codebase)
expected_tool_count = 44
assert len(tools) == expected_tool_count, (
f"Expected {expected_tool_count} tools, but got {len(tools)}"
)
# Expected tool names (alphabetically sorted for easier verification)
expected_tools = [
"client_list",
"create_vector_index_hash",
"dbsize",
"delete",
"expire",
"get",
"get_index_info",
"get_indexed_keys_number",
"get_indexes",
"get_vector_from_hash",
"hdel",
"hexists",
"hget",
"hgetall",
"hset",
"info",
"json_del",
"json_get",
"json_set",
"llen",
"lpop",
"lpush",
"lrange",
"publish",
"rename",
"rpop",
"rpush",
"sadd",
"scan_all_keys",
"scan_keys",
"set",
"set_vector_in_hash",
"smembers",
"srem",
"subscribe",
"type",
"unsubscribe",
"vector_search_hash",
"xadd",
"xdel",
"xrange",
"zadd",
"zrange",
"zrem",
]
# Verify all expected tools are present
missing_tools = set(expected_tools) - set(tool_names)
extra_tools = set(tool_names) - set(expected_tools)
assert not missing_tools, f"Missing expected tools: {sorted(missing_tools)}"
assert not extra_tools, f"Found unexpected tools: {sorted(extra_tools)}"
# Verify tool categories are represented
tool_categories = {
"string": ["get", "set"],
"hash": ["hget", "hset", "hgetall", "hdel", "hexists"],
"list": ["lpush", "rpush", "lpop", "rpop", "lrange", "llen"],
"set": ["sadd", "srem", "smembers"],
"sorted_set": ["zadd", "zrem", "zrange"],
"stream": ["xadd", "xdel", "xrange"],
"json": ["json_get", "json_set", "json_del"],
"pub_sub": ["publish", "subscribe", "unsubscribe"],
"server_mgmt": ["dbsize", "info", "client_list"],
"misc": [
"delete",
"expire",
"rename",
"type",
"scan_keys",
"scan_all_keys",
],
"vector_search": [
"create_vector_index_hash",
"vector_search_hash",
"get_indexes",
"get_index_info",
"set_vector_in_hash",
"get_vector_from_hash",
"get_indexed_keys_number",
],
}
for category, category_tools in tool_categories.items():
for tool in category_tools:
assert tool in tool_names, (
f"Tool '{tool}' from category '{category}' not found in registered tools"
)
def _initialize_server(self, server_process):
"""Helper to initialize the MCP server."""
# Send initialize request
initialize_request = {
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "test-client", "version": "1.0.0"},
},
}
server_process.stdin.write(json.dumps(initialize_request) + "\n")
server_process.stdin.flush()
server_process.stdout.readline() # Read response
# Send initialized notification
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized",
}
server_process.stdin.write(json.dumps(initialized_notification) + "\n")
server_process.stdin.flush()
```
--------------------------------------------------------------------------------
/tests/tools/test_hash.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for src/tools/hash.py
"""
import numpy as np
import pytest
from redis.exceptions import RedisError
from src.tools.hash import (
get_vector_from_hash,
hdel,
hexists,
hget,
hgetall,
hset,
set_vector_in_hash,
)
class TestHashOperations:
"""Test cases for Redis hash operations."""
@pytest.mark.asyncio
async def test_hset_success(self, mock_redis_connection_manager):
"""Test successful hash set operation."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
result = await hset("test_hash", "field1", "value1")
mock_redis.hset.assert_called_once_with("test_hash", "field1", "value1")
assert "Field 'field1' set successfully in hash 'test_hash'." in result
@pytest.mark.asyncio
async def test_hset_with_expiration(self, mock_redis_connection_manager):
"""Test hash set operation with expiration."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
mock_redis.expire.return_value = True
result = await hset("test_hash", "field1", "value1", 60)
mock_redis.hset.assert_called_once_with("test_hash", "field1", "value1")
mock_redis.expire.assert_called_once_with("test_hash", 60)
assert "Expires in 60 seconds." in result
@pytest.mark.asyncio
async def test_hset_integer_value(self, mock_redis_connection_manager):
"""Test hash set operation with integer value."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
result = await hset("test_hash", "count", 42)
mock_redis.hset.assert_called_once_with("test_hash", "count", "42")
assert "Field 'count' set successfully in hash 'test_hash'." in result
@pytest.mark.asyncio
async def test_hset_float_value(self, mock_redis_connection_manager):
"""Test hash set operation with float value."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
result = await hset("test_hash", "price", 19.99)
mock_redis.hset.assert_called_once_with("test_hash", "price", "19.99")
assert "Field 'price' set successfully in hash 'test_hash'." in result
@pytest.mark.asyncio
async def test_hset_redis_error(self, mock_redis_connection_manager):
"""Test hash set operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.side_effect = RedisError("Connection failed")
result = await hset("test_hash", "field1", "value1")
assert (
"Error setting field 'field1' in hash 'test_hash': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_hget_success(self, mock_redis_connection_manager):
"""Test successful hash get operation."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.return_value = "value1"
result = await hget("test_hash", "field1")
mock_redis.hget.assert_called_once_with("test_hash", "field1")
assert result == "value1"
@pytest.mark.asyncio
async def test_hget_field_not_found(self, mock_redis_connection_manager):
"""Test hash get operation when field doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.return_value = None
result = await hget("test_hash", "nonexistent_field")
assert "Field 'nonexistent_field' not found in hash 'test_hash'" in result
@pytest.mark.asyncio
async def test_hget_redis_error(self, mock_redis_connection_manager):
"""Test hash get operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.side_effect = RedisError("Connection failed")
result = await hget("test_hash", "field1")
assert (
"Error getting field 'field1' from hash 'test_hash': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_hgetall_success(self, mock_redis_connection_manager):
"""Test successful hash get all operation."""
mock_redis = mock_redis_connection_manager
mock_redis.hgetall.return_value = {"field1": "value1", "field2": "value2"}
result = await hgetall("test_hash")
mock_redis.hgetall.assert_called_once_with("test_hash")
assert result == {"field1": "value1", "field2": "value2"}
@pytest.mark.asyncio
async def test_hgetall_empty_hash(self, mock_redis_connection_manager):
"""Test hash get all operation on empty hash."""
mock_redis = mock_redis_connection_manager
mock_redis.hgetall.return_value = {}
result = await hgetall("empty_hash")
assert "Hash 'empty_hash' is empty or does not exist" in result
@pytest.mark.asyncio
async def test_hgetall_redis_error(self, mock_redis_connection_manager):
"""Test hash get all operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hgetall.side_effect = RedisError("Connection failed")
result = await hgetall("test_hash")
assert (
"Error getting all fields from hash 'test_hash': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_hdel_success(self, mock_redis_connection_manager):
"""Test successful hash delete operation."""
mock_redis = mock_redis_connection_manager
mock_redis.hdel.return_value = 1
result = await hdel("test_hash", "field1")
mock_redis.hdel.assert_called_once_with("test_hash", "field1")
assert "Field 'field1' deleted from hash 'test_hash'." in result
@pytest.mark.asyncio
async def test_hdel_field_not_found(self, mock_redis_connection_manager):
"""Test hash delete operation when field doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.hdel.return_value = 0
result = await hdel("test_hash", "nonexistent_field")
assert "Field 'nonexistent_field' not found in hash 'test_hash'" in result
@pytest.mark.asyncio
async def test_hdel_redis_error(self, mock_redis_connection_manager):
"""Test hash delete operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hdel.side_effect = RedisError("Connection failed")
result = await hdel("test_hash", "field1")
assert (
"Error deleting field 'field1' from hash 'test_hash': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_hexists_field_exists(self, mock_redis_connection_manager):
"""Test hash exists operation when field exists."""
mock_redis = mock_redis_connection_manager
mock_redis.hexists.return_value = True
result = await hexists("test_hash", "field1")
mock_redis.hexists.assert_called_once_with("test_hash", "field1")
assert result is True
@pytest.mark.asyncio
async def test_hexists_field_not_exists(self, mock_redis_connection_manager):
"""Test hash exists operation when field doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.hexists.return_value = False
result = await hexists("test_hash", "nonexistent_field")
assert result is False
@pytest.mark.asyncio
async def test_hexists_redis_error(self, mock_redis_connection_manager):
"""Test hash exists operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hexists.side_effect = RedisError("Connection failed")
result = await hexists("test_hash", "field1")
assert (
"Error checking existence of field 'field1' in hash 'test_hash': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_set_vector_in_hash_success(
self, mock_redis_connection_manager, mock_numpy_array
):
"""Test successful vector set operation in hash."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
vector = [0.1, 0.2, 0.3, 0.4, 0.5]
result = await set_vector_in_hash("test_hash", vector)
mock_numpy_array.assert_called_once_with(vector, dtype=np.float32)
mock_redis.hset.assert_called_once_with(
"test_hash", "vector", b"mock_binary_data"
)
assert result is True
@pytest.mark.asyncio
async def test_set_vector_in_hash_custom_field(
self, mock_redis_connection_manager, mock_numpy_array
):
"""Test vector set operation with custom field name."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
vector = [0.1, 0.2, 0.3]
result = await set_vector_in_hash("test_hash", vector, "custom_vector")
mock_redis.hset.assert_called_once_with(
"test_hash", "custom_vector", b"mock_binary_data"
)
assert result is True
@pytest.mark.asyncio
async def test_set_vector_in_hash_redis_error(
self, mock_redis_connection_manager, mock_numpy_array
):
"""Test vector set operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.side_effect = RedisError("Connection failed")
vector = [0.1, 0.2, 0.3]
result = await set_vector_in_hash("test_hash", vector)
assert (
"Error storing vector in hash 'test_hash' with field 'vector': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_get_vector_from_hash_success(
self, mock_redis_connection_manager, mock_numpy_frombuffer
):
"""Test successful vector get operation from hash."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.return_value = b"mock_binary_data"
result = await get_vector_from_hash("test_hash")
mock_redis.hget.assert_called_once_with("test_hash", "vector")
mock_numpy_frombuffer.assert_called_once_with(
b"mock_binary_data", dtype=np.float32
)
assert result == [0.1, 0.2, 0.3]
@pytest.mark.asyncio
async def test_get_vector_from_hash_custom_field(
self, mock_redis_connection_manager, mock_numpy_frombuffer
):
"""Test vector get operation with custom field name."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.return_value = b"mock_binary_data"
result = await get_vector_from_hash("test_hash", "custom_vector")
mock_redis.hget.assert_called_once_with("test_hash", "custom_vector")
assert result == [0.1, 0.2, 0.3]
@pytest.mark.asyncio
async def test_get_vector_from_hash_not_found(self, mock_redis_connection_manager):
"""Test vector get operation when field doesn't exist."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.return_value = None
result = await get_vector_from_hash("test_hash")
assert "Field 'vector' not found in hash 'test_hash'." in result
@pytest.mark.asyncio
async def test_get_vector_from_hash_redis_error(
self, mock_redis_connection_manager
):
"""Test vector get operation with Redis error."""
mock_redis = mock_redis_connection_manager
mock_redis.hget.side_effect = RedisError("Connection failed")
result = await get_vector_from_hash("test_hash")
assert (
"Error retrieving vector field 'vector' from hash 'test_hash': Connection failed"
in result
)
@pytest.mark.asyncio
async def test_hset_expiration_error(self, mock_redis_connection_manager):
"""Test hash set operation when expiration fails."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
mock_redis.expire.side_effect = RedisError("Expire failed")
result = await hset("test_hash", "field1", "value1", 60)
# Should still report success for hset, but mention expire error
assert (
"Error setting field 'field1' in hash 'test_hash': Expire failed" in result
)
@pytest.mark.asyncio
async def test_vector_operations_with_empty_vector(
self, mock_redis_connection_manager, mock_numpy_array
):
"""Test vector operations with empty vector."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
empty_vector = []
result = await set_vector_in_hash("test_hash", empty_vector)
mock_numpy_array.assert_called_once_with(empty_vector, dtype=np.float32)
assert result is True
@pytest.mark.asyncio
async def test_vector_operations_with_large_vector(
self, mock_redis_connection_manager, mock_numpy_array
):
"""Test vector operations with large vector."""
mock_redis = mock_redis_connection_manager
mock_redis.hset.return_value = 1
large_vector = [0.1] * 1000 # 1000-dimensional vector
result = await set_vector_in_hash("test_hash", large_vector)
mock_numpy_array.assert_called_once_with(large_vector, dtype=np.float32)
assert result is True
```
--------------------------------------------------------------------------------
/tests/test_entraid_auth.py:
--------------------------------------------------------------------------------
```python
"""
Unit tests for Entra ID authentication in src/common/entraid_auth.py
"""
from unittest.mock import Mock, patch
import pytest
from src.common.entraid_auth import (
create_credential_provider,
EntraIDAuthenticationError,
_create_token_manager_config,
_create_service_principal_provider,
_create_managed_identity_provider,
_create_default_credential_provider,
)
class TestCreateCredentialProvider:
"""Test cases for create_credential_provider function."""
def setup_method(self):
"""Set up test fixtures."""
self.original_entraid_cfg = {}
self.original_entraid_available = None
def teardown_method(self):
"""Restore original state."""
pass
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_returns_none_when_auth_disabled(self, mock_is_enabled):
"""Test that None is returned when Entra ID auth is disabled."""
mock_is_enabled.return_value = False
result = create_credential_provider()
assert result is None
mock_is_enabled.assert_called_once()
@patch("src.common.entraid_auth.ENTRAID_AVAILABLE", False)
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_raises_error_when_package_not_available(self, mock_is_enabled):
"""Test that error is raised when redis-entraid package is not available."""
mock_is_enabled.return_value = True
with pytest.raises(EntraIDAuthenticationError) as exc_info:
create_credential_provider()
assert "redis-entraid package is required" in str(exc_info.value)
@patch("src.common.entraid_auth.validate_entraid_config")
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_raises_error_on_invalid_config(self, mock_is_enabled, mock_validate):
"""Test that error is raised when configuration is invalid."""
mock_is_enabled.return_value = True
mock_validate.return_value = (False, "Invalid configuration")
with pytest.raises(EntraIDAuthenticationError) as exc_info:
create_credential_provider()
assert "Invalid Entra ID configuration" in str(exc_info.value)
@patch("src.common.entraid_auth._create_service_principal_provider")
@patch("src.common.entraid_auth._create_token_manager_config")
@patch.dict(
"src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "service_principal"}
)
@patch("src.common.entraid_auth.validate_entraid_config")
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_creates_service_principal_provider(
self,
mock_is_enabled,
mock_validate,
mock_create_token_config,
mock_create_sp_provider,
):
"""Test creating service principal credential provider."""
mock_is_enabled.return_value = True
mock_validate.return_value = (True, "")
mock_token_config = Mock()
mock_create_token_config.return_value = mock_token_config
mock_provider = Mock()
mock_create_sp_provider.return_value = mock_provider
result = create_credential_provider()
assert result == mock_provider
mock_create_sp_provider.assert_called_once_with(mock_token_config)
@patch("src.common.entraid_auth._create_managed_identity_provider")
@patch("src.common.entraid_auth._create_token_manager_config")
@patch.dict(
"src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "managed_identity"}
)
@patch("src.common.entraid_auth.validate_entraid_config")
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_creates_managed_identity_provider(
self,
mock_is_enabled,
mock_validate,
mock_create_token_config,
mock_create_mi_provider,
):
"""Test creating managed identity credential provider."""
mock_is_enabled.return_value = True
mock_validate.return_value = (True, "")
mock_token_config = Mock()
mock_create_token_config.return_value = mock_token_config
mock_provider = Mock()
mock_create_mi_provider.return_value = mock_provider
result = create_credential_provider()
assert result == mock_provider
mock_create_mi_provider.assert_called_once_with(mock_token_config)
@patch("src.common.entraid_auth._create_default_credential_provider")
@patch("src.common.entraid_auth._create_token_manager_config")
@patch.dict(
"src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "default_credential"}
)
@patch("src.common.entraid_auth.validate_entraid_config")
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_creates_default_credential_provider(
self,
mock_is_enabled,
mock_validate,
mock_create_token_config,
mock_create_dc_provider,
):
"""Test creating default credential provider."""
mock_is_enabled.return_value = True
mock_validate.return_value = (True, "")
mock_token_config = Mock()
mock_create_token_config.return_value = mock_token_config
mock_provider = Mock()
mock_create_dc_provider.return_value = mock_provider
result = create_credential_provider()
assert result == mock_provider
mock_create_dc_provider.assert_called_once_with(mock_token_config)
@patch("src.common.entraid_auth._create_token_manager_config")
@patch.dict(
"src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "unsupported_flow"}
)
@patch("src.common.entraid_auth.validate_entraid_config")
@patch("src.common.entraid_auth.is_entraid_auth_enabled")
def test_raises_error_on_unsupported_flow(
self, mock_is_enabled, mock_validate, mock_create_token_config
):
"""Test that error is raised for unsupported auth flow."""
mock_is_enabled.return_value = True
mock_validate.return_value = (True, "")
mock_create_token_config.return_value = Mock()
with pytest.raises(EntraIDAuthenticationError) as exc_info:
create_credential_provider()
assert "Unsupported authentication flow" in str(exc_info.value)
class TestCreateTokenManagerConfig:
"""Test cases for _create_token_manager_config function."""
@patch("src.common.entraid_auth.ENTRAID_CFG")
@patch("src.common.entraid_auth.TokenManagerConfig")
@patch("src.common.entraid_auth.RetryPolicy")
def test_creates_token_manager_config(
self, mock_retry_policy_class, mock_token_config_class, mock_entraid_cfg
):
"""Test creating token manager configuration."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"retry_max_attempts": 3,
"retry_delay_ms": 100,
"token_expiration_refresh_ratio": 0.9,
"lower_refresh_bound_millis": 30000,
"token_request_execution_timeout_ms": 10000,
}[key]
mock_retry_policy = Mock()
mock_retry_policy_class.return_value = mock_retry_policy
mock_token_config = Mock()
mock_token_config_class.return_value = mock_token_config
result = _create_token_manager_config()
# Verify RetryPolicy was created with correct parameters
mock_retry_policy_class.assert_called_once_with(max_attempts=3, delay_in_ms=100)
# Verify TokenManagerConfig was created with correct parameters
mock_token_config_class.assert_called_once_with(
expiration_refresh_ratio=0.9,
lower_refresh_bound_millis=30000,
token_request_execution_timeout_in_ms=10000,
retry_policy=mock_retry_policy,
)
assert result == mock_token_config
class TestCreateServicePrincipalProvider:
"""Test cases for _create_service_principal_provider function."""
@patch("src.common.entraid_auth.create_from_service_principal")
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_creates_service_principal_provider(self, mock_entraid_cfg, mock_create_sp):
"""Test creating service principal provider."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"client_id": "test-client-id",
"client_secret": "test-secret",
"tenant_id": "test-tenant-id",
}[key]
mock_token_config = Mock()
mock_provider = Mock()
mock_create_sp.return_value = mock_provider
result = _create_service_principal_provider(mock_token_config)
mock_create_sp.assert_called_once_with(
client_id="test-client-id",
client_credential="test-secret",
tenant_id="test-tenant-id",
token_manager_config=mock_token_config,
)
assert result == mock_provider
class TestCreateManagedIdentityProvider:
"""Test cases for _create_managed_identity_provider function."""
@patch("src.common.entraid_auth.create_from_managed_identity")
@patch("src.common.entraid_auth.ManagedIdentityType")
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_creates_system_assigned_managed_identity_provider(
self, mock_entraid_cfg, mock_identity_type_class, mock_create_mi
):
"""Test creating system-assigned managed identity provider."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"identity_type": "system_assigned",
"resource": "https://redis.azure.com/",
}[key]
mock_identity_type = Mock()
mock_identity_type_class.SYSTEM_ASSIGNED = mock_identity_type
mock_token_config = Mock()
mock_provider = Mock()
mock_create_mi.return_value = mock_provider
result = _create_managed_identity_provider(mock_token_config)
mock_create_mi.assert_called_once_with(
identity_type=mock_identity_type,
resource="https://redis.azure.com/",
token_manager_config=mock_token_config,
)
assert result == mock_provider
@patch("src.common.entraid_auth.create_from_managed_identity")
@patch("src.common.entraid_auth.ManagedIdentityType")
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_creates_user_assigned_managed_identity_provider(
self, mock_entraid_cfg, mock_identity_type_class, mock_create_mi
):
"""Test creating user-assigned managed identity provider."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"identity_type": "user_assigned",
"resource": "https://redis.azure.com/",
"user_assigned_identity_client_id": "test-user-assigned-id",
}[key]
mock_identity_type = Mock()
mock_identity_type_class.USER_ASSIGNED = mock_identity_type
mock_token_config = Mock()
mock_provider = Mock()
mock_create_mi.return_value = mock_provider
result = _create_managed_identity_provider(mock_token_config)
mock_create_mi.assert_called_once_with(
identity_type=mock_identity_type,
resource="https://redis.azure.com/",
client_id="test-user-assigned-id",
token_manager_config=mock_token_config,
)
assert result == mock_provider
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_raises_error_on_invalid_identity_type(self, mock_entraid_cfg):
"""Test that error is raised for invalid identity type."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"identity_type": "invalid_type",
}[key]
mock_token_config = Mock()
with pytest.raises(EntraIDAuthenticationError) as exc_info:
_create_managed_identity_provider(mock_token_config)
assert "Invalid identity type" in str(exc_info.value)
class TestCreateDefaultCredentialProvider:
"""Test cases for _create_default_credential_provider function."""
@patch("src.common.entraid_auth.create_from_default_azure_credential")
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_creates_default_credential_provider_single_scope(
self, mock_entraid_cfg, mock_create_dc
):
"""Test creating default credential provider with single scope."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"scopes": "https://redis.azure.com/.default",
}[key]
mock_token_config = Mock()
mock_provider = Mock()
mock_create_dc.return_value = mock_provider
result = _create_default_credential_provider(mock_token_config)
mock_create_dc.assert_called_once_with(
scopes=("https://redis.azure.com/.default",),
token_manager_config=mock_token_config,
)
assert result == mock_provider
@patch("src.common.entraid_auth.create_from_default_azure_credential")
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_creates_default_credential_provider_multiple_scopes(
self, mock_entraid_cfg, mock_create_dc
):
"""Test creating default credential provider with multiple scopes."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"scopes": "https://redis.azure.com/.default, https://other.scope/.default",
}[key]
mock_token_config = Mock()
mock_provider = Mock()
mock_create_dc.return_value = mock_provider
result = _create_default_credential_provider(mock_token_config)
mock_create_dc.assert_called_once_with(
scopes=(
"https://redis.azure.com/.default",
"https://other.scope/.default",
),
token_manager_config=mock_token_config,
)
assert result == mock_provider
@patch("src.common.entraid_auth.create_from_default_azure_credential")
@patch("src.common.entraid_auth.ENTRAID_CFG")
def test_creates_default_credential_provider_with_whitespace(
self, mock_entraid_cfg, mock_create_dc
):
"""Test that scopes with whitespace are properly trimmed."""
mock_entraid_cfg.__getitem__.side_effect = lambda key: {
"scopes": " https://redis.azure.com/.default , https://other.scope/.default ",
}[key]
mock_token_config = Mock()
mock_provider = Mock()
mock_create_dc.return_value = mock_provider
result = _create_default_credential_provider(mock_token_config)
mock_create_dc.assert_called_once_with(
scopes=(
"https://redis.azure.com/.default",
"https://other.scope/.default",
),
token_manager_config=mock_token_config,
)
assert result == mock_provider
class TestEntraIDAuthenticationError:
"""Test cases for EntraIDAuthenticationError exception."""
def test_exception_can_be_raised(self):
"""Test that EntraIDAuthenticationError can be raised."""
with pytest.raises(EntraIDAuthenticationError):
raise EntraIDAuthenticationError("Test error message")
def test_exception_message(self):
"""Test that exception message is preserved."""
error_msg = "Test error message"
with pytest.raises(EntraIDAuthenticationError) as exc_info:
raise EntraIDAuthenticationError(error_msg)
assert str(exc_info.value) == error_msg
def test_exception_is_exception_subclass(self):
"""Test that EntraIDAuthenticationError is an Exception subclass."""
assert issubclass(EntraIDAuthenticationError, Exception)
```