#
tokens: 44523/50000 9/114 files (page 4/6)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 4 of 6. Use http://codebase.md/threatflux/yaraflux?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .dockerignore
├── .env
├── .env.example
├── .github
│   ├── dependabot.yml
│   └── workflows
│       ├── ci.yml
│       ├── codeql.yml
│       ├── publish-release.yml
│       ├── safety_scan.yml
│       ├── update-actions.yml
│       └── version-bump.yml
├── .gitignore
├── .pylintrc
├── .safety-project.ini
├── bandit.yaml
├── codecov.yml
├── docker-compose.yml
├── docker-entrypoint.sh
├── Dockerfile
├── docs
│   ├── api_mcp_architecture.md
│   ├── api.md
│   ├── architecture_diagram.md
│   ├── cli.md
│   ├── examples.md
│   ├── file_management.md
│   ├── installation.md
│   ├── mcp.md
│   ├── README.md
│   └── yara_rules.md
├── entrypoint.sh
├── examples
│   ├── claude_desktop_config.json
│   └── install_via_smithery.sh
├── glama.json
├── images
│   ├── architecture.svg
│   ├── architecture.txt
│   ├── image copy.png
│   └── image.png
├── LICENSE
├── Makefile
├── mypy.ini
├── pyproject.toml
├── pytest.ini
├── README.md
├── requirements-dev.txt
├── requirements.txt
├── SECURITY.md
├── setup.py
├── src
│   └── yaraflux_mcp_server
│       ├── __init__.py
│       ├── __main__.py
│       ├── app.py
│       ├── auth.py
│       ├── claude_mcp_tools.py
│       ├── claude_mcp.py
│       ├── config.py
│       ├── mcp_server.py
│       ├── mcp_tools
│       │   ├── __init__.py
│       │   ├── base.py
│       │   ├── file_tools.py
│       │   ├── rule_tools.py
│       │   ├── scan_tools.py
│       │   └── storage_tools.py
│       ├── models.py
│       ├── routers
│       │   ├── __init__.py
│       │   ├── auth.py
│       │   ├── files.py
│       │   ├── rules.py
│       │   └── scan.py
│       ├── run_mcp.py
│       ├── storage
│       │   ├── __init__.py
│       │   ├── base.py
│       │   ├── factory.py
│       │   ├── local.py
│       │   └── minio.py
│       ├── utils
│       │   ├── __init__.py
│       │   ├── error_handling.py
│       │   ├── logging_config.py
│       │   ├── param_parsing.py
│       │   └── wrapper_generator.py
│       └── yara_service.py
├── test.txt
├── tests
│   ├── conftest.py
│   ├── functional
│   │   └── __init__.py
│   ├── integration
│   │   └── __init__.py
│   └── unit
│       ├── __init__.py
│       ├── test_app.py
│       ├── test_auth_fixtures
│       │   ├── test_token_auth.py
│       │   └── test_user_management.py
│       ├── test_auth.py
│       ├── test_claude_mcp_tools.py
│       ├── test_cli
│       │   ├── __init__.py
│       │   ├── test_main.py
│       │   └── test_run_mcp.py
│       ├── test_config.py
│       ├── test_mcp_server.py
│       ├── test_mcp_tools
│       │   ├── test_file_tools_extended.py
│       │   ├── test_file_tools.py
│       │   ├── test_init.py
│       │   ├── test_rule_tools_extended.py
│       │   ├── test_rule_tools.py
│       │   ├── test_scan_tools_extended.py
│       │   ├── test_scan_tools.py
│       │   ├── test_storage_tools_enhanced.py
│       │   └── test_storage_tools.py
│       ├── test_mcp_tools.py
│       ├── test_routers
│       │   ├── test_auth_router.py
│       │   ├── test_files.py
│       │   ├── test_rules.py
│       │   └── test_scan.py
│       ├── test_storage
│       │   ├── test_factory.py
│       │   ├── test_local_storage.py
│       │   └── test_minio_storage.py
│       ├── test_storage_base.py
│       ├── test_utils
│       │   ├── __init__.py
│       │   ├── test_error_handling.py
│       │   ├── test_logging_config.py
│       │   ├── test_param_parsing.py
│       │   └── test_wrapper_generator.py
│       ├── test_yara_rule_compilation.py
│       └── test_yara_service.py
└── uv.lock
```

# Files

--------------------------------------------------------------------------------
/src/yaraflux_mcp_server/routers/rules.py:
--------------------------------------------------------------------------------

```python
  1 | """YARA rules router for YaraFlux MCP Server.
  2 | 
  3 | This module provides API routes for YARA rule management, including listing,
  4 | adding, updating, and deleting rules.
  5 | """
  6 | 
  7 | import logging
  8 | from datetime import UTC, datetime
  9 | from typing import List, Optional
 10 | 
 11 | from fastapi import (
 12 |     APIRouter,
 13 |     Body,
 14 |     Depends,
 15 |     File,
 16 |     Form,
 17 |     HTTPException,
 18 |     Request,
 19 |     Response,
 20 |     UploadFile,
 21 |     status,
 22 | )
 23 | 
 24 | from yaraflux_mcp_server.auth import get_current_active_user, validate_admin
 25 | from yaraflux_mcp_server.models import ErrorResponse, User, YaraRuleCreate, YaraRuleMetadata
 26 | from yaraflux_mcp_server.yara_service import YaraError, yara_service
 27 | 
 28 | # Configure logging
 29 | logger = logging.getLogger(__name__)
 30 | 
 31 | # Create router
 32 | router = APIRouter(
 33 |     prefix="/rules",
 34 |     tags=["rules"],
 35 |     responses={
 36 |         401: {"description": "Unauthorized", "model": ErrorResponse},
 37 |         403: {"description": "Forbidden", "model": ErrorResponse},
 38 |         404: {"description": "Not Found", "model": ErrorResponse},
 39 |         422: {"description": "Validation Error", "model": ErrorResponse},
 40 |     },
 41 | )
 42 | 
 43 | # Import MCP tools with safeguards
 44 | try:
 45 |     from yaraflux_mcp_server.mcp_tools import import_threatflux_rules as import_rules_tool
 46 |     from yaraflux_mcp_server.mcp_tools import validate_yara_rule as validate_rule_tool
 47 | except Exception as e:
 48 |     logger.error(f"Error importing MCP tools: {str(e)}")
 49 | 
 50 |     # Create fallback functions
 51 |     def validate_rule_tool(content: str):
 52 |         try:
 53 |             # Create a temporary rule name for validation
 54 |             temp_rule_name = f"validate_{int(datetime.now(UTC).timestamp())}.yar"
 55 |             # Validate via direct service call
 56 |             yara_service.add_rule(temp_rule_name, content)
 57 |             yara_service.delete_rule(temp_rule_name)
 58 |             return {"valid": True, "message": "Rule is valid"}
 59 |         except Exception as error:
 60 |             return {"valid": False, "message": str(error)}
 61 | 
 62 |     def import_rules_tool(url: Optional[str] = None):
 63 |         # Simple import implementation
 64 |         url_msg = f" from {url}" if url else ""
 65 |         return {"success": False, "message": f"MCP tools not available for import{url_msg}"}
 66 | 
 67 | 
 68 | @router.get("/", response_model=List[YaraRuleMetadata])
 69 | async def list_rules(source: Optional[str] = None):
 70 |     """List all YARA rules.
 71 | 
 72 |     Args:
 73 |         source: Optional source filter ("custom" or "community")
 74 |         current_user: Current authenticated user
 75 | 
 76 |     Returns:
 77 |         List of YARA rule metadata
 78 |     """
 79 |     try:
 80 |         rules = yara_service.list_rules(source)
 81 |         return rules
 82 |     except YaraError as error:
 83 |         raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error
 84 | 
 85 | 
 86 | @router.get("/{rule_name}", response_model=dict)
 87 | async def get_rule(
 88 |     rule_name: str,
 89 |     source: Optional[str] = "custom",
 90 | ):
 91 |     """Get a YARA rule's content and metadata.
 92 | 
 93 |     Args:
 94 |         rule_name: Name of the rule
 95 |         source: Source of the rule ("custom" or "community")
 96 |         current_user: Current authenticated user
 97 | 
 98 |     Returns:
 99 |         Rule content and metadata
100 | 
101 |     Raises:
102 |         HTTPException: If rule not found
103 |     """
104 |     try:
105 |         # Get rule content
106 |         content = yara_service.get_rule(rule_name, source)
107 | 
108 |         # Find metadata in the list of rules
109 |         metadata = None
110 |         rules = yara_service.list_rules(source)
111 |         for rule in rules:
112 |             if rule.name == rule_name:
113 |                 metadata = rule
114 |                 break
115 | 
116 |         return {
117 |             "name": rule_name,
118 |             "source": source,
119 |             "content": content,
120 |             "metadata": metadata.model_dump() if metadata else {},
121 |         }
122 |     except YaraError as error:
123 |         raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(error)) from error
124 | 
125 | 
126 | @router.get("/{rule_name}/raw")
127 | async def get_rule_raw(
128 |     rule_name: str,
129 |     source: Optional[str] = "custom",
130 | ):
131 |     """Get a YARA rule's raw content as plain text.
132 | 
133 |     Args:
134 |         rule_name: Name of the rule
135 |         source: Source of the rule ("custom" or "community")
136 |         current_user: Current authenticated user
137 | 
138 |     Returns:
139 |         Plain text rule content
140 | 
141 |     Raises:
142 |         HTTPException: If rule not found
143 |     """
144 |     try:
145 |         # Get rule content
146 |         content = yara_service.get_rule(rule_name, source)
147 | 
148 |         # Return as plain text
149 |         return Response(content=content, media_type="text/plain")
150 |     except YaraError as error:
151 |         raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(error)) from error
152 | 
153 | 
154 | @router.post("/", response_model=YaraRuleMetadata)
155 | async def create_rule(rule: YaraRuleCreate, current_user: User = Depends(get_current_active_user)):
156 |     """Create a new YARA rule.
157 | 
158 |     Args:
159 |         rule: Rule to create
160 |         current_user: Current authenticated user
161 | 
162 |     Returns:
163 |         Metadata of the created rule
164 | 
165 |     Raises:
166 |         HTTPException: If rule creation fails
167 |     """
168 |     try:
169 |         metadata = yara_service.add_rule(rule.name, rule.content)
170 |         logger.info(f"Rule {rule.name} created by {current_user.username}")
171 |         return metadata
172 |     except YaraError as error:
173 |         raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
174 | 
175 | 
176 | @router.post("/upload", response_model=YaraRuleMetadata)
177 | async def upload_rule(
178 |     rule_file: UploadFile = File(...),
179 |     source: str = Form("custom"),
180 |     current_user: User = Depends(get_current_active_user),
181 | ):
182 |     """Upload a YARA rule file.
183 | 
184 |     Args:
185 |         rule_file: YARA rule file to upload
186 |         source: Source of the rule ("custom" or "community")
187 |         current_user: Current authenticated user
188 | 
189 |     Returns:
190 |         Metadata of the uploaded rule
191 | 
192 |     Raises:
193 |         HTTPException: If file upload or rule creation fails
194 |     """
195 |     try:
196 |         # Read file content
197 |         content = await rule_file.read()
198 | 
199 |         # Get rule name from filename
200 |         rule_name = rule_file.filename
201 |         if not rule_name:
202 |             raise ValueError("Filename is required")
203 | 
204 |         # Add rule
205 |         metadata = yara_service.add_rule(rule_name, content.decode("utf-8"), source)
206 |         logger.info(f"Rule {rule_name} uploaded by {current_user.username}")
207 |         return metadata
208 |     except YaraError as err:
209 |         raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(err)) from err
210 |     except Exception as error:
211 |         raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error
212 | 
213 | 
214 | @router.put("/{rule_name}", response_model=YaraRuleMetadata)
215 | async def update_rule(
216 |     rule_name: str,
217 |     content: str = Body(...),
218 |     source: str = "custom",
219 |     current_user: User = Depends(get_current_active_user),
220 | ):
221 |     """Update an existing YARA rule.
222 | 
223 |     Args:
224 |         rule_name: Name of the rule
225 |         content: Updated rule content
226 |         source: Source of the rule ("custom" or "community")
227 |         current_user: Current authenticated user
228 | 
229 |     Returns:
230 |         Metadata of the updated rule
231 | 
232 |     Raises:
233 |         HTTPException: If rule update fails
234 |     """
235 |     try:
236 |         metadata = yara_service.update_rule(rule_name, content, source)
237 |         logger.info(f"Rule {rule_name} updated by {current_user.username}")
238 |         return metadata
239 |     except YaraError as error:
240 |         if "Rule not found" in str(error):
241 |             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(error)) from error
242 |         raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
243 | 
244 | 
245 | @router.put("/{rule_name}/plain", response_model=YaraRuleMetadata)
246 | async def update_rule_plain(
247 |     rule_name: str,
248 |     source: str = "custom",
249 |     content: str = Body(..., media_type="text/plain"),
250 |     current_user: User = Depends(get_current_active_user),
251 | ):
252 |     """Update an existing YARA rule using plain text.
253 | 
254 |     This endpoint accepts the YARA rule as plain text in the request body, making it
255 |     easier to update YARA rules without having to escape special characters for JSON.
256 | 
257 |     Args:
258 |         rule_name: Name of the rule
259 |         source: Source of the rule ("custom" or "community")
260 |         content: Updated YARA rule content as plain text
261 |         current_user: Current authenticated user
262 | 
263 |     Returns:
264 |         Metadata of the updated rule
265 | 
266 |     Raises:
267 |         HTTPException: If rule update fails
268 |     """
269 |     try:
270 |         metadata = yara_service.update_rule(rule_name, content, source)
271 |         logger.info(f"Rule {rule_name} updated by {current_user.username} via plain text endpoint")
272 |         return metadata
273 |     except YaraError as error:
274 |         if "Rule not found" in str(error):
275 |             raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(error)) from error
276 |         raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
277 | 
278 | 
279 | @router.delete("/{rule_name}")
280 | async def delete_rule(rule_name: str, source: str = "custom", current_user: User = Depends(get_current_active_user)):
281 |     """Delete a YARA rule.
282 | 
283 |     Args:
284 |         rule_name: Name of the rule
285 |         source: Source of the rule ("custom" or "community")
286 |         current_user: Current authenticated user
287 | 
288 |     Returns:
289 |         Success message
290 | 
291 |     Raises:
292 |         HTTPException: If rule deletion fails
293 |     """
294 |     try:
295 |         result = yara_service.delete_rule(rule_name, source)
296 |         if not result:
297 |             raise HTTPException(
298 |                 status_code=status.HTTP_404_NOT_FOUND,
299 |                 detail=f"Rule {rule_name} not found in {source}",
300 |             )
301 | 
302 |         logger.info(f"Rule {rule_name} deleted by {current_user.username}")
303 | 
304 |         return {"message": f"Rule {rule_name} deleted"}
305 |     except YaraError as error:
306 |         raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error
307 | 
308 | 
309 | @router.post("/import")
310 | async def import_rules(url: Optional[str] = None, current_user: User = Depends(validate_admin)):
311 |     """Import ThreatFlux YARA rules from GitHub.
312 | 
313 |     Args:
314 |         url: URL to the GitHub repository
315 |         current_user: Current authenticated admin user
316 | 
317 |     Returns:
318 |         Import result
319 | 
320 |     Raises:
321 |         HTTPException: If import fails
322 |     """
323 |     try:
324 |         result = import_rules_tool(url)
325 | 
326 |         if not result.get("success"):
327 |             raise HTTPException(
328 |                 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
329 |                 detail=result.get("message", "Import failed"),
330 |             )
331 | 
332 |         logger.info(f"Rules imported from {url or 'ThreatFlux repository'} by {current_user.username}")
333 | 
334 |         return result
335 |     except Exception as error:
336 |         raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error
337 | 
338 | 
339 | @router.post("/validate")
340 | async def validate_rule(request: Request):
341 |     """Validate a YARA rule.
342 | 
343 |     This endpoint tries to handle both JSON and plain text inputs, with some format detection.
344 |     For guaranteed reliability, use the /validate/plain endpoint for plain text YARA rules.
345 | 
346 |     Args:
347 |         request: Request object containing the rule content
348 |         current_user: Current authenticated user
349 | 
350 |     Returns:
351 |         Validation result
352 |     """
353 |     try:
354 |         # Read content as text
355 |         content = await request.body()
356 |         content_str = content.decode("utf-8")
357 | 
358 |         # Basic heuristic to detect YARA vs JSON:
359 |         # If it starts with a curly brace and has line breaks, it might be a YARA rule
360 |         # If it doesn't look like valid JSON, treat it as a YARA rule
361 |         if not content_str.strip().startswith("rule"):
362 |             try:
363 |                 # Try to parse as JSON
364 |                 import json  # pylint: disable=import-outside-toplevel
365 | 
366 |                 json_content = json.loads(content_str)
367 | 
368 |                 # If it parsed as JSON, check what kind of content it has
369 |                 if isinstance(json_content, str):
370 |                     # It was a JSON string, use that as the content
371 |                     content_str = json_content
372 |                 elif isinstance(json_content, dict) and "content" in json_content:
373 |                     # It was a JSON object with a content field
374 |                     content_str = json_content["content"]
375 |             except json.JSONDecodeError:
376 |                 # It wasn't valid JSON, assume it's a YARA rule
377 |                 logger.error("Failed to decode JSON content from %s", content_str)
378 | 
379 |         # Use the validate_yara_rule MCP tool
380 |         result = validate_rule_tool(content_str)
381 |         return result
382 |     except Exception as error:
383 |         raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(error)) from error
384 | 
385 | 
386 | @router.post("/validate/plain")
387 | async def validate_rule_plain(
388 |     content: str = Body(..., media_type="text/plain"),
389 | ):
390 |     """Validate a YARA rule submitted as plain text.
391 | 
392 |     This endpoint accepts the YARA rule as plain text without requiring JSON formatting.
393 | 
394 |     Args:
395 |         content: YARA rule content to validate as plain text
396 |         current_user: Current authenticated user
397 | 
398 |     Returns:
399 |         Validation result
400 |     """
401 |     try:
402 |         # Use the validate_yara_rule MCP tool
403 |         result = validate_rule_tool(content)
404 |         return result
405 |     except Exception as e:
406 |         raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e
407 | 
408 | 
409 | @router.post("/plain", response_model=YaraRuleMetadata)
410 | async def create_rule_plain(
411 |     rule_name: str,
412 |     source: str = "custom",
413 |     content: str = Body(..., media_type="text/plain"),
414 |     current_user: User = Depends(get_current_active_user),
415 | ):
416 |     """Create a new YARA rule using plain text content.
417 | 
418 |     This endpoint accepts the YARA rule as plain text in the request body, making it
419 |     easier to submit YARA rules without having to escape special characters for JSON.
420 | 
421 |     Args:
422 |         rule_name: Name of the rule file (with or without .yar extension)
423 |         source: Source of the rule ("custom" or "community")
424 |         content: YARA rule content as plain text
425 |         current_user: Current authenticated user
426 | 
427 |     Returns:
428 |         Metadata of the created rule
429 | 
430 |     Raises:
431 |         HTTPException: If rule creation fails
432 |     """
433 |     try:
434 |         metadata = yara_service.add_rule(rule_name, content, source)
435 |         logger.info(f"Rule {rule_name} created by {current_user.username} via plain text endpoint")
436 |         return metadata
437 |     except YaraError as error:
438 |         raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(error)) from error
439 | 
```

--------------------------------------------------------------------------------
/tests/unit/test_mcp_tools/test_file_tools_extended.py:
--------------------------------------------------------------------------------

```python
  1 | """Extended tests for file tools to improve coverage."""
  2 | 
  3 | import base64
  4 | import json
  5 | import uuid
  6 | from io import BytesIO
  7 | from unittest.mock import MagicMock, Mock, patch
  8 | 
  9 | import pytest
 10 | 
 11 | from yaraflux_mcp_server.mcp_tools.file_tools import (
 12 |     delete_file,
 13 |     download_file,
 14 |     extract_strings,
 15 |     get_file_info,
 16 |     get_hex_view,
 17 |     list_files,
 18 |     upload_file,
 19 | )
 20 | from yaraflux_mcp_server.storage import StorageError
 21 | 
 22 | 
 23 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.base64.b64decode")
 24 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
 25 | def test_upload_file_invalid_base64(mock_get_storage, mock_b64decode):
 26 |     """Test upload_file with invalid base64 data."""
 27 |     # Mock b64decode to raise exception
 28 |     mock_b64decode.side_effect = Exception("Invalid base64 data")
 29 | 
 30 |     # Call the function with invalid base64
 31 |     result = upload_file(data="This is not valid base64!", file_name="test.txt", encoding="base64")
 32 | 
 33 |     # Verify error handling
 34 |     assert isinstance(result, dict)
 35 |     assert "success" in result
 36 |     assert result["success"] is False
 37 |     assert "message" in result
 38 |     assert "Invalid base64 data" in result["message"]
 39 | 
 40 |     # Verify storage client was not called
 41 |     mock_get_storage.assert_not_called()
 42 | 
 43 | 
 44 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
 45 | def test_upload_file_empty_data(mock_get_storage):
 46 |     """Test upload_file with empty data."""
 47 |     # Call the function with empty data
 48 |     result = upload_file(data="", file_name="test.txt", encoding="base64")
 49 | 
 50 |     # Verify error handling
 51 |     assert isinstance(result, dict)
 52 |     assert "success" in result
 53 |     assert result["success"] is False
 54 |     assert "message" in result
 55 |     assert "cannot be empty" in result["message"].lower()
 56 | 
 57 |     # Verify storage client was not called
 58 |     mock_get_storage.assert_not_called()
 59 | 
 60 | 
 61 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
 62 | def test_upload_file_empty_filename(mock_get_storage):
 63 |     """Test upload_file with empty filename."""
 64 |     # Call the function with empty filename
 65 |     result = upload_file(data="SGVsbG8gV29ybGQ=", file_name="", encoding="base64")  # "Hello World"
 66 | 
 67 |     # Verify error handling
 68 |     assert isinstance(result, dict)
 69 |     assert "success" in result
 70 |     assert result["success"] is False
 71 |     assert "message" in result
 72 |     assert "name cannot be empty" in result["message"].lower()
 73 | 
 74 |     # Verify storage client was not called
 75 |     mock_get_storage.assert_not_called()
 76 | 
 77 | 
 78 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
 79 | def test_upload_file_invalid_encoding(mock_get_storage):
 80 |     """Test upload_file with invalid encoding."""
 81 |     # Call the function with invalid encoding
 82 |     result = upload_file(data="test data", file_name="test.txt", encoding="invalid")
 83 | 
 84 |     # Verify error handling
 85 |     assert isinstance(result, dict)
 86 |     assert "success" in result
 87 |     assert result["success"] is False
 88 |     assert "message" in result
 89 |     assert "Unsupported encoding" in result["message"]
 90 | 
 91 |     # Verify storage client was not called
 92 |     mock_get_storage.assert_not_called()
 93 | 
 94 | 
 95 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
 96 | def test_upload_file_storage_error(mock_get_storage):
 97 |     """Test upload_file with storage error."""
 98 |     # Setup mock to raise StorageError
 99 |     mock_storage = Mock()
