# Directory Structure ``` ├── .gitignore ├── README.md ├── requirements.txt └── src ├── client.py ├── config.py ├── exceptions.py ├── models.py ├── server.py ├── service.py ├── task_handler.py └── utils.py ``` # Files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` 1 | # Python-generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | 9 | # Virtual environments 10 | .venv 11 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- ``` 1 | # Core dependencies for Midjourney MCP Server 2 | httpx>=0.28.1,<1.0.0 3 | mcp>=1.9.1,<2.0.0 4 | pydantic>=2.11.5,<3.0.0 5 | python-dotenv>=1.1.0 6 | ``` -------------------------------------------------------------------------------- /src/exceptions.py: -------------------------------------------------------------------------------- ```python 1 | """Custom exceptions for Midjourney MCP server.""" 2 | 3 | 4 | class MidjourneyMCPError(Exception): 5 | """Base exception for Midjourney MCP errors.""" 6 | pass 7 | 8 | 9 | class ConfigurationError(MidjourneyMCPError): 10 | """Raised when there's a configuration error.""" 11 | pass 12 | 13 | 14 | class APIError(MidjourneyMCPError): 15 | """Base class for API-related errors.""" 16 | 17 | def __init__(self, message: str, status_code: int = None, response_data: dict = None): 18 | super().__init__(message) 19 | self.status_code = status_code 20 | self.response_data = response_data or {} 21 | 22 | 23 | class AuthenticationError(APIError): 24 | """Raised when API authentication fails.""" 25 | pass 26 | 27 | 28 | class RateLimitError(APIError): 29 | """Raised when API rate limit is exceeded.""" 30 | pass 31 | 32 | 33 | class TaskSubmissionError(APIError): 34 | """Raised when task submission fails.""" 35 | pass 36 | 37 | 38 | class TaskNotFoundError(APIError): 39 | """Raised when a task is not found.""" 40 | pass 41 | 42 | 43 | class TaskFailedError(APIError): 44 | """Raised when a task fails to complete.""" 45 | pass 46 | 47 | 48 | class TimeoutError(MidjourneyMCPError): 49 | """Raised when an operation times out.""" 50 | pass 51 | 52 | 53 | class ValidationError(MidjourneyMCPError): 54 | """Raised when input validation fails.""" 55 | pass 56 | 57 | 58 | class NetworkError(MidjourneyMCPError): 59 | """Raised when network operations fail.""" 60 | pass 61 | ``` -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- ```python 1 | """Configuration management for Midjourney MCP server.""" 2 | 3 | import os 4 | import logging 5 | import sys 6 | from typing import Optional 7 | 8 | 9 | class Config: 10 | """Configuration settings for the Midjourney MCP server.""" 11 | 12 | def __init__(self): 13 | """Initialize configuration from environment variables.""" 14 | # GPTNB API Configuration 15 | self.gptnb_api_key = os.getenv("GPTNB_API_KEY", "") 16 | self.gptnb_base_url = os.getenv("GPTNB_BASE_URL", "https://api.gptnb.ai") 17 | 18 | # Request Configuration 19 | self.timeout = int(os.getenv("TIMEOUT", "300")) 20 | self.max_retries = int(os.getenv("MAX_RETRIES", "3")) 21 | self.retry_delay = float(os.getenv("RETRY_DELAY", "1.0")) 22 | 23 | # Optional Configuration 24 | self.notify_hook = os.getenv("NOTIFY_HOOK") 25 | 26 | # Midjourney Settings 27 | self.default_suffix = os.getenv("DEFAULT_SUFFIX", "--v 6.1") 28 | 29 | # Logging 30 | self.log_level = os.getenv("LOG_LEVEL", "INFO") 31 | 32 | # Validate configuration 33 | self._validate() 34 | 35 | def _validate(self): 36 | """Validate configuration values.""" 37 | if not self.gptnb_api_key: 38 | raise ValueError("GPTNB_API_KEY environment variable is required") 39 | 40 | if not self.gptnb_api_key.startswith("sk-"): 41 | raise ValueError("GPTNB API key must start with 'sk-'") 42 | 43 | if self.timeout <= 0: 44 | raise ValueError("Timeout must be positive") 45 | 46 | if self.max_retries < 0: 47 | raise ValueError("Max retries must be non-negative") 48 | 49 | if self.retry_delay < 0: 50 | raise ValueError("Retry delay must be non-negative") 51 | 52 | # Validate base URL format 53 | if not self.gptnb_base_url.startswith(('http://', 'https://')): 54 | raise ValueError("Base URL must start with http:// or https://") 55 | 56 | # Validate log level 57 | valid_log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] 58 | if self.log_level.upper() not in valid_log_levels: 59 | raise ValueError(f"Invalid log level: {self.log_level}. Must be one of {valid_log_levels}") 60 | 61 | def __repr__(self): 62 | """String representation of config (without sensitive data).""" 63 | return f"Config(api_key_configured={bool(self.gptnb_api_key)}, base_url={self.gptnb_base_url})" 64 | 65 | 66 | # Global configuration instance 67 | config = Config() 68 | 69 | 70 | def get_config() -> Config: 71 | """Get the global configuration instance.""" 72 | return config 73 | 74 | 75 | def reload_config() -> Config: 76 | """Reload configuration from environment.""" 77 | global config 78 | config = Config() 79 | return config 80 | 81 | 82 | # ============================================================================ 83 | # Logging Configuration 84 | # ============================================================================ 85 | 86 | def setup_logging(log_level: Optional[str] = None) -> None: 87 | """Setup logging configuration. 88 | 89 | Args: 90 | log_level: Log level override (uses config if None) 91 | """ 92 | config_instance = get_config() 93 | level = log_level or config_instance.log_level 94 | 95 | # Convert string level to logging constant 96 | numeric_level = getattr(logging, level.upper(), logging.INFO) 97 | 98 | # Create formatter 99 | formatter = logging.Formatter( 100 | fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 101 | datefmt='%Y-%m-%d %H:%M:%S' 102 | ) 103 | 104 | # Setup root logger 105 | root_logger = logging.getLogger() 106 | root_logger.setLevel(numeric_level) 107 | 108 | # Remove existing handlers 109 | for handler in root_logger.handlers[:]: 110 | root_logger.removeHandler(handler) 111 | 112 | # Create console handler 113 | console_handler = logging.StreamHandler(sys.stdout) 114 | console_handler.setLevel(numeric_level) 115 | console_handler.setFormatter(formatter) 116 | root_logger.addHandler(console_handler) 117 | 118 | # Set specific logger levels 119 | logging.getLogger("httpx").setLevel(logging.WARNING) 120 | logging.getLogger("httpcore").setLevel(logging.WARNING) 121 | 122 | # Set our package logger level 123 | package_logger = logging.getLogger("midjourney_mcp") 124 | package_logger.setLevel(numeric_level) 125 | 126 | 127 | def get_logger(name: str) -> logging.Logger: 128 | """Get a logger with the given name. 129 | 130 | Args: 131 | name: Logger name 132 | 133 | Returns: 134 | Logger instance 135 | """ 136 | return logging.getLogger(f"midjourney_mcp.{name}") 137 | ``` -------------------------------------------------------------------------------- /src/task_handler.py: -------------------------------------------------------------------------------- ```python 1 | """Unified task handling for Midjourney operations - task management and scheduling.""" 2 | 3 | import asyncio 4 | import logging 5 | 6 | from client import GPTNBClient 7 | from models import TaskDetail, TaskStatus, TaskResponse 8 | from config import Config 9 | from exceptions import TaskFailedError, TimeoutError, TaskNotFoundError 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | # ============================================================================ 15 | # Task Manager 16 | # ============================================================================ 17 | 18 | class TaskManager: 19 | """Manages Midjourney task lifecycle.""" 20 | 21 | def __init__(self, client: GPTNBClient, config: Config): 22 | """Initialize task manager.""" 23 | self.client = client 24 | self.config = config 25 | self.poll_interval = 5 # seconds 26 | self.max_poll_time = config.timeout 27 | 28 | async def submit_and_wait(self, submit_func, *args, **kwargs) -> TaskDetail: 29 | """Submit a task and wait for completion.""" 30 | logger.info(f"Submitting task with function: {submit_func.__name__}") 31 | response: TaskResponse = await submit_func(*args, **kwargs) 32 | 33 | if response.code != 1: 34 | raise TaskFailedError(f"Task submission failed: {response.description}") 35 | 36 | if not response.result: 37 | raise TaskFailedError("No task ID returned from submission") 38 | 39 | task_id = response.result 40 | logger.info(f"Task submitted successfully with ID: {task_id}") 41 | 42 | # Wait for completion 43 | return await self.wait_for_completion(task_id) 44 | 45 | async def wait_for_completion(self, task_id: str) -> TaskDetail: 46 | """Wait for task completion by polling.""" 47 | logger.info(f"Waiting for task completion: {task_id}") 48 | start_time = asyncio.get_event_loop().time() 49 | 50 | while True: 51 | try: 52 | # Get task status 53 | task = await self.client.get_task(task_id) 54 | 55 | logger.debug(f"Task {task_id} status: {task.status}, progress: {task.progress}") 56 | 57 | # Check if completed 58 | if task.status == TaskStatus.SUCCESS: 59 | logger.info(f"Task {task_id} completed successfully") 60 | return task 61 | 62 | # Check if failed 63 | if task.status == TaskStatus.FAILURE: 64 | error_msg = task.failReason or "Unknown error" 65 | logger.error(f"Task {task_id} failed: {error_msg}") 66 | raise TaskFailedError(f"Task failed: {error_msg}") 67 | 68 | # Check timeout 69 | elapsed = asyncio.get_event_loop().time() - start_time 70 | if elapsed > self.max_poll_time: 71 | logger.error(f"Task {task_id} timed out after {elapsed:.1f} seconds") 72 | raise TimeoutError(f"Task timed out after {self.max_poll_time} seconds") 73 | 74 | # Wait before next poll 75 | await asyncio.sleep(self.poll_interval) 76 | 77 | except TaskNotFoundError: 78 | logger.error(f"Task {task_id} not found") 79 | raise 80 | except Exception as e: 81 | logger.error(f"Error polling task {task_id}: {e}") 82 | await asyncio.sleep(self.poll_interval) 83 | 84 | async def get_task_status(self, task_id: str) -> TaskDetail: 85 | """Get current task status.""" 86 | # Get from API 87 | task = await self.client.get_task(task_id) 88 | return task 89 | 90 | def format_task_result(self, task: TaskDetail) -> str: 91 | """Format task result for display.""" 92 | if task.status == TaskStatus.SUCCESS: 93 | if task.imageUrl: 94 | result = f"✅ Task completed successfully!\n\n" 95 | result += f"**Image URL:** {task.imageUrl}\n\n" 96 | result += f"🖼️ **Generated Image:**\n" 97 | result += f"\n\n" 98 | result += f"📎 **Direct Link:** {task.imageUrl}\n\n" 99 | result += f"**Task ID:** {task.id}" 100 | return result 101 | elif task.description: 102 | return f"✅ Task completed successfully!\n\n**Result:** {task.description}\n\n**Task ID:** {task.id}" 103 | else: 104 | return f"✅ Task completed successfully!\n\n**Task ID:** {task.id}" 105 | 106 | elif task.status == TaskStatus.FAILURE: 107 | error_msg = task.failReason or "Unknown error" 108 | return f"❌ Task failed: {error_msg}\n\n**Task ID:** {task.id}" 109 | 110 | elif task.status == TaskStatus.IN_PROGRESS: 111 | progress = task.progress or "Processing" 112 | return f"🔄 Task in progress: {progress}\n\n**Task ID:** {task.id}" 113 | 114 | else: 115 | return f"⏳ Task status: {task.status}\n\n**Task ID:** {task.id}" 116 | 117 | ``` -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- ```python 1 | """Utility functions for Midjourney MCP server.""" 2 | 3 | import base64 4 | import re 5 | from typing import List, Optional 6 | from exceptions import ValidationError 7 | 8 | 9 | def validate_base64_image(base64_str: str) -> bool: 10 | """Validate base64 image string. 11 | 12 | Args: 13 | base64_str: Base64 encoded image string 14 | 15 | Returns: 16 | True if valid, False otherwise 17 | """ 18 | try: 19 | # Check if it has data URL prefix 20 | if base64_str.startswith('data:image/'): 21 | # Extract base64 part after comma 22 | if ',' in base64_str: 23 | base64_str = base64_str.split(',', 1)[1] 24 | 25 | # Try to decode 26 | base64.b64decode(base64_str, validate=True) 27 | return True 28 | except Exception: 29 | return False 30 | 31 | 32 | def format_base64_image(base64_str: str, image_type: str = "png") -> str: 33 | """Format base64 string with proper data URL prefix. 34 | 35 | Args: 36 | base64_str: Base64 encoded image string 37 | image_type: Image type (png, jpg, jpeg, webp) 38 | 39 | Returns: 40 | Properly formatted base64 data URL 41 | """ 42 | # Remove existing data URL prefix if present 43 | if base64_str.startswith('data:image/'): 44 | return base64_str 45 | 46 | # Add data URL prefix 47 | return f"data:image/{image_type};base64,{base64_str}" 48 | 49 | 50 | def validate_aspect_ratio(aspect_ratio: str) -> bool: 51 | """Validate aspect ratio format. 52 | 53 | Args: 54 | aspect_ratio: Aspect ratio string (e.g., "16:9", "1:1") 55 | 56 | Returns: 57 | True if valid, False otherwise 58 | """ 59 | pattern = r'^\d+:\d+$' 60 | return bool(re.match(pattern, aspect_ratio)) 61 | 62 | 63 | def validate_prompt(prompt: str) -> str: 64 | """Validate and clean prompt text. 65 | 66 | Args: 67 | prompt: Input prompt 68 | 69 | Returns: 70 | Cleaned prompt 71 | 72 | Raises: 73 | ValidationError: If prompt is invalid 74 | """ 75 | if not prompt or not prompt.strip(): 76 | raise ValidationError("Prompt cannot be empty") 77 | 78 | prompt = prompt.strip() 79 | 80 | # Check length (Midjourney has limits) 81 | if len(prompt) > 4000: 82 | raise ValidationError("Prompt is too long (max 4000 characters)") 83 | 84 | return prompt 85 | 86 | 87 | def validate_task_id(task_id: str) -> str: 88 | """Validate task ID format. 89 | 90 | Args: 91 | task_id: Task ID string 92 | 93 | Returns: 94 | Validated task ID 95 | 96 | Raises: 97 | ValidationError: If task ID is invalid 98 | """ 99 | if not task_id or not task_id.strip(): 100 | raise ValidationError("Task ID cannot be empty") 101 | 102 | task_id = task_id.strip() 103 | 104 | # Basic format validation (adjust based on GPTNB format) 105 | if not task_id.isdigit() and len(task_id) < 10: 106 | raise ValidationError("Invalid task ID format") 107 | 108 | return task_id 109 | 110 | 111 | def validate_image_index(index: int) -> int: 112 | """Validate image index for variations/upscales. 113 | 114 | Args: 115 | index: Image index (1-4) 116 | 117 | Returns: 118 | Validated index 119 | 120 | Raises: 121 | ValidationError: If index is invalid 122 | """ 123 | if not isinstance(index, int) or index < 1 or index > 4: 124 | raise ValidationError("Image index must be between 1 and 4") 125 | 126 | return index 127 | 128 | 129 | def validate_base64_images(base64_images: List[str], min_count: int = 1, max_count: int = 5) -> List[str]: 130 | """Validate list of base64 images. 131 | 132 | Args: 133 | base64_images: List of base64 image strings 134 | min_count: Minimum number of images required 135 | max_count: Maximum number of images allowed 136 | 137 | Returns: 138 | Validated list of base64 images 139 | 140 | Raises: 141 | ValidationError: If validation fails 142 | """ 143 | if not base64_images: 144 | if min_count > 0: 145 | raise ValidationError(f"At least {min_count} image(s) required") 146 | return [] 147 | 148 | if len(base64_images) < min_count: 149 | raise ValidationError(f"At least {min_count} image(s) required") 150 | 151 | if len(base64_images) > max_count: 152 | raise ValidationError(f"Maximum {max_count} image(s) allowed") 153 | 154 | # Validate each image 155 | validated_images = [] 156 | for i, img in enumerate(base64_images): 157 | if not validate_base64_image(img): 158 | raise ValidationError(f"Invalid base64 image at index {i}") 159 | validated_images.append(format_base64_image(img)) 160 | 161 | return validated_images 162 | 163 | 164 | def extract_task_id_from_response(response_text: str) -> Optional[str]: 165 | """Extract task ID from response text. 166 | 167 | Args: 168 | response_text: Response text that may contain task ID 169 | 170 | Returns: 171 | Extracted task ID or None 172 | """ 173 | # Look for patterns like "Task ID: 1234567890" or similar 174 | patterns = [ 175 | r'[Tt]ask\s+ID[:\s]+(\d+)', 176 | r'ID[:\s]+(\d+)', 177 | r'(\d{10,})', # Long numeric IDs 178 | ] 179 | 180 | for pattern in patterns: 181 | match = re.search(pattern, response_text) 182 | if match: 183 | return match.group(1) 184 | 185 | return None 186 | 187 | 188 | def format_error_message(error: Exception, context: str = "") -> str: 189 | """Format error message for user display. 190 | 191 | Args: 192 | error: Exception object 193 | context: Additional context information 194 | 195 | Returns: 196 | Formatted error message 197 | """ 198 | error_type = type(error).__name__ 199 | error_msg = str(error) 200 | 201 | if context: 202 | return f"Error in {context}: {error_type} - {error_msg}" 203 | else: 204 | return f"{error_type}: {error_msg}" 205 | 206 | 207 | def truncate_text(text: str, max_length: int = 100) -> str: 208 | """Truncate text to specified length. 209 | 210 | Args: 211 | text: Input text 212 | max_length: Maximum length 213 | 214 | Returns: 215 | Truncated text 216 | """ 217 | if len(text) <= max_length: 218 | return text 219 | 220 | return text[:max_length - 3] + "..." 221 | ``` -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- ```python 1 | """Midjourney MCP Server - Main entry point with tool functions.""" 2 | 3 | import logging 4 | from typing import List, Optional 5 | from mcp.server.fastmcp import FastMCP 6 | 7 | from service import get_service 8 | from config import setup_logging, get_config 9 | 10 | # Initialize logging 11 | setup_logging() 12 | logger = logging.getLogger(__name__) 13 | 14 | # Initialize MCP server 15 | mcp = FastMCP("midjourney") 16 | 17 | # Get configuration 18 | config = get_config() 19 | 20 | 21 | # ============================================================================ 22 | # MCP Tool Functions - Complete Midjourney Functionality 23 | # ============================================================================ 24 | 25 | @mcp.tool() 26 | async def imagine_image( 27 | prompt: str, 28 | aspect_ratio: str = "1:1", 29 | base64_images: Optional[List[str]] = None 30 | ) -> str: 31 | """Generate images from text prompts with optional reference images. 32 | 33 | Args: 34 | prompt: Text description of the image to generate (English only) 35 | aspect_ratio: Aspect ratio of the image (e.g., "16:9", "1:1", "9:16") 36 | base64_images: Optional list of reference images in base64 format 37 | 38 | Returns: 39 | Generated image URL and task information 40 | """ 41 | try: 42 | logger.info(f"Imagine request: {prompt[:100]}...") 43 | service = await get_service() 44 | result = await service.imagine(prompt, aspect_ratio, base64_images) 45 | return result 46 | except Exception as e: 47 | logger.error(f"Error in imagine_image: {e}") 48 | return f"❌ Error generating image: {str(e)}" 49 | 50 | 51 | @mcp.tool() 52 | async def blend_images( 53 | base64_images: List[str], 54 | dimensions: str = "SQUARE" 55 | ) -> str: 56 | """Blend multiple images together. 57 | 58 | Args: 59 | base64_images: List of 2-5 images to blend in base64 format 60 | dimensions: Output dimensions ("PORTRAIT", "SQUARE", "LANDSCAPE") 61 | 62 | Returns: 63 | Blended image URL and task information 64 | """ 65 | try: 66 | logger.info(f"Blend request: {len(base64_images)} images") 67 | service = await get_service() 68 | result = await service.blend(base64_images, dimensions) 69 | return result 70 | except Exception as e: 71 | logger.error(f"Error in blend_images: {e}") 72 | return f"❌ Error blending images: {str(e)}" 73 | 74 | 75 | @mcp.tool() 76 | async def describe_image(base64_image: str) -> str: 77 | """Generate text descriptions of an image. 78 | 79 | Args: 80 | base64_image: Image to describe in base64 format 81 | 82 | Returns: 83 | Text description of the image 84 | """ 85 | try: 86 | logger.info("Describe image request") 87 | service = await get_service() 88 | result = await service.describe(base64_image) 89 | return result 90 | except Exception as e: 91 | logger.error(f"Error in describe_image: {e}") 92 | return f"❌ Error describing image: {str(e)}" 93 | 94 | 95 | @mcp.tool() 96 | async def change_image( 97 | task_id: str, 98 | action: str, 99 | index: Optional[int] = None 100 | ) -> str: 101 | """Create variations, upscales, or rerolls of existing images. 102 | 103 | Args: 104 | task_id: ID of the original generation task 105 | action: Action type ("UPSCALE", "VARIATION", "REROLL") 106 | index: Image index (1-4) for UPSCALE and VARIATION actions 107 | 108 | Returns: 109 | Modified image URL and task information 110 | """ 111 | try: 112 | logger.info(f"Change request: {action} for task {task_id}") 113 | service = await get_service() 114 | result = await service.change(task_id, action, index) 115 | return result 116 | except Exception as e: 117 | logger.error(f"Error in change_image: {e}") 118 | return f"❌ Error changing image: {str(e)}" 119 | 120 | 121 | @mcp.tool() 122 | async def modal_edit( 123 | task_id: str, 124 | action: str, 125 | prompt: Optional[str] = None 126 | ) -> str: 127 | """Perform advanced editing like zoom, pan, or inpainting. 128 | 129 | Args: 130 | task_id: ID of the original generation task 131 | action: Edit action type (zoom, pan, inpaint, etc.) 132 | prompt: Additional prompt for the edit 133 | 134 | Returns: 135 | Edited image URL and task information 136 | """ 137 | try: 138 | logger.info(f"Modal edit request: {action} for task {task_id}") 139 | service = await get_service() 140 | result = await service.modal_edit(task_id, action, prompt) 141 | return result 142 | except Exception as e: 143 | logger.error(f"Error in modal_edit: {e}") 144 | return f"❌ Error in modal edit: {str(e)}" 145 | 146 | 147 | @mcp.tool() 148 | async def swap_face(source_image: str, target_image: str) -> str: 149 | """Swap faces between two images. 150 | 151 | Args: 152 | source_image: Source face image in base64 format 153 | target_image: Target image in base64 format 154 | 155 | Returns: 156 | Face-swapped image URL and task information 157 | """ 158 | try: 159 | logger.info("Face swap request") 160 | service = await get_service() 161 | result = await service.swap_face(source_image, target_image) 162 | return result 163 | except Exception as e: 164 | logger.error(f"Error in swap_face: {e}") 165 | return f"❌ Error swapping faces: {str(e)}" 166 | 167 | 168 | # ============================================================================ 169 | # Task Management Tools 170 | # ============================================================================ 171 | 172 | @mcp.tool() 173 | async def get_task_status(task_id: str) -> str: 174 | """Get current status of a Midjourney task. 175 | 176 | Args: 177 | task_id: Task ID to check 178 | 179 | Returns: 180 | Current task status and details 181 | """ 182 | try: 183 | logger.info(f"Task status request: {task_id}") 184 | service = await get_service() 185 | result = await service.get_task_status(task_id) 186 | return result 187 | except Exception as e: 188 | logger.error(f"Error in get_task_status: {e}") 189 | return f"❌ Error getting task status: {str(e)}" 190 | 191 | 192 | # ============================================================================ 193 | # Server Lifecycle Management 194 | # ============================================================================ 195 | 196 | # Server lifecycle will be handled by the main function 197 | 198 | 199 | def main(): 200 | """Main entry point for the MCP server.""" 201 | try: 202 | logger.info("Starting Midjourney MCP Server...") 203 | logger.info(f"Configuration: API Key configured = {bool(config.gptnb_api_key)}") 204 | logger.info(f"Base URL: {config.gptnb_base_url}") 205 | logger.info("Server started successfully!") 206 | 207 | mcp.run(transport="stdio") 208 | except KeyboardInterrupt: 209 | logger.info("Server interrupted by user") 210 | except Exception as e: 211 | logger.error(f"Server error: {e}") 212 | raise 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | ``` -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- ```python 1 | """Data models for Midjourney MCP server.""" 2 | 3 | from enum import Enum 4 | from typing import List, Optional, Dict, Any, Union 5 | from pydantic import BaseModel, Field, validator 6 | 7 | 8 | class TaskStatus(str, Enum): 9 | """Task status enumeration.""" 10 | NOT_START = "NOT_START" 11 | SUBMITTED = "SUBMITTED" 12 | IN_PROGRESS = "IN_PROGRESS" 13 | SUCCESS = "SUCCESS" 14 | FAILURE = "FAILURE" 15 | 16 | 17 | class TaskAction(str, Enum): 18 | """Task action enumeration.""" 19 | IMAGINE = "IMAGINE" 20 | UPSCALE = "UPSCALE" 21 | VARIATION = "VARIATION" 22 | REROLL = "REROLL" 23 | DESCRIBE = "DESCRIBE" 24 | BLEND = "BLEND" 25 | SWAP_FACE = "SWAP_FACE" 26 | SHORTEN = "SHORTEN" 27 | 28 | 29 | class Dimensions(str, Enum): 30 | """Image dimensions enumeration.""" 31 | PORTRAIT = "PORTRAIT" # 2:3 32 | SQUARE = "SQUARE" # 1:1 33 | LANDSCAPE = "LANDSCAPE" # 3:2 34 | 35 | 36 | class Button(BaseModel): 37 | """Button model for task actions.""" 38 | customId: str = Field(..., description="Custom ID for action submission") 39 | label: str = Field(..., description="Button label") 40 | type: Union[str, int] = Field(..., description="Button type") 41 | style: Union[str, int] = Field(..., description="Button style") 42 | emoji: str = Field(..., description="Button emoji") 43 | 44 | @validator('type', pre=True) 45 | def convert_type_to_string(cls, v): 46 | """Convert type to string if it's an integer.""" 47 | return str(v) if v is not None else v 48 | 49 | @validator('style', pre=True) 50 | def convert_style_to_string(cls, v): 51 | """Convert style to string if it's an integer.""" 52 | return str(v) if v is not None else v 53 | 54 | 55 | class TaskResponse(BaseModel): 56 | """Task response model.""" 57 | code: int = Field(..., description="Status code") 58 | description: str = Field(..., description="Response description") 59 | result: Optional[str] = Field(None, description="Task ID") 60 | properties: Dict[str, Any] = Field(default_factory=dict, description="Additional properties") 61 | 62 | 63 | class TaskDetail(BaseModel): 64 | """Detailed task information.""" 65 | id: Optional[str] = Field(None, description="Task ID") 66 | action: Optional[TaskAction] = Field(None, description="Task action") 67 | prompt: Optional[str] = Field(None, description="Original prompt") 68 | promptEn: Optional[str] = Field(None, description="English prompt") 69 | description: Optional[str] = Field(None, description="Task description") 70 | status: Optional[TaskStatus] = Field(None, description="Task status") 71 | progress: Optional[str] = Field(None, description="Task progress") 72 | imageUrl: Optional[str] = Field(None, description="Generated image URL") 73 | failReason: Optional[str] = Field(None, description="Failure reason") 74 | submitTime: Optional[int] = Field(None, description="Submit timestamp") 75 | startTime: Optional[int] = Field(None, description="Start timestamp") 76 | finishTime: Optional[int] = Field(None, description="Finish timestamp") 77 | state: Optional[str] = Field(None, description="Custom state") 78 | buttons: Optional[List[Button]] = Field(default_factory=list, description="Available action buttons") 79 | properties: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional properties") 80 | 81 | @classmethod 82 | def from_api_response(cls, data: Dict[str, Any]) -> "TaskDetail": 83 | """Create TaskDetail from API response, handling missing/invalid fields.""" 84 | # Handle empty or invalid status 85 | status = data.get('status', 'NOT_START') 86 | if not status or status not in [s.value for s in TaskStatus]: 87 | status = 'NOT_START' 88 | 89 | # Handle empty or invalid action 90 | action = data.get('action') 91 | if action and action not in [a.value for a in TaskAction]: 92 | action = None 93 | 94 | # Handle buttons - ensure it's a list and convert properly 95 | buttons = data.get('buttons') 96 | if buttons is None: 97 | buttons = [] 98 | elif not isinstance(buttons, list): 99 | buttons = [] 100 | else: 101 | # Convert button data to ensure proper types 102 | converted_buttons = [] 103 | for button_data in buttons: 104 | if isinstance(button_data, dict): 105 | # Ensure all required fields exist with defaults 106 | button_dict = { 107 | 'customId': button_data.get('customId', ''), 108 | 'label': button_data.get('label', ''), 109 | 'type': button_data.get('type', ''), 110 | 'style': button_data.get('style', ''), 111 | 'emoji': button_data.get('emoji', '') 112 | } 113 | converted_buttons.append(button_dict) 114 | buttons = converted_buttons 115 | 116 | # Handle properties - ensure it's a dict 117 | properties = data.get('properties') 118 | if properties is None: 119 | properties = {} 120 | elif not isinstance(properties, dict): 121 | properties = {} 122 | 123 | # Create cleaned data 124 | cleaned_data = { 125 | **data, 126 | 'status': status, 127 | 'action': action, 128 | 'buttons': buttons, 129 | 'properties': properties 130 | } 131 | 132 | return cls(**cleaned_data) 133 | 134 | 135 | class ImagineRequest(BaseModel): 136 | """Request model for imagine task.""" 137 | prompt: str = Field(..., description="Text prompt") 138 | base64Array: Optional[List[str]] = Field(None, description="Reference images in base64") 139 | notifyHook: Optional[str] = Field(None, description="Callback URL") 140 | state: Optional[str] = Field(None, description="Custom state") 141 | 142 | 143 | class BlendRequest(BaseModel): 144 | """Request model for blend task.""" 145 | base64Array: List[str] = Field(..., description="Images to blend in base64", min_items=2, max_items=5) 146 | dimensions: Dimensions = Field(default=Dimensions.SQUARE, description="Output dimensions") 147 | notifyHook: Optional[str] = Field(None, description="Callback URL") 148 | state: Optional[str] = Field(None, description="Custom state") 149 | 150 | 151 | class DescribeRequest(BaseModel): 152 | """Request model for describe task.""" 153 | base64: str = Field(..., description="Image in base64 format") 154 | notifyHook: Optional[str] = Field(None, description="Callback URL") 155 | state: Optional[str] = Field(None, description="Custom state") 156 | 157 | 158 | class ChangeRequest(BaseModel): 159 | """Request model for change task.""" 160 | action: str = Field(..., description="Action type (UPSCALE, VARIATION, REROLL)") 161 | index: Optional[int] = Field(None, description="Image index (1-4)") 162 | taskId: str = Field(..., description="Original task ID") 163 | notifyHook: Optional[str] = Field(None, description="Callback URL") 164 | state: Optional[str] = Field(None, description="Custom state") 165 | 166 | 167 | class SwapFaceRequest(BaseModel): 168 | """Request model for swap face task.""" 169 | sourceBase64: str = Field(..., description="Source face image in base64") 170 | targetBase64: str = Field(..., description="Target image in base64") 171 | notifyHook: Optional[str] = Field(None, description="Callback URL") 172 | state: Optional[str] = Field(None, description="Custom state") 173 | 174 | 175 | class ModalRequest(BaseModel): 176 | """Request model for modal task.""" 177 | taskId: str = Field(..., description="Original task ID") 178 | maskBase64: str = Field(..., description="Mask image in base64") 179 | prompt: Optional[str] = Field(None, description="Additional prompt") 180 | notifyHook: Optional[str] = Field(None, description="Callback URL") 181 | state: Optional[str] = Field(None, description="Custom state") 182 | ``` -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- ```python 1 | """GPTNB API client for Midjourney operations.""" 2 | 3 | import asyncio 4 | import logging 5 | from typing import Optional, Dict, Any, List 6 | import httpx 7 | from config import Config 8 | from models import ( 9 | TaskResponse, TaskDetail, ImagineRequest, BlendRequest, 10 | DescribeRequest, ChangeRequest, SwapFaceRequest, ModalRequest 11 | ) 12 | from exceptions import ( 13 | APIError, AuthenticationError, RateLimitError, 14 | TaskNotFoundError, NetworkError, TimeoutError 15 | ) 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class GPTNBClient: 21 | """GPTNB API client for Midjourney operations.""" 22 | 23 | def __init__(self, config: Config): 24 | """Initialize the GPTNB client. 25 | 26 | Args: 27 | config: Configuration object 28 | """ 29 | self.config = config 30 | self.base_url = config.gptnb_base_url.rstrip('/') 31 | self.headers = { 32 | "Authorization": f"Bearer {config.gptnb_api_key}", 33 | "Content-Type": "application/json", 34 | "User-Agent": "midjourney-mcp/0.2.0" 35 | } 36 | self._client: Optional[httpx.AsyncClient] = None 37 | 38 | async def __aenter__(self): 39 | """Async context manager entry.""" 40 | await self._ensure_client() 41 | return self 42 | 43 | async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[object]) -> None: 44 | """Async context manager exit.""" 45 | await self.close() 46 | 47 | async def _ensure_client(self): 48 | """Ensure HTTP client is initialized.""" 49 | if self._client is None: 50 | self._client = httpx.AsyncClient( 51 | timeout=httpx.Timeout(self.config.timeout), 52 | headers=self.headers 53 | ) 54 | 55 | async def close(self) -> None: 56 | """Close the HTTP client.""" 57 | if self._client: 58 | await self._client.aclose() 59 | self._client = None 60 | 61 | async def _make_request( 62 | self, 63 | method: str, 64 | endpoint: str, 65 | data: Optional[Dict[str, Any]] = None, 66 | params: Optional[Dict[str, Any]] = None 67 | ) -> Dict[str, Any]: 68 | """Make HTTP request with error handling and retries. 69 | 70 | Args: 71 | method: HTTP method 72 | endpoint: API endpoint 73 | data: Request body data 74 | params: Query parameters 75 | 76 | Returns: 77 | Response data 78 | 79 | Raises: 80 | APIError: For API-related errors 81 | NetworkError: For network-related errors 82 | TimeoutError: For timeout errors 83 | """ 84 | await self._ensure_client() 85 | url = f"{self.base_url}{endpoint}" 86 | 87 | for attempt in range(self.config.max_retries + 1): 88 | try: 89 | logger.debug(f"Making {method} request to {url} (attempt {attempt + 1})") 90 | 91 | response = await self._client.request( 92 | method=method, 93 | url=url, 94 | json=data, 95 | params=params 96 | ) 97 | 98 | # Handle different status codes 99 | if response.status_code == 200: 100 | return response.json() 101 | elif response.status_code == 401: 102 | raise AuthenticationError( 103 | "Invalid API key or authentication failed", 104 | status_code=response.status_code, 105 | response_data=response.json() if response.content else {} 106 | ) 107 | elif response.status_code == 429: 108 | raise RateLimitError( 109 | "Rate limit exceeded", 110 | status_code=response.status_code, 111 | response_data=response.json() if response.content else {} 112 | ) 113 | elif response.status_code == 404: 114 | raise TaskNotFoundError( 115 | "Task not found", 116 | status_code=response.status_code, 117 | response_data=response.json() if response.content else {} 118 | ) 119 | else: 120 | response_data = response.json() if response.content else {} 121 | raise APIError( 122 | f"API request failed with status {response.status_code}", 123 | status_code=response.status_code, 124 | response_data=response_data 125 | ) 126 | 127 | except httpx.TimeoutException as e: 128 | if attempt == self.config.max_retries: 129 | raise TimeoutError(f"Request timed out after {self.config.timeout} seconds") from e 130 | logger.warning(f"Request timeout on attempt {attempt + 1}, retrying...") 131 | 132 | except httpx.NetworkError as e: 133 | if attempt == self.config.max_retries: 134 | raise NetworkError(f"Network error: {str(e)}") from e 135 | logger.warning(f"Network error on attempt {attempt + 1}, retrying...") 136 | 137 | except (AuthenticationError, RateLimitError, TaskNotFoundError): 138 | # Don't retry these errors 139 | raise 140 | 141 | except Exception as e: 142 | if attempt == self.config.max_retries: 143 | raise APIError(f"Unexpected error: {str(e)}") from e 144 | logger.warning(f"Unexpected error on attempt {attempt + 1}, retrying...") 145 | 146 | # Wait before retry 147 | if attempt < self.config.max_retries: 148 | wait_time = self.config.retry_delay * (2 ** attempt) # Exponential backoff 149 | logger.debug(f"Waiting {wait_time} seconds before retry...") 150 | await asyncio.sleep(wait_time) 151 | 152 | raise APIError("Max retries exceeded") 153 | 154 | def _prepare_request_data(self, request: Any) -> Dict[str, Any]: 155 | """Prepare request data with notify hook if configured. 156 | 157 | Args: 158 | request: Request object with model_dump method 159 | 160 | Returns: 161 | Prepared request data 162 | """ 163 | data = request.model_dump(exclude_none=True) 164 | if self.config.notify_hook and not data.get('notifyHook'): 165 | data['notifyHook'] = self.config.notify_hook 166 | return data 167 | 168 | async def submit_imagine(self, request: ImagineRequest) -> TaskResponse: 169 | """Submit an imagine task. 170 | 171 | Args: 172 | request: Imagine request data 173 | 174 | Returns: 175 | Task response 176 | """ 177 | data = self._prepare_request_data(request) 178 | response_data = await self._make_request("POST", "/mj/submit/imagine", data) 179 | return TaskResponse(**response_data) 180 | 181 | async def submit_blend(self, request: BlendRequest) -> TaskResponse: 182 | """Submit a blend task. 183 | 184 | Args: 185 | request: Blend request data 186 | 187 | Returns: 188 | Task response 189 | """ 190 | data = self._prepare_request_data(request) 191 | response_data = await self._make_request("POST", "/mj/submit/blend", data) 192 | return TaskResponse(**response_data) 193 | 194 | async def submit_describe(self, request: DescribeRequest) -> TaskResponse: 195 | """Submit a describe task. 196 | 197 | Args: 198 | request: Describe request data 199 | 200 | Returns: 201 | Task response 202 | """ 203 | data = self._prepare_request_data(request) 204 | response_data = await self._make_request("POST", "/mj/submit/describe", data) 205 | return TaskResponse(**response_data) 206 | 207 | async def submit_change(self, request: ChangeRequest) -> TaskResponse: 208 | """Submit a change task (upscale, variation, reroll). 209 | 210 | Args: 211 | request: Change request data 212 | 213 | Returns: 214 | Task response 215 | """ 216 | data = self._prepare_request_data(request) 217 | response_data = await self._make_request("POST", "/mj/submit/change", data) 218 | return TaskResponse(**response_data) 219 | 220 | async def submit_swap_face(self, request: SwapFaceRequest) -> TaskResponse: 221 | """Submit a swap face task. 222 | 223 | Args: 224 | request: Swap face request data 225 | 226 | Returns: 227 | Task response 228 | """ 229 | data = self._prepare_request_data(request) 230 | response_data = await self._make_request("POST", "/mj/submit/swap-face", data) 231 | return TaskResponse(**response_data) 232 | 233 | async def submit_modal(self, request: ModalRequest) -> TaskResponse: 234 | """Submit a modal task (zoom, pan, inpainting). 235 | 236 | Args: 237 | request: Modal request data 238 | 239 | Returns: 240 | Task response 241 | """ 242 | data = self._prepare_request_data(request) 243 | response_data = await self._make_request("POST", "/mj/submit/modal", data) 244 | return TaskResponse(**response_data) 245 | 246 | async def get_task(self, task_id: str) -> TaskDetail: 247 | """Get task details by ID. 248 | 249 | Args: 250 | task_id: Task ID 251 | 252 | Returns: 253 | Task details 254 | """ 255 | response_data = await self._make_request("GET", f"/mj/task/{task_id}/fetch") 256 | return TaskDetail.from_api_response(response_data) 257 | 258 | async def get_tasks(self, task_ids: List[str]) -> List[TaskDetail]: 259 | """Get multiple tasks by IDs. 260 | 261 | Args: 262 | task_ids: List of task IDs 263 | 264 | Returns: 265 | List of task details 266 | """ 267 | data = {"ids": task_ids} 268 | response_data = await self._make_request("POST", "/mj/task/list-by-condition", data) 269 | return [TaskDetail(**task) for task in response_data] 270 | ``` -------------------------------------------------------------------------------- /src/service.py: -------------------------------------------------------------------------------- ```python 1 | """Midjourney service layer for business logic.""" 2 | 3 | import logging 4 | from functools import wraps 5 | from typing import List, Optional, Callable 6 | from client import GPTNBClient 7 | from task_handler import TaskManager 8 | from models import ( 9 | ImagineRequest, BlendRequest, DescribeRequest, ChangeRequest, 10 | SwapFaceRequest, ModalRequest, Dimensions 11 | ) 12 | from config import Config, get_config 13 | from utils import ( 14 | validate_prompt, validate_aspect_ratio, validate_base64_images, 15 | validate_task_id, validate_image_index, format_error_message 16 | ) 17 | from exceptions import ValidationError 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def handle_service_errors(operation_name: str): 23 | """Decorator to handle common service errors. 24 | 25 | Args: 26 | operation_name: Name of the operation for error messages 27 | """ 28 | def decorator(func: Callable) -> Callable: 29 | """Decorator function.""" 30 | @wraps(func) 31 | async def wrapper(*args, **kwargs) -> str: 32 | """Wrapper function with error handling.""" 33 | try: 34 | return await func(*args, **kwargs) 35 | except Exception as e: 36 | logger.error(f"Error in {operation_name}: {e}") 37 | return format_error_message(e, operation_name) 38 | return wrapper 39 | return decorator 40 | 41 | 42 | class MidjourneyService: 43 | """High-level service for Midjourney operations.""" 44 | 45 | def __init__(self, config: Optional[Config] = None): 46 | """Initialize Midjourney service. 47 | 48 | Args: 49 | config: Configuration object (uses global config if None) 50 | """ 51 | self.config = config or get_config() 52 | self.client: Optional[GPTNBClient] = None 53 | self.task_manager: Optional[TaskManager] = None 54 | 55 | async def __aenter__(self): 56 | """Async context manager entry.""" 57 | self.client = GPTNBClient(self.config) 58 | await self.client.__aenter__() 59 | self.task_manager = TaskManager(self.client, self.config) 60 | return self 61 | 62 | async def __aexit__(self, exc_type, exc_val, exc_tb): 63 | """Async context manager exit.""" 64 | if self.client: 65 | await self.client.__aexit__(exc_type, exc_val, exc_tb) 66 | 67 | async def imagine( 68 | self, 69 | prompt: str, 70 | aspect_ratio: str = "1:1", 71 | base64_images: Optional[List[str]] = None 72 | ) -> str: 73 | """Generate images from text prompt. 74 | 75 | Args: 76 | prompt: Text description of the image 77 | aspect_ratio: Aspect ratio (e.g., "16:9", "1:1", "9:16") 78 | base64_images: Optional reference images 79 | 80 | Returns: 81 | Formatted result string 82 | 83 | Raises: 84 | ValidationError: If input validation fails 85 | MidjourneyMCPError: If operation fails 86 | """ 87 | try: 88 | # Validate inputs 89 | prompt = validate_prompt(prompt) 90 | if not validate_aspect_ratio(aspect_ratio): 91 | raise ValidationError(f"Invalid aspect ratio: {aspect_ratio}") 92 | 93 | # Add aspect ratio to prompt if not already present 94 | if "--ar" not in prompt: 95 | prompt = f"{prompt} --ar {aspect_ratio}" 96 | 97 | # Add default suffix if configured 98 | if self.config.default_suffix and self.config.default_suffix not in prompt: 99 | prompt = f"{prompt} {self.config.default_suffix}" 100 | 101 | # Validate and format images if provided 102 | validated_images = None 103 | if base64_images: 104 | validated_images = validate_base64_images(base64_images, min_count=0, max_count=5) 105 | 106 | # Create request 107 | request = ImagineRequest( 108 | prompt=prompt, 109 | base64Array=validated_images 110 | ) 111 | 112 | # Submit task (don't wait for completion) 113 | response = await self.client.submit_imagine(request) 114 | 115 | if response.code != 1: 116 | raise Exception(f"Task submission failed: {response.description}") 117 | 118 | if not response.result: 119 | raise Exception("No task ID returned from submission") 120 | 121 | task_id = response.result 122 | 123 | # Return immediate response with task ID 124 | result = f"🎨 **Image Generation Started!**\n\n" 125 | result += f"**Task ID:** {task_id}\n" 126 | result += f"**Prompt:** {prompt}\n" 127 | result += f"**Aspect Ratio:** {aspect_ratio}\n" 128 | result += f"**Status:** Task submitted successfully\n\n" 129 | result += f"💡 **Next Steps:**\n" 130 | result += f"Use `get_task_status(task_id=\"{task_id}\")` to check current status\n\n" 131 | result += f"⏱️ **Estimated Time:** 30-60 seconds\n" 132 | 133 | return result 134 | 135 | except Exception as e: 136 | logger.error(f"Error in imagine: {e}") 137 | return format_error_message(e, "imagine") 138 | 139 | async def blend( 140 | self, 141 | base64_images: List[str], 142 | dimensions: str = "SQUARE" 143 | ) -> str: 144 | """Blend multiple images together. 145 | 146 | Args: 147 | base64_images: List of 2-5 images to blend 148 | dimensions: Output dimensions ("PORTRAIT", "SQUARE", "LANDSCAPE") 149 | 150 | Returns: 151 | Formatted result string 152 | 153 | Raises: 154 | ValidationError: If input validation fails 155 | MidjourneyMCPError: If operation fails 156 | """ 157 | try: 158 | # Validate inputs 159 | validated_images = validate_base64_images(base64_images, min_count=2, max_count=5) 160 | 161 | # Validate dimensions 162 | try: 163 | dim_enum = Dimensions(dimensions.upper()) 164 | except ValueError: 165 | raise ValidationError(f"Invalid dimensions: {dimensions}. Must be PORTRAIT, SQUARE, or LANDSCAPE") 166 | 167 | # Create request 168 | request = BlendRequest( 169 | base64Array=validated_images, 170 | dimensions=dim_enum 171 | ) 172 | 173 | # Submit task (don't wait for completion) 174 | response = await self.client.submit_blend(request) 175 | 176 | if response.code != 1: 177 | raise Exception(f"Task submission failed: {response.description}") 178 | 179 | if not response.result: 180 | raise Exception("No task ID returned from submission") 181 | 182 | task_id = response.result 183 | 184 | # Return immediate response with task ID 185 | result = f"🎨 **Image Blending Started!**\n\n" 186 | result += f"**Task ID:** {task_id}\n" 187 | result += f"**Images:** {len(base64_images)} images to blend\n" 188 | result += f"**Dimensions:** {dimensions}\n" 189 | result += f"**Status:** Task submitted successfully\n\n" 190 | result += f"💡 **Monitor Progress:** Use `get_task_status(task_id=\"{task_id}\")`\n" 191 | 192 | return result 193 | 194 | except Exception as e: 195 | logger.error(f"Error in blend: {e}") 196 | return format_error_message(e, "blend") 197 | 198 | async def describe(self, base64_image: str) -> str: 199 | """Generate text description of an image. 200 | 201 | Args: 202 | base64_image: Image to describe 203 | 204 | Returns: 205 | Formatted result string 206 | 207 | Raises: 208 | ValidationError: If input validation fails 209 | MidjourneyMCPError: If operation fails 210 | """ 211 | try: 212 | # Validate image 213 | validated_images = validate_base64_images([base64_image], min_count=1, max_count=1) 214 | 215 | # Create request 216 | request = DescribeRequest( 217 | base64=validated_images[0] 218 | ) 219 | 220 | # Submit and wait for completion 221 | task = await self.task_manager.submit_and_wait( 222 | self.client.submit_describe, request 223 | ) 224 | 225 | return self.task_manager.format_task_result(task) 226 | 227 | except Exception as e: 228 | logger.error(f"Error in describe: {e}") 229 | return format_error_message(e, "describe") 230 | 231 | async def change( 232 | self, 233 | task_id: str, 234 | action: str, 235 | index: Optional[int] = None 236 | ) -> str: 237 | """Create variations, upscales, or rerolls of existing images. 238 | 239 | Args: 240 | task_id: ID of the original generation task 241 | action: Action type ("UPSCALE", "VARIATION", "REROLL") 242 | index: Image index (1-4) for UPSCALE and VARIATION 243 | 244 | Returns: 245 | Formatted result string 246 | 247 | Raises: 248 | ValidationError: If input validation fails 249 | MidjourneyMCPError: If operation fails 250 | """ 251 | try: 252 | # Validate inputs 253 | task_id = validate_task_id(task_id) 254 | action = action.upper() 255 | 256 | if action not in ["UPSCALE", "VARIATION", "REROLL"]: 257 | raise ValidationError(f"Invalid action: {action}. Must be UPSCALE, VARIATION, or REROLL") 258 | 259 | if action in ["UPSCALE", "VARIATION"]: 260 | if index is None: 261 | raise ValidationError(f"Index is required for {action} action") 262 | index = validate_image_index(index) 263 | 264 | # Create request 265 | request = ChangeRequest( 266 | taskId=task_id, 267 | action=action, 268 | index=index 269 | ) 270 | 271 | # Submit and wait for completion 272 | task = await self.task_manager.submit_and_wait( 273 | self.client.submit_change, request 274 | ) 275 | 276 | return self.task_manager.format_task_result(task) 277 | 278 | except Exception as e: 279 | logger.error(f"Error in change: {e}") 280 | return format_error_message(e, "change") 281 | 282 | async def modal_edit( 283 | self, 284 | task_id: str, 285 | action: str, 286 | prompt: Optional[str] = None 287 | ) -> str: 288 | """Perform advanced editing like zoom, pan, or inpainting. 289 | 290 | Args: 291 | task_id: ID of the original generation task 292 | action: Edit action type (zoom, pan, inpaint, etc.) 293 | prompt: Additional prompt for the edit 294 | 295 | Returns: 296 | Formatted result string 297 | 298 | Raises: 299 | ValidationError: If input validation fails 300 | MidjourneyMCPError: If operation fails 301 | """ 302 | try: 303 | # Validate inputs 304 | task_id = validate_task_id(task_id) 305 | 306 | if prompt: 307 | prompt = validate_prompt(prompt) 308 | 309 | # Note: This is a simplified implementation 310 | # Real modal operations would need mask images and specific action parameters 311 | logger.info(f"Performing modal edit with action: {action}") 312 | 313 | # Create request (Note: This is a simplified version) 314 | # In practice, modal operations require more complex parameters 315 | request = ModalRequest( 316 | taskId=task_id, 317 | maskBase64="", # This would need to be provided by the user 318 | prompt=prompt 319 | ) 320 | 321 | # Submit and wait for completion 322 | task = await self.task_manager.submit_and_wait( 323 | self.client.submit_modal, request 324 | ) 325 | 326 | return self.task_manager.format_task_result(task) 327 | 328 | except Exception as e: 329 | logger.error(f"Error in modal_edit: {e}") 330 | return format_error_message(e, "modal_edit") 331 | 332 | async def swap_face( 333 | self, 334 | source_image: str, 335 | target_image: str 336 | ) -> str: 337 | """Swap faces between two images. 338 | 339 | Args: 340 | source_image: Source face image in base64 format 341 | target_image: Target image in base64 format 342 | 343 | Returns: 344 | Formatted result string 345 | 346 | Raises: 347 | ValidationError: If input validation fails 348 | MidjourneyMCPError: If operation fails 349 | """ 350 | try: 351 | # Validate images 352 | validated_images = validate_base64_images([source_image, target_image], min_count=2, max_count=2) 353 | 354 | # Create request 355 | request = SwapFaceRequest( 356 | sourceBase64=validated_images[0], 357 | targetBase64=validated_images[1] 358 | ) 359 | 360 | # Submit and wait for completion 361 | task = await self.task_manager.submit_and_wait( 362 | self.client.submit_swap_face, request 363 | ) 364 | 365 | return self.task_manager.format_task_result(task) 366 | 367 | except Exception as e: 368 | logger.error(f"Error in swap_face: {e}") 369 | return format_error_message(e, "swap_face") 370 | 371 | 372 | async def get_task_status(self, task_id: str) -> str: 373 | """Get current status of a task. 374 | 375 | Args: 376 | task_id: Task ID to check 377 | 378 | Returns: 379 | Formatted status string 380 | 381 | Raises: 382 | ValidationError: If input validation fails 383 | MidjourneyMCPError: If operation fails 384 | """ 385 | try: 386 | # Validate task ID 387 | task_id = validate_task_id(task_id) 388 | 389 | # Get task status 390 | task = await self.task_manager.get_task_status(task_id) 391 | 392 | return self.task_manager.format_task_result(task) 393 | 394 | except Exception as e: 395 | logger.error(f"Error in get_task_status: {e}") 396 | return format_error_message(e, "get_task_status") 397 | 398 | 399 | # Global service instance 400 | _service_instance: Optional[MidjourneyService] = None 401 | 402 | 403 | async def get_service() -> MidjourneyService: 404 | """Get or create the global service instance. 405 | 406 | Returns: 407 | MidjourneyService instance 408 | """ 409 | global _service_instance 410 | if _service_instance is None: 411 | _service_instance = MidjourneyService() 412 | await _service_instance.__aenter__() 413 | return _service_instance 414 | 415 | 416 | async def close_service(): 417 | """Close the global service instance.""" 418 | global _service_instance 419 | if _service_instance is not None: 420 | await _service_instance.__aexit__(None, None, None) 421 | _service_instance = None 422 | ```