# Directory Structure ``` ├── .gitignore ├── LICENSE ├── plugin │ ├── ida_mcp_server_plugin │ │ ├── __init__.py │ │ └── ida_mcp_core.py │ └── ida_mcp_server_plugin.py ├── pyproject.toml ├── README.md ├── Screenshots │ ├── iShot_2025-03-15_18.54.53.png │ ├── iShot_2025-03-15_19.04.06.png │ └── iShot_2025-03-15_19.06.27.png ├── src │ └── mcp_server_ida │ ├── __init__.py │ ├── __main__.py │ └── server.py ├── test │ └── IDA Debug.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` 1 | .DS_Store 2 | __pycache__ 3 | ida_mcp_server.egg-info 4 | alternatives 5 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown 1 | # IDA MCP Server 2 | 3 | > [!NOTE] 4 | > The idalib mode is under development, and it will not require installing the IDA plugin or running IDA (idalib is available from IDA Pro 9.0+). 5 | 6 | ## Overview 7 | 8 | A Model Context Protocol server for IDA interaction and automation. This server provides tools to read IDA database via Large Language Models. 9 | 10 | Please note that mcp-server-ida is currently in early development. The functionality and available tools are subject to change and expansion as we continue to develop and improve the server. 11 | 12 | ## Installation 13 | 14 | ### Using uv (recommended) 15 | 16 | When using [`uv`](https://docs.astral.sh/uv/) no specific installation is needed. We will 17 | use [`uvx`](https://docs.astral.sh/uv/guides/tools/) to directly run *mcp-server-ida*. 18 | 19 | ### Using PIP 20 | 21 | Alternatively you can install `mcp-server-ida` via pip: 22 | 23 | ``` 24 | pip install mcp-server-ida 25 | ``` 26 | 27 | After installation, you can run it as a script using: 28 | 29 | ``` 30 | python -m mcp_server_ida 31 | ``` 32 | 33 | ### IDA-Side 34 | 35 | Copy `repository/plugin/ida_mcp_server_plugin.py` and `repository/plugin/ida_mcp_server_plugin` directory into IDAs plugin directory 36 | 37 | Windows: `%APPDATA%\Hex-Rays\IDA Pro\plugins` 38 | 39 | Linux/macOS: `$HOME/.idapro/plugins` eg: `~/.idapro/plugins` 40 | 41 | [igors-tip-of-the-week-103-sharing-plugins-between-ida-installs](https://hex-rays.com/blog/igors-tip-of-the-week-103-sharing-plugins-between-ida-installs) 42 | 43 | ## Configuration 44 | 45 | ### Usage with Claude Desktop 46 | 47 | Add this to your `claude_desktop_config.json`: 48 | 49 | <details> 50 | <summary>Using uvx</summary> 51 | 52 | ```json 53 | "mcpServers": { 54 | "ida": { 55 | "command": "uvx", 56 | "args": [ 57 | "mcp-server-ida" 58 | ] 59 | } 60 | } 61 | ``` 62 | </details> 63 | 64 | <details> 65 | <summary>Using pip installation</summary> 66 | 67 | ```json 68 | "mcpServers": { 69 | "ida": { 70 | "command": "python", 71 | "args": [ 72 | "-m", 73 | "mcp_server_ida" 74 | ] 75 | } 76 | } 77 | ``` 78 | </details> 79 | 80 | ## Debugging 81 | 82 | You can use the MCP inspector to debug the server. For uvx installations: 83 | 84 | ``` 85 | npx @modelcontextprotocol/inspector uvx mcp-server-ida 86 | ``` 87 | 88 | Or if you've installed the package in a specific directory or are developing on it: 89 | 90 | ``` 91 | cd path/to/mcp-server-ida/src 92 | npx @modelcontextprotocol/inspector uv run mcp-server-ida 93 | ``` 94 | 95 | Running `tail -n 20 -f ~/Library/Logs/Claude/mcp*.log` will show the logs from the server and may 96 | help you debug any issues. 97 | 98 | ## Development 99 | 100 | If you are doing local development, there are two ways to test your changes: 101 | 102 | 1. Run the MCP inspector to test your changes. See [Debugging](#debugging) for run instructions. 103 | 104 | 2. Test using the Claude desktop app. Add the following to your `claude_desktop_config.json`: 105 | 106 | ### UVX 107 | ```json 108 | { 109 | "mcpServers": { 110 | "ida": { 111 | "command": "uv", 112 | "args": [ 113 | "--directory", 114 | "/<path to mcp-server-ida>", 115 | "run", 116 | "mcp-server-ida" 117 | ] 118 | } 119 | } 120 | ``` 121 | 122 | ## Alternatives 123 | [ida-pro-mcp](https://github.com/mrexodia/ida-pro-mcp) 124 | 125 | [ida-mcp-server-plugin](https://github.com/taida957789/ida-mcp-server-plugin) 126 | 127 | [mcp-server-idapro](https://github.com/fdrechsler/mcp-server-idapro) 128 | 129 | [pcm](https://github.com/rand-tech/pcm) 130 | 131 | 132 | ## Screenshots 133 | 134 |  135 |  136 |  137 | ``` -------------------------------------------------------------------------------- /plugin/ida_mcp_server_plugin/__init__.py: -------------------------------------------------------------------------------- ```python 1 | ``` -------------------------------------------------------------------------------- /src/mcp_server_ida/__main__.py: -------------------------------------------------------------------------------- ```python 1 | from mcp_server_ida import main 2 | 3 | main() 4 | ``` -------------------------------------------------------------------------------- /src/mcp_server_ida/__init__.py: -------------------------------------------------------------------------------- ```python 1 | import click 2 | import logging 3 | import sys 4 | from .server import serve 5 | 6 | @click.command() 7 | @click.option("-v", "--verbose", count=True) 8 | def main(verbose: bool) -> None: 9 | """MCP IDA Server - IDA functionality for MCP""" 10 | import asyncio 11 | 12 | logging_level = logging.WARN 13 | if verbose == 1: 14 | logging_level = logging.INFO 15 | elif verbose >= 2: 16 | logging_level = logging.DEBUG 17 | 18 | logging.basicConfig(level=logging_level, stream=sys.stderr) 19 | asyncio.run(serve()) 20 | 21 | if __name__ == "__main__": 22 | main() 23 | ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml 1 | [project] 2 | name = "mcp-server-ida" 3 | version = "0.3.0" 4 | description = "A Model Context Protocol server providing tools to read, search IDA Database programmatically via LLMs" 5 | # readme = "README.md" 6 | requires-python = ">=3.10" 7 | authors = [{ name = "Mx-Iris" }] 8 | keywords = ["ida", "mcp", "llm", "automation"] 9 | license = { text = "MIT" } 10 | classifiers = [ 11 | "Development Status :: 4 - Beta", 12 | "Intended Audience :: Developers", 13 | "License :: OSI Approved :: MIT License", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.10", 16 | ] 17 | dependencies = [ 18 | "click>=8.1.7", 19 | "mcp>=1.0.0", 20 | "pydantic>=2.0.0", 21 | ] 22 | 23 | [project.scripts] 24 | mcp-server-ida = "mcp_server_ida:main" 25 | 26 | [build-system] 27 | requires = ["hatchling"] 28 | build-backend = "hatchling.build" 29 | 30 | [tool.uv] 31 | dev-dependencies = ["pyright>=1.1.389", "ruff>=0.7.3", "pytest>=8.0.0"] ``` -------------------------------------------------------------------------------- /plugin/ida_mcp_server_plugin.py: -------------------------------------------------------------------------------- ```python 1 | import idaapi 2 | import json 3 | import socket 4 | import struct 5 | import threading 6 | import traceback 7 | import time 8 | from typing import Optional, Dict, Any, List, Tuple, Union, Set, Type, cast 9 | from ida_mcp_server_plugin.ida_mcp_core import IDAMCPCore 10 | 11 | PLUGIN_NAME = "IDA MCP Server" 12 | PLUGIN_HOTKEY = "Ctrl-Alt-M" 13 | PLUGIN_VERSION = "1.0" 14 | PLUGIN_AUTHOR = "IDA MCP" 15 | 16 | # Default configuration 17 | DEFAULT_HOST = "localhost" 18 | DEFAULT_PORT = 5000 19 | 20 | class IDACommunicator: 21 | """IDA Communication class""" 22 | def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT): 23 | self.host: str = host 24 | self.port: int = port 25 | self.socket: Optional[socket.socket] = None 26 | 27 | def connect(self) -> None: 28 | pass 29 | 30 | class IDAMCPServer: 31 | def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT): 32 | self.host: str = host 33 | self.port: int = port 34 | self.server_socket: Optional[socket.socket] = None 35 | self.running: bool = False 36 | self.thread: Optional[threading.Thread] = None 37 | self.client_counter: int = 0 38 | self.core: IDAMCPCore = IDAMCPCore() 39 | 40 | def start(self) -> bool: 41 | """Start Socket server""" 42 | if self.running: 43 | print("MCP Server already running") 44 | return False 45 | 46 | try: 47 | self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 48 | self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 49 | self.server_socket.bind((self.host, self.port)) 50 | self.server_socket.listen(5) 51 | # self.server_socket.settimeout(1.0) # Set timeout to allow server to respond to stop requests 52 | 53 | self.running = True 54 | self.thread = threading.Thread(target=self.server_loop) 55 | self.thread.daemon = True 56 | self.thread.start() 57 | 58 | print(f"MCP Server started on {self.host}:{self.port}") 59 | return True 60 | except Exception as e: 61 | print(f"Failed to start MCP Server: {str(e)}") 62 | traceback.print_exc() 63 | return False 64 | 65 | def stop(self) -> None: 66 | """Stop Socket server""" 67 | if not self.running: 68 | print("MCP Server is not running, no need to stop") 69 | return 70 | 71 | print("Stopping MCP Server...") 72 | self.running = False 73 | 74 | if self.server_socket: 75 | try: 76 | self.server_socket.close() 77 | except Exception as e: 78 | print(f"Error closing server socket: {str(e)}") 79 | self.server_socket = None 80 | 81 | if self.thread: 82 | try: 83 | self.thread.join(2.0) # Wait for thread to end, maximum 2 seconds 84 | except Exception as e: 85 | print(f"Error joining server thread: {str(e)}") 86 | self.thread = None 87 | 88 | print("MCP Server stopped") 89 | 90 | def send_message(self, client_socket: socket.socket, data: bytes) -> None: 91 | """Send message with length prefix""" 92 | length: int = len(data) 93 | length_bytes: bytes = struct.pack('!I', length) # 4-byte length prefix 94 | client_socket.sendall(length_bytes + data) 95 | 96 | def receive_message(self, client_socket: socket.socket) -> bytes: 97 | """Receive message with length prefix""" 98 | # Receive 4-byte length prefix 99 | length_bytes: bytes = self.receive_exactly(client_socket, 4) 100 | if not length_bytes: 101 | raise ConnectionError("Connection closed") 102 | 103 | length: int = struct.unpack('!I', length_bytes)[0] 104 | 105 | # Receive message body 106 | data: bytes = self.receive_exactly(client_socket, length) 107 | return data 108 | 109 | def receive_exactly(self, client_socket: socket.socket, n: int) -> bytes: 110 | """Receive exactly n bytes of data""" 111 | data: bytes = b'' 112 | while len(data) < n: 113 | chunk: bytes = client_socket.recv(min(n - len(data), 4096)) 114 | if not chunk: # Connection closed 115 | raise ConnectionError("Connection closed, unable to receive complete data") 116 | data += chunk 117 | return data 118 | 119 | def server_loop(self) -> None: 120 | """Server main loop""" 121 | print("Server loop started") 122 | while self.running: 123 | try: 124 | # Use timeout receive to periodically check running flag 125 | try: 126 | client_socket, client_address = self.server_socket.accept() 127 | self.client_counter += 1 128 | client_id: int = self.client_counter 129 | print(f"Client #{client_id} connected from {client_address}") 130 | 131 | # Handle client request - use thread to support multiple clients 132 | client_thread: threading.Thread = threading.Thread( 133 | target=self.handle_client, 134 | args=(client_socket, client_id) 135 | ) 136 | client_thread.daemon = True 137 | client_thread.start() 138 | except socket.timeout: 139 | # Timeout is just for periodically checking running flag 140 | continue 141 | except OSError as e: 142 | if self.running: # Only print error if server is running 143 | if e.errno == 9: # Bad file descriptor, usually means socket is closed 144 | print("Server socket was closed") 145 | break 146 | print(f"Socket error: {str(e)}") 147 | except Exception as e: 148 | if self.running: # Only print error if server is running 149 | print(f"Error accepting connection: {str(e)}") 150 | traceback.print_exc() 151 | except Exception as e: 152 | if self.running: 153 | print(f"Error in server loop: {str(e)}") 154 | traceback.print_exc() 155 | time.sleep(1) # Avoid high CPU usage 156 | 157 | print("Server loop ended") 158 | 159 | def handle_client(self, client_socket: socket.socket, client_id: int) -> None: 160 | """Handle client requests""" 161 | try: 162 | # Set timeout 163 | client_socket.settimeout(30) 164 | 165 | while self.running: 166 | try: 167 | # Receive message 168 | data: bytes = self.receive_message(client_socket) 169 | 170 | # Parse request 171 | request: Dict[str, Any] = json.loads(data.decode('utf-8')) 172 | request_type: str = request.get('type') 173 | request_data: Dict[str, Any] = request.get('data', {}) 174 | request_id: str = request.get('id', 'unknown') 175 | request_count: int = request.get('count', -1) 176 | 177 | print(f"Client #{client_id} request: {request_type}, ID: {request_id}, Count: {request_count}") 178 | 179 | # Handle different types of requests 180 | response: Dict[str, Any] = { 181 | "id": request_id, # Return same request ID 182 | "count": request_count # Return same request count 183 | } 184 | 185 | if request_type == "get_function_assembly_by_name": 186 | response.update(self.core.get_function_assembly_by_name(request_data.get("function_name", ""))) 187 | elif request_type == "get_function_assembly_by_address": 188 | response.update(self.core.get_function_assembly_by_address(request_data.get("address", 0))) 189 | elif request_type == "get_function_decompiled_by_name": 190 | response.update(self.core.get_function_decompiled_by_name(request_data.get("function_name", ""))) 191 | elif request_type == "get_function_decompiled_by_address": 192 | response.update(self.core.get_function_decompiled_by_address(request_data.get("address", 0))) 193 | elif request_type == "get_global_variable_by_name": 194 | response.update(self.core.get_global_variable_by_name(request_data.get("variable_name", ""))) 195 | elif request_type == "get_global_variable_by_address": 196 | response.update(self.core.get_global_variable_by_address(request_data.get("address", 0))) 197 | elif request_type == "get_current_function_assembly": 198 | response.update(self.core.get_current_function_assembly()) 199 | elif request_type == "get_current_function_decompiled": 200 | response.update(self.core.get_current_function_decompiled()) 201 | elif request_type == "rename_global_variable": 202 | response.update(self.core.rename_global_variable( 203 | request_data.get("old_name", ""), 204 | request_data.get("new_name", "") 205 | )) 206 | elif request_type == "rename_function": 207 | response.update(self.core.rename_function( 208 | request_data.get("old_name", ""), 209 | request_data.get("new_name", "") 210 | )) 211 | # Backward compatibility with old method names 212 | elif request_type == "get_function_assembly": 213 | response.update(self.core.get_function_assembly_by_name(request_data.get("function_name", ""))) 214 | elif request_type == "get_function_decompiled": 215 | response.update(self.core.get_function_decompiled_by_name(request_data.get("function_name", ""))) 216 | elif request_type == "get_global_variable": 217 | response.update(self.core.get_global_variable_by_name(request_data.get("variable_name", ""))) 218 | elif request_type == "add_assembly_comment": 219 | response.update(self.core.add_assembly_comment( 220 | request_data.get("address", ""), 221 | request_data.get("comment", ""), 222 | request_data.get("is_repeatable", False) 223 | )) 224 | elif request_type == "rename_local_variable": 225 | response.update(self.core.rename_local_variable( 226 | request_data.get("function_name", ""), 227 | request_data.get("old_name", ""), 228 | request_data.get("new_name", "") 229 | )) 230 | elif request_type == "add_function_comment": 231 | response.update(self.core.add_function_comment( 232 | request_data.get("function_name", ""), 233 | request_data.get("comment", ""), 234 | request_data.get("is_repeatable", False) 235 | )) 236 | elif request_type == "ping": 237 | response["status"] = "pong" 238 | elif request_type == "add_pseudocode_comment": 239 | response.update(self.core.add_pseudocode_comment( 240 | request_data.get("function_name", ""), 241 | request_data.get("address", ""), 242 | request_data.get("comment", ""), 243 | request_data.get("is_repeatable", False) 244 | )) 245 | elif request_type == "execute_script": 246 | response.update(self.core.execute_script( 247 | request_data.get("script", "") 248 | )) 249 | elif request_type == "execute_script_from_file": 250 | response.update(self.core.execute_script_from_file( 251 | request_data.get("file_path", "") 252 | )) 253 | elif request_type == "refresh_view": 254 | response.update(self.core.refresh_view()) 255 | elif request_type == "rename_multi_local_variables": 256 | response.update(self.core.rename_multi_local_variables( 257 | request_data.get("function_name", ""), 258 | request_data.get("rename_pairs_old2new", []) 259 | )) 260 | elif request_type == "rename_multi_global_variables": 261 | response.update(self.core.rename_multi_global_variables( 262 | request_data.get("rename_pairs_old2new", []) 263 | )) 264 | elif request_type == "rename_multi_functions": 265 | response.update(self.core.rename_multi_functions( 266 | request_data.get("rename_pairs_old2new", []) 267 | )) 268 | else: 269 | response["error"] = f"Unknown request type: {request_type}" 270 | 271 | # Verify response is correct 272 | if not isinstance(response, dict): 273 | print(f"Response is not a dictionary: {type(response).__name__}") 274 | response = { 275 | "id": request_id, 276 | "count": request_count, 277 | "error": f"Internal server error: response is not a dictionary but {type(response).__name__}" 278 | } 279 | 280 | # Ensure all values in response are serializable 281 | for key, value in list(response.items()): 282 | if value is None: 283 | response[key] = "null" 284 | elif not isinstance(value, (str, int, float, bool, list, dict, tuple)): 285 | print(f"Response key '{key}' has non-serializable type: {type(value).__name__}") 286 | response[key] = str(value) 287 | 288 | # Send response 289 | response_json: bytes = json.dumps(response).encode('utf-8') 290 | self.send_message(client_socket, response_json) 291 | print(f"Sent response to client #{client_id}, ID: {request_id}, Count: {request_count}") 292 | 293 | except ConnectionError as e: 294 | print(f"Connection with client #{client_id} lost: {str(e)}") 295 | return 296 | except socket.timeout: 297 | # print(f"Socket timeout with client #{client_id}") 298 | continue 299 | except json.JSONDecodeError as e: 300 | print(f"Invalid JSON request from client #{client_id}: {str(e)}") 301 | try: 302 | response: Dict[str, Any] = { 303 | "error": f"Invalid JSON request: {str(e)}" 304 | } 305 | self.send_message(client_socket, json.dumps(response).encode('utf-8')) 306 | except: 307 | print(f"Failed to send error response to client #{client_id}") 308 | except Exception as e: 309 | print(f"Error processing request from client #{client_id}: {str(e)}") 310 | traceback.print_exc() 311 | try: 312 | response: Dict[str, Any] = { 313 | "error": str(e) 314 | } 315 | self.send_message(client_socket, json.dumps(response).encode('utf-8')) 316 | except: 317 | print(f"Failed to send error response to client #{client_id}") 318 | 319 | except Exception as e: 320 | print(f"Error handling client #{client_id}: {str(e)}") 321 | traceback.print_exc() 322 | finally: 323 | try: 324 | client_socket.close() 325 | except: 326 | pass 327 | print(f"Client #{client_id} connection closed") 328 | 329 | 330 | # IDA Plugin class 331 | class IDAMCPPlugin(idaapi.plugin_t): 332 | flags = idaapi.PLUGIN_KEEP 333 | comment = "IDA MCP Server Plugin" 334 | help = "Provides MCP server functionality for IDA" 335 | wanted_name = PLUGIN_NAME 336 | wanted_hotkey = PLUGIN_HOTKEY 337 | 338 | def __init__(self): 339 | super(IDAMCPPlugin, self).__init__() 340 | self.server: Optional[IDAMCPServer] = None 341 | self.initialized: bool = False 342 | self.menu_items_added: bool = False 343 | print(f"IDAMCPPlugin instance created") 344 | 345 | def init(self) -> int: 346 | """Plugin initialization""" 347 | try: 348 | print(f"{PLUGIN_NAME} v{PLUGIN_VERSION} by {PLUGIN_AUTHOR}") 349 | print("Initializing plugin...") 350 | 351 | # Add menu items 352 | if not self.menu_items_added: 353 | self.create_menu_items() 354 | self.menu_items_added = True 355 | print("Menu items added") 356 | 357 | # Mark as initialized 358 | self.initialized = True 359 | print("Plugin initialized successfully") 360 | 361 | # Delay server start to avoid initialization issues 362 | idaapi.register_timer(500, self._delayed_server_start) 363 | 364 | return idaapi.PLUGIN_KEEP 365 | except Exception as e: 366 | print(f"Error initializing plugin: {str(e)}") 367 | traceback.print_exc() 368 | return idaapi.PLUGIN_SKIP 369 | 370 | def _delayed_server_start(self) -> int: 371 | """Delayed server start to avoid initialization race conditions""" 372 | try: 373 | if not self.server or not self.server.running: 374 | print("Delayed server start...") 375 | self.start_server() 376 | except Exception as e: 377 | print(f"Error in delayed server start: {str(e)}") 378 | traceback.print_exc() 379 | return -1 # Don't repeat 380 | 381 | def create_menu_items(self) -> None: 382 | """Create plugin menu items""" 383 | # Create menu items 384 | menu_path: str = "Edit/Plugins/" 385 | 386 | class StartServerHandler(idaapi.action_handler_t): 387 | def __init__(self, plugin: 'IDAMCPPlugin'): 388 | idaapi.action_handler_t.__init__(self) 389 | self.plugin: 'IDAMCPPlugin' = plugin 390 | 391 | def activate(self, ctx) -> int: 392 | self.plugin.start_server() 393 | return 1 394 | 395 | def update(self, ctx) -> int: 396 | return idaapi.AST_ENABLE_ALWAYS 397 | 398 | class StopServerHandler(idaapi.action_handler_t): 399 | def __init__(self, plugin: 'IDAMCPPlugin'): 400 | idaapi.action_handler_t.__init__(self) 401 | self.plugin: 'IDAMCPPlugin' = plugin 402 | 403 | def activate(self, ctx) -> int: 404 | self.plugin.stop_server() 405 | return 1 406 | 407 | def update(self, ctx) -> int: 408 | return idaapi.AST_ENABLE_ALWAYS 409 | 410 | try: 411 | # Register and add start server action 412 | start_action_name: str = "mcp:start_server" 413 | start_action_desc: idaapi.action_desc_t = idaapi.action_desc_t( 414 | start_action_name, 415 | "Start MCP Server", 416 | StartServerHandler(self), 417 | "Ctrl+Alt+S", 418 | "Start the MCP Server", 419 | 199 # Icon ID 420 | ) 421 | 422 | # Register and add stop server action 423 | stop_action_name: str = "mcp:stop_server" 424 | stop_action_desc: idaapi.action_desc_t = idaapi.action_desc_t( 425 | stop_action_name, 426 | "Stop MCP Server", 427 | StopServerHandler(self), 428 | "Ctrl+Alt+X", 429 | "Stop the MCP Server", 430 | 200 # Icon ID 431 | ) 432 | 433 | # Register actions 434 | if not idaapi.register_action(start_action_desc): 435 | print("Failed to register start server action") 436 | if not idaapi.register_action(stop_action_desc): 437 | print("Failed to register stop server action") 438 | 439 | # Add to menu 440 | if not idaapi.attach_action_to_menu(menu_path + "Start MCP Server", start_action_name, idaapi.SETMENU_APP): 441 | print("Failed to attach start server action to menu") 442 | if not idaapi.attach_action_to_menu(menu_path + "Stop MCP Server", stop_action_name, idaapi.SETMENU_APP): 443 | print("Failed to attach stop server action to menu") 444 | 445 | print("Menu items created successfully") 446 | except Exception as e: 447 | print(f"Error creating menu items: {str(e)}") 448 | traceback.print_exc() 449 | 450 | def start_server(self) -> None: 451 | """Start server""" 452 | if self.server and self.server.running: 453 | print("MCP Server is already running") 454 | return 455 | 456 | try: 457 | print("Creating MCP Server instance...") 458 | self.server = IDAMCPServer() 459 | print("Starting MCP Server...") 460 | if self.server.start(): 461 | print("MCP Server started successfully") 462 | else: 463 | print("Failed to start MCP Server") 464 | except Exception as e: 465 | print(f"Error starting server: {str(e)}") 466 | traceback.print_exc() 467 | 468 | def stop_server(self) -> None: 469 | """Stop server""" 470 | if not self.server: 471 | print("MCP Server instance does not exist") 472 | return 473 | 474 | if not self.server.running: 475 | print("MCP Server is not running") 476 | return 477 | 478 | try: 479 | self.server.stop() 480 | print("MCP Server stopped by user") 481 | except Exception as e: 482 | print(f"Error stopping server: {str(e)}") 483 | traceback.print_exc() 484 | 485 | def run(self, arg) -> None: 486 | """Execute when hotkey is pressed""" 487 | if not self.initialized: 488 | print("Plugin not initialized") 489 | return 490 | 491 | # Automatically start or stop server when hotkey is triggered 492 | try: 493 | if not self.server or not self.server.running: 494 | print("Hotkey triggered: starting server") 495 | self.start_server() 496 | else: 497 | print("Hotkey triggered: stopping server") 498 | self.stop_server() 499 | except Exception as e: 500 | print(f"Error in run method: {str(e)}") 501 | traceback.print_exc() 502 | 503 | def term(self) -> None: 504 | """Plugin termination""" 505 | try: 506 | if self.server and self.server.running: 507 | print("Terminating plugin: stopping server") 508 | self.server.stop() 509 | print(f"{PLUGIN_NAME} terminated") 510 | except Exception as e: 511 | print(f"Error terminating plugin: {str(e)}") 512 | traceback.print_exc() 513 | 514 | # Register plugin 515 | def PLUGIN_ENTRY() -> IDAMCPPlugin: 516 | return IDAMCPPlugin() 517 | ``` -------------------------------------------------------------------------------- /plugin/ida_mcp_server_plugin/ida_mcp_core.py: -------------------------------------------------------------------------------- ```python 1 | import idaapi 2 | import idautils 3 | import ida_funcs 4 | import ida_hexrays 5 | import ida_bytes 6 | import ida_name 7 | import ida_segment 8 | import ida_lines 9 | import idc 10 | import json 11 | import traceback 12 | import functools 13 | import queue 14 | from typing import Any, Callable, TypeVar, Optional, Dict, List, Union, Tuple, Type 15 | 16 | # Type variable for function return type 17 | T = TypeVar('T') 18 | 19 | class IDASyncError(Exception): 20 | """Exception raised for IDA synchronization errors""" 21 | pass 22 | 23 | # Global call stack to track synchronization calls 24 | call_stack: queue.LifoQueue[str] = queue.LifoQueue() 25 | 26 | def sync_wrapper(func: Callable[..., T], sync_type: int) -> T: 27 | """ 28 | Wrapper function to execute a function in IDA's main thread 29 | 30 | Args: 31 | func: The function to execute 32 | sync_type: Synchronization type (MFF_READ or MFF_WRITE) 33 | 34 | Returns: 35 | The result of the function execution 36 | """ 37 | if sync_type not in [idaapi.MFF_READ, idaapi.MFF_WRITE]: 38 | error_str = f'Invalid sync type {sync_type} for function {func.__name__}' 39 | print(error_str) 40 | raise IDASyncError(error_str) 41 | 42 | # Container for the result 43 | result_container: queue.Queue[Any] = queue.Queue() 44 | 45 | def execute_in_main_thread() -> int: 46 | # Check if we're already inside a sync_wrapper call 47 | if not call_stack.empty(): 48 | last_func = call_stack.get() 49 | error_str = f'Nested sync call detected: function {func.__name__} called from {last_func}' 50 | print(error_str) 51 | call_stack.put(last_func) # Put it back 52 | raise IDASyncError(error_str) 53 | 54 | # Add function to call stack 55 | call_stack.put(func.__name__) 56 | 57 | try: 58 | # Execute function and store result 59 | result_container.put(func()) 60 | except Exception as e: 61 | print(f"Error in {func.__name__}: {str(e)}") 62 | traceback.print_exc() 63 | result_container.put(None) 64 | finally: 65 | # Always remove function from call stack 66 | call_stack.get() 67 | 68 | return 1 # Required by execute_sync 69 | 70 | # Execute in IDA's main thread 71 | idaapi.execute_sync(execute_in_main_thread, sync_type) 72 | 73 | # Return the result 74 | return result_container.get() 75 | 76 | def idaread(func: Callable[..., T]) -> Callable[..., T]: 77 | """ 78 | Decorator for functions that read from the IDA database 79 | 80 | Args: 81 | func: The function to decorate 82 | 83 | Returns: 84 | Decorated function that executes in IDA's main thread with read access 85 | """ 86 | @functools.wraps(func) 87 | def wrapper(*args: Any, **kwargs: Any) -> T: 88 | # Create a partial function with the arguments 89 | partial_func = functools.partial(func, *args, **kwargs) 90 | # Preserve the original function name 91 | partial_func.__name__ = func.__name__ 92 | # Execute with sync_wrapper 93 | return sync_wrapper(partial_func, idaapi.MFF_READ) 94 | 95 | return wrapper 96 | 97 | def idawrite(func: Callable[..., T]) -> Callable[..., T]: 98 | """ 99 | Decorator for functions that write to the IDA database 100 | 101 | Args: 102 | func: The function to decorate 103 | 104 | Returns: 105 | Decorated function that executes in IDA's main thread with write access 106 | """ 107 | @functools.wraps(func) 108 | def wrapper(*args: Any, **kwargs: Any) -> T: 109 | # Create a partial function with the arguments 110 | partial_func = functools.partial(func, *args, **kwargs) 111 | # Preserve the original function name 112 | partial_func.__name__ = func.__name__ 113 | # Execute with sync_wrapper 114 | return sync_wrapper(partial_func, idaapi.MFF_WRITE) 115 | 116 | return wrapper 117 | 118 | class IDAMCPCore: 119 | """Core functionality implementation class for IDA MCP""" 120 | 121 | @idaread 122 | def get_function_assembly_by_name(self, function_name: str) -> Dict[str, Any]: 123 | """Get assembly code for a function by its name""" 124 | try: 125 | # Get function address from name 126 | func = idaapi.get_func(idaapi.get_name_ea(0, function_name)) 127 | if not func: 128 | return {"error": f"Function '{function_name}' not found"} 129 | 130 | # Call address-based implementation 131 | result = self._get_function_assembly_by_address_internal(func.start_ea) 132 | 133 | # If successful, add function name to result 134 | if "error" not in result: 135 | result["function_name"] = function_name 136 | 137 | return result 138 | except Exception as e: 139 | traceback.print_exc() 140 | return {"error": str(e)} 141 | 142 | @idaread 143 | def get_function_assembly_by_address(self, address: int) -> Dict[str, Any]: 144 | """Get assembly code for a function by its address""" 145 | return self._get_function_assembly_by_address_internal(address) 146 | 147 | def _get_function_assembly_by_address_internal(self, address: int) -> Dict[str, Any]: 148 | """Internal implementation for get_function_assembly_by_address without sync wrapper""" 149 | try: 150 | # Get function object 151 | func = ida_funcs.get_func(address) 152 | 153 | # Get function name 154 | func_name = idaapi.get_func_name(func.start_ea) 155 | 156 | if not func: 157 | return {"error": f"Invalid function at {hex(address)}"} 158 | 159 | # Collect all assembly instructions 160 | assembly_lines = [] 161 | for instr_addr in idautils.FuncItems(address): 162 | disasm = idc.GetDisasm(instr_addr) 163 | assembly_lines.append(f"{hex(instr_addr)}: {disasm}") 164 | 165 | if not assembly_lines: 166 | return {"error": "No assembly instructions found"} 167 | 168 | return {"assembly": "\n".join(assembly_lines), "function_name": func_name} 169 | except Exception as e: 170 | print(f"Error getting function assembly: {str(e)}") 171 | traceback.print_exc() 172 | return {"error": str(e)} 173 | 174 | 175 | @idaread 176 | def get_function_decompiled_by_name(self, function_name: str) -> Dict[str, Any]: 177 | """Get decompiled code for a function by its name""" 178 | try: 179 | # Get function address from name 180 | func_addr = idaapi.get_name_ea(0, function_name) 181 | if func_addr == idaapi.BADADDR: 182 | return {"error": f"Function '{function_name}' not found"} 183 | 184 | # Call internal implementation without decorator 185 | result = self._get_function_decompiled_by_address_internal(func_addr) 186 | 187 | # If successful, add function name to result 188 | if "error" not in result: 189 | result["function_name"] = function_name 190 | 191 | return result 192 | except Exception as e: 193 | traceback.print_exc() 194 | return {"error": str(e)} 195 | 196 | @idaread 197 | def get_function_decompiled_by_address(self, address: int) -> Dict[str, Any]: 198 | """Get decompiled code for a function by its address""" 199 | return self._get_function_decompiled_by_address_internal(address) 200 | 201 | def _get_function_decompiled_by_address_internal(self, address: int) -> Dict[str, Any]: 202 | """Internal implementation for get_function_decompiled_by_address without sync wrapper""" 203 | try: 204 | # Get function from address 205 | func = idaapi.get_func(address) 206 | if not func: 207 | return {"error": f"No function found at address 0x{address:X}"} 208 | 209 | # Get function name 210 | func_name = idaapi.get_func_name(func.start_ea) 211 | 212 | # Try to import decompiler module 213 | try: 214 | import ida_hexrays 215 | except ImportError: 216 | return {"error": "Hex-Rays decompiler is not available"} 217 | 218 | # Check if decompiler is available 219 | if not ida_hexrays.init_hexrays_plugin(): 220 | return {"error": "Unable to initialize Hex-Rays decompiler"} 221 | 222 | # Get decompiled function 223 | cfunc = None 224 | try: 225 | cfunc = ida_hexrays.decompile(func.start_ea) 226 | except Exception as e: 227 | return {"error": f"Unable to decompile function: {str(e)}"} 228 | 229 | if not cfunc: 230 | return {"error": "Decompilation failed"} 231 | 232 | # Get pseudocode as string 233 | decompiled_code = str(cfunc) 234 | 235 | return {"decompiled_code": decompiled_code, "function_name": func_name} 236 | except Exception as e: 237 | traceback.print_exc() 238 | return {"error": str(e)} 239 | 240 | @idaread 241 | def get_current_function_assembly(self) -> Dict[str, Any]: 242 | """Get assembly code for the function at the current cursor position""" 243 | try: 244 | # Get current address 245 | curr_addr = idaapi.get_screen_ea() 246 | if curr_addr == idaapi.BADADDR: 247 | return {"error": "No valid cursor position"} 248 | 249 | # Use the internal implementation without decorator 250 | return self._get_function_assembly_by_address_internal(curr_addr) 251 | except Exception as e: 252 | traceback.print_exc() 253 | return {"error": str(e)} 254 | 255 | @idaread 256 | def get_current_function_decompiled(self) -> Dict[str, Any]: 257 | """Get decompiled code for the function at the current cursor position""" 258 | try: 259 | # Get current address 260 | curr_addr = idaapi.get_screen_ea() 261 | if curr_addr == idaapi.BADADDR: 262 | return {"error": "No valid cursor position"} 263 | 264 | # Use the internal implementation without decorator 265 | return self._get_function_decompiled_by_address_internal(curr_addr) 266 | except Exception as e: 267 | traceback.print_exc() 268 | return {"error": str(e)} 269 | 270 | @idaread 271 | def get_global_variable_by_name(self, variable_name: str) -> Dict[str, Any]: 272 | """Get global variable information by its name""" 273 | try: 274 | # Get variable address 275 | var_addr: int = ida_name.get_name_ea(0, variable_name) 276 | if var_addr == idaapi.BADADDR: 277 | return {"error": f"Global variable '{variable_name}' not found"} 278 | 279 | # Call internal implementation 280 | result = self._get_global_variable_by_address_internal(var_addr) 281 | 282 | # If successful, add variable name to result 283 | if "error" not in result and "variable_info" in result: 284 | # Parse the JSON string back to dict to modify it 285 | var_info = json.loads(result["variable_info"]) 286 | var_info["name"] = variable_name 287 | # Convert back to JSON string 288 | result["variable_info"] = json.dumps(var_info, indent=2) 289 | 290 | return result 291 | except Exception as e: 292 | print(f"Error getting global variable by name: {str(e)}") 293 | traceback.print_exc() 294 | return {"error": str(e)} 295 | 296 | @idaread 297 | def get_global_variable_by_address(self, address: int) -> Dict[str, Any]: 298 | """Get global variable information by its address""" 299 | return self._get_global_variable_by_address_internal(address) 300 | 301 | def _get_global_variable_by_address_internal(self, address: int) -> Dict[str, Any]: 302 | """Internal implementation for get_global_variable_by_address without sync wrapper""" 303 | try: 304 | # Verify address is valid 305 | if address == idaapi.BADADDR: 306 | return {"error": f"Invalid address: {hex(address)}"} 307 | 308 | # Get variable name if available 309 | variable_name = ida_name.get_name(address) 310 | if not variable_name: 311 | variable_name = f"unnamed_{hex(address)}" 312 | 313 | # Get variable segment 314 | segment: Optional[ida_segment.segment_t] = ida_segment.getseg(address) 315 | if not segment: 316 | return {"error": f"No segment found for address {hex(address)}"} 317 | 318 | segment_name: str = ida_segment.get_segm_name(segment) 319 | segment_class: str = ida_segment.get_segm_class(segment) 320 | 321 | # Get variable type 322 | tinfo = idaapi.tinfo_t() 323 | guess_type: bool = idaapi.guess_tinfo(tinfo, address) 324 | type_str: str = tinfo.get_type_name() if guess_type else "unknown" 325 | 326 | # Try to get variable value 327 | size: int = ida_bytes.get_item_size(address) 328 | if size <= 0: 329 | size = 8 # Default to 8 bytes 330 | 331 | # Read data based on size 332 | value: Optional[int] = None 333 | if size == 1: 334 | value = ida_bytes.get_byte(address) 335 | elif size == 2: 336 | value = ida_bytes.get_word(address) 337 | elif size == 4: 338 | value = ida_bytes.get_dword(address) 339 | elif size == 8: 340 | value = ida_bytes.get_qword(address) 341 | 342 | # Build variable info 343 | var_info: Dict[str, Any] = { 344 | "name": variable_name, 345 | "address": hex(address), 346 | "segment": segment_name, 347 | "segment_class": segment_class, 348 | "type": type_str, 349 | "size": size, 350 | "value": hex(value) if value is not None else "N/A" 351 | } 352 | 353 | # If it's a string, try to read string content 354 | if ida_bytes.is_strlit(ida_bytes.get_flags(address)): 355 | str_value = idc.get_strlit_contents(address, -1, 0) 356 | if str_value: 357 | try: 358 | var_info["string_value"] = str_value.decode('utf-8', errors='replace') 359 | except: 360 | var_info["string_value"] = str(str_value) 361 | 362 | return {"variable_info": json.dumps(var_info, indent=2)} 363 | except Exception as e: 364 | print(f"Error getting global variable by address: {str(e)}") 365 | traceback.print_exc() 366 | return {"error": str(e)} 367 | 368 | @idawrite 369 | def rename_global_variable(self, old_name: str, new_name: str) -> Dict[str, Any]: 370 | """Rename a global variable""" 371 | return self._rename_global_variable_internal(old_name, new_name) 372 | 373 | def _rename_global_variable_internal(self, old_name: str, new_name: str) -> Dict[str, Any]: 374 | """Internal implementation for rename_global_variable without sync wrapper""" 375 | try: 376 | # Get variable address 377 | var_addr: int = ida_name.get_name_ea(0, old_name) 378 | if var_addr == idaapi.BADADDR: 379 | return {"success": False, "message": f"Variable '{old_name}' not found"} 380 | 381 | # Check if new name is already in use 382 | if ida_name.get_name_ea(0, new_name) != idaapi.BADADDR: 383 | return {"success": False, "message": f"Name '{new_name}' is already in use"} 384 | 385 | # Try to rename 386 | if not ida_name.set_name(var_addr, new_name): 387 | return {"success": False, "message": f"Failed to rename variable, possibly due to invalid name format or other IDA restrictions"} 388 | 389 | # Refresh view 390 | self._refresh_view_internal() 391 | 392 | return {"success": True, "message": f"Variable renamed from '{old_name}' to '{new_name}' at address {hex(var_addr)}"} 393 | 394 | except Exception as e: 395 | print(f"Error renaming variable: {str(e)}") 396 | traceback.print_exc() 397 | return {"success": False, "message": str(e)} 398 | 399 | @idawrite 400 | def rename_function(self, old_name: str, new_name: str) -> Dict[str, Any]: 401 | """Rename a function""" 402 | return self._rename_function_internal(old_name, new_name) 403 | 404 | def _rename_function_internal(self, old_name: str, new_name: str) -> Dict[str, Any]: 405 | """Internal implementation for rename_function without sync wrapper""" 406 | try: 407 | # Get function address 408 | func_addr: int = ida_name.get_name_ea(0, old_name) 409 | if func_addr == idaapi.BADADDR: 410 | return {"success": False, "message": f"Function '{old_name}' not found"} 411 | 412 | # Check if it's a function 413 | func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) 414 | if not func: 415 | return {"success": False, "message": f"'{old_name}' is not a function"} 416 | 417 | # Check if new name is already in use 418 | if ida_name.get_name_ea(0, new_name) != idaapi.BADADDR: 419 | return {"success": False, "message": f"Name '{new_name}' is already in use"} 420 | 421 | # Try to rename 422 | if not ida_name.set_name(func_addr, new_name): 423 | return {"success": False, "message": f"Failed to rename function, possibly due to invalid name format or other IDA restrictions"} 424 | 425 | # Refresh view 426 | self._refresh_view_internal() 427 | 428 | return {"success": True, "message": f"Function renamed from '{old_name}' to '{new_name}' at address {hex(func_addr)}"} 429 | 430 | except Exception as e: 431 | print(f"Error renaming function: {str(e)}") 432 | traceback.print_exc() 433 | return {"success": False, "message": str(e)} 434 | 435 | @idawrite 436 | def add_assembly_comment(self, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: 437 | """Add an assembly comment""" 438 | return self._add_assembly_comment_internal(address, comment, is_repeatable) 439 | 440 | def _add_assembly_comment_internal(self, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: 441 | """Internal implementation for add_assembly_comment without sync wrapper""" 442 | try: 443 | # Convert address string to integer 444 | addr: int 445 | if isinstance(address, str): 446 | if address.startswith("0x"): 447 | addr = int(address, 16) 448 | else: 449 | try: 450 | addr = int(address, 16) # Try parsing as hex 451 | except ValueError: 452 | try: 453 | addr = int(address) # Try parsing as decimal 454 | except ValueError: 455 | return {"success": False, "message": f"Invalid address format: {address}"} 456 | else: 457 | addr = address 458 | 459 | # Check if address is valid 460 | if addr == idaapi.BADADDR or not ida_bytes.is_loaded(addr): 461 | return {"success": False, "message": f"Invalid or unloaded address: {hex(addr)}"} 462 | 463 | # Add comment 464 | result: bool = idc.set_cmt(addr, comment, is_repeatable) 465 | if result: 466 | # Refresh view 467 | self._refresh_view_internal() 468 | comment_type: str = "repeatable" if is_repeatable else "regular" 469 | return {"success": True, "message": f"Added {comment_type} assembly comment at address {hex(addr)}"} 470 | else: 471 | return {"success": False, "message": f"Failed to add assembly comment at address {hex(addr)}"} 472 | 473 | except Exception as e: 474 | print(f"Error adding assembly comment: {str(e)}") 475 | traceback.print_exc() 476 | return {"success": False, "message": str(e)} 477 | 478 | @idawrite 479 | def rename_local_variable(self, function_name: str, old_name: str, new_name: str) -> Dict[str, Any]: 480 | """Rename a local variable within a function""" 481 | return self._rename_local_variable_internal(function_name, old_name, new_name) 482 | 483 | def _rename_local_variable_internal(self, function_name: str, old_name: str, new_name: str) -> Dict[str, Any]: 484 | """Internal implementation for rename_local_variable without sync wrapper""" 485 | try: 486 | # Parameter validation 487 | if not function_name: 488 | return {"success": False, "message": "Function name cannot be empty"} 489 | if not old_name: 490 | return {"success": False, "message": "Old variable name cannot be empty"} 491 | if not new_name: 492 | return {"success": False, "message": "New variable name cannot be empty"} 493 | 494 | # Get function address 495 | func_addr: int = ida_name.get_name_ea(0, function_name) 496 | if func_addr == idaapi.BADADDR: 497 | return {"success": False, "message": f"Function '{function_name}' not found"} 498 | 499 | # Check if it's a function 500 | func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) 501 | if not func: 502 | return {"success": False, "message": f"'{function_name}' is not a function"} 503 | 504 | # Check if decompiler is available 505 | if not ida_hexrays.init_hexrays_plugin(): 506 | return {"success": False, "message": "Hex-Rays decompiler is not available"} 507 | 508 | # Get decompilation result 509 | cfunc: Optional[ida_hexrays.cfunc_t] = ida_hexrays.decompile(func_addr) 510 | if not cfunc: 511 | return {"success": False, "message": f"Failed to decompile function '{function_name}'"} 512 | 513 | ida_hexrays.open_pseudocode(func_addr, 0) 514 | 515 | # Find local variable to rename 516 | found: bool = False 517 | renamed: bool = False 518 | lvar: Optional[ida_hexrays.lvar_t] = None 519 | 520 | # Iterate through all local variables 521 | lvars = cfunc.get_lvars() 522 | for i in range(lvars.size()): 523 | v = lvars[i] 524 | if v.name == old_name: 525 | lvar = v 526 | found = True 527 | break 528 | 529 | if not found: 530 | return {"success": False, "message": f"Local variable '{old_name}' not found in function '{function_name}'"} 531 | 532 | # Rename local variable 533 | if ida_hexrays.rename_lvar(cfunc.entry_ea, lvar.name, new_name): 534 | renamed = True 535 | 536 | if renamed: 537 | # Refresh view 538 | self._refresh_view_internal() 539 | return {"success": True, "message": f"Local variable renamed from '{old_name}' to '{new_name}' in function '{function_name}'"} 540 | else: 541 | return {"success": False, "message": f"Failed to rename local variable from '{old_name}' to '{new_name}', possibly due to invalid name format or other IDA restrictions"} 542 | 543 | except Exception as e: 544 | print(f"Error renaming local variable: {str(e)}") 545 | traceback.print_exc() 546 | return {"success": False, "message": str(e)} 547 | 548 | @idawrite 549 | def rename_multi_local_variables(self, function_name: str, rename_pairs_old2new: List[Dict[str, str]]) -> Dict[str, Any]: 550 | """Rename multiple local variables within a function at once""" 551 | try: 552 | success_count: int = 0 553 | failed_pairs: List[Dict[str, str]] = [] 554 | 555 | for pair in rename_pairs_old2new: 556 | old_name = next(iter(pair.keys())) 557 | new_name = pair[old_name] 558 | 559 | # Call existing rename_local_variable_internal for each pair 560 | result = self._rename_local_variable_internal(function_name, old_name, new_name) 561 | 562 | if result.get("success", False): 563 | success_count += 1 564 | else: 565 | failed_pairs.append({ 566 | "old_name": old_name, 567 | "new_name": new_name, 568 | "error": result.get("message", "Unknown error") 569 | }) 570 | 571 | return { 572 | "success": True, 573 | "message": f"Renamed {success_count} out of {len(rename_pairs_old2new)} local variables", 574 | "success_count": success_count, 575 | "failed_pairs": failed_pairs 576 | } 577 | 578 | except Exception as e: 579 | print(f"Error in rename_multi_local_variables: {str(e)}") 580 | traceback.print_exc() 581 | return { 582 | "success": False, 583 | "message": str(e), 584 | "success_count": 0, 585 | "failed_pairs": rename_pairs_old2new 586 | } 587 | 588 | @idawrite 589 | def rename_multi_global_variables(self, rename_pairs_old2new: List[Dict[str, str]]) -> Dict[str, Any]: 590 | """Rename multiple global variables at once""" 591 | try: 592 | success_count: int = 0 593 | failed_pairs: List[Dict[str, str]] = [] 594 | 595 | for pair in rename_pairs_old2new: 596 | old_name = next(iter(pair.keys())) 597 | new_name = pair[old_name] 598 | 599 | # Call existing rename_global_variable_internal for each pair 600 | result = self._rename_global_variable_internal(old_name, new_name) 601 | 602 | if result.get("success", False): 603 | success_count += 1 604 | else: 605 | failed_pairs.append({ 606 | "old_name": old_name, 607 | "new_name": new_name, 608 | "error": result.get("message", "Unknown error") 609 | }) 610 | 611 | return { 612 | "success": True, 613 | "message": f"Renamed {success_count} out of {len(rename_pairs_old2new)} global variables", 614 | "success_count": success_count, 615 | "failed_pairs": failed_pairs 616 | } 617 | 618 | except Exception as e: 619 | print(f"Error in rename_multi_global_variables: {str(e)}") 620 | traceback.print_exc() 621 | return { 622 | "success": False, 623 | "message": str(e), 624 | "success_count": 0, 625 | "failed_pairs": rename_pairs_old2new 626 | } 627 | 628 | @idawrite 629 | def rename_multi_functions(self, rename_pairs_old2new: List[Dict[str, str]]) -> Dict[str, Any]: 630 | """Rename multiple functions at once""" 631 | try: 632 | success_count: int = 0 633 | failed_pairs: List[Dict[str, str]] = [] 634 | 635 | for pair in rename_pairs_old2new: 636 | old_name = next(iter(pair.keys())) 637 | new_name = pair[old_name] 638 | 639 | # Call existing rename_function_internal for each pair 640 | result = self._rename_function_internal(old_name, new_name) 641 | 642 | if result.get("success", False): 643 | success_count += 1 644 | else: 645 | failed_pairs.append({ 646 | "old_name": old_name, 647 | "new_name": new_name, 648 | "error": result.get("message", "Unknown error") 649 | }) 650 | 651 | return { 652 | "success": True, 653 | "message": f"Renamed {success_count} out of {len(rename_pairs_old2new)} functions", 654 | "success_count": success_count, 655 | "failed_pairs": failed_pairs 656 | } 657 | 658 | except Exception as e: 659 | print(f"Error in rename_multi_functions: {str(e)}") 660 | traceback.print_exc() 661 | return { 662 | "success": False, 663 | "message": str(e), 664 | "success_count": 0, 665 | "failed_pairs": rename_pairs_old2new 666 | } 667 | 668 | @idawrite 669 | def add_function_comment(self, function_name: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: 670 | """Add a comment to a function""" 671 | return self._add_function_comment_internal(function_name, comment, is_repeatable) 672 | 673 | def _add_function_comment_internal(self, function_name: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: 674 | """Internal implementation for add_function_comment without sync wrapper""" 675 | try: 676 | # Parameter validation 677 | if not function_name: 678 | return {"success": False, "message": "Function name cannot be empty"} 679 | if not comment: 680 | # Allow empty comment to clear the comment 681 | comment = "" 682 | 683 | # Get function address 684 | func_addr: int = ida_name.get_name_ea(0, function_name) 685 | if func_addr == idaapi.BADADDR: 686 | return {"success": False, "message": f"Function '{function_name}' not found"} 687 | 688 | # Check if it's a function 689 | func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) 690 | if not func: 691 | return {"success": False, "message": f"'{function_name}' is not a function"} 692 | 693 | # Open pseudocode view 694 | ida_hexrays.open_pseudocode(func_addr, 0) 695 | 696 | # Add function comment 697 | # is_repeatable=True means show comment at all references to this function 698 | # is_repeatable=False means show comment only at function definition 699 | result: bool = idc.set_func_cmt(func_addr, comment, is_repeatable) 700 | 701 | if result: 702 | # Refresh view 703 | self._refresh_view_internal() 704 | comment_type: str = "repeatable" if is_repeatable else "regular" 705 | return {"success": True, "message": f"Added {comment_type} comment to function '{function_name}'"} 706 | else: 707 | return {"success": False, "message": f"Failed to add comment to function '{function_name}'"} 708 | 709 | except Exception as e: 710 | print(f"Error adding function comment: {str(e)}") 711 | traceback.print_exc() 712 | return {"success": False, "message": str(e)} 713 | 714 | @idawrite 715 | def add_pseudocode_comment(self, function_name: str, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: 716 | """Add a comment to a specific address in the function's decompiled pseudocode""" 717 | return self._add_pseudocode_comment_internal(function_name, address, comment, is_repeatable) 718 | 719 | def _add_pseudocode_comment_internal(self, function_name: str, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: 720 | """Internal implementation for add_pseudocode_comment without sync wrapper""" 721 | try: 722 | # Parameter validation 723 | if not function_name: 724 | return {"success": False, "message": "Function name cannot be empty"} 725 | if not address: 726 | return {"success": False, "message": "Address cannot be empty"} 727 | if not comment: 728 | # Allow empty comment to clear the comment 729 | comment = "" 730 | 731 | # Get function address 732 | func_addr: int = ida_name.get_name_ea(0, function_name) 733 | if func_addr == idaapi.BADADDR: 734 | return {"success": False, "message": f"Function '{function_name}' not found"} 735 | 736 | # Check if it's a function 737 | func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) 738 | if not func: 739 | return {"success": False, "message": f"'{function_name}' is not a function"} 740 | 741 | # Check if decompiler is available 742 | if not ida_hexrays.init_hexrays_plugin(): 743 | return {"success": False, "message": "Hex-Rays decompiler is not available"} 744 | 745 | # Get decompilation result 746 | cfunc: Optional[ida_hexrays.cfunc_t] = ida_hexrays.decompile(func_addr) 747 | if not cfunc: 748 | return {"success": False, "message": f"Failed to decompile function '{function_name}'"} 749 | 750 | # Open pseudocode view 751 | ida_hexrays.open_pseudocode(func_addr, 0) 752 | 753 | # Convert address string to integer 754 | addr: int 755 | if isinstance(address, str): 756 | if address.startswith("0x"): 757 | addr = int(address, 16) 758 | else: 759 | try: 760 | addr = int(address, 16) # Try parsing as hex 761 | except ValueError: 762 | try: 763 | addr = int(address) # Try parsing as decimal 764 | except ValueError: 765 | return {"success": False, "message": f"Invalid address format: {address}"} 766 | else: 767 | addr = address 768 | 769 | # Check if address is valid 770 | if addr == idaapi.BADADDR or not ida_bytes.is_loaded(addr): 771 | return {"success": False, "message": f"Invalid or unloaded address: {hex(addr)}"} 772 | 773 | # Check if address is within function 774 | if not (func.start_ea <= addr < func.end_ea): 775 | return {"success": False, "message": f"Address {hex(addr)} is not within function '{function_name}'"} 776 | 777 | # Create treeloc_t object for comment location 778 | loc = ida_hexrays.treeloc_t() 779 | loc.ea = addr 780 | loc.itp = ida_hexrays.ITP_BLOCK1 # Comment location 781 | 782 | # Set comment 783 | cfunc.set_user_cmt(loc, comment) 784 | cfunc.save_user_cmts() 785 | 786 | # Refresh view 787 | self._refresh_view_internal() 788 | 789 | comment_type: str = "repeatable" if is_repeatable else "regular" 790 | return { 791 | "success": True, 792 | "message": f"Added {comment_type} comment at address {hex(addr)} in function '{function_name}'" 793 | } 794 | 795 | except Exception as e: 796 | print(f"Error adding pseudocode comment: {str(e)}") 797 | traceback.print_exc() 798 | return {"success": False, "message": str(e)} 799 | 800 | @idawrite 801 | def refresh_view(self) -> Dict[str, Any]: 802 | """Refresh IDA Pro view""" 803 | return self._refresh_view_internal() 804 | 805 | def _refresh_view_internal(self) -> Dict[str, Any]: 806 | """Implementation of refreshing view in IDA main thread""" 807 | try: 808 | # Refresh disassembly view 809 | idaapi.refresh_idaview_anyway() 810 | 811 | # Refresh decompilation view 812 | current_widget = idaapi.get_current_widget() 813 | if current_widget: 814 | widget_type: int = idaapi.get_widget_type(current_widget) 815 | if widget_type == idaapi.BWN_PSEUDOCODE: 816 | # If current view is pseudocode, refresh it 817 | vu = idaapi.get_widget_vdui(current_widget) 818 | if vu: 819 | vu.refresh_view(True) 820 | 821 | # Try to find and refresh all open pseudocode windows 822 | for i in range(5): # Check multiple possible pseudocode windows 823 | widget_name: str = f"Pseudocode-{chr(65+i)}" # Pseudocode-A, Pseudocode-B, ... 824 | widget = idaapi.find_widget(widget_name) 825 | if widget: 826 | vu = idaapi.get_widget_vdui(widget) 827 | if vu: 828 | vu.refresh_view(True) 829 | 830 | return {"success": True, "message": "Views refreshed successfully"} 831 | except Exception as e: 832 | print(f"Error refreshing views: {str(e)}") 833 | traceback.print_exc() 834 | return {"success": False, "message": str(e)} 835 | 836 | @idawrite 837 | def execute_script(self, script: str) -> Dict[str, Any]: 838 | """Execute a Python script in IDA context""" 839 | return self._execute_script_internal(script) 840 | 841 | def _execute_script_internal(self, script: str) -> Dict[str, Any]: 842 | """Internal implementation for execute_script without sync wrapper""" 843 | try: 844 | print(f"Executing script, length: {len(script) if script else 0}") 845 | 846 | # Check for empty script 847 | if not script or not script.strip(): 848 | print("Error: Empty script provided") 849 | return { 850 | "success": False, 851 | "error": "Empty script provided", 852 | "stdout": "", 853 | "stderr": "", 854 | "traceback": "" 855 | } 856 | 857 | # Create a local namespace for script execution 858 | script_globals = { 859 | '__builtins__': __builtins__, 860 | 'idaapi': idaapi, 861 | 'idautils': idautils, 862 | 'idc': idc, 863 | 'ida_funcs': ida_funcs, 864 | 'ida_bytes': ida_bytes, 865 | 'ida_name': ida_name, 866 | 'ida_segment': ida_segment, 867 | 'ida_lines': ida_lines, 868 | 'ida_hexrays': ida_hexrays 869 | } 870 | script_locals = {} 871 | 872 | # Save original stdin/stdout/stderr 873 | import sys 874 | import io 875 | original_stdout = sys.stdout 876 | original_stderr = sys.stderr 877 | original_stdin = sys.stdin 878 | 879 | # Create string IO objects to capture output 880 | stdout_capture = io.StringIO() 881 | stderr_capture = io.StringIO() 882 | 883 | # Redirect stdout/stderr to capture output 884 | sys.stdout = stdout_capture 885 | sys.stderr = stderr_capture 886 | 887 | # Prevent script from trying to read from stdin 888 | sys.stdin = io.StringIO() 889 | 890 | try: 891 | # Create UI hooks 892 | print("Setting up UI hooks") 893 | hooks = self._create_ui_hooks() 894 | hooks.hook() 895 | 896 | # Install auto-continue handlers for common dialogs - but first, redirect stderr 897 | temp_stderr = sys.stderr 898 | auto_handler_stderr = io.StringIO() 899 | sys.stderr = auto_handler_stderr 900 | 901 | print("Installing auto handlers") 902 | self._install_auto_handlers() 903 | 904 | # Restore stderr and save auto-handler errors separately 905 | sys.stderr = stderr_capture 906 | auto_handler_errors = auto_handler_stderr.getvalue() 907 | 908 | # Only log auto-handler errors, don't include in script output 909 | if auto_handler_errors: 910 | print(f"Auto-handler setup errors (not shown to user): {auto_handler_errors}") 911 | 912 | # Execute the script 913 | print("Executing script...") 914 | exec(script, script_globals, script_locals) 915 | print("Script execution completed") 916 | 917 | # Get captured output 918 | stdout = stdout_capture.getvalue() 919 | stderr = stderr_capture.getvalue() 920 | 921 | # Filter out auto-handler messages from stdout 922 | stdout_lines = stdout.splitlines() 923 | filtered_stdout_lines = [] 924 | 925 | for line in stdout_lines: 926 | skip_line = False 927 | auto_handler_messages = [ 928 | "Setting up UI hooks", 929 | "Installing auto handlers", 930 | "Error installing auto handlers", 931 | "Found and saved", 932 | "Could not access user_cancelled", 933 | "Installed auto_", 934 | "Auto handlers installed", 935 | "Note: Could not", 936 | "Restoring IO streams", 937 | "Unhooking UI hooks", 938 | "Restoring original handlers", 939 | "Refreshing view", 940 | "Original handlers restored", 941 | "No original handlers" 942 | ] 943 | 944 | for msg in auto_handler_messages: 945 | if msg in line: 946 | skip_line = True 947 | break 948 | 949 | if not skip_line: 950 | filtered_stdout_lines.append(line) 951 | 952 | filtered_stdout = "\n".join(filtered_stdout_lines) 953 | 954 | # Compile script results - ensure all fields are present 955 | result = { 956 | "stdout": filtered_stdout.strip() if filtered_stdout else "", 957 | "stderr": stderr.strip() if stderr else "", 958 | "success": True, 959 | "traceback": "" 960 | } 961 | 962 | # Check for return value 963 | if "result" in script_locals: 964 | try: 965 | print(f"Script returned value of type: {type(script_locals['result']).__name__}") 966 | result["return_value"] = str(script_locals["result"]) 967 | except Exception as rv_err: 968 | print(f"Error converting return value: {str(rv_err)}") 969 | result["stderr"] += f"\nError converting return value: {str(rv_err)}" 970 | result["return_value"] = "Error: Could not convert return value to string" 971 | 972 | print(f"Returning script result with keys: {', '.join(result.keys())}") 973 | return result 974 | except Exception as e: 975 | import traceback 976 | error_msg = str(e) 977 | tb = traceback.format_exc() 978 | print(f"Script execution error: {error_msg}") 979 | print(tb) 980 | return { 981 | "success": False, 982 | "stdout": stdout_capture.getvalue().strip() if stdout_capture else "", 983 | "stderr": stderr_capture.getvalue().strip() if stderr_capture else "", 984 | "error": error_msg, 985 | "traceback": tb 986 | } 987 | finally: 988 | # Restore original stdin/stdout/stderr 989 | print("Restoring IO streams") 990 | sys.stdout = original_stdout 991 | sys.stderr = original_stderr 992 | sys.stdin = original_stdin 993 | 994 | # Unhook UI hooks 995 | print("Unhooking UI hooks") 996 | hooks.unhook() 997 | 998 | # Restore original handlers 999 | print("Restoring original handlers") 1000 | self._restore_original_handlers() 1001 | 1002 | # Refresh view to show any changes made by script 1003 | print("Refreshing view") 1004 | self._refresh_view_internal() 1005 | except Exception as e: 1006 | print(f"Error in execute_script outer scope: {str(e)}") 1007 | traceback.print_exc() 1008 | return { 1009 | "success": False, 1010 | "stdout": "", 1011 | "stderr": "", 1012 | "error": str(e), 1013 | "traceback": traceback.format_exc() 1014 | } 1015 | 1016 | @idawrite 1017 | def execute_script_from_file(self, file_path: str) -> Dict[str, Any]: 1018 | """Execute a Python script from a file in IDA context""" 1019 | return self._execute_script_from_file_internal(file_path) 1020 | 1021 | def _execute_script_from_file_internal(self, file_path: str) -> Dict[str, Any]: 1022 | """Internal implementation for execute_script_from_file without sync wrapper""" 1023 | try: 1024 | # Check if file path is provided 1025 | if not file_path or not file_path.strip(): 1026 | return { 1027 | "success": False, 1028 | "error": "No file path provided", 1029 | "stdout": "", 1030 | "stderr": "", 1031 | "traceback": "" 1032 | } 1033 | 1034 | # Check if file exists 1035 | import os 1036 | if not os.path.exists(file_path): 1037 | return { 1038 | "success": False, 1039 | "error": f"Script file not found: {file_path}", 1040 | "stdout": "", 1041 | "stderr": "", 1042 | "traceback": "" 1043 | } 1044 | 1045 | try: 1046 | # Read script content 1047 | with open(file_path, 'r') as f: 1048 | script = f.read() 1049 | 1050 | # Execute script using internal method 1051 | return self._execute_script_internal(script) 1052 | except Exception as file_error: 1053 | print(f"Error reading or executing script file: {str(file_error)}") 1054 | traceback.print_exc() 1055 | return { 1056 | "success": False, 1057 | "stdout": "", 1058 | "stderr": "", 1059 | "error": f"Error with script file: {str(file_error)}", 1060 | "traceback": traceback.format_exc() 1061 | } 1062 | except Exception as e: 1063 | print(f"Error executing script from file: {str(e)}") 1064 | traceback.print_exc() 1065 | return { 1066 | "success": False, 1067 | "stdout": "", 1068 | "stderr": "", 1069 | "error": str(e), 1070 | "traceback": traceback.format_exc() 1071 | } 1072 | 1073 | def _create_ui_hooks(self) -> idaapi.UI_Hooks: 1074 | """Create UI hooks to suppress dialogs during script execution""" 1075 | try: 1076 | class DialogHook(idaapi.UI_Hooks): 1077 | def populating_widget_popup(self, widget, popup): 1078 | # Just suppress all popups 1079 | return 1 1080 | 1081 | def finish_populating_widget_popup(self, widget, popup): 1082 | # Also suppress here 1083 | return 1 1084 | 1085 | def ready_to_run(self): 1086 | # Always continue 1087 | return 1 1088 | 1089 | def updating_actions(self, ctx): 1090 | # Always continue 1091 | return 1 1092 | 1093 | def updated_actions(self): 1094 | # Always continue 1095 | return 1 1096 | 1097 | def ui_refresh(self, cnd): 1098 | # Suppress UI refreshes 1099 | return 1 1100 | 1101 | hooks = DialogHook() 1102 | return hooks 1103 | except Exception as e: 1104 | print(f"Error creating UI hooks: {str(e)}") 1105 | traceback.print_exc() 1106 | 1107 | # Create minimal dummy hooks that won't cause errors 1108 | class DummyHook: 1109 | def hook(self): 1110 | print("Using dummy hook (hook)") 1111 | pass 1112 | 1113 | def unhook(self): 1114 | print("Using dummy hook (unhook)") 1115 | pass 1116 | 1117 | return DummyHook() 1118 | 1119 | def _install_auto_handlers(self) -> None: 1120 | """Install auto-continue handlers for common dialogs""" 1121 | try: 1122 | import ida_kernwin 1123 | 1124 | # Save original handlers - with safer access to cvar.user_cancelled 1125 | self._original_handlers = {} 1126 | 1127 | # Try to access user_cancelled more safely 1128 | try: 1129 | if hasattr(ida_kernwin, 'cvar') and hasattr(ida_kernwin.cvar, 'user_cancelled'): 1130 | self._original_handlers["yn"] = ida_kernwin.cvar.user_cancelled 1131 | print("Found and saved user_cancelled handler") 1132 | except Exception as yn_err: 1133 | print(f"Note: Could not access user_cancelled: {str(yn_err)}") 1134 | 1135 | # Save other dialog handlers 1136 | if hasattr(ida_kernwin, 'ask_buttons'): 1137 | self._original_handlers["buttons"] = ida_kernwin.ask_buttons 1138 | 1139 | if hasattr(ida_kernwin, 'ask_text'): 1140 | self._original_handlers["text"] = ida_kernwin.ask_text 1141 | 1142 | if hasattr(ida_kernwin, 'ask_file'): 1143 | self._original_handlers["file"] = ida_kernwin.ask_file 1144 | 1145 | # Define auto handlers 1146 | def auto_yes_no(*args, **kwargs): 1147 | return 1 # Return "Yes" 1148 | 1149 | def auto_buttons(*args, **kwargs): 1150 | return 1 # Return first button 1151 | 1152 | def auto_text(*args, **kwargs): 1153 | return "" # Return empty text 1154 | 1155 | def auto_file(*args, **kwargs): 1156 | return "" # Return empty filename 1157 | 1158 | # Install auto handlers only for what we successfully saved 1159 | if "yn" in self._original_handlers: 1160 | try: 1161 | ida_kernwin.cvar.user_cancelled = auto_yes_no 1162 | print("Installed auto_yes_no handler") 1163 | except Exception as e: 1164 | print(f"Could not install auto_yes_no handler: {str(e)}") 1165 | 1166 | if "buttons" in self._original_handlers: 1167 | ida_kernwin.ask_buttons = auto_buttons 1168 | print("Installed auto_buttons handler") 1169 | 1170 | if "text" in self._original_handlers: 1171 | ida_kernwin.ask_text = auto_text 1172 | print("Installed auto_text handler") 1173 | 1174 | if "file" in self._original_handlers: 1175 | ida_kernwin.ask_file = auto_file 1176 | print("Installed auto_file handler") 1177 | 1178 | print(f"Auto handlers installed successfully. Installed handlers: {', '.join(self._original_handlers.keys())}") 1179 | except Exception as e: 1180 | print(f"Error installing auto handlers: {str(e)}") 1181 | traceback.print_exc() 1182 | # Ensure _original_handlers exists even on failure 1183 | if not hasattr(self, "_original_handlers"): 1184 | self._original_handlers = {} 1185 | 1186 | def _restore_original_handlers(self) -> None: 1187 | """Restore original dialog handlers""" 1188 | try: 1189 | if hasattr(self, "_original_handlers"): 1190 | import ida_kernwin 1191 | 1192 | # Restore original handlers (only what was successfully saved) 1193 | if "yn" in self._original_handlers: 1194 | try: 1195 | ida_kernwin.cvar.user_cancelled = self._original_handlers["yn"] 1196 | print("Restored user_cancelled handler") 1197 | except Exception as e: 1198 | print(f"Could not restore user_cancelled handler: {str(e)}") 1199 | 1200 | if "buttons" in self._original_handlers: 1201 | ida_kernwin.ask_buttons = self._original_handlers["buttons"] 1202 | print("Restored ask_buttons handler") 1203 | 1204 | if "text" in self._original_handlers: 1205 | ida_kernwin.ask_text = self._original_handlers["text"] 1206 | print("Restored ask_text handler") 1207 | 1208 | if "file" in self._original_handlers: 1209 | ida_kernwin.ask_file = self._original_handlers["file"] 1210 | print("Restored ask_file handler") 1211 | 1212 | saved_keys = list(self._original_handlers.keys()) 1213 | if saved_keys: 1214 | print(f"Original handlers restored: {', '.join(saved_keys)}") 1215 | else: 1216 | print("No original handlers were saved, nothing to restore") 1217 | else: 1218 | print("No original handlers dictionary to restore") 1219 | except Exception as e: 1220 | print(f"Error restoring original handlers: {str(e)}") 1221 | traceback.print_exc() ``` -------------------------------------------------------------------------------- /src/mcp_server_ida/server.py: -------------------------------------------------------------------------------- ```python 1 | import logging 2 | import socket 3 | import json 4 | import time 5 | import struct 6 | import uuid 7 | from typing import Dict, Any, List, Union, Optional, Tuple, Callable, TypeVar, Set, Awaitable, Type, cast 8 | from mcp.server import Server 9 | from mcp.server.stdio import stdio_server 10 | from mcp.types import ( 11 | TextContent, 12 | Tool, 13 | ) 14 | from enum import Enum 15 | from pydantic import BaseModel 16 | 17 | # Modify request models 18 | class GetFunctionAssemblyByName(BaseModel): 19 | function_name: str 20 | 21 | class GetFunctionAssemblyByAddress(BaseModel): 22 | address: str # Hexadecimal address as string 23 | 24 | class GetFunctionDecompiledByName(BaseModel): 25 | function_name: str 26 | 27 | class GetFunctionDecompiledByAddress(BaseModel): 28 | address: str # Hexadecimal address as string 29 | 30 | class GetGlobalVariableByName(BaseModel): 31 | variable_name: str 32 | 33 | class GetGlobalVariableByAddress(BaseModel): 34 | address: str # Hexadecimal address as string 35 | 36 | class GetCurrentFunctionAssembly(BaseModel): 37 | pass 38 | 39 | class GetCurrentFunctionDecompiled(BaseModel): 40 | pass 41 | 42 | class RenameLocalVariable(BaseModel): 43 | function_name: str 44 | old_name: str 45 | new_name: str 46 | 47 | class RenameGlobalVariable(BaseModel): 48 | old_name: str 49 | new_name: str 50 | 51 | class RenameFunction(BaseModel): 52 | old_name: str 53 | new_name: str 54 | 55 | class RenameMultiLocalVariables(BaseModel): 56 | function_name: str 57 | rename_pairs_old2new: List[Dict[str, str]] # List of dictionaries with "old_name" and "new_name" keys 58 | 59 | class RenameMultiGlobalVariables(BaseModel): 60 | rename_pairs_old2new: List[Dict[str, str]] 61 | 62 | class RenameMultiFunctions(BaseModel): 63 | rename_pairs_old2new: List[Dict[str, str]] 64 | 65 | class AddAssemblyComment(BaseModel): 66 | address: str # Can be a hexadecimal address string 67 | comment: str 68 | is_repeatable: bool = False # Whether the comment should be repeatable 69 | 70 | class AddFunctionComment(BaseModel): 71 | function_name: str 72 | comment: str 73 | is_repeatable: bool = False # Whether the comment should be repeatable 74 | 75 | class AddPseudocodeComment(BaseModel): 76 | function_name: str 77 | address: str # Address in the pseudocode 78 | comment: str 79 | is_repeatable: bool = False # Whether comment should be repeated at all occurrences 80 | 81 | class ExecuteScript(BaseModel): 82 | script: str 83 | 84 | class ExecuteScriptFromFile(BaseModel): 85 | file_path: str 86 | 87 | class IDATools(str, Enum): 88 | GET_FUNCTION_ASSEMBLY_BY_NAME = "ida_get_function_assembly_by_name" 89 | GET_FUNCTION_ASSEMBLY_BY_ADDRESS = "ida_get_function_assembly_by_address" 90 | GET_FUNCTION_DECOMPILED_BY_NAME = "ida_get_function_decompiled_by_name" 91 | GET_FUNCTION_DECOMPILED_BY_ADDRESS = "ida_get_function_decompiled_by_address" 92 | GET_GLOBAL_VARIABLE_BY_NAME = "ida_get_global_variable_by_name" 93 | GET_GLOBAL_VARIABLE_BY_ADDRESS = "ida_get_global_variable_by_address" 94 | GET_CURRENT_FUNCTION_ASSEMBLY = "ida_get_current_function_assembly" 95 | GET_CURRENT_FUNCTION_DECOMPILED = "ida_get_current_function_decompiled" 96 | RENAME_LOCAL_VARIABLE = "ida_rename_local_variable" 97 | RENAME_GLOBAL_VARIABLE = "ida_rename_global_variable" 98 | RENAME_FUNCTION = "ida_rename_function" 99 | RENAME_MULTI_LOCAL_VARIABLES = "ida_rename_multi_local_variables" 100 | RENAME_MULTI_GLOBAL_VARIABLES = "ida_rename_multi_global_variables" 101 | RENAME_MULTI_FUNCTIONS = "ida_rename_multi_functions" 102 | ADD_ASSEMBLY_COMMENT = "ida_add_assembly_comment" 103 | ADD_FUNCTION_COMMENT = "ida_add_function_comment" 104 | ADD_PSEUDOCODE_COMMENT = "ida_add_pseudocode_comment" 105 | EXECUTE_SCRIPT = "ida_execute_script" 106 | EXECUTE_SCRIPT_FROM_FILE = "ida_execute_script_from_file" 107 | 108 | # IDA Pro通信处理器 109 | class IDAProCommunicator: 110 | def __init__(self, host: str = 'localhost', port: int = 5000): 111 | self.host: str = host 112 | self.port: int = port 113 | self.sock: Optional[socket.socket] = None 114 | self.logger: logging.Logger = logging.getLogger(__name__) 115 | self.connected: bool = False 116 | self.reconnect_attempts: int = 0 117 | self.max_reconnect_attempts: int = 5 118 | self.last_reconnect_time: float = 0 119 | self.reconnect_cooldown: int = 5 # seconds 120 | self.request_count: int = 0 121 | self.default_timeout: int = 10 122 | self.batch_timeout: int = 60 # it may take more time for batch operations 123 | 124 | def connect(self) -> bool: 125 | """Connect to IDA plugin""" 126 | # Check if cooldown is needed 127 | current_time: float = time.time() 128 | if current_time - self.last_reconnect_time < self.reconnect_cooldown and self.reconnect_attempts > 0: 129 | self.logger.debug("In reconnection cooldown, skipping") 130 | return False 131 | 132 | # If already connected, disconnect first 133 | if self.connected: 134 | self.disconnect() 135 | 136 | try: 137 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 138 | self.sock.settimeout(self.default_timeout) 139 | self.sock.connect((self.host, self.port)) 140 | self.connected = True 141 | self.reconnect_attempts = 0 142 | self.logger.info(f"Connected to IDA Pro ({self.host}:{self.port})") 143 | return True 144 | except Exception as e: 145 | self.last_reconnect_time = current_time 146 | self.reconnect_attempts += 1 147 | if self.reconnect_attempts <= self.max_reconnect_attempts: 148 | self.logger.warning(f"Failed to connect to IDA Pro: {str(e)}. Attempt {self.reconnect_attempts}/{self.max_reconnect_attempts}") 149 | else: 150 | self.logger.error(f"Failed to connect to IDA Pro after {self.max_reconnect_attempts} attempts: {str(e)}") 151 | return False 152 | 153 | def disconnect(self) -> None: 154 | """Disconnect from IDA Pro""" 155 | if self.sock: 156 | try: 157 | self.sock.close() 158 | except: 159 | pass 160 | self.sock = None 161 | self.connected = False 162 | 163 | def ensure_connection(self) -> bool: 164 | """Ensure connection is established""" 165 | if not self.connected: 166 | return self.connect() 167 | return True 168 | 169 | def send_message(self, data: bytes) -> None: 170 | """Send message with length prefix""" 171 | if self.sock is None: 172 | raise ConnectionError("Socket is not connected") 173 | 174 | length: int = len(data) 175 | length_bytes: bytes = struct.pack('!I', length) # 4-byte length prefix 176 | self.sock.sendall(length_bytes + data) 177 | 178 | def receive_message(self) -> Optional[bytes]: 179 | """Receive message with length prefix""" 180 | try: 181 | # Receive 4-byte length prefix 182 | length_bytes: Optional[bytes] = self.receive_exactly(4) 183 | if not length_bytes: 184 | return None 185 | 186 | length: int = struct.unpack('!I', length_bytes)[0] 187 | 188 | # Receive message body 189 | data: Optional[bytes] = self.receive_exactly(length) 190 | return data 191 | except Exception as e: 192 | self.logger.error(f"Error receiving message: {str(e)}") 193 | return None 194 | 195 | def receive_exactly(self, n: int) -> Optional[bytes]: 196 | """Receive exactly n bytes of data""" 197 | if self.sock is None: 198 | raise ConnectionError("Socket is not connected") 199 | 200 | data: bytes = b'' 201 | while len(data) < n: 202 | chunk: bytes = self.sock.recv(min(n - len(data), 4096)) 203 | if not chunk: # Connection closed 204 | return None 205 | data += chunk 206 | return data 207 | 208 | def send_request(self, request_type: str, data: Dict[str, Any]) -> Dict[str, Any]: 209 | """Send request to IDA plugin""" 210 | # Ensure connection is established 211 | if not self.ensure_connection(): 212 | return {"error": "Cannot connect to IDA Pro"} 213 | 214 | try: 215 | if request_type in ["rename_multi_local_variables", 216 | "rename_multi_global_variables", 217 | "rename_multi_functions"]: 218 | if self.sock: 219 | self.sock.settimeout(self.batch_timeout) 220 | self.logger.debug(f"Set timeout to {self.batch_timeout}s for batch operation") 221 | else: 222 | if self.sock: 223 | self.sock.settimeout(self.default_timeout) 224 | self.logger.debug(f"Set timeout to {self.default_timeout}s for normal operation") 225 | 226 | # Add request ID 227 | request_id: str = str(uuid.uuid4()) 228 | self.request_count += 1 229 | request_count: int = self.request_count 230 | 231 | request: Dict[str, Any] = { 232 | "id": request_id, 233 | "count": request_count, 234 | "type": request_type, 235 | "data": data 236 | } 237 | 238 | self.logger.debug(f"Sending request: {request_id}, type: {request_type}, count: {request_count}") 239 | 240 | try: 241 | # Send request 242 | request_json: bytes = json.dumps(request).encode('utf-8') 243 | self.send_message(request_json) 244 | 245 | # Receive response 246 | response_data: Optional[bytes] = self.receive_message() 247 | 248 | # If no data received, assume connection is closed 249 | if not response_data: 250 | self.logger.warning("No data received, connection may be closed") 251 | self.disconnect() 252 | return {"error": "No response received from IDA Pro"} 253 | 254 | # Parse response 255 | try: 256 | self.logger.debug(f"Received raw data length: {len(response_data)}") 257 | response: Dict[str, Any] = json.loads(response_data.decode('utf-8')) 258 | 259 | # Verify response ID matches 260 | response_id: str = response.get("id") 261 | if response_id != request_id: 262 | self.logger.warning(f"Response ID mismatch! Request ID: {request_id}, Response ID: {response_id}") 263 | 264 | self.logger.debug(f"Received response: ID={response.get('id')}, count={response.get('count')}") 265 | 266 | # Additional type verification 267 | if not isinstance(response, dict): 268 | self.logger.error(f"Received response is not a dictionary: {type(response)}") 269 | return {"error": f"Response format error: expected dictionary, got {type(response).__name__}"} 270 | 271 | return response 272 | except json.JSONDecodeError as e: 273 | self.logger.error(f"Failed to parse JSON response: {str(e)}") 274 | return {"error": f"Invalid JSON response: {str(e)}"} 275 | 276 | except Exception as e: 277 | self.logger.error(f"Error communicating with IDA Pro: {str(e)}") 278 | self.disconnect() # Disconnect after error 279 | return {"error": str(e)} 280 | finally: 281 | # restore timeout 282 | if self.sock: 283 | self.sock.settimeout(self.default_timeout) 284 | 285 | def ping(self) -> bool: 286 | """Check if connection is valid""" 287 | response: Dict[str, Any] = self.send_request("ping", {}) 288 | return response.get("status") == "pong" 289 | 290 | # Actual IDA Pro functionality implementation 291 | class IDAProFunctions: 292 | def __init__(self, communicator: IDAProCommunicator): 293 | self.communicator: IDAProCommunicator = communicator 294 | self.logger: logging.Logger = logging.getLogger(__name__) 295 | 296 | def get_function_assembly(self, function_name: str) -> str: 297 | """Get assembly code for a function by name (legacy method)""" 298 | return self.get_function_assembly_by_name(function_name) 299 | 300 | def get_function_assembly_by_name(self, function_name: str) -> str: 301 | """Get assembly code for a function by its name""" 302 | try: 303 | response: Dict[str, Any] = self.communicator.send_request( 304 | "get_function_assembly_by_name", 305 | {"function_name": function_name} 306 | ) 307 | 308 | if "error" in response: 309 | return f"Error retrieving assembly for function '{function_name}': {response['error']}" 310 | 311 | assembly: Any = response.get("assembly") 312 | # Verify assembly is string type 313 | if assembly is None: 314 | return f"Error: No assembly data returned for function '{function_name}'" 315 | if not isinstance(assembly, str): 316 | self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion") 317 | assembly = str(assembly) 318 | 319 | return f"Assembly code for function '{function_name}':\n{assembly}" 320 | except Exception as e: 321 | self.logger.error(f"Error getting function assembly: {str(e)}", exc_info=True) 322 | return f"Error retrieving assembly for function '{function_name}': {str(e)}" 323 | 324 | def get_function_decompiled(self, function_name: str) -> str: 325 | """Get decompiled code for a function by name (legacy method)""" 326 | return self.get_function_decompiled_by_name(function_name) 327 | 328 | def get_function_decompiled_by_name(self, function_name: str) -> str: 329 | """Get decompiled pseudocode for a function by its name""" 330 | try: 331 | response: Dict[str, Any] = self.communicator.send_request( 332 | "get_function_decompiled_by_name", 333 | {"function_name": function_name} 334 | ) 335 | 336 | # Log complete response for debugging 337 | self.logger.debug(f"Decompilation response: {response}") 338 | 339 | if "error" in response: 340 | return f"Error retrieving decompiled code for function '{function_name}': {response['error']}" 341 | 342 | decompiled_code: Any = response.get("decompiled_code") 343 | 344 | # Detailed type checking and conversion 345 | if decompiled_code is None: 346 | return f"Error: No decompiled code returned for function '{function_name}'" 347 | 348 | # Log actual type 349 | actual_type: str = type(decompiled_code).__name__ 350 | self.logger.debug(f"Decompiled code type is: {actual_type}") 351 | 352 | # Ensure result is string 353 | if not isinstance(decompiled_code, str): 354 | self.logger.warning(f"Decompiled code type is not string but {actual_type}, attempting conversion") 355 | try: 356 | decompiled_code = str(decompiled_code) 357 | except Exception as e: 358 | return f"Error: Failed to convert decompiled code from {actual_type} to string: {str(e)}" 359 | 360 | return f"Decompiled code for function '{function_name}':\n{decompiled_code}" 361 | except Exception as e: 362 | self.logger.error(f"Error getting function decompiled code: {str(e)}", exc_info=True) 363 | return f"Error retrieving decompiled code for function '{function_name}': {str(e)}" 364 | 365 | def get_global_variable(self, variable_name: str) -> str: 366 | """Get global variable information by name (legacy method)""" 367 | return self.get_global_variable_by_name(variable_name) 368 | 369 | def get_global_variable_by_name(self, variable_name: str) -> str: 370 | """Get global variable information by its name""" 371 | try: 372 | response: Dict[str, Any] = self.communicator.send_request( 373 | "get_global_variable_by_name", 374 | {"variable_name": variable_name} 375 | ) 376 | 377 | if "error" in response: 378 | return f"Error retrieving global variable '{variable_name}': {response['error']}" 379 | 380 | variable_info: Any = response.get("variable_info") 381 | 382 | # Verify variable_info is string type 383 | if variable_info is None: 384 | return f"Error: No variable info returned for '{variable_name}'" 385 | if not isinstance(variable_info, str): 386 | self.logger.warning(f"Variable info type is not string but {type(variable_info).__name__}, attempting conversion") 387 | try: 388 | # If it's a dictionary, convert to JSON string first 389 | if isinstance(variable_info, dict): 390 | variable_info = json.dumps(variable_info, indent=2) 391 | else: 392 | variable_info = str(variable_info) 393 | except Exception as e: 394 | return f"Error: Failed to convert variable info to string: {str(e)}" 395 | 396 | return f"Global variable '{variable_name}':\n{variable_info}" 397 | except Exception as e: 398 | self.logger.error(f"Error getting global variable: {str(e)}", exc_info=True) 399 | return f"Error retrieving global variable '{variable_name}': {str(e)}" 400 | 401 | def get_global_variable_by_address(self, address: str) -> str: 402 | """Get global variable information by its address""" 403 | try: 404 | # Convert string address to int 405 | try: 406 | addr_int = int(address, 16) if address.startswith("0x") else int(address) 407 | except ValueError: 408 | return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal" 409 | 410 | response: Dict[str, Any] = self.communicator.send_request( 411 | "get_global_variable_by_address", 412 | {"address": addr_int} 413 | ) 414 | 415 | if "error" in response: 416 | return f"Error retrieving global variable at address '{address}': {response['error']}" 417 | 418 | variable_info: Any = response.get("variable_info") 419 | 420 | # Verify variable_info is string type 421 | if variable_info is None: 422 | return f"Error: No variable info returned for address '{address}'" 423 | if not isinstance(variable_info, str): 424 | self.logger.warning(f"Variable info type is not string but {type(variable_info).__name__}, attempting conversion") 425 | try: 426 | # If it's a dictionary, convert to JSON string first 427 | if isinstance(variable_info, dict): 428 | variable_info = json.dumps(variable_info, indent=2) 429 | else: 430 | variable_info = str(variable_info) 431 | except Exception as e: 432 | return f"Error: Failed to convert variable info to string: {str(e)}" 433 | 434 | # Try to extract the variable name from the JSON for a better message 435 | var_name = "Unknown" 436 | try: 437 | var_info_dict = json.loads(variable_info) 438 | if isinstance(var_info_dict, dict) and "name" in var_info_dict: 439 | var_name = var_info_dict["name"] 440 | except: 441 | pass 442 | 443 | return f"Global variable '{var_name}' at address {address}:\n{variable_info}" 444 | except Exception as e: 445 | self.logger.error(f"Error getting global variable by address: {str(e)}", exc_info=True) 446 | return f"Error retrieving global variable at address '{address}': {str(e)}" 447 | 448 | def get_current_function_assembly(self) -> str: 449 | """Get assembly code for the function at current cursor position""" 450 | try: 451 | response: Dict[str, Any] = self.communicator.send_request( 452 | "get_current_function_assembly", 453 | {} 454 | ) 455 | 456 | if "error" in response: 457 | return f"Error retrieving assembly for current function: {response['error']}" 458 | 459 | assembly: Any = response.get("assembly") 460 | function_name: str = response.get("function_name", "Current function") 461 | 462 | # Verify assembly is string type 463 | if assembly is None: 464 | return f"Error: No assembly data returned for current function" 465 | if not isinstance(assembly, str): 466 | self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion") 467 | assembly = str(assembly) 468 | 469 | return f"Assembly code for function '{function_name}':\n{assembly}" 470 | except Exception as e: 471 | self.logger.error(f"Error getting current function assembly: {str(e)}", exc_info=True) 472 | return f"Error retrieving assembly for current function: {str(e)}" 473 | 474 | def get_current_function_decompiled(self) -> str: 475 | """Get decompiled code for the function at current cursor position""" 476 | try: 477 | response: Dict[str, Any] = self.communicator.send_request( 478 | "get_current_function_decompiled", 479 | {} 480 | ) 481 | 482 | if "error" in response: 483 | return f"Error retrieving decompiled code for current function: {response['error']}" 484 | 485 | decompiled_code: Any = response.get("decompiled_code") 486 | function_name: str = response.get("function_name", "Current function") 487 | 488 | # Detailed type checking and conversion 489 | if decompiled_code is None: 490 | return f"Error: No decompiled code returned for current function" 491 | 492 | # Ensure result is string 493 | if not isinstance(decompiled_code, str): 494 | self.logger.warning(f"Decompiled code type is not string but {type(decompiled_code).__name__}, attempting conversion") 495 | try: 496 | decompiled_code = str(decompiled_code) 497 | except Exception as e: 498 | return f"Error: Failed to convert decompiled code: {str(e)}" 499 | 500 | return f"Decompiled code for function '{function_name}':\n{decompiled_code}" 501 | except Exception as e: 502 | self.logger.error(f"Error getting current function decompiled code: {str(e)}", exc_info=True) 503 | return f"Error retrieving decompiled code for current function: {str(e)}" 504 | 505 | def rename_local_variable(self, function_name: str, old_name: str, new_name: str) -> str: 506 | """Rename a local variable within a function""" 507 | try: 508 | response: Dict[str, Any] = self.communicator.send_request( 509 | "rename_local_variable", 510 | {"function_name": function_name, "old_name": old_name, "new_name": new_name} 511 | ) 512 | 513 | if "error" in response: 514 | return f"Error renaming local variable from '{old_name}' to '{new_name}' in function '{function_name}': {response['error']}" 515 | 516 | success: bool = response.get("success", False) 517 | message: str = response.get("message", "") 518 | 519 | if success: 520 | return f"Successfully renamed local variable from '{old_name}' to '{new_name}' in function '{function_name}': {message}" 521 | else: 522 | return f"Failed to rename local variable from '{old_name}' to '{new_name}' in function '{function_name}': {message}" 523 | except Exception as e: 524 | self.logger.error(f"Error renaming local variable: {str(e)}", exc_info=True) 525 | return f"Error renaming local variable from '{old_name}' to '{new_name}' in function '{function_name}': {str(e)}" 526 | 527 | def rename_global_variable(self, old_name: str, new_name: str) -> str: 528 | """Rename a global variable""" 529 | try: 530 | response: Dict[str, Any] = self.communicator.send_request( 531 | "rename_global_variable", 532 | {"old_name": old_name, "new_name": new_name} 533 | ) 534 | 535 | if "error" in response: 536 | return f"Error renaming global variable from '{old_name}' to '{new_name}': {response['error']}" 537 | 538 | success: bool = response.get("success", False) 539 | message: str = response.get("message", "") 540 | 541 | if success: 542 | return f"Successfully renamed global variable from '{old_name}' to '{new_name}': {message}" 543 | else: 544 | return f"Failed to rename global variable from '{old_name}' to '{new_name}': {message}" 545 | except Exception as e: 546 | self.logger.error(f"Error renaming global variable: {str(e)}", exc_info=True) 547 | return f"Error renaming global variable from '{old_name}' to '{new_name}': {str(e)}" 548 | 549 | def rename_function(self, old_name: str, new_name: str) -> str: 550 | """Rename a function""" 551 | try: 552 | response: Dict[str, Any] = self.communicator.send_request( 553 | "rename_function", 554 | {"old_name": old_name, "new_name": new_name} 555 | ) 556 | 557 | if "error" in response: 558 | return f"Error renaming function from '{old_name}' to '{new_name}': {response['error']}" 559 | 560 | success: bool = response.get("success", False) 561 | message: str = response.get("message", "") 562 | 563 | if success: 564 | return f"Successfully renamed function from '{old_name}' to '{new_name}': {message}" 565 | else: 566 | return f"Failed to rename function from '{old_name}' to '{new_name}': {message}" 567 | except Exception as e: 568 | self.logger.error(f"Error renaming function: {str(e)}", exc_info=True) 569 | return f"Error renaming function from '{old_name}' to '{new_name}': {str(e)}" 570 | 571 | def rename_multi_local_variables(self, function_name: str, rename_pairs_old2new: List[Dict[str, str]]) -> str: 572 | """Rename multiple local variables within a function at once""" 573 | try: 574 | response: Dict[str, Any] = self.communicator.send_request( 575 | "rename_multi_local_variables", 576 | { 577 | "function_name": function_name, 578 | "rename_pairs_old2new": rename_pairs_old2new 579 | } 580 | ) 581 | 582 | if "error" in response: 583 | return f"Error renaming multiple local variables in function '{function_name}': {response['error']}" 584 | 585 | success_count: int = response.get("success_count", 0) 586 | failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", []) 587 | 588 | result_parts: List[str] = [ 589 | f"Successfully renamed {success_count} local variables in function '{function_name}'" 590 | ] 591 | 592 | if failed_pairs: 593 | result_parts.append("\nFailed renamings:") 594 | for pair in failed_pairs: 595 | result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}") 596 | 597 | return "\n".join(result_parts) 598 | except Exception as e: 599 | self.logger.error(f"Error renaming multiple local variables: {str(e)}", exc_info=True) 600 | return f"Error renaming multiple local variables in function '{function_name}': {str(e)}" 601 | 602 | def rename_multi_global_variables(self, rename_pairs_old2new: List[Dict[str, str]]) -> str: 603 | """Rename multiple global variables at once""" 604 | try: 605 | response: Dict[str, Any] = self.communicator.send_request( 606 | "rename_multi_global_variables", 607 | {"rename_pairs_old2new": rename_pairs_old2new} 608 | ) 609 | 610 | if "error" in response: 611 | return f"Error renaming multiple global variables: {response['error']}" 612 | 613 | success_count: int = response.get("success_count", 0) 614 | failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", []) 615 | 616 | result_parts: List[str] = [ 617 | f"Successfully renamed {success_count} global variables" 618 | ] 619 | 620 | if failed_pairs: 621 | result_parts.append("\nFailed renamings:") 622 | for pair in failed_pairs: 623 | result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}") 624 | 625 | return "\n".join(result_parts) 626 | except Exception as e: 627 | self.logger.error(f"Error renaming multiple global variables: {str(e)}", exc_info=True) 628 | return f"Error renaming multiple global variables: {str(e)}" 629 | 630 | def rename_multi_functions(self, rename_pairs_old2new: List[Dict[str, str]]) -> str: 631 | """Rename multiple functions at once""" 632 | try: 633 | response: Dict[str, Any] = self.communicator.send_request( 634 | "rename_multi_functions", 635 | {"rename_pairs_old2new": rename_pairs_old2new} 636 | ) 637 | 638 | if "error" in response: 639 | return f"Error renaming multiple functions: {response['error']}" 640 | 641 | success_count: int = response.get("success_count", 0) 642 | failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", []) 643 | 644 | result_parts: List[str] = [ 645 | f"Successfully renamed {success_count} functions" 646 | ] 647 | 648 | if failed_pairs: 649 | result_parts.append("\nFailed renamings:") 650 | for pair in failed_pairs: 651 | result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}") 652 | 653 | return "\n".join(result_parts) 654 | except Exception as e: 655 | self.logger.error(f"Error renaming multiple functions: {str(e)}", exc_info=True) 656 | return f"Error renaming multiple functions: {str(e)}" 657 | 658 | def add_assembly_comment(self, address: str, comment: str, is_repeatable: bool = False) -> str: 659 | """Add an assembly comment""" 660 | try: 661 | response: Dict[str, Any] = self.communicator.send_request( 662 | "add_assembly_comment", 663 | {"address": address, "comment": comment, "is_repeatable": is_repeatable} 664 | ) 665 | 666 | if "error" in response: 667 | return f"Error adding assembly comment at address '{address}': {response['error']}" 668 | 669 | success: bool = response.get("success", False) 670 | message: str = response.get("message", "") 671 | 672 | if success: 673 | comment_type: str = "repeatable" if is_repeatable else "regular" 674 | return f"Successfully added {comment_type} assembly comment at address '{address}': {message}" 675 | else: 676 | return f"Failed to add assembly comment at address '{address}': {message}" 677 | except Exception as e: 678 | self.logger.error(f"Error adding assembly comment: {str(e)}", exc_info=True) 679 | return f"Error adding assembly comment at address '{address}': {str(e)}" 680 | 681 | def add_function_comment(self, function_name: str, comment: str, is_repeatable: bool = False) -> str: 682 | """Add a comment to a function""" 683 | try: 684 | response: Dict[str, Any] = self.communicator.send_request( 685 | "add_function_comment", 686 | {"function_name": function_name, "comment": comment, "is_repeatable": is_repeatable} 687 | ) 688 | 689 | if "error" in response: 690 | return f"Error adding comment to function '{function_name}': {response['error']}" 691 | 692 | success: bool = response.get("success", False) 693 | message: str = response.get("message", "") 694 | 695 | if success: 696 | comment_type: str = "repeatable" if is_repeatable else "regular" 697 | return f"Successfully added {comment_type} comment to function '{function_name}': {message}" 698 | else: 699 | return f"Failed to add comment to function '{function_name}': {message}" 700 | except Exception as e: 701 | self.logger.error(f"Error adding function comment: {str(e)}", exc_info=True) 702 | return f"Error adding comment to function '{function_name}': {str(e)}" 703 | 704 | def add_pseudocode_comment(self, function_name: str, address: str, comment: str, is_repeatable: bool = False) -> str: 705 | """Add a comment to a specific address in the function's decompiled pseudocode""" 706 | try: 707 | response: Dict[str, Any] = self.communicator.send_request( 708 | "add_pseudocode_comment", 709 | { 710 | "function_name": function_name, 711 | "address": address, 712 | "comment": comment, 713 | "is_repeatable": is_repeatable 714 | } 715 | ) 716 | 717 | if "error" in response: 718 | return f"Error adding comment at address {address} in function '{function_name}': {response['error']}" 719 | 720 | success: bool = response.get("success", False) 721 | message: str = response.get("message", "") 722 | 723 | if success: 724 | comment_type: str = "repeatable" if is_repeatable else "regular" 725 | return f"Successfully added {comment_type} comment at address {address} in function '{function_name}': {message}" 726 | else: 727 | return f"Failed to add comment at address {address} in function '{function_name}': {message}" 728 | except Exception as e: 729 | self.logger.error(f"Error adding pseudocode comment: {str(e)}", exc_info=True) 730 | return f"Error adding comment at address {address} in function '{function_name}': {str(e)}" 731 | 732 | def execute_script(self, script: str) -> str: 733 | """Execute a Python script in IDA Pro and return its output. The script runs in IDA's context with access to all IDA API modules.""" 734 | try: 735 | response: Dict[str, Any] = self.communicator.send_request( 736 | "execute_script", 737 | {"script": script} 738 | ) 739 | 740 | # Handle case where response is None 741 | if response is None: 742 | self.logger.error("Received None response from IDA when executing script") 743 | return "Error executing script: Received empty response from IDA" 744 | 745 | # Handle case where response contains error 746 | if "error" in response: 747 | return f"Error executing script: {response['error']}" 748 | 749 | # Handle successful execution 750 | success: bool = response.get("success", False) 751 | if not success: 752 | error_msg: str = response.get("error", "Unknown error") 753 | traceback: str = response.get("traceback", "") 754 | return f"Script execution failed: {error_msg}\n\nTraceback:\n{traceback}" 755 | 756 | # Get output - ensure all values are strings to avoid None errors 757 | stdout: str = str(response.get("stdout", "")) 758 | stderr: str = str(response.get("stderr", "")) 759 | return_value: str = str(response.get("return_value", "")) 760 | 761 | result_text: List[str] = [] 762 | result_text.append("Script executed successfully") 763 | 764 | if return_value and return_value != "None": 765 | result_text.append(f"\nReturn value:\n{return_value}") 766 | 767 | if stdout: 768 | result_text.append(f"\nStandard output:\n{stdout}") 769 | 770 | if stderr: 771 | result_text.append(f"\nStandard error:\n{stderr}") 772 | 773 | return "\n".join(result_text) 774 | 775 | except Exception as e: 776 | self.logger.error(f"Error executing script: {str(e)}", exc_info=True) 777 | return f"Error executing script: {str(e)}" 778 | 779 | def execute_script_from_file(self, file_path: str) -> str: 780 | """Execute a Python script from a file path in IDA Pro and return its output. The file should be accessible from IDA's process.""" 781 | try: 782 | response: Dict[str, Any] = self.communicator.send_request( 783 | "execute_script_from_file", 784 | {"file_path": file_path} 785 | ) 786 | 787 | # Handle case where response is None 788 | if response is None: 789 | self.logger.error("Received None response from IDA when executing script from file") 790 | return f"Error executing script from file '{file_path}': Received empty response from IDA" 791 | 792 | # Handle case where response contains error 793 | if "error" in response: 794 | return f"Error executing script from file '{file_path}': {response['error']}" 795 | 796 | # Handle successful execution 797 | success: bool = response.get("success", False) 798 | if not success: 799 | error_msg: str = response.get("error", "Unknown error") 800 | traceback: str = response.get("traceback", "") 801 | return f"Script execution from file '{file_path}' failed: {error_msg}\n\nTraceback:\n{traceback}" 802 | 803 | # Get output - ensure all values are strings to avoid None errors 804 | stdout: str = str(response.get("stdout", "")) 805 | stderr: str = str(response.get("stderr", "")) 806 | return_value: str = str(response.get("return_value", "")) 807 | 808 | result_text: List[str] = [] 809 | result_text.append(f"Script from file '{file_path}' executed successfully") 810 | 811 | if return_value and return_value != "None": 812 | result_text.append(f"\nReturn value:\n{return_value}") 813 | 814 | if stdout: 815 | result_text.append(f"\nStandard output:\n{stdout}") 816 | 817 | if stderr: 818 | result_text.append(f"\nStandard error:\n{stderr}") 819 | 820 | return "\n".join(result_text) 821 | 822 | except Exception as e: 823 | self.logger.error(f"Error executing script from file: {str(e)}", exc_info=True) 824 | return f"Error executing script from file '{file_path}': {str(e)}" 825 | 826 | def get_function_assembly_by_address(self, address: str) -> str: 827 | """Get assembly code for a function by its address""" 828 | try: 829 | # Convert string address to int 830 | try: 831 | addr_int = int(address, 16) if address.startswith("0x") else int(address) 832 | except ValueError: 833 | return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal" 834 | 835 | response: Dict[str, Any] = self.communicator.send_request( 836 | "get_function_assembly_by_address", 837 | {"address": addr_int} 838 | ) 839 | 840 | if "error" in response: 841 | return f"Error retrieving assembly for address '{address}': {response['error']}" 842 | 843 | assembly: Any = response.get("assembly") 844 | function_name: str = response.get("function_name", "Unknown function") 845 | 846 | # Verify assembly is string type 847 | if assembly is None: 848 | return f"Error: No assembly data returned for address '{address}'" 849 | if not isinstance(assembly, str): 850 | self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion") 851 | assembly = str(assembly) 852 | 853 | return f"Assembly code for function '{function_name}' at address {address}:\n{assembly}" 854 | except Exception as e: 855 | self.logger.error(f"Error getting function assembly by address: {str(e)}", exc_info=True) 856 | return f"Error retrieving assembly for address '{address}': {str(e)}" 857 | 858 | def get_function_decompiled_by_address(self, address: str) -> str: 859 | """Get decompiled pseudocode for a function by its address""" 860 | try: 861 | # Convert string address to int 862 | try: 863 | addr_int = int(address, 16) if address.startswith("0x") else int(address) 864 | except ValueError: 865 | return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal" 866 | 867 | response: Dict[str, Any] = self.communicator.send_request( 868 | "get_function_decompiled_by_address", 869 | {"address": addr_int} 870 | ) 871 | 872 | if "error" in response: 873 | return f"Error retrieving decompiled code for address '{address}': {response['error']}" 874 | 875 | decompiled_code: Any = response.get("decompiled_code") 876 | function_name: str = response.get("function_name", "Unknown function") 877 | 878 | # Detailed type checking and conversion 879 | if decompiled_code is None: 880 | return f"Error: No decompiled code returned for address '{address}'" 881 | 882 | # Ensure result is string 883 | if not isinstance(decompiled_code, str): 884 | self.logger.warning(f"Decompiled code type is not string but {type(decompiled_code).__name__}, attempting conversion") 885 | try: 886 | decompiled_code = str(decompiled_code) 887 | except Exception as e: 888 | return f"Error: Failed to convert decompiled code: {str(e)}" 889 | 890 | return f"Decompiled code for function '{function_name}' at address {address}:\n{decompiled_code}" 891 | except Exception as e: 892 | self.logger.error(f"Error getting function decompiled code by address: {str(e)}", exc_info=True) 893 | return f"Error retrieving decompiled code for address '{address}': {str(e)}" 894 | 895 | async def serve() -> None: 896 | """MCP server main entry point""" 897 | logger: logging.Logger = logging.getLogger(__name__) 898 | # Set log level to DEBUG for detailed information 899 | logger.setLevel(logging.DEBUG) 900 | server: Server = Server("mcp-ida") 901 | 902 | # Create communicator and attempt connection 903 | ida_communicator: IDAProCommunicator = IDAProCommunicator() 904 | logger.info("Attempting to connect to IDA Pro plugin...") 905 | 906 | if ida_communicator.connect(): 907 | logger.info("Successfully connected to IDA Pro plugin") 908 | else: 909 | logger.warning("Initial connection to IDA Pro plugin failed, will retry on request") 910 | 911 | # Create IDA functions class with persistent connection 912 | ida_functions: IDAProFunctions = IDAProFunctions(ida_communicator) 913 | 914 | @server.list_tools() 915 | async def list_tools() -> List[Tool]: 916 | """List supported tools""" 917 | return [ 918 | Tool( 919 | name=IDATools.GET_FUNCTION_ASSEMBLY_BY_NAME, 920 | description="Get assembly code for a function by name", 921 | inputSchema=GetFunctionAssemblyByName.schema(), 922 | ), 923 | Tool( 924 | name=IDATools.GET_FUNCTION_ASSEMBLY_BY_ADDRESS, 925 | description="Get assembly code for a function by address", 926 | inputSchema=GetFunctionAssemblyByAddress.schema(), 927 | ), 928 | Tool( 929 | name=IDATools.GET_FUNCTION_DECOMPILED_BY_NAME, 930 | description="Get decompiled pseudocode for a function by name", 931 | inputSchema=GetFunctionDecompiledByName.schema(), 932 | ), 933 | Tool( 934 | name=IDATools.GET_FUNCTION_DECOMPILED_BY_ADDRESS, 935 | description="Get decompiled pseudocode for a function by address", 936 | inputSchema=GetFunctionDecompiledByAddress.schema(), 937 | ), 938 | Tool( 939 | name=IDATools.GET_GLOBAL_VARIABLE_BY_NAME, 940 | description="Get information about a global variable by name", 941 | inputSchema=GetGlobalVariableByName.schema(), 942 | ), 943 | Tool( 944 | name=IDATools.GET_GLOBAL_VARIABLE_BY_ADDRESS, 945 | description="Get information about a global variable by address", 946 | inputSchema=GetGlobalVariableByAddress.schema(), 947 | ), 948 | Tool( 949 | name=IDATools.GET_CURRENT_FUNCTION_ASSEMBLY, 950 | description="Get assembly code for the function at the current cursor position", 951 | inputSchema=GetCurrentFunctionAssembly.schema(), 952 | ), 953 | Tool( 954 | name=IDATools.GET_CURRENT_FUNCTION_DECOMPILED, 955 | description="Get decompiled pseudocode for the function at the current cursor position", 956 | inputSchema=GetCurrentFunctionDecompiled.schema(), 957 | ), 958 | Tool( 959 | name=IDATools.RENAME_LOCAL_VARIABLE, 960 | description="Rename a local variable within a function in the IDA database", 961 | inputSchema=RenameLocalVariable.schema(), 962 | ), 963 | Tool( 964 | name=IDATools.RENAME_GLOBAL_VARIABLE, 965 | description="Rename a global variable in the IDA database", 966 | inputSchema=RenameGlobalVariable.schema(), 967 | ), 968 | Tool( 969 | name=IDATools.RENAME_FUNCTION, 970 | description="Rename a function in the IDA database", 971 | inputSchema=RenameFunction.schema(), 972 | ), 973 | Tool( 974 | name=IDATools.RENAME_MULTI_LOCAL_VARIABLES, 975 | description="Rename multiple local variables within a function at once in the IDA database", 976 | inputSchema=RenameMultiLocalVariables.schema(), 977 | ), 978 | Tool( 979 | name=IDATools.RENAME_MULTI_GLOBAL_VARIABLES, 980 | description="Rename multiple global variables at once in the IDA database", 981 | inputSchema=RenameMultiGlobalVariables.schema(), 982 | ), 983 | Tool( 984 | name=IDATools.RENAME_MULTI_FUNCTIONS, 985 | description="Rename multiple functions at once in the IDA database", 986 | inputSchema=RenameMultiFunctions.schema(), 987 | ), 988 | Tool( 989 | name=IDATools.ADD_ASSEMBLY_COMMENT, 990 | description="Add a comment at a specific address in the assembly view of the IDA database", 991 | inputSchema=AddAssemblyComment.schema(), 992 | ), 993 | Tool( 994 | name=IDATools.ADD_FUNCTION_COMMENT, 995 | description="Add a comment to a function in the IDA database", 996 | inputSchema=AddFunctionComment.schema(), 997 | ), 998 | Tool( 999 | name=IDATools.ADD_PSEUDOCODE_COMMENT, 1000 | description="Add a comment to a specific address in the function's decompiled pseudocode", 1001 | inputSchema=AddPseudocodeComment.schema(), 1002 | ), 1003 | Tool( 1004 | name=IDATools.EXECUTE_SCRIPT, 1005 | description="Execute a Python script in IDA Pro and return its output. The script runs in IDA's context with access to all IDA API modules.", 1006 | inputSchema=ExecuteScript.schema(), 1007 | ), 1008 | Tool( 1009 | name=IDATools.EXECUTE_SCRIPT_FROM_FILE, 1010 | description="Execute a Python script from a file path in IDA Pro and return its output. The file should be accessible from IDA's process.", 1011 | inputSchema=ExecuteScriptFromFile.schema(), 1012 | ), 1013 | ] 1014 | 1015 | @server.call_tool() 1016 | async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: 1017 | """Call tool and handle results""" 1018 | # Ensure connection exists 1019 | if not ida_communicator.connected and not ida_communicator.ensure_connection(): 1020 | return [TextContent( 1021 | type="text", 1022 | text=f"Error: Cannot connect to IDA Pro plugin. Please ensure the plugin is running." 1023 | )] 1024 | 1025 | try: 1026 | match name: 1027 | case IDATools.GET_FUNCTION_ASSEMBLY_BY_NAME: 1028 | assembly: str = ida_functions.get_function_assembly_by_name(arguments["function_name"]) 1029 | return [TextContent( 1030 | type="text", 1031 | text=assembly 1032 | )] 1033 | 1034 | case IDATools.GET_FUNCTION_ASSEMBLY_BY_ADDRESS: 1035 | assembly: str = ida_functions.get_function_assembly_by_address(arguments["address"]) 1036 | return [TextContent( 1037 | type="text", 1038 | text=assembly 1039 | )] 1040 | 1041 | case IDATools.GET_FUNCTION_DECOMPILED_BY_NAME: 1042 | decompiled: str = ida_functions.get_function_decompiled_by_name(arguments["function_name"]) 1043 | return [TextContent( 1044 | type="text", 1045 | text=decompiled 1046 | )] 1047 | 1048 | case IDATools.GET_FUNCTION_DECOMPILED_BY_ADDRESS: 1049 | decompiled: str = ida_functions.get_function_decompiled_by_address(arguments["address"]) 1050 | return [TextContent( 1051 | type="text", 1052 | text=decompiled 1053 | )] 1054 | 1055 | case IDATools.GET_GLOBAL_VARIABLE_BY_NAME: 1056 | variable_info: str = ida_functions.get_global_variable_by_name(arguments["variable_name"]) 1057 | return [TextContent( 1058 | type="text", 1059 | text=variable_info 1060 | )] 1061 | 1062 | case IDATools.GET_GLOBAL_VARIABLE_BY_ADDRESS: 1063 | variable_info: str = ida_functions.get_global_variable_by_address(arguments["address"]) 1064 | return [TextContent( 1065 | type="text", 1066 | text=variable_info 1067 | )] 1068 | 1069 | case IDATools.GET_CURRENT_FUNCTION_ASSEMBLY: 1070 | assembly: str = ida_functions.get_current_function_assembly() 1071 | return [TextContent( 1072 | type="text", 1073 | text=assembly 1074 | )] 1075 | 1076 | case IDATools.GET_CURRENT_FUNCTION_DECOMPILED: 1077 | decompiled: str = ida_functions.get_current_function_decompiled() 1078 | return [TextContent( 1079 | type="text", 1080 | text=decompiled 1081 | )] 1082 | 1083 | case IDATools.RENAME_LOCAL_VARIABLE: 1084 | result: str = ida_functions.rename_local_variable( 1085 | arguments["function_name"], 1086 | arguments["old_name"], 1087 | arguments["new_name"] 1088 | ) 1089 | return [TextContent( 1090 | type="text", 1091 | text=result 1092 | )] 1093 | 1094 | case IDATools.RENAME_GLOBAL_VARIABLE: 1095 | result: str = ida_functions.rename_global_variable( 1096 | arguments["old_name"], 1097 | arguments["new_name"] 1098 | ) 1099 | return [TextContent( 1100 | type="text", 1101 | text=result 1102 | )] 1103 | 1104 | case IDATools.RENAME_FUNCTION: 1105 | result: str = ida_functions.rename_function( 1106 | arguments["old_name"], 1107 | arguments["new_name"] 1108 | ) 1109 | return [TextContent( 1110 | type="text", 1111 | text=result 1112 | )] 1113 | 1114 | case IDATools.RENAME_MULTI_LOCAL_VARIABLES: 1115 | result: str = ida_functions.rename_multi_local_variables( 1116 | arguments["function_name"], 1117 | arguments["rename_pairs_old2new"] 1118 | ) 1119 | return [TextContent( 1120 | type="text", 1121 | text=result 1122 | )] 1123 | 1124 | case IDATools.RENAME_MULTI_GLOBAL_VARIABLES: 1125 | result: str = ida_functions.rename_multi_global_variables( 1126 | arguments["rename_pairs_old2new"] 1127 | ) 1128 | return [TextContent( 1129 | type="text", 1130 | text=result 1131 | )] 1132 | 1133 | case IDATools.RENAME_MULTI_FUNCTIONS: 1134 | result: str = ida_functions.rename_multi_functions( 1135 | arguments["rename_pairs_old2new"] 1136 | ) 1137 | return [TextContent( 1138 | type="text", 1139 | text=result 1140 | )] 1141 | 1142 | case IDATools.ADD_ASSEMBLY_COMMENT: 1143 | result: str = ida_functions.add_assembly_comment( 1144 | arguments["address"], 1145 | arguments["comment"], 1146 | arguments.get("is_repeatable", False) 1147 | ) 1148 | return [TextContent( 1149 | type="text", 1150 | text=result 1151 | )] 1152 | 1153 | case IDATools.ADD_FUNCTION_COMMENT: 1154 | result: str = ida_functions.add_function_comment( 1155 | arguments["function_name"], 1156 | arguments["comment"], 1157 | arguments.get("is_repeatable", False) 1158 | ) 1159 | return [TextContent( 1160 | type="text", 1161 | text=result 1162 | )] 1163 | 1164 | case IDATools.ADD_PSEUDOCODE_COMMENT: 1165 | result: str = ida_functions.add_pseudocode_comment( 1166 | arguments["function_name"], 1167 | arguments["address"], 1168 | arguments["comment"], 1169 | arguments.get("is_repeatable", False) 1170 | ) 1171 | return [TextContent( 1172 | type="text", 1173 | text=result 1174 | )] 1175 | 1176 | case IDATools.EXECUTE_SCRIPT: 1177 | try: 1178 | if "script" not in arguments or not arguments["script"]: 1179 | return [TextContent( 1180 | type="text", 1181 | text="Error: No script content provided" 1182 | )] 1183 | 1184 | result: str = ida_functions.execute_script(arguments["script"]) 1185 | return [TextContent( 1186 | type="text", 1187 | text=result 1188 | )] 1189 | except Exception as e: 1190 | logger.error(f"Error executing script: {str(e)}", exc_info=True) 1191 | return [TextContent( 1192 | type="text", 1193 | text=f"Error executing script: {str(e)}" 1194 | )] 1195 | 1196 | case IDATools.EXECUTE_SCRIPT_FROM_FILE: 1197 | try: 1198 | if "file_path" not in arguments or not arguments["file_path"]: 1199 | return [TextContent( 1200 | type="text", 1201 | text="Error: No file path provided" 1202 | )] 1203 | 1204 | result: str = ida_functions.execute_script_from_file(arguments["file_path"]) 1205 | return [TextContent( 1206 | type="text", 1207 | text=result 1208 | )] 1209 | except Exception as e: 1210 | logger.error(f"Error executing script from file: {str(e)}", exc_info=True) 1211 | return [TextContent( 1212 | type="text", 1213 | text=f"Error executing script from file: {str(e)}" 1214 | )] 1215 | 1216 | case _: 1217 | raise ValueError(f"Unknown tool: {name}") 1218 | except Exception as e: 1219 | logger.error(f"Error calling tool: {str(e)}", exc_info=True) 1220 | return [TextContent( 1221 | type="text", 1222 | text=f"Error executing {name}: {str(e)}" 1223 | )] 1224 | 1225 | options = server.create_initialization_options() 1226 | async with stdio_server() as (read_stream, write_stream): 1227 | await server.run(read_stream, write_stream, options, raise_exceptions=True) 1228 | ```