100 |     mock_storage.save_file.side_effect = StorageError("Storage error")
101 |     mock_get_storage.return_value = mock_storage
102 | 
103 |     # Call the function
104 |     result = upload_file(data="SGVsbG8gV29ybGQ=", file_name="test.txt", encoding="base64")  # "Hello World"
105 | 
106 |     # Verify error handling
107 |     assert isinstance(result, dict)
108 |     assert "success" in result
109 |     assert result["success"] is False
110 |     assert "message" in result
111 |     assert "Storage error" in result["message"]
112 | 
113 | 
114 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
115 | def test_upload_file_general_exception(mock_get_storage):
116 |     """Test upload_file with general exception."""
117 |     # Setup mock to raise Exception
118 |     mock_storage = Mock()
119 |     mock_storage.save_file.side_effect = Exception("Unexpected error")
120 |     mock_get_storage.return_value = mock_storage
121 | 
122 |     # Call the function
123 |     result = upload_file(data="SGVsbG8gV29ybGQ=", file_name="test.txt", encoding="base64")  # "Hello World"
124 | 
125 |     # Verify error handling
126 |     assert isinstance(result, dict)
127 |     assert "success" in result
128 |     assert result["success"] is False
129 |     assert "message" in result
130 |     assert "Unexpected error" in result["message"]
131 | 
132 | 
133 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
134 | def test_get_file_info_empty_id(mock_get_storage):
135 |     """Test get_file_info with empty file ID."""
136 |     # Call the function with empty ID
137 |     result = get_file_info(file_id="")
138 | 
139 |     # Verify error handling
140 |     assert isinstance(result, dict)
141 |     assert "success" in result
142 |     assert result["success"] is False
143 |     assert "message" in result
144 |     assert "cannot be empty" in result["message"].lower()
145 | 
146 |     # Verify storage client was not called
147 |     mock_get_storage.assert_not_called()
148 | 
149 | 
150 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
151 | def test_list_files_invalid_page(mock_get_storage):
152 |     """Test list_files with invalid page number."""
153 |     # Call the function with invalid page
154 |     result = list_files(page=0)
155 | 
156 |     # Verify error handling
157 |     assert isinstance(result, dict)
158 |     assert "success" in result
159 |     assert result["success"] is False
160 |     assert "message" in result
161 |     assert "Page number must be positive" in result["message"]
162 | 
163 |     # Verify storage client was not called
164 |     mock_get_storage.assert_not_called()
165 | 
166 | 
167 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
168 | def test_list_files_invalid_page_size(mock_get_storage):
169 |     """Test list_files with invalid page size."""
170 |     # Call the function with invalid page size
171 |     result = list_files(page_size=0)
172 | 
173 |     # Verify error handling
174 |     assert isinstance(result, dict)
175 |     assert "success" in result
176 |     assert result["success"] is False
177 |     assert "message" in result
178 |     assert "Page size must be positive" in result["message"]
179 | 
180 |     # Verify storage client was not called
181 |     mock_get_storage.assert_not_called()
182 | 
183 | 
184 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
185 | def test_list_files_invalid_sort_field(mock_get_storage):
186 |     """Test list_files with invalid sort field."""
187 |     # Call the function with invalid sort field
188 |     result = list_files(sort_by="invalid_field")
189 | 
190 |     # Verify error handling
191 |     assert isinstance(result, dict)
192 |     assert "success" in result
193 |     assert result["success"] is False
194 |     assert "message" in result
195 |     assert "Invalid sort field" in result["message"]
196 | 
197 |     # Verify storage client was not called
198 |     mock_get_storage.assert_not_called()
199 | 
200 | 
201 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
202 | def test_delete_file_empty_id(mock_get_storage):
203 |     """Test delete_file with empty file ID."""
204 |     # Call the function with empty ID
205 |     result = delete_file(file_id="")
206 | 
207 |     # Verify error handling
208 |     assert isinstance(result, dict)
209 |     assert "success" in result
210 |     assert result["success"] is False
211 |     assert "message" in result
212 |     assert "cannot be empty" in result["message"].lower()
213 | 
214 |     # Verify storage client was not called
215 |     mock_get_storage.assert_not_called()
216 | 
217 | 
218 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
219 | def test_delete_file_storage_error(mock_get_storage):
220 |     """Test delete_file with storage error."""
221 |     # Setup mock that fails when get_file_info is called
222 |     mock_storage = Mock()
223 |     mock_storage.get_file_info.side_effect = StorageError("Storage error")
224 |     mock_get_storage.return_value = mock_storage
225 | 
226 |     # Call the function
227 |     result = delete_file(file_id="test-id")
228 | 
229 |     # Verify error handling - the implementation returns success=True
230 |     assert isinstance(result, dict)
231 |     assert "Error deleting file" in result["message"]
232 |     assert "message" in result
233 |     assert "Storage error" in result["message"]
234 | 
235 | 
236 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
237 | def test_extract_strings_empty_id(mock_get_storage):
238 |     """Test extract_strings with empty file ID."""
239 |     # Call the function with empty ID
240 |     result = extract_strings(file_id="")
241 | 
242 |     # Verify error handling
243 |     assert isinstance(result, dict)
244 |     assert "success" in result
245 |     assert result["success"] is False
246 |     assert "message" in result
247 |     assert "cannot be empty" in result["message"].lower()
248 | 
249 |     # Verify storage client was not called
250 |     mock_get_storage.assert_not_called()
251 | 
252 | 
253 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
254 | def test_extract_strings_invalid_min_length(mock_get_storage):
255 |     """Test extract_strings with invalid minimum length."""
256 |     # Call the function with invalid min_length
257 |     result = extract_strings(file_id="test-id", min_length=0)
258 | 
259 |     # Verify error handling
260 |     assert isinstance(result, dict)
261 |     assert "success" in result
262 |     assert result["success"] is False
263 |     assert "message" in result
264 |     assert "Minimum string length must be positive" in result["message"]
265 | 
266 |     # Verify storage client was not called
267 |     mock_get_storage.assert_not_called()
268 | 
269 | 
270 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
271 | def test_extract_strings_no_string_types(mock_get_storage):
272 |     """Test extract_strings with no string types selected."""
273 |     # Call the function with both string types disabled
274 |     result = extract_strings(file_id="test-id", include_unicode=False, include_ascii=False)
275 | 
276 |     # Verify error handling
277 |     assert isinstance(result, dict)
278 |     assert "success" in result
279 |     assert result["success"] is False
280 |     assert "message" in result
281 |     assert "At least one string type" in result["message"]
282 | 
283 |     # Verify storage client was not called
284 |     mock_get_storage.assert_not_called()
285 | 
286 | 
287 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
288 | def test_get_hex_view_empty_id(mock_get_storage):
289 |     """Test get_hex_view with empty file ID."""
290 |     # Call the function with empty ID
291 |     result = get_hex_view(file_id="")
292 | 
293 |     # Verify error handling
294 |     assert isinstance(result, dict)
295 |     assert "success" in result
296 |     assert result["success"] is False
297 |     assert "message" in result
298 |     assert "cannot be empty" in result["message"].lower()
299 | 
300 |     # Verify storage client was not called
301 |     mock_get_storage.assert_not_called()
302 | 
303 | 
304 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
305 | def test_get_hex_view_negative_offset(mock_get_storage):
306 |     """Test get_hex_view with negative offset."""
307 |     # Call the function with negative offset
308 |     result = get_hex_view(file_id="test-id", offset=-1)
309 | 
310 |     # Verify error handling
311 |     assert isinstance(result, dict)
312 |     assert "success" in result
313 |     assert result["success"] is False
314 |     assert "message" in result
315 |     assert "Offset must be non-negative" in result["message"]
316 | 
317 |     # Verify storage client was not called
318 |     mock_get_storage.assert_not_called()
319 | 
320 | 
321 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
322 | def test_get_hex_view_invalid_length(mock_get_storage):
323 |     """Test get_hex_view with invalid length."""
324 |     # Call the function with invalid length
325 |     result = get_hex_view(file_id="test-id", length=0)
326 | 
327 |     # Verify error handling
328 |     assert isinstance(result, dict)
329 |     assert "success" in result
330 |     assert result["success"] is False
331 |     assert "message" in result
332 |     assert "Length must be positive" in result["message"]
333 | 
334 |     # Verify storage client was not called
335 |     mock_get_storage.assert_not_called()
336 | 
337 | 
338 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
339 | def test_get_hex_view_invalid_bytes_per_line(mock_get_storage):
340 |     """Test get_hex_view with invalid bytes per line."""
341 |     # Call the function with invalid bytes_per_line
342 |     result = get_hex_view(file_id="test-id", bytes_per_line=0)
343 | 
344 |     # Verify error handling
345 |     assert isinstance(result, dict)
346 |     assert "success" in result
347 |     assert result["success"] is False
348 |     assert "message" in result
349 |     assert "Bytes per line must be positive" in result["message"]
350 | 
351 |     # Verify storage client was not called
352 |     mock_get_storage.assert_not_called()
353 | 
354 | 
355 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
356 | def test_download_file_empty_id(mock_get_storage):
357 |     """Test download_file with empty file ID."""
358 |     # Call the function with empty ID
359 |     result = download_file(file_id="")
360 | 
361 |     # Verify error handling
362 |     assert isinstance(result, dict)
363 |     assert "success" in result
364 |     assert result["success"] is False
365 |     assert "message" in result
366 |     assert "cannot be empty" in result["message"].lower()
367 | 
368 |     # Verify storage client was not called
369 |     mock_get_storage.assert_not_called()
370 | 
371 | 
372 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
373 | def test_download_file_invalid_encoding(mock_get_storage):
374 |     """Test download_file with invalid encoding."""
375 |     # Call the function with invalid encoding
376 |     result = download_file(file_id="test-id", encoding="invalid")
377 | 
378 |     # Verify error handling
379 |     assert isinstance(result, dict)
380 |     assert "success" in result
381 |     assert result["success"] is False
382 |     assert "message" in result
383 |     assert "Unsupported encoding" in result["message"]
384 | 
385 |     # Verify storage client was not called
386 |     mock_get_storage.assert_not_called()
387 | 
388 | 
389 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
390 | def test_download_file_unicode_decode_error(mock_get_storage):
391 |     """Test download_file with Unicode decode error."""
392 |     # Setup mock
393 |     mock_storage = Mock()
394 |     # Create binary data that will cause UnicodeDecodeError
395 |     binary_data = b"\xff\xfe\xff\xfe"  # Invalid UTF-8 sequence
396 |     mock_storage.get_file.return_value = binary_data
397 |     mock_storage.get_file_info.return_value = {
398 |         "file_id": "test-id",
399 |         "file_name": "binary.bin",
400 |         "file_size": len(binary_data),
401 |         "mime_type": "application/octet-stream",
402 |     }
403 |     mock_get_storage.return_value = mock_storage
404 | 
405 |     # Call the function requesting text encoding
406 |     result = download_file(file_id="test-id", encoding="text")
407 | 
408 |     # Verify handling - should fall back to base64
409 |     assert isinstance(result, dict)
410 |     assert "success" in result
411 |     assert result["success"] is True
412 |     assert "encoding" in result
413 |     assert result["encoding"] == "base64"
414 |     assert "data" in result
415 |     # The data should be base64-encoded
416 |     decoded = base64.b64decode(result["data"])
417 |     assert decoded == binary_data
418 | 
419 | 
420 | @patch("yaraflux_mcp_server.mcp_tools.file_tools.get_storage_client")
421 | def test_download_file_storage_error(mock_get_storage):
422 |     """Test download_file with storage error."""
423 |     # Setup mock
424 |     mock_storage = Mock()
425 |     mock_storage.get_file.side_effect = StorageError("Storage error")
426 |     mock_get_storage.return_value = mock_storage
427 | 
428 |     # Call the function
429 |     result = download_file(file_id="test-id")
430 | 
431 |     # Verify error handling
432 |     assert isinstance(result, dict)
433 |     assert "success" in result
434 |     assert result["success"] is False
435 |     assert "message" in result
436 |     assert "Storage error" in result["message"]
437 | 
```

--------------------------------------------------------------------------------
/src/yaraflux_mcp_server/mcp_tools/file_tools.py:
--------------------------------------------------------------------------------

```python
  1 | """File management tools for Claude MCP integration.
  2 | 
  3 | This module provides tools for file operations including uploading, downloading,
  4 | viewing hex dumps, and extracting strings from files. It uses direct function implementations
  5 | with inline error handling.
  6 | """
  7 | 
  8 | import base64
  9 | import logging
 10 | from typing import Any, Dict, Optional
 11 | 
 12 | from yaraflux_mcp_server.mcp_tools.base import register_tool
 13 | from yaraflux_mcp_server.storage import StorageError, get_storage_client
 14 | 
 15 | # Configure logging
 16 | logger = logging.getLogger(__name__)
 17 | 
 18 | 
 19 | @register_tool()
 20 | def upload_file(
 21 |     data: str, file_name: str, encoding: str = "base64", metadata: Optional[Dict[str, Any]] = None
 22 | ) -> Dict[str, Any]:
 23 |     """Upload a file to the storage system.
 24 | 
 25 |     This tool allows you to upload files with metadata for later retrieval and analysis.
 26 |     Files can be uploaded as base64-encoded data or plain text.
 27 | 
 28 |     For LLM users connecting through MCP, this can be invoked with natural language like:
 29 |     "Upload this file with base64 data: SGVsbG8gV29ybGQ="
 30 |     "Save this text as a file named example.txt: This is the content"
 31 |     "Store this code snippet as script.py with metadata indicating it's executable"
 32 | 
 33 |     Args:
 34 |         data: File content encoded as specified by the encoding parameter
 35 |         file_name: Name of the file
 36 |         encoding: Encoding of the data ("base64" or "text")
 37 |         metadata: Optional metadata to associate with the file
 38 | 
 39 |     Returns:
 40 |         File information including ID, size, and metadata
 41 |     """
 42 |     try:
 43 |         # Validate parameters
 44 |         if not data:
 45 |             raise ValueError("File data cannot be empty")
 46 | 
 47 |         if not file_name:
 48 |             raise ValueError("File name cannot be empty")
 49 | 
 50 |         if encoding not in ["base64", "text"]:
 51 |             raise ValueError(f"Unsupported encoding: {encoding}")
 52 | 
 53 |         # Decode the data
 54 |         if encoding == "base64":
 55 |             try:
 56 |                 decoded_data = base64.b64decode(data)
 57 |             except Exception as e:
 58 |                 raise ValueError(f"Invalid base64 data: {str(e)}") from e
 59 |         else:  # encoding == "text"
 60 |             decoded_data = data.encode("utf-8")
 61 | 
 62 |         # Save the file
 63 |         storage = get_storage_client()
 64 |         file_info = storage.save_file(file_name, decoded_data, metadata or {})
 65 | 
 66 |         return {"success": True, "message": f"File {file_name} uploaded successfully", "file_info": file_info}
 67 |     except ValueError as e:
 68 |         logger.error(f"Value error in upload_file: {str(e)}")
 69 |         return {"success": False, "message": str(e)}
 70 |     except StorageError as e:
 71 |         logger.error(f"Storage error in upload_file: {str(e)}")
 72 |         return {"success": False, "message": f"Storage error: {str(e)}"}
 73 |     except Exception as e:
 74 |         logger.error(f"Unexpected error in upload_file: {str(e)}")
 75 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
 76 | 
 77 | 
 78 | @register_tool()
 79 | def get_file_info(file_id: str) -> Dict[str, Any]:
 80 |     """Get detailed information about a file.
 81 | 
 82 |     For LLM users connecting through MCP, this can be invoked with natural language like:
 83 |     "Get details about file abc123"
 84 |     "Show me the metadata for file xyz789"
 85 |     "What's the size and upload date of file 456def?"
 86 | 
 87 |     Args:
 88 |         file_id: ID of the file
 89 | 
 90 |     Returns:
 91 |         File information including metadata
 92 |     """
 93 |     try:
 94 |         if not file_id:
 95 |             raise ValueError("File ID cannot be empty")
 96 | 
 97 |         storage = get_storage_client()
 98 |         file_info = storage.get_file_info(file_id)
 99 | 
