This is page 2 of 2. Use http://codebase.md/redis/mcp-redis?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .dockerignore
├── .env.example
├── .github
│ ├── dependabot.yml
│ └── workflows
│ ├── ci.yml
│ ├── release.yml
│ └── stale-issues.yml
├── .gitignore
├── Dockerfile
├── examples
│ └── redis_assistant.py
├── fly.toml
├── gemini-extension.json
├── GEMINI.md
├── LICENSE
├── pyproject.toml
├── README.md
├── server.json
├── src
│ ├── __init__.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── config.py
│ │ ├── connection.py
│ │ ├── entraid_auth.py
│ │ ├── logging_utils.py
│ │ └── server.py
│ ├── main.py
│ ├── tools
│ │ ├── __init__.py
│ │ ├── hash.py
│ │ ├── json.py
│ │ ├── list.py
│ │ ├── misc.py
│ │ ├── pub_sub.py
│ │ ├── redis_query_engine.py
│ │ ├── server_management.py
│ │ ├── set.py
│ │ ├── sorted_set.py
│ │ ├── stream.py
│ │ └── string.py
│ └── version.py
├── tests
│ ├── __init__.py
│ ├── conftest.py
│ ├── test_config.py
│ ├── test_connection.py
│ ├── test_entraid_auth.py
│ ├── test_integration.py
│ ├── test_logging_utils.py
│ ├── test_main.py
│ ├── test_server.py
│ └── tools
│ ├── __init__.py
│ ├── test_hash.py
│ ├── test_json.py
│ ├── test_list.py
│ ├── test_pub_sub.py
│ ├── test_redis_query_engine.py
│ ├── test_server_management.py
│ ├── test_set.py
│ ├── test_sorted_set.py
│ ├── test_stream.py
│ └── test_string.py
└── uv.lock
```
# Files
--------------------------------------------------------------------------------
/tests/test_config.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/common/config.py
3 | """
4 |
5 | import os
6 | from unittest.mock import patch
7 |
8 | import pytest
9 |
10 | from src.common.config import REDIS_CFG, parse_redis_uri, set_redis_config_from_cli
11 |
12 |
13 | class TestParseRedisURI:
14 | """Test cases for parse_redis_uri function."""
15 |
16 | def test_parse_basic_redis_uri(self):
17 | """Test parsing basic Redis URI."""
18 | uri = "redis://localhost:6379/0"
19 | result = parse_redis_uri(uri)
20 |
21 | expected = {"ssl": False, "host": "localhost", "port": 6379, "db": 0}
22 | assert result == expected
23 |
24 | def test_parse_redis_uri_with_auth(self):
25 | """Test parsing Redis URI with authentication."""
26 | uri = "redis://user:pass@localhost:6379/1"
27 | result = parse_redis_uri(uri)
28 |
29 | expected = {
30 | "ssl": False,
31 | "host": "localhost",
32 | "port": 6379,
33 | "db": 1,
34 | "username": "user",
35 | "password": "pass",
36 | }
37 | assert result == expected
38 |
39 | def test_parse_rediss_uri(self):
40 | """Test parsing Redis SSL URI."""
41 | uri = "rediss://user:[email protected]:6380/2"
42 | result = parse_redis_uri(uri)
43 |
44 | expected = {
45 | "ssl": True,
46 | "host": "redis.example.com",
47 | "port": 6380,
48 | "db": 2,
49 | "username": "user",
50 | "password": "pass",
51 | }
52 | assert result == expected
53 |
54 | def test_parse_uri_with_query_parameters(self):
55 | """Test parsing URI with query parameters."""
56 | uri = "redis://localhost:6379/0?ssl_cert_reqs=optional&ssl_ca_certs=/path/to/ca.pem"
57 | result = parse_redis_uri(uri)
58 |
59 | assert result["ssl"] is False
60 | assert result["host"] == "localhost"
61 | assert result["port"] == 6379
62 | assert result["db"] == 0
63 | assert result["ssl_cert_reqs"] == "optional"
64 | assert result["ssl_ca_certs"] == "/path/to/ca.pem"
65 |
66 | def test_parse_uri_with_db_in_query(self):
67 | """Test parsing URI with database number in query parameters."""
68 | uri = "redis://localhost:6379?db=5"
69 | result = parse_redis_uri(uri)
70 |
71 | assert result["db"] == 5
72 |
73 | def test_parse_uri_with_ssl_parameters(self):
74 | """Test parsing URI with SSL-related query parameters."""
75 | uri = "rediss://localhost:6379/0?ssl_keyfile=/key.pem&ssl_certfile=/cert.pem&ssl_ca_path=/ca.pem"
76 | result = parse_redis_uri(uri)
77 |
78 | assert result["ssl"] is True
79 | assert result["ssl_keyfile"] == "/key.pem"
80 | assert result["ssl_certfile"] == "/cert.pem"
81 | assert result["ssl_ca_path"] == "/ca.pem"
82 |
83 | def test_parse_uri_defaults(self):
84 | """Test parsing URI with default values."""
85 | uri = "redis://example.com"
86 | result = parse_redis_uri(uri)
87 |
88 | assert result["host"] == "example.com"
89 | assert result["port"] == 6379 # Default port
90 | assert result["db"] == 0 # Default database
91 |
92 | def test_parse_uri_no_path(self):
93 | """Test parsing URI without path."""
94 | uri = "redis://localhost:6379"
95 | result = parse_redis_uri(uri)
96 |
97 | assert result["db"] == 0
98 |
99 | def test_parse_uri_root_path(self):
100 | """Test parsing URI with root path."""
101 | uri = "redis://localhost:6379/"
102 | result = parse_redis_uri(uri)
103 |
104 | assert result["db"] == 0
105 |
106 | def test_parse_uri_invalid_db_in_path(self):
107 | """Test parsing URI with invalid database number in path."""
108 | uri = "redis://localhost:6379/invalid"
109 | result = parse_redis_uri(uri)
110 |
111 | assert result["db"] == 0 # Should default to 0
112 |
113 | def test_parse_uri_invalid_db_in_query(self):
114 | """Test parsing URI with invalid database number in query."""
115 | uri = "redis://localhost:6379?db=invalid"
116 | result = parse_redis_uri(uri)
117 |
118 | # Should not have db key or should be handled gracefully
119 | assert "db" not in result or result["db"] == 0
120 |
121 | def test_parse_uri_unsupported_scheme(self):
122 | """Test parsing URI with unsupported scheme."""
123 | uri = "http://localhost:6379/0"
124 |
125 | with pytest.raises(ValueError, match="Unsupported scheme: http"):
126 | parse_redis_uri(uri)
127 |
128 |
129 | class TestSetRedisConfigFromCLI:
130 | """Test cases for set_redis_config_from_cli function."""
131 |
132 | def setup_method(self):
133 | """Set up test fixtures."""
134 | # Store original config
135 | self.original_config = REDIS_CFG.copy()
136 |
137 | def teardown_method(self):
138 | """Restore original config."""
139 | REDIS_CFG.clear()
140 | REDIS_CFG.update(self.original_config)
141 |
142 | def test_set_string_values(self):
143 | """Test setting string configuration values."""
144 | config = {
145 | "host": "redis.example.com",
146 | "username": "testuser",
147 | "password": "testpass",
148 | }
149 |
150 | set_redis_config_from_cli(config)
151 |
152 | assert REDIS_CFG["host"] == "redis.example.com"
153 | assert REDIS_CFG["username"] == "testuser"
154 | assert REDIS_CFG["password"] == "testpass"
155 |
156 | def test_set_integer_values(self):
157 | """Test setting integer configuration values."""
158 | config = {"port": 6380, "db": 2}
159 |
160 | set_redis_config_from_cli(config)
161 |
162 | assert REDIS_CFG["port"] == 6380
163 | assert isinstance(REDIS_CFG["port"], int)
164 | assert REDIS_CFG["db"] == 2
165 | assert isinstance(REDIS_CFG["db"], int)
166 |
167 | def test_set_boolean_values(self):
168 | """Test setting boolean configuration values."""
169 | config = {"ssl": True, "cluster_mode": False}
170 |
171 | set_redis_config_from_cli(config)
172 |
173 | assert REDIS_CFG["ssl"] is True
174 | assert isinstance(REDIS_CFG["ssl"], bool)
175 | assert REDIS_CFG["cluster_mode"] is False
176 | assert isinstance(REDIS_CFG["cluster_mode"], bool)
177 |
178 | def test_set_none_values(self):
179 | """Test setting None configuration values."""
180 | config = {"ssl_ca_path": None, "ssl_keyfile": None}
181 |
182 | set_redis_config_from_cli(config)
183 |
184 | assert REDIS_CFG["ssl_ca_path"] is None
185 | assert REDIS_CFG["ssl_keyfile"] is None
186 |
187 | def test_set_mixed_values(self):
188 | """Test setting mixed configuration values."""
189 | config = {
190 | "host": "localhost",
191 | "port": 6379,
192 | "ssl": True,
193 | "ssl_ca_path": "/path/to/ca.pem",
194 | "cluster_mode": False,
195 | "username": None,
196 | }
197 |
198 | set_redis_config_from_cli(config)
199 |
200 | assert REDIS_CFG["host"] == "localhost"
201 | assert REDIS_CFG["port"] == 6379
202 | assert REDIS_CFG["ssl"] is True
203 | assert REDIS_CFG["ssl_ca_path"] == "/path/to/ca.pem"
204 | assert REDIS_CFG["cluster_mode"] is False
205 | assert REDIS_CFG["username"] is None
206 |
207 | def test_convert_string_integers(self):
208 | """Test converting string integers to integers."""
209 | config = {"port": "6380", "db": "1"}
210 |
211 | set_redis_config_from_cli(config)
212 |
213 | assert REDIS_CFG["port"] == 6380
214 | assert isinstance(REDIS_CFG["port"], int)
215 | assert REDIS_CFG["db"] == 1
216 | assert isinstance(REDIS_CFG["db"], int)
217 |
218 | def test_convert_other_booleans_to_strings(self):
219 | """Test converting non-ssl/cluster_mode booleans to strings."""
220 | # This tests the behavior where other boolean values are converted to strings
221 | # for environment compatibility
222 | config = {"some_other_bool": True}
223 |
224 | set_redis_config_from_cli(config)
225 |
226 | # This would be converted to string for environment compatibility
227 | assert REDIS_CFG["some_other_bool"] == "true"
228 |
229 | def test_empty_config(self):
230 | """Test setting empty configuration."""
231 | original_config = REDIS_CFG.copy()
232 | config = {}
233 |
234 | set_redis_config_from_cli(config)
235 |
236 | # Config should remain unchanged
237 | assert REDIS_CFG == original_config
238 |
239 |
240 | @patch.dict(os.environ, {}, clear=True)
241 | class TestRedisConfigDefaults:
242 | """Test cases for REDIS_CFG default values."""
243 |
244 | @patch("src.common.config.load_dotenv")
245 | def test_default_config_values(self, mock_load_dotenv):
246 | """Test default configuration values when no environment variables are set."""
247 | # Re-import to get fresh config
248 | import importlib
249 |
250 | import src.common.config
251 |
252 | importlib.reload(src.common.config)
253 |
254 | config = src.common.config.REDIS_CFG
255 |
256 | assert config["host"] == "127.0.0.1"
257 | assert config["port"] == 6379
258 | assert config["username"] is None
259 | assert config["password"] == ""
260 | assert config["ssl"] is False
261 | assert config["cluster_mode"] is False
262 | assert config["db"] == 0
263 |
264 | @patch.dict(
265 | os.environ,
266 | {
267 | "REDIS_HOST": "redis.example.com",
268 | "REDIS_PORT": "6380",
269 | "REDIS_SSL": "true",
270 | "REDIS_CLUSTER_MODE": "1",
271 | },
272 | )
273 | @patch("src.common.config.load_dotenv")
274 | def test_config_from_environment(self, mock_load_dotenv):
275 | """Test configuration loading from environment variables."""
276 | # Re-import to get fresh config
277 | import importlib
278 |
279 | import src.common.config
280 |
281 | importlib.reload(src.common.config)
282 |
283 | config = src.common.config.REDIS_CFG
284 |
285 | assert config["host"] == "redis.example.com"
286 | assert config["port"] == 6380
287 | assert config["ssl"] is True
288 | assert config["cluster_mode"] is True
289 |
```
--------------------------------------------------------------------------------
/tests/tools/test_set.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/set.py
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pytest
8 | from redis.exceptions import RedisError
9 |
10 | from src.tools.set import sadd, smembers, srem
11 |
12 |
13 | class TestSetOperations:
14 | """Test cases for Redis set operations."""
15 |
16 | @pytest.mark.asyncio
17 | async def test_sadd_success(self, mock_redis_connection_manager):
18 | """Test successful set add operation."""
19 | mock_redis = mock_redis_connection_manager
20 | mock_redis.sadd.return_value = 1 # Number of elements added
21 |
22 | result = await sadd("test_set", "member1")
23 |
24 | mock_redis.sadd.assert_called_once_with("test_set", "member1")
25 | assert "Value 'member1' added successfully to set 'test_set'" in result
26 |
27 | @pytest.mark.asyncio
28 | async def test_sadd_with_expiration(self, mock_redis_connection_manager):
29 | """Test set add operation with expiration."""
30 | mock_redis = mock_redis_connection_manager
31 | mock_redis.sadd.return_value = 1
32 | mock_redis.expire.return_value = True
33 |
34 | result = await sadd("test_set", "member1", 60)
35 |
36 | mock_redis.sadd.assert_called_once_with("test_set", "member1")
37 | mock_redis.expire.assert_called_once_with("test_set", 60)
38 | assert "Expires in 60 seconds" in result
39 |
40 | @pytest.mark.asyncio
41 | async def test_sadd_member_already_exists(self, mock_redis_connection_manager):
42 | """Test set add operation when member already exists."""
43 | mock_redis = mock_redis_connection_manager
44 | mock_redis.sadd.return_value = 0 # Member already exists
45 |
46 | result = await sadd("test_set", "existing_member")
47 |
48 | assert "Value 'existing_member' added successfully to set 'test_set'" in result
49 |
50 | @pytest.mark.asyncio
51 | async def test_sadd_redis_error(self, mock_redis_connection_manager):
52 | """Test set add operation with Redis error."""
53 | mock_redis = mock_redis_connection_manager
54 | mock_redis.sadd.side_effect = RedisError("Connection failed")
55 |
56 | result = await sadd("test_set", "member1")
57 |
58 | assert (
59 | "Error adding value 'member1' to set 'test_set': Connection failed"
60 | in result
61 | )
62 |
63 | @pytest.mark.asyncio
64 | async def test_sadd_numeric_member(self, mock_redis_connection_manager):
65 | """Test set add operation with numeric member."""
66 | mock_redis = mock_redis_connection_manager
67 | mock_redis.sadd.return_value = 1
68 |
69 | result = await sadd("test_set", 42)
70 |
71 | mock_redis.sadd.assert_called_once_with("test_set", 42)
72 | assert "Value '42' added successfully to set 'test_set'" in result
73 |
74 | @pytest.mark.asyncio
75 | async def test_srem_success(self, mock_redis_connection_manager):
76 | """Test successful set remove operation."""
77 | mock_redis = mock_redis_connection_manager
78 | mock_redis.srem.return_value = 1 # Number of elements removed
79 |
80 | result = await srem("test_set", "member1")
81 |
82 | mock_redis.srem.assert_called_once_with("test_set", "member1")
83 | assert "Value 'member1' removed from set 'test_set'" in result
84 |
85 | @pytest.mark.asyncio
86 | async def test_srem_member_not_exists(self, mock_redis_connection_manager):
87 | """Test set remove operation when member doesn't exist."""
88 | mock_redis = mock_redis_connection_manager
89 | mock_redis.srem.return_value = 0 # Member doesn't exist
90 |
91 | result = await srem("test_set", "nonexistent_member")
92 |
93 | assert "Value 'nonexistent_member' not found in set 'test_set'" in result
94 |
95 | @pytest.mark.asyncio
96 | async def test_srem_redis_error(self, mock_redis_connection_manager):
97 | """Test set remove operation with Redis error."""
98 | mock_redis = mock_redis_connection_manager
99 | mock_redis.srem.side_effect = RedisError("Connection failed")
100 |
101 | result = await srem("test_set", "member1")
102 |
103 | assert (
104 | "Error removing value 'member1' from set 'test_set': Connection failed"
105 | in result
106 | )
107 |
108 | @pytest.mark.asyncio
109 | async def test_srem_numeric_member(self, mock_redis_connection_manager):
110 | """Test set remove operation with numeric member."""
111 | mock_redis = mock_redis_connection_manager
112 | mock_redis.srem.return_value = 1
113 |
114 | result = await srem("test_set", 42)
115 |
116 | mock_redis.srem.assert_called_once_with("test_set", 42)
117 | assert "Value '42' removed from set 'test_set'" in result
118 |
119 | @pytest.mark.asyncio
120 | async def test_smembers_success(self, mock_redis_connection_manager):
121 | """Test successful set members operation."""
122 | mock_redis = mock_redis_connection_manager
123 | mock_redis.smembers.return_value = {"member1", "member2", "member3"}
124 |
125 | result = await smembers("test_set")
126 |
127 | mock_redis.smembers.assert_called_once_with("test_set")
128 | assert set(result) == {"member1", "member2", "member3"}
129 |
130 | @pytest.mark.asyncio
131 | async def test_smembers_empty_set(self, mock_redis_connection_manager):
132 | """Test set members operation on empty set."""
133 | mock_redis = mock_redis_connection_manager
134 | mock_redis.smembers.return_value = set()
135 |
136 | result = await smembers("empty_set")
137 |
138 | assert "Set 'empty_set' is empty or does not exist" in result
139 |
140 | @pytest.mark.asyncio
141 | async def test_smembers_redis_error(self, mock_redis_connection_manager):
142 | """Test set members operation with Redis error."""
143 | mock_redis = mock_redis_connection_manager
144 | mock_redis.smembers.side_effect = RedisError("Connection failed")
145 |
146 | result = await smembers("test_set")
147 |
148 | assert "Error retrieving members of set 'test_set': Connection failed" in result
149 |
150 | @pytest.mark.asyncio
151 | async def test_smembers_single_member(self, mock_redis_connection_manager):
152 | """Test set members operation with single member."""
153 | mock_redis = mock_redis_connection_manager
154 | mock_redis.smembers.return_value = {"single_member"}
155 |
156 | result = await smembers("test_set")
157 |
158 | assert result == ["single_member"]
159 |
160 | @pytest.mark.asyncio
161 | async def test_smembers_numeric_members(self, mock_redis_connection_manager):
162 | """Test set members operation with numeric members."""
163 | mock_redis = mock_redis_connection_manager
164 | mock_redis.smembers.return_value = {"1", "2", "3", "42"}
165 |
166 | result = await smembers("numeric_set")
167 |
168 | assert set(result) == {"1", "2", "3", "42"}
169 |
170 | @pytest.mark.asyncio
171 | async def test_sadd_expiration_error(self, mock_redis_connection_manager):
172 | """Test set add operation when expiration fails."""
173 | mock_redis = mock_redis_connection_manager
174 | mock_redis.sadd.return_value = 1
175 | mock_redis.expire.side_effect = RedisError("Expire failed")
176 |
177 | result = await sadd("test_set", "member1", 60)
178 |
179 | assert "Error adding value 'member1' to set 'test_set': Expire failed" in result
180 |
181 | @pytest.mark.asyncio
182 | async def test_sadd_with_special_characters(self, mock_redis_connection_manager):
183 | """Test set add operation with special characters in member."""
184 | mock_redis = mock_redis_connection_manager
185 | mock_redis.sadd.return_value = 1
186 |
187 | special_member = "member:with:colons"
188 | result = await sadd("test_set", special_member)
189 |
190 | mock_redis.sadd.assert_called_once_with("test_set", special_member)
191 | assert (
192 | f"Value '{special_member}' added successfully to set 'test_set'" in result
193 | )
194 |
195 | @pytest.mark.asyncio
196 | async def test_sadd_with_unicode_member(self, mock_redis_connection_manager):
197 | """Test set add operation with unicode member."""
198 | mock_redis = mock_redis_connection_manager
199 | mock_redis.sadd.return_value = 1
200 |
201 | unicode_member = "测试成员 🚀"
202 | result = await sadd("test_set", unicode_member)
203 |
204 | mock_redis.sadd.assert_called_once_with("test_set", unicode_member)
205 | assert (
206 | f"Value '{unicode_member}' added successfully to set 'test_set'" in result
207 | )
208 |
209 | @pytest.mark.asyncio
210 | async def test_smembers_large_set(self, mock_redis_connection_manager):
211 | """Test set members operation with large set."""
212 | mock_redis = mock_redis_connection_manager
213 | large_set = {f"member_{i}" for i in range(1000)}
214 | mock_redis.smembers.return_value = large_set
215 |
216 | result = await smembers("large_set")
217 |
218 | # smembers returns a list, not a set
219 | assert isinstance(result, list)
220 | assert len(result) == 1000
221 |
222 | @pytest.mark.asyncio
223 | async def test_srem_multiple_members_behavior(self, mock_redis_connection_manager):
224 | """Test that srem function handles single member correctly."""
225 | mock_redis = mock_redis_connection_manager
226 | mock_redis.srem.return_value = 1
227 |
228 | result = await srem("test_set", "single_member")
229 |
230 | # Should call srem with single member, not multiple members
231 | mock_redis.srem.assert_called_once_with("test_set", "single_member")
232 | assert "Value 'single_member' removed from set 'test_set'" in result
233 |
234 | @pytest.mark.asyncio
235 | async def test_connection_manager_called_correctly(self):
236 | """Test that RedisConnectionManager.get_connection is called correctly."""
237 | with patch(
238 | "src.tools.set.RedisConnectionManager.get_connection"
239 | ) as mock_get_conn:
240 | mock_redis = Mock()
241 | mock_redis.sadd.return_value = 1
242 | mock_get_conn.return_value = mock_redis
243 |
244 | await sadd("test_set", "member1")
245 |
246 | mock_get_conn.assert_called_once()
247 |
248 | @pytest.mark.asyncio
249 | async def test_function_signatures(self):
250 | """Test that functions have correct signatures."""
251 | import inspect
252 |
253 | # Test sadd function signature
254 | sadd_sig = inspect.signature(sadd)
255 | sadd_params = list(sadd_sig.parameters.keys())
256 | assert sadd_params == ["name", "value", "expire_seconds"]
257 | assert sadd_sig.parameters["expire_seconds"].default is None
258 |
259 | # Test srem function signature
260 | srem_sig = inspect.signature(srem)
261 | srem_params = list(srem_sig.parameters.keys())
262 | assert srem_params == ["name", "value"]
263 |
264 | # Test smembers function signature
265 | smembers_sig = inspect.signature(smembers)
266 | smembers_params = list(smembers_sig.parameters.keys())
267 | assert smembers_params == ["name"]
268 |
```
--------------------------------------------------------------------------------
/tests/tools/test_server_management.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/server_management.py
3 | """
4 |
5 | import pytest
6 | from redis.exceptions import ConnectionError, RedisError
7 |
8 | from src.tools.server_management import client_list, dbsize, info
9 |
10 |
11 | class TestServerManagementOperations:
12 | """Test cases for Redis server management operations."""
13 |
14 | @pytest.mark.asyncio
15 | async def test_dbsize_success(self, mock_redis_connection_manager):
16 | """Test successful database size operation."""
17 | mock_redis = mock_redis_connection_manager
18 | mock_redis.dbsize.return_value = 1000
19 |
20 | result = await dbsize()
21 |
22 | mock_redis.dbsize.assert_called_once()
23 | assert result == 1000
24 |
25 | @pytest.mark.asyncio
26 | async def test_dbsize_zero_keys(self, mock_redis_connection_manager):
27 | """Test database size operation with empty database."""
28 | mock_redis = mock_redis_connection_manager
29 | mock_redis.dbsize.return_value = 0
30 |
31 | result = await dbsize()
32 |
33 | assert result == 0
34 |
35 | @pytest.mark.asyncio
36 | async def test_dbsize_redis_error(self, mock_redis_connection_manager):
37 | """Test database size operation with Redis error."""
38 | mock_redis = mock_redis_connection_manager
39 | mock_redis.dbsize.side_effect = RedisError("Connection failed")
40 |
41 | result = await dbsize()
42 |
43 | assert "Error getting database size: Connection failed" in result
44 |
45 | @pytest.mark.asyncio
46 | async def test_info_success_default_section(self, mock_redis_connection_manager):
47 | """Test successful info operation with default section."""
48 | mock_redis = mock_redis_connection_manager
49 | mock_info = {
50 | "redis_version": "7.0.0",
51 | "used_memory": "1024000",
52 | "connected_clients": "5",
53 | "total_commands_processed": "1000",
54 | }
55 | mock_redis.info.return_value = mock_info
56 |
57 | result = await info()
58 |
59 | mock_redis.info.assert_called_once_with("default")
60 | assert result == mock_info
61 |
62 | @pytest.mark.asyncio
63 | async def test_info_success_specific_section(self, mock_redis_connection_manager):
64 | """Test successful info operation with specific section."""
65 | mock_redis = mock_redis_connection_manager
66 | mock_memory_info = {
67 | "used_memory": "2048000",
68 | "used_memory_human": "2.00M",
69 | "used_memory_peak": "3072000",
70 | "used_memory_peak_human": "3.00M",
71 | }
72 | mock_redis.info.return_value = mock_memory_info
73 |
74 | result = await info("memory")
75 |
76 | mock_redis.info.assert_called_once_with("memory")
77 | assert result == mock_memory_info
78 |
79 | @pytest.mark.asyncio
80 | async def test_info_all_sections(self, mock_redis_connection_manager):
81 | """Test info operation with 'all' section."""
82 | mock_redis = mock_redis_connection_manager
83 | mock_all_info = {
84 | "redis_version": "7.0.0",
85 | "used_memory": "1024000",
86 | "connected_clients": "5",
87 | "keyspace_hits": "500",
88 | "keyspace_misses": "100",
89 | }
90 | mock_redis.info.return_value = mock_all_info
91 |
92 | result = await info("all")
93 |
94 | mock_redis.info.assert_called_once_with("all")
95 | assert result == mock_all_info
96 |
97 | @pytest.mark.asyncio
98 | async def test_info_redis_error(self, mock_redis_connection_manager):
99 | """Test info operation with Redis error."""
100 | mock_redis = mock_redis_connection_manager
101 | mock_redis.info.side_effect = RedisError("Connection failed")
102 |
103 | result = await info("server")
104 |
105 | assert "Error retrieving Redis info: Connection failed" in result
106 |
107 | @pytest.mark.asyncio
108 | async def test_info_invalid_section(self, mock_redis_connection_manager):
109 | """Test info operation with invalid section."""
110 | mock_redis = mock_redis_connection_manager
111 | mock_redis.info.side_effect = RedisError("Unknown section")
112 |
113 | result = await info("invalid_section")
114 |
115 | assert "Error retrieving Redis info: Unknown section" in result
116 |
117 | @pytest.mark.asyncio
118 | async def test_client_list_success(self, mock_redis_connection_manager):
119 | """Test successful client list operation."""
120 | mock_redis = mock_redis_connection_manager
121 | mock_clients = [
122 | {
123 | "id": "1",
124 | "addr": "127.0.0.1:12345",
125 | "name": "client1",
126 | "age": "100",
127 | "idle": "0",
128 | "flags": "N",
129 | "db": "0",
130 | "sub": "0",
131 | "psub": "0",
132 | "multi": "-1",
133 | "qbuf": "0",
134 | "qbuf-free": "32768",
135 | "obl": "0",
136 | "oll": "0",
137 | "omem": "0",
138 | "events": "r",
139 | "cmd": "client",
140 | },
141 | {
142 | "id": "2",
143 | "addr": "127.0.0.1:12346",
144 | "name": "client2",
145 | "age": "200",
146 | "idle": "5",
147 | "flags": "N",
148 | "db": "1",
149 | "sub": "0",
150 | "psub": "0",
151 | "multi": "-1",
152 | "qbuf": "0",
153 | "qbuf-free": "32768",
154 | "obl": "0",
155 | "oll": "0",
156 | "omem": "0",
157 | "events": "r",
158 | "cmd": "get",
159 | },
160 | ]
161 | mock_redis.client_list.return_value = mock_clients
162 |
163 | result = await client_list()
164 |
165 | mock_redis.client_list.assert_called_once()
166 | assert result == mock_clients
167 | assert len(result) == 2
168 | assert result[0]["id"] == "1"
169 | assert result[1]["id"] == "2"
170 |
171 | @pytest.mark.asyncio
172 | async def test_client_list_empty(self, mock_redis_connection_manager):
173 | """Test client list operation with no clients."""
174 | mock_redis = mock_redis_connection_manager
175 | mock_redis.client_list.return_value = []
176 |
177 | result = await client_list()
178 |
179 | assert result == []
180 |
181 | @pytest.mark.asyncio
182 | async def test_client_list_redis_error(self, mock_redis_connection_manager):
183 | """Test client list operation with Redis error."""
184 | mock_redis = mock_redis_connection_manager
185 | mock_redis.client_list.side_effect = RedisError("Connection failed")
186 |
187 | result = await client_list()
188 |
189 | assert "Error retrieving client list: Connection failed" in result
190 |
191 | @pytest.mark.asyncio
192 | async def test_client_list_connection_error(self, mock_redis_connection_manager):
193 | """Test client list operation with connection error."""
194 | mock_redis = mock_redis_connection_manager
195 | mock_redis.client_list.side_effect = ConnectionError("Redis server unavailable")
196 |
197 | result = await client_list()
198 |
199 | assert "Error retrieving client list: Redis server unavailable" in result
200 |
201 | @pytest.mark.asyncio
202 | async def test_info_stats_section(self, mock_redis_connection_manager):
203 | """Test info operation with stats section."""
204 | mock_redis = mock_redis_connection_manager
205 | mock_stats_info = {
206 | "total_connections_received": "1000",
207 | "total_commands_processed": "5000",
208 | "instantaneous_ops_per_sec": "10",
209 | "total_net_input_bytes": "1024000",
210 | "total_net_output_bytes": "2048000",
211 | "instantaneous_input_kbps": "1.5",
212 | "instantaneous_output_kbps": "3.0",
213 | "rejected_connections": "0",
214 | "sync_full": "0",
215 | "sync_partial_ok": "0",
216 | "sync_partial_err": "0",
217 | "expired_keys": "100",
218 | "evicted_keys": "0",
219 | "keyspace_hits": "4000",
220 | "keyspace_misses": "1000",
221 | "pubsub_channels": "0",
222 | "pubsub_patterns": "0",
223 | "latest_fork_usec": "0",
224 | }
225 | mock_redis.info.return_value = mock_stats_info
226 |
227 | result = await info("stats")
228 |
229 | mock_redis.info.assert_called_once_with("stats")
230 | assert result == mock_stats_info
231 | assert "keyspace_hits" in result
232 | assert "keyspace_misses" in result
233 |
234 | @pytest.mark.asyncio
235 | async def test_info_replication_section(self, mock_redis_connection_manager):
236 | """Test info operation with replication section."""
237 | mock_redis = mock_redis_connection_manager
238 | mock_replication_info = {
239 | "role": "master",
240 | "connected_slaves": "2",
241 | "master_replid": "abc123def456",
242 | "master_replid2": "0000000000000000000000000000000000000000",
243 | "master_repl_offset": "1000",
244 | "second_repl_offset": "-1",
245 | "repl_backlog_active": "1",
246 | "repl_backlog_size": "1048576",
247 | "repl_backlog_first_byte_offset": "1",
248 | "repl_backlog_histlen": "1000",
249 | }
250 | mock_redis.info.return_value = mock_replication_info
251 |
252 | result = await info("replication")
253 |
254 | mock_redis.info.assert_called_once_with("replication")
255 | assert result == mock_replication_info
256 | assert result["role"] == "master"
257 | assert result["connected_slaves"] == "2"
258 |
259 | @pytest.mark.asyncio
260 | async def test_dbsize_large_number(self, mock_redis_connection_manager):
261 | """Test database size operation with large number of keys."""
262 | mock_redis = mock_redis_connection_manager
263 | mock_redis.dbsize.return_value = 1000000 # 1 million keys
264 |
265 | result = await dbsize()
266 |
267 | assert result == 1000000
268 |
269 | @pytest.mark.asyncio
270 | async def test_client_list_single_client(self, mock_redis_connection_manager):
271 | """Test client list operation with single client."""
272 | mock_redis = mock_redis_connection_manager
273 | mock_clients = [
274 | {
275 | "id": "1",
276 | "addr": "127.0.0.1:12345",
277 | "name": "",
278 | "age": "50",
279 | "idle": "0",
280 | "flags": "N",
281 | "db": "0",
282 | "sub": "0",
283 | "psub": "0",
284 | "multi": "-1",
285 | "qbuf": "0",
286 | "qbuf-free": "32768",
287 | "obl": "0",
288 | "oll": "0",
289 | "omem": "0",
290 | "events": "r",
291 | "cmd": "ping",
292 | }
293 | ]
294 | mock_redis.client_list.return_value = mock_clients
295 |
296 | result = await client_list()
297 |
298 | assert len(result) == 1
299 | assert result[0]["id"] == "1"
300 | assert result[0]["cmd"] == "ping"
301 |
```
--------------------------------------------------------------------------------
/tests/tools/test_list.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/list.py
3 | """
4 |
5 | import pytest
6 | from redis.exceptions import RedisError
7 |
8 | from src.tools.list import llen, lpop, lpush, lrange, rpop, rpush
9 |
10 |
11 | class TestListOperations:
12 | """Test cases for Redis list operations."""
13 |
14 | @pytest.mark.asyncio
15 | async def test_lpush_success(self, mock_redis_connection_manager):
16 | """Test successful left push operation."""
17 | mock_redis = mock_redis_connection_manager
18 | mock_redis.lpush.return_value = 2 # New length of list
19 |
20 | result = await lpush("test_list", "value1")
21 |
22 | mock_redis.lpush.assert_called_once_with("test_list", "value1")
23 | assert "Value 'value1' pushed to the left of list 'test_list'" in result
24 |
25 | @pytest.mark.asyncio
26 | async def test_lpush_with_expiration(self, mock_redis_connection_manager):
27 | """Test left push operation with expiration."""
28 | mock_redis = mock_redis_connection_manager
29 | mock_redis.lpush.return_value = 1
30 | mock_redis.expire.return_value = True
31 |
32 | result = await lpush("test_list", "value1", 60)
33 |
34 | mock_redis.lpush.assert_called_once_with("test_list", "value1")
35 | mock_redis.expire.assert_called_once_with("test_list", 60)
36 | # The implementation doesn't include expiration info in the message
37 | assert "Value 'value1' pushed to the left of list 'test_list'" in result
38 |
39 | @pytest.mark.asyncio
40 | async def test_lpush_redis_error(self, mock_redis_connection_manager):
41 | """Test left push operation with Redis error."""
42 | mock_redis = mock_redis_connection_manager
43 | mock_redis.lpush.side_effect = RedisError("Connection failed")
44 |
45 | result = await lpush("test_list", "value1")
46 |
47 | assert "Error pushing value to list 'test_list': Connection failed" in result
48 |
49 | @pytest.mark.asyncio
50 | async def test_rpush_success(self, mock_redis_connection_manager):
51 | """Test successful right push operation."""
52 | mock_redis = mock_redis_connection_manager
53 | mock_redis.rpush.return_value = 3
54 |
55 | result = await rpush("test_list", "value2")
56 |
57 | mock_redis.rpush.assert_called_once_with("test_list", "value2")
58 | assert "Value 'value2' pushed to the right of list 'test_list'" in result
59 |
60 | @pytest.mark.asyncio
61 | async def test_rpush_with_expiration(self, mock_redis_connection_manager):
62 | """Test right push operation with expiration."""
63 | mock_redis = mock_redis_connection_manager
64 | mock_redis.rpush.return_value = 1
65 | mock_redis.expire.return_value = True
66 |
67 | result = await rpush("test_list", "value2", 120)
68 |
69 | mock_redis.rpush.assert_called_once_with("test_list", "value2")
70 | mock_redis.expire.assert_called_once_with("test_list", 120)
71 | # The implementation doesn't include expiration info in the message
72 | assert "Value 'value2' pushed to the right of list 'test_list'" in result
73 |
74 | @pytest.mark.asyncio
75 | async def test_rpush_redis_error(self, mock_redis_connection_manager):
76 | """Test right push operation with Redis error."""
77 | mock_redis = mock_redis_connection_manager
78 | mock_redis.rpush.side_effect = RedisError("Connection failed")
79 |
80 | result = await rpush("test_list", "value2")
81 |
82 | assert "Error pushing value to list 'test_list': Connection failed" in result
83 |
84 | @pytest.mark.asyncio
85 | async def test_lpop_success(self, mock_redis_connection_manager):
86 | """Test successful left pop operation."""
87 | mock_redis = mock_redis_connection_manager
88 | mock_redis.lpop.return_value = "popped_value"
89 |
90 | result = await lpop("test_list")
91 |
92 | mock_redis.lpop.assert_called_once_with("test_list")
93 | assert result == "popped_value"
94 |
95 | @pytest.mark.asyncio
96 | async def test_lpop_empty_list(self, mock_redis_connection_manager):
97 | """Test left pop operation on empty list."""
98 | mock_redis = mock_redis_connection_manager
99 | mock_redis.lpop.return_value = None
100 |
101 | result = await lpop("empty_list")
102 |
103 | assert "List 'empty_list' is empty" in result
104 |
105 | @pytest.mark.asyncio
106 | async def test_lpop_redis_error(self, mock_redis_connection_manager):
107 | """Test left pop operation with Redis error."""
108 | mock_redis = mock_redis_connection_manager
109 | mock_redis.lpop.side_effect = RedisError("Connection failed")
110 |
111 | result = await lpop("test_list")
112 |
113 | assert "Error popping value from list 'test_list': Connection failed" in result
114 |
115 | @pytest.mark.asyncio
116 | async def test_rpop_success(self, mock_redis_connection_manager):
117 | """Test successful right pop operation."""
118 | mock_redis = mock_redis_connection_manager
119 | mock_redis.rpop.return_value = "right_popped_value"
120 |
121 | result = await rpop("test_list")
122 |
123 | mock_redis.rpop.assert_called_once_with("test_list")
124 | assert result == "right_popped_value"
125 |
126 | @pytest.mark.asyncio
127 | async def test_rpop_empty_list(self, mock_redis_connection_manager):
128 | """Test right pop operation on empty list."""
129 | mock_redis = mock_redis_connection_manager
130 | mock_redis.rpop.return_value = None
131 |
132 | result = await rpop("empty_list")
133 |
134 | assert "List 'empty_list' is empty" in result
135 |
136 | @pytest.mark.asyncio
137 | async def test_rpop_redis_error(self, mock_redis_connection_manager):
138 | """Test right pop operation with Redis error."""
139 | mock_redis = mock_redis_connection_manager
140 | mock_redis.rpop.side_effect = RedisError("Connection failed")
141 |
142 | result = await rpop("test_list")
143 |
144 | assert "Error popping value from list 'test_list': Connection failed" in result
145 |
146 | @pytest.mark.asyncio
147 | async def test_lrange_success(self, mock_redis_connection_manager):
148 | """Test successful list range operation."""
149 | mock_redis = mock_redis_connection_manager
150 | mock_redis.lrange.return_value = ["item1", "item2", "item3"]
151 |
152 | result = await lrange("test_list", 0, 2)
153 |
154 | mock_redis.lrange.assert_called_once_with("test_list", 0, 2)
155 | assert result == '["item1", "item2", "item3"]'
156 |
157 | @pytest.mark.asyncio
158 | async def test_lrange_default_parameters(self, mock_redis_connection_manager):
159 | """Test list range operation with default parameters."""
160 | mock_redis = mock_redis_connection_manager
161 | mock_redis.lrange.return_value = ["item1", "item2"]
162 |
163 | result = await lrange("test_list", 0, -1)
164 |
165 | mock_redis.lrange.assert_called_once_with("test_list", 0, -1)
166 | assert result == '["item1", "item2"]'
167 |
168 | @pytest.mark.asyncio
169 | async def test_lrange_empty_list(self, mock_redis_connection_manager):
170 | """Test list range operation on empty list."""
171 | mock_redis = mock_redis_connection_manager
172 | mock_redis.lrange.return_value = []
173 |
174 | result = await lrange("empty_list", 0, -1)
175 |
176 | assert "List 'empty_list' is empty or does not exist" in result
177 |
178 | @pytest.mark.asyncio
179 | async def test_lrange_redis_error(self, mock_redis_connection_manager):
180 | """Test list range operation with Redis error."""
181 | mock_redis = mock_redis_connection_manager
182 | mock_redis.lrange.side_effect = RedisError("Connection failed")
183 |
184 | result = await lrange("test_list", 0, -1)
185 |
186 | assert (
187 | "Error retrieving values from list 'test_list': Connection failed" in result
188 | )
189 |
190 | @pytest.mark.asyncio
191 | async def test_llen_success(self, mock_redis_connection_manager):
192 | """Test successful list length operation."""
193 | mock_redis = mock_redis_connection_manager
194 | mock_redis.llen.return_value = 5
195 |
196 | result = await llen("test_list")
197 |
198 | mock_redis.llen.assert_called_once_with("test_list")
199 | assert result == 5
200 |
201 | @pytest.mark.asyncio
202 | async def test_llen_empty_list(self, mock_redis_connection_manager):
203 | """Test list length operation on empty list."""
204 | mock_redis = mock_redis_connection_manager
205 | mock_redis.llen.return_value = 0
206 |
207 | result = await llen("empty_list")
208 |
209 | assert result == 0
210 |
211 | @pytest.mark.asyncio
212 | async def test_llen_redis_error(self, mock_redis_connection_manager):
213 | """Test list length operation with Redis error."""
214 | mock_redis = mock_redis_connection_manager
215 | mock_redis.llen.side_effect = RedisError("Connection failed")
216 |
217 | result = await llen("test_list")
218 |
219 | assert (
220 | "Error retrieving length of list 'test_list': Connection failed" in result
221 | )
222 |
223 | @pytest.mark.asyncio
224 | async def test_push_operations_with_numeric_values(
225 | self, mock_redis_connection_manager
226 | ):
227 | """Test push operations with numeric values."""
228 | mock_redis = mock_redis_connection_manager
229 | mock_redis.lpush.return_value = 1
230 | mock_redis.rpush.return_value = 2
231 |
232 | # Test with integer
233 | result1 = await lpush("test_list", 42)
234 | mock_redis.lpush.assert_called_with("test_list", 42)
235 |
236 | # Test with float
237 | result2 = await rpush("test_list", 3.14)
238 | mock_redis.rpush.assert_called_with("test_list", 3.14)
239 |
240 | assert "pushed to the left of list" in result1
241 | assert "pushed to the right of list" in result2
242 |
243 | @pytest.mark.asyncio
244 | async def test_lrange_with_negative_indices(self, mock_redis_connection_manager):
245 | """Test list range operation with negative indices."""
246 | mock_redis = mock_redis_connection_manager
247 | mock_redis.lrange.return_value = ["last_item"]
248 |
249 | result = await lrange("test_list", -1, -1)
250 |
251 | mock_redis.lrange.assert_called_once_with("test_list", -1, -1)
252 | assert result == '["last_item"]'
253 |
254 | @pytest.mark.asyncio
255 | async def test_expiration_error_handling(self, mock_redis_connection_manager):
256 | """Test expiration error handling in push operations."""
257 | mock_redis = mock_redis_connection_manager
258 | mock_redis.lpush.return_value = 1
259 | mock_redis.expire.side_effect = RedisError("Expire failed")
260 |
261 | result = await lpush("test_list", "value", 60)
262 |
263 | # Should report the expire error
264 | assert "Error pushing value to list 'test_list': Expire failed" in result
265 |
266 | @pytest.mark.asyncio
267 | async def test_push_operations_return_new_length(
268 | self, mock_redis_connection_manager
269 | ):
270 | """Test that push operations handle return values correctly."""
271 | mock_redis = mock_redis_connection_manager
272 | mock_redis.lpush.return_value = 3
273 | mock_redis.rpush.return_value = 4
274 |
275 | result1 = await lpush("test_list", "value1")
276 | result2 = await rpush("test_list", "value2")
277 |
278 | # Results should indicate successful push regardless of return value
279 | assert "pushed to the left of list" in result1
280 | assert "pushed to the right of list" in result2
281 |
```
--------------------------------------------------------------------------------
/tests/tools/test_sorted_set.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/sorted_set.py
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pytest
8 | from redis.exceptions import RedisError
9 |
10 | from src.tools.sorted_set import zadd, zrange, zrem
11 |
12 |
13 | class TestSortedSetOperations:
14 | """Test cases for Redis sorted set operations."""
15 |
16 | @pytest.mark.asyncio
17 | async def test_zadd_success(self, mock_redis_connection_manager):
18 | """Test successful sorted set add operation."""
19 | mock_redis = mock_redis_connection_manager
20 | mock_redis.zadd.return_value = 1 # Number of elements added
21 |
22 | result = await zadd("test_zset", 1.5, "member1")
23 |
24 | mock_redis.zadd.assert_called_once_with("test_zset", {"member1": 1.5})
25 | assert "Successfully added member1 to test_zset with score 1.5" in result
26 |
27 | @pytest.mark.asyncio
28 | async def test_zadd_with_expiration(self, mock_redis_connection_manager):
29 | """Test sorted set add operation with expiration."""
30 | mock_redis = mock_redis_connection_manager
31 | mock_redis.zadd.return_value = 1
32 | mock_redis.expire.return_value = True
33 |
34 | result = await zadd("test_zset", 2.0, "member1", 60)
35 |
36 | mock_redis.zadd.assert_called_once_with("test_zset", {"member1": 2.0})
37 | mock_redis.expire.assert_called_once_with("test_zset", 60)
38 | assert "and expiration 60 seconds" in result
39 |
40 | @pytest.mark.asyncio
41 | async def test_zadd_member_updated(self, mock_redis_connection_manager):
42 | """Test sorted set add operation when member score is updated."""
43 | mock_redis = mock_redis_connection_manager
44 | mock_redis.zadd.return_value = 0 # Member already exists, score updated
45 |
46 | result = await zadd("test_zset", 3.0, "existing_member")
47 |
48 | assert (
49 | "Successfully added existing_member to test_zset with score 3.0" in result
50 | )
51 |
52 | @pytest.mark.asyncio
53 | async def test_zadd_redis_error(self, mock_redis_connection_manager):
54 | """Test sorted set add operation with Redis error."""
55 | mock_redis = mock_redis_connection_manager
56 | mock_redis.zadd.side_effect = RedisError("Connection failed")
57 |
58 | result = await zadd("test_zset", 1.0, "member1")
59 |
60 | assert "Error adding to sorted set test_zset: Connection failed" in result
61 |
62 | @pytest.mark.asyncio
63 | async def test_zadd_integer_score(self, mock_redis_connection_manager):
64 | """Test sorted set add operation with integer score."""
65 | mock_redis = mock_redis_connection_manager
66 | mock_redis.zadd.return_value = 1
67 |
68 | result = await zadd("test_zset", 5, "member1")
69 |
70 | mock_redis.zadd.assert_called_once_with("test_zset", {"member1": 5})
71 | assert "Successfully added member1 to test_zset with score 5" in result
72 |
73 | @pytest.mark.asyncio
74 | async def test_zrange_success_without_scores(self, mock_redis_connection_manager):
75 | """Test successful sorted set range operation without scores."""
76 | mock_redis = mock_redis_connection_manager
77 | mock_redis.zrange.return_value = ["member1", "member2", "member3"]
78 |
79 | result = await zrange("test_zset", 0, 2)
80 |
81 | mock_redis.zrange.assert_called_once_with("test_zset", 0, 2, withscores=False)
82 | assert result == "['member1', 'member2', 'member3']"
83 |
84 | @pytest.mark.asyncio
85 | async def test_zrange_success_with_scores(self, mock_redis_connection_manager):
86 | """Test successful sorted set range operation with scores."""
87 | mock_redis = mock_redis_connection_manager
88 | mock_redis.zrange.return_value = [
89 | ("member1", 1.0),
90 | ("member2", 2.0),
91 | ("member3", 3.0),
92 | ]
93 |
94 | result = await zrange("test_zset", 0, 2, True)
95 |
96 | mock_redis.zrange.assert_called_once_with("test_zset", 0, 2, withscores=True)
97 | assert result == "[('member1', 1.0), ('member2', 2.0), ('member3', 3.0)]"
98 |
99 | @pytest.mark.asyncio
100 | async def test_zrange_default_parameters(self, mock_redis_connection_manager):
101 | """Test sorted set range operation with default parameters."""
102 | mock_redis = mock_redis_connection_manager
103 | mock_redis.zrange.return_value = ["member1", "member2"]
104 |
105 | result = await zrange("test_zset", 0, -1)
106 |
107 | mock_redis.zrange.assert_called_once_with("test_zset", 0, -1, withscores=False)
108 | assert result == "['member1', 'member2']"
109 |
110 | @pytest.mark.asyncio
111 | async def test_zrange_empty_set(self, mock_redis_connection_manager):
112 | """Test sorted set range operation on empty set."""
113 | mock_redis = mock_redis_connection_manager
114 | mock_redis.zrange.return_value = []
115 |
116 | result = await zrange("empty_zset", 0, -1)
117 |
118 | mock_redis.zrange.assert_called_once_with("empty_zset", 0, -1, withscores=False)
119 | assert "Sorted set empty_zset is empty or does not exist" in result
120 |
121 | @pytest.mark.asyncio
122 | async def test_zrange_redis_error(self, mock_redis_connection_manager):
123 | """Test sorted set range operation with Redis error."""
124 | mock_redis = mock_redis_connection_manager
125 | mock_redis.zrange.side_effect = RedisError("Connection failed")
126 |
127 | result = await zrange("test_zset", 0, -1)
128 |
129 | assert "Error retrieving sorted set test_zset: Connection failed" in result
130 |
131 | @pytest.mark.asyncio
132 | async def test_zrem_success(self, mock_redis_connection_manager):
133 | """Test successful sorted set remove operation."""
134 | mock_redis = mock_redis_connection_manager
135 | mock_redis.zrem.return_value = 1 # Number of elements removed
136 |
137 | result = await zrem("test_zset", "member1")
138 |
139 | mock_redis.zrem.assert_called_once_with("test_zset", "member1")
140 | assert "Successfully removed member1 from test_zset" in result
141 |
142 | @pytest.mark.asyncio
143 | async def test_zrem_member_not_exists(self, mock_redis_connection_manager):
144 | """Test sorted set remove operation when member doesn't exist."""
145 | mock_redis = mock_redis_connection_manager
146 | mock_redis.zrem.return_value = 0 # Member doesn't exist
147 |
148 | result = await zrem("test_zset", "nonexistent_member")
149 |
150 | assert "Member nonexistent_member not found in test_zset" in result
151 |
152 | @pytest.mark.asyncio
153 | async def test_zrem_redis_error(self, mock_redis_connection_manager):
154 | """Test sorted set remove operation with Redis error."""
155 | mock_redis = mock_redis_connection_manager
156 | mock_redis.zrem.side_effect = RedisError("Connection failed")
157 |
158 | result = await zrem("test_zset", "member1")
159 |
160 | assert "Error removing from sorted set test_zset: Connection failed" in result
161 |
162 | @pytest.mark.asyncio
163 | async def test_zadd_negative_score(self, mock_redis_connection_manager):
164 | """Test sorted set add operation with negative score."""
165 | mock_redis = mock_redis_connection_manager
166 | mock_redis.zadd.return_value = 1
167 |
168 | result = await zadd("test_zset", -1.5, "negative_member")
169 |
170 | mock_redis.zadd.assert_called_once_with("test_zset", {"negative_member": -1.5})
171 | assert (
172 | "Successfully added negative_member to test_zset with score -1.5" in result
173 | )
174 |
175 | @pytest.mark.asyncio
176 | async def test_zadd_zero_score(self, mock_redis_connection_manager):
177 | """Test sorted set add operation with zero score."""
178 | mock_redis = mock_redis_connection_manager
179 | mock_redis.zadd.return_value = 1
180 |
181 | result = await zadd("test_zset", 0, "zero_member")
182 |
183 | mock_redis.zadd.assert_called_once_with("test_zset", {"zero_member": 0})
184 | assert "Successfully added zero_member to test_zset with score 0" in result
185 |
186 | @pytest.mark.asyncio
187 | async def test_zrange_negative_indices(self, mock_redis_connection_manager):
188 | """Test sorted set range operation with negative indices."""
189 | mock_redis = mock_redis_connection_manager
190 | mock_redis.zrange.return_value = ["last_member"]
191 |
192 | result = await zrange("test_zset", -1, -1)
193 |
194 | mock_redis.zrange.assert_called_once_with("test_zset", -1, -1, withscores=False)
195 | assert result == "['last_member']"
196 |
197 | @pytest.mark.asyncio
198 | async def test_zadd_expiration_error(self, mock_redis_connection_manager):
199 | """Test sorted set add operation when expiration fails."""
200 | mock_redis = mock_redis_connection_manager
201 | mock_redis.zadd.return_value = 1
202 | mock_redis.expire.side_effect = RedisError("Expire failed")
203 |
204 | result = await zadd("test_zset", 1.0, "member1", 60)
205 |
206 | assert "Error adding to sorted set test_zset: Expire failed" in result
207 |
208 | @pytest.mark.asyncio
209 | async def test_zadd_with_unicode_member(self, mock_redis_connection_manager):
210 | """Test sorted set add operation with unicode member."""
211 | mock_redis = mock_redis_connection_manager
212 | mock_redis.zadd.return_value = 1
213 |
214 | unicode_member = "测试成员 🚀"
215 | result = await zadd("test_zset", 1.0, unicode_member)
216 |
217 | mock_redis.zadd.assert_called_once_with("test_zset", {unicode_member: 1.0})
218 | assert (
219 | f"Successfully added {unicode_member} to test_zset with score 1.0" in result
220 | )
221 |
222 | @pytest.mark.asyncio
223 | async def test_zrange_large_range(self, mock_redis_connection_manager):
224 | """Test sorted set range operation with large range."""
225 | mock_redis = mock_redis_connection_manager
226 | large_result = [f"member_{i}" for i in range(1000)]
227 | mock_redis.zrange.return_value = large_result
228 |
229 | result = await zrange("large_zset", 0, 999)
230 |
231 | # The function returns a string representation
232 | assert result == str(large_result)
233 | # Check that the original list had 1000 items
234 | assert len(large_result) == 1000
235 |
236 | @pytest.mark.asyncio
237 | async def test_connection_manager_called_correctly(self):
238 | """Test that RedisConnectionManager.get_connection is called correctly."""
239 | with patch(
240 | "src.tools.sorted_set.RedisConnectionManager.get_connection"
241 | ) as mock_get_conn:
242 | mock_redis = Mock()
243 | mock_redis.zadd.return_value = 1
244 | mock_get_conn.return_value = mock_redis
245 |
246 | await zadd("test_zset", 1.0, "member1")
247 |
248 | mock_get_conn.assert_called_once()
249 |
250 | @pytest.mark.asyncio
251 | async def test_function_signatures(self):
252 | """Test that functions have correct signatures."""
253 | import inspect
254 |
255 | # Test zadd function signature
256 | zadd_sig = inspect.signature(zadd)
257 | zadd_params = list(zadd_sig.parameters.keys())
258 | assert zadd_params == ["key", "score", "member", "expiration"]
259 | assert zadd_sig.parameters["expiration"].default is None
260 |
261 | # Test zrange function signature
262 | zrange_sig = inspect.signature(zrange)
263 | zrange_params = list(zrange_sig.parameters.keys())
264 | assert zrange_params == ["key", "start", "end", "with_scores"]
265 | # start and end are required parameters (no defaults)
266 | assert zrange_sig.parameters["start"].default == inspect.Parameter.empty
267 | assert zrange_sig.parameters["end"].default == inspect.Parameter.empty
268 | assert zrange_sig.parameters["with_scores"].default is False
269 |
270 | # Test zrem function signature
271 | zrem_sig = inspect.signature(zrem)
272 | zrem_params = list(zrem_sig.parameters.keys())
273 | assert zrem_params == ["key", "member"]
274 |
```
--------------------------------------------------------------------------------
/tests/test_connection.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/common/connection.py
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pytest
8 | from redis.exceptions import ConnectionError
9 |
10 | from src.common.connection import RedisConnectionManager
11 |
12 |
13 | class TestRedisConnectionManager:
14 | """Test cases for RedisConnectionManager class."""
15 |
16 | def setup_method(self):
17 | """Set up test fixtures."""
18 | # Reset singleton instance before each test
19 | RedisConnectionManager._instance = None
20 |
21 | def teardown_method(self):
22 | """Clean up after each test."""
23 | # Reset singleton instance after each test
24 | RedisConnectionManager._instance = None
25 |
26 | @patch("src.common.connection.redis.Redis")
27 | @patch("src.common.connection.REDIS_CFG")
28 | def test_get_connection_standalone_mode(self, mock_config, mock_redis_class):
29 | """Test getting connection in standalone mode."""
30 | mock_config.__getitem__.side_effect = lambda key: {
31 | "cluster_mode": False,
32 | "host": "localhost",
33 | "port": 6379,
34 | "db": 0,
35 | "username": None,
36 | "password": "",
37 | "ssl": False,
38 | "ssl_ca_path": None,
39 | "ssl_keyfile": None,
40 | "ssl_certfile": None,
41 | "ssl_cert_reqs": "required",
42 | "ssl_ca_certs": None,
43 | }[key]
44 |
45 | mock_redis_instance = Mock()
46 | mock_redis_class.return_value = mock_redis_instance
47 |
48 | connection = RedisConnectionManager.get_connection()
49 |
50 | assert connection == mock_redis_instance
51 | mock_redis_class.assert_called_once()
52 |
53 | # Verify connection parameters
54 | call_args = mock_redis_class.call_args[1]
55 | assert call_args["host"] == "localhost"
56 | assert call_args["port"] == 6379
57 | assert call_args["db"] == 0
58 | assert call_args["decode_responses"] is True
59 | assert call_args["max_connections"] == 10
60 | assert "lib_name" in call_args
61 |
62 | @patch("src.common.connection.redis.cluster.RedisCluster")
63 | @patch("src.common.connection.REDIS_CFG")
64 | def test_get_connection_cluster_mode(self, mock_config, mock_cluster_class):
65 | """Test getting connection in cluster mode."""
66 | mock_config.__getitem__.side_effect = lambda key: {
67 | "cluster_mode": True,
68 | "host": "localhost",
69 | "port": 6379,
70 | "username": "testuser",
71 | "password": "testpass",
72 | "ssl": True,
73 | "ssl_ca_path": "/path/to/ca.pem",
74 | "ssl_keyfile": "/path/to/key.pem",
75 | "ssl_certfile": "/path/to/cert.pem",
76 | "ssl_cert_reqs": "required",
77 | "ssl_ca_certs": "/path/to/ca-bundle.pem",
78 | }[key]
79 |
80 | mock_cluster_instance = Mock()
81 | mock_cluster_class.return_value = mock_cluster_instance
82 |
83 | connection = RedisConnectionManager.get_connection()
84 |
85 | assert connection == mock_cluster_instance
86 | mock_cluster_class.assert_called_once()
87 |
88 | # Verify connection parameters
89 | call_args = mock_cluster_class.call_args[1]
90 | assert call_args["host"] == "localhost"
91 | assert call_args["port"] == 6379
92 | assert call_args["username"] == "testuser"
93 | assert call_args["password"] == "testpass"
94 | assert call_args["ssl"] is True
95 | assert call_args["ssl_ca_path"] == "/path/to/ca.pem"
96 | assert call_args["decode_responses"] is True
97 | assert call_args["max_connections_per_node"] == 10
98 | assert "lib_name" in call_args
99 |
100 | @patch("src.common.connection.redis.Redis")
101 | @patch("src.common.connection.REDIS_CFG")
102 | def test_get_connection_singleton_behavior(self, mock_config, mock_redis_class):
103 | """Test that get_connection returns the same instance (singleton behavior)."""
104 | mock_config.__getitem__.side_effect = lambda key: {
105 | "cluster_mode": False,
106 | "host": "localhost",
107 | "port": 6379,
108 | "db": 0,
109 | "username": None,
110 | "password": "",
111 | "ssl": False,
112 | "ssl_ca_path": None,
113 | "ssl_keyfile": None,
114 | "ssl_certfile": None,
115 | "ssl_cert_reqs": "required",
116 | "ssl_ca_certs": None,
117 | }[key]
118 |
119 | mock_redis_instance = Mock()
120 | mock_redis_class.return_value = mock_redis_instance
121 |
122 | # First call
123 | connection1 = RedisConnectionManager.get_connection()
124 | # Second call
125 | connection2 = RedisConnectionManager.get_connection()
126 |
127 | assert connection1 == connection2
128 | assert connection1 == mock_redis_instance
129 | # Redis class should only be called once
130 | mock_redis_class.assert_called_once()
131 |
132 | @patch("src.common.connection.redis.Redis")
133 | @patch("src.common.connection.REDIS_CFG")
134 | def test_get_connection_with_decode_responses_false(
135 | self, mock_config, mock_redis_class
136 | ):
137 | """Test getting connection with decode_responses=False."""
138 | mock_config.__getitem__.side_effect = lambda key: {
139 | "cluster_mode": False,
140 | "host": "localhost",
141 | "port": 6379,
142 | "db": 0,
143 | "username": None,
144 | "password": "",
145 | "ssl": False,
146 | "ssl_ca_path": None,
147 | "ssl_keyfile": None,
148 | "ssl_certfile": None,
149 | "ssl_cert_reqs": "required",
150 | "ssl_ca_certs": None,
151 | }[key]
152 |
153 | mock_redis_instance = Mock()
154 | mock_redis_class.return_value = mock_redis_instance
155 |
156 | connection = RedisConnectionManager.get_connection(decode_responses=False)
157 | assert connection == mock_redis_instance
158 |
159 | call_args = mock_redis_class.call_args[1]
160 | assert call_args["decode_responses"] is False
161 |
162 | @patch("src.common.connection.redis.Redis")
163 | @patch("src.common.connection.REDIS_CFG")
164 | def test_get_connection_with_ssl_configuration(self, mock_config, mock_redis_class):
165 | """Test getting connection with SSL configuration."""
166 | mock_config.__getitem__.side_effect = lambda key: {
167 | "cluster_mode": False,
168 | "host": "redis.example.com",
169 | "port": 6380,
170 | "db": 1,
171 | "username": "ssluser",
172 | "password": "sslpass",
173 | "ssl": True,
174 | "ssl_ca_path": "/path/to/ca.pem",
175 | "ssl_keyfile": "/path/to/key.pem",
176 | "ssl_certfile": "/path/to/cert.pem",
177 | "ssl_cert_reqs": "optional",
178 | "ssl_ca_certs": "/path/to/ca-bundle.pem",
179 | }[key]
180 |
181 | mock_redis_instance = Mock()
182 | mock_redis_class.return_value = mock_redis_instance
183 |
184 | connection = RedisConnectionManager.get_connection()
185 | assert connection == mock_redis_instance
186 |
187 | call_args = mock_redis_class.call_args[1]
188 | assert call_args["ssl"] is True
189 | assert call_args["ssl_ca_path"] == "/path/to/ca.pem"
190 | assert call_args["ssl_keyfile"] == "/path/to/key.pem"
191 | assert call_args["ssl_certfile"] == "/path/to/cert.pem"
192 | assert call_args["ssl_cert_reqs"] == "optional"
193 | assert call_args["ssl_ca_certs"] == "/path/to/ca-bundle.pem"
194 |
195 | @patch("src.common.connection.redis.Redis")
196 | @patch("src.common.connection.REDIS_CFG")
197 | def test_get_connection_includes_version_in_lib_name(
198 | self, mock_config, mock_redis_class
199 | ):
200 | """Test that connection includes version information in lib_name."""
201 | mock_config.__getitem__.side_effect = lambda key: {
202 | "cluster_mode": False,
203 | "host": "localhost",
204 | "port": 6379,
205 | "db": 0,
206 | "username": None,
207 | "password": "",
208 | "ssl": False,
209 | "ssl_ca_path": None,
210 | "ssl_keyfile": None,
211 | "ssl_certfile": None,
212 | "ssl_cert_reqs": "required",
213 | "ssl_ca_certs": None,
214 | }[key]
215 |
216 | mock_redis_instance = Mock()
217 | mock_redis_class.return_value = mock_redis_instance
218 |
219 | with patch("src.common.connection.__version__", "1.0.0"):
220 | connection = RedisConnectionManager.get_connection()
221 |
222 | assert connection == mock_redis_instance
223 |
224 | call_args = mock_redis_class.call_args[1]
225 | assert "redis-py(mcp-server_v1.0.0)" in call_args["lib_name"]
226 |
227 | @patch("src.common.connection.redis.Redis")
228 | @patch("src.common.connection.REDIS_CFG")
229 | def test_connection_error_handling(self, mock_config, mock_redis_class):
230 | """Test connection error handling."""
231 | mock_config.__getitem__.side_effect = lambda key: {
232 | "cluster_mode": False,
233 | "host": "localhost",
234 | "port": 6379,
235 | "db": 0,
236 | "username": None,
237 | "password": "",
238 | "ssl": False,
239 | "ssl_ca_path": None,
240 | "ssl_keyfile": None,
241 | "ssl_certfile": None,
242 | "ssl_cert_reqs": "required",
243 | "ssl_ca_certs": None,
244 | }[key]
245 |
246 | # Mock Redis constructor to raise ConnectionError
247 | mock_redis_class.side_effect = ConnectionError("Connection refused")
248 |
249 | with pytest.raises(ConnectionError, match="Connection refused"):
250 | RedisConnectionManager.get_connection()
251 |
252 | @patch("src.common.connection.redis.cluster.RedisCluster")
253 | @patch("src.common.connection.REDIS_CFG")
254 | def test_cluster_connection_error_handling(self, mock_config, mock_cluster_class):
255 | """Test cluster connection error handling."""
256 | mock_config.__getitem__.side_effect = lambda key: {
257 | "cluster_mode": True,
258 | "host": "localhost",
259 | "port": 6379,
260 | "username": None,
261 | "password": "",
262 | "ssl": False,
263 | "ssl_ca_path": None,
264 | "ssl_keyfile": None,
265 | "ssl_certfile": None,
266 | "ssl_cert_reqs": "required",
267 | "ssl_ca_certs": None,
268 | }[key]
269 |
270 | # Mock RedisCluster constructor to raise ConnectionError
271 | mock_cluster_class.side_effect = ConnectionError("Cluster connection failed")
272 |
273 | with pytest.raises(ConnectionError, match="Cluster connection failed"):
274 | RedisConnectionManager.get_connection()
275 |
276 | def test_reset_instance(self):
277 | """Test that the singleton instance can be reset."""
278 | # Set up a mock instance
279 | mock_instance = Mock()
280 | RedisConnectionManager._instance = mock_instance
281 |
282 | # Verify instance is set
283 | assert RedisConnectionManager._instance == mock_instance
284 |
285 | # Reset instance
286 | RedisConnectionManager._instance = None
287 |
288 | # Verify instance is reset
289 | assert RedisConnectionManager._instance is None
290 |
291 | @patch("src.common.connection.redis.Redis")
292 | @patch("src.common.connection.REDIS_CFG")
293 | def test_connection_parameters_filtering(self, mock_config, mock_redis_class):
294 | """Test that None values are properly handled in connection parameters."""
295 | mock_config.__getitem__.side_effect = lambda key: {
296 | "cluster_mode": False,
297 | "host": "localhost",
298 | "port": 6379,
299 | "db": 0,
300 | "username": None, # This should be passed as None
301 | "password": "", # This should be passed as empty string
302 | "ssl": False,
303 | "ssl_ca_path": None,
304 | "ssl_keyfile": None,
305 | "ssl_certfile": None,
306 | "ssl_cert_reqs": "required",
307 | "ssl_ca_certs": None,
308 | }[key]
309 |
310 | mock_redis_instance = Mock()
311 | mock_redis_class.return_value = mock_redis_instance
312 |
313 | connection = RedisConnectionManager.get_connection()
314 |
315 | assert connection == mock_redis_instance
316 |
317 | call_args = mock_redis_class.call_args[1]
318 | assert call_args["username"] is None
319 | assert call_args["password"] == ""
320 | assert call_args["ssl_ca_path"] is None
321 |
```
--------------------------------------------------------------------------------
/tests/tools/test_pub_sub.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/pub_sub.py
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pytest
8 | from redis.exceptions import ConnectionError, RedisError
9 |
10 | from src.tools.pub_sub import publish, subscribe, unsubscribe
11 |
12 |
13 | class TestPubSubOperations:
14 | """Test cases for Redis pub/sub operations."""
15 |
16 | @pytest.mark.asyncio
17 | async def test_publish_success(self, mock_redis_connection_manager):
18 | """Test successful publish operation."""
19 | mock_redis = mock_redis_connection_manager
20 | mock_redis.publish.return_value = (
21 | 2 # Number of subscribers that received the message
22 | )
23 |
24 | result = await publish("test_channel", "Hello World")
25 |
26 | mock_redis.publish.assert_called_once_with("test_channel", "Hello World")
27 | assert "Message published to channel 'test_channel'" in result
28 |
29 | @pytest.mark.asyncio
30 | async def test_publish_no_subscribers(self, mock_redis_connection_manager):
31 | """Test publish operation with no subscribers."""
32 | mock_redis = mock_redis_connection_manager
33 | mock_redis.publish.return_value = 0 # No subscribers
34 |
35 | result = await publish("empty_channel", "Hello World")
36 |
37 | mock_redis.publish.assert_called_once_with("empty_channel", "Hello World")
38 | assert "Message published to channel 'empty_channel'" in result
39 |
40 | @pytest.mark.asyncio
41 | async def test_publish_redis_error(self, mock_redis_connection_manager):
42 | """Test publish operation with Redis error."""
43 | mock_redis = mock_redis_connection_manager
44 | mock_redis.publish.side_effect = RedisError("Connection failed")
45 |
46 | result = await publish("test_channel", "Hello World")
47 |
48 | assert (
49 | "Error publishing message to channel 'test_channel': Connection failed"
50 | in result
51 | )
52 |
53 | @pytest.mark.asyncio
54 | async def test_publish_connection_error(self, mock_redis_connection_manager):
55 | """Test publish operation with connection error."""
56 | mock_redis = mock_redis_connection_manager
57 | mock_redis.publish.side_effect = ConnectionError("Redis server unavailable")
58 |
59 | result = await publish("test_channel", "Hello World")
60 |
61 | assert (
62 | "Error publishing message to channel 'test_channel': Redis server unavailable"
63 | in result
64 | )
65 |
66 | @pytest.mark.asyncio
67 | async def test_publish_empty_message(self, mock_redis_connection_manager):
68 | """Test publish operation with empty message."""
69 | mock_redis = mock_redis_connection_manager
70 | mock_redis.publish.return_value = 1
71 |
72 | result = await publish("test_channel", "")
73 |
74 | mock_redis.publish.assert_called_once_with("test_channel", "")
75 | assert "Message published to channel 'test_channel'" in result
76 |
77 | @pytest.mark.asyncio
78 | async def test_publish_numeric_message(self, mock_redis_connection_manager):
79 | """Test publish operation with numeric message."""
80 | mock_redis = mock_redis_connection_manager
81 | mock_redis.publish.return_value = 1
82 |
83 | result = await publish("test_channel", 42)
84 |
85 | mock_redis.publish.assert_called_once_with("test_channel", 42)
86 | assert "Message published to channel 'test_channel'" in result
87 |
88 | @pytest.mark.asyncio
89 | async def test_publish_json_message(self, mock_redis_connection_manager):
90 | """Test publish operation with JSON-like message."""
91 | mock_redis = mock_redis_connection_manager
92 | mock_redis.publish.return_value = 3
93 |
94 | json_message = (
95 | '{"type": "notification", "data": {"user": "john", "action": "login"}}'
96 | )
97 | result = await publish("notifications", json_message)
98 |
99 | mock_redis.publish.assert_called_once_with("notifications", json_message)
100 | assert "Message published to channel 'notifications'" in result
101 |
102 | @pytest.mark.asyncio
103 | async def test_publish_unicode_message(self, mock_redis_connection_manager):
104 | """Test publish operation with unicode message."""
105 | mock_redis = mock_redis_connection_manager
106 | mock_redis.publish.return_value = 1
107 |
108 | unicode_message = "Hello 世界 🌍"
109 | result = await publish("test_channel", unicode_message)
110 |
111 | mock_redis.publish.assert_called_once_with("test_channel", unicode_message)
112 | assert "Message published to channel 'test_channel'" in result
113 |
114 | @pytest.mark.asyncio
115 | async def test_subscribe_success(self, mock_redis_connection_manager):
116 | """Test successful subscribe operation."""
117 | mock_redis = mock_redis_connection_manager
118 | mock_pubsub = Mock()
119 | mock_redis.pubsub.return_value = mock_pubsub
120 | mock_pubsub.subscribe.return_value = None
121 |
122 | result = await subscribe("test_channel")
123 |
124 | mock_redis.pubsub.assert_called_once()
125 | mock_pubsub.subscribe.assert_called_once_with("test_channel")
126 | assert "Subscribed to channel 'test_channel'" in result
127 |
128 | @pytest.mark.asyncio
129 | async def test_subscribe_redis_error(self, mock_redis_connection_manager):
130 | """Test subscribe operation with Redis error."""
131 | mock_redis = mock_redis_connection_manager
132 | mock_redis.pubsub.side_effect = RedisError("Connection failed")
133 |
134 | result = await subscribe("test_channel")
135 |
136 | assert (
137 | "Error subscribing to channel 'test_channel': Connection failed" in result
138 | )
139 |
140 | @pytest.mark.asyncio
141 | async def test_subscribe_pubsub_error(self, mock_redis_connection_manager):
142 | """Test subscribe operation with pubsub creation error."""
143 | mock_redis = mock_redis_connection_manager
144 | mock_pubsub = Mock()
145 | mock_redis.pubsub.return_value = mock_pubsub
146 | mock_pubsub.subscribe.side_effect = RedisError("Subscribe failed")
147 |
148 | result = await subscribe("test_channel")
149 |
150 | assert "Error subscribing to channel 'test_channel': Subscribe failed" in result
151 |
152 | @pytest.mark.asyncio
153 | async def test_subscribe_multiple_channels_pattern(
154 | self, mock_redis_connection_manager
155 | ):
156 | """Test subscribe operation with pattern-like channel name."""
157 | mock_redis = mock_redis_connection_manager
158 | mock_pubsub = Mock()
159 | mock_redis.pubsub.return_value = mock_pubsub
160 | mock_pubsub.subscribe.return_value = None
161 |
162 | pattern_channel = "notifications:*"
163 | result = await subscribe(pattern_channel)
164 |
165 | mock_pubsub.subscribe.assert_called_once_with(pattern_channel)
166 | assert f"Subscribed to channel '{pattern_channel}'" in result
167 |
168 | @pytest.mark.asyncio
169 | async def test_unsubscribe_success(self, mock_redis_connection_manager):
170 | """Test successful unsubscribe operation."""
171 | mock_redis = mock_redis_connection_manager
172 | mock_pubsub = Mock()
173 | mock_redis.pubsub.return_value = mock_pubsub
174 | mock_pubsub.unsubscribe.return_value = None
175 |
176 | result = await unsubscribe("test_channel")
177 |
178 | mock_redis.pubsub.assert_called_once()
179 | mock_pubsub.unsubscribe.assert_called_once_with("test_channel")
180 | assert "Unsubscribed from channel 'test_channel'" in result
181 |
182 | @pytest.mark.asyncio
183 | async def test_unsubscribe_redis_error(self, mock_redis_connection_manager):
184 | """Test unsubscribe operation with Redis error."""
185 | mock_redis = mock_redis_connection_manager
186 | mock_redis.pubsub.side_effect = RedisError("Connection failed")
187 |
188 | result = await unsubscribe("test_channel")
189 |
190 | assert (
191 | "Error unsubscribing from channel 'test_channel': Connection failed"
192 | in result
193 | )
194 |
195 | @pytest.mark.asyncio
196 | async def test_unsubscribe_pubsub_error(self, mock_redis_connection_manager):
197 | """Test unsubscribe operation with pubsub error."""
198 | mock_redis = mock_redis_connection_manager
199 | mock_pubsub = Mock()
200 | mock_redis.pubsub.return_value = mock_pubsub
201 | mock_pubsub.unsubscribe.side_effect = RedisError("Unsubscribe failed")
202 |
203 | result = await unsubscribe("test_channel")
204 |
205 | assert (
206 | "Error unsubscribing from channel 'test_channel': Unsubscribe failed"
207 | in result
208 | )
209 |
210 | @pytest.mark.asyncio
211 | async def test_unsubscribe_from_all_channels(self, mock_redis_connection_manager):
212 | """Test unsubscribe operation without specifying channel (unsubscribe from all)."""
213 | mock_redis = mock_redis_connection_manager
214 | mock_pubsub = Mock()
215 | mock_redis.pubsub.return_value = mock_pubsub
216 | mock_pubsub.unsubscribe.return_value = None
217 |
218 | # Test unsubscribing from specific channel
219 | result = await unsubscribe("specific_channel")
220 |
221 | mock_pubsub.unsubscribe.assert_called_once_with("specific_channel")
222 | assert "Unsubscribed from channel 'specific_channel'" in result
223 |
224 | @pytest.mark.asyncio
225 | async def test_publish_to_pattern_channel(self, mock_redis_connection_manager):
226 | """Test publish operation to pattern-like channel."""
227 | mock_redis = mock_redis_connection_manager
228 | mock_redis.publish.return_value = 5
229 |
230 | pattern_channel = "user:123:notifications"
231 | result = await publish(pattern_channel, "User notification")
232 |
233 | mock_redis.publish.assert_called_once_with(pattern_channel, "User notification")
234 | assert f"Message published to channel '{pattern_channel}'" in result
235 |
236 | @pytest.mark.asyncio
237 | async def test_subscribe_with_special_characters(
238 | self, mock_redis_connection_manager
239 | ):
240 | """Test subscribe operation with special characters in channel name."""
241 | mock_redis = mock_redis_connection_manager
242 | mock_pubsub = Mock()
243 | mock_redis.pubsub.return_value = mock_pubsub
244 | mock_pubsub.subscribe.return_value = None
245 |
246 | special_channel = "channel:with:colons-and-dashes_and_underscores"
247 | result = await subscribe(special_channel)
248 |
249 | mock_pubsub.subscribe.assert_called_once_with(special_channel)
250 | assert f"Subscribed to channel '{special_channel}'" in result
251 |
252 | @pytest.mark.asyncio
253 | async def test_connection_manager_called_correctly(self):
254 | """Test that RedisConnectionManager.get_connection is called correctly."""
255 | with patch(
256 | "src.tools.pub_sub.RedisConnectionManager.get_connection"
257 | ) as mock_get_conn:
258 | mock_redis = Mock()
259 | mock_redis.publish.return_value = 1
260 | mock_get_conn.return_value = mock_redis
261 |
262 | await publish("test_channel", "test_message")
263 |
264 | mock_get_conn.assert_called_once()
265 |
266 | @pytest.mark.asyncio
267 | async def test_function_signatures(self):
268 | """Test that functions have correct signatures."""
269 | import inspect
270 |
271 | # Test publish function signature
272 | publish_sig = inspect.signature(publish)
273 | publish_params = list(publish_sig.parameters.keys())
274 | assert publish_params == ["channel", "message"]
275 |
276 | # Test subscribe function signature
277 | subscribe_sig = inspect.signature(subscribe)
278 | subscribe_params = list(subscribe_sig.parameters.keys())
279 | assert subscribe_params == ["channel"]
280 |
281 | # Test unsubscribe function signature
282 | unsubscribe_sig = inspect.signature(unsubscribe)
283 | unsubscribe_params = list(unsubscribe_sig.parameters.keys())
284 | assert unsubscribe_params == ["channel"]
285 |
286 | @pytest.mark.asyncio
287 | async def test_publish_large_message(self, mock_redis_connection_manager):
288 | """Test publish operation with large message."""
289 | mock_redis = mock_redis_connection_manager
290 | mock_redis.publish.return_value = 1
291 |
292 | large_message = "x" * 10000 # 10KB message
293 | result = await publish("test_channel", large_message)
294 |
295 | mock_redis.publish.assert_called_once_with("test_channel", large_message)
296 | assert "Message published to channel 'test_channel'" in result
297 |
```
--------------------------------------------------------------------------------
/tests/tools/test_stream.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/stream.py
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pytest
8 | from redis.exceptions import RedisError
9 |
10 | from src.tools.stream import xadd, xdel, xrange
11 |
12 |
13 | class TestStreamOperations:
14 | """Test cases for Redis stream operations."""
15 |
16 | @pytest.mark.asyncio
17 | async def test_xadd_success(self, mock_redis_connection_manager):
18 | """Test successful stream add operation."""
19 | mock_redis = mock_redis_connection_manager
20 | mock_redis.xadd.return_value = "1234567890123-0" # Stream entry ID
21 |
22 | fields = {"field1": "value1", "field2": "value2"}
23 | result = await xadd("test_stream", fields)
24 |
25 | mock_redis.xadd.assert_called_once_with("test_stream", fields)
26 | assert "Successfully added entry 1234567890123-0 to test_stream" in result
27 | assert "1234567890123-0" in result
28 |
29 | @pytest.mark.asyncio
30 | async def test_xadd_with_expiration(self, mock_redis_connection_manager):
31 | """Test stream add operation with expiration."""
32 | mock_redis = mock_redis_connection_manager
33 | mock_redis.xadd.return_value = "1234567890124-0"
34 | mock_redis.expire.return_value = True
35 |
36 | fields = {"message": "test message"}
37 | result = await xadd("test_stream", fields, 60)
38 |
39 | mock_redis.xadd.assert_called_once_with("test_stream", fields)
40 | mock_redis.expire.assert_called_once_with("test_stream", 60)
41 | assert "with expiration 60 seconds" in result
42 |
43 | @pytest.mark.asyncio
44 | async def test_xadd_single_field(self, mock_redis_connection_manager):
45 | """Test stream add operation with single field."""
46 | mock_redis = mock_redis_connection_manager
47 | mock_redis.xadd.return_value = "1234567890125-0"
48 |
49 | fields = {"message": "single field message"}
50 | result = await xadd("test_stream", fields)
51 |
52 | mock_redis.xadd.assert_called_once_with("test_stream", fields)
53 | assert "Successfully added entry 1234567890125-0 to test_stream" in result
54 |
55 | @pytest.mark.asyncio
56 | async def test_xadd_redis_error(self, mock_redis_connection_manager):
57 | """Test stream add operation with Redis error."""
58 | mock_redis = mock_redis_connection_manager
59 | mock_redis.xadd.side_effect = RedisError("Connection failed")
60 |
61 | fields = {"field1": "value1"}
62 | result = await xadd("test_stream", fields)
63 |
64 | assert "Error adding to stream test_stream: Connection failed" in result
65 |
66 | @pytest.mark.asyncio
67 | async def test_xadd_with_numeric_values(self, mock_redis_connection_manager):
68 | """Test stream add operation with numeric field values."""
69 | mock_redis = mock_redis_connection_manager
70 | mock_redis.xadd.return_value = "1234567890126-0"
71 |
72 | fields = {"count": 42, "price": 19.99, "active": True}
73 | result = await xadd("test_stream", fields)
74 |
75 | mock_redis.xadd.assert_called_once_with("test_stream", fields)
76 | assert "Successfully added entry 1234567890126-0 to test_stream" in result
77 |
78 | @pytest.mark.asyncio
79 | async def test_xrange_success(self, mock_redis_connection_manager):
80 | """Test successful stream range operation."""
81 | mock_redis = mock_redis_connection_manager
82 | mock_entries = [
83 | ("1234567890123-0", {"field1": "value1", "field2": "value2"}),
84 | ("1234567890124-0", {"field1": "value3", "field2": "value4"}),
85 | ]
86 | mock_redis.xrange.return_value = mock_entries
87 |
88 | result = await xrange("test_stream")
89 |
90 | mock_redis.xrange.assert_called_once_with("test_stream", count=1)
91 | assert result == str(mock_entries)
92 |
93 | @pytest.mark.asyncio
94 | async def test_xrange_with_custom_count(self, mock_redis_connection_manager):
95 | """Test stream range operation with custom count."""
96 | mock_redis = mock_redis_connection_manager
97 | mock_entries = [
98 | ("1234567890123-0", {"message": "entry1"}),
99 | ("1234567890124-0", {"message": "entry2"}),
100 | ("1234567890125-0", {"message": "entry3"}),
101 | ]
102 | mock_redis.xrange.return_value = mock_entries
103 |
104 | result = await xrange("test_stream", 3)
105 |
106 | mock_redis.xrange.assert_called_once_with("test_stream", count=3)
107 | assert result == str(mock_entries)
108 | # Check the original mock_entries length
109 | assert len(mock_entries) == 3
110 |
111 | @pytest.mark.asyncio
112 | async def test_xrange_empty_stream(self, mock_redis_connection_manager):
113 | """Test stream range operation on empty stream."""
114 | mock_redis = mock_redis_connection_manager
115 | mock_redis.xrange.return_value = []
116 |
117 | result = await xrange("empty_stream")
118 |
119 | assert "Stream empty_stream is empty or does not exist" in result
120 |
121 | @pytest.mark.asyncio
122 | async def test_xrange_redis_error(self, mock_redis_connection_manager):
123 | """Test stream range operation with Redis error."""
124 | mock_redis = mock_redis_connection_manager
125 | mock_redis.xrange.side_effect = RedisError("Connection failed")
126 |
127 | result = await xrange("test_stream")
128 |
129 | assert "Error reading from stream test_stream: Connection failed" in result
130 |
131 | @pytest.mark.asyncio
132 | async def test_xdel_success(self, mock_redis_connection_manager):
133 | """Test successful stream delete operation."""
134 | mock_redis = mock_redis_connection_manager
135 | mock_redis.xdel.return_value = 1 # Number of entries deleted
136 |
137 | result = await xdel("test_stream", "1234567890123-0")
138 |
139 | mock_redis.xdel.assert_called_once_with("test_stream", "1234567890123-0")
140 | assert "Successfully deleted entry 1234567890123-0 from test_stream" in result
141 |
142 | @pytest.mark.asyncio
143 | async def test_xdel_entry_not_found(self, mock_redis_connection_manager):
144 | """Test stream delete operation when entry doesn't exist."""
145 | mock_redis = mock_redis_connection_manager
146 | mock_redis.xdel.return_value = 0 # No entries deleted
147 |
148 | result = await xdel("test_stream", "nonexistent-entry-id")
149 |
150 | assert "Entry nonexistent-entry-id not found in test_stream" in result
151 |
152 | @pytest.mark.asyncio
153 | async def test_xdel_redis_error(self, mock_redis_connection_manager):
154 | """Test stream delete operation with Redis error."""
155 | mock_redis = mock_redis_connection_manager
156 | mock_redis.xdel.side_effect = RedisError("Connection failed")
157 |
158 | result = await xdel("test_stream", "1234567890123-0")
159 |
160 | assert "Error deleting from stream test_stream: Connection failed" in result
161 |
162 | @pytest.mark.asyncio
163 | async def test_xadd_with_empty_fields(self, mock_redis_connection_manager):
164 | """Test stream add operation with empty fields dictionary."""
165 | mock_redis = mock_redis_connection_manager
166 | mock_redis.xadd.return_value = "1234567890127-0"
167 |
168 | fields = {}
169 | result = await xadd("test_stream", fields)
170 |
171 | mock_redis.xadd.assert_called_once_with("test_stream", fields)
172 | assert "Successfully added entry 1234567890127-0 to test_stream" in result
173 |
174 | @pytest.mark.asyncio
175 | async def test_xadd_with_unicode_values(self, mock_redis_connection_manager):
176 | """Test stream add operation with unicode field values."""
177 | mock_redis = mock_redis_connection_manager
178 | mock_redis.xadd.return_value = "1234567890128-0"
179 |
180 | fields = {"message": "Hello 世界 🌍", "user": "测试用户"}
181 | result = await xadd("test_stream", fields)
182 |
183 | mock_redis.xadd.assert_called_once_with("test_stream", fields)
184 | assert "Successfully added entry 1234567890128-0 to test_stream" in result
185 |
186 | @pytest.mark.asyncio
187 | async def test_xrange_large_count(self, mock_redis_connection_manager):
188 | """Test stream range operation with large count."""
189 | mock_redis = mock_redis_connection_manager
190 | mock_entries = [
191 | (f"123456789012{i}-0", {"data": f"entry_{i}"}) for i in range(100)
192 | ]
193 | mock_redis.xrange.return_value = mock_entries
194 |
195 | result = await xrange("test_stream", 100)
196 |
197 | mock_redis.xrange.assert_called_once_with("test_stream", count=100)
198 | # The function returns a string representation
199 | assert result == str(mock_entries)
200 | # Check the original mock_entries length
201 | assert len(mock_entries) == 100
202 |
203 | @pytest.mark.asyncio
204 | async def test_xdel_multiple_entries_behavior(self, mock_redis_connection_manager):
205 | """Test that xdel function handles single entry correctly."""
206 | mock_redis = mock_redis_connection_manager
207 | mock_redis.xdel.return_value = 1
208 |
209 | result = await xdel("test_stream", "single-entry-id")
210 |
211 | # Should call xdel with single entry ID, not multiple
212 | mock_redis.xdel.assert_called_once_with("test_stream", "single-entry-id")
213 | assert "Successfully deleted entry single-entry-id from test_stream" in result
214 |
215 | @pytest.mark.asyncio
216 | async def test_xadd_expiration_error(self, mock_redis_connection_manager):
217 | """Test stream add operation when expiration fails."""
218 | mock_redis = mock_redis_connection_manager
219 | mock_redis.xadd.return_value = "1234567890129-0"
220 | mock_redis.expire.side_effect = RedisError("Expire failed")
221 |
222 | fields = {"message": "test"}
223 | result = await xadd("test_stream", fields, 60)
224 |
225 | assert "Error adding to stream test_stream: Expire failed" in result
226 |
227 | @pytest.mark.asyncio
228 | async def test_xrange_single_entry(self, mock_redis_connection_manager):
229 | """Test stream range operation returning single entry."""
230 | mock_redis = mock_redis_connection_manager
231 | mock_entries = [("1234567890123-0", {"single": "entry"})]
232 | mock_redis.xrange.return_value = mock_entries
233 |
234 | result = await xrange("test_stream", 1)
235 |
236 | assert result == "[('1234567890123-0', {'single': 'entry'})]"
237 | # Check the original mock_entries length
238 | assert len(mock_entries) == 1
239 |
240 | @pytest.mark.asyncio
241 | async def test_connection_manager_called_correctly(self):
242 | """Test that RedisConnectionManager.get_connection is called correctly."""
243 | with patch(
244 | "src.tools.stream.RedisConnectionManager.get_connection"
245 | ) as mock_get_conn:
246 | mock_redis = Mock()
247 | mock_redis.xadd.return_value = "1234567890123-0"
248 | mock_get_conn.return_value = mock_redis
249 |
250 | await xadd("test_stream", {"field": "value"})
251 |
252 | mock_get_conn.assert_called_once()
253 |
254 | @pytest.mark.asyncio
255 | async def test_function_signatures(self):
256 | """Test that functions have correct signatures."""
257 | import inspect
258 |
259 | # Test xadd function signature
260 | xadd_sig = inspect.signature(xadd)
261 | xadd_params = list(xadd_sig.parameters.keys())
262 | assert xadd_params == ["key", "fields", "expiration"]
263 | assert xadd_sig.parameters["expiration"].default is None
264 |
265 | # Test xrange function signature
266 | xrange_sig = inspect.signature(xrange)
267 | xrange_params = list(xrange_sig.parameters.keys())
268 | assert xrange_params == ["key", "count"]
269 | assert xrange_sig.parameters["count"].default == 1
270 |
271 | # Test xdel function signature
272 | xdel_sig = inspect.signature(xdel)
273 | xdel_params = list(xdel_sig.parameters.keys())
274 | assert xdel_params == ["key", "entry_id"]
275 |
276 | @pytest.mark.asyncio
277 | async def test_xadd_with_complex_fields(self, mock_redis_connection_manager):
278 | """Test stream add operation with complex field structure."""
279 | mock_redis = mock_redis_connection_manager
280 | mock_redis.xadd.return_value = "1234567890130-0"
281 |
282 | fields = {
283 | "event_type": "user_action",
284 | "user_id": "12345",
285 | "timestamp": "2024-01-01T12:00:00Z",
286 | "metadata": '{"browser": "chrome", "version": "120"}',
287 | "score": 95.5,
288 | "active": True,
289 | }
290 | result = await xadd("events_stream", fields)
291 |
292 | mock_redis.xadd.assert_called_once_with("events_stream", fields)
293 | assert "Successfully added entry 1234567890130-0 to events_stream" in result
294 |
```
--------------------------------------------------------------------------------
/tests/tools/test_json.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/json.py
3 | """
4 |
5 | import json
6 |
7 | import pytest
8 | from redis.exceptions import RedisError
9 |
10 | from src.tools.json import json_del, json_get, json_set
11 |
12 |
13 | class TestJSONOperations:
14 | """Test cases for Redis JSON operations."""
15 |
16 | @pytest.mark.asyncio
17 | async def test_json_set_success(
18 | self, mock_redis_connection_manager, sample_json_data
19 | ):
20 | """Test successful JSON set operation."""
21 | mock_redis = mock_redis_connection_manager
22 | mock_redis.json.return_value.set.return_value = "OK"
23 |
24 | result = await json_set("test_doc", "$", sample_json_data)
25 |
26 | mock_redis.json.return_value.set.assert_called_once_with(
27 | "test_doc", "$", sample_json_data
28 | )
29 | assert "JSON value set at path '$' in 'test_doc'." in result
30 |
31 | @pytest.mark.asyncio
32 | async def test_json_set_with_expiration(
33 | self, mock_redis_connection_manager, sample_json_data
34 | ):
35 | """Test JSON set operation with expiration."""
36 | mock_redis = mock_redis_connection_manager
37 | mock_redis.json.return_value.set.return_value = "OK"
38 | mock_redis.expire.return_value = True
39 |
40 | result = await json_set("test_doc", "$.name", "John Updated", 60)
41 |
42 | mock_redis.json.return_value.set.assert_called_once_with(
43 | "test_doc", "$.name", "John Updated"
44 | )
45 | mock_redis.expire.assert_called_once_with("test_doc", 60)
46 | assert "Expires in 60 seconds" in result
47 |
48 | @pytest.mark.asyncio
49 | async def test_json_set_nested_path(self, mock_redis_connection_manager):
50 | """Test JSON set operation with nested path."""
51 | mock_redis = mock_redis_connection_manager
52 | mock_redis.json.return_value.set.return_value = "OK"
53 |
54 | result = await json_set("test_doc", "$.user.profile.age", 25)
55 |
56 | mock_redis.json.return_value.set.assert_called_once_with(
57 | "test_doc", "$.user.profile.age", 25
58 | )
59 | assert "JSON value set at path '$.user.profile.age'" in result
60 |
61 | @pytest.mark.asyncio
62 | async def test_json_set_redis_error(self, mock_redis_connection_manager):
63 | """Test JSON set operation with Redis error."""
64 | mock_redis = mock_redis_connection_manager
65 | mock_redis.json.return_value.set.side_effect = RedisError(
66 | "JSON module not loaded"
67 | )
68 |
69 | result = await json_set("test_doc", "$", {"key": "value"})
70 |
71 | assert (
72 | "Error setting JSON value at path '$' in 'test_doc': JSON module not loaded"
73 | in result
74 | )
75 |
76 | @pytest.mark.asyncio
77 | async def test_json_get_success(
78 | self, mock_redis_connection_manager, sample_json_data
79 | ):
80 | """Test successful JSON get operation."""
81 | mock_redis = mock_redis_connection_manager
82 | mock_redis.json.return_value.get.return_value = sample_json_data
83 |
84 | result = await json_get("test_doc", "$")
85 |
86 | mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$")
87 | # json_get returns a JSON string representation
88 | assert result == json.dumps(sample_json_data, ensure_ascii=False, indent=2)
89 |
90 | @pytest.mark.asyncio
91 | async def test_json_get_specific_field(self, mock_redis_connection_manager):
92 | """Test JSON get operation for specific field."""
93 | mock_redis = mock_redis_connection_manager
94 | mock_redis.json.return_value.get.return_value = ["John Doe"]
95 |
96 | result = await json_get("test_doc", "$.name")
97 |
98 | mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$.name")
99 | # json_get returns a JSON string representation
100 | assert result == json.dumps(["John Doe"], ensure_ascii=False, indent=2)
101 |
102 | @pytest.mark.asyncio
103 | async def test_json_get_default_path(
104 | self, mock_redis_connection_manager, sample_json_data
105 | ):
106 | """Test JSON get operation with default path."""
107 | mock_redis = mock_redis_connection_manager
108 | mock_redis.json.return_value.get.return_value = sample_json_data
109 |
110 | result = await json_get("test_doc")
111 |
112 | mock_redis.json.return_value.get.assert_called_once_with("test_doc", "$")
113 | # json_get returns a JSON string representation
114 | assert result == json.dumps(sample_json_data, ensure_ascii=False, indent=2)
115 |
116 | @pytest.mark.asyncio
117 | async def test_json_get_not_found(self, mock_redis_connection_manager):
118 | """Test JSON get operation when document doesn't exist."""
119 | mock_redis = mock_redis_connection_manager
120 | mock_redis.json.return_value.get.return_value = None
121 |
122 | result = await json_get("nonexistent_doc", "$")
123 |
124 | assert "No data found at path '$' in 'nonexistent_doc'" in result
125 |
126 | @pytest.mark.asyncio
127 | async def test_json_get_redis_error(self, mock_redis_connection_manager):
128 | """Test JSON get operation with Redis error."""
129 | mock_redis = mock_redis_connection_manager
130 | mock_redis.json.return_value.get.side_effect = RedisError("Connection failed")
131 |
132 | result = await json_get("test_doc", "$")
133 |
134 | assert (
135 | "Error retrieving JSON value at path '$' in 'test_doc': Connection failed"
136 | in result
137 | )
138 |
139 | @pytest.mark.asyncio
140 | async def test_json_del_success(self, mock_redis_connection_manager):
141 | """Test successful JSON delete operation."""
142 | mock_redis = mock_redis_connection_manager
143 | mock_redis.json.return_value.delete.return_value = 1
144 |
145 | result = await json_del("test_doc", "$.name")
146 |
147 | mock_redis.json.return_value.delete.assert_called_once_with(
148 | "test_doc", "$.name"
149 | )
150 | assert "Deleted JSON value at path '$.name' in 'test_doc'" in result
151 |
152 | @pytest.mark.asyncio
153 | async def test_json_del_default_path(self, mock_redis_connection_manager):
154 | """Test JSON delete operation with default path (entire document)."""
155 | mock_redis = mock_redis_connection_manager
156 | mock_redis.json.return_value.delete.return_value = 1
157 |
158 | result = await json_del("test_doc")
159 |
160 | mock_redis.json.return_value.delete.assert_called_once_with("test_doc", "$")
161 | assert "Deleted JSON value at path '$' in 'test_doc'" in result
162 |
163 | @pytest.mark.asyncio
164 | async def test_json_del_not_found(self, mock_redis_connection_manager):
165 | """Test JSON delete operation when path doesn't exist."""
166 | mock_redis = mock_redis_connection_manager
167 | mock_redis.json.return_value.delete.return_value = 0
168 |
169 | result = await json_del("test_doc", "$.nonexistent")
170 |
171 | assert "No JSON value found at path '$.nonexistent' in 'test_doc'" in result
172 |
173 | @pytest.mark.asyncio
174 | async def test_json_del_redis_error(self, mock_redis_connection_manager):
175 | """Test JSON delete operation with Redis error."""
176 | mock_redis = mock_redis_connection_manager
177 | mock_redis.json.return_value.delete.side_effect = RedisError(
178 | "Connection failed"
179 | )
180 |
181 | result = await json_del("test_doc", "$.name")
182 |
183 | assert (
184 | "Error deleting JSON value at path '$.name' in 'test_doc': Connection failed"
185 | in result
186 | )
187 |
188 | @pytest.mark.asyncio
189 | async def test_json_set_with_array(self, mock_redis_connection_manager):
190 | """Test JSON set operation with array value."""
191 | mock_redis = mock_redis_connection_manager
192 | mock_redis.json.return_value.set.return_value = "OK"
193 |
194 | array_data = ["item1", "item2", "item3"]
195 | result = await json_set("test_doc", "$.items", array_data)
196 |
197 | mock_redis.json.return_value.set.assert_called_once_with(
198 | "test_doc", "$.items", array_data
199 | )
200 | assert "JSON value set at path '$.items'" in result
201 |
202 | @pytest.mark.asyncio
203 | async def test_json_set_with_nested_object(self, mock_redis_connection_manager):
204 | """Test JSON set operation with nested object."""
205 | mock_redis = mock_redis_connection_manager
206 | mock_redis.json.return_value.set.return_value = "OK"
207 |
208 | nested_data = {
209 | "user": {
210 | "profile": {
211 | "name": "John",
212 | "settings": {"theme": "dark", "notifications": True},
213 | }
214 | }
215 | }
216 | result = await json_set("test_doc", "$", nested_data)
217 |
218 | mock_redis.json.return_value.set.assert_called_once_with(
219 | "test_doc", "$", nested_data
220 | )
221 | assert "JSON value set at path '$'" in result
222 |
223 | @pytest.mark.asyncio
224 | async def test_json_get_array_element(self, mock_redis_connection_manager):
225 | """Test JSON get operation for array element."""
226 | mock_redis = mock_redis_connection_manager
227 | mock_redis.json.return_value.get.return_value = ["first_item"]
228 |
229 | result = await json_get("test_doc", "$.items[0]")
230 |
231 | mock_redis.json.return_value.get.assert_called_once_with(
232 | "test_doc", "$.items[0]"
233 | )
234 | # json_get returns a JSON string representation
235 | assert result == json.dumps(["first_item"], ensure_ascii=False, indent=2)
236 |
237 | @pytest.mark.asyncio
238 | async def test_json_operations_with_numeric_values(
239 | self, mock_redis_connection_manager
240 | ):
241 | """Test JSON operations with numeric values."""
242 | mock_redis = mock_redis_connection_manager
243 | mock_redis.json.return_value.set.return_value = "OK"
244 | mock_redis.json.return_value.get.return_value = [42]
245 |
246 | # Set numeric value
247 | await json_set("test_doc", "$.count", 42)
248 | mock_redis.json.return_value.set.assert_called_with("test_doc", "$.count", 42)
249 |
250 | # Get numeric value
251 | result = await json_get("test_doc", "$.count")
252 | assert result == json.dumps([42], ensure_ascii=False, indent=2)
253 |
254 | @pytest.mark.asyncio
255 | async def test_json_operations_with_boolean_values(
256 | self, mock_redis_connection_manager
257 | ):
258 | """Test JSON operations with boolean values."""
259 | mock_redis = mock_redis_connection_manager
260 | mock_redis.json.return_value.set.return_value = "OK"
261 | mock_redis.json.return_value.get.return_value = [True]
262 |
263 | # Set boolean value
264 | await json_set("test_doc", "$.active", True)
265 | mock_redis.json.return_value.set.assert_called_with(
266 | "test_doc", "$.active", True
267 | )
268 |
269 | # Get boolean value
270 | result = await json_get("test_doc", "$.active")
271 | assert result == json.dumps([True], ensure_ascii=False, indent=2)
272 |
273 | @pytest.mark.asyncio
274 | async def test_json_set_expiration_error(self, mock_redis_connection_manager):
275 | """Test JSON set operation when expiration fails."""
276 | mock_redis = mock_redis_connection_manager
277 | mock_redis.json.return_value.set.return_value = "OK"
278 | mock_redis.expire.side_effect = RedisError("Expire failed")
279 |
280 | result = await json_set("test_doc", "$", {"key": "value"}, 60)
281 |
282 | assert (
283 | "Error setting JSON value at path '$' in 'test_doc': Expire failed"
284 | in result
285 | )
286 |
287 | @pytest.mark.asyncio
288 | async def test_json_del_multiple_matches(self, mock_redis_connection_manager):
289 | """Test JSON delete operation that matches multiple elements."""
290 | mock_redis = mock_redis_connection_manager
291 | mock_redis.json.return_value.delete.return_value = (
292 | 3 # Multiple elements deleted
293 | )
294 |
295 | result = await json_del("test_doc", "$..name")
296 |
297 | mock_redis.json.return_value.delete.assert_called_once_with(
298 | "test_doc", "$..name"
299 | )
300 | assert "Deleted JSON value at path '$..name'" in result
301 |
302 | @pytest.mark.asyncio
303 | async def test_json_operations_with_null_values(
304 | self, mock_redis_connection_manager
305 | ):
306 | """Test JSON operations with null values."""
307 | mock_redis = mock_redis_connection_manager
308 | mock_redis.json.return_value.set.return_value = "OK"
309 | mock_redis.json.return_value.get.return_value = [None]
310 |
311 | # Set null value
312 | await json_set("test_doc", "$.optional_field", None)
313 | mock_redis.json.return_value.set.assert_called_with(
314 | "test_doc", "$.optional_field", None
315 | )
316 |
317 | # Get null value
318 | result = await json_get("test_doc", "$.optional_field")
319 | assert result == json.dumps([None], ensure_ascii=False, indent=2)
320 |
```
--------------------------------------------------------------------------------
/tests/tools/test_redis_query_engine.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/redis_query_engine.py
3 | """
4 |
5 | import json
6 | from unittest.mock import Mock, patch
7 |
8 | import pytest
9 | from redis.commands.search.field import VectorField
10 | from redis.commands.search.index_definition import IndexDefinition
11 | from redis.commands.search.query import Query
12 | from redis.exceptions import RedisError
13 |
14 | from src.tools.redis_query_engine import (
15 | create_vector_index_hash,
16 | get_index_info,
17 | get_indexes,
18 | vector_search_hash,
19 | )
20 |
21 |
22 | class TestRedisQueryEngineOperations:
23 | """Test cases for Redis query engine operations."""
24 |
25 | @pytest.mark.asyncio
26 | async def test_get_indexes_success(self, mock_redis_connection_manager):
27 | """Test successful get indexes operation."""
28 | mock_redis = mock_redis_connection_manager
29 | mock_indexes = ["index1", "index2", "vector_index"]
30 | mock_redis.execute_command.return_value = mock_indexes
31 |
32 | result = await get_indexes()
33 |
34 | mock_redis.execute_command.assert_called_once_with("FT._LIST")
35 | assert result == json.dumps(mock_indexes)
36 |
37 | @pytest.mark.asyncio
38 | async def test_get_indexes_empty(self, mock_redis_connection_manager):
39 | """Test get indexes operation with no indexes."""
40 | mock_redis = mock_redis_connection_manager
41 | mock_redis.execute_command.return_value = []
42 |
43 | result = await get_indexes()
44 |
45 | assert result == json.dumps([])
46 |
47 | @pytest.mark.asyncio
48 | async def test_get_indexes_redis_error(self, mock_redis_connection_manager):
49 | """Test get indexes operation with Redis error."""
50 | mock_redis = mock_redis_connection_manager
51 | mock_redis.execute_command.side_effect = RedisError("Search module not loaded")
52 |
53 | result = await get_indexes()
54 |
55 | assert "Error retrieving indexes: Search module not loaded" in result
56 |
57 | @pytest.mark.asyncio
58 | async def test_create_vector_index_hash_success(
59 | self, mock_redis_connection_manager
60 | ):
61 | """Test successful vector index creation."""
62 | mock_redis = mock_redis_connection_manager
63 | mock_ft = Mock()
64 | mock_redis.ft.return_value = mock_ft
65 | mock_ft.create_index.return_value = "OK"
66 |
67 | result = await create_vector_index_hash()
68 |
69 | mock_redis.ft.assert_called_once_with("vector_index")
70 | mock_ft.create_index.assert_called_once()
71 |
72 | # Verify the create_index call arguments
73 | call_args = mock_ft.create_index.call_args
74 | fields = call_args[0][0] # First positional argument (fields)
75 | definition = call_args[1]["definition"] # Keyword argument
76 |
77 | assert len(fields) == 1
78 | assert isinstance(fields[0], VectorField)
79 | assert fields[0].name == "vector"
80 | assert isinstance(definition, IndexDefinition)
81 |
82 | assert "Index 'vector_index' created successfully." in result
83 |
84 | @pytest.mark.asyncio
85 | async def test_create_vector_index_hash_custom_params(
86 | self, mock_redis_connection_manager
87 | ):
88 | """Test vector index creation with custom parameters."""
89 | mock_redis = mock_redis_connection_manager
90 | mock_ft = Mock()
91 | mock_redis.ft.return_value = mock_ft
92 | mock_ft.create_index.return_value = "OK"
93 |
94 | result = await create_vector_index_hash(
95 | index_name="custom_index",
96 | vector_field="embedding",
97 | dim=512,
98 | distance_metric="COSINE",
99 | )
100 |
101 | mock_redis.ft.assert_called_once_with("custom_index")
102 |
103 | # Verify the field configuration
104 | call_args = mock_ft.create_index.call_args
105 | fields = call_args[0][0]
106 |
107 | assert fields[0].name == "embedding"
108 | assert "Index 'custom_index' created successfully." in result
109 |
110 | @pytest.mark.asyncio
111 | async def test_create_vector_index_hash_redis_error(
112 | self, mock_redis_connection_manager
113 | ):
114 | """Test vector index creation with Redis error."""
115 | mock_redis = mock_redis_connection_manager
116 | mock_ft = Mock()
117 | mock_redis.ft.return_value = mock_ft
118 | mock_ft.create_index.side_effect = RedisError("Index already exists")
119 |
120 | result = await create_vector_index_hash()
121 |
122 | assert "Error creating index 'vector_index': Index already exists" in result
123 |
124 | @pytest.mark.asyncio
125 | async def test_vector_search_hash_success(
126 | self, mock_redis_connection_manager, sample_vector
127 | ):
128 | """Test successful vector search operation."""
129 | mock_redis = mock_redis_connection_manager
130 | mock_ft = Mock()
131 | mock_redis.ft.return_value = mock_ft
132 |
133 | # Mock search results
134 | mock_doc1 = Mock()
135 | mock_doc1.__dict__ = {"id": "doc1", "vector": "binary_data", "score": "0.95"}
136 | mock_doc2 = Mock()
137 | mock_doc2.__dict__ = {"id": "doc2", "vector": "binary_data", "score": "0.87"}
138 |
139 | mock_result = Mock()
140 | mock_result.docs = [mock_doc1, mock_doc2]
141 | mock_ft.search.return_value = mock_result
142 |
143 | with patch("numpy.array") as mock_np_array:
144 | mock_np_array.return_value.tobytes.return_value = b"query_vector_bytes"
145 |
146 | result = await vector_search_hash(sample_vector)
147 |
148 | mock_redis.ft.assert_called_once_with("vector_index")
149 | mock_ft.search.assert_called_once()
150 |
151 | # Verify the search query
152 | search_call_args = mock_ft.search.call_args[0][0]
153 | assert isinstance(search_call_args, Query)
154 |
155 | assert isinstance(result, list)
156 | assert len(result) == 2
157 |
158 | @pytest.mark.asyncio
159 | async def test_vector_search_hash_custom_params(
160 | self, mock_redis_connection_manager, sample_vector
161 | ):
162 | """Test vector search with custom parameters."""
163 | mock_redis = mock_redis_connection_manager
164 | mock_ft = Mock()
165 | mock_redis.ft.return_value = mock_ft
166 |
167 | mock_result = Mock()
168 | mock_result.docs = []
169 | mock_ft.search.return_value = mock_result
170 |
171 | with patch("numpy.array") as mock_np_array:
172 | mock_np_array.return_value.tobytes.return_value = b"query_vector_bytes"
173 |
174 | result = await vector_search_hash(
175 | query_vector=sample_vector,
176 | index_name="custom_index",
177 | vector_field="embedding",
178 | k=10,
179 | return_fields=["title", "content"],
180 | )
181 |
182 | mock_redis.ft.assert_called_once_with("custom_index")
183 | assert isinstance(result, list)
184 |
185 | @pytest.mark.asyncio
186 | async def test_vector_search_hash_no_results(
187 | self, mock_redis_connection_manager, sample_vector
188 | ):
189 | """Test vector search with no results."""
190 | mock_redis = mock_redis_connection_manager
191 | mock_ft = Mock()
192 | mock_redis.ft.return_value = mock_ft
193 |
194 | mock_result = Mock()
195 | mock_result.docs = []
196 | mock_ft.search.return_value = mock_result
197 |
198 | with patch("numpy.array") as mock_np_array:
199 | mock_np_array.return_value.tobytes.return_value = b"query_vector_bytes"
200 |
201 | result = await vector_search_hash(sample_vector)
202 |
203 | assert result == [] # Empty list when no results
204 |
205 | @pytest.mark.asyncio
206 | async def test_vector_search_hash_redis_error(
207 | self, mock_redis_connection_manager, sample_vector
208 | ):
209 | """Test vector search with Redis error."""
210 | mock_redis = mock_redis_connection_manager
211 | mock_ft = Mock()
212 | mock_redis.ft.return_value = mock_ft
213 | mock_ft.search.side_effect = RedisError("Index not found")
214 |
215 | with patch("numpy.array") as mock_np_array:
216 | mock_np_array.return_value.astype.return_value.tobytes.return_value = (
217 | b"query_vector_bytes"
218 | )
219 |
220 | result = await vector_search_hash(sample_vector)
221 |
222 | assert (
223 | "Error performing vector search on index 'vector_index': Index not found"
224 | in result
225 | )
226 |
227 | @pytest.mark.asyncio
228 | async def test_get_index_info_success(self, mock_redis_connection_manager):
229 | """Test successful get index info operation."""
230 | mock_redis = mock_redis_connection_manager
231 | mock_ft = Mock()
232 | mock_redis.ft.return_value = mock_ft
233 |
234 | mock_info = {
235 | "index_name": "vector_index",
236 | "index_options": [],
237 | "index_definition": ["key_type", "HASH", "prefixes", ["doc:"]],
238 | "attributes": [
239 | ["identifier", "vector", "attribute", "vector", "type", "VECTOR"]
240 | ],
241 | "num_docs": "100",
242 | "max_doc_id": "100",
243 | "num_terms": "0",
244 | "num_records": "100",
245 | "inverted_sz_mb": "0.00",
246 | "vector_index_sz_mb": "1.50",
247 | "total_inverted_index_blocks": "0",
248 | "offset_vectors_sz_mb": "0.00",
249 | "doc_table_size_mb": "0.01",
250 | "sortable_values_size_mb": "0.00",
251 | "key_table_size_mb": "0.00",
252 | }
253 | mock_ft.info.return_value = mock_info
254 |
255 | result = await get_index_info("vector_index")
256 |
257 | mock_redis.ft.assert_called_once_with("vector_index")
258 | mock_ft.info.assert_called_once()
259 | # get_index_info returns a JSON string representation
260 | assert result == json.dumps(mock_info, ensure_ascii=False, indent=2)
261 |
262 | @pytest.mark.asyncio
263 | async def test_get_index_info_default_index(self, mock_redis_connection_manager):
264 | """Test get index info with default index name."""
265 | mock_redis = mock_redis_connection_manager
266 | mock_ft = Mock()
267 | mock_redis.ft.return_value = mock_ft
268 | mock_ft.info.return_value = {"index_name": "vector_index"}
269 |
270 | result = await get_index_info("vector_index")
271 |
272 | mock_redis.ft.assert_called_once_with("vector_index")
273 | # get_index_info returns a JSON string representation
274 | assert result == json.dumps(
275 | {"index_name": "vector_index"}, ensure_ascii=False, indent=2
276 | )
277 |
278 | @pytest.mark.asyncio
279 | async def test_get_index_info_redis_error(self, mock_redis_connection_manager):
280 | """Test get index info with Redis error."""
281 | mock_redis = mock_redis_connection_manager
282 | mock_ft = Mock()
283 | mock_redis.ft.return_value = mock_ft
284 | mock_ft.info.side_effect = RedisError("Index not found")
285 |
286 | result = await get_index_info("nonexistent_index")
287 |
288 | assert "Error retrieving index info: Index not found" in result
289 |
290 | @pytest.mark.asyncio
291 | async def test_create_vector_index_different_metrics(
292 | self, mock_redis_connection_manager
293 | ):
294 | """Test vector index creation with different distance metrics."""
295 | mock_redis = mock_redis_connection_manager
296 | mock_ft = Mock()
297 | mock_redis.ft.return_value = mock_ft
298 | mock_ft.create_index.return_value = "OK"
299 |
300 | # Test L2 metric
301 | await create_vector_index_hash(distance_metric="L2")
302 | mock_ft.create_index.assert_called()
303 |
304 | # Test IP metric
305 | mock_ft.reset_mock()
306 | await create_vector_index_hash(distance_metric="IP")
307 | mock_ft.create_index.assert_called()
308 |
309 | @pytest.mark.asyncio
310 | async def test_vector_search_with_large_k(
311 | self, mock_redis_connection_manager, sample_vector
312 | ):
313 | """Test vector search with large k value."""
314 | mock_redis = mock_redis_connection_manager
315 | mock_ft = Mock()
316 | mock_redis.ft.return_value = mock_ft
317 |
318 | mock_result = Mock()
319 | mock_result.docs = []
320 | mock_ft.search.return_value = mock_result
321 |
322 | with patch("numpy.array") as mock_np_array:
323 | mock_np_array.return_value.astype.return_value.tobytes.return_value = (
324 | b"query_vector_bytes"
325 | )
326 |
327 | result = await vector_search_hash(sample_vector, k=1000)
328 | assert result == [] # Empty list when no results
329 |
330 | # Should handle large k values
331 | mock_ft.search.assert_called_once()
332 |
333 | @pytest.mark.asyncio
334 | async def test_connection_manager_called_correctly(self):
335 | """Test that RedisConnectionManager.get_connection is called correctly."""
336 | with patch(
337 | "src.tools.redis_query_engine.RedisConnectionManager.get_connection"
338 | ) as mock_get_conn:
339 | mock_redis = Mock()
340 | mock_redis.execute_command.return_value = []
341 | mock_get_conn.return_value = mock_redis
342 |
343 | await get_indexes()
344 |
345 | mock_get_conn.assert_called_once()
346 |
```
--------------------------------------------------------------------------------
/tests/test_integration.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Integration tests for Redis MCP Server.
3 |
4 | These tests actually start the MCP server process and verify it can handle real requests.
5 | """
6 |
7 | import json
8 | import subprocess
9 | import sys
10 | import time
11 | import os
12 | from pathlib import Path
13 |
14 | import pytest
15 |
16 |
17 | def _redis_available():
18 | """Check if Redis is available for testing."""
19 | try:
20 | import redis
21 |
22 | r = redis.Redis(host="localhost", port=6379, decode_responses=True)
23 | r.ping()
24 | return True
25 | except Exception:
26 | return False
27 |
28 |
29 | def _create_server_process(project_root):
30 | """Create a server process with proper encoding for cross-platform compatibility."""
31 | return subprocess.Popen(
32 | [sys.executable, "-m", "src.main"],
33 | cwd=project_root,
34 | stdin=subprocess.PIPE,
35 | stdout=subprocess.PIPE,
36 | stderr=subprocess.PIPE,
37 | text=True,
38 | encoding="utf-8",
39 | errors="replace", # Replace invalid characters instead of failing
40 | env={"REDIS_HOST": "localhost", "REDIS_PORT": "6379", **dict(os.environ)},
41 | )
42 |
43 |
44 | @pytest.mark.integration
45 | class TestMCPServerIntegration:
46 | """Integration tests that start the actual MCP server."""
47 |
48 | @pytest.fixture
49 | def server_process(self):
50 | """Start the MCP server process for testing."""
51 | # Get the project root directory
52 | project_root = Path(__file__).parent.parent
53 |
54 | # Start the server process with proper encoding for cross-platform compatibility
55 | process = _create_server_process(project_root)
56 |
57 | # Give the server a moment to start
58 | time.sleep(1)
59 |
60 | yield process
61 |
62 | # Clean up
63 | process.terminate()
64 | try:
65 | process.wait(timeout=5)
66 | except subprocess.TimeoutExpired:
67 | process.kill()
68 | process.wait()
69 |
70 | def test_server_starts_successfully(self, server_process):
71 | """Test that the MCP server starts without crashing."""
72 | # Check if process is still running
73 | assert server_process.poll() is None, "Server process should be running"
74 |
75 | # Check for startup message in stderr
76 | # Note: MCP servers typically output startup info to stderr
77 | time.sleep(0.5) # Give time for startup message
78 |
79 | # The server should still be running
80 | assert server_process.poll() is None
81 |
82 | def test_server_handles_unicode_on_windows(self, server_process):
83 | """Test that the server handles Unicode properly on Windows."""
84 | # This test specifically addresses the Windows Unicode decode error
85 | # Check if process is still running
86 | assert server_process.poll() is None, "Server process should be running"
87 |
88 | # Try to read any available output without blocking
89 | # This should not cause a UnicodeDecodeError on Windows
90 | try:
91 | # Use a short timeout to avoid blocking
92 | import select
93 | import sys
94 |
95 | if sys.platform == "win32":
96 | # On Windows, we can't use select, so just check if process is alive
97 | time.sleep(0.1)
98 | assert server_process.poll() is None
99 | else:
100 | # On Unix-like systems, we can use select
101 | ready, _, _ = select.select([server_process.stdout], [], [], 0.1)
102 | # If there's output available, try to read it
103 | if ready:
104 | try:
105 | server_process.stdout.read(1) # Read just one character
106 | # If we get here, Unicode handling is working
107 | assert True
108 | except UnicodeDecodeError:
109 | pytest.fail("Unicode decode error occurred")
110 |
111 | except Exception:
112 | # If any other error occurs, that's fine - we're just testing Unicode handling
113 | pass
114 |
115 | # Main assertion: process should still be running
116 | assert server_process.poll() is None
117 |
118 | def test_server_responds_to_initialize_request(self, server_process):
119 | """Test that the server responds to MCP initialize request."""
120 | # MCP initialize request
121 | initialize_request = {
122 | "jsonrpc": "2.0",
123 | "id": 1,
124 | "method": "initialize",
125 | "params": {
126 | "protocolVersion": "2024-11-05",
127 | "capabilities": {},
128 | "clientInfo": {"name": "test-client", "version": "1.0.0"},
129 | },
130 | }
131 |
132 | # Send the request
133 | request_json = json.dumps(initialize_request) + "\n"
134 | server_process.stdin.write(request_json)
135 | server_process.stdin.flush()
136 |
137 | # Read the response
138 | response_line = server_process.stdout.readline()
139 | assert response_line.strip(), "Server should respond to initialize request"
140 |
141 | # Parse the response
142 | try:
143 | response = json.loads(response_line)
144 | assert response.get("jsonrpc") == "2.0"
145 | assert response.get("id") == 1
146 | assert "result" in response
147 | except json.JSONDecodeError:
148 | pytest.fail(f"Invalid JSON response: {response_line}")
149 |
150 | def test_server_lists_tools(self, server_process):
151 | """Test that the server can list available tools."""
152 | # First initialize
153 | initialize_request = {
154 | "jsonrpc": "2.0",
155 | "id": 1,
156 | "method": "initialize",
157 | "params": {
158 | "protocolVersion": "2024-11-05",
159 | "capabilities": {},
160 | "clientInfo": {"name": "test-client", "version": "1.0.0"},
161 | },
162 | }
163 |
164 | server_process.stdin.write(json.dumps(initialize_request) + "\n")
165 | server_process.stdin.flush()
166 | server_process.stdout.readline() # Read initialize response
167 |
168 | # Send initialized notification
169 | initialized_notification = {
170 | "jsonrpc": "2.0",
171 | "method": "notifications/initialized",
172 | }
173 | server_process.stdin.write(json.dumps(initialized_notification) + "\n")
174 | server_process.stdin.flush()
175 |
176 | # Request tools list
177 | tools_request = {"jsonrpc": "2.0", "id": 2, "method": "tools/list"}
178 |
179 | server_process.stdin.write(json.dumps(tools_request) + "\n")
180 | server_process.stdin.flush()
181 |
182 | # Read the response
183 | response_line = server_process.stdout.readline()
184 | response = json.loads(response_line)
185 |
186 | assert response.get("jsonrpc") == "2.0"
187 | assert response.get("id") == 2
188 | assert "result" in response
189 | assert "tools" in response["result"]
190 |
191 | # Verify we have some Redis tools
192 | tools = response["result"]["tools"]
193 | tool_names = [tool["name"] for tool in tools]
194 |
195 | # Should have basic Redis operations
196 | expected_tools = [
197 | "hset",
198 | "hget",
199 | "hdel",
200 | "hgetall",
201 | "hexists",
202 | "set_vector_in_hash",
203 | "get_vector_from_hash",
204 | "json_set",
205 | "json_get",
206 | "json_del",
207 | "lpush",
208 | "rpush",
209 | "lpop",
210 | "rpop",
211 | "lrange",
212 | "llen",
213 | "delete",
214 | "type",
215 | "expire",
216 | "rename",
217 | "scan_keys",
218 | "scan_all_keys",
219 | "publish",
220 | "subscribe",
221 | "unsubscribe",
222 | "get_indexes",
223 | "get_index_info",
224 | "get_indexed_keys_number",
225 | "create_vector_index_hash",
226 | "vector_search_hash",
227 | "dbsize",
228 | "info",
229 | "client_list",
230 | "sadd",
231 | "srem",
232 | "smembers",
233 | "zadd",
234 | "zrange",
235 | "zrem",
236 | "xadd",
237 | "xrange",
238 | "xdel",
239 | "set",
240 | "get",
241 | ]
242 | for tool in tool_names:
243 | assert tool in expected_tools, (
244 | f"Expected tool '{tool}' not found in {tool_names}"
245 | )
246 |
247 | def test_server_tool_count_and_names(self, server_process):
248 | """Test that the server registers the correct number of tools with expected names."""
249 | # Initialize the server
250 | self._initialize_server(server_process)
251 |
252 | # Request tools list
253 | tools_request = {"jsonrpc": "2.0", "id": 3, "method": "tools/list"}
254 |
255 | server_process.stdin.write(json.dumps(tools_request) + "\n")
256 | server_process.stdin.flush()
257 |
258 | # Read the response
259 | response_line = server_process.stdout.readline()
260 | response = json.loads(response_line)
261 |
262 | assert response.get("jsonrpc") == "2.0"
263 | assert response.get("id") == 3
264 | assert "result" in response
265 | assert "tools" in response["result"]
266 |
267 | tools = response["result"]["tools"]
268 | tool_names = [tool["name"] for tool in tools]
269 |
270 | # Expected tool count (based on @mcp.tool() decorators in codebase)
271 | expected_tool_count = 44
272 | assert len(tools) == expected_tool_count, (
273 | f"Expected {expected_tool_count} tools, but got {len(tools)}"
274 | )
275 |
276 | # Expected tool names (alphabetically sorted for easier verification)
277 | expected_tools = [
278 | "client_list",
279 | "create_vector_index_hash",
280 | "dbsize",
281 | "delete",
282 | "expire",
283 | "get",
284 | "get_index_info",
285 | "get_indexed_keys_number",
286 | "get_indexes",
287 | "get_vector_from_hash",
288 | "hdel",
289 | "hexists",
290 | "hget",
291 | "hgetall",
292 | "hset",
293 | "info",
294 | "json_del",
295 | "json_get",
296 | "json_set",
297 | "llen",
298 | "lpop",
299 | "lpush",
300 | "lrange",
301 | "publish",
302 | "rename",
303 | "rpop",
304 | "rpush",
305 | "sadd",
306 | "scan_all_keys",
307 | "scan_keys",
308 | "set",
309 | "set_vector_in_hash",
310 | "smembers",
311 | "srem",
312 | "subscribe",
313 | "type",
314 | "unsubscribe",
315 | "vector_search_hash",
316 | "xadd",
317 | "xdel",
318 | "xrange",
319 | "zadd",
320 | "zrange",
321 | "zrem",
322 | ]
323 |
324 | # Verify all expected tools are present
325 | missing_tools = set(expected_tools) - set(tool_names)
326 | extra_tools = set(tool_names) - set(expected_tools)
327 |
328 | assert not missing_tools, f"Missing expected tools: {sorted(missing_tools)}"
329 | assert not extra_tools, f"Found unexpected tools: {sorted(extra_tools)}"
330 |
331 | # Verify tool categories are represented
332 | tool_categories = {
333 | "string": ["get", "set"],
334 | "hash": ["hget", "hset", "hgetall", "hdel", "hexists"],
335 | "list": ["lpush", "rpush", "lpop", "rpop", "lrange", "llen"],
336 | "set": ["sadd", "srem", "smembers"],
337 | "sorted_set": ["zadd", "zrem", "zrange"],
338 | "stream": ["xadd", "xdel", "xrange"],
339 | "json": ["json_get", "json_set", "json_del"],
340 | "pub_sub": ["publish", "subscribe", "unsubscribe"],
341 | "server_mgmt": ["dbsize", "info", "client_list"],
342 | "misc": [
343 | "delete",
344 | "expire",
345 | "rename",
346 | "type",
347 | "scan_keys",
348 | "scan_all_keys",
349 | ],
350 | "vector_search": [
351 | "create_vector_index_hash",
352 | "vector_search_hash",
353 | "get_indexes",
354 | "get_index_info",
355 | "set_vector_in_hash",
356 | "get_vector_from_hash",
357 | "get_indexed_keys_number",
358 | ],
359 | }
360 |
361 | for category, category_tools in tool_categories.items():
362 | for tool in category_tools:
363 | assert tool in tool_names, (
364 | f"Tool '{tool}' from category '{category}' not found in registered tools"
365 | )
366 |
367 | def _initialize_server(self, server_process):
368 | """Helper to initialize the MCP server."""
369 | # Send initialize request
370 | initialize_request = {
371 | "jsonrpc": "2.0",
372 | "id": 1,
373 | "method": "initialize",
374 | "params": {
375 | "protocolVersion": "2024-11-05",
376 | "capabilities": {},
377 | "clientInfo": {"name": "test-client", "version": "1.0.0"},
378 | },
379 | }
380 |
381 | server_process.stdin.write(json.dumps(initialize_request) + "\n")
382 | server_process.stdin.flush()
383 | server_process.stdout.readline() # Read response
384 |
385 | # Send initialized notification
386 | initialized_notification = {
387 | "jsonrpc": "2.0",
388 | "method": "notifications/initialized",
389 | }
390 | server_process.stdin.write(json.dumps(initialized_notification) + "\n")
391 | server_process.stdin.flush()
392 |
```
--------------------------------------------------------------------------------
/tests/tools/test_hash.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for src/tools/hash.py
3 | """
4 |
5 | import numpy as np
6 | import pytest
7 | from redis.exceptions import RedisError
8 |
9 | from src.tools.hash import (
10 | get_vector_from_hash,
11 | hdel,
12 | hexists,
13 | hget,
14 | hgetall,
15 | hset,
16 | set_vector_in_hash,
17 | )
18 |
19 |
20 | class TestHashOperations:
21 | """Test cases for Redis hash operations."""
22 |
23 | @pytest.mark.asyncio
24 | async def test_hset_success(self, mock_redis_connection_manager):
25 | """Test successful hash set operation."""
26 | mock_redis = mock_redis_connection_manager
27 | mock_redis.hset.return_value = 1
28 |
29 | result = await hset("test_hash", "field1", "value1")
30 |
31 | mock_redis.hset.assert_called_once_with("test_hash", "field1", "value1")
32 | assert "Field 'field1' set successfully in hash 'test_hash'." in result
33 |
34 | @pytest.mark.asyncio
35 | async def test_hset_with_expiration(self, mock_redis_connection_manager):
36 | """Test hash set operation with expiration."""
37 | mock_redis = mock_redis_connection_manager
38 | mock_redis.hset.return_value = 1
39 | mock_redis.expire.return_value = True
40 |
41 | result = await hset("test_hash", "field1", "value1", 60)
42 |
43 | mock_redis.hset.assert_called_once_with("test_hash", "field1", "value1")
44 | mock_redis.expire.assert_called_once_with("test_hash", 60)
45 | assert "Expires in 60 seconds." in result
46 |
47 | @pytest.mark.asyncio
48 | async def test_hset_integer_value(self, mock_redis_connection_manager):
49 | """Test hash set operation with integer value."""
50 | mock_redis = mock_redis_connection_manager
51 | mock_redis.hset.return_value = 1
52 |
53 | result = await hset("test_hash", "count", 42)
54 |
55 | mock_redis.hset.assert_called_once_with("test_hash", "count", "42")
56 | assert "Field 'count' set successfully in hash 'test_hash'." in result
57 |
58 | @pytest.mark.asyncio
59 | async def test_hset_float_value(self, mock_redis_connection_manager):
60 | """Test hash set operation with float value."""
61 | mock_redis = mock_redis_connection_manager
62 | mock_redis.hset.return_value = 1
63 |
64 | result = await hset("test_hash", "price", 19.99)
65 |
66 | mock_redis.hset.assert_called_once_with("test_hash", "price", "19.99")
67 | assert "Field 'price' set successfully in hash 'test_hash'." in result
68 |
69 | @pytest.mark.asyncio
70 | async def test_hset_redis_error(self, mock_redis_connection_manager):
71 | """Test hash set operation with Redis error."""
72 | mock_redis = mock_redis_connection_manager
73 | mock_redis.hset.side_effect = RedisError("Connection failed")
74 |
75 | result = await hset("test_hash", "field1", "value1")
76 |
77 | assert (
78 | "Error setting field 'field1' in hash 'test_hash': Connection failed"
79 | in result
80 | )
81 |
82 | @pytest.mark.asyncio
83 | async def test_hget_success(self, mock_redis_connection_manager):
84 | """Test successful hash get operation."""
85 | mock_redis = mock_redis_connection_manager
86 | mock_redis.hget.return_value = "value1"
87 |
88 | result = await hget("test_hash", "field1")
89 |
90 | mock_redis.hget.assert_called_once_with("test_hash", "field1")
91 | assert result == "value1"
92 |
93 | @pytest.mark.asyncio
94 | async def test_hget_field_not_found(self, mock_redis_connection_manager):
95 | """Test hash get operation when field doesn't exist."""
96 | mock_redis = mock_redis_connection_manager
97 | mock_redis.hget.return_value = None
98 |
99 | result = await hget("test_hash", "nonexistent_field")
100 |
101 | assert "Field 'nonexistent_field' not found in hash 'test_hash'" in result
102 |
103 | @pytest.mark.asyncio
104 | async def test_hget_redis_error(self, mock_redis_connection_manager):
105 | """Test hash get operation with Redis error."""
106 | mock_redis = mock_redis_connection_manager
107 | mock_redis.hget.side_effect = RedisError("Connection failed")
108 |
109 | result = await hget("test_hash", "field1")
110 |
111 | assert (
112 | "Error getting field 'field1' from hash 'test_hash': Connection failed"
113 | in result
114 | )
115 |
116 | @pytest.mark.asyncio
117 | async def test_hgetall_success(self, mock_redis_connection_manager):
118 | """Test successful hash get all operation."""
119 | mock_redis = mock_redis_connection_manager
120 | mock_redis.hgetall.return_value = {"field1": "value1", "field2": "value2"}
121 |
122 | result = await hgetall("test_hash")
123 |
124 | mock_redis.hgetall.assert_called_once_with("test_hash")
125 | assert result == {"field1": "value1", "field2": "value2"}
126 |
127 | @pytest.mark.asyncio
128 | async def test_hgetall_empty_hash(self, mock_redis_connection_manager):
129 | """Test hash get all operation on empty hash."""
130 | mock_redis = mock_redis_connection_manager
131 | mock_redis.hgetall.return_value = {}
132 |
133 | result = await hgetall("empty_hash")
134 |
135 | assert "Hash 'empty_hash' is empty or does not exist" in result
136 |
137 | @pytest.mark.asyncio
138 | async def test_hgetall_redis_error(self, mock_redis_connection_manager):
139 | """Test hash get all operation with Redis error."""
140 | mock_redis = mock_redis_connection_manager
141 | mock_redis.hgetall.side_effect = RedisError("Connection failed")
142 |
143 | result = await hgetall("test_hash")
144 |
145 | assert (
146 | "Error getting all fields from hash 'test_hash': Connection failed"
147 | in result
148 | )
149 |
150 | @pytest.mark.asyncio
151 | async def test_hdel_success(self, mock_redis_connection_manager):
152 | """Test successful hash delete operation."""
153 | mock_redis = mock_redis_connection_manager
154 | mock_redis.hdel.return_value = 1
155 |
156 | result = await hdel("test_hash", "field1")
157 |
158 | mock_redis.hdel.assert_called_once_with("test_hash", "field1")
159 | assert "Field 'field1' deleted from hash 'test_hash'." in result
160 |
161 | @pytest.mark.asyncio
162 | async def test_hdel_field_not_found(self, mock_redis_connection_manager):
163 | """Test hash delete operation when field doesn't exist."""
164 | mock_redis = mock_redis_connection_manager
165 | mock_redis.hdel.return_value = 0
166 |
167 | result = await hdel("test_hash", "nonexistent_field")
168 |
169 | assert "Field 'nonexistent_field' not found in hash 'test_hash'" in result
170 |
171 | @pytest.mark.asyncio
172 | async def test_hdel_redis_error(self, mock_redis_connection_manager):
173 | """Test hash delete operation with Redis error."""
174 | mock_redis = mock_redis_connection_manager
175 | mock_redis.hdel.side_effect = RedisError("Connection failed")
176 |
177 | result = await hdel("test_hash", "field1")
178 |
179 | assert (
180 | "Error deleting field 'field1' from hash 'test_hash': Connection failed"
181 | in result
182 | )
183 |
184 | @pytest.mark.asyncio
185 | async def test_hexists_field_exists(self, mock_redis_connection_manager):
186 | """Test hash exists operation when field exists."""
187 | mock_redis = mock_redis_connection_manager
188 | mock_redis.hexists.return_value = True
189 |
190 | result = await hexists("test_hash", "field1")
191 |
192 | mock_redis.hexists.assert_called_once_with("test_hash", "field1")
193 | assert result is True
194 |
195 | @pytest.mark.asyncio
196 | async def test_hexists_field_not_exists(self, mock_redis_connection_manager):
197 | """Test hash exists operation when field doesn't exist."""
198 | mock_redis = mock_redis_connection_manager
199 | mock_redis.hexists.return_value = False
200 |
201 | result = await hexists("test_hash", "nonexistent_field")
202 |
203 | assert result is False
204 |
205 | @pytest.mark.asyncio
206 | async def test_hexists_redis_error(self, mock_redis_connection_manager):
207 | """Test hash exists operation with Redis error."""
208 | mock_redis = mock_redis_connection_manager
209 | mock_redis.hexists.side_effect = RedisError("Connection failed")
210 |
211 | result = await hexists("test_hash", "field1")
212 |
213 | assert (
214 | "Error checking existence of field 'field1' in hash 'test_hash': Connection failed"
215 | in result
216 | )
217 |
218 | @pytest.mark.asyncio
219 | async def test_set_vector_in_hash_success(
220 | self, mock_redis_connection_manager, mock_numpy_array
221 | ):
222 | """Test successful vector set operation in hash."""
223 | mock_redis = mock_redis_connection_manager
224 | mock_redis.hset.return_value = 1
225 |
226 | vector = [0.1, 0.2, 0.3, 0.4, 0.5]
227 | result = await set_vector_in_hash("test_hash", vector)
228 |
229 | mock_numpy_array.assert_called_once_with(vector, dtype=np.float32)
230 | mock_redis.hset.assert_called_once_with(
231 | "test_hash", "vector", b"mock_binary_data"
232 | )
233 | assert result is True
234 |
235 | @pytest.mark.asyncio
236 | async def test_set_vector_in_hash_custom_field(
237 | self, mock_redis_connection_manager, mock_numpy_array
238 | ):
239 | """Test vector set operation with custom field name."""
240 | mock_redis = mock_redis_connection_manager
241 | mock_redis.hset.return_value = 1
242 |
243 | vector = [0.1, 0.2, 0.3]
244 | result = await set_vector_in_hash("test_hash", vector, "custom_vector")
245 |
246 | mock_redis.hset.assert_called_once_with(
247 | "test_hash", "custom_vector", b"mock_binary_data"
248 | )
249 | assert result is True
250 |
251 | @pytest.mark.asyncio
252 | async def test_set_vector_in_hash_redis_error(
253 | self, mock_redis_connection_manager, mock_numpy_array
254 | ):
255 | """Test vector set operation with Redis error."""
256 | mock_redis = mock_redis_connection_manager
257 | mock_redis.hset.side_effect = RedisError("Connection failed")
258 |
259 | vector = [0.1, 0.2, 0.3]
260 | result = await set_vector_in_hash("test_hash", vector)
261 |
262 | assert (
263 | "Error storing vector in hash 'test_hash' with field 'vector': Connection failed"
264 | in result
265 | )
266 |
267 | @pytest.mark.asyncio
268 | async def test_get_vector_from_hash_success(
269 | self, mock_redis_connection_manager, mock_numpy_frombuffer
270 | ):
271 | """Test successful vector get operation from hash."""
272 | mock_redis = mock_redis_connection_manager
273 | mock_redis.hget.return_value = b"mock_binary_data"
274 |
275 | result = await get_vector_from_hash("test_hash")
276 |
277 | mock_redis.hget.assert_called_once_with("test_hash", "vector")
278 | mock_numpy_frombuffer.assert_called_once_with(
279 | b"mock_binary_data", dtype=np.float32
280 | )
281 | assert result == [0.1, 0.2, 0.3]
282 |
283 | @pytest.mark.asyncio
284 | async def test_get_vector_from_hash_custom_field(
285 | self, mock_redis_connection_manager, mock_numpy_frombuffer
286 | ):
287 | """Test vector get operation with custom field name."""
288 | mock_redis = mock_redis_connection_manager
289 | mock_redis.hget.return_value = b"mock_binary_data"
290 |
291 | result = await get_vector_from_hash("test_hash", "custom_vector")
292 |
293 | mock_redis.hget.assert_called_once_with("test_hash", "custom_vector")
294 | assert result == [0.1, 0.2, 0.3]
295 |
296 | @pytest.mark.asyncio
297 | async def test_get_vector_from_hash_not_found(self, mock_redis_connection_manager):
298 | """Test vector get operation when field doesn't exist."""
299 | mock_redis = mock_redis_connection_manager
300 | mock_redis.hget.return_value = None
301 |
302 | result = await get_vector_from_hash("test_hash")
303 |
304 | assert "Field 'vector' not found in hash 'test_hash'." in result
305 |
306 | @pytest.mark.asyncio
307 | async def test_get_vector_from_hash_redis_error(
308 | self, mock_redis_connection_manager
309 | ):
310 | """Test vector get operation with Redis error."""
311 | mock_redis = mock_redis_connection_manager
312 | mock_redis.hget.side_effect = RedisError("Connection failed")
313 |
314 | result = await get_vector_from_hash("test_hash")
315 |
316 | assert (
317 | "Error retrieving vector field 'vector' from hash 'test_hash': Connection failed"
318 | in result
319 | )
320 |
321 | @pytest.mark.asyncio
322 | async def test_hset_expiration_error(self, mock_redis_connection_manager):
323 | """Test hash set operation when expiration fails."""
324 | mock_redis = mock_redis_connection_manager
325 | mock_redis.hset.return_value = 1
326 | mock_redis.expire.side_effect = RedisError("Expire failed")
327 |
328 | result = await hset("test_hash", "field1", "value1", 60)
329 |
330 | # Should still report success for hset, but mention expire error
331 | assert (
332 | "Error setting field 'field1' in hash 'test_hash': Expire failed" in result
333 | )
334 |
335 | @pytest.mark.asyncio
336 | async def test_vector_operations_with_empty_vector(
337 | self, mock_redis_connection_manager, mock_numpy_array
338 | ):
339 | """Test vector operations with empty vector."""
340 | mock_redis = mock_redis_connection_manager
341 | mock_redis.hset.return_value = 1
342 |
343 | empty_vector = []
344 | result = await set_vector_in_hash("test_hash", empty_vector)
345 |
346 | mock_numpy_array.assert_called_once_with(empty_vector, dtype=np.float32)
347 | assert result is True
348 |
349 | @pytest.mark.asyncio
350 | async def test_vector_operations_with_large_vector(
351 | self, mock_redis_connection_manager, mock_numpy_array
352 | ):
353 | """Test vector operations with large vector."""
354 | mock_redis = mock_redis_connection_manager
355 | mock_redis.hset.return_value = 1
356 |
357 | large_vector = [0.1] * 1000 # 1000-dimensional vector
358 | result = await set_vector_in_hash("test_hash", large_vector)
359 |
360 | mock_numpy_array.assert_called_once_with(large_vector, dtype=np.float32)
361 | assert result is True
362 |
```
--------------------------------------------------------------------------------
/tests/test_entraid_auth.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Unit tests for Entra ID authentication in src/common/entraid_auth.py
3 | """
4 |
5 | from unittest.mock import Mock, patch
6 |
7 | import pytest
8 |
9 | from src.common.entraid_auth import (
10 | create_credential_provider,
11 | EntraIDAuthenticationError,
12 | _create_token_manager_config,
13 | _create_service_principal_provider,
14 | _create_managed_identity_provider,
15 | _create_default_credential_provider,
16 | )
17 |
18 |
19 | class TestCreateCredentialProvider:
20 | """Test cases for create_credential_provider function."""
21 |
22 | def setup_method(self):
23 | """Set up test fixtures."""
24 | self.original_entraid_cfg = {}
25 | self.original_entraid_available = None
26 |
27 | def teardown_method(self):
28 | """Restore original state."""
29 | pass
30 |
31 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
32 | def test_returns_none_when_auth_disabled(self, mock_is_enabled):
33 | """Test that None is returned when Entra ID auth is disabled."""
34 | mock_is_enabled.return_value = False
35 |
36 | result = create_credential_provider()
37 |
38 | assert result is None
39 | mock_is_enabled.assert_called_once()
40 |
41 | @patch("src.common.entraid_auth.ENTRAID_AVAILABLE", False)
42 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
43 | def test_raises_error_when_package_not_available(self, mock_is_enabled):
44 | """Test that error is raised when redis-entraid package is not available."""
45 | mock_is_enabled.return_value = True
46 |
47 | with pytest.raises(EntraIDAuthenticationError) as exc_info:
48 | create_credential_provider()
49 |
50 | assert "redis-entraid package is required" in str(exc_info.value)
51 |
52 | @patch("src.common.entraid_auth.validate_entraid_config")
53 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
54 | def test_raises_error_on_invalid_config(self, mock_is_enabled, mock_validate):
55 | """Test that error is raised when configuration is invalid."""
56 | mock_is_enabled.return_value = True
57 | mock_validate.return_value = (False, "Invalid configuration")
58 |
59 | with pytest.raises(EntraIDAuthenticationError) as exc_info:
60 | create_credential_provider()
61 |
62 | assert "Invalid Entra ID configuration" in str(exc_info.value)
63 |
64 | @patch("src.common.entraid_auth._create_service_principal_provider")
65 | @patch("src.common.entraid_auth._create_token_manager_config")
66 | @patch.dict(
67 | "src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "service_principal"}
68 | )
69 | @patch("src.common.entraid_auth.validate_entraid_config")
70 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
71 | def test_creates_service_principal_provider(
72 | self,
73 | mock_is_enabled,
74 | mock_validate,
75 | mock_create_token_config,
76 | mock_create_sp_provider,
77 | ):
78 | """Test creating service principal credential provider."""
79 | mock_is_enabled.return_value = True
80 | mock_validate.return_value = (True, "")
81 | mock_token_config = Mock()
82 | mock_create_token_config.return_value = mock_token_config
83 | mock_provider = Mock()
84 | mock_create_sp_provider.return_value = mock_provider
85 |
86 | result = create_credential_provider()
87 |
88 | assert result == mock_provider
89 | mock_create_sp_provider.assert_called_once_with(mock_token_config)
90 |
91 | @patch("src.common.entraid_auth._create_managed_identity_provider")
92 | @patch("src.common.entraid_auth._create_token_manager_config")
93 | @patch.dict(
94 | "src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "managed_identity"}
95 | )
96 | @patch("src.common.entraid_auth.validate_entraid_config")
97 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
98 | def test_creates_managed_identity_provider(
99 | self,
100 | mock_is_enabled,
101 | mock_validate,
102 | mock_create_token_config,
103 | mock_create_mi_provider,
104 | ):
105 | """Test creating managed identity credential provider."""
106 | mock_is_enabled.return_value = True
107 | mock_validate.return_value = (True, "")
108 | mock_token_config = Mock()
109 | mock_create_token_config.return_value = mock_token_config
110 | mock_provider = Mock()
111 | mock_create_mi_provider.return_value = mock_provider
112 |
113 | result = create_credential_provider()
114 |
115 | assert result == mock_provider
116 | mock_create_mi_provider.assert_called_once_with(mock_token_config)
117 |
118 | @patch("src.common.entraid_auth._create_default_credential_provider")
119 | @patch("src.common.entraid_auth._create_token_manager_config")
120 | @patch.dict(
121 | "src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "default_credential"}
122 | )
123 | @patch("src.common.entraid_auth.validate_entraid_config")
124 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
125 | def test_creates_default_credential_provider(
126 | self,
127 | mock_is_enabled,
128 | mock_validate,
129 | mock_create_token_config,
130 | mock_create_dc_provider,
131 | ):
132 | """Test creating default credential provider."""
133 | mock_is_enabled.return_value = True
134 | mock_validate.return_value = (True, "")
135 | mock_token_config = Mock()
136 | mock_create_token_config.return_value = mock_token_config
137 | mock_provider = Mock()
138 | mock_create_dc_provider.return_value = mock_provider
139 |
140 | result = create_credential_provider()
141 |
142 | assert result == mock_provider
143 | mock_create_dc_provider.assert_called_once_with(mock_token_config)
144 |
145 | @patch("src.common.entraid_auth._create_token_manager_config")
146 | @patch.dict(
147 | "src.common.entraid_auth.ENTRAID_CFG", {"auth_flow": "unsupported_flow"}
148 | )
149 | @patch("src.common.entraid_auth.validate_entraid_config")
150 | @patch("src.common.entraid_auth.is_entraid_auth_enabled")
151 | def test_raises_error_on_unsupported_flow(
152 | self, mock_is_enabled, mock_validate, mock_create_token_config
153 | ):
154 | """Test that error is raised for unsupported auth flow."""
155 | mock_is_enabled.return_value = True
156 | mock_validate.return_value = (True, "")
157 | mock_create_token_config.return_value = Mock()
158 |
159 | with pytest.raises(EntraIDAuthenticationError) as exc_info:
160 | create_credential_provider()
161 |
162 | assert "Unsupported authentication flow" in str(exc_info.value)
163 |
164 |
165 | class TestCreateTokenManagerConfig:
166 | """Test cases for _create_token_manager_config function."""
167 |
168 | @patch("src.common.entraid_auth.ENTRAID_CFG")
169 | @patch("src.common.entraid_auth.TokenManagerConfig")
170 | @patch("src.common.entraid_auth.RetryPolicy")
171 | def test_creates_token_manager_config(
172 | self, mock_retry_policy_class, mock_token_config_class, mock_entraid_cfg
173 | ):
174 | """Test creating token manager configuration."""
175 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
176 | "retry_max_attempts": 3,
177 | "retry_delay_ms": 100,
178 | "token_expiration_refresh_ratio": 0.9,
179 | "lower_refresh_bound_millis": 30000,
180 | "token_request_execution_timeout_ms": 10000,
181 | }[key]
182 |
183 | mock_retry_policy = Mock()
184 | mock_retry_policy_class.return_value = mock_retry_policy
185 | mock_token_config = Mock()
186 | mock_token_config_class.return_value = mock_token_config
187 |
188 | result = _create_token_manager_config()
189 |
190 | # Verify RetryPolicy was created with correct parameters
191 | mock_retry_policy_class.assert_called_once_with(max_attempts=3, delay_in_ms=100)
192 |
193 | # Verify TokenManagerConfig was created with correct parameters
194 | mock_token_config_class.assert_called_once_with(
195 | expiration_refresh_ratio=0.9,
196 | lower_refresh_bound_millis=30000,
197 | token_request_execution_timeout_in_ms=10000,
198 | retry_policy=mock_retry_policy,
199 | )
200 |
201 | assert result == mock_token_config
202 |
203 |
204 | class TestCreateServicePrincipalProvider:
205 | """Test cases for _create_service_principal_provider function."""
206 |
207 | @patch("src.common.entraid_auth.create_from_service_principal")
208 | @patch("src.common.entraid_auth.ENTRAID_CFG")
209 | def test_creates_service_principal_provider(self, mock_entraid_cfg, mock_create_sp):
210 | """Test creating service principal provider."""
211 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
212 | "client_id": "test-client-id",
213 | "client_secret": "test-secret",
214 | "tenant_id": "test-tenant-id",
215 | }[key]
216 |
217 | mock_token_config = Mock()
218 | mock_provider = Mock()
219 | mock_create_sp.return_value = mock_provider
220 |
221 | result = _create_service_principal_provider(mock_token_config)
222 |
223 | mock_create_sp.assert_called_once_with(
224 | client_id="test-client-id",
225 | client_credential="test-secret",
226 | tenant_id="test-tenant-id",
227 | token_manager_config=mock_token_config,
228 | )
229 |
230 | assert result == mock_provider
231 |
232 |
233 | class TestCreateManagedIdentityProvider:
234 | """Test cases for _create_managed_identity_provider function."""
235 |
236 | @patch("src.common.entraid_auth.create_from_managed_identity")
237 | @patch("src.common.entraid_auth.ManagedIdentityType")
238 | @patch("src.common.entraid_auth.ENTRAID_CFG")
239 | def test_creates_system_assigned_managed_identity_provider(
240 | self, mock_entraid_cfg, mock_identity_type_class, mock_create_mi
241 | ):
242 | """Test creating system-assigned managed identity provider."""
243 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
244 | "identity_type": "system_assigned",
245 | "resource": "https://redis.azure.com/",
246 | }[key]
247 |
248 | mock_identity_type = Mock()
249 | mock_identity_type_class.SYSTEM_ASSIGNED = mock_identity_type
250 | mock_token_config = Mock()
251 | mock_provider = Mock()
252 | mock_create_mi.return_value = mock_provider
253 |
254 | result = _create_managed_identity_provider(mock_token_config)
255 |
256 | mock_create_mi.assert_called_once_with(
257 | identity_type=mock_identity_type,
258 | resource="https://redis.azure.com/",
259 | token_manager_config=mock_token_config,
260 | )
261 |
262 | assert result == mock_provider
263 |
264 | @patch("src.common.entraid_auth.create_from_managed_identity")
265 | @patch("src.common.entraid_auth.ManagedIdentityType")
266 | @patch("src.common.entraid_auth.ENTRAID_CFG")
267 | def test_creates_user_assigned_managed_identity_provider(
268 | self, mock_entraid_cfg, mock_identity_type_class, mock_create_mi
269 | ):
270 | """Test creating user-assigned managed identity provider."""
271 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
272 | "identity_type": "user_assigned",
273 | "resource": "https://redis.azure.com/",
274 | "user_assigned_identity_client_id": "test-user-assigned-id",
275 | }[key]
276 |
277 | mock_identity_type = Mock()
278 | mock_identity_type_class.USER_ASSIGNED = mock_identity_type
279 | mock_token_config = Mock()
280 | mock_provider = Mock()
281 | mock_create_mi.return_value = mock_provider
282 |
283 | result = _create_managed_identity_provider(mock_token_config)
284 |
285 | mock_create_mi.assert_called_once_with(
286 | identity_type=mock_identity_type,
287 | resource="https://redis.azure.com/",
288 | client_id="test-user-assigned-id",
289 | token_manager_config=mock_token_config,
290 | )
291 |
292 | assert result == mock_provider
293 |
294 | @patch("src.common.entraid_auth.ENTRAID_CFG")
295 | def test_raises_error_on_invalid_identity_type(self, mock_entraid_cfg):
296 | """Test that error is raised for invalid identity type."""
297 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
298 | "identity_type": "invalid_type",
299 | }[key]
300 |
301 | mock_token_config = Mock()
302 |
303 | with pytest.raises(EntraIDAuthenticationError) as exc_info:
304 | _create_managed_identity_provider(mock_token_config)
305 |
306 | assert "Invalid identity type" in str(exc_info.value)
307 |
308 |
309 | class TestCreateDefaultCredentialProvider:
310 | """Test cases for _create_default_credential_provider function."""
311 |
312 | @patch("src.common.entraid_auth.create_from_default_azure_credential")
313 | @patch("src.common.entraid_auth.ENTRAID_CFG")
314 | def test_creates_default_credential_provider_single_scope(
315 | self, mock_entraid_cfg, mock_create_dc
316 | ):
317 | """Test creating default credential provider with single scope."""
318 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
319 | "scopes": "https://redis.azure.com/.default",
320 | }[key]
321 |
322 | mock_token_config = Mock()
323 | mock_provider = Mock()
324 | mock_create_dc.return_value = mock_provider
325 |
326 | result = _create_default_credential_provider(mock_token_config)
327 |
328 | mock_create_dc.assert_called_once_with(
329 | scopes=("https://redis.azure.com/.default",),
330 | token_manager_config=mock_token_config,
331 | )
332 |
333 | assert result == mock_provider
334 |
335 | @patch("src.common.entraid_auth.create_from_default_azure_credential")
336 | @patch("src.common.entraid_auth.ENTRAID_CFG")
337 | def test_creates_default_credential_provider_multiple_scopes(
338 | self, mock_entraid_cfg, mock_create_dc
339 | ):
340 | """Test creating default credential provider with multiple scopes."""
341 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
342 | "scopes": "https://redis.azure.com/.default, https://other.scope/.default",
343 | }[key]
344 |
345 | mock_token_config = Mock()
346 | mock_provider = Mock()
347 | mock_create_dc.return_value = mock_provider
348 |
349 | result = _create_default_credential_provider(mock_token_config)
350 |
351 | mock_create_dc.assert_called_once_with(
352 | scopes=(
353 | "https://redis.azure.com/.default",
354 | "https://other.scope/.default",
355 | ),
356 | token_manager_config=mock_token_config,
357 | )
358 |
359 | assert result == mock_provider
360 |
361 | @patch("src.common.entraid_auth.create_from_default_azure_credential")
362 | @patch("src.common.entraid_auth.ENTRAID_CFG")
363 | def test_creates_default_credential_provider_with_whitespace(
364 | self, mock_entraid_cfg, mock_create_dc
365 | ):
366 | """Test that scopes with whitespace are properly trimmed."""
367 | mock_entraid_cfg.__getitem__.side_effect = lambda key: {
368 | "scopes": " https://redis.azure.com/.default , https://other.scope/.default ",
369 | }[key]
370 |
371 | mock_token_config = Mock()
372 | mock_provider = Mock()
373 | mock_create_dc.return_value = mock_provider
374 |
375 | result = _create_default_credential_provider(mock_token_config)
376 |
377 | mock_create_dc.assert_called_once_with(
378 | scopes=(
379 | "https://redis.azure.com/.default",
380 | "https://other.scope/.default",
381 | ),
382 | token_manager_config=mock_token_config,
383 | )
384 |
385 | assert result == mock_provider
386 |
387 |
388 | class TestEntraIDAuthenticationError:
389 | """Test cases for EntraIDAuthenticationError exception."""
390 |
391 | def test_exception_can_be_raised(self):
392 | """Test that EntraIDAuthenticationError can be raised."""
393 | with pytest.raises(EntraIDAuthenticationError):
394 | raise EntraIDAuthenticationError("Test error message")
395 |
396 | def test_exception_message(self):
397 | """Test that exception message is preserved."""
398 | error_msg = "Test error message"
399 | with pytest.raises(EntraIDAuthenticationError) as exc_info:
400 | raise EntraIDAuthenticationError(error_msg)
401 |
402 | assert str(exc_info.value) == error_msg
403 |
404 | def test_exception_is_exception_subclass(self):
405 | """Test that EntraIDAuthenticationError is an Exception subclass."""
406 | assert issubclass(EntraIDAuthenticationError, Exception)
407 |
```