# Directory Structure ``` ├── .gitignore ├── LICENSE ├── plugin │ ├── ida_mcp_server_plugin │ │ ├── __init__.py │ │ └── ida_mcp_core.py │ └── ida_mcp_server_plugin.py ├── pyproject.toml ├── README.md ├── Screenshots │ ├── iShot_2025-03-15_18.54.53.png │ ├── iShot_2025-03-15_19.04.06.png │ └── iShot_2025-03-15_19.06.27.png ├── src │ └── mcp_server_ida │ ├── __init__.py │ ├── __main__.py │ └── server.py ├── test │ └── IDA Debug.py └── uv.lock ``` # Files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` .DS_Store __pycache__ ida_mcp_server.egg-info alternatives ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- ```markdown # IDA MCP Server > [!NOTE] > The idalib mode is under development, and it will not require installing the IDA plugin or running IDA (idalib is available from IDA Pro 9.0+). ## Overview A Model Context Protocol server for IDA interaction and automation. This server provides tools to read IDA database via Large Language Models. Please note that mcp-server-ida is currently in early development. The functionality and available tools are subject to change and expansion as we continue to develop and improve the server. ## Installation ### Using uv (recommended) When using [`uv`](https://docs.astral.sh/uv/) no specific installation is needed. We will use [`uvx`](https://docs.astral.sh/uv/guides/tools/) to directly run *mcp-server-ida*. ### Using PIP Alternatively you can install `mcp-server-ida` via pip: ``` pip install mcp-server-ida ``` After installation, you can run it as a script using: ``` python -m mcp_server_ida ``` ### IDA-Side Copy `repository/plugin/ida_mcp_server_plugin.py` and `repository/plugin/ida_mcp_server_plugin` directory into IDAs plugin directory Windows: `%APPDATA%\Hex-Rays\IDA Pro\plugins` Linux/macOS: `$HOME/.idapro/plugins` eg: `~/.idapro/plugins` [igors-tip-of-the-week-103-sharing-plugins-between-ida-installs](https://hex-rays.com/blog/igors-tip-of-the-week-103-sharing-plugins-between-ida-installs) ## Configuration ### Usage with Claude Desktop Add this to your `claude_desktop_config.json`: <details> <summary>Using uvx</summary> ```json "mcpServers": { "ida": { "command": "uvx", "args": [ "mcp-server-ida" ] } } ``` </details> <details> <summary>Using pip installation</summary> ```json "mcpServers": { "ida": { "command": "python", "args": [ "-m", "mcp_server_ida" ] } } ``` </details> ## Debugging You can use the MCP inspector to debug the server. For uvx installations: ``` npx @modelcontextprotocol/inspector uvx mcp-server-ida ``` Or if you've installed the package in a specific directory or are developing on it: ``` cd path/to/mcp-server-ida/src npx @modelcontextprotocol/inspector uv run mcp-server-ida ``` Running `tail -n 20 -f ~/Library/Logs/Claude/mcp*.log` will show the logs from the server and may help you debug any issues. ## Development If you are doing local development, there are two ways to test your changes: 1. Run the MCP inspector to test your changes. See [Debugging](#debugging) for run instructions. 2. Test using the Claude desktop app. Add the following to your `claude_desktop_config.json`: ### UVX ```json { "mcpServers": { "ida": { "command": "uv", "args": [ "--directory", "/<path to mcp-server-ida>", "run", "mcp-server-ida" ] } } ``` ## Alternatives [ida-pro-mcp](https://github.com/mrexodia/ida-pro-mcp) [ida-mcp-server-plugin](https://github.com/taida957789/ida-mcp-server-plugin) [mcp-server-idapro](https://github.com/fdrechsler/mcp-server-idapro) [pcm](https://github.com/rand-tech/pcm) ## Screenshots    ``` -------------------------------------------------------------------------------- /plugin/ida_mcp_server_plugin/__init__.py: -------------------------------------------------------------------------------- ```python ``` -------------------------------------------------------------------------------- /src/mcp_server_ida/__main__.py: -------------------------------------------------------------------------------- ```python from mcp_server_ida import main main() ``` -------------------------------------------------------------------------------- /src/mcp_server_ida/__init__.py: -------------------------------------------------------------------------------- ```python import click import logging import sys from .server import serve @click.command() @click.option("-v", "--verbose", count=True) def main(verbose: bool) -> None: """MCP IDA Server - IDA functionality for MCP""" import asyncio logging_level = logging.WARN if verbose == 1: logging_level = logging.INFO elif verbose >= 2: logging_level = logging.DEBUG logging.basicConfig(level=logging_level, stream=sys.stderr) asyncio.run(serve()) if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- ```toml [project] name = "mcp-server-ida" version = "0.3.0" description = "A Model Context Protocol server providing tools to read, search IDA Database programmatically via LLMs" # readme = "README.md" requires-python = ">=3.10" authors = [{ name = "Mx-Iris" }] keywords = ["ida", "mcp", "llm", "automation"] license = { text = "MIT" } classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", ] dependencies = [ "click>=8.1.7", "mcp>=1.0.0", "pydantic>=2.0.0", ] [project.scripts] mcp-server-ida = "mcp_server_ida:main" [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.uv] dev-dependencies = ["pyright>=1.1.389", "ruff>=0.7.3", "pytest>=8.0.0"] ``` -------------------------------------------------------------------------------- /plugin/ida_mcp_server_plugin.py: -------------------------------------------------------------------------------- ```python import idaapi import json import socket import struct import threading import traceback import time from typing import Optional, Dict, Any, List, Tuple, Union, Set, Type, cast from ida_mcp_server_plugin.ida_mcp_core import IDAMCPCore PLUGIN_NAME = "IDA MCP Server" PLUGIN_HOTKEY = "Ctrl-Alt-M" PLUGIN_VERSION = "1.0" PLUGIN_AUTHOR = "IDA MCP" # Default configuration DEFAULT_HOST = "localhost" DEFAULT_PORT = 5000 class IDACommunicator: """IDA Communication class""" def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT): self.host: str = host self.port: int = port self.socket: Optional[socket.socket] = None def connect(self) -> None: pass class IDAMCPServer: def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT): self.host: str = host self.port: int = port self.server_socket: Optional[socket.socket] = None self.running: bool = False self.thread: Optional[threading.Thread] = None self.client_counter: int = 0 self.core: IDAMCPCore = IDAMCPCore() def start(self) -> bool: """Start Socket server""" if self.running: print("MCP Server already running") return False try: self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server_socket.bind((self.host, self.port)) self.server_socket.listen(5) # self.server_socket.settimeout(1.0) # Set timeout to allow server to respond to stop requests self.running = True self.thread = threading.Thread(target=self.server_loop) self.thread.daemon = True self.thread.start() print(f"MCP Server started on {self.host}:{self.port}") return True except Exception as e: print(f"Failed to start MCP Server: {str(e)}") traceback.print_exc() return False def stop(self) -> None: """Stop Socket server""" if not self.running: print("MCP Server is not running, no need to stop") return print("Stopping MCP Server...") self.running = False if self.server_socket: try: self.server_socket.close() except Exception as e: print(f"Error closing server socket: {str(e)}") self.server_socket = None if self.thread: try: self.thread.join(2.0) # Wait for thread to end, maximum 2 seconds except Exception as e: print(f"Error joining server thread: {str(e)}") self.thread = None print("MCP Server stopped") def send_message(self, client_socket: socket.socket, data: bytes) -> None: """Send message with length prefix""" length: int = len(data) length_bytes: bytes = struct.pack('!I', length) # 4-byte length prefix client_socket.sendall(length_bytes + data) def receive_message(self, client_socket: socket.socket) -> bytes: """Receive message with length prefix""" # Receive 4-byte length prefix length_bytes: bytes = self.receive_exactly(client_socket, 4) if not length_bytes: raise ConnectionError("Connection closed") length: int = struct.unpack('!I', length_bytes)[0] # Receive message body data: bytes = self.receive_exactly(client_socket, length) return data def receive_exactly(self, client_socket: socket.socket, n: int) -> bytes: """Receive exactly n bytes of data""" data: bytes = b'' while len(data) < n: chunk: bytes = client_socket.recv(min(n - len(data), 4096)) if not chunk: # Connection closed raise ConnectionError("Connection closed, unable to receive complete data") data += chunk return data def server_loop(self) -> None: """Server main loop""" print("Server loop started") while self.running: try: # Use timeout receive to periodically check running flag try: client_socket, client_address = self.server_socket.accept() self.client_counter += 1 client_id: int = self.client_counter print(f"Client #{client_id} connected from {client_address}") # Handle client request - use thread to support multiple clients client_thread: threading.Thread = threading.Thread( target=self.handle_client, args=(client_socket, client_id) ) client_thread.daemon = True client_thread.start() except socket.timeout: # Timeout is just for periodically checking running flag continue except OSError as e: if self.running: # Only print error if server is running if e.errno == 9: # Bad file descriptor, usually means socket is closed print("Server socket was closed") break print(f"Socket error: {str(e)}") except Exception as e: if self.running: # Only print error if server is running print(f"Error accepting connection: {str(e)}") traceback.print_exc() except Exception as e: if self.running: print(f"Error in server loop: {str(e)}") traceback.print_exc() time.sleep(1) # Avoid high CPU usage print("Server loop ended") def handle_client(self, client_socket: socket.socket, client_id: int) -> None: """Handle client requests""" try: # Set timeout client_socket.settimeout(30) while self.running: try: # Receive message data: bytes = self.receive_message(client_socket) # Parse request request: Dict[str, Any] = json.loads(data.decode('utf-8')) request_type: str = request.get('type') request_data: Dict[str, Any] = request.get('data', {}) request_id: str = request.get('id', 'unknown') request_count: int = request.get('count', -1) print(f"Client #{client_id} request: {request_type}, ID: {request_id}, Count: {request_count}") # Handle different types of requests response: Dict[str, Any] = { "id": request_id, # Return same request ID "count": request_count # Return same request count } if request_type == "get_function_assembly_by_name": response.update(self.core.get_function_assembly_by_name(request_data.get("function_name", ""))) elif request_type == "get_function_assembly_by_address": response.update(self.core.get_function_assembly_by_address(request_data.get("address", 0))) elif request_type == "get_function_decompiled_by_name": response.update(self.core.get_function_decompiled_by_name(request_data.get("function_name", ""))) elif request_type == "get_function_decompiled_by_address": response.update(self.core.get_function_decompiled_by_address(request_data.get("address", 0))) elif request_type == "get_global_variable_by_name": response.update(self.core.get_global_variable_by_name(request_data.get("variable_name", ""))) elif request_type == "get_global_variable_by_address": response.update(self.core.get_global_variable_by_address(request_data.get("address", 0))) elif request_type == "get_current_function_assembly": response.update(self.core.get_current_function_assembly()) elif request_type == "get_current_function_decompiled": response.update(self.core.get_current_function_decompiled()) elif request_type == "rename_global_variable": response.update(self.core.rename_global_variable( request_data.get("old_name", ""), request_data.get("new_name", "") )) elif request_type == "rename_function": response.update(self.core.rename_function( request_data.get("old_name", ""), request_data.get("new_name", "") )) # Backward compatibility with old method names elif request_type == "get_function_assembly": response.update(self.core.get_function_assembly_by_name(request_data.get("function_name", ""))) elif request_type == "get_function_decompiled": response.update(self.core.get_function_decompiled_by_name(request_data.get("function_name", ""))) elif request_type == "get_global_variable": response.update(self.core.get_global_variable_by_name(request_data.get("variable_name", ""))) elif request_type == "add_assembly_comment": response.update(self.core.add_assembly_comment( request_data.get("address", ""), request_data.get("comment", ""), request_data.get("is_repeatable", False) )) elif request_type == "rename_local_variable": response.update(self.core.rename_local_variable( request_data.get("function_name", ""), request_data.get("old_name", ""), request_data.get("new_name", "") )) elif request_type == "add_function_comment": response.update(self.core.add_function_comment( request_data.get("function_name", ""), request_data.get("comment", ""), request_data.get("is_repeatable", False) )) elif request_type == "ping": response["status"] = "pong" elif request_type == "add_pseudocode_comment": response.update(self.core.add_pseudocode_comment( request_data.get("function_name", ""), request_data.get("address", ""), request_data.get("comment", ""), request_data.get("is_repeatable", False) )) elif request_type == "execute_script": response.update(self.core.execute_script( request_data.get("script", "") )) elif request_type == "execute_script_from_file": response.update(self.core.execute_script_from_file( request_data.get("file_path", "") )) elif request_type == "refresh_view": response.update(self.core.refresh_view()) elif request_type == "rename_multi_local_variables": response.update(self.core.rename_multi_local_variables( request_data.get("function_name", ""), request_data.get("rename_pairs_old2new", []) )) elif request_type == "rename_multi_global_variables": response.update(self.core.rename_multi_global_variables( request_data.get("rename_pairs_old2new", []) )) elif request_type == "rename_multi_functions": response.update(self.core.rename_multi_functions( request_data.get("rename_pairs_old2new", []) )) else: response["error"] = f"Unknown request type: {request_type}" # Verify response is correct if not isinstance(response, dict): print(f"Response is not a dictionary: {type(response).__name__}") response = { "id": request_id, "count": request_count, "error": f"Internal server error: response is not a dictionary but {type(response).__name__}" } # Ensure all values in response are serializable for key, value in list(response.items()): if value is None: response[key] = "null" elif not isinstance(value, (str, int, float, bool, list, dict, tuple)): print(f"Response key '{key}' has non-serializable type: {type(value).__name__}") response[key] = str(value) # Send response response_json: bytes = json.dumps(response).encode('utf-8') self.send_message(client_socket, response_json) print(f"Sent response to client #{client_id}, ID: {request_id}, Count: {request_count}") except ConnectionError as e: print(f"Connection with client #{client_id} lost: {str(e)}") return except socket.timeout: # print(f"Socket timeout with client #{client_id}") continue except json.JSONDecodeError as e: print(f"Invalid JSON request from client #{client_id}: {str(e)}") try: response: Dict[str, Any] = { "error": f"Invalid JSON request: {str(e)}" } self.send_message(client_socket, json.dumps(response).encode('utf-8')) except: print(f"Failed to send error response to client #{client_id}") except Exception as e: print(f"Error processing request from client #{client_id}: {str(e)}") traceback.print_exc() try: response: Dict[str, Any] = { "error": str(e) } self.send_message(client_socket, json.dumps(response).encode('utf-8')) except: print(f"Failed to send error response to client #{client_id}") except Exception as e: print(f"Error handling client #{client_id}: {str(e)}") traceback.print_exc() finally: try: client_socket.close() except: pass print(f"Client #{client_id} connection closed") # IDA Plugin class class IDAMCPPlugin(idaapi.plugin_t): flags = idaapi.PLUGIN_KEEP comment = "IDA MCP Server Plugin" help = "Provides MCP server functionality for IDA" wanted_name = PLUGIN_NAME wanted_hotkey = PLUGIN_HOTKEY def __init__(self): super(IDAMCPPlugin, self).__init__() self.server: Optional[IDAMCPServer] = None self.initialized: bool = False self.menu_items_added: bool = False print(f"IDAMCPPlugin instance created") def init(self) -> int: """Plugin initialization""" try: print(f"{PLUGIN_NAME} v{PLUGIN_VERSION} by {PLUGIN_AUTHOR}") print("Initializing plugin...") # Add menu items if not self.menu_items_added: self.create_menu_items() self.menu_items_added = True print("Menu items added") # Mark as initialized self.initialized = True print("Plugin initialized successfully") # Delay server start to avoid initialization issues idaapi.register_timer(500, self._delayed_server_start) return idaapi.PLUGIN_KEEP except Exception as e: print(f"Error initializing plugin: {str(e)}") traceback.print_exc() return idaapi.PLUGIN_SKIP def _delayed_server_start(self) -> int: """Delayed server start to avoid initialization race conditions""" try: if not self.server or not self.server.running: print("Delayed server start...") self.start_server() except Exception as e: print(f"Error in delayed server start: {str(e)}") traceback.print_exc() return -1 # Don't repeat def create_menu_items(self) -> None: """Create plugin menu items""" # Create menu items menu_path: str = "Edit/Plugins/" class StartServerHandler(idaapi.action_handler_t): def __init__(self, plugin: 'IDAMCPPlugin'): idaapi.action_handler_t.__init__(self) self.plugin: 'IDAMCPPlugin' = plugin def activate(self, ctx) -> int: self.plugin.start_server() return 1 def update(self, ctx) -> int: return idaapi.AST_ENABLE_ALWAYS class StopServerHandler(idaapi.action_handler_t): def __init__(self, plugin: 'IDAMCPPlugin'): idaapi.action_handler_t.__init__(self) self.plugin: 'IDAMCPPlugin' = plugin def activate(self, ctx) -> int: self.plugin.stop_server() return 1 def update(self, ctx) -> int: return idaapi.AST_ENABLE_ALWAYS try: # Register and add start server action start_action_name: str = "mcp:start_server" start_action_desc: idaapi.action_desc_t = idaapi.action_desc_t( start_action_name, "Start MCP Server", StartServerHandler(self), "Ctrl+Alt+S", "Start the MCP Server", 199 # Icon ID ) # Register and add stop server action stop_action_name: str = "mcp:stop_server" stop_action_desc: idaapi.action_desc_t = idaapi.action_desc_t( stop_action_name, "Stop MCP Server", StopServerHandler(self), "Ctrl+Alt+X", "Stop the MCP Server", 200 # Icon ID ) # Register actions if not idaapi.register_action(start_action_desc): print("Failed to register start server action") if not idaapi.register_action(stop_action_desc): print("Failed to register stop server action") # Add to menu if not idaapi.attach_action_to_menu(menu_path + "Start MCP Server", start_action_name, idaapi.SETMENU_APP): print("Failed to attach start server action to menu") if not idaapi.attach_action_to_menu(menu_path + "Stop MCP Server", stop_action_name, idaapi.SETMENU_APP): print("Failed to attach stop server action to menu") print("Menu items created successfully") except Exception as e: print(f"Error creating menu items: {str(e)}") traceback.print_exc() def start_server(self) -> None: """Start server""" if self.server and self.server.running: print("MCP Server is already running") return try: print("Creating MCP Server instance...") self.server = IDAMCPServer() print("Starting MCP Server...") if self.server.start(): print("MCP Server started successfully") else: print("Failed to start MCP Server") except Exception as e: print(f"Error starting server: {str(e)}") traceback.print_exc() def stop_server(self) -> None: """Stop server""" if not self.server: print("MCP Server instance does not exist") return if not self.server.running: print("MCP Server is not running") return try: self.server.stop() print("MCP Server stopped by user") except Exception as e: print(f"Error stopping server: {str(e)}") traceback.print_exc() def run(self, arg) -> None: """Execute when hotkey is pressed""" if not self.initialized: print("Plugin not initialized") return # Automatically start or stop server when hotkey is triggered try: if not self.server or not self.server.running: print("Hotkey triggered: starting server") self.start_server() else: print("Hotkey triggered: stopping server") self.stop_server() except Exception as e: print(f"Error in run method: {str(e)}") traceback.print_exc() def term(self) -> None: """Plugin termination""" try: if self.server and self.server.running: print("Terminating plugin: stopping server") self.server.stop() print(f"{PLUGIN_NAME} terminated") except Exception as e: print(f"Error terminating plugin: {str(e)}") traceback.print_exc() # Register plugin def PLUGIN_ENTRY() -> IDAMCPPlugin: return IDAMCPPlugin() ``` -------------------------------------------------------------------------------- /plugin/ida_mcp_server_plugin/ida_mcp_core.py: -------------------------------------------------------------------------------- ```python import idaapi import idautils import ida_funcs import ida_hexrays import ida_bytes import ida_name import ida_segment import ida_lines import idc import json import traceback import functools import queue from typing import Any, Callable, TypeVar, Optional, Dict, List, Union, Tuple, Type # Type variable for function return type T = TypeVar('T') class IDASyncError(Exception): """Exception raised for IDA synchronization errors""" pass # Global call stack to track synchronization calls call_stack: queue.LifoQueue[str] = queue.LifoQueue() def sync_wrapper(func: Callable[..., T], sync_type: int) -> T: """ Wrapper function to execute a function in IDA's main thread Args: func: The function to execute sync_type: Synchronization type (MFF_READ or MFF_WRITE) Returns: The result of the function execution """ if sync_type not in [idaapi.MFF_READ, idaapi.MFF_WRITE]: error_str = f'Invalid sync type {sync_type} for function {func.__name__}' print(error_str) raise IDASyncError(error_str) # Container for the result result_container: queue.Queue[Any] = queue.Queue() def execute_in_main_thread() -> int: # Check if we're already inside a sync_wrapper call if not call_stack.empty(): last_func = call_stack.get() error_str = f'Nested sync call detected: function {func.__name__} called from {last_func}' print(error_str) call_stack.put(last_func) # Put it back raise IDASyncError(error_str) # Add function to call stack call_stack.put(func.__name__) try: # Execute function and store result result_container.put(func()) except Exception as e: print(f"Error in {func.__name__}: {str(e)}") traceback.print_exc() result_container.put(None) finally: # Always remove function from call stack call_stack.get() return 1 # Required by execute_sync # Execute in IDA's main thread idaapi.execute_sync(execute_in_main_thread, sync_type) # Return the result return result_container.get() def idaread(func: Callable[..., T]) -> Callable[..., T]: """ Decorator for functions that read from the IDA database Args: func: The function to decorate Returns: Decorated function that executes in IDA's main thread with read access """ @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: # Create a partial function with the arguments partial_func = functools.partial(func, *args, **kwargs) # Preserve the original function name partial_func.__name__ = func.__name__ # Execute with sync_wrapper return sync_wrapper(partial_func, idaapi.MFF_READ) return wrapper def idawrite(func: Callable[..., T]) -> Callable[..., T]: """ Decorator for functions that write to the IDA database Args: func: The function to decorate Returns: Decorated function that executes in IDA's main thread with write access """ @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: # Create a partial function with the arguments partial_func = functools.partial(func, *args, **kwargs) # Preserve the original function name partial_func.__name__ = func.__name__ # Execute with sync_wrapper return sync_wrapper(partial_func, idaapi.MFF_WRITE) return wrapper class IDAMCPCore: """Core functionality implementation class for IDA MCP""" @idaread def get_function_assembly_by_name(self, function_name: str) -> Dict[str, Any]: """Get assembly code for a function by its name""" try: # Get function address from name func = idaapi.get_func(idaapi.get_name_ea(0, function_name)) if not func: return {"error": f"Function '{function_name}' not found"} # Call address-based implementation result = self._get_function_assembly_by_address_internal(func.start_ea) # If successful, add function name to result if "error" not in result: result["function_name"] = function_name return result except Exception as e: traceback.print_exc() return {"error": str(e)} @idaread def get_function_assembly_by_address(self, address: int) -> Dict[str, Any]: """Get assembly code for a function by its address""" return self._get_function_assembly_by_address_internal(address) def _get_function_assembly_by_address_internal(self, address: int) -> Dict[str, Any]: """Internal implementation for get_function_assembly_by_address without sync wrapper""" try: # Get function object func = ida_funcs.get_func(address) # Get function name func_name = idaapi.get_func_name(func.start_ea) if not func: return {"error": f"Invalid function at {hex(address)}"} # Collect all assembly instructions assembly_lines = [] for instr_addr in idautils.FuncItems(address): disasm = idc.GetDisasm(instr_addr) assembly_lines.append(f"{hex(instr_addr)}: {disasm}") if not assembly_lines: return {"error": "No assembly instructions found"} return {"assembly": "\n".join(assembly_lines), "function_name": func_name} except Exception as e: print(f"Error getting function assembly: {str(e)}") traceback.print_exc() return {"error": str(e)} @idaread def get_function_decompiled_by_name(self, function_name: str) -> Dict[str, Any]: """Get decompiled code for a function by its name""" try: # Get function address from name func_addr = idaapi.get_name_ea(0, function_name) if func_addr == idaapi.BADADDR: return {"error": f"Function '{function_name}' not found"} # Call internal implementation without decorator result = self._get_function_decompiled_by_address_internal(func_addr) # If successful, add function name to result if "error" not in result: result["function_name"] = function_name return result except Exception as e: traceback.print_exc() return {"error": str(e)} @idaread def get_function_decompiled_by_address(self, address: int) -> Dict[str, Any]: """Get decompiled code for a function by its address""" return self._get_function_decompiled_by_address_internal(address) def _get_function_decompiled_by_address_internal(self, address: int) -> Dict[str, Any]: """Internal implementation for get_function_decompiled_by_address without sync wrapper""" try: # Get function from address func = idaapi.get_func(address) if not func: return {"error": f"No function found at address 0x{address:X}"} # Get function name func_name = idaapi.get_func_name(func.start_ea) # Try to import decompiler module try: import ida_hexrays except ImportError: return {"error": "Hex-Rays decompiler is not available"} # Check if decompiler is available if not ida_hexrays.init_hexrays_plugin(): return {"error": "Unable to initialize Hex-Rays decompiler"} # Get decompiled function cfunc = None try: cfunc = ida_hexrays.decompile(func.start_ea) except Exception as e: return {"error": f"Unable to decompile function: {str(e)}"} if not cfunc: return {"error": "Decompilation failed"} # Get pseudocode as string decompiled_code = str(cfunc) return {"decompiled_code": decompiled_code, "function_name": func_name} except Exception as e: traceback.print_exc() return {"error": str(e)} @idaread def get_current_function_assembly(self) -> Dict[str, Any]: """Get assembly code for the function at the current cursor position""" try: # Get current address curr_addr = idaapi.get_screen_ea() if curr_addr == idaapi.BADADDR: return {"error": "No valid cursor position"} # Use the internal implementation without decorator return self._get_function_assembly_by_address_internal(curr_addr) except Exception as e: traceback.print_exc() return {"error": str(e)} @idaread def get_current_function_decompiled(self) -> Dict[str, Any]: """Get decompiled code for the function at the current cursor position""" try: # Get current address curr_addr = idaapi.get_screen_ea() if curr_addr == idaapi.BADADDR: return {"error": "No valid cursor position"} # Use the internal implementation without decorator return self._get_function_decompiled_by_address_internal(curr_addr) except Exception as e: traceback.print_exc() return {"error": str(e)} @idaread def get_global_variable_by_name(self, variable_name: str) -> Dict[str, Any]: """Get global variable information by its name""" try: # Get variable address var_addr: int = ida_name.get_name_ea(0, variable_name) if var_addr == idaapi.BADADDR: return {"error": f"Global variable '{variable_name}' not found"} # Call internal implementation result = self._get_global_variable_by_address_internal(var_addr) # If successful, add variable name to result if "error" not in result and "variable_info" in result: # Parse the JSON string back to dict to modify it var_info = json.loads(result["variable_info"]) var_info["name"] = variable_name # Convert back to JSON string result["variable_info"] = json.dumps(var_info, indent=2) return result except Exception as e: print(f"Error getting global variable by name: {str(e)}") traceback.print_exc() return {"error": str(e)} @idaread def get_global_variable_by_address(self, address: int) -> Dict[str, Any]: """Get global variable information by its address""" return self._get_global_variable_by_address_internal(address) def _get_global_variable_by_address_internal(self, address: int) -> Dict[str, Any]: """Internal implementation for get_global_variable_by_address without sync wrapper""" try: # Verify address is valid if address == idaapi.BADADDR: return {"error": f"Invalid address: {hex(address)}"} # Get variable name if available variable_name = ida_name.get_name(address) if not variable_name: variable_name = f"unnamed_{hex(address)}" # Get variable segment segment: Optional[ida_segment.segment_t] = ida_segment.getseg(address) if not segment: return {"error": f"No segment found for address {hex(address)}"} segment_name: str = ida_segment.get_segm_name(segment) segment_class: str = ida_segment.get_segm_class(segment) # Get variable type tinfo = idaapi.tinfo_t() guess_type: bool = idaapi.guess_tinfo(tinfo, address) type_str: str = tinfo.get_type_name() if guess_type else "unknown" # Try to get variable value size: int = ida_bytes.get_item_size(address) if size <= 0: size = 8 # Default to 8 bytes # Read data based on size value: Optional[int] = None if size == 1: value = ida_bytes.get_byte(address) elif size == 2: value = ida_bytes.get_word(address) elif size == 4: value = ida_bytes.get_dword(address) elif size == 8: value = ida_bytes.get_qword(address) # Build variable info var_info: Dict[str, Any] = { "name": variable_name, "address": hex(address), "segment": segment_name, "segment_class": segment_class, "type": type_str, "size": size, "value": hex(value) if value is not None else "N/A" } # If it's a string, try to read string content if ida_bytes.is_strlit(ida_bytes.get_flags(address)): str_value = idc.get_strlit_contents(address, -1, 0) if str_value: try: var_info["string_value"] = str_value.decode('utf-8', errors='replace') except: var_info["string_value"] = str(str_value) return {"variable_info": json.dumps(var_info, indent=2)} except Exception as e: print(f"Error getting global variable by address: {str(e)}") traceback.print_exc() return {"error": str(e)} @idawrite def rename_global_variable(self, old_name: str, new_name: str) -> Dict[str, Any]: """Rename a global variable""" return self._rename_global_variable_internal(old_name, new_name) def _rename_global_variable_internal(self, old_name: str, new_name: str) -> Dict[str, Any]: """Internal implementation for rename_global_variable without sync wrapper""" try: # Get variable address var_addr: int = ida_name.get_name_ea(0, old_name) if var_addr == idaapi.BADADDR: return {"success": False, "message": f"Variable '{old_name}' not found"} # Check if new name is already in use if ida_name.get_name_ea(0, new_name) != idaapi.BADADDR: return {"success": False, "message": f"Name '{new_name}' is already in use"} # Try to rename if not ida_name.set_name(var_addr, new_name): return {"success": False, "message": f"Failed to rename variable, possibly due to invalid name format or other IDA restrictions"} # Refresh view self._refresh_view_internal() return {"success": True, "message": f"Variable renamed from '{old_name}' to '{new_name}' at address {hex(var_addr)}"} except Exception as e: print(f"Error renaming variable: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def rename_function(self, old_name: str, new_name: str) -> Dict[str, Any]: """Rename a function""" return self._rename_function_internal(old_name, new_name) def _rename_function_internal(self, old_name: str, new_name: str) -> Dict[str, Any]: """Internal implementation for rename_function without sync wrapper""" try: # Get function address func_addr: int = ida_name.get_name_ea(0, old_name) if func_addr == idaapi.BADADDR: return {"success": False, "message": f"Function '{old_name}' not found"} # Check if it's a function func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) if not func: return {"success": False, "message": f"'{old_name}' is not a function"} # Check if new name is already in use if ida_name.get_name_ea(0, new_name) != idaapi.BADADDR: return {"success": False, "message": f"Name '{new_name}' is already in use"} # Try to rename if not ida_name.set_name(func_addr, new_name): return {"success": False, "message": f"Failed to rename function, possibly due to invalid name format or other IDA restrictions"} # Refresh view self._refresh_view_internal() return {"success": True, "message": f"Function renamed from '{old_name}' to '{new_name}' at address {hex(func_addr)}"} except Exception as e: print(f"Error renaming function: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def add_assembly_comment(self, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: """Add an assembly comment""" return self._add_assembly_comment_internal(address, comment, is_repeatable) def _add_assembly_comment_internal(self, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: """Internal implementation for add_assembly_comment without sync wrapper""" try: # Convert address string to integer addr: int if isinstance(address, str): if address.startswith("0x"): addr = int(address, 16) else: try: addr = int(address, 16) # Try parsing as hex except ValueError: try: addr = int(address) # Try parsing as decimal except ValueError: return {"success": False, "message": f"Invalid address format: {address}"} else: addr = address # Check if address is valid if addr == idaapi.BADADDR or not ida_bytes.is_loaded(addr): return {"success": False, "message": f"Invalid or unloaded address: {hex(addr)}"} # Add comment result: bool = idc.set_cmt(addr, comment, is_repeatable) if result: # Refresh view self._refresh_view_internal() comment_type: str = "repeatable" if is_repeatable else "regular" return {"success": True, "message": f"Added {comment_type} assembly comment at address {hex(addr)}"} else: return {"success": False, "message": f"Failed to add assembly comment at address {hex(addr)}"} except Exception as e: print(f"Error adding assembly comment: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def rename_local_variable(self, function_name: str, old_name: str, new_name: str) -> Dict[str, Any]: """Rename a local variable within a function""" return self._rename_local_variable_internal(function_name, old_name, new_name) def _rename_local_variable_internal(self, function_name: str, old_name: str, new_name: str) -> Dict[str, Any]: """Internal implementation for rename_local_variable without sync wrapper""" try: # Parameter validation if not function_name: return {"success": False, "message": "Function name cannot be empty"} if not old_name: return {"success": False, "message": "Old variable name cannot be empty"} if not new_name: return {"success": False, "message": "New variable name cannot be empty"} # Get function address func_addr: int = ida_name.get_name_ea(0, function_name) if func_addr == idaapi.BADADDR: return {"success": False, "message": f"Function '{function_name}' not found"} # Check if it's a function func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) if not func: return {"success": False, "message": f"'{function_name}' is not a function"} # Check if decompiler is available if not ida_hexrays.init_hexrays_plugin(): return {"success": False, "message": "Hex-Rays decompiler is not available"} # Get decompilation result cfunc: Optional[ida_hexrays.cfunc_t] = ida_hexrays.decompile(func_addr) if not cfunc: return {"success": False, "message": f"Failed to decompile function '{function_name}'"} ida_hexrays.open_pseudocode(func_addr, 0) # Find local variable to rename found: bool = False renamed: bool = False lvar: Optional[ida_hexrays.lvar_t] = None # Iterate through all local variables lvars = cfunc.get_lvars() for i in range(lvars.size()): v = lvars[i] if v.name == old_name: lvar = v found = True break if not found: return {"success": False, "message": f"Local variable '{old_name}' not found in function '{function_name}'"} # Rename local variable if ida_hexrays.rename_lvar(cfunc.entry_ea, lvar.name, new_name): renamed = True if renamed: # Refresh view self._refresh_view_internal() return {"success": True, "message": f"Local variable renamed from '{old_name}' to '{new_name}' in function '{function_name}'"} else: return {"success": False, "message": f"Failed to rename local variable from '{old_name}' to '{new_name}', possibly due to invalid name format or other IDA restrictions"} except Exception as e: print(f"Error renaming local variable: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def rename_multi_local_variables(self, function_name: str, rename_pairs_old2new: List[Dict[str, str]]) -> Dict[str, Any]: """Rename multiple local variables within a function at once""" try: success_count: int = 0 failed_pairs: List[Dict[str, str]] = [] for pair in rename_pairs_old2new: old_name = next(iter(pair.keys())) new_name = pair[old_name] # Call existing rename_local_variable_internal for each pair result = self._rename_local_variable_internal(function_name, old_name, new_name) if result.get("success", False): success_count += 1 else: failed_pairs.append({ "old_name": old_name, "new_name": new_name, "error": result.get("message", "Unknown error") }) return { "success": True, "message": f"Renamed {success_count} out of {len(rename_pairs_old2new)} local variables", "success_count": success_count, "failed_pairs": failed_pairs } except Exception as e: print(f"Error in rename_multi_local_variables: {str(e)}") traceback.print_exc() return { "success": False, "message": str(e), "success_count": 0, "failed_pairs": rename_pairs_old2new } @idawrite def rename_multi_global_variables(self, rename_pairs_old2new: List[Dict[str, str]]) -> Dict[str, Any]: """Rename multiple global variables at once""" try: success_count: int = 0 failed_pairs: List[Dict[str, str]] = [] for pair in rename_pairs_old2new: old_name = next(iter(pair.keys())) new_name = pair[old_name] # Call existing rename_global_variable_internal for each pair result = self._rename_global_variable_internal(old_name, new_name) if result.get("success", False): success_count += 1 else: failed_pairs.append({ "old_name": old_name, "new_name": new_name, "error": result.get("message", "Unknown error") }) return { "success": True, "message": f"Renamed {success_count} out of {len(rename_pairs_old2new)} global variables", "success_count": success_count, "failed_pairs": failed_pairs } except Exception as e: print(f"Error in rename_multi_global_variables: {str(e)}") traceback.print_exc() return { "success": False, "message": str(e), "success_count": 0, "failed_pairs": rename_pairs_old2new } @idawrite def rename_multi_functions(self, rename_pairs_old2new: List[Dict[str, str]]) -> Dict[str, Any]: """Rename multiple functions at once""" try: success_count: int = 0 failed_pairs: List[Dict[str, str]] = [] for pair in rename_pairs_old2new: old_name = next(iter(pair.keys())) new_name = pair[old_name] # Call existing rename_function_internal for each pair result = self._rename_function_internal(old_name, new_name) if result.get("success", False): success_count += 1 else: failed_pairs.append({ "old_name": old_name, "new_name": new_name, "error": result.get("message", "Unknown error") }) return { "success": True, "message": f"Renamed {success_count} out of {len(rename_pairs_old2new)} functions", "success_count": success_count, "failed_pairs": failed_pairs } except Exception as e: print(f"Error in rename_multi_functions: {str(e)}") traceback.print_exc() return { "success": False, "message": str(e), "success_count": 0, "failed_pairs": rename_pairs_old2new } @idawrite def add_function_comment(self, function_name: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: """Add a comment to a function""" return self._add_function_comment_internal(function_name, comment, is_repeatable) def _add_function_comment_internal(self, function_name: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: """Internal implementation for add_function_comment without sync wrapper""" try: # Parameter validation if not function_name: return {"success": False, "message": "Function name cannot be empty"} if not comment: # Allow empty comment to clear the comment comment = "" # Get function address func_addr: int = ida_name.get_name_ea(0, function_name) if func_addr == idaapi.BADADDR: return {"success": False, "message": f"Function '{function_name}' not found"} # Check if it's a function func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) if not func: return {"success": False, "message": f"'{function_name}' is not a function"} # Open pseudocode view ida_hexrays.open_pseudocode(func_addr, 0) # Add function comment # is_repeatable=True means show comment at all references to this function # is_repeatable=False means show comment only at function definition result: bool = idc.set_func_cmt(func_addr, comment, is_repeatable) if result: # Refresh view self._refresh_view_internal() comment_type: str = "repeatable" if is_repeatable else "regular" return {"success": True, "message": f"Added {comment_type} comment to function '{function_name}'"} else: return {"success": False, "message": f"Failed to add comment to function '{function_name}'"} except Exception as e: print(f"Error adding function comment: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def add_pseudocode_comment(self, function_name: str, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: """Add a comment to a specific address in the function's decompiled pseudocode""" return self._add_pseudocode_comment_internal(function_name, address, comment, is_repeatable) def _add_pseudocode_comment_internal(self, function_name: str, address: str, comment: str, is_repeatable: bool) -> Dict[str, Any]: """Internal implementation for add_pseudocode_comment without sync wrapper""" try: # Parameter validation if not function_name: return {"success": False, "message": "Function name cannot be empty"} if not address: return {"success": False, "message": "Address cannot be empty"} if not comment: # Allow empty comment to clear the comment comment = "" # Get function address func_addr: int = ida_name.get_name_ea(0, function_name) if func_addr == idaapi.BADADDR: return {"success": False, "message": f"Function '{function_name}' not found"} # Check if it's a function func: Optional[ida_funcs.func_t] = ida_funcs.get_func(func_addr) if not func: return {"success": False, "message": f"'{function_name}' is not a function"} # Check if decompiler is available if not ida_hexrays.init_hexrays_plugin(): return {"success": False, "message": "Hex-Rays decompiler is not available"} # Get decompilation result cfunc: Optional[ida_hexrays.cfunc_t] = ida_hexrays.decompile(func_addr) if not cfunc: return {"success": False, "message": f"Failed to decompile function '{function_name}'"} # Open pseudocode view ida_hexrays.open_pseudocode(func_addr, 0) # Convert address string to integer addr: int if isinstance(address, str): if address.startswith("0x"): addr = int(address, 16) else: try: addr = int(address, 16) # Try parsing as hex except ValueError: try: addr = int(address) # Try parsing as decimal except ValueError: return {"success": False, "message": f"Invalid address format: {address}"} else: addr = address # Check if address is valid if addr == idaapi.BADADDR or not ida_bytes.is_loaded(addr): return {"success": False, "message": f"Invalid or unloaded address: {hex(addr)}"} # Check if address is within function if not (func.start_ea <= addr < func.end_ea): return {"success": False, "message": f"Address {hex(addr)} is not within function '{function_name}'"} # Create treeloc_t object for comment location loc = ida_hexrays.treeloc_t() loc.ea = addr loc.itp = ida_hexrays.ITP_BLOCK1 # Comment location # Set comment cfunc.set_user_cmt(loc, comment) cfunc.save_user_cmts() # Refresh view self._refresh_view_internal() comment_type: str = "repeatable" if is_repeatable else "regular" return { "success": True, "message": f"Added {comment_type} comment at address {hex(addr)} in function '{function_name}'" } except Exception as e: print(f"Error adding pseudocode comment: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def refresh_view(self) -> Dict[str, Any]: """Refresh IDA Pro view""" return self._refresh_view_internal() def _refresh_view_internal(self) -> Dict[str, Any]: """Implementation of refreshing view in IDA main thread""" try: # Refresh disassembly view idaapi.refresh_idaview_anyway() # Refresh decompilation view current_widget = idaapi.get_current_widget() if current_widget: widget_type: int = idaapi.get_widget_type(current_widget) if widget_type == idaapi.BWN_PSEUDOCODE: # If current view is pseudocode, refresh it vu = idaapi.get_widget_vdui(current_widget) if vu: vu.refresh_view(True) # Try to find and refresh all open pseudocode windows for i in range(5): # Check multiple possible pseudocode windows widget_name: str = f"Pseudocode-{chr(65+i)}" # Pseudocode-A, Pseudocode-B, ... widget = idaapi.find_widget(widget_name) if widget: vu = idaapi.get_widget_vdui(widget) if vu: vu.refresh_view(True) return {"success": True, "message": "Views refreshed successfully"} except Exception as e: print(f"Error refreshing views: {str(e)}") traceback.print_exc() return {"success": False, "message": str(e)} @idawrite def execute_script(self, script: str) -> Dict[str, Any]: """Execute a Python script in IDA context""" return self._execute_script_internal(script) def _execute_script_internal(self, script: str) -> Dict[str, Any]: """Internal implementation for execute_script without sync wrapper""" try: print(f"Executing script, length: {len(script) if script else 0}") # Check for empty script if not script or not script.strip(): print("Error: Empty script provided") return { "success": False, "error": "Empty script provided", "stdout": "", "stderr": "", "traceback": "" } # Create a local namespace for script execution script_globals = { '__builtins__': __builtins__, 'idaapi': idaapi, 'idautils': idautils, 'idc': idc, 'ida_funcs': ida_funcs, 'ida_bytes': ida_bytes, 'ida_name': ida_name, 'ida_segment': ida_segment, 'ida_lines': ida_lines, 'ida_hexrays': ida_hexrays } script_locals = {} # Save original stdin/stdout/stderr import sys import io original_stdout = sys.stdout original_stderr = sys.stderr original_stdin = sys.stdin # Create string IO objects to capture output stdout_capture = io.StringIO() stderr_capture = io.StringIO() # Redirect stdout/stderr to capture output sys.stdout = stdout_capture sys.stderr = stderr_capture # Prevent script from trying to read from stdin sys.stdin = io.StringIO() try: # Create UI hooks print("Setting up UI hooks") hooks = self._create_ui_hooks() hooks.hook() # Install auto-continue handlers for common dialogs - but first, redirect stderr temp_stderr = sys.stderr auto_handler_stderr = io.StringIO() sys.stderr = auto_handler_stderr print("Installing auto handlers") self._install_auto_handlers() # Restore stderr and save auto-handler errors separately sys.stderr = stderr_capture auto_handler_errors = auto_handler_stderr.getvalue() # Only log auto-handler errors, don't include in script output if auto_handler_errors: print(f"Auto-handler setup errors (not shown to user): {auto_handler_errors}") # Execute the script print("Executing script...") exec(script, script_globals, script_locals) print("Script execution completed") # Get captured output stdout = stdout_capture.getvalue() stderr = stderr_capture.getvalue() # Filter out auto-handler messages from stdout stdout_lines = stdout.splitlines() filtered_stdout_lines = [] for line in stdout_lines: skip_line = False auto_handler_messages = [ "Setting up UI hooks", "Installing auto handlers", "Error installing auto handlers", "Found and saved", "Could not access user_cancelled", "Installed auto_", "Auto handlers installed", "Note: Could not", "Restoring IO streams", "Unhooking UI hooks", "Restoring original handlers", "Refreshing view", "Original handlers restored", "No original handlers" ] for msg in auto_handler_messages: if msg in line: skip_line = True break if not skip_line: filtered_stdout_lines.append(line) filtered_stdout = "\n".join(filtered_stdout_lines) # Compile script results - ensure all fields are present result = { "stdout": filtered_stdout.strip() if filtered_stdout else "", "stderr": stderr.strip() if stderr else "", "success": True, "traceback": "" } # Check for return value if "result" in script_locals: try: print(f"Script returned value of type: {type(script_locals['result']).__name__}") result["return_value"] = str(script_locals["result"]) except Exception as rv_err: print(f"Error converting return value: {str(rv_err)}") result["stderr"] += f"\nError converting return value: {str(rv_err)}" result["return_value"] = "Error: Could not convert return value to string" print(f"Returning script result with keys: {', '.join(result.keys())}") return result except Exception as e: import traceback error_msg = str(e) tb = traceback.format_exc() print(f"Script execution error: {error_msg}") print(tb) return { "success": False, "stdout": stdout_capture.getvalue().strip() if stdout_capture else "", "stderr": stderr_capture.getvalue().strip() if stderr_capture else "", "error": error_msg, "traceback": tb } finally: # Restore original stdin/stdout/stderr print("Restoring IO streams") sys.stdout = original_stdout sys.stderr = original_stderr sys.stdin = original_stdin # Unhook UI hooks print("Unhooking UI hooks") hooks.unhook() # Restore original handlers print("Restoring original handlers") self._restore_original_handlers() # Refresh view to show any changes made by script print("Refreshing view") self._refresh_view_internal() except Exception as e: print(f"Error in execute_script outer scope: {str(e)}") traceback.print_exc() return { "success": False, "stdout": "", "stderr": "", "error": str(e), "traceback": traceback.format_exc() } @idawrite def execute_script_from_file(self, file_path: str) -> Dict[str, Any]: """Execute a Python script from a file in IDA context""" return self._execute_script_from_file_internal(file_path) def _execute_script_from_file_internal(self, file_path: str) -> Dict[str, Any]: """Internal implementation for execute_script_from_file without sync wrapper""" try: # Check if file path is provided if not file_path or not file_path.strip(): return { "success": False, "error": "No file path provided", "stdout": "", "stderr": "", "traceback": "" } # Check if file exists import os if not os.path.exists(file_path): return { "success": False, "error": f"Script file not found: {file_path}", "stdout": "", "stderr": "", "traceback": "" } try: # Read script content with open(file_path, 'r') as f: script = f.read() # Execute script using internal method return self._execute_script_internal(script) except Exception as file_error: print(f"Error reading or executing script file: {str(file_error)}") traceback.print_exc() return { "success": False, "stdout": "", "stderr": "", "error": f"Error with script file: {str(file_error)}", "traceback": traceback.format_exc() } except Exception as e: print(f"Error executing script from file: {str(e)}") traceback.print_exc() return { "success": False, "stdout": "", "stderr": "", "error": str(e), "traceback": traceback.format_exc() } def _create_ui_hooks(self) -> idaapi.UI_Hooks: """Create UI hooks to suppress dialogs during script execution""" try: class DialogHook(idaapi.UI_Hooks): def populating_widget_popup(self, widget, popup): # Just suppress all popups return 1 def finish_populating_widget_popup(self, widget, popup): # Also suppress here return 1 def ready_to_run(self): # Always continue return 1 def updating_actions(self, ctx): # Always continue return 1 def updated_actions(self): # Always continue return 1 def ui_refresh(self, cnd): # Suppress UI refreshes return 1 hooks = DialogHook() return hooks except Exception as e: print(f"Error creating UI hooks: {str(e)}") traceback.print_exc() # Create minimal dummy hooks that won't cause errors class DummyHook: def hook(self): print("Using dummy hook (hook)") pass def unhook(self): print("Using dummy hook (unhook)") pass return DummyHook() def _install_auto_handlers(self) -> None: """Install auto-continue handlers for common dialogs""" try: import ida_kernwin # Save original handlers - with safer access to cvar.user_cancelled self._original_handlers = {} # Try to access user_cancelled more safely try: if hasattr(ida_kernwin, 'cvar') and hasattr(ida_kernwin.cvar, 'user_cancelled'): self._original_handlers["yn"] = ida_kernwin.cvar.user_cancelled print("Found and saved user_cancelled handler") except Exception as yn_err: print(f"Note: Could not access user_cancelled: {str(yn_err)}") # Save other dialog handlers if hasattr(ida_kernwin, 'ask_buttons'): self._original_handlers["buttons"] = ida_kernwin.ask_buttons if hasattr(ida_kernwin, 'ask_text'): self._original_handlers["text"] = ida_kernwin.ask_text if hasattr(ida_kernwin, 'ask_file'): self._original_handlers["file"] = ida_kernwin.ask_file # Define auto handlers def auto_yes_no(*args, **kwargs): return 1 # Return "Yes" def auto_buttons(*args, **kwargs): return 1 # Return first button def auto_text(*args, **kwargs): return "" # Return empty text def auto_file(*args, **kwargs): return "" # Return empty filename # Install auto handlers only for what we successfully saved if "yn" in self._original_handlers: try: ida_kernwin.cvar.user_cancelled = auto_yes_no print("Installed auto_yes_no handler") except Exception as e: print(f"Could not install auto_yes_no handler: {str(e)}") if "buttons" in self._original_handlers: ida_kernwin.ask_buttons = auto_buttons print("Installed auto_buttons handler") if "text" in self._original_handlers: ida_kernwin.ask_text = auto_text print("Installed auto_text handler") if "file" in self._original_handlers: ida_kernwin.ask_file = auto_file print("Installed auto_file handler") print(f"Auto handlers installed successfully. Installed handlers: {', '.join(self._original_handlers.keys())}") except Exception as e: print(f"Error installing auto handlers: {str(e)}") traceback.print_exc() # Ensure _original_handlers exists even on failure if not hasattr(self, "_original_handlers"): self._original_handlers = {} def _restore_original_handlers(self) -> None: """Restore original dialog handlers""" try: if hasattr(self, "_original_handlers"): import ida_kernwin # Restore original handlers (only what was successfully saved) if "yn" in self._original_handlers: try: ida_kernwin.cvar.user_cancelled = self._original_handlers["yn"] print("Restored user_cancelled handler") except Exception as e: print(f"Could not restore user_cancelled handler: {str(e)}") if "buttons" in self._original_handlers: ida_kernwin.ask_buttons = self._original_handlers["buttons"] print("Restored ask_buttons handler") if "text" in self._original_handlers: ida_kernwin.ask_text = self._original_handlers["text"] print("Restored ask_text handler") if "file" in self._original_handlers: ida_kernwin.ask_file = self._original_handlers["file"] print("Restored ask_file handler") saved_keys = list(self._original_handlers.keys()) if saved_keys: print(f"Original handlers restored: {', '.join(saved_keys)}") else: print("No original handlers were saved, nothing to restore") else: print("No original handlers dictionary to restore") except Exception as e: print(f"Error restoring original handlers: {str(e)}") traceback.print_exc() ``` -------------------------------------------------------------------------------- /src/mcp_server_ida/server.py: -------------------------------------------------------------------------------- ```python import logging import socket import json import time import struct import uuid from typing import Dict, Any, List, Union, Optional, Tuple, Callable, TypeVar, Set, Awaitable, Type, cast from mcp.server import Server from mcp.server.stdio import stdio_server from mcp.types import ( TextContent, Tool, ) from enum import Enum from pydantic import BaseModel # Modify request models class GetFunctionAssemblyByName(BaseModel): function_name: str class GetFunctionAssemblyByAddress(BaseModel): address: str # Hexadecimal address as string class GetFunctionDecompiledByName(BaseModel): function_name: str class GetFunctionDecompiledByAddress(BaseModel): address: str # Hexadecimal address as string class GetGlobalVariableByName(BaseModel): variable_name: str class GetGlobalVariableByAddress(BaseModel): address: str # Hexadecimal address as string class GetCurrentFunctionAssembly(BaseModel): pass class GetCurrentFunctionDecompiled(BaseModel): pass class RenameLocalVariable(BaseModel): function_name: str old_name: str new_name: str class RenameGlobalVariable(BaseModel): old_name: str new_name: str class RenameFunction(BaseModel): old_name: str new_name: str class RenameMultiLocalVariables(BaseModel): function_name: str rename_pairs_old2new: List[Dict[str, str]] # List of dictionaries with "old_name" and "new_name" keys class RenameMultiGlobalVariables(BaseModel): rename_pairs_old2new: List[Dict[str, str]] class RenameMultiFunctions(BaseModel): rename_pairs_old2new: List[Dict[str, str]] class AddAssemblyComment(BaseModel): address: str # Can be a hexadecimal address string comment: str is_repeatable: bool = False # Whether the comment should be repeatable class AddFunctionComment(BaseModel): function_name: str comment: str is_repeatable: bool = False # Whether the comment should be repeatable class AddPseudocodeComment(BaseModel): function_name: str address: str # Address in the pseudocode comment: str is_repeatable: bool = False # Whether comment should be repeated at all occurrences class ExecuteScript(BaseModel): script: str class ExecuteScriptFromFile(BaseModel): file_path: str class IDATools(str, Enum): GET_FUNCTION_ASSEMBLY_BY_NAME = "ida_get_function_assembly_by_name" GET_FUNCTION_ASSEMBLY_BY_ADDRESS = "ida_get_function_assembly_by_address" GET_FUNCTION_DECOMPILED_BY_NAME = "ida_get_function_decompiled_by_name" GET_FUNCTION_DECOMPILED_BY_ADDRESS = "ida_get_function_decompiled_by_address" GET_GLOBAL_VARIABLE_BY_NAME = "ida_get_global_variable_by_name" GET_GLOBAL_VARIABLE_BY_ADDRESS = "ida_get_global_variable_by_address" GET_CURRENT_FUNCTION_ASSEMBLY = "ida_get_current_function_assembly" GET_CURRENT_FUNCTION_DECOMPILED = "ida_get_current_function_decompiled" RENAME_LOCAL_VARIABLE = "ida_rename_local_variable" RENAME_GLOBAL_VARIABLE = "ida_rename_global_variable" RENAME_FUNCTION = "ida_rename_function" RENAME_MULTI_LOCAL_VARIABLES = "ida_rename_multi_local_variables" RENAME_MULTI_GLOBAL_VARIABLES = "ida_rename_multi_global_variables" RENAME_MULTI_FUNCTIONS = "ida_rename_multi_functions" ADD_ASSEMBLY_COMMENT = "ida_add_assembly_comment" ADD_FUNCTION_COMMENT = "ida_add_function_comment" ADD_PSEUDOCODE_COMMENT = "ida_add_pseudocode_comment" EXECUTE_SCRIPT = "ida_execute_script" EXECUTE_SCRIPT_FROM_FILE = "ida_execute_script_from_file" # IDA Pro通信处理器 class IDAProCommunicator: def __init__(self, host: str = 'localhost', port: int = 5000): self.host: str = host self.port: int = port self.sock: Optional[socket.socket] = None self.logger: logging.Logger = logging.getLogger(__name__) self.connected: bool = False self.reconnect_attempts: int = 0 self.max_reconnect_attempts: int = 5 self.last_reconnect_time: float = 0 self.reconnect_cooldown: int = 5 # seconds self.request_count: int = 0 self.default_timeout: int = 10 self.batch_timeout: int = 60 # it may take more time for batch operations def connect(self) -> bool: """Connect to IDA plugin""" # Check if cooldown is needed current_time: float = time.time() if current_time - self.last_reconnect_time < self.reconnect_cooldown and self.reconnect_attempts > 0: self.logger.debug("In reconnection cooldown, skipping") return False # If already connected, disconnect first if self.connected: self.disconnect() try: self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.settimeout(self.default_timeout) self.sock.connect((self.host, self.port)) self.connected = True self.reconnect_attempts = 0 self.logger.info(f"Connected to IDA Pro ({self.host}:{self.port})") return True except Exception as e: self.last_reconnect_time = current_time self.reconnect_attempts += 1 if self.reconnect_attempts <= self.max_reconnect_attempts: self.logger.warning(f"Failed to connect to IDA Pro: {str(e)}. Attempt {self.reconnect_attempts}/{self.max_reconnect_attempts}") else: self.logger.error(f"Failed to connect to IDA Pro after {self.max_reconnect_attempts} attempts: {str(e)}") return False def disconnect(self) -> None: """Disconnect from IDA Pro""" if self.sock: try: self.sock.close() except: pass self.sock = None self.connected = False def ensure_connection(self) -> bool: """Ensure connection is established""" if not self.connected: return self.connect() return True def send_message(self, data: bytes) -> None: """Send message with length prefix""" if self.sock is None: raise ConnectionError("Socket is not connected") length: int = len(data) length_bytes: bytes = struct.pack('!I', length) # 4-byte length prefix self.sock.sendall(length_bytes + data) def receive_message(self) -> Optional[bytes]: """Receive message with length prefix""" try: # Receive 4-byte length prefix length_bytes: Optional[bytes] = self.receive_exactly(4) if not length_bytes: return None length: int = struct.unpack('!I', length_bytes)[0] # Receive message body data: Optional[bytes] = self.receive_exactly(length) return data except Exception as e: self.logger.error(f"Error receiving message: {str(e)}") return None def receive_exactly(self, n: int) -> Optional[bytes]: """Receive exactly n bytes of data""" if self.sock is None: raise ConnectionError("Socket is not connected") data: bytes = b'' while len(data) < n: chunk: bytes = self.sock.recv(min(n - len(data), 4096)) if not chunk: # Connection closed return None data += chunk return data def send_request(self, request_type: str, data: Dict[str, Any]) -> Dict[str, Any]: """Send request to IDA plugin""" # Ensure connection is established if not self.ensure_connection(): return {"error": "Cannot connect to IDA Pro"} try: if request_type in ["rename_multi_local_variables", "rename_multi_global_variables", "rename_multi_functions"]: if self.sock: self.sock.settimeout(self.batch_timeout) self.logger.debug(f"Set timeout to {self.batch_timeout}s for batch operation") else: if self.sock: self.sock.settimeout(self.default_timeout) self.logger.debug(f"Set timeout to {self.default_timeout}s for normal operation") # Add request ID request_id: str = str(uuid.uuid4()) self.request_count += 1 request_count: int = self.request_count request: Dict[str, Any] = { "id": request_id, "count": request_count, "type": request_type, "data": data } self.logger.debug(f"Sending request: {request_id}, type: {request_type}, count: {request_count}") try: # Send request request_json: bytes = json.dumps(request).encode('utf-8') self.send_message(request_json) # Receive response response_data: Optional[bytes] = self.receive_message() # If no data received, assume connection is closed if not response_data: self.logger.warning("No data received, connection may be closed") self.disconnect() return {"error": "No response received from IDA Pro"} # Parse response try: self.logger.debug(f"Received raw data length: {len(response_data)}") response: Dict[str, Any] = json.loads(response_data.decode('utf-8')) # Verify response ID matches response_id: str = response.get("id") if response_id != request_id: self.logger.warning(f"Response ID mismatch! Request ID: {request_id}, Response ID: {response_id}") self.logger.debug(f"Received response: ID={response.get('id')}, count={response.get('count')}") # Additional type verification if not isinstance(response, dict): self.logger.error(f"Received response is not a dictionary: {type(response)}") return {"error": f"Response format error: expected dictionary, got {type(response).__name__}"} return response except json.JSONDecodeError as e: self.logger.error(f"Failed to parse JSON response: {str(e)}") return {"error": f"Invalid JSON response: {str(e)}"} except Exception as e: self.logger.error(f"Error communicating with IDA Pro: {str(e)}") self.disconnect() # Disconnect after error return {"error": str(e)} finally: # restore timeout if self.sock: self.sock.settimeout(self.default_timeout) def ping(self) -> bool: """Check if connection is valid""" response: Dict[str, Any] = self.send_request("ping", {}) return response.get("status") == "pong" # Actual IDA Pro functionality implementation class IDAProFunctions: def __init__(self, communicator: IDAProCommunicator): self.communicator: IDAProCommunicator = communicator self.logger: logging.Logger = logging.getLogger(__name__) def get_function_assembly(self, function_name: str) -> str: """Get assembly code for a function by name (legacy method)""" return self.get_function_assembly_by_name(function_name) def get_function_assembly_by_name(self, function_name: str) -> str: """Get assembly code for a function by its name""" try: response: Dict[str, Any] = self.communicator.send_request( "get_function_assembly_by_name", {"function_name": function_name} ) if "error" in response: return f"Error retrieving assembly for function '{function_name}': {response['error']}" assembly: Any = response.get("assembly") # Verify assembly is string type if assembly is None: return f"Error: No assembly data returned for function '{function_name}'" if not isinstance(assembly, str): self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion") assembly = str(assembly) return f"Assembly code for function '{function_name}':\n{assembly}" except Exception as e: self.logger.error(f"Error getting function assembly: {str(e)}", exc_info=True) return f"Error retrieving assembly for function '{function_name}': {str(e)}" def get_function_decompiled(self, function_name: str) -> str: """Get decompiled code for a function by name (legacy method)""" return self.get_function_decompiled_by_name(function_name) def get_function_decompiled_by_name(self, function_name: str) -> str: """Get decompiled pseudocode for a function by its name""" try: response: Dict[str, Any] = self.communicator.send_request( "get_function_decompiled_by_name", {"function_name": function_name} ) # Log complete response for debugging self.logger.debug(f"Decompilation response: {response}") if "error" in response: return f"Error retrieving decompiled code for function '{function_name}': {response['error']}" decompiled_code: Any = response.get("decompiled_code") # Detailed type checking and conversion if decompiled_code is None: return f"Error: No decompiled code returned for function '{function_name}'" # Log actual type actual_type: str = type(decompiled_code).__name__ self.logger.debug(f"Decompiled code type is: {actual_type}") # Ensure result is string if not isinstance(decompiled_code, str): self.logger.warning(f"Decompiled code type is not string but {actual_type}, attempting conversion") try: decompiled_code = str(decompiled_code) except Exception as e: return f"Error: Failed to convert decompiled code from {actual_type} to string: {str(e)}" return f"Decompiled code for function '{function_name}':\n{decompiled_code}" except Exception as e: self.logger.error(f"Error getting function decompiled code: {str(e)}", exc_info=True) return f"Error retrieving decompiled code for function '{function_name}': {str(e)}" def get_global_variable(self, variable_name: str) -> str: """Get global variable information by name (legacy method)""" return self.get_global_variable_by_name(variable_name) def get_global_variable_by_name(self, variable_name: str) -> str: """Get global variable information by its name""" try: response: Dict[str, Any] = self.communicator.send_request( "get_global_variable_by_name", {"variable_name": variable_name} ) if "error" in response: return f"Error retrieving global variable '{variable_name}': {response['error']}" variable_info: Any = response.get("variable_info") # Verify variable_info is string type if variable_info is None: return f"Error: No variable info returned for '{variable_name}'" if not isinstance(variable_info, str): self.logger.warning(f"Variable info type is not string but {type(variable_info).__name__}, attempting conversion") try: # If it's a dictionary, convert to JSON string first if isinstance(variable_info, dict): variable_info = json.dumps(variable_info, indent=2) else: variable_info = str(variable_info) except Exception as e: return f"Error: Failed to convert variable info to string: {str(e)}" return f"Global variable '{variable_name}':\n{variable_info}" except Exception as e: self.logger.error(f"Error getting global variable: {str(e)}", exc_info=True) return f"Error retrieving global variable '{variable_name}': {str(e)}" def get_global_variable_by_address(self, address: str) -> str: """Get global variable information by its address""" try: # Convert string address to int try: addr_int = int(address, 16) if address.startswith("0x") else int(address) except ValueError: return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal" response: Dict[str, Any] = self.communicator.send_request( "get_global_variable_by_address", {"address": addr_int} ) if "error" in response: return f"Error retrieving global variable at address '{address}': {response['error']}" variable_info: Any = response.get("variable_info") # Verify variable_info is string type if variable_info is None: return f"Error: No variable info returned for address '{address}'" if not isinstance(variable_info, str): self.logger.warning(f"Variable info type is not string but {type(variable_info).__name__}, attempting conversion") try: # If it's a dictionary, convert to JSON string first if isinstance(variable_info, dict): variable_info = json.dumps(variable_info, indent=2) else: variable_info = str(variable_info) except Exception as e: return f"Error: Failed to convert variable info to string: {str(e)}" # Try to extract the variable name from the JSON for a better message var_name = "Unknown" try: var_info_dict = json.loads(variable_info) if isinstance(var_info_dict, dict) and "name" in var_info_dict: var_name = var_info_dict["name"] except: pass return f"Global variable '{var_name}' at address {address}:\n{variable_info}" except Exception as e: self.logger.error(f"Error getting global variable by address: {str(e)}", exc_info=True) return f"Error retrieving global variable at address '{address}': {str(e)}" def get_current_function_assembly(self) -> str: """Get assembly code for the function at current cursor position""" try: response: Dict[str, Any] = self.communicator.send_request( "get_current_function_assembly", {} ) if "error" in response: return f"Error retrieving assembly for current function: {response['error']}" assembly: Any = response.get("assembly") function_name: str = response.get("function_name", "Current function") # Verify assembly is string type if assembly is None: return f"Error: No assembly data returned for current function" if not isinstance(assembly, str): self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion") assembly = str(assembly) return f"Assembly code for function '{function_name}':\n{assembly}" except Exception as e: self.logger.error(f"Error getting current function assembly: {str(e)}", exc_info=True) return f"Error retrieving assembly for current function: {str(e)}" def get_current_function_decompiled(self) -> str: """Get decompiled code for the function at current cursor position""" try: response: Dict[str, Any] = self.communicator.send_request( "get_current_function_decompiled", {} ) if "error" in response: return f"Error retrieving decompiled code for current function: {response['error']}" decompiled_code: Any = response.get("decompiled_code") function_name: str = response.get("function_name", "Current function") # Detailed type checking and conversion if decompiled_code is None: return f"Error: No decompiled code returned for current function" # Ensure result is string if not isinstance(decompiled_code, str): self.logger.warning(f"Decompiled code type is not string but {type(decompiled_code).__name__}, attempting conversion") try: decompiled_code = str(decompiled_code) except Exception as e: return f"Error: Failed to convert decompiled code: {str(e)}" return f"Decompiled code for function '{function_name}':\n{decompiled_code}" except Exception as e: self.logger.error(f"Error getting current function decompiled code: {str(e)}", exc_info=True) return f"Error retrieving decompiled code for current function: {str(e)}" def rename_local_variable(self, function_name: str, old_name: str, new_name: str) -> str: """Rename a local variable within a function""" try: response: Dict[str, Any] = self.communicator.send_request( "rename_local_variable", {"function_name": function_name, "old_name": old_name, "new_name": new_name} ) if "error" in response: return f"Error renaming local variable from '{old_name}' to '{new_name}' in function '{function_name}': {response['error']}" success: bool = response.get("success", False) message: str = response.get("message", "") if success: return f"Successfully renamed local variable from '{old_name}' to '{new_name}' in function '{function_name}': {message}" else: return f"Failed to rename local variable from '{old_name}' to '{new_name}' in function '{function_name}': {message}" except Exception as e: self.logger.error(f"Error renaming local variable: {str(e)}", exc_info=True) return f"Error renaming local variable from '{old_name}' to '{new_name}' in function '{function_name}': {str(e)}" def rename_global_variable(self, old_name: str, new_name: str) -> str: """Rename a global variable""" try: response: Dict[str, Any] = self.communicator.send_request( "rename_global_variable", {"old_name": old_name, "new_name": new_name} ) if "error" in response: return f"Error renaming global variable from '{old_name}' to '{new_name}': {response['error']}" success: bool = response.get("success", False) message: str = response.get("message", "") if success: return f"Successfully renamed global variable from '{old_name}' to '{new_name}': {message}" else: return f"Failed to rename global variable from '{old_name}' to '{new_name}': {message}" except Exception as e: self.logger.error(f"Error renaming global variable: {str(e)}", exc_info=True) return f"Error renaming global variable from '{old_name}' to '{new_name}': {str(e)}" def rename_function(self, old_name: str, new_name: str) -> str: """Rename a function""" try: response: Dict[str, Any] = self.communicator.send_request( "rename_function", {"old_name": old_name, "new_name": new_name} ) if "error" in response: return f"Error renaming function from '{old_name}' to '{new_name}': {response['error']}" success: bool = response.get("success", False) message: str = response.get("message", "") if success: return f"Successfully renamed function from '{old_name}' to '{new_name}': {message}" else: return f"Failed to rename function from '{old_name}' to '{new_name}': {message}" except Exception as e: self.logger.error(f"Error renaming function: {str(e)}", exc_info=True) return f"Error renaming function from '{old_name}' to '{new_name}': {str(e)}" def rename_multi_local_variables(self, function_name: str, rename_pairs_old2new: List[Dict[str, str]]) -> str: """Rename multiple local variables within a function at once""" try: response: Dict[str, Any] = self.communicator.send_request( "rename_multi_local_variables", { "function_name": function_name, "rename_pairs_old2new": rename_pairs_old2new } ) if "error" in response: return f"Error renaming multiple local variables in function '{function_name}': {response['error']}" success_count: int = response.get("success_count", 0) failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", []) result_parts: List[str] = [ f"Successfully renamed {success_count} local variables in function '{function_name}'" ] if failed_pairs: result_parts.append("\nFailed renamings:") for pair in failed_pairs: result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}") return "\n".join(result_parts) except Exception as e: self.logger.error(f"Error renaming multiple local variables: {str(e)}", exc_info=True) return f"Error renaming multiple local variables in function '{function_name}': {str(e)}" def rename_multi_global_variables(self, rename_pairs_old2new: List[Dict[str, str]]) -> str: """Rename multiple global variables at once""" try: response: Dict[str, Any] = self.communicator.send_request( "rename_multi_global_variables", {"rename_pairs_old2new": rename_pairs_old2new} ) if "error" in response: return f"Error renaming multiple global variables: {response['error']}" success_count: int = response.get("success_count", 0) failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", []) result_parts: List[str] = [ f"Successfully renamed {success_count} global variables" ] if failed_pairs: result_parts.append("\nFailed renamings:") for pair in failed_pairs: result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}") return "\n".join(result_parts) except Exception as e: self.logger.error(f"Error renaming multiple global variables: {str(e)}", exc_info=True) return f"Error renaming multiple global variables: {str(e)}" def rename_multi_functions(self, rename_pairs_old2new: List[Dict[str, str]]) -> str: """Rename multiple functions at once""" try: response: Dict[str, Any] = self.communicator.send_request( "rename_multi_functions", {"rename_pairs_old2new": rename_pairs_old2new} ) if "error" in response: return f"Error renaming multiple functions: {response['error']}" success_count: int = response.get("success_count", 0) failed_pairs: List[Dict[str, str]] = response.get("failed_pairs", []) result_parts: List[str] = [ f"Successfully renamed {success_count} functions" ] if failed_pairs: result_parts.append("\nFailed renamings:") for pair in failed_pairs: result_parts.append(f"- {pair['old_name']} → {pair['new_name']}: {pair.get('error', 'Unknown error')}") return "\n".join(result_parts) except Exception as e: self.logger.error(f"Error renaming multiple functions: {str(e)}", exc_info=True) return f"Error renaming multiple functions: {str(e)}" def add_assembly_comment(self, address: str, comment: str, is_repeatable: bool = False) -> str: """Add an assembly comment""" try: response: Dict[str, Any] = self.communicator.send_request( "add_assembly_comment", {"address": address, "comment": comment, "is_repeatable": is_repeatable} ) if "error" in response: return f"Error adding assembly comment at address '{address}': {response['error']}" success: bool = response.get("success", False) message: str = response.get("message", "") if success: comment_type: str = "repeatable" if is_repeatable else "regular" return f"Successfully added {comment_type} assembly comment at address '{address}': {message}" else: return f"Failed to add assembly comment at address '{address}': {message}" except Exception as e: self.logger.error(f"Error adding assembly comment: {str(e)}", exc_info=True) return f"Error adding assembly comment at address '{address}': {str(e)}" def add_function_comment(self, function_name: str, comment: str, is_repeatable: bool = False) -> str: """Add a comment to a function""" try: response: Dict[str, Any] = self.communicator.send_request( "add_function_comment", {"function_name": function_name, "comment": comment, "is_repeatable": is_repeatable} ) if "error" in response: return f"Error adding comment to function '{function_name}': {response['error']}" success: bool = response.get("success", False) message: str = response.get("message", "") if success: comment_type: str = "repeatable" if is_repeatable else "regular" return f"Successfully added {comment_type} comment to function '{function_name}': {message}" else: return f"Failed to add comment to function '{function_name}': {message}" except Exception as e: self.logger.error(f"Error adding function comment: {str(e)}", exc_info=True) return f"Error adding comment to function '{function_name}': {str(e)}" def add_pseudocode_comment(self, function_name: str, address: str, comment: str, is_repeatable: bool = False) -> str: """Add a comment to a specific address in the function's decompiled pseudocode""" try: response: Dict[str, Any] = self.communicator.send_request( "add_pseudocode_comment", { "function_name": function_name, "address": address, "comment": comment, "is_repeatable": is_repeatable } ) if "error" in response: return f"Error adding comment at address {address} in function '{function_name}': {response['error']}" success: bool = response.get("success", False) message: str = response.get("message", "") if success: comment_type: str = "repeatable" if is_repeatable else "regular" return f"Successfully added {comment_type} comment at address {address} in function '{function_name}': {message}" else: return f"Failed to add comment at address {address} in function '{function_name}': {message}" except Exception as e: self.logger.error(f"Error adding pseudocode comment: {str(e)}", exc_info=True) return f"Error adding comment at address {address} in function '{function_name}': {str(e)}" def execute_script(self, script: str) -> str: """Execute a Python script in IDA Pro and return its output. The script runs in IDA's context with access to all IDA API modules.""" try: response: Dict[str, Any] = self.communicator.send_request( "execute_script", {"script": script} ) # Handle case where response is None if response is None: self.logger.error("Received None response from IDA when executing script") return "Error executing script: Received empty response from IDA" # Handle case where response contains error if "error" in response: return f"Error executing script: {response['error']}" # Handle successful execution success: bool = response.get("success", False) if not success: error_msg: str = response.get("error", "Unknown error") traceback: str = response.get("traceback", "") return f"Script execution failed: {error_msg}\n\nTraceback:\n{traceback}" # Get output - ensure all values are strings to avoid None errors stdout: str = str(response.get("stdout", "")) stderr: str = str(response.get("stderr", "")) return_value: str = str(response.get("return_value", "")) result_text: List[str] = [] result_text.append("Script executed successfully") if return_value and return_value != "None": result_text.append(f"\nReturn value:\n{return_value}") if stdout: result_text.append(f"\nStandard output:\n{stdout}") if stderr: result_text.append(f"\nStandard error:\n{stderr}") return "\n".join(result_text) except Exception as e: self.logger.error(f"Error executing script: {str(e)}", exc_info=True) return f"Error executing script: {str(e)}" def execute_script_from_file(self, file_path: str) -> str: """Execute a Python script from a file path in IDA Pro and return its output. The file should be accessible from IDA's process.""" try: response: Dict[str, Any] = self.communicator.send_request( "execute_script_from_file", {"file_path": file_path} ) # Handle case where response is None if response is None: self.logger.error("Received None response from IDA when executing script from file") return f"Error executing script from file '{file_path}': Received empty response from IDA" # Handle case where response contains error if "error" in response: return f"Error executing script from file '{file_path}': {response['error']}" # Handle successful execution success: bool = response.get("success", False) if not success: error_msg: str = response.get("error", "Unknown error") traceback: str = response.get("traceback", "") return f"Script execution from file '{file_path}' failed: {error_msg}\n\nTraceback:\n{traceback}" # Get output - ensure all values are strings to avoid None errors stdout: str = str(response.get("stdout", "")) stderr: str = str(response.get("stderr", "")) return_value: str = str(response.get("return_value", "")) result_text: List[str] = [] result_text.append(f"Script from file '{file_path}' executed successfully") if return_value and return_value != "None": result_text.append(f"\nReturn value:\n{return_value}") if stdout: result_text.append(f"\nStandard output:\n{stdout}") if stderr: result_text.append(f"\nStandard error:\n{stderr}") return "\n".join(result_text) except Exception as e: self.logger.error(f"Error executing script from file: {str(e)}", exc_info=True) return f"Error executing script from file '{file_path}': {str(e)}" def get_function_assembly_by_address(self, address: str) -> str: """Get assembly code for a function by its address""" try: # Convert string address to int try: addr_int = int(address, 16) if address.startswith("0x") else int(address) except ValueError: return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal" response: Dict[str, Any] = self.communicator.send_request( "get_function_assembly_by_address", {"address": addr_int} ) if "error" in response: return f"Error retrieving assembly for address '{address}': {response['error']}" assembly: Any = response.get("assembly") function_name: str = response.get("function_name", "Unknown function") # Verify assembly is string type if assembly is None: return f"Error: No assembly data returned for address '{address}'" if not isinstance(assembly, str): self.logger.warning(f"Assembly data type is not string but {type(assembly).__name__}, attempting conversion") assembly = str(assembly) return f"Assembly code for function '{function_name}' at address {address}:\n{assembly}" except Exception as e: self.logger.error(f"Error getting function assembly by address: {str(e)}", exc_info=True) return f"Error retrieving assembly for address '{address}': {str(e)}" def get_function_decompiled_by_address(self, address: str) -> str: """Get decompiled pseudocode for a function by its address""" try: # Convert string address to int try: addr_int = int(address, 16) if address.startswith("0x") else int(address) except ValueError: return f"Error: Invalid address format '{address}', expected hexadecimal (0x...) or decimal" response: Dict[str, Any] = self.communicator.send_request( "get_function_decompiled_by_address", {"address": addr_int} ) if "error" in response: return f"Error retrieving decompiled code for address '{address}': {response['error']}" decompiled_code: Any = response.get("decompiled_code") function_name: str = response.get("function_name", "Unknown function") # Detailed type checking and conversion if decompiled_code is None: return f"Error: No decompiled code returned for address '{address}'" # Ensure result is string if not isinstance(decompiled_code, str): self.logger.warning(f"Decompiled code type is not string but {type(decompiled_code).__name__}, attempting conversion") try: decompiled_code = str(decompiled_code) except Exception as e: return f"Error: Failed to convert decompiled code: {str(e)}" return f"Decompiled code for function '{function_name}' at address {address}:\n{decompiled_code}" except Exception as e: self.logger.error(f"Error getting function decompiled code by address: {str(e)}", exc_info=True) return f"Error retrieving decompiled code for address '{address}': {str(e)}" async def serve() -> None: """MCP server main entry point""" logger: logging.Logger = logging.getLogger(__name__) # Set log level to DEBUG for detailed information logger.setLevel(logging.DEBUG) server: Server = Server("mcp-ida") # Create communicator and attempt connection ida_communicator: IDAProCommunicator = IDAProCommunicator() logger.info("Attempting to connect to IDA Pro plugin...") if ida_communicator.connect(): logger.info("Successfully connected to IDA Pro plugin") else: logger.warning("Initial connection to IDA Pro plugin failed, will retry on request") # Create IDA functions class with persistent connection ida_functions: IDAProFunctions = IDAProFunctions(ida_communicator) @server.list_tools() async def list_tools() -> List[Tool]: """List supported tools""" return [ Tool( name=IDATools.GET_FUNCTION_ASSEMBLY_BY_NAME, description="Get assembly code for a function by name", inputSchema=GetFunctionAssemblyByName.schema(), ), Tool( name=IDATools.GET_FUNCTION_ASSEMBLY_BY_ADDRESS, description="Get assembly code for a function by address", inputSchema=GetFunctionAssemblyByAddress.schema(), ), Tool( name=IDATools.GET_FUNCTION_DECOMPILED_BY_NAME, description="Get decompiled pseudocode for a function by name", inputSchema=GetFunctionDecompiledByName.schema(), ), Tool( name=IDATools.GET_FUNCTION_DECOMPILED_BY_ADDRESS, description="Get decompiled pseudocode for a function by address", inputSchema=GetFunctionDecompiledByAddress.schema(), ), Tool( name=IDATools.GET_GLOBAL_VARIABLE_BY_NAME, description="Get information about a global variable by name", inputSchema=GetGlobalVariableByName.schema(), ), Tool( name=IDATools.GET_GLOBAL_VARIABLE_BY_ADDRESS, description="Get information about a global variable by address", inputSchema=GetGlobalVariableByAddress.schema(), ), Tool( name=IDATools.GET_CURRENT_FUNCTION_ASSEMBLY, description="Get assembly code for the function at the current cursor position", inputSchema=GetCurrentFunctionAssembly.schema(), ), Tool( name=IDATools.GET_CURRENT_FUNCTION_DECOMPILED, description="Get decompiled pseudocode for the function at the current cursor position", inputSchema=GetCurrentFunctionDecompiled.schema(), ), Tool( name=IDATools.RENAME_LOCAL_VARIABLE, description="Rename a local variable within a function in the IDA database", inputSchema=RenameLocalVariable.schema(), ), Tool( name=IDATools.RENAME_GLOBAL_VARIABLE, description="Rename a global variable in the IDA database", inputSchema=RenameGlobalVariable.schema(), ), Tool( name=IDATools.RENAME_FUNCTION, description="Rename a function in the IDA database", inputSchema=RenameFunction.schema(), ), Tool( name=IDATools.RENAME_MULTI_LOCAL_VARIABLES, description="Rename multiple local variables within a function at once in the IDA database", inputSchema=RenameMultiLocalVariables.schema(), ), Tool( name=IDATools.RENAME_MULTI_GLOBAL_VARIABLES, description="Rename multiple global variables at once in the IDA database", inputSchema=RenameMultiGlobalVariables.schema(), ), Tool( name=IDATools.RENAME_MULTI_FUNCTIONS, description="Rename multiple functions at once in the IDA database", inputSchema=RenameMultiFunctions.schema(), ), Tool( name=IDATools.ADD_ASSEMBLY_COMMENT, description="Add a comment at a specific address in the assembly view of the IDA database", inputSchema=AddAssemblyComment.schema(), ), Tool( name=IDATools.ADD_FUNCTION_COMMENT, description="Add a comment to a function in the IDA database", inputSchema=AddFunctionComment.schema(), ), Tool( name=IDATools.ADD_PSEUDOCODE_COMMENT, description="Add a comment to a specific address in the function's decompiled pseudocode", inputSchema=AddPseudocodeComment.schema(), ), Tool( name=IDATools.EXECUTE_SCRIPT, description="Execute a Python script in IDA Pro and return its output. The script runs in IDA's context with access to all IDA API modules.", inputSchema=ExecuteScript.schema(), ), Tool( name=IDATools.EXECUTE_SCRIPT_FROM_FILE, description="Execute a Python script from a file path in IDA Pro and return its output. The file should be accessible from IDA's process.", inputSchema=ExecuteScriptFromFile.schema(), ), ] @server.call_tool() async def call_tool(name: str, arguments: Dict[str, Any]) -> List[TextContent]: """Call tool and handle results""" # Ensure connection exists if not ida_communicator.connected and not ida_communicator.ensure_connection(): return [TextContent( type="text", text=f"Error: Cannot connect to IDA Pro plugin. Please ensure the plugin is running." )] try: match name: case IDATools.GET_FUNCTION_ASSEMBLY_BY_NAME: assembly: str = ida_functions.get_function_assembly_by_name(arguments["function_name"]) return [TextContent( type="text", text=assembly )] case IDATools.GET_FUNCTION_ASSEMBLY_BY_ADDRESS: assembly: str = ida_functions.get_function_assembly_by_address(arguments["address"]) return [TextContent( type="text", text=assembly )] case IDATools.GET_FUNCTION_DECOMPILED_BY_NAME: decompiled: str = ida_functions.get_function_decompiled_by_name(arguments["function_name"]) return [TextContent( type="text", text=decompiled )] case IDATools.GET_FUNCTION_DECOMPILED_BY_ADDRESS: decompiled: str = ida_functions.get_function_decompiled_by_address(arguments["address"]) return [TextContent( type="text", text=decompiled )] case IDATools.GET_GLOBAL_VARIABLE_BY_NAME: variable_info: str = ida_functions.get_global_variable_by_name(arguments["variable_name"]) return [TextContent( type="text", text=variable_info )] case IDATools.GET_GLOBAL_VARIABLE_BY_ADDRESS: variable_info: str = ida_functions.get_global_variable_by_address(arguments["address"]) return [TextContent( type="text", text=variable_info )] case IDATools.GET_CURRENT_FUNCTION_ASSEMBLY: assembly: str = ida_functions.get_current_function_assembly() return [TextContent( type="text", text=assembly )] case IDATools.GET_CURRENT_FUNCTION_DECOMPILED: decompiled: str = ida_functions.get_current_function_decompiled() return [TextContent( type="text", text=decompiled )] case IDATools.RENAME_LOCAL_VARIABLE: result: str = ida_functions.rename_local_variable( arguments["function_name"], arguments["old_name"], arguments["new_name"] ) return [TextContent( type="text", text=result )] case IDATools.RENAME_GLOBAL_VARIABLE: result: str = ida_functions.rename_global_variable( arguments["old_name"], arguments["new_name"] ) return [TextContent( type="text", text=result )] case IDATools.RENAME_FUNCTION: result: str = ida_functions.rename_function( arguments["old_name"], arguments["new_name"] ) return [TextContent( type="text", text=result )] case IDATools.RENAME_MULTI_LOCAL_VARIABLES: result: str = ida_functions.rename_multi_local_variables( arguments["function_name"], arguments["rename_pairs_old2new"] ) return [TextContent( type="text", text=result )] case IDATools.RENAME_MULTI_GLOBAL_VARIABLES: result: str = ida_functions.rename_multi_global_variables( arguments["rename_pairs_old2new"] ) return [TextContent( type="text", text=result )] case IDATools.RENAME_MULTI_FUNCTIONS: result: str = ida_functions.rename_multi_functions( arguments["rename_pairs_old2new"] ) return [TextContent( type="text", text=result )] case IDATools.ADD_ASSEMBLY_COMMENT: result: str = ida_functions.add_assembly_comment( arguments["address"], arguments["comment"], arguments.get("is_repeatable", False) ) return [TextContent( type="text", text=result )] case IDATools.ADD_FUNCTION_COMMENT: result: str = ida_functions.add_function_comment( arguments["function_name"], arguments["comment"], arguments.get("is_repeatable", False) ) return [TextContent( type="text", text=result )] case IDATools.ADD_PSEUDOCODE_COMMENT: result: str = ida_functions.add_pseudocode_comment( arguments["function_name"], arguments["address"], arguments["comment"], arguments.get("is_repeatable", False) ) return [TextContent( type="text", text=result )] case IDATools.EXECUTE_SCRIPT: try: if "script" not in arguments or not arguments["script"]: return [TextContent( type="text", text="Error: No script content provided" )] result: str = ida_functions.execute_script(arguments["script"]) return [TextContent( type="text", text=result )] except Exception as e: logger.error(f"Error executing script: {str(e)}", exc_info=True) return [TextContent( type="text", text=f"Error executing script: {str(e)}" )] case IDATools.EXECUTE_SCRIPT_FROM_FILE: try: if "file_path" not in arguments or not arguments["file_path"]: return [TextContent( type="text", text="Error: No file path provided" )] result: str = ida_functions.execute_script_from_file(arguments["file_path"]) return [TextContent( type="text", text=result )] except Exception as e: logger.error(f"Error executing script from file: {str(e)}", exc_info=True) return [TextContent( type="text", text=f"Error executing script from file: {str(e)}" )] case _: raise ValueError(f"Unknown tool: {name}") except Exception as e: logger.error(f"Error calling tool: {str(e)}", exc_info=True) return [TextContent( type="text", text=f"Error executing {name}: {str(e)}" )] options = server.create_initialization_options() async with stdio_server() as (read_stream, write_stream): await server.run(read_stream, write_stream, options, raise_exceptions=True) ```