100 |         return {"success": True, "file_info": file_info}
101 |     except StorageError as e:
102 |         logger.error(f"Error getting file info: {str(e)}")
103 |         return {"success": False, "message": f"Error getting file info: {str(e)}"}
104 |     except ValueError as e:
105 |         logger.error(f"Value error in get_file_info: {str(e)}")
106 |         return {"success": False, "message": str(e)}
107 |     except Exception as e:
108 |         logger.error(f"Unexpected error in get_file_info: {str(e)}")
109 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
110 | 
111 | 
112 | @register_tool()
113 | def list_files(
114 |     page: int = 1, page_size: int = 100, sort_by: str = "uploaded_at", sort_desc: bool = True
115 | ) -> Dict[str, Any]:
116 |     """List files with pagination and sorting.
117 | 
118 |     For LLM users connecting through MCP, this can be invoked with natural language like:
119 |     "Show me all the uploaded files"
120 |     "List the most recently uploaded files first"
121 |     "Show files sorted by name in alphabetical order"
122 |     "List the largest files first"
123 | 
124 |     Args:
125 |         page: Page number (1-based)
126 |         page_size: Number of items per page
127 |         sort_by: Field to sort by (uploaded_at, file_name, file_size)
128 |         sort_desc: Sort in descending order if True
129 | 
130 |     Returns:
131 |         List of files with pagination info
132 |     """
133 |     try:
134 |         # Validate parameters
135 |         if page < 1:
136 |             raise ValueError("Page number must be positive")
137 | 
138 |         if page_size < 1:
139 |             raise ValueError("Page size must be positive")
140 | 
141 |         valid_sort_fields = ["uploaded_at", "file_name", "file_size"]
142 |         if sort_by not in valid_sort_fields:
143 |             raise ValueError(f"Invalid sort field: {sort_by}. Must be one of {valid_sort_fields}")
144 | 
145 |         storage = get_storage_client()
146 |         result = storage.list_files(page, page_size, sort_by, sort_desc)
147 | 
148 |         return {
149 |             "success": True,
150 |             "files": result.get("files", []),
151 |             "total": result.get("total", 0),
152 |             "page": result.get("page", page),
153 |             "page_size": result.get("page_size", page_size),
154 |         }
155 |     except StorageError as e:
156 |         logger.error(f"Error listing files: {str(e)}")
157 |         return {"success": False, "message": f"Error listing files: {str(e)}"}
158 |     except ValueError as e:
159 |         logger.error(f"Value error in list_files: {str(e)}")
160 |         return {"success": False, "message": str(e)}
161 |     except Exception as e:
162 |         logger.error(f"Unexpected error in list_files: {str(e)}")
163 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
164 | 
165 | 
166 | @register_tool()
167 | def delete_file(file_id: str) -> Dict[str, Any]:
168 |     """Delete a file from storage.
169 | 
170 |     For LLM users connecting through MCP, this can be invoked with natural language like:
171 |     "Delete file abc123"
172 |     "Remove the file with ID xyz789"
173 |     "Please get rid of file 456def"
174 | 
175 |     Args:
176 |         file_id: ID of the file to delete
177 | 
178 |     Returns:
179 |         Deletion result
180 |     """
181 |     try:
182 |         if not file_id:
183 |             raise ValueError("File ID cannot be empty")
184 | 
185 |         storage = get_storage_client()
186 | 
187 |         # Get file info first to include in response
188 |         try:
189 |             file_info = storage.get_file_info(file_id)
190 |             file_name = file_info.get("file_name", "Unknown file")
191 |         except StorageError as e:
192 |             # Return error if get_file_info fails
193 |             logger.error(f"Error getting file info: {str(e)}")
194 |             return {"success": False, "message": f"Error deleting file: {str(e)}"}
195 |         except Exception:
196 |             file_name = "Unknown file"
197 | 
198 |         # Delete the file
199 |         result = storage.delete_file(file_id)
200 | 
201 |         if result:
202 |             return {"success": True, "message": f"File {file_name} deleted successfully", "file_id": file_id}
203 |         return {"success": False, "message": f"File {file_id} not found or could not be deleted"}
204 |     except StorageError as e:
205 |         logger.error(f"Error deleting file: {str(e)}")
206 |         return {"success": False, "message": f"Error deleting file: {str(e)}"}
207 |     except ValueError as e:
208 |         logger.error(f"Value error in delete_file: {str(e)}")
209 |         return {"success": False, "message": str(e)}
210 |     except Exception as e:
211 |         logger.error(f"Unexpected error in delete_file: {str(e)}")
212 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
213 | 
214 | 
215 | @register_tool()
216 | def extract_strings(
217 |     file_id: str,
218 |     min_length: int = 4,
219 |     include_unicode: bool = True,
220 |     include_ascii: bool = True,
221 |     limit: Optional[int] = None,
222 | ) -> Dict[str, Any]:
223 |     """Extract strings from a file.
224 | 
225 |     This tool extracts ASCII and/or Unicode strings from a file with a specified minimum length.
226 |     It's useful for analyzing binary files or looking for embedded text in files.
227 | 
228 |     For LLM users connecting through MCP, this can be invoked with natural language like:
229 |     "Extract strings from file abc123"
230 |     "Find all text strings in the file with ID xyz789"
231 |     "Show me any readable text in file 456def with at least 8 characters"
232 | 
233 |     Args:
234 |         file_id: ID of the file
235 |         min_length: Minimum string length
236 |         include_unicode: Include Unicode strings
237 |         include_ascii: Include ASCII strings
238 |         limit: Maximum number of strings to return
239 | 
240 |     Returns:
241 |         Extracted strings and metadata
242 |     """
243 |     try:
244 |         # Validate parameters
245 |         if not file_id:
246 |             raise ValueError("File ID cannot be empty")
247 | 
248 |         if min_length < 1:
249 |             raise ValueError("Minimum string length must be positive")
250 | 
251 |         if not include_unicode and not include_ascii:
252 |             raise ValueError("At least one string type (Unicode or ASCII) must be included")
253 | 
254 |         storage = get_storage_client()
255 |         result = storage.extract_strings(
256 |             file_id, min_length=min_length, include_unicode=include_unicode, include_ascii=include_ascii, limit=limit
257 |         )
258 | 
259 |         return {
260 |             "success": True,
261 |             "file_id": result.get("file_id"),
262 |             "file_name": result.get("file_name"),
263 |             "strings": result.get("strings", []),
264 |             "total_strings": result.get("total_strings", 0),
265 |             "min_length": result.get("min_length", min_length),
266 |             "include_unicode": result.get("include_unicode", include_unicode),
267 |             "include_ascii": result.get("include_ascii", include_ascii),
268 |         }
269 |     except StorageError as e:
270 |         logger.error(f"Error extracting strings: {str(e)}")
271 |         return {"success": False, "message": f"Error extracting strings: {str(e)}"}
272 |     except ValueError as e:
273 |         logger.error(f"Value error in extract_strings: {str(e)}")
274 |         return {"success": False, "message": str(e)}
275 |     except Exception as e:
276 |         logger.error(f"Unexpected error in extract_strings: {str(e)}")
277 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
278 | 
279 | 
280 | @register_tool()
281 | def get_hex_view(
282 |     file_id: str, offset: int = 0, length: Optional[int] = None, bytes_per_line: int = 16
283 | ) -> Dict[str, Any]:
284 |     """Get hexadecimal view of file content.
285 | 
286 |     This tool provides a hexadecimal representation of file content with optional ASCII view.
287 |     It's useful for examining binary files or seeing the raw content of text files.
288 | 
289 |     For LLM users connecting through MCP, this can be invoked with natural language like:
290 |     "Show me a hex dump of file abc123"
291 |     "Display the hex representation of file xyz789"
292 |     "I need to see the raw bytes of file 456def"
293 | 
294 |     Args:
295 |         file_id: ID of the file
296 |         offset: Starting offset in bytes
297 |         length: Number of bytes to return (if None, a reasonable default is used)
298 |         bytes_per_line: Number of bytes per line in output
299 | 
300 |     Returns:
301 |         Hexadecimal representation of file content
302 |     """
303 |     try:
304 |         # Validate parameters
305 |         if not file_id:
306 |             raise ValueError("File ID cannot be empty")
307 | 
308 |         if offset < 0:
309 |             raise ValueError("Offset must be non-negative")
310 | 
311 |         if length is not None and length < 1:
312 |             raise ValueError("Length must be positive")
313 | 
314 |         if bytes_per_line < 1:
315 |             raise ValueError("Bytes per line must be positive")
316 | 
317 |         storage = get_storage_client()
318 |         result = storage.get_hex_view(file_id, offset=offset, length=length, bytes_per_line=bytes_per_line)
319 | 
320 |         return {
321 |             "success": True,
322 |             "file_id": result.get("file_id"),
323 |             "file_name": result.get("file_name"),
324 |             "hex_content": result.get("hex_content"),
325 |             "offset": result.get("offset", offset),
326 |             "length": result.get("length", 0),
327 |             "total_size": result.get("total_size", 0),
328 |             "bytes_per_line": result.get("bytes_per_line", bytes_per_line),
329 |         }
330 |     except StorageError as e:
331 |         logger.error(f"Error getting hex view: {str(e)}")
332 |         return {"success": False, "message": f"Error getting hex view: {str(e)}"}
333 |     except ValueError as e:
334 |         logger.error(f"Value error in get_hex_view: {str(e)}")
335 |         return {"success": False, "message": str(e)}
336 |     except Exception as e:
337 |         logger.error(f"Unexpected error in get_hex_view: {str(e)}")
338 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
339 | 
340 | 
341 | @register_tool()
342 | def download_file(file_id: str, encoding: str = "base64") -> Dict[str, Any]:
343 |     """Download a file's content.
344 | 
345 |     This tool retrieves the content of a file, returning it in the specified encoding.
346 | 
347 |     For LLM users connecting through MCP, this can be invoked with natural language like:
348 |     "Download file abc123 and show me its contents"
349 |     "Get the content of file xyz789 as text if possible"
350 |     "Retrieve file 456def for me"
351 | 
352 |     Args:
353 |         file_id: ID of the file to download
354 |         encoding: Encoding for the returned data ("base64" or "text")
355 | 
356 |     Returns:
357 |         File content and metadata
358 |     """
359 |     try:
360 |         # Validate parameters
361 |         if not file_id:
362 |             raise ValueError("File ID cannot be empty")
363 | 
364 |         if encoding not in ["base64", "text"]:
365 |             raise ValueError(f"Unsupported encoding: {encoding}")
366 | 
367 |         storage = get_storage_client()
368 |         file_data = storage.get_file(file_id)
369 |         file_info = storage.get_file_info(file_id)
370 | 
371 |         # Encode the data as requested
372 |         if encoding == "base64":
373 |             encoded_data = base64.b64encode(file_data).decode("ascii")
374 |         elif encoding == "text":
375 |             try:
376 |                 encoded_data = file_data.decode("utf-8")
377 |             except UnicodeDecodeError:
378 |                 # If the file isn't valid utf-8 text, fall back to base64
379 |                 encoded_data = base64.b64encode(file_data).decode("ascii")
380 |                 encoding = "base64"  # Update encoding to reflect what was actually used
381 |         else:
382 |             # This shouldn't happen due to validation, but just in case
383 |             encoded_data = base64.b64encode(file_data).decode("ascii")
384 |             encoding = "base64"
385 | 
386 |         return {
387 |             "success": True,
388 |             "file_id": file_id,
389 |             "file_name": file_info.get("file_name"),
390 |             "file_size": file_info.get("file_size"),
391 |             "mime_type": file_info.get("mime_type"),
392 |             "data": encoded_data,
393 |             "encoding": encoding,
394 |         }
395 |     except StorageError as e:
396 |         logger.error(f"Error downloading file: {str(e)}")
397 |         return {"success": False, "message": f"Error downloading file: {str(e)}"}
398 |     except ValueError as e:
399 |         logger.error(f"Value error in download_file: {str(e)}")
400 |         return {"success": False, "message": str(e)}
401 |     except Exception as e:
402 |         logger.error(f"Unexpected error in download_file: {str(e)}")
403 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
404 | 
```

--------------------------------------------------------------------------------
/tests/unit/test_utils/test_logging_config.py:
--------------------------------------------------------------------------------

```python
  1 | """Unit tests for logging_config module."""
  2 | 
  3 | import json
  4 | import logging
  5 | import os
  6 | import sys
  7 | import threading  # Import threading here as it's needed by the module
  8 | import uuid
  9 | from datetime import datetime
 10 | from logging import LogRecord
 11 | from unittest.mock import MagicMock, Mock, patch
 12 | 
 13 | import pytest
 14 | 
 15 | from yaraflux_mcp_server.utils.logging_config import (
 16 |     JsonFormatter,
 17 |     RequestIdFilter,
 18 |     clear_request_id,
 19 |     configure_logging,
 20 |     get_request_id,
 21 |     log_entry_exit,
 22 |     mask_sensitive_data,
 23 |     set_request_id,
 24 | )
 25 | 
 26 | 
 27 | class TestRequestIdContext:
 28 |     """Tests for request ID context management functions."""
 29 | 
 30 |     def test_get_request_id(self):
 31 |         """Test getting a request ID."""
 32 |         # First call should create and return a UUID
 33 |         request_id = get_request_id()
 34 |         assert request_id is not None
 35 |         # UUID validation (basic check)
 36 |         try:
 37 |             uuid_obj = uuid.UUID(request_id)
 38 |             assert str(uuid_obj) == request_id
 39 |         except ValueError:
 40 |             pytest.fail("Request ID is not a valid UUID")
 41 | 
 42 |         # Second call should return the same ID for the same thread
 43 |         second_id = get_request_id()
 44 |         assert second_id == request_id
 45 | 
 46 |     def test_set_request_id(self):
 47 |         """Test setting a request ID."""
 48 |         # Set a specific request ID
 49 |         custom_id = "test-request-id"
 50 |         result = set_request_id(custom_id)
 51 |         assert result == custom_id
 52 | 
 53 |         # Get should now return the custom ID
 54 |         assert get_request_id() == custom_id
 55 | 
 56 |         # Set with no parameter should generate a new UUID
 57 |         new_id = set_request_id()
 58 |         assert new_id != custom_id
 59 |         assert get_request_id() == new_id
 60 | 
 61 |     def test_clear_request_id(self):
 62 |         """Test clearing the request ID."""
 63 |         # Set a request ID
 64 |         set_request_id("test-id")
 65 |         assert get_request_id() == "test-id"
 66 | 
 67 |         # Clear it
 68 |         clear_request_id()
 69 | 
 70 |         # Next get should create a new one
 71 |         new_id = get_request_id()
 72 |         assert new_id != "test-id"
 73 |         assert uuid.UUID(new_id)  # Validate it's a UUID
 74 | 
 75 | 
 76 | class TestRequestIdFilter:
 77 |     """Tests for the RequestIdFilter class."""
 78 | 
 79 |     def test_filter(self):
 80 |         """Test that the filter adds a request ID to log records."""
 81 |         # Set a known request ID
 82 |         set_request_id("test-filter-id")
 83 | 
 84 |         # Create a record
 85 |         record = logging.LogRecord(
 86 |             name="test_logger",
 87 |             level=logging.INFO,
 88 |             pathname="test_path",
 89 |             lineno=42,
 90 |             msg="Test message",
 91 |             args=(),
 92 |             exc_info=None,
 93 |         )
 94 | 
 95 |         # Apply the filter
 96 |         filter_obj = RequestIdFilter()
 97 |         result = filter_obj.filter(record)
 98 | 
 99 |         # Verify the filter added the request ID
