# Directory Structure ``` ├── .gitignore ├── README.md ├── requirements.txt └── src ├── client.py ├── config.py ├── exceptions.py ├── models.py ├── server.py ├── service.py ├── task_handler.py └── utils.py ``` # Files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- ``` # Python-generated files __pycache__/ *.py[oc] build/ dist/ wheels/ *.egg-info # Virtual environments .venv ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- ``` # Core dependencies for Midjourney MCP Server httpx>=0.28.1,<1.0.0 mcp>=1.9.1,<2.0.0 pydantic>=2.11.5,<3.0.0 python-dotenv>=1.1.0 ``` -------------------------------------------------------------------------------- /src/exceptions.py: -------------------------------------------------------------------------------- ```python """Custom exceptions for Midjourney MCP server.""" class MidjourneyMCPError(Exception): """Base exception for Midjourney MCP errors.""" pass class ConfigurationError(MidjourneyMCPError): """Raised when there's a configuration error.""" pass class APIError(MidjourneyMCPError): """Base class for API-related errors.""" def __init__(self, message: str, status_code: int = None, response_data: dict = None): super().__init__(message) self.status_code = status_code self.response_data = response_data or {} class AuthenticationError(APIError): """Raised when API authentication fails.""" pass class RateLimitError(APIError): """Raised when API rate limit is exceeded.""" pass class TaskSubmissionError(APIError): """Raised when task submission fails.""" pass class TaskNotFoundError(APIError): """Raised when a task is not found.""" pass class TaskFailedError(APIError): """Raised when a task fails to complete.""" pass class TimeoutError(MidjourneyMCPError): """Raised when an operation times out.""" pass class ValidationError(MidjourneyMCPError): """Raised when input validation fails.""" pass class NetworkError(MidjourneyMCPError): """Raised when network operations fail.""" pass ``` -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- ```python """Configuration management for Midjourney MCP server.""" import os import logging import sys from typing import Optional class Config: """Configuration settings for the Midjourney MCP server.""" def __init__(self): """Initialize configuration from environment variables.""" # GPTNB API Configuration self.gptnb_api_key = os.getenv("GPTNB_API_KEY", "") self.gptnb_base_url = os.getenv("GPTNB_BASE_URL", "https://api.gptnb.ai") # Request Configuration self.timeout = int(os.getenv("TIMEOUT", "300")) self.max_retries = int(os.getenv("MAX_RETRIES", "3")) self.retry_delay = float(os.getenv("RETRY_DELAY", "1.0")) # Optional Configuration self.notify_hook = os.getenv("NOTIFY_HOOK") # Midjourney Settings self.default_suffix = os.getenv("DEFAULT_SUFFIX", "--v 6.1") # Logging self.log_level = os.getenv("LOG_LEVEL", "INFO") # Validate configuration self._validate() def _validate(self): """Validate configuration values.""" if not self.gptnb_api_key: raise ValueError("GPTNB_API_KEY environment variable is required") if not self.gptnb_api_key.startswith("sk-"): raise ValueError("GPTNB API key must start with 'sk-'") if self.timeout <= 0: raise ValueError("Timeout must be positive") if self.max_retries < 0: raise ValueError("Max retries must be non-negative") if self.retry_delay < 0: raise ValueError("Retry delay must be non-negative") # Validate base URL format if not self.gptnb_base_url.startswith(('http://', 'https://')): raise ValueError("Base URL must start with http:// or https://") # Validate log level valid_log_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] if self.log_level.upper() not in valid_log_levels: raise ValueError(f"Invalid log level: {self.log_level}. Must be one of {valid_log_levels}") def __repr__(self): """String representation of config (without sensitive data).""" return f"Config(api_key_configured={bool(self.gptnb_api_key)}, base_url={self.gptnb_base_url})" # Global configuration instance config = Config() def get_config() -> Config: """Get the global configuration instance.""" return config def reload_config() -> Config: """Reload configuration from environment.""" global config config = Config() return config # ============================================================================ # Logging Configuration # ============================================================================ def setup_logging(log_level: Optional[str] = None) -> None: """Setup logging configuration. Args: log_level: Log level override (uses config if None) """ config_instance = get_config() level = log_level or config_instance.log_level # Convert string level to logging constant numeric_level = getattr(logging, level.upper(), logging.INFO) # Create formatter formatter = logging.Formatter( fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) # Setup root logger root_logger = logging.getLogger() root_logger.setLevel(numeric_level) # Remove existing handlers for handler in root_logger.handlers[:]: root_logger.removeHandler(handler) # Create console handler console_handler = logging.StreamHandler(sys.stdout) console_handler.setLevel(numeric_level) console_handler.setFormatter(formatter) root_logger.addHandler(console_handler) # Set specific logger levels logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) # Set our package logger level package_logger = logging.getLogger("midjourney_mcp") package_logger.setLevel(numeric_level) def get_logger(name: str) -> logging.Logger: """Get a logger with the given name. Args: name: Logger name Returns: Logger instance """ return logging.getLogger(f"midjourney_mcp.{name}") ``` -------------------------------------------------------------------------------- /src/task_handler.py: -------------------------------------------------------------------------------- ```python """Unified task handling for Midjourney operations - task management and scheduling.""" import asyncio import logging from client import GPTNBClient from models import TaskDetail, TaskStatus, TaskResponse from config import Config from exceptions import TaskFailedError, TimeoutError, TaskNotFoundError logger = logging.getLogger(__name__) # ============================================================================ # Task Manager # ============================================================================ class TaskManager: """Manages Midjourney task lifecycle.""" def __init__(self, client: GPTNBClient, config: Config): """Initialize task manager.""" self.client = client self.config = config self.poll_interval = 5 # seconds self.max_poll_time = config.timeout async def submit_and_wait(self, submit_func, *args, **kwargs) -> TaskDetail: """Submit a task and wait for completion.""" logger.info(f"Submitting task with function: {submit_func.__name__}") response: TaskResponse = await submit_func(*args, **kwargs) if response.code != 1: raise TaskFailedError(f"Task submission failed: {response.description}") if not response.result: raise TaskFailedError("No task ID returned from submission") task_id = response.result logger.info(f"Task submitted successfully with ID: {task_id}") # Wait for completion return await self.wait_for_completion(task_id) async def wait_for_completion(self, task_id: str) -> TaskDetail: """Wait for task completion by polling.""" logger.info(f"Waiting for task completion: {task_id}") start_time = asyncio.get_event_loop().time() while True: try: # Get task status task = await self.client.get_task(task_id) logger.debug(f"Task {task_id} status: {task.status}, progress: {task.progress}") # Check if completed if task.status == TaskStatus.SUCCESS: logger.info(f"Task {task_id} completed successfully") return task # Check if failed if task.status == TaskStatus.FAILURE: error_msg = task.failReason or "Unknown error" logger.error(f"Task {task_id} failed: {error_msg}") raise TaskFailedError(f"Task failed: {error_msg}") # Check timeout elapsed = asyncio.get_event_loop().time() - start_time if elapsed > self.max_poll_time: logger.error(f"Task {task_id} timed out after {elapsed:.1f} seconds") raise TimeoutError(f"Task timed out after {self.max_poll_time} seconds") # Wait before next poll await asyncio.sleep(self.poll_interval) except TaskNotFoundError: logger.error(f"Task {task_id} not found") raise except Exception as e: logger.error(f"Error polling task {task_id}: {e}") await asyncio.sleep(self.poll_interval) async def get_task_status(self, task_id: str) -> TaskDetail: """Get current task status.""" # Get from API task = await self.client.get_task(task_id) return task def format_task_result(self, task: TaskDetail) -> str: """Format task result for display.""" if task.status == TaskStatus.SUCCESS: if task.imageUrl: result = f"✅ Task completed successfully!\n\n" result += f"**Image URL:** {task.imageUrl}\n\n" result += f"🖼️ **Generated Image:**\n" result += f"\n\n" result += f"📎 **Direct Link:** {task.imageUrl}\n\n" result += f"**Task ID:** {task.id}" return result elif task.description: return f"✅ Task completed successfully!\n\n**Result:** {task.description}\n\n**Task ID:** {task.id}" else: return f"✅ Task completed successfully!\n\n**Task ID:** {task.id}" elif task.status == TaskStatus.FAILURE: error_msg = task.failReason or "Unknown error" return f"❌ Task failed: {error_msg}\n\n**Task ID:** {task.id}" elif task.status == TaskStatus.IN_PROGRESS: progress = task.progress or "Processing" return f"🔄 Task in progress: {progress}\n\n**Task ID:** {task.id}" else: return f"⏳ Task status: {task.status}\n\n**Task ID:** {task.id}" ``` -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- ```python """Utility functions for Midjourney MCP server.""" import base64 import re from typing import List, Optional from exceptions import ValidationError def validate_base64_image(base64_str: str) -> bool: """Validate base64 image string. Args: base64_str: Base64 encoded image string Returns: True if valid, False otherwise """ try: # Check if it has data URL prefix if base64_str.startswith('data:image/'): # Extract base64 part after comma if ',' in base64_str: base64_str = base64_str.split(',', 1)[1] # Try to decode base64.b64decode(base64_str, validate=True) return True except Exception: return False def format_base64_image(base64_str: str, image_type: str = "png") -> str: """Format base64 string with proper data URL prefix. Args: base64_str: Base64 encoded image string image_type: Image type (png, jpg, jpeg, webp) Returns: Properly formatted base64 data URL """ # Remove existing data URL prefix if present if base64_str.startswith('data:image/'): return base64_str # Add data URL prefix return f"data:image/{image_type};base64,{base64_str}" def validate_aspect_ratio(aspect_ratio: str) -> bool: """Validate aspect ratio format. Args: aspect_ratio: Aspect ratio string (e.g., "16:9", "1:1") Returns: True if valid, False otherwise """ pattern = r'^\d+:\d+$' return bool(re.match(pattern, aspect_ratio)) def validate_prompt(prompt: str) -> str: """Validate and clean prompt text. Args: prompt: Input prompt Returns: Cleaned prompt Raises: ValidationError: If prompt is invalid """ if not prompt or not prompt.strip(): raise ValidationError("Prompt cannot be empty") prompt = prompt.strip() # Check length (Midjourney has limits) if len(prompt) > 4000: raise ValidationError("Prompt is too long (max 4000 characters)") return prompt def validate_task_id(task_id: str) -> str: """Validate task ID format. Args: task_id: Task ID string Returns: Validated task ID Raises: ValidationError: If task ID is invalid """ if not task_id or not task_id.strip(): raise ValidationError("Task ID cannot be empty") task_id = task_id.strip() # Basic format validation (adjust based on GPTNB format) if not task_id.isdigit() and len(task_id) < 10: raise ValidationError("Invalid task ID format") return task_id def validate_image_index(index: int) -> int: """Validate image index for variations/upscales. Args: index: Image index (1-4) Returns: Validated index Raises: ValidationError: If index is invalid """ if not isinstance(index, int) or index < 1 or index > 4: raise ValidationError("Image index must be between 1 and 4") return index def validate_base64_images(base64_images: List[str], min_count: int = 1, max_count: int = 5) -> List[str]: """Validate list of base64 images. Args: base64_images: List of base64 image strings min_count: Minimum number of images required max_count: Maximum number of images allowed Returns: Validated list of base64 images Raises: ValidationError: If validation fails """ if not base64_images: if min_count > 0: raise ValidationError(f"At least {min_count} image(s) required") return [] if len(base64_images) < min_count: raise ValidationError(f"At least {min_count} image(s) required") if len(base64_images) > max_count: raise ValidationError(f"Maximum {max_count} image(s) allowed") # Validate each image validated_images = [] for i, img in enumerate(base64_images): if not validate_base64_image(img): raise ValidationError(f"Invalid base64 image at index {i}") validated_images.append(format_base64_image(img)) return validated_images def extract_task_id_from_response(response_text: str) -> Optional[str]: """Extract task ID from response text. Args: response_text: Response text that may contain task ID Returns: Extracted task ID or None """ # Look for patterns like "Task ID: 1234567890" or similar patterns = [ r'[Tt]ask\s+ID[:\s]+(\d+)', r'ID[:\s]+(\d+)', r'(\d{10,})', # Long numeric IDs ] for pattern in patterns: match = re.search(pattern, response_text) if match: return match.group(1) return None def format_error_message(error: Exception, context: str = "") -> str: """Format error message for user display. Args: error: Exception object context: Additional context information Returns: Formatted error message """ error_type = type(error).__name__ error_msg = str(error) if context: return f"Error in {context}: {error_type} - {error_msg}" else: return f"{error_type}: {error_msg}" def truncate_text(text: str, max_length: int = 100) -> str: """Truncate text to specified length. Args: text: Input text max_length: Maximum length Returns: Truncated text """ if len(text) <= max_length: return text return text[:max_length - 3] + "..." ``` -------------------------------------------------------------------------------- /src/server.py: -------------------------------------------------------------------------------- ```python """Midjourney MCP Server - Main entry point with tool functions.""" import logging from typing import List, Optional from mcp.server.fastmcp import FastMCP from service import get_service from config import setup_logging, get_config # Initialize logging setup_logging() logger = logging.getLogger(__name__) # Initialize MCP server mcp = FastMCP("midjourney") # Get configuration config = get_config() # ============================================================================ # MCP Tool Functions - Complete Midjourney Functionality # ============================================================================ @mcp.tool() async def imagine_image( prompt: str, aspect_ratio: str = "1:1", base64_images: Optional[List[str]] = None ) -> str: """Generate images from text prompts with optional reference images. Args: prompt: Text description of the image to generate (English only) aspect_ratio: Aspect ratio of the image (e.g., "16:9", "1:1", "9:16") base64_images: Optional list of reference images in base64 format Returns: Generated image URL and task information """ try: logger.info(f"Imagine request: {prompt[:100]}...") service = await get_service() result = await service.imagine(prompt, aspect_ratio, base64_images) return result except Exception as e: logger.error(f"Error in imagine_image: {e}") return f"❌ Error generating image: {str(e)}" @mcp.tool() async def blend_images( base64_images: List[str], dimensions: str = "SQUARE" ) -> str: """Blend multiple images together. Args: base64_images: List of 2-5 images to blend in base64 format dimensions: Output dimensions ("PORTRAIT", "SQUARE", "LANDSCAPE") Returns: Blended image URL and task information """ try: logger.info(f"Blend request: {len(base64_images)} images") service = await get_service() result = await service.blend(base64_images, dimensions) return result except Exception as e: logger.error(f"Error in blend_images: {e}") return f"❌ Error blending images: {str(e)}" @mcp.tool() async def describe_image(base64_image: str) -> str: """Generate text descriptions of an image. Args: base64_image: Image to describe in base64 format Returns: Text description of the image """ try: logger.info("Describe image request") service = await get_service() result = await service.describe(base64_image) return result except Exception as e: logger.error(f"Error in describe_image: {e}") return f"❌ Error describing image: {str(e)}" @mcp.tool() async def change_image( task_id: str, action: str, index: Optional[int] = None ) -> str: """Create variations, upscales, or rerolls of existing images. Args: task_id: ID of the original generation task action: Action type ("UPSCALE", "VARIATION", "REROLL") index: Image index (1-4) for UPSCALE and VARIATION actions Returns: Modified image URL and task information """ try: logger.info(f"Change request: {action} for task {task_id}") service = await get_service() result = await service.change(task_id, action, index) return result except Exception as e: logger.error(f"Error in change_image: {e}") return f"❌ Error changing image: {str(e)}" @mcp.tool() async def modal_edit( task_id: str, action: str, prompt: Optional[str] = None ) -> str: """Perform advanced editing like zoom, pan, or inpainting. Args: task_id: ID of the original generation task action: Edit action type (zoom, pan, inpaint, etc.) prompt: Additional prompt for the edit Returns: Edited image URL and task information """ try: logger.info(f"Modal edit request: {action} for task {task_id}") service = await get_service() result = await service.modal_edit(task_id, action, prompt) return result except Exception as e: logger.error(f"Error in modal_edit: {e}") return f"❌ Error in modal edit: {str(e)}" @mcp.tool() async def swap_face(source_image: str, target_image: str) -> str: """Swap faces between two images. Args: source_image: Source face image in base64 format target_image: Target image in base64 format Returns: Face-swapped image URL and task information """ try: logger.info("Face swap request") service = await get_service() result = await service.swap_face(source_image, target_image) return result except Exception as e: logger.error(f"Error in swap_face: {e}") return f"❌ Error swapping faces: {str(e)}" # ============================================================================ # Task Management Tools # ============================================================================ @mcp.tool() async def get_task_status(task_id: str) -> str: """Get current status of a Midjourney task. Args: task_id: Task ID to check Returns: Current task status and details """ try: logger.info(f"Task status request: {task_id}") service = await get_service() result = await service.get_task_status(task_id) return result except Exception as e: logger.error(f"Error in get_task_status: {e}") return f"❌ Error getting task status: {str(e)}" # ============================================================================ # Server Lifecycle Management # ============================================================================ # Server lifecycle will be handled by the main function def main(): """Main entry point for the MCP server.""" try: logger.info("Starting Midjourney MCP Server...") logger.info(f"Configuration: API Key configured = {bool(config.gptnb_api_key)}") logger.info(f"Base URL: {config.gptnb_base_url}") logger.info("Server started successfully!") mcp.run(transport="stdio") except KeyboardInterrupt: logger.info("Server interrupted by user") except Exception as e: logger.error(f"Server error: {e}") raise if __name__ == "__main__": main() ``` -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- ```python """Data models for Midjourney MCP server.""" from enum import Enum from typing import List, Optional, Dict, Any, Union from pydantic import BaseModel, Field, validator class TaskStatus(str, Enum): """Task status enumeration.""" NOT_START = "NOT_START" SUBMITTED = "SUBMITTED" IN_PROGRESS = "IN_PROGRESS" SUCCESS = "SUCCESS" FAILURE = "FAILURE" class TaskAction(str, Enum): """Task action enumeration.""" IMAGINE = "IMAGINE" UPSCALE = "UPSCALE" VARIATION = "VARIATION" REROLL = "REROLL" DESCRIBE = "DESCRIBE" BLEND = "BLEND" SWAP_FACE = "SWAP_FACE" SHORTEN = "SHORTEN" class Dimensions(str, Enum): """Image dimensions enumeration.""" PORTRAIT = "PORTRAIT" # 2:3 SQUARE = "SQUARE" # 1:1 LANDSCAPE = "LANDSCAPE" # 3:2 class Button(BaseModel): """Button model for task actions.""" customId: str = Field(..., description="Custom ID for action submission") label: str = Field(..., description="Button label") type: Union[str, int] = Field(..., description="Button type") style: Union[str, int] = Field(..., description="Button style") emoji: str = Field(..., description="Button emoji") @validator('type', pre=True) def convert_type_to_string(cls, v): """Convert type to string if it's an integer.""" return str(v) if v is not None else v @validator('style', pre=True) def convert_style_to_string(cls, v): """Convert style to string if it's an integer.""" return str(v) if v is not None else v class TaskResponse(BaseModel): """Task response model.""" code: int = Field(..., description="Status code") description: str = Field(..., description="Response description") result: Optional[str] = Field(None, description="Task ID") properties: Dict[str, Any] = Field(default_factory=dict, description="Additional properties") class TaskDetail(BaseModel): """Detailed task information.""" id: Optional[str] = Field(None, description="Task ID") action: Optional[TaskAction] = Field(None, description="Task action") prompt: Optional[str] = Field(None, description="Original prompt") promptEn: Optional[str] = Field(None, description="English prompt") description: Optional[str] = Field(None, description="Task description") status: Optional[TaskStatus] = Field(None, description="Task status") progress: Optional[str] = Field(None, description="Task progress") imageUrl: Optional[str] = Field(None, description="Generated image URL") failReason: Optional[str] = Field(None, description="Failure reason") submitTime: Optional[int] = Field(None, description="Submit timestamp") startTime: Optional[int] = Field(None, description="Start timestamp") finishTime: Optional[int] = Field(None, description="Finish timestamp") state: Optional[str] = Field(None, description="Custom state") buttons: Optional[List[Button]] = Field(default_factory=list, description="Available action buttons") properties: Optional[Dict[str, Any]] = Field(default_factory=dict, description="Additional properties") @classmethod def from_api_response(cls, data: Dict[str, Any]) -> "TaskDetail": """Create TaskDetail from API response, handling missing/invalid fields.""" # Handle empty or invalid status status = data.get('status', 'NOT_START') if not status or status not in [s.value for s in TaskStatus]: status = 'NOT_START' # Handle empty or invalid action action = data.get('action') if action and action not in [a.value for a in TaskAction]: action = None # Handle buttons - ensure it's a list and convert properly buttons = data.get('buttons') if buttons is None: buttons = [] elif not isinstance(buttons, list): buttons = [] else: # Convert button data to ensure proper types converted_buttons = [] for button_data in buttons: if isinstance(button_data, dict): # Ensure all required fields exist with defaults button_dict = { 'customId': button_data.get('customId', ''), 'label': button_data.get('label', ''), 'type': button_data.get('type', ''), 'style': button_data.get('style', ''), 'emoji': button_data.get('emoji', '') } converted_buttons.append(button_dict) buttons = converted_buttons # Handle properties - ensure it's a dict properties = data.get('properties') if properties is None: properties = {} elif not isinstance(properties, dict): properties = {} # Create cleaned data cleaned_data = { **data, 'status': status, 'action': action, 'buttons': buttons, 'properties': properties } return cls(**cleaned_data) class ImagineRequest(BaseModel): """Request model for imagine task.""" prompt: str = Field(..., description="Text prompt") base64Array: Optional[List[str]] = Field(None, description="Reference images in base64") notifyHook: Optional[str] = Field(None, description="Callback URL") state: Optional[str] = Field(None, description="Custom state") class BlendRequest(BaseModel): """Request model for blend task.""" base64Array: List[str] = Field(..., description="Images to blend in base64", min_items=2, max_items=5) dimensions: Dimensions = Field(default=Dimensions.SQUARE, description="Output dimensions") notifyHook: Optional[str] = Field(None, description="Callback URL") state: Optional[str] = Field(None, description="Custom state") class DescribeRequest(BaseModel): """Request model for describe task.""" base64: str = Field(..., description="Image in base64 format") notifyHook: Optional[str] = Field(None, description="Callback URL") state: Optional[str] = Field(None, description="Custom state") class ChangeRequest(BaseModel): """Request model for change task.""" action: str = Field(..., description="Action type (UPSCALE, VARIATION, REROLL)") index: Optional[int] = Field(None, description="Image index (1-4)") taskId: str = Field(..., description="Original task ID") notifyHook: Optional[str] = Field(None, description="Callback URL") state: Optional[str] = Field(None, description="Custom state") class SwapFaceRequest(BaseModel): """Request model for swap face task.""" sourceBase64: str = Field(..., description="Source face image in base64") targetBase64: str = Field(..., description="Target image in base64") notifyHook: Optional[str] = Field(None, description="Callback URL") state: Optional[str] = Field(None, description="Custom state") class ModalRequest(BaseModel): """Request model for modal task.""" taskId: str = Field(..., description="Original task ID") maskBase64: str = Field(..., description="Mask image in base64") prompt: Optional[str] = Field(None, description="Additional prompt") notifyHook: Optional[str] = Field(None, description="Callback URL") state: Optional[str] = Field(None, description="Custom state") ``` -------------------------------------------------------------------------------- /src/client.py: -------------------------------------------------------------------------------- ```python """GPTNB API client for Midjourney operations.""" import asyncio import logging from typing import Optional, Dict, Any, List import httpx from config import Config from models import ( TaskResponse, TaskDetail, ImagineRequest, BlendRequest, DescribeRequest, ChangeRequest, SwapFaceRequest, ModalRequest ) from exceptions import ( APIError, AuthenticationError, RateLimitError, TaskNotFoundError, NetworkError, TimeoutError ) logger = logging.getLogger(__name__) class GPTNBClient: """GPTNB API client for Midjourney operations.""" def __init__(self, config: Config): """Initialize the GPTNB client. Args: config: Configuration object """ self.config = config self.base_url = config.gptnb_base_url.rstrip('/') self.headers = { "Authorization": f"Bearer {config.gptnb_api_key}", "Content-Type": "application/json", "User-Agent": "midjourney-mcp/0.2.0" } self._client: Optional[httpx.AsyncClient] = None async def __aenter__(self): """Async context manager entry.""" await self._ensure_client() return self async def __aexit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_tb: Optional[object]) -> None: """Async context manager exit.""" await self.close() async def _ensure_client(self): """Ensure HTTP client is initialized.""" if self._client is None: self._client = httpx.AsyncClient( timeout=httpx.Timeout(self.config.timeout), headers=self.headers ) async def close(self) -> None: """Close the HTTP client.""" if self._client: await self._client.aclose() self._client = None async def _make_request( self, method: str, endpoint: str, data: Optional[Dict[str, Any]] = None, params: Optional[Dict[str, Any]] = None ) -> Dict[str, Any]: """Make HTTP request with error handling and retries. Args: method: HTTP method endpoint: API endpoint data: Request body data params: Query parameters Returns: Response data Raises: APIError: For API-related errors NetworkError: For network-related errors TimeoutError: For timeout errors """ await self._ensure_client() url = f"{self.base_url}{endpoint}" for attempt in range(self.config.max_retries + 1): try: logger.debug(f"Making {method} request to {url} (attempt {attempt + 1})") response = await self._client.request( method=method, url=url, json=data, params=params ) # Handle different status codes if response.status_code == 200: return response.json() elif response.status_code == 401: raise AuthenticationError( "Invalid API key or authentication failed", status_code=response.status_code, response_data=response.json() if response.content else {} ) elif response.status_code == 429: raise RateLimitError( "Rate limit exceeded", status_code=response.status_code, response_data=response.json() if response.content else {} ) elif response.status_code == 404: raise TaskNotFoundError( "Task not found", status_code=response.status_code, response_data=response.json() if response.content else {} ) else: response_data = response.json() if response.content else {} raise APIError( f"API request failed with status {response.status_code}", status_code=response.status_code, response_data=response_data ) except httpx.TimeoutException as e: if attempt == self.config.max_retries: raise TimeoutError(f"Request timed out after {self.config.timeout} seconds") from e logger.warning(f"Request timeout on attempt {attempt + 1}, retrying...") except httpx.NetworkError as e: if attempt == self.config.max_retries: raise NetworkError(f"Network error: {str(e)}") from e logger.warning(f"Network error on attempt {attempt + 1}, retrying...") except (AuthenticationError, RateLimitError, TaskNotFoundError): # Don't retry these errors raise except Exception as e: if attempt == self.config.max_retries: raise APIError(f"Unexpected error: {str(e)}") from e logger.warning(f"Unexpected error on attempt {attempt + 1}, retrying...") # Wait before retry if attempt < self.config.max_retries: wait_time = self.config.retry_delay * (2 ** attempt) # Exponential backoff logger.debug(f"Waiting {wait_time} seconds before retry...") await asyncio.sleep(wait_time) raise APIError("Max retries exceeded") def _prepare_request_data(self, request: Any) -> Dict[str, Any]: """Prepare request data with notify hook if configured. Args: request: Request object with model_dump method Returns: Prepared request data """ data = request.model_dump(exclude_none=True) if self.config.notify_hook and not data.get('notifyHook'): data['notifyHook'] = self.config.notify_hook return data async def submit_imagine(self, request: ImagineRequest) -> TaskResponse: """Submit an imagine task. Args: request: Imagine request data Returns: Task response """ data = self._prepare_request_data(request) response_data = await self._make_request("POST", "/mj/submit/imagine", data) return TaskResponse(**response_data) async def submit_blend(self, request: BlendRequest) -> TaskResponse: """Submit a blend task. Args: request: Blend request data Returns: Task response """ data = self._prepare_request_data(request) response_data = await self._make_request("POST", "/mj/submit/blend", data) return TaskResponse(**response_data) async def submit_describe(self, request: DescribeRequest) -> TaskResponse: """Submit a describe task. Args: request: Describe request data Returns: Task response """ data = self._prepare_request_data(request) response_data = await self._make_request("POST", "/mj/submit/describe", data) return TaskResponse(**response_data) async def submit_change(self, request: ChangeRequest) -> TaskResponse: """Submit a change task (upscale, variation, reroll). Args: request: Change request data Returns: Task response """ data = self._prepare_request_data(request) response_data = await self._make_request("POST", "/mj/submit/change", data) return TaskResponse(**response_data) async def submit_swap_face(self, request: SwapFaceRequest) -> TaskResponse: """Submit a swap face task. Args: request: Swap face request data Returns: Task response """ data = self._prepare_request_data(request) response_data = await self._make_request("POST", "/mj/submit/swap-face", data) return TaskResponse(**response_data) async def submit_modal(self, request: ModalRequest) -> TaskResponse: """Submit a modal task (zoom, pan, inpainting). Args: request: Modal request data Returns: Task response """ data = self._prepare_request_data(request) response_data = await self._make_request("POST", "/mj/submit/modal", data) return TaskResponse(**response_data) async def get_task(self, task_id: str) -> TaskDetail: """Get task details by ID. Args: task_id: Task ID Returns: Task details """ response_data = await self._make_request("GET", f"/mj/task/{task_id}/fetch") return TaskDetail.from_api_response(response_data) async def get_tasks(self, task_ids: List[str]) -> List[TaskDetail]: """Get multiple tasks by IDs. Args: task_ids: List of task IDs Returns: List of task details """ data = {"ids": task_ids} response_data = await self._make_request("POST", "/mj/task/list-by-condition", data) return [TaskDetail(**task) for task in response_data] ``` -------------------------------------------------------------------------------- /src/service.py: -------------------------------------------------------------------------------- ```python """Midjourney service layer for business logic.""" import logging from functools import wraps from typing import List, Optional, Callable from client import GPTNBClient from task_handler import TaskManager from models import ( ImagineRequest, BlendRequest, DescribeRequest, ChangeRequest, SwapFaceRequest, ModalRequest, Dimensions ) from config import Config, get_config from utils import ( validate_prompt, validate_aspect_ratio, validate_base64_images, validate_task_id, validate_image_index, format_error_message ) from exceptions import ValidationError logger = logging.getLogger(__name__) def handle_service_errors(operation_name: str): """Decorator to handle common service errors. Args: operation_name: Name of the operation for error messages """ def decorator(func: Callable) -> Callable: """Decorator function.""" @wraps(func) async def wrapper(*args, **kwargs) -> str: """Wrapper function with error handling.""" try: return await func(*args, **kwargs) except Exception as e: logger.error(f"Error in {operation_name}: {e}") return format_error_message(e, operation_name) return wrapper return decorator class MidjourneyService: """High-level service for Midjourney operations.""" def __init__(self, config: Optional[Config] = None): """Initialize Midjourney service. Args: config: Configuration object (uses global config if None) """ self.config = config or get_config() self.client: Optional[GPTNBClient] = None self.task_manager: Optional[TaskManager] = None async def __aenter__(self): """Async context manager entry.""" self.client = GPTNBClient(self.config) await self.client.__aenter__() self.task_manager = TaskManager(self.client, self.config) return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" if self.client: await self.client.__aexit__(exc_type, exc_val, exc_tb) async def imagine( self, prompt: str, aspect_ratio: str = "1:1", base64_images: Optional[List[str]] = None ) -> str: """Generate images from text prompt. Args: prompt: Text description of the image aspect_ratio: Aspect ratio (e.g., "16:9", "1:1", "9:16") base64_images: Optional reference images Returns: Formatted result string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate inputs prompt = validate_prompt(prompt) if not validate_aspect_ratio(aspect_ratio): raise ValidationError(f"Invalid aspect ratio: {aspect_ratio}") # Add aspect ratio to prompt if not already present if "--ar" not in prompt: prompt = f"{prompt} --ar {aspect_ratio}" # Add default suffix if configured if self.config.default_suffix and self.config.default_suffix not in prompt: prompt = f"{prompt} {self.config.default_suffix}" # Validate and format images if provided validated_images = None if base64_images: validated_images = validate_base64_images(base64_images, min_count=0, max_count=5) # Create request request = ImagineRequest( prompt=prompt, base64Array=validated_images ) # Submit task (don't wait for completion) response = await self.client.submit_imagine(request) if response.code != 1: raise Exception(f"Task submission failed: {response.description}") if not response.result: raise Exception("No task ID returned from submission") task_id = response.result # Return immediate response with task ID result = f"🎨 **Image Generation Started!**\n\n" result += f"**Task ID:** {task_id}\n" result += f"**Prompt:** {prompt}\n" result += f"**Aspect Ratio:** {aspect_ratio}\n" result += f"**Status:** Task submitted successfully\n\n" result += f"💡 **Next Steps:**\n" result += f"Use `get_task_status(task_id=\"{task_id}\")` to check current status\n\n" result += f"⏱️ **Estimated Time:** 30-60 seconds\n" return result except Exception as e: logger.error(f"Error in imagine: {e}") return format_error_message(e, "imagine") async def blend( self, base64_images: List[str], dimensions: str = "SQUARE" ) -> str: """Blend multiple images together. Args: base64_images: List of 2-5 images to blend dimensions: Output dimensions ("PORTRAIT", "SQUARE", "LANDSCAPE") Returns: Formatted result string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate inputs validated_images = validate_base64_images(base64_images, min_count=2, max_count=5) # Validate dimensions try: dim_enum = Dimensions(dimensions.upper()) except ValueError: raise ValidationError(f"Invalid dimensions: {dimensions}. Must be PORTRAIT, SQUARE, or LANDSCAPE") # Create request request = BlendRequest( base64Array=validated_images, dimensions=dim_enum ) # Submit task (don't wait for completion) response = await self.client.submit_blend(request) if response.code != 1: raise Exception(f"Task submission failed: {response.description}") if not response.result: raise Exception("No task ID returned from submission") task_id = response.result # Return immediate response with task ID result = f"🎨 **Image Blending Started!**\n\n" result += f"**Task ID:** {task_id}\n" result += f"**Images:** {len(base64_images)} images to blend\n" result += f"**Dimensions:** {dimensions}\n" result += f"**Status:** Task submitted successfully\n\n" result += f"💡 **Monitor Progress:** Use `get_task_status(task_id=\"{task_id}\")`\n" return result except Exception as e: logger.error(f"Error in blend: {e}") return format_error_message(e, "blend") async def describe(self, base64_image: str) -> str: """Generate text description of an image. Args: base64_image: Image to describe Returns: Formatted result string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate image validated_images = validate_base64_images([base64_image], min_count=1, max_count=1) # Create request request = DescribeRequest( base64=validated_images[0] ) # Submit and wait for completion task = await self.task_manager.submit_and_wait( self.client.submit_describe, request ) return self.task_manager.format_task_result(task) except Exception as e: logger.error(f"Error in describe: {e}") return format_error_message(e, "describe") async def change( self, task_id: str, action: str, index: Optional[int] = None ) -> str: """Create variations, upscales, or rerolls of existing images. Args: task_id: ID of the original generation task action: Action type ("UPSCALE", "VARIATION", "REROLL") index: Image index (1-4) for UPSCALE and VARIATION Returns: Formatted result string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate inputs task_id = validate_task_id(task_id) action = action.upper() if action not in ["UPSCALE", "VARIATION", "REROLL"]: raise ValidationError(f"Invalid action: {action}. Must be UPSCALE, VARIATION, or REROLL") if action in ["UPSCALE", "VARIATION"]: if index is None: raise ValidationError(f"Index is required for {action} action") index = validate_image_index(index) # Create request request = ChangeRequest( taskId=task_id, action=action, index=index ) # Submit and wait for completion task = await self.task_manager.submit_and_wait( self.client.submit_change, request ) return self.task_manager.format_task_result(task) except Exception as e: logger.error(f"Error in change: {e}") return format_error_message(e, "change") async def modal_edit( self, task_id: str, action: str, prompt: Optional[str] = None ) -> str: """Perform advanced editing like zoom, pan, or inpainting. Args: task_id: ID of the original generation task action: Edit action type (zoom, pan, inpaint, etc.) prompt: Additional prompt for the edit Returns: Formatted result string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate inputs task_id = validate_task_id(task_id) if prompt: prompt = validate_prompt(prompt) # Note: This is a simplified implementation # Real modal operations would need mask images and specific action parameters logger.info(f"Performing modal edit with action: {action}") # Create request (Note: This is a simplified version) # In practice, modal operations require more complex parameters request = ModalRequest( taskId=task_id, maskBase64="", # This would need to be provided by the user prompt=prompt ) # Submit and wait for completion task = await self.task_manager.submit_and_wait( self.client.submit_modal, request ) return self.task_manager.format_task_result(task) except Exception as e: logger.error(f"Error in modal_edit: {e}") return format_error_message(e, "modal_edit") async def swap_face( self, source_image: str, target_image: str ) -> str: """Swap faces between two images. Args: source_image: Source face image in base64 format target_image: Target image in base64 format Returns: Formatted result string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate images validated_images = validate_base64_images([source_image, target_image], min_count=2, max_count=2) # Create request request = SwapFaceRequest( sourceBase64=validated_images[0], targetBase64=validated_images[1] ) # Submit and wait for completion task = await self.task_manager.submit_and_wait( self.client.submit_swap_face, request ) return self.task_manager.format_task_result(task) except Exception as e: logger.error(f"Error in swap_face: {e}") return format_error_message(e, "swap_face") async def get_task_status(self, task_id: str) -> str: """Get current status of a task. Args: task_id: Task ID to check Returns: Formatted status string Raises: ValidationError: If input validation fails MidjourneyMCPError: If operation fails """ try: # Validate task ID task_id = validate_task_id(task_id) # Get task status task = await self.task_manager.get_task_status(task_id) return self.task_manager.format_task_result(task) except Exception as e: logger.error(f"Error in get_task_status: {e}") return format_error_message(e, "get_task_status") # Global service instance _service_instance: Optional[MidjourneyService] = None async def get_service() -> MidjourneyService: """Get or create the global service instance. Returns: MidjourneyService instance """ global _service_instance if _service_instance is None: _service_instance = MidjourneyService() await _service_instance.__aenter__() return _service_instance async def close_service(): """Close the global service instance.""" global _service_instance if _service_instance is not None: await _service_instance.__aexit__(None, None, None) _service_instance = None ```