100 |         assert result is True  # Filter should always return True
101 |         assert hasattr(record, "request_id")
102 |         assert record.request_id == "test-filter-id"
103 | 
104 |         # Clean up
105 |         clear_request_id()
106 | 
107 | 
108 | class TestJsonFormatter:
109 |     """Tests for the JsonFormatter class."""
110 | 
111 |     def test_format_basic(self):
112 |         """Test basic formatting of a log record."""
113 |         formatter = JsonFormatter()
114 | 
115 |         # Create a sample log record with all required fields
116 |         record = logging.LogRecord(
117 |             name="test_logger",
118 |             level=logging.INFO,
119 |             pathname="/path/to/file.py",
120 |             lineno=42,
121 |             msg="Test message",
122 |             args=(),
123 |             exc_info=None,
124 |         )
125 |         # Set the funcName explicitly since we're expecting it in the test
126 |         record.funcName = "?"
127 | 
128 |         # Add a request ID
129 |         record.request_id = "test-json-id"
130 | 
131 |         # Format the record
132 |         formatted = formatter.format(record)
133 | 
134 |         # Parse the JSON
135 |         log_data = json.loads(formatted)
136 | 
137 |         # Verify the basic fields
138 |         assert log_data["level"] == "INFO"
139 |         assert log_data["logger"] == "test_logger"
140 |         assert log_data["message"] == "Test message"
141 |         assert log_data["module"] == "file"  # Extracted from pathname
142 |         assert log_data["function"] == "?"
143 |         assert log_data["line"] == 42
144 |         assert log_data["request_id"] == "test-json-id"
145 |         assert "timestamp" in log_data
146 |         assert "hostname" in log_data
147 |         assert "process_id" in log_data
148 |         assert "thread_id" in log_data
149 | 
150 |     def test_format_with_exception(self):
151 |         """Test formatting a log record with an exception."""
152 |         formatter = JsonFormatter()
153 | 
154 |         # Create an exception
155 |         try:
156 |             raise ValueError("Test exception")
157 |         except ValueError:
158 |             exc_info = sys.exc_info()
159 | 
160 |         # Create a log record with the exception
161 |         record = logging.LogRecord(
162 |             name="test_logger",
163 |             level=logging.ERROR,
164 |             pathname="/path/to/file.py",
165 |             lineno=42,
166 |             msg="Exception occurred",
167 |             args=(),
168 |             exc_info=exc_info,
169 |         )
170 |         record.request_id = "test-exception-id"
171 | 
172 |         # Format the record
173 |         formatted = formatter.format(record)
174 | 
175 |         # Parse the JSON
176 |         log_data = json.loads(formatted)
177 | 
178 |         # Verify exception information is included
179 |         assert "exception" in log_data
180 |         assert isinstance(log_data["exception"], list)
181 |         assert any("ValueError: Test exception" in line for line in log_data["exception"])
182 | 
183 |     def test_format_with_extra_fields(self):
184 |         """Test formatting a log record with extra fields."""
185 |         formatter = JsonFormatter()
186 | 
187 |         # Create a record with extra fields
188 |         record = logging.LogRecord(
189 |             name="test_logger",
190 |             level=logging.INFO,
191 |             pathname="/path/to/file.py",
192 |             lineno=42,
193 |             msg="Test with extras",
194 |             args=(),
195 |             exc_info=None,
196 |         )
197 |         record.request_id = "test-extras-id"
198 | 
199 |         # Add custom attributes
200 |         record.custom_str = "custom value"
201 |         record.custom_int = 123
202 |         record.custom_dict = {"key": "value"}
203 | 
204 |         # Format the record
205 |         formatted = formatter.format(record)
206 | 
207 |         # Parse the JSON
208 |         log_data = json.loads(formatted)
209 | 
210 |         # Verify extra fields are included
211 |         assert log_data["custom_str"] == "custom value"
212 |         assert log_data["custom_int"] == 123
213 |         assert log_data["custom_dict"] == {"key": "value"}
214 | 
215 | 
216 | class TestMaskSensitiveData:
217 |     """Tests for the mask_sensitive_data function."""
218 | 
219 |     def test_mask_sensitive_data_simple(self):
220 |         """Test masking sensitive data in a simple dictionary."""
221 |         data = {
222 |             "username": "test_user",
223 |             "password": "secret123",
224 |             "api_key": "abcdef123456",
225 |             "message": "Hello, world!",
226 |         }
227 | 
228 |         masked = mask_sensitive_data(data)
229 | 
230 |         # Verify sensitive fields are masked
231 |         assert masked["username"] == "test_user"  # Not sensitive
232 |         assert masked["password"] == "**REDACTED**"
233 |         assert masked["api_key"] == "**REDACTED**"
234 |         assert masked["message"] == "Hello, world!"  # Not sensitive
235 | 
236 |     def test_mask_sensitive_data_nested(self):
237 |         """Test masking sensitive data in nested structures."""
238 |         data = {
239 |             "user": {
240 |                 "name": "Test User",
241 |                 "credentials": {
242 |                     "password": "secret123",
243 |                     "token": "abc123",
244 |                 },
245 |             },
246 |             "settings": [
247 |                 {"name": "theme", "value": "dark"},
248 |                 # Need to adjust the test to match actual behavior
249 |                 # The current implementation only checks the key name, not the value of "name"
250 |                 {"name": "api_key", "api_key": "xyz789"},  # Changed to have a sensitive key
251 |             ],
252 |         }
253 | 
254 |         masked = mask_sensitive_data(data)
255 | 
256 |         # Verify sensitive fields are masked at all levels
257 |         assert masked["user"]["name"] == "Test User"
258 |         assert masked["user"]["credentials"]["password"] == "**REDACTED**"
259 |         assert masked["user"]["credentials"]["token"] == "**REDACTED**"
260 |         assert masked["settings"][0]["name"] == "theme"
261 |         assert masked["settings"][0]["value"] == "dark"
262 |         assert masked["settings"][1]["name"] == "api_key"
263 |         assert masked["settings"][1]["api_key"] == "**REDACTED**"  # This key should be masked
264 | 
265 |     def test_mask_sensitive_data_custom_fields(self):
266 |         """Test masking with custom sensitive field names."""
267 |         data = {
268 |             "user": "test_user",
269 |             "ssn": "123-45-6789",
270 |             "credit_card": "4111-1111-1111-1111",
271 |         }
272 | 
273 |         # Define custom sensitive fields
274 |         sensitive = ["ssn", "credit_card"]
275 | 
276 |         masked = mask_sensitive_data(data, sensitive_fields=sensitive)
277 | 
278 |         # Verify only custom fields are masked
279 |         assert masked["user"] == "test_user"
280 |         assert masked["ssn"] == "**REDACTED**"
281 |         assert masked["credit_card"] == "**REDACTED**"
282 | 
283 | 
284 | @patch("logging.Logger")
285 | class TestLogEntryExit:
286 |     """Tests for the log_entry_exit decorator."""
287 | 
288 |     def test_log_entry_exit_success(self, mock_logger):
289 |         """Test the decorator with a successful function."""
290 | 
291 |         # Create a decorated function
292 |         @log_entry_exit(logger=mock_logger)
293 |         def test_function(arg1, arg2=None):
294 |             """Test function."""
295 |             return arg1 + (arg2 or 0)
296 | 
297 |         # Call the function
298 |         result = test_function(5, arg2=10)
299 | 
300 |         # Verify the result
301 |         assert result == 15
302 | 
303 |         # Verify logging
304 |         assert mock_logger.log.call_count == 2  # Entry and exit logs
305 | 
306 |         # Check that the entry log contains the function name and arguments
307 |         entry_log_call = mock_logger.log.call_args_list[0]
308 |         assert "Entering test_function" in entry_log_call[0][1]
309 |         assert "5" in entry_log_call[0][1]  # arg1
310 |         assert "arg2=10" in entry_log_call[0][1]  # arg2
311 | 
312 |         # Check the exit log
313 |         exit_log_call = mock_logger.log.call_args_list[1]
314 |         assert "Exiting test_function" in exit_log_call[0][1]
315 | 
316 |     def test_log_entry_exit_exception(self, mock_logger):
317 |         """Test the decorator with a function that raises an exception."""
318 | 
319 |         # Create a decorated function that raises an exception
320 |         @log_entry_exit(logger=mock_logger)
321 |         def failing_function():
322 |             """Function that raises an exception."""
323 |             raise ValueError("Test error")
324 | 
325 |         # Call the function and expect an exception
326 |         with pytest.raises(ValueError, match="Test error"):
327 |             failing_function()
328 | 
329 |         # Verify logging - should have entry log and exception log
330 |         assert mock_logger.log.call_count == 1  # Entry log
331 |         assert mock_logger.exception.call_count == 1  # Exception log
332 | 
333 |         # Check entry log
334 |         entry_log_call = mock_logger.log.call_args_list[0]
335 |         assert "Entering failing_function" in entry_log_call[0][1]
336 | 
337 |         # Check exception log
338 |         exception_log_call = mock_logger.exception.call_args_list[0]
339 |         assert "Exception in failing_function" in exception_log_call[0][0]
340 |         assert "Test error" in exception_log_call[0][0]
341 | 
342 | 
343 | @patch("logging.config.dictConfig")
344 | @patch("logging.getLogger")
345 | class TestConfigureLogging:
346 |     """Tests for the configure_logging function."""
347 | 
348 |     def test_configure_logging_defaults(self, mock_get_logger, mock_dict_config):
349 |         """Test configuring logging with default parameters."""
350 |         # Mock the logger returned by getLogger
351 |         mock_logger = MagicMock()
352 |         mock_get_logger.return_value = mock_logger
353 | 
354 |         # Call configure_logging with defaults
355 |         configure_logging()
356 | 
357 |         # Verify dictionary config was called
358 |         mock_dict_config.assert_called_once()
359 | 
360 |         # Check that the config has the expected structure
361 |         config = mock_dict_config.call_args[0][0]
362 |         assert "formatters" in config
363 |         assert "filters" in config
364 |         assert "handlers" in config
365 |         assert "loggers" in config
366 | 
367 |         # Verify console handler is included by default
368 |         assert "console" in config["handlers"]
369 | 
370 |         # Verify no file handler by default
371 |         assert "file" not in config["handlers"]
372 | 
373 |         # Verify the logger was used to log configuration
374 |         mock_get_logger.assert_called_with("yaraflux_mcp_server")
375 |         mock_logger.info.assert_called_once()
376 |         assert "Logging configured" in mock_logger.info.call_args[0][0]
377 | 
378 |     def test_configure_logging_with_file(self, mock_get_logger, mock_dict_config):
379 |         """Test configuring logging with a file handler."""
380 |         # Mock the logger
381 |         mock_logger = MagicMock()
382 |         mock_get_logger.return_value = mock_logger
383 | 
384 |         # Patch os.makedirs to track creation of log directory
385 |         with patch("os.makedirs") as mock_makedirs:
386 |             # Call configure_logging with a log file
387 |             configure_logging(log_file="/tmp/test_log.log", log_level="DEBUG")
388 | 
389 |             # Verify the log directory was created
390 |             mock_makedirs.assert_called_once()
391 |             assert "/tmp" in mock_makedirs.call_args[0][0]
392 | 
393 |         # Verify dictionary config was called
394 |         mock_dict_config.assert_called_once()
395 | 
396 |         # Check the config has a file handler
397 |         config = mock_dict_config.call_args[0][0]
398 |         assert "file" in config["handlers"]
399 |         assert config["handlers"]["file"]["filename"] == "/tmp/test_log.log"
400 |         assert config["handlers"]["file"]["level"] == "DEBUG"
401 | 
402 |         # Verify both console and file handlers are used
403 |         assert len(config["handlers"]) == 2
404 |         assert "console" in config["handlers"]
405 | 
406 |         # Verify the logger was configured with both handlers
407 |         root_logger = config["loggers"][""]
408 |         assert "console" in root_logger["handlers"]
409 |         assert "file" in root_logger["handlers"]
410 | 
411 |     def test_configure_logging_no_console(self, mock_get_logger, mock_dict_config):
412 |         """Test configuring logging without console output."""
413 |         # Mock the logger
414 |         mock_logger = MagicMock()
415 |         mock_get_logger.return_value = mock_logger
416 | 
417 |         # Call configure_logging with no console output
418 |         configure_logging(log_to_console=False, log_file="/tmp/test_log.log")
419 | 
420 |         # Verify dictionary config was called
421 |         mock_dict_config.assert_called_once()
422 | 
423 |         # Check the config has no console handler
424 |         config = mock_dict_config.call_args[0][0]
425 |         assert "console" not in config["handlers"]
426 |         assert "file" in config["handlers"]
427 | 
428 |         # Verify only file handler is used
429 |         assert len(config["handlers"]) == 1
430 |         assert config["loggers"][""]["handlers"] == ["file"]
431 | 
432 |     def test_configure_logging_plaintext(self, mock_get_logger, mock_dict_config):
433 |         """Test configuring logging with plaintext instead of JSON."""
434 |         # Mock the logger
435 |         mock_logger = MagicMock()
436 |         mock_get_logger.return_value = mock_logger
437 | 
438 |         # Call configure_logging with plaintext formatting
439 |         configure_logging(enable_json=False)
440 | 
441 |         # Verify dictionary config was called
442 |         mock_dict_config.assert_called_once()
443 | 
444 |         # Check the config uses standard formatter
445 |         config = mock_dict_config.call_args[0][0]
446 |         assert config["handlers"]["console"]["formatter"] == "standard"
447 | 
```

--------------------------------------------------------------------------------
/tests/unit/test_storage/test_local_storage.py:
--------------------------------------------------------------------------------

```python
  1 | """Unit tests for the local storage client."""
  2 | 
  3 | import hashlib
  4 | import json
  5 | import os
  6 | import tempfile
  7 | from datetime import datetime
  8 | from pathlib import Path
  9 | from unittest.mock import MagicMock, Mock, patch
 10 | 
 11 | import pytest
 12 | 
 13 | from yaraflux_mcp_server.storage.base import StorageError
 14 | from yaraflux_mcp_server.storage.local import LocalStorageClient
 15 | 
 16 | 
 17 | @pytest.fixture
 18 | def temp_dir():
 19 |     """Create a temporary directory for testing."""
 20 |     with tempfile.TemporaryDirectory() as tmp_dir:
 21 |         yield Path(tmp_dir)
 22 | 
 23 | 
 24 | @pytest.fixture
 25 | def mock_settings(temp_dir):
 26 |     """Mock settings for testing."""
 27 |     with patch("yaraflux_mcp_server.storage.local.settings") as mock_settings:
 28 |         mock_settings.STORAGE_DIR = temp_dir / "storage"
 29 |         mock_settings.YARA_RULES_DIR = temp_dir / "rules"
 30 |         mock_settings.YARA_SAMPLES_DIR = temp_dir / "samples"
 31 |         mock_settings.YARA_RESULTS_DIR = temp_dir / "results"
 32 |         yield mock_settings
 33 | 
 34 | 
 35 | @pytest.fixture
 36 | def storage_client(mock_settings):
 37 |     """Create a storage client for testing."""
 38 |     client = LocalStorageClient()
 39 |     return client
 40 | 
 41 | 
 42 | class TestLocalStorageClient:
 43 |     """Tests for LocalStorageClient."""
 44 | 
 45 |     def test_init_creates_directories(self, storage_client, mock_settings):
 46 |         """Test that initialization creates the required directories."""
 47 |         # All directories should be created during initialization
 48 |         assert mock_settings.STORAGE_DIR.exists()
 49 |         assert mock_settings.YARA_RULES_DIR.exists()
 50 |         assert mock_settings.YARA_SAMPLES_DIR.exists()
 51 |         assert mock_settings.YARA_RESULTS_DIR.exists()
 52 |         assert (mock_settings.STORAGE_DIR / "files").exists()
 53 |         assert (mock_settings.STORAGE_DIR / "files_meta").exists()
 54 |         assert (mock_settings.YARA_RULES_DIR / "community").exists()
 55 |         assert (mock_settings.YARA_RULES_DIR / "custom").exists()
 56 | 
 57 |     def test_save_rule(self, storage_client, mock_settings):
 58 |         """Test saving a YARA rule."""
 59 |         rule_name = "test_rule"
 60 |         rule_content = "rule TestRule { condition: true }"
 61 | 
 62 |         # Test saving without .yar extension
 63 |         path = storage_client.save_rule(rule_name, rule_content)
 64 |         rule_path = mock_settings.YARA_RULES_DIR / "custom" / "test_rule.yar"
 65 | 
 66 |         assert path == str(rule_path)
 67 |         assert rule_path.exists()
 68 | 
 69 |         with open(rule_path, "r") as f:
 70 |             saved_content = f.read()
 71 |         assert saved_content == rule_content
 72 | 
 73 |         # Test saving with .yar extension
 74 |         rule_name_with_ext = "test_rule2.yar"
 75 |         path = storage_client.save_rule(rule_name_with_ext, rule_content)
 76 |         rule_path = mock_settings.YARA_RULES_DIR / "custom" / "test_rule2.yar"
 77 | 
 78 |         assert path == str(rule_path)
 79 |         assert rule_path.exists()
 80 | 
 81 |     def test_get_rule(self, storage_client):
 82 |         """Test getting a YARA rule."""
 83 |         rule_name = "test_get_rule"
 84 |         rule_content = "rule TestGetRule { condition: true }"
 85 | 
 86 |         # Save the rule first
 87 |         storage_client.save_rule(rule_name, rule_content)
 88 | 
 89 |         # Get the rule
 90 |         retrieved_content = storage_client.get_rule(rule_name)
 91 |         assert retrieved_content == rule_content
 92 | 
 93 |         # Test getting a rule with extension
 94 |         retrieved_content = storage_client.get_rule(f"{rule_name}.yar")
 95 |         assert retrieved_content == rule_content
 96 | 
 97 |         # Test getting a nonexistent rule
 98 |         with pytest.raises(StorageError, match="Rule not found"):
 99 |             storage_client.get_rule("nonexistent_rule")
100 | 
101 |     def test_delete_rule(self, storage_client):
102 |         """Test deleting a YARA rule."""
103 |         rule_name = "test_delete_rule"
104 |         rule_content = "rule TestDeleteRule { condition: true }"
105 | 
106 |         # Save the rule first
107 |         storage_client.save_rule(rule_name, rule_content)
108 | 
109 |         # Delete the rule
110 |         result = storage_client.delete_rule(rule_name)
111 |         assert result is True
112 | 
113 |         # Verify it's gone
114 |         with pytest.raises(StorageError, match="Rule not found"):
115 |             storage_client.get_rule(rule_name)
116 | 
117 |         # Test deleting a nonexistent rule
118 |         result = storage_client.delete_rule("nonexistent_rule")
119 |         assert result is False
120 | 
121 |     def test_list_rules(self, storage_client):
122 |         """Test listing YARA rules."""
123 |         # Save some rules
124 |         storage_client.save_rule("test_list_1", "rule Test1 { condition: true }", "custom")
125 |         storage_client.save_rule("test_list_2", "rule Test2 { condition: true }", "custom")
126 |         storage_client.save_rule("test_list_3", "rule Test3 { condition: true }", "community")
127 | 
128 |         # List all rules
129 |         rules = storage_client.list_rules()
130 |         assert len(rules) == 3
131 | 
132 |         # Check rule names
133 |         rule_names = [rule["name"] for rule in rules]
134 |         assert "test_list_1.yar" in rule_names
135 |         assert "test_list_2.yar" in rule_names
136 |         assert "test_list_3.yar" in rule_names
137 | 
138 |         # Test filtering by source
139 |         custom_rules = storage_client.list_rules(source="custom")
140 |         assert len(custom_rules) == 2
141 |         custom_names = [rule["name"] for rule in custom_rules]
142 |         assert "test_list_1.yar" in custom_names
143 |         assert "test_list_2.yar" in custom_names
144 |         assert "test_list_3.yar" not in custom_names
145 | 
146 |         community_rules = storage_client.list_rules(source="community")
147 |         assert len(community_rules) == 1
148 |         assert community_rules[0]["name"] == "test_list_3.yar"
149 | 
150 |     def test_save_sample(self, storage_client, mock_settings):
151 |         """Test saving a sample file."""
152 |         filename = "test_sample.bin"
153 |         content = b"Test sample content"
154 | 
155 |         # Save the sample
156 |         path, file_hash = storage_client.save_sample(filename, content)
157 | 
158 |         # Check the hash
159 |         expected_hash = hashlib.sha256(content).hexdigest()
160 |         assert file_hash == expected_hash
161 | 
162 |         # Verify the file exists
163 |         sample_path = Path(path)
164 |         assert sample_path.exists()
165 | 
166 |         # Check the content
167 |         with open(sample_path, "rb") as f:
168 |             saved_content = f.read()
169 |         assert saved_content == content
170 | 
171 |         # Test with file-like object
172 |         from io import BytesIO
173 | 
174 |         file_obj = BytesIO(b"File-like object content")
175 |         path2, hash2 = storage_client.save_sample("file_obj.bin", file_obj)
176 | 
177 |         # Verify the file exists
178 |         sample_path2 = Path(path2)
179 |         assert sample_path2.exists()
180 | 
181 |         # Check the content
182 |         with open(sample_path2, "rb") as f:
183 |             saved_content2 = f.read()
184 |         assert saved_content2 == b"File-like object content"
185 | 
186 |     def test_get_sample(self, storage_client):
187 |         """Test getting a sample."""
188 |         filename = "test_get_sample.bin"
189 |         content = b"Test get sample content"
190 | 
191 |         # Save the sample first
192 |         path, file_hash = storage_client.save_sample(filename, content)
193 | 
194 |         # Get by file path
195 |         retrieved_content = storage_client.get_sample(path)
196 |         assert retrieved_content == content
197 | 
198 |         # Get by hash
199 |         retrieved_content = storage_client.get_sample(file_hash)
200 |         assert retrieved_content == content
201 | 
202 |         # Test with nonexistent sample
203 |         with pytest.raises(StorageError, match="Sample not found"):
204 |             storage_client.get_sample("nonexistent_sample")
205 | 
206 |     def test_save_result(self, storage_client, mock_settings):
207 |         """Test saving a scan result."""
208 |         result_id = "test-result-12345"
209 |         result_content = {"matches": [{"rule": "test", "strings": []}]}
210 | 
211 |         # Save the result
212 |         path = storage_client.save_result(result_id, result_content)
213 | 
214 |         # Verify the file exists
215 |         result_path = Path(path)
216 |         assert result_path.exists()
217 | 
218 |         # Check the content
219 |         with open(result_path, "r") as f:
220 |             saved_content = json.load(f)
221 |         assert saved_content == result_content
222 | 
223 |         # Test with special characters in the ID
224 |         special_id = "test/result\\with:special?chars"
225 |         path = storage_client.save_result(special_id, result_content)
226 | 
227 |         # Verify the file exists with sanitized name
228 |         result_path = Path(path)
229 |         assert result_path.exists()
230 | 
231 |     def test_get_result(self, storage_client):
232 |         """Test getting a scan result."""
233 |         result_id = "test-get-result"
234 |         result_content = {"matches": [{"rule": "test_get", "strings": []}]}
235 | 
236 |         # Save the result first
237 |         path = storage_client.save_result(result_id, result_content)
238 | 
239 |         # Get by ID
240 |         retrieved_content = storage_client.get_result(result_id)
241 |         assert retrieved_content == result_content
242 | 
243 |         # Get by path
244 |         retrieved_content = storage_client.get_result(path)
245 |         assert retrieved_content == result_content
246 | 
247 |         # Test with nonexistent result
248 |         with pytest.raises(StorageError, match="Result not found"):
249 |             storage_client.get_result("nonexistent_result")
250 | 
251 |     def test_save_file(self, storage_client, mock_settings):
252 |         """Test saving a file with metadata."""
253 |         filename = "test_file.txt"
254 |         content = b"Test file content"
255 |         metadata = {"test_key": "test_value", "source": "test"}
256 | 
257 |         # Save the file
258 |         file_info = storage_client.save_file(filename, content, metadata)
259 | 
260 |         # Check the returned info
261 |         assert file_info["file_name"] == filename
262 |         assert file_info["file_size"] == len(content)
263 |         assert "file_id" in file_info
264 |         assert "file_hash" in file_info
265 |         assert file_info["metadata"] == metadata
266 | 
267 |         # Verify the metadata file exists
268 |         file_id = file_info["file_id"]
269 |         meta_path = mock_settings.STORAGE_DIR / "files_meta" / f"{file_id}.json"
270 |         assert meta_path.exists()
271 | 
272 |         # Check the metadata content
273 |         with open(meta_path, "r") as f:
274 |             saved_meta = json.load(f)
275 |         assert saved_meta["file_name"] == filename
276 |         assert saved_meta["metadata"] == metadata
277 | 
278 |         # Verify the actual file exists
279 |         file_path_components = [mock_settings.STORAGE_DIR, "files", file_id[:2], file_id[2:4], filename]
280 |         file_path = Path(*file_path_components)
281 |         assert file_path.exists()
282 | 
283 |         # Check the file content
284 |         with open(file_path, "rb") as f:
285 |             saved_content = f.read()
286 |         assert saved_content == content
287 | 
288 |         # Test with file-like object
289 |         from io import BytesIO
290 | 
291 |         file_obj = BytesIO(b"File object content")
292 |         file_info2 = storage_client.save_file("file_obj.txt", file_obj)
293 | 
294 |         # Verify the file exists
295 |         file_id2 = file_info2["file_id"]
296 |         file_path2_components = [mock_settings.STORAGE_DIR, "files", file_id2[:2], file_id2[2:4], "file_obj.txt"]
297 |         file_path2 = Path(*file_path2_components)
298 |         assert file_path2.exists()
299 | 
300 |     def test_get_file(self, storage_client):
301 |         """Test getting a file."""
302 |         filename = "test_get_file.txt"
303 |         content = b"Test get file content"
304 | 
305 |         # Save the file first
306 |         file_info = storage_client.save_file(filename, content)
307 |         file_id = file_info["file_id"]
308 | 
309 |         # Get the file
310 |         retrieved_content = storage_client.get_file(file_id)
311 |         assert retrieved_content == content
312 | 
313 |         # Test with nonexistent file
314 |         with pytest.raises(StorageError, match="File not found"):
315 |             storage_client.get_file("nonexistent-file-id")
316 | 
317 |     def test_get_file_info(self, storage_client):
318 |         """Test getting file metadata."""
319 |         filename = "test_file_info.txt"
320 |         content = b"Test file info content"
321 |         metadata = {"test_key": "test_value"}
322 | 
323 |         # Save the file first
324 |         file_info = storage_client.save_file(filename, content, metadata)
325 |         file_id = file_info["file_id"]
326 | 
327 |         # Get the file info
328 |         retrieved_info = storage_client.get_file_info(file_id)
329 | 
330 |         # Check the info
331 |         assert retrieved_info["file_name"] == filename
332 |         assert retrieved_info["file_size"] == len(content)
333 |         assert retrieved_info["metadata"] == metadata
334 | 
335 |         # Test with nonexistent file
336 |         with pytest.raises(StorageError, match="File not found"):
337 |             storage_client.get_file_info("nonexistent-file-id")
338 | 
339 |     def test_list_files(self, storage_client):
340 |         """Test listing files with pagination."""
341 |         # Save multiple files
342 |         num_files = 15
343 |         for i in range(num_files):
344 |             storage_client.save_file(f"list_file_{i}.txt", f"Content {i}".encode(), {"index": i})
345 | 
346 |         # Test default pagination
347 |         result = storage_client.list_files()
348 |         assert result["total"] == num_files
349 |         assert len(result["files"]) == num_files
350 |         assert result["page"] == 1
351 |         assert result["page_size"] == 100
352 | 
353 |         # Test custom pagination
354 |         page_size = 5
355 |         result = storage_client.list_files(page=1, page_size=page_size)
356 |         assert result["total"] == num_files
357 |         assert len(result["files"]) == page_size
358 |         assert result["page"] == 1
359 |         assert result["page_size"] == page_size
360 | 
361 |         # Test second page
362 |         result = storage_client.list_files(page=2, page_size=page_size)
363 |         assert result["total"] == num_files
364 |         assert len(result["files"]) == page_size
365 |         assert result["page"] == 2
366 | 
367 |         # Test sorting
368 |         # Default is by uploaded_at descending
369 |         result = storage_client.list_files(sort_by="file_name", sort_desc=False)
370 |         names = [f["file_name"] for f in result["files"]]
371 |         assert sorted(names) == names
372 | 
373 |         result = storage_client.list_files(sort_by="file_name", sort_desc=True)
374 |         names = [f["file_name"] for f in result["files"]]
375 |         assert sorted(names, reverse=True) == names
376 | 
377 |     def test_delete_file(self, storage_client):
378 |         """Test deleting a file."""
379 |         filename = "test_delete_file.txt"
380 |         content = b"Test delete file content"
381 | 
382 |         # Save the file first
383 |         file_info = storage_client.save_file(filename, content)
384 |         file_id = file_info["file_id"]
385 | 
386 |         # Delete the file
387 |         result = storage_client.delete_file(file_id)
388 |         assert result is True
389 | 
390 |         # Verify it's gone
391 |         with pytest.raises(StorageError, match="File not found"):
392 |             storage_client.get_file(file_id)
393 | 
394 |         with pytest.raises(StorageError, match="File not found"):
395 |             storage_client.get_file_info(file_id)
396 | 
397 |         # Test deleting a nonexistent file
398 |         result = storage_client.delete_file("nonexistent-file-id")
399 |         assert result is False
400 | 
401 |     def test_extract_strings(self, storage_client):
402 |         """Test extracting strings from a file."""
403 |         # Create a file with both ASCII and Unicode strings
404 |         content = b"Hello, world!\x00\x00\x00This is a test.\x00\x00"
405 |         content += "Unicode test string".encode("utf-16le")
406 | 
407 |         file_info = storage_client.save_file("strings_test.bin", content)
408 |         file_id = file_info["file_id"]
409 | 
410 |         # Extract strings with default settings
411 |         result = storage_client.extract_strings(file_id)
412 | 
413 |         # Check the result structure
414 |         assert result["file_id"] == file_id
415 |         assert result["file_name"] == "strings_test.bin"
416 |         assert "strings" in result
417 |         assert "total_strings" in result
418 |         assert result["min_length"] == 4
419 |         assert result["include_unicode"] is True
420 |         assert result["include_ascii"] is True
421 | 
422 |         # Check with custom settings
423 |         result = storage_client.extract_strings(file_id, min_length=10, include_unicode=False, limit=1)
424 |         assert result["min_length"] == 10
425 |         assert result["include_unicode"] is False
426 |         assert result["include_ascii"] is True
427 |         assert len(result["strings"]) <= 1  # Might be 0 if no strings meet criteria
428 | 
429 |         # Test with nonexistent file
430 |         with pytest.raises(StorageError, match="File not found"):
431 |             storage_client.extract_strings("nonexistent-file-id")
432 | 
433 |     def test_get_hex_view(self, storage_client):
434 |         """Test getting a hex view of a file."""
435 |         # Create a test file with varied content
436 |         content = bytes(range(0, 128))  # 0-127 byte values
437 |         file_info = storage_client.save_file("hex_test.bin", content)
438 |         file_id = file_info["file_id"]
439 | 
440 |         # Get hex view with default settings
441 |         result = storage_client.get_hex_view(file_id)
442 | 
443 |         # Check the result structure
444 |         assert result["file_id"] == file_id
445 |         assert result["file_name"] == "hex_test.bin"
446 |         assert "hex_content" in result
447 |         assert result["offset"] == 0
448 |         assert result["bytes_per_line"] == 16
449 |         assert result["total_size"] == len(content)
450 | 
451 |         # The hex view should contain string representations
452 |         assert "00000000" in result["hex_content"]  # Offset
453 |         assert "00 01 02 03" in result["hex_content"]  # Hex values
454 | 
455 |         # Test with custom settings
456 |         result = storage_client.get_hex_view(file_id, offset=16, length=32, bytes_per_line=8)
457 |         assert result["offset"] == 16
458 |         assert result["length"] == 32
459 |         assert result["bytes_per_line"] == 8
460 | 
461 |         # Now the hex view should start at 16 (0x10)
462 |         assert "00000010" in result["hex_content"]
463 | 
464 |         # Test with offset beyond file size
465 |         result = storage_client.get_hex_view(file_id, offset=1000)
466 |         assert result["hex_content"] == ""
467 | 
468 |         # Test with nonexistent file
469 |         with pytest.raises(StorageError, match="File not found"):
470 |             storage_client.get_hex_view("nonexistent-file-id")
471 | 
```

--------------------------------------------------------------------------------
/tests/unit/test_mcp_tools/test_rule_tools.py:
--------------------------------------------------------------------------------

```python
  1 | """Fixed tests for rule tools to improve coverage."""
  2 | 
  3 | import json
  4 | from unittest.mock import MagicMock, Mock, patch
  5 | 
  6 | import pytest
  7 | from fastapi import HTTPException
  8 | 
  9 | from yaraflux_mcp_server.mcp_tools.rule_tools import (
 10 |     add_yara_rule,
 11 |     delete_yara_rule,
 12 |     get_yara_rule,
 13 |     import_threatflux_rules,
 14 |     list_yara_rules,
 15 |     update_yara_rule,
 16 |     validate_yara_rule,
 17 | )
 18 | from yaraflux_mcp_server.yara_service import YaraError
 19 | 
 20 | 
 21 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 22 | def test_list_yara_rules_success(mock_yara_service):
 23 |     """Test list_yara_rules successfully returns rules."""
 24 |     # Setup mocks
 25 |     rule1 = Mock()
 26 |     rule1.model_dump.return_value = {"name": "rule1.yar", "source": "custom"}
 27 |     rule2 = Mock()
 28 |     rule2.model_dump.return_value = {"name": "rule2.yar", "source": "community"}
 29 |     mock_yara_service.list_rules.return_value = [rule1, rule2]
 30 | 
 31 |     # Call the function (without filters)
 32 |     result = list_yara_rules()
 33 | 
 34 |     # Verify results
 35 |     assert len(result) == 2
 36 |     assert {"name": "rule1.yar", "source": "custom"} in result
 37 |     assert {"name": "rule2.yar", "source": "community"} in result
 38 | 
 39 |     # Verify mocks were called correctly
 40 |     mock_yara_service.list_rules.assert_called_once_with(None)
 41 | 
 42 | 
 43 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 44 | def test_list_yara_rules_filtered(mock_yara_service):
 45 |     """Test list_yara_rules with source filtering."""
 46 |     # Setup mocks
 47 |     rule1 = Mock()
 48 |     rule1.model_dump.return_value = {"name": "rule1.yar", "source": "custom"}
 49 |     rule2 = Mock()
 50 |     rule2.model_dump.return_value = {"name": "rule2.yar", "source": "custom"}
 51 |     mock_yara_service.list_rules.return_value = [rule1, rule2]
 52 | 
 53 |     # Call the function with source filter
 54 |     result = list_yara_rules("custom")
 55 | 
 56 |     # Verify results
 57 |     assert len(result) == 2
 58 | 
 59 |     # Verify mocks were called correctly
 60 |     mock_yara_service.list_rules.assert_called_once_with("custom")
 61 | 
 62 | 
 63 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 64 | def test_list_yara_rules_all_source(mock_yara_service):
 65 |     """Test list_yara_rules with 'all' source."""
 66 |     # Setup mocks
 67 |     rule1 = Mock()
 68 |     rule1.model_dump.return_value = {"name": "rule1.yar", "source": "custom"}
 69 |     rule2 = Mock()
 70 |     rule2.model_dump.return_value = {"name": "rule2.yar", "source": "community"}
 71 |     mock_yara_service.list_rules.return_value = [rule1, rule2]
 72 | 
 73 |     # Call the function with 'all' source
 74 |     result = list_yara_rules("all")
 75 | 
 76 |     # Verify results
 77 |     assert len(result) == 2
 78 | 
 79 |     # Verify mocks were called correctly
 80 |     mock_yara_service.list_rules.assert_called_once_with(None)
 81 | 
 82 | 
 83 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 84 | def test_list_yara_rules_error(mock_yara_service):
 85 |     """Test list_yara_rules with an error."""
 86 |     # Setup mock to raise an exception
 87 |     mock_yara_service.list_rules.side_effect = Exception("Test error")
 88 | 
 89 |     # Call the function
 90 |     result = list_yara_rules()
 91 | 
 92 |     # Verify results
 93 |     assert result == []
 94 | 
 95 | 
 96 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 97 | def test_get_yara_rule_success(mock_yara_service):
 98 |     """Test get_yara_rule successfully retrieves a rule."""
 99 |     # Setup mocks
100 |     mock_yara_service.get_rule.return_value = "rule test { condition: true }"
101 |     rule = Mock()
102 |     rule.name = "test.yar"
103 |     rule.model_dump.return_value = {"name": "test.yar", "source": "custom"}
104 |     mock_yara_service.list_rules.return_value = [rule]
105 | 
106 |     # Call the function
107 |     result = get_yara_rule(rule_name="test.yar", source="custom")
108 | 
109 |     # Verify results
110 |     assert result["success"] is True
111 |     assert result["result"]["name"] == "test.yar"
112 |     assert result["result"]["source"] == "custom"
113 |     assert result["result"]["content"] == "rule test { condition: true }"
114 |     assert result["result"]["metadata"] == {"name": "test.yar", "source": "custom"}
115 | 
116 |     # Verify mocks were called correctly
117 |     mock_yara_service.get_rule.assert_called_once_with("test.yar", "custom")
118 |     mock_yara_service.list_rules.assert_called_once_with("custom")
119 | 
120 | 
121 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
122 | def test_get_yara_rule_invalid_source(mock_yara_service):
123 |     """Test get_yara_rule with invalid source."""
124 |     # Call the function with invalid source
125 |     result = get_yara_rule(rule_name="test.yar", source="invalid")
126 | 
127 |     # Verify results
128 |     assert result["success"] is False
129 |     assert "Invalid source" in result["message"]
130 | 
131 |     # Verify mock was not called
132 |     mock_yara_service.get_rule.assert_not_called()
133 | 
134 | 
135 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
136 | def test_get_yara_rule_no_metadata(mock_yara_service):
137 |     """Test get_yara_rule with no matching metadata."""
138 |     # Setup mocks
139 |     mock_yara_service.get_rule.return_value = "rule test { condition: true }"
140 |     rule = Mock()
141 |     rule.name = "other_rule.yar"
142 |     rule.model_dump.return_value = {"name": "other_rule.yar", "source": "custom"}
143 |     mock_yara_service.list_rules.return_value = [rule]  # Different rule name
144 | 
145 |     # Call the function
146 |     result = get_yara_rule(rule_name="test.yar", source="custom")
147 | 
148 |     # Verify results
149 |     assert result["success"] is True
150 |     assert result["result"]["name"] == "test.yar"
151 |     assert result["result"]["metadata"] == {}  # No metadata found
152 | 
153 |     # Verify mocks were called correctly
154 |     mock_yara_service.get_rule.assert_called_once_with("test.yar", "custom")
155 | 
156 | 
157 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
158 | def test_get_yara_rule_error(mock_yara_service):
159 |     """Test get_yara_rule with error."""
160 |     # Setup mock to raise an exception
161 |     mock_yara_service.get_rule.side_effect = YaraError("Rule not found")
162 | 
163 |     # Call the function
164 |     result = get_yara_rule(rule_name="test.yar", source="custom")
165 | 
166 |     # Verify results
167 |     assert result["success"] is False
168 |     assert "Rule not found" in result["message"]
169 |     assert result["name"] == "test.yar"
170 |     assert result["source"] == "custom"
171 | 
172 | 
173 | @patch("builtins.__import__")
174 | def test_validate_yara_rule_valid(mock_import):
175 |     """Test validate_yara_rule with valid rule."""
176 |     # Setup mock for the yara import
177 |     mock_yara_module = Mock()
178 |     mock_import.return_value = mock_yara_module
179 | 
180 |     # Call the function
181 |     result = validate_yara_rule(content="rule test { condition: true }")
182 | 
183 |     # Verify results
184 |     assert "valid" in result
185 |     assert result["valid"] is True
186 |     assert result["message"] == "Rule is valid"
187 | 
188 | 
189 | @patch("builtins.__import__")
190 | def test_validate_yara_rule_invalid(mock_import):
191 |     """Test validate_yara_rule with invalid rule."""
192 |     # Setup mocks for the yara import to raise an exception
193 |     mock_yara_module = Mock()
194 |     mock_yara_module.compile.side_effect = Exception('line 1: undefined identifier "invalid"')
195 |     mock_import.return_value = mock_yara_module
196 | 
197 |     # Call the function
198 |     result = validate_yara_rule(content="rule test { condition: invalid }")
199 | 
200 |     # Verify results
201 |     assert "valid" in result
202 |     assert result["valid"] is False
203 |     assert "undefined identifier" in result["message"]
204 |     assert result["error_type"] == "YaraError"
205 | 
206 | 
207 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
208 | def test_add_yara_rule_success(mock_yara_service):
209 |     """Test add_yara_rule successfully adds a rule."""
210 |     # Setup mock
211 |     metadata = Mock()
212 |     metadata.model_dump.return_value = {"name": "test.yar", "source": "custom"}
213 |     mock_yara_service.add_rule.return_value = metadata
214 | 
215 |     # Call the function
216 |     result = add_yara_rule(name="test.yar", content="rule test { condition: true }", source="custom")
217 | 
218 |     # Verify results
219 |     assert result["success"] is True
220 |     assert "added successfully" in result["message"]
221 |     assert result["metadata"] == {"name": "test.yar", "source": "custom"}
222 | 
223 |     # Verify mock was called correctly
224 |     mock_yara_service.add_rule.assert_called_once_with("test.yar", "rule test { condition: true }", "custom")
225 | 
226 | 
227 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
228 | def test_add_yara_rule_adds_extension(mock_yara_service):
229 |     """Test add_yara_rule adds .yar extension if missing."""
230 |     # Setup mock
231 |     metadata = Mock()
232 |     metadata.model_dump.return_value = {"name": "test.yar", "source": "custom"}
233 |     mock_yara_service.add_rule.return_value = metadata
234 | 
235 |     # Call the function without .yar extension
236 |     result = add_yara_rule(name="test", content="rule test { condition: true }", source="custom")  # No .yar extension
237 | 
238 |     # Verify results
239 |     assert result["success"] is True
240 | 
241 |     # Verify mock was called with .yar extension
242 |     mock_yara_service.add_rule.assert_called_once_with("test.yar", "rule test { condition: true }", "custom")
243 | 
244 | 
245 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
246 | def test_add_yara_rule_invalid_source(mock_yara_service):
247 |     """Test add_yara_rule with invalid source."""
248 |     # Call the function with invalid source
249 |     result = add_yara_rule(name="test.yar", content="rule test { condition: true }", source="invalid")
250 | 
251 |     # Verify results
252 |     assert result["success"] is False
253 |     assert "Invalid source" in result["message"]
254 | 
255 |     # Verify mock was not called
256 |     mock_yara_service.add_rule.assert_not_called()
257 | 
258 | 
259 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
260 | def test_add_yara_rule_empty_content(mock_yara_service):
261 |     """Test add_yara_rule with empty content."""
262 |     # Call the function with empty content
263 |     result = add_yara_rule(name="test.yar", content="   ", source="custom")  # Empty after strip
264 | 
265 |     # Verify results
266 |     assert result["success"] is False
267 |     assert "content cannot be empty" in result["message"]
268 | 
269 |     # Verify mock was not called
270 |     mock_yara_service.add_rule.assert_not_called()
271 | 
272 | 
273 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
274 | def test_add_yara_rule_error(mock_yara_service):
275 |     """Test add_yara_rule with error."""
276 |     # Setup mock to raise an exception
277 |     mock_yara_service.add_rule.side_effect = YaraError("Compilation error")
278 | 
279 |     # Call the function
280 |     result = add_yara_rule(name="test.yar", content="rule test { condition: true }", source="custom")
281 | 
282 |     # Verify results
283 |     assert result["success"] is False
284 |     assert "Compilation error" in result["message"]
285 | 
286 | 
287 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
288 | def test_update_yara_rule_success(mock_yara_service):
289 |     """Test update_yara_rule successfully updates a rule."""
290 |     # Setup mocks
291 |     metadata = Mock()
292 |     metadata.model_dump.return_value = {"name": "test.yar", "source": "custom"}
293 |     mock_yara_service.update_rule.return_value = metadata
294 | 
295 |     # Call the function
296 |     result = update_yara_rule(name="test.yar", content="rule test { condition: true }", source="custom")
297 | 
298 |     # Verify results
299 |     assert result["success"] is True
300 |     assert "updated successfully" in result["message"]
301 |     assert result["metadata"] == {"name": "test.yar", "source": "custom"}
302 | 
303 |     # Verify mocks were called correctly
304 |     mock_yara_service.get_rule.assert_called_once_with("test.yar", "custom")
305 |     mock_yara_service.update_rule.assert_called_once_with("test.yar", "rule test { condition: true }", "custom")
306 | 
307 | 
308 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
309 | def test_update_yara_rule_not_found(mock_yara_service):
310 |     """Test update_yara_rule with rule not found."""
311 |     # Setup mock to raise an exception
312 |     mock_yara_service.get_rule.side_effect = YaraError("Rule not found")
313 | 
314 |     # Call the function
315 |     result = update_yara_rule(name="test.yar", content="rule test { condition: true }", source="custom")
316 | 
317 |     # Verify results
318 |     assert result["success"] is False
319 |     assert "Rule not found" in result["message"]
320 | 
321 |     # Verify only get_rule was called, not update_rule
322 |     mock_yara_service.get_rule.assert_called_once_with("test.yar", "custom")
323 |     mock_yara_service.update_rule.assert_not_called()
324 | 
325 | 
326 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
327 | def test_delete_yara_rule_success(mock_yara_service):
328 |     """Test delete_yara_rule successfully deletes a rule."""
329 |     # Setup mock
330 |     mock_yara_service.delete_rule.return_value = True
331 | 
332 |     # Call the function
333 |     result = delete_yara_rule(name="test.yar", source="custom")
334 | 
335 |     # Verify results
336 |     assert result["success"] is True
337 |     assert "deleted successfully" in result["message"]
338 | 
339 |     # Verify mock was called correctly
340 |     mock_yara_service.delete_rule.assert_called_once_with("test.yar", "custom")
341 | 
342 | 
343 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
344 | def test_delete_yara_rule_not_found(mock_yara_service):
345 |     """Test delete_yara_rule with rule not found."""
346 |     # Setup mock
347 |     mock_yara_service.delete_rule.return_value = False
348 | 
349 |     # Call the function
350 |     result = delete_yara_rule(name="test.yar", source="custom")
351 | 
352 |     # Verify results
353 |     assert result["success"] is False
354 |     assert "not found" in result["message"]
355 | 
356 |     # Verify mock was called correctly
357 |     mock_yara_service.delete_rule.assert_called_once_with("test.yar", "custom")
358 | 
359 | 
360 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
361 | def test_delete_yara_rule_error(mock_yara_service):
362 |     """Test delete_yara_rule with error."""
363 |     # Setup mock to raise an exception
364 |     mock_yara_service.delete_rule.side_effect = YaraError("Permission denied")
365 | 
366 |     # Call the function
367 |     result = delete_yara_rule(name="test.yar", source="custom")
368 | 
369 |     # Verify results
370 |     assert result["success"] is False
371 |     assert "Permission denied" in result["message"]
372 | 
373 | 
374 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
375 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
376 | def test_import_threatflux_rules_success(mock_yara_service, mock_httpx):
377 |     """Test import_threatflux_rules successfully imports rules."""
378 |     # Setup mock test response
379 |     mock_test_response = MagicMock()
380 |     mock_test_response.status_code = 200
381 | 
382 |     # Setup mock index response
383 |     mock_response = MagicMock()
384 |     mock_response.status_code = 200
385 |     mock_response.json.return_value = {"rules": ["rule1.yar", "rule2.yar"]}
386 | 
387 |     # Setup mock response for rule files
388 |     mock_rule_response = MagicMock()
389 |     mock_rule_response.status_code = 200
390 |     mock_rule_response.text = "rule test { condition: true }"
391 | 
392 |     # Configure httpx mock to return different responses for different calls
393 |     mock_httpx.get.side_effect = [mock_test_response, mock_response, mock_rule_response, mock_rule_response]
394 | 
395 |     # Call the function
396 |     result = import_threatflux_rules()
397 | 
398 |     # Verify results
399 |     assert result["success"] is True
400 |     # Verify yara_service was called
401 |     assert mock_yara_service.add_rule.call_count >= 1
402 |     mock_yara_service.load_rules.assert_called_once()
403 | 
404 | 
405 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
406 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
407 | def test_import_threatflux_rules_with_custom_url(mock_yara_service, mock_httpx):
408 |     """Test import_threatflux_rules with custom URL."""
409 |     # Setup mock test response
410 |     mock_test_response = MagicMock()
411 |     mock_test_response.status_code = 200
412 | 
413 |     # Setup mock response for index.json
414 |     mock_response = MagicMock()
415 |     mock_response.status_code = 200
416 |     mock_response.json.return_value = {"rules": ["rule1.yar"]}
417 | 
418 |     # Setup mock response for rule file
419 |     mock_rule_response = MagicMock()
420 |     mock_rule_response.status_code = 200
421 |     mock_rule_response.text = "rule test { condition: true }"
422 | 
423 |     # Configure httpx mock to return different responses
424 |     mock_httpx.get.side_effect = [mock_test_response, mock_response, mock_rule_response]
425 | 
426 |     # Call the function with custom URL
427 |     result = import_threatflux_rules(url="https://github.com/custom/repo")
428 | 
429 |     # Verify results
430 |     assert result["success"] is True
431 | 
432 |     # Verify connection test was made first
433 |     mock_httpx.get.assert_any_call("https://raw.githubusercontent.com/custom/repo", timeout=10)
434 | 
435 | 
436 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
437 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
438 | def test_import_threatflux_rules_no_index(mock_yara_service, mock_httpx):
439 |     """Test import_threatflux_rules with no index.json."""
440 |     # Setup initial test response (success)
441 |     mock_test_response = MagicMock()
442 |     mock_test_response.status_code = 200
443 | 
444 |     # Setup mock response for index.json (not found)
445 |     mock_response = MagicMock()
446 |     mock_response.status_code = 404
447 | 
448 |     # Setup mock response for rule file
449 |     mock_rule_response = MagicMock()
450 |     mock_rule_response.status_code = 200
451 |     mock_rule_response.text = "rule test { condition: true }"
452 | 
453 |     # Configure httpx mock to return different responses
454 |     # First 200 for test, then 404 for index, then a few 200s for rule files
455 |     mock_httpx.get.side_effect = [mock_test_response, mock_response, mock_rule_response, mock_rule_response]
456 | 
457 |     # Call the function
458 |     result = import_threatflux_rules()
459 | 
460 |     # Still should successfully import some rules
461 |     assert result["success"] is True
462 | 
463 | 
464 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
465 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
466 | def test_import_threatflux_rules_error(mock_yara_service, mock_httpx):
467 |     """Test import_threatflux_rules with error."""
468 |     # Setup httpx to raise an exception for the first get call
469 |     mock_httpx.get.side_effect = Exception("Connection error")
470 | 
471 |     # Call the function
472 |     result = import_threatflux_rules()
473 | 
474 |     # Verify results - with our new connection test implementation
475 |     assert isinstance(result, dict)
476 |     assert "success" in result
477 |     assert not result["success"]  # Should be False
478 |     assert "message" in result
479 |     assert "Connection error" in result["message"]
480 |     assert "error" in result
481 | 
```

--------------------------------------------------------------------------------
/src/yaraflux_mcp_server/mcp_tools/rule_tools.py:
--------------------------------------------------------------------------------

```python
  1 | """YARA rule management tools for Claude MCP integration.
  2 | 
  3 | This module provides tools for managing YARA rules, including listing,
  4 | adding, updating, validating, and deleting rules. It uses direct function
  5 | implementations with inline error handling.
  6 | """
  7 | 
  8 | import logging
  9 | import os
 10 | import tempfile
 11 | from datetime import UTC, datetime
 12 | from pathlib import Path
 13 | from tarfile import ReadError
 14 | from typing import Any, Dict, List, Optional
 15 | 
 16 | import httpx
 17 | 
 18 | from yaraflux_mcp_server.mcp_tools.base import register_tool
 19 | from yaraflux_mcp_server.yara_service import YaraError, yara_service
 20 | 
 21 | # Configure logging
 22 | logger = logging.getLogger(__name__)
 23 | 
 24 | 
 25 | @register_tool()
 26 | def list_yara_rules(source: Optional[str] = None) -> List[Dict[str, Any]]:
 27 |     """List available YARA rules.
 28 | 
 29 |     For LLM users connecting through MCP, this can be invoked with natural language like:
 30 |     "Show me all YARA rules"
 31 |     "List custom YARA rules only"
 32 |     "What community rules are available?"
 33 | 
 34 |     Args:
 35 |         source: Optional source filter ("custom" or "community")
 36 | 
 37 |     Returns:
 38 |         List of YARA rule metadata objects
 39 |     """
 40 |     try:
 41 |         # Validate source if provided
 42 |         if source and source not in ["custom", "community", "all"]:
 43 |             raise ValueError(f"Invalid source: {source}. Must be 'custom', 'community', or 'all'")
 44 | 
 45 |         # Get rules from the YARA service
 46 |         rules = yara_service.list_rules(None if source == "all" else source)
 47 | 
 48 |         # Convert to dict for serialization
 49 |         return [rule.model_dump() for rule in rules]
 50 |     except ValueError as e:
 51 |         logger.error(f"Value error in list_yara_rules: {str(e)}")
 52 |         return []
 53 |     except Exception as e:
 54 |         logger.error(f"Error listing YARA rules: {str(e)}")
 55 |         return []
 56 | 
 57 | 
 58 | @register_tool()
 59 | def get_yara_rule(rule_name: str, source: str = "custom") -> Dict[str, Any]:
 60 |     """Get a YARA rule's content.
 61 | 
 62 |     For LLM users connecting through MCP, this can be invoked with natural language like:
 63 |     "Show me the code for rule suspicious_strings"
 64 |     "Get the content of the ransomware detection rule"
 65 |     "What does the CVE-2023-1234 rule look like?"
 66 | 
 67 |     Args:
 68 |         rule_name: Name of the rule to get
 69 |         source: Source of the rule ("custom" or "community")
 70 | 
 71 |     Returns:
 72 |         Rule content and metadata
 73 |     """
 74 |     try:
 75 |         # Validate source
 76 |         if source not in ["custom", "community"]:
 77 |             raise ValueError(f"Invalid source: {source}. Must be 'custom' or 'community'")
 78 | 
 79 |         # Get rule content
 80 |         content = yara_service.get_rule(rule_name, source)
 81 | 
 82 |         # Get rule metadata
 83 |         rules = yara_service.list_rules(source)
 84 |         metadata = None
 85 |         for rule in rules:
 86 |             if rule.name == rule_name:
 87 |                 metadata = rule
 88 |                 break
 89 | 
 90 |         # Return content and metadata
 91 |         return {
 92 |             "success": True,
 93 |             "result": {
 94 |                 "name": rule_name,
 95 |                 "source": source,
 96 |                 "content": content,
 97 |                 "metadata": metadata.model_dump() if metadata else {},
 98 |             },
 99 |         }
100 |     except YaraError as e:
101 |         logger.error(f"YARA error in get_yara_rule: {str(e)}")
102 |         return {"success": False, "message": str(e), "name": rule_name, "source": source}
103 |     except ValueError as e:
104 |         logger.error(f"Value error in get_yara_rule: {str(e)}")
105 |         return {"success": False, "message": str(e), "name": rule_name, "source": source}
106 |     except Exception as e:
107 |         logger.error(f"Unexpected error in get_yara_rule: {str(e)}")
108 |         return {"success": False, "message": f"Unexpected error: {str(e)}", "name": rule_name, "source": source}
109 | 
110 | 
111 | @register_tool()
112 | def validate_yara_rule(content: str) -> Dict[str, Any]:
113 |     """Validate a YARA rule.
114 | 
115 |     For LLM users connecting through MCP, this can be invoked with natural language like:
116 |     "Check if this YARA rule syntax is valid"
117 |     "Validate this detection rule for me"
118 |     "Is this YARA code correctly formatted?"
119 | 
120 |     Args:
121 |         content: YARA rule content to validate
122 | 
123 |     Returns:
124 |         Validation result with detailed error information if invalid
125 |     """
126 |     try:
127 |         if not content.strip():
128 |             raise ValueError("Rule content cannot be empty")
129 | 
130 |         try:
131 |             # Create a temporary rule name for validation
132 |             temp_rule_name = f"validate_{int(datetime.now(UTC).timestamp())}.yar"
133 | 
134 |             # Attempt to add the rule (this will validate it)
135 |             yara_service.add_rule(temp_rule_name, content)
136 | 
137 |             # Rule is valid, delete it
138 |             yara_service.delete_rule(temp_rule_name)
139 | 
140 |             return {"valid": True, "message": "Rule is valid"}
141 | 
142 |         except YaraError as e:
143 |             # Capture the original compilation error
144 |             error_message = str(e)
145 |             logger.debug("YARA compilation error: %s", error_message)
146 |             raise YaraError("Rule validation failed: " + error_message) from e
147 | 
148 |     except YaraError as e:
149 |         logger.error(f"YARA error in validate_yara_rule: {str(e)}")
150 |         return {"valid": False, "message": str(e), "error_type": "YaraError"}
151 |     except ValueError as e:
152 |         logger.error(f"Value error in validate_yara_rule: {str(e)}")
153 |         return {"valid": False, "message": str(e), "error_type": "ValueError"}
154 |     except Exception as e:
155 |         logger.error(f"Unexpected error in validate_yara_rule: {str(e)}")
156 |         return {
157 |             "valid": False,
158 |             "message": f"Unexpected error: {str(e)}",
159 |             "error_type": e.__class__.__name__,
160 |         }
161 | 
162 | 
163 | @register_tool()
164 | def add_yara_rule(name: str, content: str, source: str = "custom") -> Dict[str, Any]:
165 |     """Add a new YARA rule.
166 | 
167 |     For LLM users connecting through MCP, this can be invoked with natural language like:
168 |     "Create a new YARA rule named suspicious_urls"
169 |     "Add this detection rule for PowerShell obfuscation"
170 |     "Save this YARA rule to detect malicious macros"
171 | 
172 |     Args:
173 |         name: Name of the rule
174 |         content: YARA rule content
175 |         source: Source of the rule ("custom" or "community")
176 | 
177 |     Returns:
178 |         Result of the operation
179 |     """
180 |     try:
181 |         # Validate source
182 |         if source not in ["custom", "community"]:
183 |             raise ValueError(f"Invalid source: {source}. Must be 'custom' or 'community'")
184 | 
185 |         # Ensure rule name has .yar extension
186 |         if not name.endswith(".yar"):
187 |             name = f"{name}.yar"
188 | 
189 |         # Validate content
190 |         if not content.strip():
191 |             raise ValueError("Rule content cannot be empty")
192 | 
193 |         # Add the rule
194 |         metadata = yara_service.add_rule(name, content, source)
195 | 
196 |         return {"success": True, "message": f"Rule {name} added successfully", "metadata": metadata.model_dump()}
197 |     except YaraError as e:
198 |         logger.error(f"YARA error in add_yara_rule: {str(e)}")
199 |         return {"success": False, "message": str(e)}
200 |     except ValueError as e:
201 |         logger.error(f"Value error in add_yara_rule: {str(e)}")
202 |         return {"success": False, "message": str(e)}
203 |     except Exception as e:
204 |         logger.error(f"Unexpected error in add_yara_rule: {str(e)}")
205 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
206 | 
207 | 
208 | @register_tool()
209 | def update_yara_rule(name: str, content: str, source: str = "custom") -> Dict[str, Any]:
210 |     """Update an existing YARA rule.
211 | 
212 |     For LLM users connecting through MCP, this can be invoked with natural language like:
213 |     "Update the ransomware detection rule"
214 |     "Modify the suspicious_urls rule to include these new patterns"
215 |     "Fix the syntax error in the malicious_macros rule"
216 | 
217 |     Args:
218 |         name: Name of the rule
219 |         content: Updated YARA rule content
220 |         source: Source of the rule ("custom" or "community")
221 | 
222 |     Returns:
223 |         Result of the operation
224 |     """
225 |     try:
226 |         # Validate source
227 |         if source not in ["custom", "community"]:
228 |             raise ValueError(f"Invalid source: {source}. Must be 'custom' or 'community'")
229 | 
230 |         # Ensure rule exists
231 |         yara_service.get_rule(name, source)  # Will raise YaraError if not found
232 | 
233 |         # Validate content
234 |         if not content.strip():
235 |             raise ValueError("Rule content cannot be empty")
236 | 
237 |         # Update the rule
238 |         metadata = yara_service.update_rule(name, content, source)
239 | 
240 |         return {"success": True, "message": f"Rule {name} updated successfully", "metadata": metadata.model_dump()}
241 |     except YaraError as e:
242 |         logger.error(f"YARA error in update_yara_rule: {str(e)}")
243 |         return {"success": False, "message": str(e)}
244 |     except ValueError as e:
245 |         logger.error(f"Value error in update_yara_rule: {str(e)}")
246 |         return {"success": False, "message": str(e)}
247 |     except Exception as e:
248 |         logger.error(f"Unexpected error in update_yara_rule: {str(e)}")
249 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
250 | 
251 | 
252 | @register_tool()
253 | def delete_yara_rule(name: str, source: str = "custom") -> Dict[str, Any]:
254 |     """Delete a YARA rule.
255 | 
256 |     For LLM users connecting through MCP, this can be invoked with natural language like:
257 |     "Delete the ransomware detection rule"
258 |     "Remove the rule named suspicious_urls"
259 |     "Get rid of the outdated CVE-2020-1234 rule"
260 | 
261 |     Args:
262 |         name: Name of the rule
263 |         source: Source of the rule ("custom" or "community")
264 | 
265 |     Returns:
266 |         Result of the operation
267 |     """
268 |     try:
269 |         # Validate source
270 |         if source not in ["custom", "community"]:
271 |             raise ValueError(f"Invalid source: {source}. Must be 'custom' or 'community'")
272 | 
273 |         # Delete the rule
274 |         result = yara_service.delete_rule(name, source)
275 | 
276 |         if result:
277 |             return {"success": True, "message": f"Rule {name} deleted successfully"}
278 |         return {"success": False, "message": f"Rule {name} not found"}
279 |     except YaraError as e:
280 |         logger.error(f"YARA error in delete_yara_rule: {str(e)}")
281 |         return {"success": False, "message": str(e)}
282 |     except ValueError as e:
283 |         logger.error(f"Value error in delete_yara_rule: {str(e)}")
284 |         return {"success": False, "message": str(e)}
285 |     except Exception as e:
286 |         logger.error(f"Unexpected error in delete_yara_rule: {str(e)}")
287 |         return {"success": False, "message": f"Unexpected error: {str(e)}"}
288 | 
289 | 
290 | @register_tool()
291 | def import_threatflux_rules(url: Optional[str] = None, branch: str = "main") -> Dict[str, Any]:
292 |     """Import ThreatFlux YARA rules from GitHub.
293 | 
294 |     For LLM users connecting through MCP, this can be invoked with natural language like:
295 |     "Import YARA rules from ThreatFlux"
296 |     "Get the latest detection rules from the ThreatFlux repository"
297 |     "Import YARA rules from a custom GitHub repo"
298 | 
299 |     Args:
300 |         url: URL to the GitHub repository (if None, use default ThreatFlux repository)
301 |         branch: Branch name to import from
302 | 
303 |     Returns:
304 |         Import result
305 |     """
306 |     try:
307 |         # Set default URL if not provided
308 |         if url is None:
309 |             url = "https://github.com/ThreatFlux/YARA-Rules"
310 | 
311 |         # Validate branch
312 |         if not branch:
313 |             branch = "main"
314 | 
315 |         import_count = 0
316 |         error_count = 0
317 | 
318 |         # Check for connection errors immediately
319 |         try:
320 |             # Test connection by attempting to access the URL
321 |             test_response = httpx.get(url.replace("github.com", "raw.githubusercontent.com"), timeout=10)
322 |             if test_response.status_code >= 400:
323 |                 raise ValueError(f"HTTP {test_response.status_code}")
324 |         except ConnectionError as e:
325 |             logger.error("Connection error in import_threatflux_rules: %s", str(e))
326 |             return {"success": False, "message": f"Connection error: {str(e)}", "error": str(e)}
327 | 
328 |         # Create a temporary directory for downloading the repo
329 |         with tempfile.TemporaryDirectory() as temp_dir:
330 |             # Set up paths
331 |             temp_path = Path(temp_dir)
332 |             if not temp_path.exists():
333 |                 temp_path.mkdir(parents=True)
334 | 
335 |             # Clone or download the repository
336 |             if "github.com" in url:
337 |                 # Format for raw content
338 |                 raw_url = url.replace("github.com", "raw.githubusercontent.com")
339 |                 if raw_url.endswith("/"):
340 |                     raw_url = raw_url[:-1]
341 | 
342 |                 # Get the repository contents
343 |                 import_path = f"{raw_url}/{branch}"
344 | 
345 |                 # Download and process index.json if available
346 |                 try:
347 |                     index_url = f"{import_path}/index.json"
348 |                     response = httpx.get(index_url, follow_redirects=True)
349 |                     if response.status_code == 200:
350 |                         # Parse index
351 |                         index = response.json()
352 |                         rule_files = index.get("rules", [])
353 | 
354 |                         # Download each rule file
355 |                         for rule_file in rule_files:
356 |                             rule_url = f"{import_path}/{rule_file}"
357 |                             try:
358 |                                 rule_response = httpx.get(rule_url, follow_redirects=True)
359 |                                 if rule_response.status_code == 200:
360 |                                     rule_content = rule_response.text
361 |                                     rule_name = os.path.basename(rule_file)
362 | 
363 |                                     # Add the rule
364 |                                     yara_service.add_rule(rule_name, rule_content, "community")
365 |                                     import_count += 1
366 |                                 else:
367 |                                     logger.warning(
368 |                                         f"Failed to download rule {rule_file}: HTTP {rule_response.status_code}"
369 |                                     )
370 |                                     error_count += 1
371 |                             except Exception as e:
372 |                                 logger.error(f"Error downloading rule {rule_file}: {str(e)}")
373 |                                 error_count += 1
374 |                     else:
375 |                         # No index.json, try a different approach
376 |                         raise ValueError("Index not found")
377 |                 except Exception:  # noqa
378 |                     # Try fetching individual .yar files from specific directories
379 |                     directories = ["malware", "general", "packer", "persistence"]
380 | 
381 |                     for directory in directories:
382 |                         try:
383 |                             # This is a simple approach, in a real implementation, you'd need to
384 |                             # get the directory listing from the GitHub API or parse HTML
385 |                             common_rule_files = [
386 |                                 f"{directory}/apt.yar",
387 |                                 f"{directory}/generic.yar",
388 |                                 f"{directory}/capabilities.yar",
389 |                                 f"{directory}/indicators.yar",
390 |                             ]
391 | 
392 |                             for rule_file in common_rule_files:
393 |                                 rule_url = f"{import_path}/{rule_file}"
394 |                                 try:
395 |                                     rule_response = httpx.get(rule_url, follow_redirects=True)
396 |                                     if rule_response.status_code == 200:
397 |                                         rule_content = rule_response.text
398 |                                         rule_name = os.path.basename(rule_file)
399 | 
400 |                                         # Add the rule
401 |                                         yara_service.add_rule(rule_name, rule_content, "community")
402 |                                         import_count += 1
403 |                                 except Exception:
404 |                                     # Rule file not found, skip
405 |                                     continue
406 |                         except Exception as e:
407 |                             logger.warning(f"Error processing directory {directory}: {str(e)}")
408 |             else:
409 |                 # Local path
410 |                 import_path = Path(url)
411 |                 if not import_path.exists():
412 |                     raise YaraError(f"Local path not found: {url}")
413 | 
414 |                 # Process .yar files
415 |                 for rule_file in import_path.glob("**/*.yar"):
416 |                     try:
417 |                         with open(rule_file, "r", encoding="utf-8") as f:
418 |                             rule_content = f.read()
419 | 
420 |                         rule_name = rule_file.name
421 |                         yara_service.add_rule(rule_name, rule_content, "community")
422 |                         import_count += 1
423 |                     except FileNotFoundError:
424 |                         logger.warning("Rule file not found: %s", rule_file)
425 |                         error_count += 1
426 |                     except ReadError as e:
427 |                         logger.error("Error reading rule file: %s", str(e))
428 |                         error_count += 1
429 | 
430 |         # Reload rules
431 |         yara_service.load_rules()
432 | 
433 |         return {
434 |             "success": True,
435 |             "message": f"Imported {import_count} rules from {url} ({error_count} errors)",
436 |             "import_count": import_count,
437 |             "error_count": error_count,
438 |         }
439 |     except YaraError as e:
440 |         logger.error(f"YARA error in import_threatflux_rules: {str(e)}")
441 |         return {"success": False, "message": str(e)}
442 |     except Exception as e:
443 |         logger.error(f"Unexpected error in import_threatflux_rules: {str(e)}")
444 |         return {
445 |             "success": False,
446 |             "message": f"Error importing rules: {str(e)}",
447 |             "error": str(e),  # Include the original error message
448 |         }
449 | 
```

--------------------------------------------------------------------------------
/tests/unit/test_mcp_tools/test_rule_tools_extended.py:
--------------------------------------------------------------------------------

```python
  1 | """Extended tests for rule tools to improve coverage."""
  2 | 
  3 | import json
  4 | from datetime import UTC, datetime
  5 | from unittest.mock import AsyncMock, MagicMock, Mock, call, patch
  6 | 
  7 | import pytest
  8 | 
  9 | from yaraflux_mcp_server.mcp_tools.rule_tools import (
 10 |     add_yara_rule,
 11 |     delete_yara_rule,
 12 |     get_yara_rule,
 13 |     import_threatflux_rules,
 14 |     list_yara_rules,
 15 |     update_yara_rule,
 16 |     validate_yara_rule,
 17 | )
 18 | from yaraflux_mcp_server.yara_service import YaraError, YaraRuleMetadata
 19 | 
 20 | 
 21 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 22 | def test_list_yara_rules_value_error(mock_yara_service):
 23 |     """Test list_yara_rules with invalid source filter."""
 24 |     # Call the function with invalid source
 25 |     result = list_yara_rules(source="invalid")
 26 | 
 27 |     # Verify error handling
 28 |     assert isinstance(result, list)
 29 |     assert len(result) == 0
 30 | 
 31 |     # Verify service not called with invalid source
 32 |     mock_yara_service.list_rules.assert_not_called()
 33 | 
 34 | 
 35 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 36 | def test_list_yara_rules_exception(mock_yara_service):
 37 |     """Test list_yara_rules with general exception."""
 38 |     # Setup mock to raise exception
 39 |     mock_yara_service.list_rules.side_effect = Exception("Service error")
 40 | 
 41 |     # Call the function
 42 |     result = list_yara_rules()
 43 | 
 44 |     # Verify error handling
 45 |     assert isinstance(result, list)
 46 |     assert len(result) == 0
 47 | 
 48 |     # Verify service was called
 49 |     mock_yara_service.list_rules.assert_called_once()
 50 | 
 51 | 
 52 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 53 | def test_list_yara_rules_all_source(mock_yara_service):
 54 |     """Test list_yara_rules with 'all' source filter."""
 55 |     # Setup mock rules
 56 |     rule1 = YaraRuleMetadata(name="rule1", source="custom", created=datetime.now(UTC), is_compiled=True)
 57 |     rule2 = YaraRuleMetadata(name="rule2", source="community", created=datetime.now(UTC), is_compiled=True)
 58 |     mock_yara_service.list_rules.return_value = [rule1, rule2]
 59 | 
 60 |     # Call the function with 'all' source
 61 |     result = list_yara_rules(source="all")
 62 | 
 63 |     # Verify the result
 64 |     assert isinstance(result, list)
 65 |     assert len(result) == 2
 66 | 
 67 |     # Verify service was called with None to get all rules
 68 |     mock_yara_service.list_rules.assert_called_with(None)
 69 | 
 70 | 
 71 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 72 | def test_get_yara_rule_invalid_source(mock_yara_service):
 73 |     """Test get_yara_rule with invalid source."""
 74 |     # Call the function with invalid source
 75 |     result = get_yara_rule(rule_name="test", source="invalid")
 76 | 
 77 |     # Verify error handling
 78 |     assert isinstance(result, dict)
 79 |     assert "success" in result
 80 |     assert result["success"] is False
 81 |     assert "message" in result
 82 |     assert "Invalid source" in result["message"]
 83 | 
 84 |     # Verify service not called with invalid source
 85 |     mock_yara_service.get_rule.assert_not_called()
 86 | 
 87 | 
 88 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
 89 | def test_get_yara_rule_yara_error(mock_yara_service):
 90 |     """Test get_yara_rule with YaraError."""
 91 |     # Setup mock to raise YaraError
 92 |     mock_yara_service.get_rule.side_effect = YaraError("Rule not found")
 93 | 
 94 |     # Call the function
 95 |     result = get_yara_rule(rule_name="nonexistent", source="custom")
 96 | 
 97 |     # Verify error handling
 98 |     assert isinstance(result, dict)
 99 |     assert "success" in result
100 |     assert result["success"] is False
101 |     assert "message" in result
102 |     assert "Rule not found" in result["message"]
103 | 
104 |     # Verify service was called
105 |     mock_yara_service.get_rule.assert_called_once()
106 | 
107 | 
108 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
109 | def test_get_yara_rule_general_exception(mock_yara_service):
110 |     """Test get_yara_rule with general exception."""
111 |     # Setup mock to raise general exception
112 |     mock_yara_service.get_rule.side_effect = Exception("Unexpected error")
113 | 
114 |     # Call the function
115 |     result = get_yara_rule(rule_name="test", source="custom")
116 | 
117 |     # Verify error handling
118 |     assert isinstance(result, dict)
119 |     assert "success" in result
120 |     assert result["success"] is False
121 |     assert "message" in result
122 |     assert "Unexpected error" in result["message"]
123 | 
124 |     # Verify service was called
125 |     mock_yara_service.get_rule.assert_called_once()
126 | 
127 | 
128 | def test_validate_yara_rule_empty_content():
129 |     """Test validate_yara_rule with empty content."""
130 |     # Call the function with empty content
131 |     result = validate_yara_rule(content="")
132 | 
133 |     # Verify error handling
134 |     assert isinstance(result, dict)
135 |     assert "valid" in result
136 |     assert result["valid"] is False
137 |     assert "message" in result
138 |     assert "cannot be empty" in result["message"].lower()
139 | 
140 | 
141 | def test_validate_yara_rule_import_error():
142 |     """Test validate_yara_rule with import error."""
143 |     # Patch yara import to raise ImportError
144 |     with patch("importlib.import_module") as mock_import:
145 |         mock_import.side_effect = ImportError("No module named 'yara'")
146 | 
147 |         # Call the function
148 |         result = validate_yara_rule(content="rule test { condition: true }")
149 | 
150 |     # Verify error handling - should still work through the module path
151 |     assert isinstance(result, dict)
152 |     assert "valid" in result
153 |     # The outcome depends on whether yara is actually available
154 | 
155 | 
156 | def test_validate_yara_rule_complex_rule():
157 |     """Test validate_yara_rule with a more complex rule."""
158 |     complex_rule = """
159 |     rule ComplexRule {
160 |         meta:
161 |             description = "This is a complex rule"
162 |             author = "Test Author"
163 |             reference = "https://example.com"
164 |         strings:
165 |             $a = "suspicious string"
166 |             $b = /[0-9a-f]{32}/
167 |             $c = { 48 54 54 50 2F 31 2E 31 }  // HTTP/1.1 in hex
168 |         condition:
169 |             all of ($a, $b, $c) and filesize < 1MB
170 |     }
171 |     """
172 | 
173 |     # Patch the yara module
174 |     with patch("yara.compile") as mock_compile:
175 |         # Call the function
176 |         result = validate_yara_rule(content=complex_rule)
177 | 
178 |     # Verify the function processed it
179 |     assert isinstance(result, dict)
180 |     assert "valid" in result
181 | 
182 | 
183 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
184 | def test_add_yara_rule_invalid_source(mock_yara_service):
185 |     """Test add_yara_rule with invalid source."""
186 |     # Call the function with invalid source
187 |     result = add_yara_rule(name="test_rule", content="rule test { condition: true }", source="invalid")
188 | 
189 |     # Verify error handling
190 |     assert isinstance(result, dict)
191 |     assert "success" in result
192 |     assert result["success"] is False
193 |     assert "message" in result
194 |     assert "Invalid source" in result["message"]
195 | 
196 |     # Verify service not called with invalid source
197 |     mock_yara_service.add_rule.assert_not_called()
198 | 
199 | 
200 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
201 | def test_add_yara_rule_empty_content(mock_yara_service):
202 |     """Test add_yara_rule with empty content."""
203 |     # Call the function with empty content
204 |     result = add_yara_rule(name="test_rule", content="", source="custom")
205 | 
206 |     # Verify error handling
207 |     assert isinstance(result, dict)
208 |     assert "success" in result
209 |     assert result["success"] is False
210 |     assert "message" in result
211 |     assert "cannot be empty" in result["message"].lower()
212 | 
213 |     # Verify service not called with invalid content
214 |     mock_yara_service.add_rule.assert_not_called()
215 | 
216 | 
217 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
218 | def test_add_yara_rule_yara_error(mock_yara_service):
219 |     """Test add_yara_rule with YaraError."""
220 |     # Setup mock to raise YaraError
221 |     mock_yara_service.add_rule.side_effect = YaraError("Failed to compile rule")
222 | 
223 |     # Call the function
224 |     result = add_yara_rule(name="test_rule", content="rule test { condition: true }", source="custom")
225 | 
226 |     # Verify error handling
227 |     assert isinstance(result, dict)
228 |     assert "success" in result
229 |     assert result["success"] is False
230 |     assert "message" in result
231 |     assert "Failed to compile rule" in result["message"]
232 | 
233 |     # Verify service was called
234 |     mock_yara_service.add_rule.assert_called_once()
235 | 
236 | 
237 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
238 | def test_add_yara_rule_general_exception(mock_yara_service):
239 |     """Test add_yara_rule with general exception."""
240 |     # Setup mock to raise general exception
241 |     mock_yara_service.add_rule.side_effect = Exception("Unexpected error")
242 | 
243 |     # Call the function
244 |     result = add_yara_rule(name="test_rule", content="rule test { condition: true }", source="custom")
245 | 
246 |     # Verify error handling
247 |     assert isinstance(result, dict)
248 |     assert "success" in result
249 |     assert result["success"] is False
250 |     assert "message" in result
251 |     assert "Unexpected error" in result["message"]
252 | 
253 |     # Verify service was called
254 |     mock_yara_service.add_rule.assert_called_once()
255 | 
256 | 
257 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
258 | def test_update_yara_rule_invalid_source(mock_yara_service):
259 |     """Test update_yara_rule with invalid source."""
260 |     # Call the function with invalid source
261 |     result = update_yara_rule(name="test_rule", content="rule test { condition: true }", source="invalid")
262 | 
263 |     # Verify error handling
264 |     assert isinstance(result, dict)
265 |     assert "success" in result
266 |     assert result["success"] is False
267 |     assert "message" in result
268 |     assert "Invalid source" in result["message"]
269 | 
270 |     # Verify service not called with invalid source
271 |     mock_yara_service.update_rule.assert_not_called()
272 | 
273 | 
274 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
275 | def test_update_yara_rule_empty_content(mock_yara_service):
276 |     """Test update_yara_rule with empty content."""
277 |     # Call the function with empty content
278 |     result = update_yara_rule(name="test_rule", content="", source="custom")
279 | 
280 |     # Verify error handling
281 |     assert isinstance(result, dict)
282 |     assert "success" in result
283 |     assert result["success"] is False
284 |     assert "message" in result
285 |     assert "cannot be empty" in result["message"].lower()
286 | 
287 |     # Verify service not called with invalid content
288 |     mock_yara_service.update_rule.assert_not_called()
289 | 
290 | 
291 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
292 | def test_update_yara_rule_rule_not_found(mock_yara_service):
293 |     """Test update_yara_rule with nonexistent rule."""
294 |     # Setup mock to raise YaraError for get_rule
295 |     mock_yara_service.get_rule.side_effect = YaraError("Rule not found")
296 | 
297 |     # Call the function
298 |     result = update_yara_rule(name="nonexistent", content="rule test { condition: true }", source="custom")
299 | 
300 |     # Verify error handling
301 |     assert isinstance(result, dict)
302 |     assert "success" in result
303 |     assert result["success"] is False
304 |     assert "message" in result
305 |     assert "Rule not found" in result["message"]
306 | 
307 |     # Verify get_rule was called but update_rule was not
308 |     mock_yara_service.get_rule.assert_called_once()
309 |     mock_yara_service.update_rule.assert_not_called()
310 | 
311 | 
312 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
313 | def test_update_yara_rule_yara_error(mock_yara_service):
314 |     """Test update_yara_rule with YaraError during update."""
315 |     # Setup mocks
316 |     mock_yara_service.get_rule.return_value = "original content"
317 |     mock_yara_service.update_rule.side_effect = YaraError("Failed to compile rule")
318 | 
319 |     # Call the function
320 |     result = update_yara_rule(name="test_rule", content="rule test { condition: true }", source="custom")
321 | 
322 |     # Verify error handling
323 |     assert isinstance(result, dict)
324 |     assert "success" in result
325 |     assert result["success"] is False
326 |     assert "message" in result
327 |     assert "Failed to compile rule" in result["message"]
328 | 
329 |     # Verify both methods were called
330 |     mock_yara_service.get_rule.assert_called_once()
331 |     mock_yara_service.update_rule.assert_called_once()
332 | 
333 | 
334 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
335 | def test_delete_yara_rule_invalid_source(mock_yara_service):
336 |     """Test delete_yara_rule with invalid source."""
337 |     # Call the function with invalid source
338 |     result = delete_yara_rule(name="test_rule", source="invalid")
339 | 
340 |     # Verify error handling
341 |     assert isinstance(result, dict)
342 |     assert "success" in result
343 |     assert result["success"] is False
344 |     assert "message" in result
345 |     assert "Invalid source" in result["message"]
346 | 
347 |     # Verify service not called with invalid source
348 |     mock_yara_service.delete_rule.assert_not_called()
349 | 
350 | 
351 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
352 | def test_delete_yara_rule_yara_error(mock_yara_service):
353 |     """Test delete_yara_rule with YaraError."""
354 |     # Setup mock to raise YaraError
355 |     mock_yara_service.delete_rule.side_effect = YaraError("Error deleting rule")
356 | 
357 |     # Call the function
358 |     result = delete_yara_rule(name="test_rule", source="custom")
359 | 
360 |     # Verify error handling
361 |     assert isinstance(result, dict)
362 |     assert "success" in result
363 |     assert result["success"] is False
364 |     assert "message" in result
365 |     assert "Error deleting rule" in result["message"]
366 | 
367 |     # Verify service was called
368 |     mock_yara_service.delete_rule.assert_called_once()
369 | 
370 | 
371 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
372 | def test_delete_yara_rule_general_exception(mock_yara_service):
373 |     """Test delete_yara_rule with general exception."""
374 |     # Setup mock to raise general exception
375 |     mock_yara_service.delete_rule.side_effect = Exception("Unexpected error")
376 | 
377 |     # Call the function
378 |     result = delete_yara_rule(name="test_rule", source="custom")
379 | 
380 |     # Verify error handling
381 |     assert isinstance(result, dict)
382 |     assert "success" in result
383 |     assert result["success"] is False
384 |     assert "message" in result
385 |     assert "Unexpected error" in result["message"]
386 | 
387 |     # Verify service was called
388 |     mock_yara_service.delete_rule.assert_called_once()
389 | 
390 | 
391 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
392 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
393 | def test_import_threatflux_rules_connection_error(mock_yara_service, mock_httpx):
394 |     """Test import_threatflux_rules with connection error."""
395 |     if not mock_yara_service:
396 |         pass
397 |     # Setup mock to raise connection error
398 |     mock_httpx.get.side_effect = Exception("Connection error")
399 | 
400 |     # Call the function
401 |     result = import_threatflux_rules()
402 | 
403 |     # Verify error handling - the implementation returns success=False
404 |     assert isinstance(result, dict)
405 |     assert "success" in result
406 |     assert not result["success"]  # Should be False
407 |     assert "Connection error" in str(result)
408 |     assert "message" in result
409 |     assert "Error importing rules: Connection error" in result["message"]
410 | 
411 |     # Verify httpx.get was called
412 |     mock_httpx.get.assert_called_once()
413 | 
414 | 
415 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
416 | def test_import_threatflux_rules_http_error(mock_httpx):
417 |     """Test import_threatflux_rules with HTTP error."""
418 |     # Setup mock response with error status
419 |     mock_response = Mock()
420 |     mock_response.status_code = 404
421 |     mock_httpx.get.return_value = mock_response
422 | 
423 |     # Call the function
424 |     result = import_threatflux_rules()
425 | 
426 |     # Verify the function handles the HTTP error
427 |     assert isinstance(result, dict)
428 |     # The function might not return an error since it handles HTTP errors
429 |     # by trying alternative approaches
430 | 
431 | 
432 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
433 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
434 | def test_import_threatflux_rules_no_index(mock_httpx, mock_yara_service):
435 |     """Test import_threatflux_rules with no index.json."""
436 |     # Setup mock test response (success)
437 |     mock_test_response = Mock()
438 |     mock_test_response.status_code = 200
439 | 
440 |     # Setup mock for index.json request
441 |     mock_index_response = Mock()
442 |     mock_index_response.status_code = 404  # Not found
443 | 
444 |     # Setup mock for individual rule file requests
445 |     mock_rule_response = Mock()
446 |     mock_rule_response.status_code = 200
447 |     mock_rule_response.text = "rule test { condition: true }"
448 | 
449 |     # Configure return values - first test response is success, then 404 for index, then rule responses
450 |     mock_httpx.get.side_effect = [mock_test_response, mock_index_response, mock_rule_response, mock_rule_response]
451 | 
452 |     # Call the function
453 |     result = import_threatflux_rules()
454 | 
455 |     # Verify fallback behavior
456 |     assert isinstance(result, dict)
457 |     # Should try to get individual rule files from common directories
458 | 
459 |     # With the new connection test, get should be called at least twice:
460 |     # 1. For the initial connection test
461 |     # 2. For the index.json file
462 |     assert mock_httpx.get.call_count >= 2
463 | 
464 |     # Should try to get rule from directories like malware, general, etc.
465 |     # using a path pattern like {import_path}/{directory}/{rule_file}
466 | 
467 | 
468 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.yara_service")
469 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
470 | def test_import_threatflux_rules_custom_url_branch(mock_httpx, mock_yara_service):
471 |     """Test import_threatflux_rules with custom URL and branch."""
472 |     # Setup mock response
473 |     mock_response = Mock()
474 |     mock_response.status_code = 200
475 |     mock_response.json.return_value = {"rules": ["rule1.yar"]}
476 |     mock_response.text = "rule test { condition: true }"
477 |     mock_httpx.get.return_value = mock_response
478 | 
479 |     # We don't need to mock the async function since import_threatflux_rules doesn't use it
480 |     # Call the function with custom URL and branch
481 |     result = import_threatflux_rules(url="https://github.com/custom/repo", branch="dev")
482 | 
483 |     # Verify the result
484 |     assert isinstance(result, dict)
485 |     assert "success" in result
486 |     assert result["success"] is True
487 | 
488 |     # Verify httpx.get called with correct URL including branch
489 |     expected_url = "https://raw.githubusercontent.com/custom/repo/dev/index.json"
490 |     mock_httpx.get.assert_any_call(expected_url, follow_redirects=True)
491 | 
492 | 
493 | # Skip this test since it requires more complex mocking - focus on other tests first
494 | @pytest.mark.skip(reason="Test skipped - requires complex patching for file:// URLs")
495 | @patch("yaraflux_mcp_server.mcp_tools.rule_tools.httpx")
496 | def test_import_threatflux_rules_local_path(mock_httpx):
497 |     """Test import_threatflux_rules with local path."""
498 |     # This test is skipped because it requires complex patching for file:// URLs
499 |     # The real functionality is tested in integration tests
500 |     assert True
501 | 
```

--------------------------------------------------------------------------------
/tests/unit/test_utils/test_wrapper_generator.py:
--------------------------------------------------------------------------------

```python
  1 | """Unit tests for wrapper_generator utilities."""
  2 | 
  3 | import inspect
  4 | import logging
  5 | from typing import Any, Dict, List, Optional
  6 | from unittest.mock import MagicMock, Mock, patch
  7 | 
  8 | import pytest
  9 | 
 10 | from yaraflux_mcp_server.utils.wrapper_generator import (
 11 |     create_tool_wrapper,
 12 |     extract_enhanced_docstring,
 13 |     extract_param_schema_from_func,
 14 |     register_tool_with_schema,
 15 | )
 16 | 
 17 | 
 18 | class TestCreateToolWrapper:
 19 |     """Tests for create_tool_wrapper function."""
 20 | 
 21 |     def test_basic_wrapper_creation(self):
 22 |         """Test creating a basic wrapper."""
 23 | 
 24 |         # Define a simple function to wrap
 25 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
 26 |             """Test function.
 27 | 
 28 |             Args:
 29 |                 param1: First parameter
 30 |                 param2: Second parameter
 31 | 
 32 |             Returns:
 33 |                 Dictionary with result
 34 |             """
 35 |             return {"result": f"{param1}-{param2}"}
 36 | 
 37 |         # Create mock MCP
 38 |         mock_mcp = Mock()
 39 |         mock_mcp.tool.return_value = lambda f: f
 40 | 
 41 |         # Create wrapper
 42 |         wrapper = create_tool_wrapper(mcp=mock_mcp, func_name="test_function", actual_func=test_function)
 43 | 
 44 |         # Verify function registration
 45 |         mock_mcp.tool.assert_called_once()
 46 | 
 47 |         # Call the wrapper with valid params
 48 |         result = wrapper("param1=test&param2=5")
 49 | 
 50 |         # Verify result
 51 |         assert result == {"result": "test-5"}
 52 | 
 53 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.parse_params")
 54 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.extract_typed_params")
 55 |     def test_wrapper_with_all_params(self, mock_extract_params, mock_parse_params):
 56 |         """Test wrapper that uses all parameter types."""
 57 | 
 58 |         # Define a function with various param types
 59 |         def test_function(
 60 |             string_param: str,
 61 |             int_param: int,
 62 |             float_param: float,
 63 |             bool_param: bool,
 64 |             list_param: List[str],
 65 |             optional_param: Optional[str] = None,
 66 |         ) -> Dict[str, Any]:
 67 |             """Test function with many param types."""
 68 |             return {
 69 |                 "string": string_param,
 70 |                 "int": int_param,
 71 |                 "float": float_param,
 72 |                 "bool": bool_param,
 73 |                 "list": list_param,
 74 |                 "optional": optional_param,
 75 |             }
 76 | 
 77 |         # Setup mocks
 78 |         mock_mcp = Mock()
 79 |         mock_mcp.tool.return_value = lambda f: f
 80 | 
 81 |         # Mock parse_params to return a dict
 82 |         mock_parse_params.return_value = {
 83 |             "string_param": "test",
 84 |             "int_param": "5",
 85 |             "float_param": "3.14",
 86 |             "bool_param": "true",
 87 |             "list_param": "a,b,c",
 88 |             "optional_param": "optional",
 89 |         }
 90 | 
 91 |         # Mock extract_typed_params to return typed values
 92 |         mock_extract_params.return_value = {
 93 |             "string_param": "test",
 94 |             "int_param": 5,
 95 |             "float_param": 3.14,
 96 |             "bool_param": True,
 97 |             "list_param": ["a", "b", "c"],
 98 |             "optional_param": "optional",
 99 |         }
100 | 
101 |         # Create wrapper
102 |         wrapper = create_tool_wrapper(mcp=mock_mcp, func_name="test_function", actual_func=test_function)
103 | 
104 |         # Call the wrapper
105 |         result = wrapper("params string doesn't matter with mocks")
106 | 
107 |         # Verify result
108 |         expected = {
109 |             "string": "test",
110 |             "int": 5,
111 |             "float": 3.14,
112 |             "bool": True,
113 |             "list": ["a", "b", "c"],
114 |             "optional": "optional",
115 |         }
116 |         assert result == expected
117 | 
118 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.logger")
119 |     def test_wrapper_logs_params(self, mock_logger):
120 |         """Test that wrapper logs parameters."""
121 | 
122 |         # Define a simple function to wrap
123 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
124 |             """Test function."""
125 |             return {"result": f"{param1}-{param2}"}
126 | 
127 |         # Create mock MCP
128 |         mock_mcp = Mock()
129 |         mock_mcp.tool.return_value = lambda f: f
130 | 
131 |         # Create wrapper
132 |         wrapper = create_tool_wrapper(
133 |             mcp=mock_mcp, func_name="test_function", actual_func=test_function, log_params=True
134 |         )
135 | 
136 |         # Call the wrapper
137 |         wrapper("param1=test&param2=5")
138 | 
139 |         # Verify logging - use the exact logger instance that's defined in the module
140 |         mock_logger.info.assert_called_once_with("test_function called with params: param1=test&param2=5")
141 | 
142 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.logger")
143 |     def test_wrapper_logs_without_params(self, mock_logger):
144 |         """Test that wrapper logs even without parameters."""
145 | 
146 |         # Define a function with no params
147 |         def test_function() -> Dict[str, Any]:
148 |             """Test function with no params."""
149 |             return {"result": "success"}
150 | 
151 |         # Create mock MCP
152 |         mock_mcp = Mock()
153 |         mock_mcp.tool.return_value = lambda f: f
154 | 
155 |         # Create wrapper
156 |         wrapper = create_tool_wrapper(
157 |             mcp=mock_mcp, func_name="test_function", actual_func=test_function, log_params=False
158 |         )
159 | 
160 |         # Call the wrapper
161 |         wrapper("")
162 | 
163 |         # Verify logging without params - use the exact logger instance in the module
164 |         mock_logger.info.assert_called_once_with("test_function called")
165 | 
166 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.handle_tool_error")
167 |     def test_wrapper_handles_missing_required_param(self, mock_handle_error):
168 |         """Test wrapper handling missing required parameter."""
169 | 
170 |         # Define a function with required params
171 |         def test_function(required_param: str) -> Dict[str, Any]:
172 |             """Test function with required param."""
173 |             return {"result": required_param}
174 | 
175 |         # Create mock MCP
176 |         mock_mcp = Mock()
177 |         mock_mcp.tool.return_value = lambda f: f
178 | 
179 |         # Set up mock error handler to return a standard error response
180 |         mock_handle_error.return_value = {"error": "Required parameter 'required_param' is missing"}
181 | 
182 |         # Create wrapper
183 |         wrapper = create_tool_wrapper(mcp=mock_mcp, func_name="test_function", actual_func=test_function)
184 | 
185 |         # Call with missing param
186 |         result = wrapper("")
187 | 
188 |         # Verify error was handled properly
189 |         assert "error" in result
190 |         assert "required_param" in result["error"]
191 |         mock_handle_error.assert_called_once()
192 | 
193 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.logger")
194 |     @patch("yaraflux_mcp_server.utils.wrapper_generator.handle_tool_error")
195 |     def test_wrapper_handles_exception(self, mock_handle_error, mock_logger):
196 |         """Test wrapper handling exception in wrapped function."""
197 | 
198 |         # Define a function that raises an exception
199 |         def test_function() -> Dict[str, Any]:
200 |             """Test function that raises an exception."""
201 |             raise ValueError("Test exception")
202 | 
203 |         # Create mock MCP
204 |         mock_mcp = Mock()
205 |         mock_mcp.tool.return_value = lambda f: f
206 | 
207 |         # Setup mock error handler
208 |         mock_handle_error.return_value = {"error": "Test exception"}
209 | 
210 |         # Create wrapper
211 |         wrapper = create_tool_wrapper(mcp=mock_mcp, func_name="test_function", actual_func=test_function)
212 | 
213 |         # Call wrapper should handle the exception
214 |         result = wrapper("")
215 | 
216 |         # Verify error handling
217 |         assert result == {"error": "Test exception"}
218 |         mock_handle_error.assert_called_once()
219 | 
220 | 
221 | class TestExtractEnhancedDocstring:
222 |     """Tests for extract_enhanced_docstring function."""
223 | 
224 |     def test_extract_basic_docstring(self):
225 |         """Test extracting a basic docstring."""
226 | 
227 |         # Define a function with a basic docstring
228 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
229 |             """Test function docstring."""
230 |             return {"result": "success"}
231 | 
232 |         # Extract docstring
233 |         docstring = extract_enhanced_docstring(test_function)
234 | 
235 |         # Verify docstring structure
236 |         assert isinstance(docstring, dict)
237 |         assert docstring["description"] == "Test function docstring."
238 |         assert docstring["param_descriptions"] == {}
239 |         assert docstring["returns_description"] == ""
240 |         assert docstring["examples"] == []
241 | 
242 |     def test_extract_full_docstring(self):
243 |         """Test extracting a full docstring with args and returns."""
244 | 
245 |         # Define a function with a full docstring
246 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
247 |             """Test function with full docstring.
248 | 
249 |             This function demonstrates a full docstring with Args and Returns sections.
250 | 
251 |             Args:
252 |                 param1: First parameter description
253 |                 param2: Second parameter description
254 | 
255 |             Returns:
256 |                 Dictionary with success result
257 |             """
258 |             return {"result": "success"}
259 | 
260 |         # Extract docstring
261 |         docstring = extract_enhanced_docstring(test_function)
262 | 
263 |         # Verify it contains the main description and the Args/Returns sections
264 |         assert "Test function with full docstring" in docstring["description"]
265 |         assert "This function demonstrates" in docstring["description"]
266 |         assert docstring["param_descriptions"]["param1"] == "First parameter description"
267 |         assert docstring["param_descriptions"]["param2"] == "Second parameter description"
268 |         assert docstring["returns_description"] == "Dictionary with success result"
269 | 
270 |     def test_extract_docstring_with_no_args(self):
271 |         """Test extracting a docstring with no args section."""
272 | 
273 |         # Define a function with no args in docstring
274 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
275 |             """Test function docstring.
276 | 
277 |             Returns:
278 |                 Dictionary with success result
279 |             """
280 |             return {"result": "success"}
281 | 
282 |         # Extract docstring
283 |         docstring = extract_enhanced_docstring(test_function)
284 | 
285 |         # Verify it contains the main description and Returns but no Args
286 |         assert "Test function docstring" in docstring["description"]
287 |         assert docstring["param_descriptions"] == {}
288 |         assert docstring["returns_description"] == "Dictionary with success result"
289 | 
290 |     def test_extract_docstring_with_no_returns(self):
291 |         """Test extracting a docstring with no returns section."""
292 | 
293 |         # Define a function with no returns in docstring
294 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
295 |             """Test function docstring.
296 | 
297 |             Args:
298 |                 param1: First parameter description
299 |                 param2: Second parameter description
300 |             """
301 |             return {"result": "success"}
302 | 
303 |         # Extract docstring
304 |         docstring = extract_enhanced_docstring(test_function)
305 | 
306 |         # Verify it contains the main description and Args but no Returns
307 |         assert "Test function docstring" in docstring["description"]
308 |         assert docstring["param_descriptions"]["param1"] == "First parameter description"
309 |         assert docstring["param_descriptions"]["param2"] == "Second parameter description"
310 |         assert docstring["returns_description"] == ""
311 | 
312 |     def test_extract_no_docstring(self):
313 |         """Test extracting when there's no docstring."""
314 | 
315 |         # Define a function with no docstring
316 |         def test_function(param1: str, param2: int) -> Dict[str, Any]:
317 |             return {"result": "success"}
318 | 
319 |         # Extract docstring
320 |         docstring = extract_enhanced_docstring(test_function)
321 | 
322 |         # Verify it returns an empty dict structure
323 |         assert docstring["description"] == ""
324 |         assert docstring["param_descriptions"] == {}
325 |         assert docstring["returns_description"] == ""
326 |         assert docstring["examples"] == []
327 | 
328 | 
329 | class TestExtractParamSchemaFromFunc:
330 |     """Tests for extract_param_schema_from_func function."""
331 | 
332 |     def test_extract_basic_schema(self):
333 |         """Test extracting a basic schema from function."""
334 | 
335 |         # Define a function with basic types
336 |         def test_function(string_param: str, int_param: int, bool_param: bool) -> Dict[str, Any]:
337 |             """Test function with basic types."""
338 |             return {"result": "success"}
339 | 
340 |         # Extract schema
341 |         schema = extract_param_schema_from_func(test_function)
342 | 
343 |         # Verify schema
344 |         assert "string_param" in schema
345 |         assert "int_param" in schema
346 |         assert "bool_param" in schema
347 |         assert schema["string_param"]["type"] == str
348 |         assert schema["int_param"]["type"] == int
349 |         assert schema["bool_param"]["type"] == bool
350 |         assert schema["string_param"]["required"] is True
351 |         assert schema["int_param"]["required"] is True
352 |         assert schema["bool_param"]["required"] is True
353 | 
354 |     def test_extract_schema_skip_self(self):
355 |         """Test extracting schema skips 'self' parameter."""
356 | 
357 |         # Define a class method that has 'self'
358 |         class TestClass:
359 |             def test_method(self, param1: str, param2: int) -> Dict[str, Any]:
360 |                 """Test method with self parameter."""
361 |                 return {"result": "success"}
362 | 
363 |         # Extract schema
364 |         schema = extract_param_schema_from_func(TestClass().test_method)
365 | 
366 |         # Verify schema skips 'self'
367 |         assert "self" not in schema
368 |         assert "param1" in schema
369 |         assert "param2" in schema
370 | 
371 |     def test_extract_schema_with_complex_types(self):
372 |         """Test extracting schema with complex types."""
373 | 
374 |         # Define a function with complex types
375 |         def test_function(
376 |             simple_param: str,
377 |             list_param: List[str],
378 |             optional_param: Optional[int] = None,
379 |             default_param: str = "default",
380 |         ) -> Dict[str, Any]:
381 |             """Test function with complex types."""
382 |             return {"result": "success"}
383 | 
384 |         # Extract schema
385 |         schema = extract_param_schema_from_func(test_function)
386 | 
387 |         # Verify schema
388 |         assert schema["simple_param"]["type"] == str
389 |         assert schema["list_param"]["type"] == List[str]
390 |         assert schema["optional_param"]["type"] == Optional[int]
391 |         assert schema["default_param"]["type"] == str
392 |         assert schema["default_param"]["default"] == "default"
393 |         assert schema["simple_param"]["required"] is True
394 |         assert schema["list_param"]["required"] is True
395 |         assert schema["optional_param"]["required"] is False
396 |         assert schema["default_param"]["required"] is False
397 | 
398 | 
399 | class TestRegisterToolWithSchema:
400 |     """Tests for register_tool_with_schema function."""
401 | 
402 |     def test_register_tool_basic(self):
403 |         """Test registering a basic tool."""
404 |         # Create mock MCP handler
405 |         mock_mcp = Mock()
406 | 
407 |         # Define a function to register
408 |         def test_tool(param1: str, param2: int) -> Dict[str, Any]:
409 |             """Test tool function."""
410 |             return {"result": f"{param1}-{param2}"}
411 | 
412 |         # Register the tool
413 |         register_tool_with_schema(
414 |             mcp=mock_mcp,
415 |             func_name="test_tool",
416 |             actual_func=test_tool,
417 |         )
418 | 
419 |         # Verify tool was registered with handler.tool()
420 |         mock_mcp.tool.assert_called_once()
421 | 
422 |     def test_register_with_custom_schema(self):
423 |         """Test registering a tool with custom schema."""
424 |         # Create mock MCP handler
425 |         mock_mcp = Mock()
426 | 
427 |         # Define a function to register
428 |         def test_tool(param1: str, param2: int) -> Dict[str, Any]:
429 |             """Test tool function."""
430 |             return {"result": "success"}
431 | 
432 |         # Define custom schema
433 |         custom_schema = {
434 |             "custom_param1": {"type": str, "description": "Custom description", "required": True},
435 |             "custom_param2": {"type": int, "required": False},
436 |         }
437 | 
438 |         # Register the tool with custom schema
439 |         register_tool_with_schema(
440 |             mcp=mock_mcp, func_name="test_tool_custom", actual_func=test_tool, param_schema=custom_schema
441 |         )
442 | 
443 |         # Verify tool was registered
444 |         mock_mcp.tool.assert_called_once()
445 | 
446 |     def test_register_tool_logs_params(self):
447 |         """Test that tool registration logs parameters."""
448 |         # Create mock MCP handler
449 |         mock_mcp = Mock()
450 | 
451 |         # Define a function to register
452 |         def test_tool(param1: str, param2: int) -> Dict[str, Any]:
453 |             """Test tool function."""
454 |             return {"result": f"{param1}-{param2}"}
455 | 
456 |         # Register the tool
457 |         result = register_tool_with_schema(
458 |             mcp=mock_mcp,
459 |             func_name="test_tool",
460 |             actual_func=test_tool,
461 |         )
462 | 
463 |         # Verify registration successful
464 |         mock_mcp.tool.assert_called_once()
465 | 
466 |     def test_register_tool_handles_exception(self):
467 |         """Test that tool registration handles exceptions."""
468 |         # Create mock MCP handler that raises exception
469 |         mock_mcp = Mock()
470 |         mock_mcp.tool.side_effect = ValueError("Registration error")
471 | 
472 |         # Define a function to register
473 |         def test_tool(param1: str) -> Dict[str, Any]:
474 |             """Test tool function."""
475 |             return {"result": param1}
476 | 
477 |         # Register the tool should handle the exception
478 |         with pytest.raises(ValueError) as excinfo:
479 |             register_tool_with_schema(
480 |                 mcp=mock_mcp,
481 |                 func_name="test_tool",
482 |                 actual_func=test_tool,
483 |             )
484 | 
485 |         assert "Registration error" in str(excinfo.value)
486 | 
487 |     def test_wrapper_preserves_docstring(self):
488 |         """Test that registered tool wrapper preserves docstring."""
489 |         # Create mock MCP handler
490 |         mock_mcp = Mock()
491 | 
492 |         # Create a mock that captures the wrapped function
493 |         def capture_wrapper(*args, **kwargs):
494 |             called_with = kwargs
495 |             return lambda f: f
496 | 
497 |         mock_mcp.tool.side_effect = capture_wrapper
498 | 
499 |         # Define a function with docstring
500 |         def test_tool(param1: str) -> Dict[str, Any]:
501 |             """Test tool docstring.
502 | 
503 |             This is a multiline docstring.
504 | 
505 |             Args:
506 |                 param1: Parameter description
507 | 
508 |             Returns:
509 |                 Dictionary with result
510 |             """
511 |             return {"result": param1}
512 | 
513 |         # Register the tool
514 |         result = register_tool_with_schema(
515 |             mcp=mock_mcp,
516 |             func_name="test_tool",
517 |             actual_func=test_tool,
518 |         )
519 | 
520 |         # Verify wrapper preserves docstring
521 |         assert result.__doc__ is not None
522 |         assert "Test tool docstring" in result.__doc__
523 |         assert "This is a multiline docstring" in result.__doc__
524 | 
```
Page 4/6FirstPrevNextLast