This is page 3 of 3. Use http://codebase.md/angrysky56/mcts-mcp-server?page={x} to view the full context.
# Directory Structure
```
├── .env.example
├── .gitignore
├── archive
│ ├── ANALYSIS_TOOLS.md
│ ├── First-Run.md
│ ├── fixed_tools.py
│ ├── gemini_adapter_old.py
│ ├── gemini_adapter.py
│ ├── GEMINI_SETUP.md
│ ├── QUICK_START_FIXED.md
│ ├── QUICK_START.md
│ ├── README.md
│ ├── run_test.py
│ ├── SERVER_FIX_SUMMARY.md
│ ├── setup_analysis_venv.sh
│ ├── setup_analysis.sh
│ ├── SETUP_SUMMARY.md
│ ├── test_adapter.py
│ ├── test_fixed_server.py
│ ├── test_gemini_setup.py
│ ├── test_mcp_init.py
│ ├── test_minimal.py
│ ├── test_new_adapters.py
│ ├── test_ollama.py
│ ├── test_rate_limiting.py
│ ├── test_server_debug.py
│ ├── test_server.py
│ ├── test_simple.py
│ ├── test_startup_simple.py
│ ├── test_startup.py
│ ├── TIMEOUT_FIX.md
│ ├── tools_fast.py
│ ├── tools_old.py
│ └── tools_original.py
├── image-1.png
├── image-2.png
├── image-3.png
├── image.png
├── LICENSE
├── prompts
│ ├── README.md
│ └── usage_guide.md
├── pyproject.toml
├── README.md
├── results
│ ├── cogito:32b
│ │ └── cogito:32b_1745989705
│ │ ├── best_solution.txt
│ │ └── progress.jsonl
│ ├── cogito:latest
│ │ ├── cogito:latest_1745979984
│ │ │ ├── best_solution.txt
│ │ │ └── progress.jsonl
│ │ └── cogito:latest_1745984274
│ │ ├── best_solution.txt
│ │ └── progress.jsonl
│ ├── local
│ │ ├── local_1745956311
│ │ │ ├── best_solution.txt
│ │ │ └── progress.jsonl
│ │ ├── local_1745956673
│ │ │ ├── best_solution.txt
│ │ │ └── progress.jsonl
│ │ └── local_1745958556
│ │ ├── best_solution.txt
│ │ └── progress.jsonl
│ └── qwen3:0.6b
│ ├── qwen3:0.6b_1745960624
│ │ ├── best_solution.txt
│ │ └── progress.jsonl
│ ├── qwen3:0.6b_1745960651
│ │ ├── best_solution.txt
│ │ └── progress.jsonl
│ ├── qwen3:0.6b_1745960694
│ │ ├── best_solution.txt
│ │ └── progress.jsonl
│ └── qwen3:0.6b_1745977462
│ ├── best_solution.txt
│ └── progress.jsonl
├── setup_unix.sh
├── setup_windows.bat
├── setup.py
├── setup.sh
├── src
│ └── mcts_mcp_server
│ ├── __init__.py
│ ├── analysis_tools
│ │ ├── __init__.py
│ │ ├── mcts_tools.py
│ │ └── results_processor.py
│ ├── anthropic_adapter.py
│ ├── base_llm_adapter.py
│ ├── gemini_adapter.py
│ ├── intent_handler.py
│ ├── llm_adapter.py
│ ├── llm_interface.py
│ ├── manage_server.py
│ ├── mcts_config.py
│ ├── mcts_core.py
│ ├── node.py
│ ├── ollama_adapter.py
│ ├── ollama_check.py
│ ├── ollama_utils.py
│ ├── openai_adapter.py
│ ├── rate_limiter.py
│ ├── reality_warps_adapter.py
│ ├── results_collector.py
│ ├── server.py
│ ├── state_manager.py
│ ├── tools.py
│ └── utils.py
├── USAGE_GUIDE.md
├── uv.lock
└── verify_installation.py
```
# Files
--------------------------------------------------------------------------------
/src/mcts_mcp_server/tools.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
Fixed Tools for MCTS with proper async handling
===============================================
This module fixes the async event loop issues in the MCTS MCP tools.
"""
import asyncio
import logging
import os
import threading
from collections.abc import Coroutine
from typing import Any
from dotenv import load_dotenv
from mcp.server.fastmcp import FastMCP
from .mcts_config import DEFAULT_CONFIG
from .mcts_core import MCTS
from .ollama_adapter import OllamaAdapter
from .ollama_utils import (
OLLAMA_PYTHON_PACKAGE_AVAILABLE,
check_available_models,
get_recommended_models,
)
from .state_manager import StateManager
from .utils import truncate_text
logger = logging.getLogger(__name__)
# Global state to maintain between tool calls
_global_state = {
"mcts_instance": None,
"config": None,
"state_manager": None,
"current_chat_id": None,
"active_llm_provider": os.getenv("DEFAULT_LLM_PROVIDER", "ollama"),
"active_model_name": os.getenv("DEFAULT_MODEL_NAME"),
"collect_results": False,
"current_run_id": None,
"ollama_available_models": [],
"background_loop": None,
"background_thread": None
}
def get_or_create_background_loop() -> asyncio.AbstractEventLoop | None:
"""
Get or create a background event loop that runs in a dedicated thread.
Returns:
The background event loop, or None if creation failed
Note:
This ensures all async operations use the same event loop and avoids
"bound to different event loop" issues common in MCP tools
"""
global _global_state
if _global_state["background_loop"] is None or _global_state["background_thread"] is None:
loop_created = threading.Event() # Use threading.Event instead of asyncio.Event
loop_container: dict[str, asyncio.AbstractEventLoop | None] = {"loop": None}
def create_background_loop():
"""Create and run a background event loop."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop_container["loop"] = loop
_global_state["background_loop"] = loop
loop_created.set()
try:
loop.run_forever()
except Exception as e:
logger.error(f"Background loop error: {e}")
finally:
loop.close()
# Start the background thread
thread = threading.Thread(target=create_background_loop, daemon=True)
thread.start()
_global_state["background_thread"] = thread
# Wait for loop to be created (with shorter timeout to avoid hanging)
if not loop_created.wait(timeout=2.0):
logger.warning("Background loop creation timed out")
# Don't raise an error, just return None and handle gracefully
return None
if loop_container["loop"] is None:
logger.warning("Failed to create background event loop")
return None
return _global_state["background_loop"]
def run_in_background_loop(coro: Coroutine[Any, Any, Any]) -> Any:
"""
Run a coroutine in the background event loop.
Args:
coro: The coroutine to execute
Returns:
The result of the coroutine execution
Raises:
RuntimeError: If all execution methods fail
Note:
This avoids the "bound to different event loop" issue by using
a dedicated background loop with fallback strategies
"""
loop = get_or_create_background_loop()
if loop is None:
# Fallback: try to run in a new event loop if background loop failed
logger.warning("Background loop not available, using fallback")
try:
return asyncio.run(coro)
except RuntimeError:
# If we're already in an event loop, use thread executor
try:
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, coro)
return future.result(timeout=300)
except Exception as e:
raise RuntimeError(f"Failed to run coroutine: {e}") from e
if loop.is_running():
# Submit to the running loop and wait for result
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result(timeout=300) # 5 minute timeout
else:
# This shouldn't happen if the background loop is properly managed
raise RuntimeError("Background event loop is not running")
def register_mcts_tools(mcp: FastMCP, db_path: str) -> None:
"""
Register all MCTS-related tools with the MCP server.
Args:
mcp: The FastMCP server instance to register tools with
db_path: Path to the SQLite database for state persistence
Note:
Initializes global state, loads environment variables, and registers
all tool functions with proper async handling
"""
global _global_state
# Load environment variables
load_dotenv()
# Initialize state manager
_global_state["state_manager"] = StateManager(db_path)
# Initialize config
_global_state["config"] = DEFAULT_CONFIG.copy()
# Don't check Ollama models during initialization to prevent hanging
# Models will be checked when list_ollama_models() is called
_global_state["ollama_available_models"] = []
# Set default model for ollama if needed
if _global_state["active_llm_provider"] == "ollama" and not _global_state["active_model_name"]:
_global_state["active_model_name"] = OllamaAdapter.DEFAULT_MODEL
@mcp.tool()
def initialize_mcts(question: str, chat_id: str, provider_name: str | None = None,
model_name: str | None = None, config_updates: dict[str, Any] | None = None) -> dict[str, Any]:
"""
Initialize the MCTS system with proper async handling.
Args:
question: The question or topic to analyze
chat_id: Unique identifier for this conversation session
provider_name: LLM provider to use (ollama, openai, anthropic, gemini)
model_name: Specific model name to use
config_updates: Optional configuration overrides
Returns:
dictionary containing initialization status, configuration, and metadata
Note:
Creates LLM adapter, generates initial analysis, and sets up MCTS instance
with optional state loading from previous sessions
"""
global _global_state
try:
logger.info(f"Initializing MCTS for chat ID: {chat_id}")
# Determine target provider and model
target_provider = provider_name or _global_state["active_llm_provider"]
target_model = model_name or _global_state["active_model_name"]
logger.info(f"Using LLM Provider: {target_provider}, Model: {target_model}")
# Update config if provided
if config_updates:
cfg = _global_state["config"].copy()
cfg.update(config_updates)
_global_state["config"] = cfg
else:
cfg = _global_state["config"]
_global_state["current_chat_id"] = chat_id
state_manager = _global_state["state_manager"]
loaded_state = state_manager.load_state(chat_id) if cfg.get("enable_state_persistence", True) else None
# Instantiate the appropriate adapter
llm_adapter = None
if target_provider == "ollama":
if not target_model:
target_model = OllamaAdapter.DEFAULT_MODEL
if target_model not in _global_state["ollama_available_models"]:
return {
"status": "model_error",
"error": f"Ollama model '{target_model}' not available",
"available_models": _global_state["ollama_available_models"]
}
llm_adapter = OllamaAdapter(model_name=target_model)
elif target_provider == "openai":
from .openai_adapter import OpenAIAdapter
if not target_model:
target_model = OpenAIAdapter.DEFAULT_MODEL
llm_adapter = OpenAIAdapter(api_key=os.getenv("OPENAI_API_KEY"), model_name=target_model)
elif target_provider == "anthropic":
from .anthropic_adapter import AnthropicAdapter
if not target_model:
target_model = AnthropicAdapter.DEFAULT_MODEL
llm_adapter = AnthropicAdapter(api_key=os.getenv("ANTHROPIC_API_KEY"), model_name=target_model)
elif target_provider == "gemini":
from .gemini_adapter import GeminiAdapter
if not target_model:
target_model = GeminiAdapter.DEFAULT_MODEL
llm_adapter = GeminiAdapter(api_key=os.getenv("GEMINI_API_KEY"), model_name=target_model)
else:
return {"error": f"Unsupported LLM provider: {target_provider}", "status": "error"}
_global_state["active_llm_provider"] = target_provider
_global_state["active_model_name"] = target_model
# Generate initial analysis using the background loop
async def generate_initial():
initial_prompt = f"<instruction>Provide an initial analysis of the following question. Be clear and concise.</instruction><question>{question}</question>"
initial_messages = [{"role": "user", "content": initial_prompt}]
return await llm_adapter.get_completion(model=target_model, messages=initial_messages)
try:
initial_analysis = run_in_background_loop(generate_initial())
except Exception as e:
logger.error(f"Failed to generate initial analysis: {e}")
return {"error": f"Failed to generate initial analysis: {str(object=e)}", "status": "error"}
# Create MCTS instance
_global_state["mcts_instance"] = MCTS(
llm_interface=llm_adapter,
question=question,
initial_analysis_content=initial_analysis or "No initial analysis available",
config=cfg,
initial_state=loaded_state
)
return {
"status": "initialized",
"question": question,
"chat_id": chat_id,
"initial_analysis": initial_analysis,
"loaded_state": loaded_state is not None,
"provider": target_provider,
"model_used": target_model,
"config": {k: v for k, v in cfg.items() if not k.startswith("_")},
"run_id": _global_state.get("current_run_id")
}
except ValueError as ve:
logger.error(f"Configuration error: {ve}")
return {"error": f"Configuration error: {ve!s}", "status": "config_error"}
except Exception as e:
logger.error(f"Error in initialize_mcts: {e}")
return {"error": f"Failed to initialize MCTS: {e!s}", "status": "error"}
@mcp.tool()
def set_active_llm(provider_name: str, model_name: str | None = None) -> dict[str, Any]:
"""
Set the active LLM provider and model for subsequent operations.
Args:
provider_name: Name of the LLM provider (ollama, openai, anthropic, gemini)
model_name: Optional specific model name to use
Returns:
dictionary containing status and confirmation message
Note:
Changes the global LLM configuration but doesn't affect already
initialized MCTS instances
"""
global _global_state
supported_providers = ["ollama", "openai", "anthropic", "gemini"]
provider_name_lower = provider_name.lower()
if provider_name_lower not in supported_providers:
return {
"status": "error",
"message": f"Unsupported provider: '{provider_name}'. Supported: {supported_providers}"
}
_global_state["active_llm_provider"] = provider_name_lower
_global_state["active_model_name"] = model_name
log_msg = f"Set active LLM provider to: {provider_name_lower}."
if model_name:
log_msg += f" Set active model to: {model_name}."
return {"status": "success", "message": log_msg}
@mcp.tool()
def list_ollama_models() -> dict[str, Any]:
"""
List all available Ollama models with recommendations.
Returns:
dictionary containing:
- status: Success or error status
- ollama_available_models: List of all available models
- current_ollama_model: Currently active model
- recommended_small_models: Models suitable for basic tasks
- recommended_medium_models: Models for complex analysis
- message: Status message
Note:
Checks Ollama server connectivity and updates global model cache
"""
logger.info("Listing Ollama models...")
# Check if Ollama server is running
try:
import httpx
with httpx.Client(base_url="http://localhost:11434", timeout=3.0) as client:
response = client.get("/")
if response.status_code != 200:
return {
"status": "error",
"message": "Ollama server not responding. Please ensure Ollama is running."
}
except Exception as e:
return {
"status": "error",
"message": f"Cannot connect to Ollama server: {e!s}"
}
# Get available models
available_models = check_available_models()
if not available_models:
return {
"status": "error",
"message": "No Ollama models found. Try 'ollama pull MODEL_NAME' to download a model."
}
# Get recommendations
recommendations = get_recommended_models(available_models)
current_model = _global_state.get("active_model_name") if _global_state.get("active_llm_provider") == "ollama" else None
# Update global state
_global_state["ollama_available_models"] = available_models
return {
"status": "success",
"ollama_available_models": available_models,
"current_ollama_model": current_model,
"recommended_small_models": recommendations["small_models"],
"recommended_medium_models": recommendations["medium_models"],
"message": f"Found {len(available_models)} Ollama models"
}
@mcp.tool()
def run_mcts(iterations: int = 1, simulations_per_iteration: int = 5, model_name: str | None = None) -> dict[str, Any]:
"""
Run the MCTS algorithm with proper async handling.
Args:
iterations: Number of MCTS iterations to run
simulations_per_iteration: Number of simulations per iteration
model_name: Optional model override (currently unused)
Returns:
dictionary containing:
- status: 'started' if successful
- message: Confirmation message
- provider: Active LLM provider
- model: Active model name
- background_thread_id: Thread ID for monitoring
Note:
Runs MCTS in a background thread to avoid blocking the MCP server
Automatically saves state if persistence is enabled
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {"error": "MCTS not initialized. Call initialize_mcts first."}
active_provider = _global_state.get("active_llm_provider")
active_model = _global_state.get("active_model_name")
if not active_provider or not active_model:
return {"error": "Active LLM provider or model not set."}
# Update config for this run
temp_config = mcts.config.copy()
temp_config["max_iterations"] = iterations
temp_config["simulations_per_iteration"] = simulations_per_iteration
mcts.config = temp_config
logger.info(f"Starting MCTS run with {iterations} iterations, {simulations_per_iteration} simulations per iteration")
def run_mcts_background():
"""Run MCTS in background thread with proper async handling."""
try:
# Use the background loop for all async operations
async def run_search():
await mcts.run_search_iterations(iterations, simulations_per_iteration)
return mcts.get_final_results()
results = run_in_background_loop(run_search())
# Save state if enabled
if temp_config.get("enable_state_persistence", True) and _global_state["current_chat_id"]:
try:
_global_state["state_manager"].save_state(_global_state["current_chat_id"], mcts)
logger.info(f"Saved state for chat ID: {_global_state['current_chat_id']}")
except Exception as e:
logger.error(f"Error saving state: {e}")
# Get best node and tags
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
# Log the tags for debugging/monitoring
if tags:
logger.info(f"Best node tags: {', '.join(tags)}")
logger.info(f"MCTS run completed. Best score: {results.best_score if results else 0.0}")
except Exception as e:
logger.error(f"Error in background MCTS run: {e}")
# Start background thread
background_thread = threading.Thread(target=run_mcts_background)
background_thread.daemon = True
background_thread.start()
return {
"status": "started",
"message": f"MCTS process started with {iterations} iterations and {simulations_per_iteration} simulations per iteration.",
"provider": active_provider,
"model": active_model,
"background_thread_id": background_thread.ident
}
@mcp.tool()
def generate_synthesis() -> dict[str, Any]:
"""
Generate a final synthesis of the MCTS results.
Returns:
dictionary containing:
- synthesis: Generated synthesis text
- best_score: Best score achieved during search
- tags: Descriptive tags from best analysis
- iterations_completed: Number of iterations completed
- provider: LLM provider used
- model: Model used
Raises:
Returns error dict if MCTS not initialized or synthesis fails
Note:
Uses the same background loop as MCTS to ensure consistency
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {"error": "MCTS not initialized. Call initialize_mcts first."}
try:
async def synth():
llm_adapter = mcts.llm
path_nodes = mcts.get_best_path_nodes()
path_thoughts_list = [
f"- (Node {node.sequence}): {node.thought.strip()}"
for node in path_nodes if node.thought and node.parent
]
path_thoughts_str = "\n".join(path_thoughts_list) if path_thoughts_list else "No significant development path identified."
results = mcts.get_final_results()
synth_context = {
"question_summary": mcts.question_summary,
"initial_analysis_summary": truncate_text(mcts.root.content, 300) if mcts.root else "N/A",
"best_score": f"{results.best_score:.1f}",
"path_thoughts": path_thoughts_str,
"final_best_analysis_summary": truncate_text(results.best_solution_content, 400),
"previous_best_summary": "N/A",
"unfit_markers_summary": "N/A",
"learned_approach_summary": "N/A"
}
synthesis = await llm_adapter.synthesize_result(synth_context, mcts.config)
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
return {
"synthesis": synthesis,
"best_score": results.best_score,
"tags": tags,
"iterations_completed": mcts.iterations_completed,
"provider": _global_state.get("active_llm_provider"),
"model": _global_state.get("active_model_name"),
}
# Use the background loop for synthesis generation
synthesis_result = run_in_background_loop(synth())
return synthesis_result
except Exception as e:
logger.error(f"Error generating synthesis: {e}")
return {"error": f"Synthesis generation failed: {e!s}"}
@mcp.tool()
def get_config() -> dict[str, Any]:
"""
Get the current MCTS configuration and system status.
Returns:
dictionary containing all configuration parameters, active LLM settings,
and system capabilities
Note:
Filters out internal configuration keys starting with underscore
"""
global _global_state
config = {k: v for k, v in _global_state["config"].items() if not k.startswith("_")}
config.update({
"active_llm_provider": _global_state.get("active_llm_provider"),
"active_model_name": _global_state.get("active_model_name"),
"ollama_python_package_available": OLLAMA_PYTHON_PACKAGE_AVAILABLE,
"ollama_available_models": _global_state.get("ollama_available_models", []),
"current_run_id": _global_state.get("current_run_id")
})
return config
@mcp.tool()
def update_config(config_updates: dict[str, Any]) -> dict[str, Any]:
"""
Update the MCTS configuration parameters.
Args:
config_updates: dictionary of configuration keys and new values
Returns:
Updated configuration dictionary
Note:
Provider and model changes are ignored - use set_active_llm instead
Updates both global config and active MCTS instance if present
"""
global _global_state
logger.info(f"Updating MCTS config with: {config_updates}")
# Provider and model changes should use set_active_llm
if "active_llm_provider" in config_updates or "active_model_name" in config_updates:
logger.warning("Use 'set_active_llm' tool to change LLM provider or model.")
config_updates.pop("active_llm_provider", None)
config_updates.pop("active_model_name", None)
# Update config
cfg = _global_state["config"].copy()
cfg.update(config_updates)
_global_state["config"] = cfg
mcts = _global_state.get("mcts_instance")
if mcts:
mcts.config = cfg
return get_config()
@mcp.tool()
def get_mcts_status() -> dict[str, Any]:
"""
Get the current status of the MCTS system.
Returns:
dictionary containing:
- initialized: Whether MCTS is initialized
- chat_id: Current chat session ID
- iterations_completed: Number of iterations run
- simulations_completed: Total simulations run
- best_score: Best score achieved
- best_content_summary: Truncated best solution
- tags: Tags from best analysis
- tree_depth: Maximum tree depth explored
- approach_types: List of analytical approaches used
- active_llm_provider: Current LLM provider
- active_model_name: Current model name
- run_id: Current run identifier
Note:
Provides comprehensive status for monitoring and debugging
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {
"initialized": False,
"message": "MCTS not initialized. Call initialize_mcts first."
}
try:
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
return {
"initialized": True,
"chat_id": _global_state.get("current_chat_id"),
"iterations_completed": getattr(mcts, "iterations_completed", 0),
"simulations_completed": getattr(mcts, "simulations_completed", 0),
"best_score": getattr(mcts, "best_score", 0.0),
"best_content_summary": truncate_text(getattr(mcts, "best_solution", ""), 100),
"tags": tags,
"tree_depth": mcts.memory.get("depth", 0) if hasattr(mcts, "memory") else 0,
"approach_types": getattr(mcts, "approach_types", []),
"active_llm_provider": _global_state.get("active_llm_provider"),
"active_model_name": _global_state.get("active_model_name"),
"run_id": _global_state.get("current_run_id")
}
except Exception as e:
logger.error(f"Error getting MCTS status: {e}")
return {
"initialized": True,
"error": f"Error getting MCTS status: {e!s}",
"chat_id": _global_state.get("current_chat_id")
}
@mcp.tool()
def run_model_comparison(question: str, iterations: int = 2, simulations_per_iteration: int = 10) -> dict[str, Any]:
"""
Run MCTS across multiple models for comparison analysis.
Args:
question: The question to analyze across models
iterations: Number of MCTS iterations per model
simulations_per_iteration: Simulations per iteration
Returns:
dictionary containing comparison setup or error information
Note:
Currently returns a placeholder - full implementation requires
additional coordination between multiple MCTS instances
"""
if not OLLAMA_PYTHON_PACKAGE_AVAILABLE:
return {"error": "Ollama python package not available for model comparison."}
# Get available models
models = check_available_models()
recommendations = get_recommended_models(models)
comparison_models = recommendations["small_models"]
if not comparison_models:
return {"error": f"No suitable models found for comparison. Available: {models}"}
return {
"status": "started",
"message": "Model comparison feature available but not implemented in this version",
"question": question,
"models": comparison_models,
"iterations": iterations,
"simulations_per_iteration": simulations_per_iteration
}
```
--------------------------------------------------------------------------------
/src/mcts_mcp_server/analysis_tools/results_processor.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MCTS Results Processor
=====================
This module provides a class for processing, analyzing, and extracting insights
from MCTS run results. It helps identify the most valuable information in MCTS
outputs and present it in a more structured and useful format.
"""
import os
import json
import logging
import datetime
from typing import Dict, Any, List, Optional, Tuple, Set, Union
import re
from pathlib import Path
logger = logging.getLogger("mcts_analysis")
class ResultsProcessor:
"""Processes and analyzes MCTS run results to extract key insights."""
def __init__(self, results_base_dir: Optional[str] = None):
"""
Initialize the results processor.
Args:
results_base_dir: Base directory for MCTS results. If None, defaults to
the standard location.
"""
if results_base_dir is None:
# Default to 'results' in the repository root
repo_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
self.results_base_dir = os.path.join(repo_dir, "results")
else:
self.results_base_dir = results_base_dir
logger.info(f"Initialized ResultsProcessor with base directory: {self.results_base_dir}")
# Cache for analyzed results
self._cache = {}
def list_runs(self, count: int = 10, model: Optional[str] = None) -> List[Dict[str, Any]]:
"""
List recent MCTS runs with key metadata.
Args:
count: Maximum number of runs to return
model: Optional model name to filter by
Returns:
List of run dictionaries with key metadata
"""
runs = []
# Walk through the results directory
for model_dir in os.listdir(self.results_base_dir):
# Skip if filtering by model and not matching
if model and model != model_dir:
continue
model_path = os.path.join(self.results_base_dir, model_dir)
if not os.path.isdir(model_path):
continue
# Check each run directory
for run_dir in os.listdir(model_path):
run_path = os.path.join(model_path, run_dir)
if not os.path.isdir(run_path):
continue
# Try to load metadata
metadata_path = os.path.join(run_path, "metadata.json")
if not os.path.exists(metadata_path):
continue
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
# Extract key information
run_info = {
"run_id": metadata.get("run_id", run_dir),
"model": metadata.get("model_name", model_dir),
"question": metadata.get("question", "Unknown"),
"timestamp": metadata.get("timestamp", 0),
"timestamp_readable": metadata.get("timestamp_readable", "Unknown"),
"status": metadata.get("status", "Unknown"),
"score": metadata.get("results", {}).get("best_score", 0),
"iterations": metadata.get("results", {}).get("iterations_completed", 0),
"simulations": metadata.get("results", {}).get("simulations_completed", 0),
"tags": metadata.get("results", {}).get("tags", []),
"path": run_path
}
runs.append(run_info)
except Exception as e:
logger.warning(f"Failed to parse metadata from {metadata_path}: {e}")
# Sort by timestamp (newest first)
runs.sort(key=lambda r: r.get("timestamp", 0), reverse=True)
# Limit to the requested count
return runs[:count]
def get_run_details(self, run_id: str) -> Optional[Dict[str, Any]]:
"""
Get detailed information about a specific run.
Args:
run_id: Run ID or path to the run directory
Returns:
Dictionary with detailed run information or None if not found
"""
# Handle the case where run_id is a path
if os.path.isdir(run_id):
run_path = run_id
else:
# Search for the run directory
run_path = None
for model_dir in os.listdir(self.results_base_dir):
model_path = os.path.join(self.results_base_dir, model_dir)
if not os.path.isdir(model_path):
continue
potential_path = os.path.join(model_path, run_id)
if os.path.isdir(potential_path):
run_path = potential_path
break
if run_path is None:
logger.warning(f"Run not found: {run_id}")
return None
# Try to load metadata
metadata_path = os.path.join(run_path, "metadata.json")
if not os.path.exists(metadata_path):
logger.warning(f"Metadata not found at {metadata_path}")
return None
try:
with open(metadata_path, 'r') as f:
metadata = json.load(f)
# Load the best solution
best_solution = ""
solution_path = os.path.join(run_path, "best_solution.txt")
if os.path.exists(solution_path):
with open(solution_path, 'r') as f:
best_solution = f.read()
# Load progress information
progress = []
progress_path = os.path.join(run_path, "progress.jsonl")
if os.path.exists(progress_path):
with open(progress_path, 'r') as f:
for line in f:
if line.strip():
try:
progress.append(json.loads(line))
except json.JSONDecodeError:
pass
# Combine everything into a single result
result = {
"metadata": metadata,
"best_solution": best_solution,
"progress": progress,
"run_path": run_path,
"run_id": os.path.basename(run_path)
}
return result
except Exception as e:
logger.warning(f"Failed to load run details from {run_path}: {e}")
return None
def extract_key_concepts(self, solution_text: str) -> List[str]:
"""
Extract key concepts from a solution text.
Args:
solution_text: The solution text to analyze
Returns:
List of key concepts extracted from the text
"""
# Look for sections explicitly labeled as key concepts
key_concepts_match = re.search(r'Key Concepts:(.+?)($|(?:\n\n))', solution_text, re.DOTALL)
if key_concepts_match:
# Extract and clean concepts
concepts_text = key_concepts_match.group(1)
concepts = [c.strip().strip('-*•') for c in concepts_text.strip().split('\n')
if c.strip() and not c.strip().startswith('#')]
return [c for c in concepts if c]
# Fallback: Look for bulleted or numbered lists
bullet_matches = re.findall(r'(?:^|\n)[ \t]*[-•*][ \t]*(.*?)(?:$|\n)', solution_text)
if bullet_matches:
return [m.strip() for m in bullet_matches if m.strip()]
# Last resort: Split paragraphs and take short ones as potential concepts
paragraphs = [p.strip() for p in re.split(r'\n\s*\n', solution_text) if p.strip()]
return [p for p in paragraphs if len(p) < 100 and len(p.split()) < 15][:5]
def extract_key_arguments(self, solution_text: str) -> Dict[str, List[str]]:
"""
Extract key arguments for and against from solution text.
Args:
solution_text: The solution text to analyze
Returns:
Dictionary with 'for' and 'against' keys mapping to lists of arguments
"""
arguments = {"for": [], "against": []}
# Look for "Arguments For" section
for_match = re.search(r'(?:Key )?Arguments For.*?:(.+?)(?:\n\n|\n(?:Arguments|Against))',
solution_text, re.DOTALL | re.IGNORECASE)
if for_match:
for_text = for_match.group(1).strip()
# Extract bullet points
for_args = [a.strip().strip('-*•') for a in re.findall(r'(?:^|\n)[ \t]*[-•*\d\.][ \t]*(.*?)(?:$|\n)', for_text)]
arguments["for"] = [a for a in for_args if a]
# Look for "Arguments Against" section
against_match = re.search(r'(?:Key )?Arguments Against.*?:(.+?)(?:\n\n|\n(?:[A-Z]))',
solution_text, re.DOTALL | re.IGNORECASE)
if against_match:
against_text = against_match.group(1).strip()
# Extract bullet points
against_args = [a.strip().strip('-*•') for a in re.findall(r'(?:^|\n)[ \t]*[-•*\d\.][ \t]*(.*?)(?:$|\n)', against_text)]
arguments["against"] = [a for a in against_args if a]
return arguments
def extract_conclusions(self, solution_text: str, progress: List[Dict[str, Any]]) -> List[str]:
"""
Extract conclusions from a solution and progress syntheses.
Args:
solution_text: The best solution text
progress: Progress information including syntheses
Returns:
List of key conclusions
"""
conclusions = []
# Extract any section labeled "Conclusion" or at the end of the text
conclusion_match = re.search(r'(?:^|\n)Conclusion:?\s*(.*?)(?:$|\n\n)', solution_text, re.DOTALL | re.IGNORECASE)
if conclusion_match:
conclusion_text = conclusion_match.group(1).strip()
conclusions.append(conclusion_text)
else:
# Try to extract the last paragraph as a potential conclusion
paragraphs = [p.strip() for p in re.split(r'\n\s*\n', solution_text) if p.strip()]
if paragraphs and not paragraphs[-1].startswith('#') and len(paragraphs[-1]) > 50:
conclusions.append(paragraphs[-1])
# Extract syntheses from progress
for entry in progress:
if "synthesis" in entry:
# Take the last paragraph of each synthesis as a conclusion
synthesis_paragraphs = [p.strip() for p in re.split(r'\n\s*\n', entry["synthesis"]) if p.strip()]
if synthesis_paragraphs:
conclusions.append(synthesis_paragraphs[-1])
# Remove duplicates and very similar conclusions
unique_conclusions = []
for c in conclusions:
if not any(self._text_similarity(c, uc) > 0.7 for uc in unique_conclusions):
unique_conclusions.append(c)
return unique_conclusions
def _text_similarity(self, text1: str, text2: str) -> float:
"""
Calculate a simple similarity score between two texts.
Args:
text1: First text
text2: Second text
Returns:
Similarity score between 0 and 1
"""
# Normalize and tokenize
words1 = set(re.findall(r'\w+', text1.lower()))
words2 = set(re.findall(r'\w+', text2.lower()))
# Calculate Jaccard similarity
if not words1 or not words2:
return 0.0
intersection = len(words1.intersection(words2))
union = len(words1.union(words2))
return intersection / union
def analyze_run(self, run_id: str) -> Dict[str, Any]:
"""
Perform a comprehensive analysis of a run's results.
Args:
run_id: The run ID to analyze
Returns:
Dictionary with analysis results
"""
# Check cache first
if run_id in self._cache:
return self._cache[run_id]
# Get the run details
run_details = self.get_run_details(run_id)
if not run_details:
return {"error": f"Run not found: {run_id}"}
# Extract key information
best_solution = run_details.get("best_solution", "")
progress = run_details.get("progress", [])
metadata = run_details.get("metadata", {})
# Extract key insights
key_concepts = self.extract_key_concepts(best_solution)
key_arguments = self.extract_key_arguments(best_solution)
conclusions = self.extract_conclusions(best_solution, progress)
# Extract tags
tags = metadata.get("results", {}).get("tags", [])
# Prepare the analysis results
analysis = {
"run_id": run_id,
"question": metadata.get("question", "Unknown"),
"model": metadata.get("model_name", "Unknown"),
"timestamp": metadata.get("timestamp_readable", "Unknown"),
"duration": metadata.get("duration_seconds", 0),
"status": metadata.get("status", "Unknown"),
"best_score": metadata.get("results", {}).get("best_score", 0),
"tags": tags,
"key_concepts": key_concepts,
"arguments_for": key_arguments["for"],
"arguments_against": key_arguments["against"],
"conclusions": conclusions,
"path": run_details.get("run_path", "")
}
# Cache the results
self._cache[run_id] = analysis
return analysis
def compare_runs(self, run_ids: List[str]) -> Dict[str, Any]:
"""
Compare multiple runs to identify similarities and differences.
Args:
run_ids: List of run IDs to compare
Returns:
Dictionary with comparison results
"""
# Analyze each run
analyses = [self.analyze_run(run_id) for run_id in run_ids]
analyses = [a for a in analyses if "error" not in a]
if not analyses:
return {"error": "No valid runs to compare"}
# Extract shared and unique concepts
all_concepts = [set(a.get("key_concepts", [])) for a in analyses]
shared_concepts = set.intersection(*all_concepts) if all_concepts else set()
unique_concepts = {}
for i, a in enumerate(analyses):
run_id = a.get("run_id", f"run_{i}")
unique = all_concepts[i] - set.union(*[c for j, c in enumerate(all_concepts) if j != i])
if unique:
unique_concepts[run_id] = list(unique)
# Get mean score
scores = [a.get("best_score", 0) for a in analyses]
mean_score = sum(scores) / len(scores) if scores else 0
# Find common arguments
all_args_for = [set(a.get("arguments_for", [])) for a in analyses]
all_args_against = [set(a.get("arguments_against", [])) for a in analyses]
shared_args_for = set.intersection(*all_args_for) if all_args_for else set()
shared_args_against = set.intersection(*all_args_against) if all_args_against else set()
# Prepare comparison results
comparison = {
"runs_compared": run_ids,
"models": [a.get("model", "Unknown") for a in analyses],
"mean_score": mean_score,
"shared_concepts": list(shared_concepts),
"unique_concepts": unique_concepts,
"shared_arguments_for": list(shared_args_for),
"shared_arguments_against": list(shared_args_against),
"best_run": max(analyses, key=lambda a: a.get("best_score", 0)).get("run_id") if analyses else None
}
return comparison
def get_best_runs(self, count: int = 5, min_score: float = 7.0) -> List[Dict[str, Any]]:
"""
Get the best MCTS runs based on score.
Args:
count: Maximum number of runs to return
min_score: Minimum score threshold
Returns:
List of best run analyses
"""
# List all runs
all_runs = self.list_runs(count=100) # Get more than we need to filter
# Filter by minimum score
qualifying_runs = [r for r in all_runs if r.get("score", 0) >= min_score]
# Sort by score (highest first)
qualifying_runs.sort(key=lambda r: r.get("score", 0), reverse=True)
# Analyze the top runs
return [self.analyze_run(r.get("run_id")) for r in qualifying_runs[:count]]
def generate_report(self, run_id: str, format: str = "markdown") -> str:
"""
Generate a comprehensive report for a run.
Args:
run_id: Run ID to generate report for
format: Output format ('markdown', 'text', or 'html')
Returns:
Formatted report as a string
"""
# Analyze the run
analysis = self.analyze_run(run_id)
if "error" in analysis:
return f"Error: {analysis['error']}"
# Get the run details for additional information
run_details = self.get_run_details(run_id)
if not run_details:
return f"Error: Run details not found for {run_id}"
# Generate the report based on the format
if format == "markdown":
return self._generate_markdown_report(analysis, run_details)
elif format == "text":
return self._generate_text_report(analysis, run_details)
elif format == "html":
return self._generate_html_report(analysis, run_details)
else:
return f"Unsupported format: {format}"
def _generate_markdown_report(self, analysis: Dict[str, Any], run_details: Dict[str, Any]) -> str:
"""Generate a markdown report."""
report = []
# Header
report.append(f"# MCTS Analysis Report: {analysis['run_id']}")
report.append("")
# Basic information
report.append("## Basic Information")
report.append("")
report.append(f"- **Question:** {analysis['question']}")
report.append(f"- **Model:** {analysis['model']}")
report.append(f"- **Date:** {analysis['timestamp']}")
report.append(f"- **Duration:** {analysis['duration']} seconds")
report.append(f"- **Score:** {analysis['best_score']}")
if analysis['tags']:
report.append(f"- **Tags:** {', '.join(analysis['tags'])}")
report.append("")
# Key concepts
if analysis.get('key_concepts'):
report.append("## Key Concepts")
report.append("")
for concept in analysis['key_concepts']:
report.append(f"- {concept}")
report.append("")
# Key arguments
if analysis.get('arguments_for') or analysis.get('arguments_against'):
report.append("## Key Arguments")
report.append("")
if analysis.get('arguments_for'):
report.append("### Arguments For")
report.append("")
for arg in analysis['arguments_for']:
report.append(f"- {arg}")
report.append("")
if analysis.get('arguments_against'):
report.append("### Arguments Against")
report.append("")
for arg in analysis['arguments_against']:
report.append(f"- {arg}")
report.append("")
# Conclusions
if analysis.get('conclusions'):
report.append("## Key Conclusions")
report.append("")
for conclusion in analysis['conclusions']:
report.append(f"> {conclusion}")
report.append("")
# Best solution
best_solution = run_details.get('best_solution', '')
if best_solution:
report.append("## Best Solution")
report.append("")
report.append("```")
report.append(best_solution)
report.append("```")
return "\n".join(report)
def _generate_text_report(self, analysis: Dict[str, Any], run_details: Dict[str, Any]) -> str:
"""Generate a plain text report."""
report = []
# Header
report.append(f"MCTS Analysis Report: {analysis['run_id']}")
report.append("=" * 80)
report.append("")
# Basic information
report.append("Basic Information:")
report.append(f" Question: {analysis['question']}")
report.append(f" Model: {analysis['model']}")
report.append(f" Date: {analysis['timestamp']}")
report.append(f" Duration: {analysis['duration']} seconds")
report.append(f" Score: {analysis['best_score']}")
if analysis['tags']:
report.append(f" Tags: {', '.join(analysis['tags'])}")
report.append("")
# Key concepts
if analysis.get('key_concepts'):
report.append("Key Concepts:")
for concept in analysis['key_concepts']:
report.append(f" * {concept}")
report.append("")
# Key arguments
if analysis.get('arguments_for') or analysis.get('arguments_against'):
report.append("Key Arguments:")
if analysis.get('arguments_for'):
report.append(" Arguments For:")
for arg in analysis['arguments_for']:
report.append(f" * {arg}")
report.append("")
if analysis.get('arguments_against'):
report.append(" Arguments Against:")
for arg in analysis['arguments_against']:
report.append(f" * {arg}")
report.append("")
# Conclusions
if analysis.get('conclusions'):
report.append("Key Conclusions:")
for conclusion in analysis['conclusions']:
report.append(f" {conclusion}")
report.append("")
# Best solution
best_solution = run_details.get('best_solution', '')
if best_solution:
report.append("Best Solution:")
report.append("-" * 80)
report.append(best_solution)
report.append("-" * 80)
return "\n".join(report)
def _generate_html_report(self, analysis: Dict[str, Any], run_details: Dict[str, Any]) -> str:
"""Generate an HTML report."""
# For now, we'll convert the markdown to basic HTML
md_report = self._generate_markdown_report(analysis, run_details)
# Convert headers
html = re.sub(r'^# (.*?)$', r'<h1>\1</h1>', md_report, flags=re.MULTILINE)
html = re.sub(r'^## (.*?)$', r'<h2>\1</h2>', html, flags=re.MULTILINE)
html = re.sub(r'^### (.*?)$', r'<h3>\1</h3>', html, flags=re.MULTILINE)
# Convert lists
html = re.sub(r'^- (.*?)$', r'<li>\1</li>', html, flags=re.MULTILINE)
html = re.sub(r'(<li>.*?</li>\n)+', r'<ul>\n\g<0></ul>', html, flags=re.DOTALL)
# Convert blockquotes
html = re.sub(r'^> (.*?)$', r'<blockquote>\1</blockquote>', html, flags=re.MULTILINE)
# Convert code blocks
html = re.sub(r'```\n(.*?)```', r'<pre><code>\1</code></pre>', html, flags=re.DOTALL)
# Convert line breaks
html = re.sub(r'\n\n', r'<br><br>', html)
# Wrap in basic HTML structure
html = f"""<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MCTS Analysis Report: {analysis['run_id']}</title>
<style>
body {{ font-family: Arial, sans-serif; line-height: 1.6; max-width: 800px; margin: 0 auto; padding: 20px; }}
h1, h2, h3 {{ color: #333; }}
blockquote {{ background-color: #f9f9f9; border-left: 5px solid #ccc; padding: 10px 20px; margin: 20px 0; }}
pre {{ background-color: #f5f5f5; padding: 15px; overflow-x: auto; }}
ul {{ margin-bottom: 20px; }}
</style>
</head>
<body>
{html}
</body>
</html>
"""
return html
def extract_insights(self, run_id: str, max_insights: int = 5) -> List[str]:
"""
Extract key insights from a run's results.
Args:
run_id: The run ID to analyze
max_insights: Maximum number of insights to extract
Returns:
List of key insights as strings
"""
# Analyze the run
analysis = self.analyze_run(run_id)
if "error" in analysis:
return [f"Error: {analysis['error']}"]
insights = []
# Add conclusions as insights
for conclusion in analysis.get('conclusions', [])[:max_insights]:
if conclusion and not any(self._text_similarity(conclusion, i) > 0.7 for i in insights):
insights.append(conclusion)
# Add key arguments as insights if we need more
if len(insights) < max_insights:
for_args = analysis.get('arguments_for', [])
against_args = analysis.get('arguments_against', [])
# Interleave arguments for and against
all_args = []
for i in range(max(len(for_args), len(against_args))):
if i < len(for_args):
all_args.append(("For: " + for_args[i]) if for_args[i].startswith("For: ") else for_args[i])
if i < len(against_args):
all_args.append(("Against: " + against_args[i]) if against_args[i].startswith("Against: ") else against_args[i])
# Add arguments as insights
for arg in all_args:
if len(insights) >= max_insights:
break
if not any(self._text_similarity(arg, i) > 0.7 for i in insights):
insights.append(arg)
# Add key concepts as insights if we still need more
if len(insights) < max_insights:
for concept in analysis.get('key_concepts', []):
if len(insights) >= max_insights:
break
if not any(self._text_similarity(concept, i) > 0.7 for i in insights):
insights.append(concept)
return insights
def suggest_improvements(self, run_id: str) -> List[str]:
"""
Suggest improvements for MCTS runs based on analysis.
Args:
run_id: The run ID to analyze
Returns:
List of improvement suggestions
"""
# Analyze the run
analysis = self.analyze_run(run_id)
if "error" in analysis:
return [f"Error: {analysis['error']}"]
suggestions = []
# Check if we've got enough iterations
iterations = analysis.get('iterations', 0)
if iterations < 2:
suggestions.append(f"Increase iterations from {iterations} to at least 2 for more thorough exploration")
# Check score
score = analysis.get('best_score', 0)
if score < 7.0:
suggestions.append(f"Current score is {score}, which is relatively low. Try using a more sophisticated model or adjusting exploration parameters")
# Check for diverse approaches
if len(analysis.get('key_concepts', [])) < 3:
suggestions.append("Limited key concepts identified. Consider increasing exploration weight parameter for more diverse thinking")
# Check for balanced arguments
if len(analysis.get('arguments_for', [])) > 0 and len(analysis.get('arguments_against', [])) == 0:
suggestions.append("Arguments are one-sided (only 'for' arguments). Consider using a balanced prompt approach to get both sides")
elif len(analysis.get('arguments_against', [])) > 0 and len(analysis.get('arguments_for', [])) == 0:
suggestions.append("Arguments are one-sided (only 'against' arguments). Consider using a balanced prompt approach to get both sides")
# Check for bayesian parameters if score is low
if score < 8.0:
suggestions.append("Try adjusting the prior parameters (beta_prior_alpha/beta) to improve the bandit algorithm performance")
# Default suggestion
if not suggestions:
suggestions.append("The MCTS run looks good and achieved a reasonable score. For even better results, try increasing iterations or using a more capable model")
return suggestions
```
--------------------------------------------------------------------------------
/archive/fixed_tools.py:
--------------------------------------------------------------------------------
```python
# -*- coding: utf-8 -*-
"""
MCP Tools for MCTS
=================
This module defines the MCP tools that expose the MCTS functionality.
"""
import asyncio
import json
import logging
import datetime
import os
import sys
import importlib.util
import subprocess
import concurrent.futures
import inspect
import traceback
from typing import Dict, Any, Optional, List
# Ensure the 'src' directory (parent of this 'mcts_mcp_server' directory) is in sys.path
_current_file_dir = os.path.dirname(os.path.abspath(__file__))
_src_dir = os.path.dirname(_current_file_dir)
if _src_dir not in sys.path:
sys.path.insert(0, _src_dir)
try:
from fastmcp import MCP
except ImportError:
# Fallback if fastmcp is not available
class MCP:
def __init__(self):
pass
def tool(self):
def decorator(func):
return func
return decorator
# Try several import strategies for DirectMcpLLMAdapter
DirectMcpLLMAdapter = None
LLM_ADAPTER_AVAILABLE = False
# Strategy 1: Direct module import
try:
from llm_adapter import DirectMcpLLMAdapter
LLM_ADAPTER_AVAILABLE = True
print("Successfully imported DirectMcpLLMAdapter (direct)")
except ImportError as e:
print(f"Failed direct import of DirectMcpLLMAdapter: {e}")
# Strategy 2: Package import
try:
from mcts_mcp_server.llm_adapter import DirectMcpLLMAdapter
LLM_ADAPTER_AVAILABLE = True
print("Successfully imported DirectMcpLLMAdapter (package)")
except ImportError as e:
print(f"Failed package import of DirectMcpLLMAdapter: {e}")
# Strategy 3: Manual module loading
try:
adapter_path = os.path.join(_current_file_dir, "llm_adapter.py") # Fixed: use _current_file_dir
if os.path.exists(adapter_path):
spec = importlib.util.spec_from_file_location("llm_adapter", adapter_path)
if spec is not None and spec.loader is not None:
llm_adapter_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(llm_adapter_module)
DirectMcpLLMAdapter = llm_adapter_module.DirectMcpLLMAdapter
LLM_ADAPTER_AVAILABLE = True
print("Successfully imported DirectMcpLLMAdapter (manual load)")
else:
print(f"Failed to create module spec or loader for {adapter_path}")
else:
print(f"llm_adapter.py file not found at {adapter_path}")
except Exception as e:
print(f"Failed manual import of DirectMcpLLMAdapter: {e}")
if not LLM_ADAPTER_AVAILABLE:
print("Warning: DirectMcpLLMAdapter not available, will need fallback")
# Try different import strategies for OllamaAdapter
OLLAMA_AVAILABLE = False
OllamaAdapter = None
try:
from ollama_adapter import OllamaAdapter
OLLAMA_AVAILABLE = True
print("Successfully imported OllamaAdapter (direct)")
except ImportError as e:
print(f"Failed direct import: {e}")
try:
from mcts_mcp_server.ollama_adapter import OllomaAdapter
OLLAMA_AVAILABLE = True
print("Successfully imported OllomaAdapter (package)")
except ImportError as e:
print(f"Failed package import: {e}")
# Rest of the imports
try:
from mcts_core import MCTS, DEFAULT_CONFIG, truncate_text
except ImportError as e:
print(f"Failed to import mcts_core: {e}")
try:
from state_manager import StateManager
except ImportError as e:
print(f"Failed to import state_manager: {e}")
# Initialize logger
logger = logging.getLogger(__name__)
# Global state
_global_state = {
"mcts_instance": None,
"config": None,
"state_manager": None,
"current_chat_id": None,
"ollama_model": "qwen3:0.6b",
"available_models": []
}
def register_mcts_tools(mcp: MCP, db_path: str):
"""
Register all MCTS-related tools with the MCP server.
Args:
mcp: The FastMCP instance to register tools with
db_path: Path to the state database
"""
# Initialize state manager
try:
_global_state["state_manager"] = StateManager(db_path)
_global_state["config"] = DEFAULT_CONFIG.copy()
except Exception as e:
logger.error(f"Failed to initialize state manager: {e}")
return
@mcp.tool()
def test_tool() -> Dict[str, Any]:
"""Test tool to verify the system is working."""
return {
"status": "success",
"message": "MCTS tools are loaded and working",
"adapters_available": {
"ollama": OLLAMA_AVAILABLE,
"llm_adapter": LLM_ADAPTER_AVAILABLE
}
}
# Add more tools here as needed
logger.info("MCTS tools registered successfully")
@mcp.tool()
def initialize_mcts(question: str, chat_id: str, model_name: Optional[str] = None, config_updates: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Initialize the MCTS system with a new question.
Args:
question: The question or text to analyze
chat_id: Unique identifier for the chat session
model_name: Optional specific Ollama model to use
config_updates: Optional dictionary of configuration updates
Returns:
Dictionary with initialization status and initial analysis
"""
global _global_state
try:
logger.info(f"Initializing MCTS for chat ID: {chat_id}")
# Update config if provided
if config_updates:
cfg = _global_state["config"].copy()
cfg.update(config_updates)
_global_state["config"] = cfg
else:
cfg = _global_state["config"]
# Store chat ID for state persistence
_global_state["current_chat_id"] = chat_id
# Try to load previous state
state_manager = _global_state["state_manager"]
loaded_state = None
if cfg.get("enable_state_persistence", True):
loaded_state = state_manager.load_state(chat_id)
if loaded_state:
logger.info(f"Loaded previous state for chat ID: {chat_id}")
else:
logger.info(f"No previous state found for chat ID: {chat_id}")
# Initialize the LLM adapter - ALWAYS use Ollama
logger.info("Initializing LLM adapter...")
# Get available Ollama models
available_models = check_available_models()
# If user specified a model, try to use it
if model_name:
if model_name in available_models:
_global_state["ollama_model"] = model_name
logger.info(f"Using user-specified model: {model_name}")
else:
# Try to find a model with the same base name
model_base = model_name.split(':')[0]
matching_models = [m for m in available_models if m.startswith(model_base + ':')]
if matching_models:
model_name = matching_models[0]
_global_state["ollama_model"] = model_name
logger.info(f"Found similar model: {model_name}")
else:
logger.warning(f"Model '{model_name}' not found. Using default model selection.")
# Make sure we have a selected model
model_name = _global_state["ollama_model"]
if not available_models or model_name not in available_models:
# If we have models but current selection isn't valid, pick a new one
if available_models:
select_default_model(available_models)
model_name = _global_state["ollama_model"]
logger.info(f"Selected model not available, using {model_name}")
# Always try to use Ollama first
logger.info(f"Using OllamaAdapter with model {model_name}")
try:
if OLLAMA_AVAILABLE and OllamaAdapter is not None:
# Check if OllamaAdapter is properly implemented
try:
# Check if OllamaAdapter is properly implemented before instantiation
import inspect
if inspect.isabstract(OllamaAdapter):
abstract_methods = getattr(OllamaAdapter, '__abstractmethods__', set())
raise NotImplementedError(f"OllamaAdapter has unimplemented abstract methods: {abstract_methods}")
llm_adapter = OllamaAdapter(model_name=model_name, mcp_server=mcp)
# Test if the adapter has required methods implemented
if not hasattr(llm_adapter, 'get_completion') or not callable(getattr(llm_adapter, 'get_completion')):
raise NotImplementedError("OllamaAdapter.get_completion not properly implemented")
if not hasattr(llm_adapter, 'get_streaming_completion') or not callable(getattr(llm_adapter, 'get_streaming_completion')):
raise NotImplementedError("OllamaAdapter.get_streaming_completion not properly implemented")
except (TypeError, NotImplementedError) as e:
logger.error(f"OllamaAdapter is not properly implemented: {e}")
raise ImportError("OllamaAdapter implementation incomplete")
else:
raise ImportError("OllamaAdapter not available")
except Exception as e:
# Only use the fallback adapter if Ollama fails completely
logger.error(f"Failed to initialize Ollama adapter: {e}")
logger.info("Using DirectMcpLLMAdapter as fallback")
llm_adapter = DirectMcpLLMAdapter(mcp)
# Test the adapter with a simple query
try:
async def test_adapter():
# Ensure we have a valid model name for testing
test_model = model_name or _global_state["ollama_model"] or DEFAULT_MODEL
test_result = await llm_adapter.get_completion(test_model, [{"role": "user", "content": "Test message"}])
return test_result
run_async(test_adapter())
logger.info("LLM adapter working properly")
except Exception as e:
logger.warning(f"Error testing LLM adapter: {e}. Using default LocalInferenceLLMAdapter.")
from llm_adapter import LocalInferenceLLMAdapter
llm_adapter = LocalInferenceLLMAdapter()
# Generate initial analysis
logger.info("Generating initial analysis...")
initial_prompt = f"<instruction>Provide an initial analysis and interpretation of the core themes, arguments, and potential implications presented. Identify key concepts. Respond with clear, natural language text ONLY.</instruction><question>{question}</question>"
initial_messages = [{"role": "user", "content": initial_prompt}]
# Call LLM for initial analysis (synchronously)
async def get_initial_analysis():
return await llm_adapter.get_completion(model_name or _global_state["ollama_model"], initial_messages)
initial_analysis = run_async(get_initial_analysis())
# Ensure initial_analysis is a string
if initial_analysis is None:
initial_analysis = "Initial analysis not available."
# Initialize MCTS
logger.info("Creating MCTS instance...")
async def init_mcts():
return MCTS(
llm_interface=llm_adapter,
question=question,
initial_analysis_content=initial_analysis,
config=cfg,
initial_state=loaded_state
)
_global_state["mcts_instance"] = run_async(init_mcts())
# Start collecting results if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None:
current_run_id = results_collector.start_run(
model_name=_global_state["ollama_model"] or DEFAULT_MODEL, # Always use Olloma model
question=question,
config=cfg
)
_global_state["current_run_id"] = current_run_id
logger.info(f"Started collecting results for run ID: {current_run_id}")
# Return success and initial analysis
return {
"status": "initialized",
"question": question,
"chat_id": chat_id,
"initial_analysis": initial_analysis,
"loaded_state": loaded_state is not None,
"adapter_type": "ollama", # Always use Olloma
"model": _global_state["ollama_model"],
"config": {k: v for k, v in cfg.items() if not k.startswith("_")}, # Filter internal config
"run_id": _global_state.get("current_run_id")
}
except Exception as e:
logger.error(f"Error in initialize_mcts {e}")
return {"error": f"Failed to initialize MCTS: {str(e)}"}
@mcp.tool()
def set_ollama_model(model_name: str) -> Dict[str, Any]:
"""
Set the Olloma model to use for future MCTS runs.
Args:
model_name: Name of the Olloma model (e.g., "qwen3:0.6b", "deepseek-r1:1.5b", etc.)
Returns:
Status message
"""
global _global_state
if not OLLAMA_AVAILABLE:
return {"error": "Olloma support is not available. Make sure olloma package is installed."}
# Refresh available models
available_models = check_available_models()
# Check if model is available
if model_name not in available_models:
# Try partial matches (just the model name without version specification)
model_base = model_name.split(':')[0]
matching_models = [m for m in available_models if m.startswith(model_base + ':')]
if matching_models:
# Found models with the same base name
model_name = matching_models[0]
_global_state["olloma_model"] = model_name
return {
"status": "success",
"message": f"Model '{model_name}' selected from available models with base name '{model_base}'.",
"available_similar_models": matching_models
}
else:
return {
"status": "warning",
"message": f"Model '{model_name}' is not available. Available models: {available_models}. You may need to pull it with 'olloma pull {model_name}'.",
"available_models": available_models
}
_global_state["olloma_model"] = model_name
return {
"status": "success",
"message": f"Set Olloma model to {model_name}. It will be used in the next MCTS initialization."
}
@mcp.tool()
def list_ollama_models() -> Dict[str, Any]:
"""
List all available Olloma models.
Returns:
Dictionary with available models and their details
"""
# Force direct command line call for reliability but with better error handling
try:
import subprocess
result = subprocess.run(['olloma', 'list'], capture_output=True, text=True, check=True)
lines = result.stdout.strip().split('\n')
# Skip the header line if present
if len(lines) > 1 and "NAME" in lines[0] and "ID" in lines[0]:
lines = lines[1:]
# Extract model names
available_models = []
model_details = []
for line in lines:
if not line.strip():
continue
parts = line.split()
if len(parts) >= 3: # We need at least NAME, ID, and SIZE
model_name = parts[0]
model_id = parts[1]
model_size = parts[2]
if ':' not in model_name:
model_name += ':latest'
available_models.append(model_name)
model_details.append({
"name": model_name,
"id": model_id,
"size": model_size
})
# Update global state
if available_models:
_global_state["available_models"] = available_models
# Select a default model if needed
if not _global_state["olloma_model"] or _global_state["olloma_model"] not in available_models:
select_default_model(available_models)
return {
"status": "success",
"available_models": available_models,
"model_details": model_details,
"current_model": _global_state["olloma_model"],
"recommended_small_models": SMALL_MODELS,
"recommended_medium_models": MEDIUM_MODELS
}
except Exception as e:
logger.warning(f"Command-line list failed: {e}")
# Fall back to check_available_models as a second attempt
available_models = check_available_models()
# Get more detailed model information when possible
model_details = []
try:
import subprocess
import json
# Try using olloma show command to get detailed info
for model in available_models:
try:
result = subprocess.run(['olloma', 'show', model, '--json'],
capture_output=True, text=True, check=False)
if result.returncode == 0 and result.stdout.strip():
details = json.loads(result.stdout)
model_details.append({
"name": model,
"parameter_size": details.get("parameter_size", "unknown"),
"quantization": details.get("quantization_level", "unknown"),
"family": details.get("family", "unknown"),
"size_mb": round(details.get("size", 0) / (1024 * 1024), 1)
})
except Exception as e:
logger.warning(f"Error getting details for model {model}: {e}")
except Exception as e:
logger.warning(f"Error getting detailed model information: {e}")
return {
"status": "success",
"available_models": available_models,
"model_details": model_details,
"current_model": _global_state["olloma_model"],
"recommended_small_models": SMALL_MODELS,
"recommended_medium_models": MEDIUM_MODELS
}
@mcp.tool()
def run_mcts(iterations: int = 1, simulations_per_iteration: int = 5) -> Dict[str, Any]:
"""
Run the MCTS algorithm for the specified number of iterations.
Args:
iterations: Number of MCTS iterations to run (default: 1)
simulations_per_iteration: Number of simulations per iteration (default: 5)
Returns:
Dictionary with results of the MCTS run
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {"error": "MCTS not initialized. Call initialize_mcts first."}
# Override config values for this run
temp_config = mcts.config.copy()
temp_config["max_iterations"] = iterations
temp_config["simulations_per_iteration"] = simulations_per_iteration
mcts.config = temp_config
logger.info(f"Running MCTS with {iterations} iterations, {simulations_per_iteration} simulations per iteration...")
# Update collector status if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.update_run_status(
_global_state["current_run_id"],
"running",
{
"iterations": iterations,
"simulations_per_iteration": simulations_per_iteration,
"timestamp": int(datetime.datetime.now().timestamp())
}
)
# Run MCTS (synchronously)
async def run_search():
await mcts.run_search_iterations(iterations, simulations_per_iteration)
return mcts.get_final_results()
try:
results = run_async(run_search())
except Exception as e:
logger.error(f"Error running MCTS: {e}")
# Update collector with failure if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.update_run_status(
_global_state["current_run_id"],
"failed",
{"error": str(e), "timestamp": int(datetime.datetime.now().timestamp())}
)
return {"error": f"MCTS run failed: {str(e)}"}
# Check if results is None
if results is None:
logger.error("MCTS search returned None results")
return {"error": "MCTS search returned no results"}
# Save state if enabled
if temp_config.get("enable_state_persistence", True) and _global_state["current_chat_id"]:
try:
_global_state["state_manager"].save_state(_global_state["current_chat_id"], mcts)
logger.info(f"Saved state for chat ID: {_global_state['current_chat_id']}")
except Exception as e:
logger.error(f"Error saving state: {e}")
# Find best node and tags
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
# Prepare results
result_dict = {
"status": "completed",
"best_score": getattr(results, 'best_score', 0.0),
"best_solution": getattr(results, 'best_solution_content', ''),
"tags": tags,
"iterations_completed": mcts.iterations_completed,
"simulations_completed": mcts.simulations_completed,
"model": _global_state["olloma_model"], # Always use Olloma model
}
# Save results to collector if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.save_run_results(_global_state["current_run_id"], result_dict)
result_dict["run_id"] = _global_state["current_run_id"]
return result_dict
@mcp.tool()
def generate_synthesis() -> Dict[str, Any]:
"""
Generate a final synthesis of the MCTS results.
Returns:
Dictionary with the synthesis and related information
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {"error": "MCTS not initialized. Call initialize_mcts first."}
logger.info("Generating synthesis of MCTS results...")
# Use same LLM adapter as the MCTS instance
llm_adapter = mcts.llm
async def synth():
# Prepare context for synthesis
path_nodes = mcts.get_best_path_nodes()
path_thoughts_list = [
f"- (Node {node.sequence}): {node.thought.strip()}"
for node in path_nodes if node.thought and node.parent
]
path_thoughts_str = "\n".join(path_thoughts_list) if path_thoughts_list else "No significant development path identified."
results = mcts.get_final_results()
synth_context = {
"question_summary": mcts.question_summary,
"initial_analysis_summary": truncate_text(mcts.root.content, 300) if mcts.root else "N/A",
"best_score": f"{results.best_score:.1f}",
"path_thoughts": path_thoughts_str,
"final_best_analysis_summary": truncate_text(results.best_solution_content, 400),
# Add defaults for unused but required keys
"previous_best_summary": "N/A",
"unfit_markers_summary": "N/A",
"learned_approach_summary": "N/A"
}
# Use the synthesize_result method from the LLMInterface
synthesis = await llm_adapter.synthesize_result(synth_context, mcts.config)
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
return {
"synthesis": synthesis,
"best_score": results.best_score,
"tags": tags,
"iterations_completed": mcts.iterations_completed,
"model": _global_state["olloma_model"], # Always use Olloma model
}
try:
synthesis_result = run_async(synth())
if synthesis_result is None:
return {"error": "Failed to generate synthesis - no results returned"}
# Update results in collector if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.update_run_status(
_global_state["current_run_id"],
"completed",
{"synthesis": synthesis_result.get("synthesis")}
)
synthesis_result["run_id"] = _global_state["current_run_id"]
return synthesis_result
except Exception as e:
logger.error(f"Error generating synthesis: {e}")
return {"error": f"Synthesis generation failed: {str(e)}"} @mcp.tool()
def get_config() -> Dict[str, Any]:
"""
Get the current MCTS configuration.
Returns:
Dictionary with the current configuration values
"""
global _global_state
# Add Olloma-specific config
config = {k: v for k, v in _global_state["config"].items() if not k.startswith("_")}
config.update({
"olloma_model": _global_state["olloma_model"],
"available_models": _global_state["available_models"],
"collect_results": _global_state["collect_results"],
"current_run_id": _global_state.get("current_run_id")
})
return config
@mcp.tool()
def update_config(config_updates: Dict[str, Any]) -> Dict[str, Any]:
"""
Update the MCTS configuration.
Args:
config_updates: Dictionary with configuration keys and values to update
Returns:
Dictionary with the updated configuration
"""
global _global_state
logger.info(f"Updating config with: {config_updates}")
if "olloma_model" in config_updates:
model_name = config_updates.pop("olloma_model")
# Check if model is available
if not _global_state["available_models"]:
check_available_models()
if model_name in _global_state["available_models"] or not _global_state["available_models"]:
_global_state["olloma_model"] = model_name
else:
logger.warning(f"Model {model_name} not available, keeping current model {_global_state['olloma_model']}")
if "collect_results" in config_updates:
_global_state["collect_results"] = bool(config_updates.pop("collect_results"))
# Update regular MCTS config
cfg = _global_state["config"].copy()
cfg.update(config_updates)
_global_state["config"] = cfg
# If MCTS instance exists, update its config
mcts = _global_state.get("mcts_instance")
if mcts:
mcts.config = cfg
# Return filtered config (without private items)
config = {k: v for k, v in cfg.items() if not k.startswith("_")}
# Add Olloma-specific config
config.update({
"olloma_model": _global_state["olloma_model"],
"olloma_available": OLLAMA_AVAILABLE,
"available_models": _global_state["available_models"],
"collect_results": _global_state["collect_results"],
"current_run_id": _global_state.get("current_run_id")
})
return config
@mcp.tool()
def get_mcts_status() -> Dict[str, Any]:
"""
Get the current status of the MCTS system.
Returns:
Dictionary with status information
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {
"initialized": False,
"message": "MCTS not initialized. Call initialize_mcts first."
}
try:
# Get best node and extract information
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
return {
"initialized": True,
"chat_id": _global_state.get("current_chat_id"),
"iterations_completed": getattr(mcts, "iterations_completed", 0),
"simulations_completed": getattr(mcts, "simulations_completed", 0),
"best_score": getattr(mcts, "best_score", 0.0),
"best_content_summary": truncate_text(getattr(mcts, "best_solution", ""), 100)
if hasattr(mcts, "best_solution") else "N/A",
"tags": tags,
"tree_depth": mcts.memory.get("depth", 0) if hasattr(mcts, "memory") else 0,
"approach_types": getattr(mcts, "approach_types", []),
"adapter_type": "ollama", # Always use Olloma
"model": _global_state["olloma_model"], # Always use Olloma model
"collected_results": _global_state["collect_results"],
"run_id": _global_state.get("current_run_id")
}
except Exception as e:
logger.error(f"Error getting MCTS status: {e}")
return {
"initialized": True,
"error": f"Error getting MCTS status: {str(e)}",
"chat_id": _global_state.get("current_chat_id")
}
@mcp.tool()
def run_model_comparison(question: str, iterations: int = 2, simulations_per_iteration: int = 10) -> Dict[str, Any]:
"""
Run MCTS with the same question across multiple models for comparison.
Args:
question: The question to analyze with MCTS
iterations: Number of MCTS iterations per model (default: 2)
simulations_per_iteration: Simulations per iteration (default: 10)
Returns:
Dictionary with the run IDs for each model
"""
if not OLLAMA_AVAILABLE:
return {"error": "Olloma is not available. Cannot run model comparison."}
if not COLLECTOR_AVAILABLE or results_collector is None:
return {"error": "Results collector is not available. Cannot track comparison results."}
# Refresh available models
models = check_available_models()
# Filter to only include our preferred small models for faster comparison
preferred_models = ["qwen3:0.6b", "deepseek-r1:1.5b", "cogito:latest"]
comparison_models = [m for m in models if any(sm in m for sm in preferred_models)]
if not comparison_models:
return {"error": f"No suitable models found. Please pull at least one of: {preferred_models}"}
# Set up comparison config
config = _global_state["config"].copy()
config.update({
"max_iterations": iterations,
"simulations_per_iteration": simulations_per_iteration
})
# Start comparison
run_ids = results_collector.compare_models(
question=question,
models=comparison_models,
config=config,
iterations=iterations,
simulations_per_iter=simulations_per_iteration
)
return {
"status": "started",
"question": question,
"iterations": iterations,
"simulations_per_iteration": simulations_per_iteration,
"models": comparison_models,
"run_ids": run_ids
}
```
--------------------------------------------------------------------------------
/archive/tools_old.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
MCP Tools for MCTS
=================
This module defines the MCP tools that expose the MCTS functionality.
"""
import asyncio
import json
import logging
import datetime
import os # Added for os.getenv
from dotenv import load_dotenv # Added for .env loading
from typing import Dict, Any, Optional, List
from mcp.server.fastmcp import FastMCP
from .llm_adapter import DirectMcpLLMAdapter # Changed to relative import
from .ollama_utils import (
OLLAMA_PYTHON_PACKAGE_AVAILABLE, # Renamed, reflects if 'ollama' python package is installed
# OllamaAdapter moved to its own file
# SMALL_MODELS,
# MEDIUM_MODELS,
# DEFAULT_MODEL was removed from ollama_utils
check_available_models,
get_recommended_models
)
from .ollama_adapter import OllamaAdapter # Import new OllamaAdapter
# Import from the MCTS core implementation
# Make sure these imports are correct based on previous refactorings
from .mcts_core import MCTS # MCTS uses DEFAULT_CONFIG, APPROACH_TAXONOMY, APPROACH_METADATA internally
from .state_manager import StateManager
from .mcts_config import DEFAULT_CONFIG # For MCTS and general tool use
from .utils import truncate_text # For get_mcts_status
# Import the results collector
try:
from results_collector import collector as results_collector
COLLECTOR_AVAILABLE = True
except ImportError:
COLLECTOR_AVAILABLE = False
results_collector = None
# Import the analysis tools
try:
from analysis_tools import register_mcts_analysis_tools
ANALYSIS_TOOLS_AVAILABLE = True
except ImportError:
ANALYSIS_TOOLS_AVAILABLE = False
register_mcts_analysis_tools = None
logger = logging.getLogger(__name__) # Changed to __name__ for consistency
# Global state to maintain between tool calls
_global_state = {
"mcts_instance": None,
"config": None, # Will be initialized with DEFAULT_CONFIG from mcts_config.py
"state_manager": None,
"current_chat_id": None,
"active_llm_provider": os.getenv("DEFAULT_LLM_PROVIDER", "ollama"),
# DEFAULT_MODEL_NAME from .env, or None. Provider-specific defaults handled in initialize_mcts.
"active_model_name": os.getenv("DEFAULT_MODEL_NAME"),
"collect_results": COLLECTOR_AVAILABLE,
"current_run_id": None,
"ollama_available_models": [] # Specifically for Ollama, populated by check_available_models
}
def run_async(coro):
"""
Utility to run an async function in a synchronous context.
Uses a thread-based approach to avoid event loop conflicts.
"""
import threading
import concurrent.futures
import functools
# Function to run in a separate thread
def thread_runner():
result = None
exception = None
loop = None
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the coroutine and store the result
result = loop.run_until_complete(coro)
except Exception as e:
exception = e
finally:
# Clean up
if loop is not None:
loop.close()
return result, exception
# Run the function in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(thread_runner)
result, exception = future.result()
# If there was an exception, log and re-raise it
if exception is not None:
logger.error(f"Error in run_async: {exception}")
raise exception
return result
# check_available_models and get_recommended_models moved to ollama_utils.py
def register_mcts_tools(mcp: FastMCP, db_path: str):
"""
Register all MCTS-related tools with the MCP server.
Args:
mcp: The FastMCP instance to register tools with
db_path: Path to the state database
"""
global _global_state
# Load environment variables from .env file
load_dotenv()
# Initialize state manager for persistence
_global_state["state_manager"] = StateManager(db_path)
# Initialize config with defaults from mcts_config.py
_global_state["config"] = DEFAULT_CONFIG.copy()
# Populate available models from ollama_utils
_global_state["ollama_available_models"] = check_available_models()
if not _global_state["ollama_available_models"]:
logger.warning("No Ollama models detected by ollama_utils.check_available_models(). Ollama provider might not function correctly if selected.")
# Set a provider-specific default model if active_model_name is None AND current provider is ollama
if _global_state["active_llm_provider"] == "ollama" and not _global_state["active_model_name"]:
_global_state["active_model_name"] = OllamaAdapter.DEFAULT_MODEL # Use class default
# Register the analysis tools if available
if ANALYSIS_TOOLS_AVAILABLE and register_mcts_analysis_tools is not None:
# Get the results directory path
# Ensure os is imported if not already: import os
repo_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
results_dir = os.path.join(repo_dir, "results")
# Register the analysis tools
register_mcts_analysis_tools(mcp, results_dir)
logger.info("Registered MCTS analysis tools")
@mcp.tool()
def initialize_mcts(question: str, chat_id: str, provider_name: Optional[str] = None, model_name: Optional[str] = None, config_updates: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Initialize the MCTS system with a new question, LLM provider, and model.
Args:
question: The question or text to analyze.
chat_id: Unique identifier for the chat session.
provider_name: Name of the LLM provider (e.g., "ollama", "openai", "anthropic", "gemini").
Defaults to DEFAULT_LLM_PROVIDER from .env or "ollama".
model_name: Specific model name for the provider. Defaults to DEFAULT_MODEL_NAME from .env or provider-specific default.
config_updates: Optional dictionary of configuration updates.
Returns:
Dictionary with initialization status and initial analysis.
"""
global _global_state
llm_adapter = None # Initialize llm_adapter to None
try:
logger.info(f"Initializing MCTS for chat ID: {chat_id}")
# Determine target provider and model
target_provider = provider_name or _global_state["active_llm_provider"]
target_model = model_name or _global_state["active_model_name"] # This could be None if not in .env
logger.info(f"Attempting to use LLM Provider: {target_provider}, Model: {target_model}")
# Update config if provided
if config_updates:
cfg = _global_state["config"].copy()
cfg.update(config_updates)
_global_state["config"] = cfg
else:
cfg = _global_state["config"]
_global_state["current_chat_id"] = chat_id
state_manager = _global_state["state_manager"]
loaded_state = state_manager.load_state(chat_id) if cfg.get("enable_state_persistence", True) else None
if loaded_state: logger.info(f"Loaded previous state for chat ID: {chat_id}")
else: logger.info(f"No previous state found for chat ID: {chat_id}")
# Instantiate the appropriate adapter
if target_provider == "ollama":
# OllamaAdapter might need specific checks like check_available_models
if not target_model: target_model = OllamaAdapter.DEFAULT_MODEL # Use class default if still None
if target_model not in _global_state["ollama_available_models"]: # Check against list for this provider
return {
"status": "model_error",
"error": f"Ollama model '{target_model}' not in available list: {_global_state['ollama_available_models']}",
"message": "Please select an available Ollama model or ensure it is pulled."
}
llm_adapter = OllamaAdapter(model_name=target_model)
elif target_provider == "openai":
from .openai_adapter import OpenAIAdapter
if not target_model: target_model = OpenAIAdapter.DEFAULT_MODEL
llm_adapter = OpenAIAdapter(api_key=os.getenv("OPENAI_API_KEY"), model_name=target_model)
elif target_provider == "anthropic":
from .anthropic_adapter import AnthropicAdapter
if not target_model: target_model = AnthropicAdapter.DEFAULT_MODEL
llm_adapter = AnthropicAdapter(api_key=os.getenv("ANTHROPIC_API_KEY"), model_name=target_model)
elif target_provider == "gemini":
from .gemini_adapter import GeminiAdapter
if not target_model: target_model = GeminiAdapter.DEFAULT_MODEL
llm_adapter = GeminiAdapter(api_key=os.getenv("GEMINI_API_KEY"), model_name=target_model)
else:
return {"error": f"Unsupported LLM provider: {target_provider}. Supported: ollama, openai, anthropic, gemini.", "status": "error"}
_global_state["active_llm_provider"] = target_provider
_global_state["active_model_name"] = target_model
logger.info(f"Successfully initialized LLM adapter for Provider: {target_provider}, Model: {target_model}")
# Test adapter (optional, can be removed for speed)
try:
async def test_adapter_briefly():
return await llm_adapter.get_completion(model=target_model, messages=[{"role": "user", "content": "Brief test query."}])
test_result = run_async(test_adapter_briefly())
logger.info(f"Adapter test successful: {truncate_text(test_result, 50)}")
except Exception as e:
logger.error(f"Failed to test LLM adapter for {target_provider} model {target_model}: {e}", exc_info=True)
# If adapter test fails, it's a significant issue. Let the error propagate or return specific error.
# Removing the DirectMcpLLMAdapter fallback here as 'mcp' is not in local scope
# and primary adapter initialization errors (e.g. API keys) are caught by ValueError.
return {"error": f"LLM adapter for {target_provider} failed test: {e}", "status": "adapter_test_error"}
# Generate initial analysis
initial_prompt_format = "<instruction>Provide an initial analysis and interpretation of the core themes, arguments, and potential implications presented. Identify key concepts. Respond with clear, natural language text ONLY.</instruction><question>{question}</question>"
initial_messages = [{"role": "user", "content": initial_prompt_format.format(question=question)}]
initial_analysis = run_async(llm_adapter.get_completion(model=target_model, messages=initial_messages))
_global_state["mcts_instance"] = MCTS(
llm_interface=llm_adapter,
question=question,
initial_analysis_content=initial_analysis or "No initial analysis available",
config=cfg,
initial_state=loaded_state
)
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None:
_global_state["current_run_id"] = results_collector.start_run(
model_name=target_model,
question=question,
config=cfg
)
logger.info(f"Started collecting results for run ID: {_global_state['current_run_id']}")
return {
"status": "initialized",
"question": question,
"chat_id": chat_id,
"initial_analysis": initial_analysis,
"loaded_state": loaded_state is not None,
"provider": target_provider,
"model_used": target_model,
"config": {k: v for k, v in cfg.items() if not k.startswith("_")},
"run_id": _global_state.get("current_run_id")
}
except ValueError as ve: # Catch API key errors specifically
logger.error(f"Configuration error in initialize_mcts: {ve}", exc_info=True)
return {"error": f"Configuration error: {str(ve)}", "status": "config_error"}
except Exception as e:
logger.error(f"Error in initialize_mcts: {e}", exc_info=True)
return {"error": f"Failed to initialize MCTS: {str(e)}", "status": "error"}
@mcp.tool()
def set_active_llm(provider_name: str, model_name: Optional[str] = None) -> Dict[str, Any]:
"""
Set the active LLM provider and optionally a model name for future MCTS runs.
Args:
provider_name: Name of the LLM provider (e.g., "ollama", "openai", "anthropic", "gemini").
model_name: Optional specific model name for the provider. If None, provider's default will be used.
Returns:
Status message.
"""
global _global_state
supported_providers = ["ollama", "openai", "anthropic", "gemini"]
provider_name_lower = provider_name.lower()
if provider_name_lower not in supported_providers:
return {
"status": "error",
"message": f"Unsupported LLM provider: '{provider_name}'. Supported providers are: {supported_providers}"
}
_global_state["active_llm_provider"] = provider_name_lower
# If a model name is provided, set it. Otherwise, it will use the provider's default or .env DEFAULT_MODEL_NAME
# specific to that provider during initialize_mcts.
_global_state["active_model_name"] = model_name
log_msg = f"Set active LLM provider to: {provider_name_lower}."
if model_name:
log_msg += f" Set active model to: {model_name}."
else:
log_msg += f" Active model will be provider's default or from .env DEFAULT_MODEL_NAME."
logger.info(log_msg)
return {
"status": "success",
"message": log_msg
}
@mcp.tool()
def list_ollama_models() -> Dict[str, Any]:
"""
List all available Ollama models.
Returns:
Dictionary with available models and their details
"""
logger.info("Listing Ollama models...")
# Check if Ollama server is running first
try:
import httpx
client = httpx.Client(base_url="http://localhost:11434", timeout=3.0)
response = client.get("/")
if response.status_code != 200:
logger.error(f"Ollama server health check failed with status code: {response.status_code}")
return {
"status": "error",
"message": "Ollama server not responding. Please ensure Ollama is running with 'ollama serve'",
"diagnostics": {
"server_check": f"Failed with status {response.status_code}",
"server_url": "http://localhost:11434"
}
}
logger.info("Ollama server is running and responding to requests")
except Exception as e:
logger.error(f"Ollama server health check failed: {e}")
return {
"status": "error",
"message": "Unable to connect to Ollama server. Please ensure Ollama is running with 'ollama serve'",
"diagnostics": {
"error": str(e),
"server_url": "http://localhost:11434"
}
}
# Get models using our comprehensive function
available_models = check_available_models()
# If we got no models, return detailed error
if not available_models:
return {
"status": "error",
"message": "No Ollama models detected. You may need to pull models using 'ollama pull MODEL_NAME'",
"diagnostics": {
"server_check": "Server appears to be running but no models detected",
"suggestion": "Try running 'ollama pull qwen3:0.6b' or 'ollama pull cogito:latest' to download a model"
}
}
# Get more detailed model information when possible
model_details = []
try:
import subprocess
import json
import sys
# Try using ollama show command to get detailed info
for model in available_models:
try:
if sys.platform == 'win32':
cmd = ['ollama.exe', 'show', model, '--json']
else:
cmd = ['ollama', 'show', model, '--json']
result = subprocess.run(cmd, capture_output=True, text=True, check=False)
if result.returncode == 0 and result.stdout.strip():
try:
details = json.loads(result.stdout)
model_details.append({
"name": model,
"parameter_size": details.get("parameter_size", "unknown"),
"quantization": details.get("quantization_level", "unknown"),
"family": details.get("family", "unknown"),
"size_mb": round(details.get("size", 0) / (1024 * 1024), 1)
})
except json.JSONDecodeError:
# Add basic info if JSON parsing fails
model_details.append({
"name": model,
"parameter_size": "unknown",
"quantization": "unknown",
"family": "unknown"
})
except Exception as e:
logger.warning(f"Error getting details for model {model}: {e}")
# Still add basic info
model_details.append({
"name": model,
"note": "Details unavailable"
})
except Exception as e:
logger.warning(f"Error getting detailed model information: {e}")
# Get recommended models
recommendations = get_recommended_models(available_models)
# Determine if a model is already selected (specific to Ollama for this tool)
current_ollama_model = _global_state.get("active_model_name") if _global_state.get("active_llm_provider") == "ollama" else None
if not current_ollama_model and _global_state.get("active_llm_provider") == "ollama":
# If provider is ollama but no model is set, try the class default for OllamaAdapter
current_ollama_model = OllamaAdapter.DEFAULT_MODEL
model_selected = current_ollama_model is not None and current_ollama_model in available_models
# Customize message based on model selection status
if model_selected:
message = f"Current Ollama model for 'ollama' provider: {current_ollama_model}. Use set_active_llm to change provider or model."
else:
message = "To use Ollama, please use set_active_llm(provider_name='ollama', model_name='your_model') to select a model."
# Clear active_model_name if it's an Ollama model but not in the available list
if _global_state.get("active_llm_provider") == "ollama" and current_ollama_model and current_ollama_model not in available_models:
logger.warning(f"Current active model {current_ollama_model} (Ollama) not found in available models. Clearing active_model_name.")
_global_state["active_model_name"] = None # Will fall back to default or require setting
current_ollama_model = None # For the message
# Update global state with Ollama-specific available models
_global_state["ollama_available_models"] = available_models # Store specifically for ollama
return {
"status": "success",
"ollama_available_models": available_models, # Specifically Ollama models
"model_details": model_details, # Details for Ollama models
"current_ollama_model_for_provider": current_ollama_model, # If provider is ollama
"recommended_small_ollama_models": recommendations["small_models"],
"recommended_medium_ollama_models": recommendations["medium_models"],
"message": message,
"model_selected": model_selected
}
@mcp.tool()
def run_mcts(iterations: int = 1, simulations_per_iteration: int = 5, model_name: Optional[str] = None) -> Dict[str, Any]: # Restored def
"""
Run the MCTS algorithm for the specified number of iterations using the currently active LLM provider and model.
The model_name parameter here is currently not used, as the model is determined by initialize_mcts or set_active_llm.
This could be changed to allow overriding the model for a specific run if desired.
Args:
iterations: Number of MCTS iterations to run (default: 1)
simulations_per_iteration: Number of simulations per iteration (default: 5)
model_name: (Currently not used, model is taken from _global_state set by initialize_mcts or set_active_llm)
Returns:
Dictionary with status message about the background run
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {"error": "MCTS not initialized. Call initialize_mcts first."}
active_provider = _global_state.get("active_llm_provider")
active_model = _global_state.get("active_model_name")
if not active_provider or not active_model:
return {"error": "Active LLM provider or model not set. Call initialize_mcts or set_active_llm first."}
# Override config values for this run
temp_config = mcts.config.copy()
temp_config["max_iterations"] = iterations
temp_config["simulations_per_iteration"] = simulations_per_iteration
mcts.config = temp_config # This updates the config in the MCTS instance
logger.info(f"Starting MCTS background run with {iterations} iterations, {simulations_per_iteration} simulations per iteration using Provider: {active_provider}, Model: {active_model}...")
# Update collector status if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.update_run_status(
_global_state["current_run_id"],
"running",
{
"iterations": iterations,
"simulations_per_iteration": simulations_per_iteration,
"provider": active_provider,
"model": active_model,
"timestamp": int(datetime.datetime.now().timestamp())
}
)
# Start MCTS search in a background thread and return immediately
import threading
def run_mcts_background():
try:
# Run the search asynchronously (wrap in run_async)
async def run_search():
await mcts.run_search_iterations(iterations, simulations_per_iteration)
return mcts.get_final_results()
results = run_async(run_search())
# After completion, save state and update results
if temp_config.get("enable_state_persistence", True) and _global_state["current_chat_id"]:
try:
_global_state["state_manager"].save_state(_global_state["current_chat_id"], mcts)
logger.info(f"Saved state for chat ID: {_global_state['current_chat_id']}")
except Exception as e:
logger.error(f"Error saving state: {e}")
# Find best node and tags
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
# Prepare results
result_dict = {
"status": "completed",
"best_score": results.best_score if results else 0.0,
"best_solution": results.best_solution_content if results else "No solution found",
"tags": tags,
"iterations_completed": mcts.iterations_completed,
"simulations_completed": mcts.simulations_completed,
"provider": _global_state.get("active_llm_provider"),
"model": _global_state.get("active_model_name"),
}
# Save results to collector if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.save_run_results(_global_state["current_run_id"], result_dict)
logger.info(f"Saved results for run ID: {_global_state['current_run_id']}")
except Exception as e:
logger.error(f"Error in background MCTS run: {e}")
# Update collector with failure if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.update_run_status(
_global_state["current_run_id"],
"failed",
{"error": str(e), "timestamp": int(datetime.datetime.now().timestamp())}
)
# Start the background thread
background_thread = threading.Thread(target=run_mcts_background)
background_thread.daemon = True # Allow the thread to exit when the main process exits
background_thread.start()
# Return immediately with a status message
return {
"status": "started",
"message": f"MCTS process started in background with {iterations} iterations and {simulations_per_iteration} simulations per iteration.",
"provider": _global_state.get("active_llm_provider"),
"model": _global_state.get("active_model_name"),
"run_id": _global_state.get("current_run_id"),
"results_path": f"/home/ty/Repositories/ai_workspace/mcts-mcp-server/results/{_global_state.get('active_llm_provider')}_{_global_state.get('active_model_name')}_{_global_state.get('current_run_id')}", # Adjusted path
"background_thread_id": background_thread.ident
}
@mcp.tool()
def generate_synthesis() -> Dict[str, Any]:
"""
Generate a final synthesis of the MCTS results.
Returns:
Dictionary with the synthesis and related information
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {"error": "MCTS not initialized. Call initialize_mcts first."}
logger.info("Generating synthesis of MCTS results...")
# Use same LLM adapter as the MCTS instance
llm_adapter = mcts.llm
async def synth():
# Prepare context for synthesis
path_nodes = mcts.get_best_path_nodes()
path_thoughts_list = [
f"- (Node {node.sequence}): {node.thought.strip()}"
for node in path_nodes if node.thought and node.parent
]
path_thoughts_str = "\n".join(path_thoughts_list) if path_thoughts_list else "No significant development path identified."
results = mcts.get_final_results()
synth_context = {
"question_summary": mcts.question_summary,
"initial_analysis_summary": truncate_text(mcts.root.content, 300) if mcts.root else "N/A",
"best_score": f"{results.best_score:.1f}",
"path_thoughts": path_thoughts_str,
"final_best_analysis_summary": truncate_text(results.best_solution_content, 400),
# Add defaults for unused but required keys
"previous_best_summary": "N/A",
"unfit_markers_summary": "N/A",
"learned_approach_summary": "N/A"
}
# Use the synthesize_result method from the LLMInterface
synthesis = await llm_adapter.synthesize_result(synth_context, mcts.config)
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
return {
"synthesis": synthesis,
"best_score": results.best_score,
"tags": tags,
"iterations_completed": mcts.iterations_completed,
"provider": _global_state.get("active_llm_provider"),
"model": _global_state.get("active_model_name"),
}
try:
synthesis_result = run_async(synth())
# Handle case where synthesis_result is None
if synthesis_result is None:
return {"error": "Synthesis generation returned no result"}
# Update results in collector if enabled
if _global_state["collect_results"] and COLLECTOR_AVAILABLE and results_collector is not None and _global_state.get("current_run_id"):
results_collector.update_run_status(
_global_state["current_run_id"],
"completed",
{"synthesis": synthesis_result.get("synthesis")}
)
synthesis_result["run_id"] = _global_state["current_run_id"]
return synthesis_result
except Exception as e:
logger.error(f"Error generating synthesis: {e}")
return {"error": f"Synthesis generation failed: {str(e)}"}
@mcp.tool()
def get_config() -> Dict[str, Any]:
"""
Get the current MCTS configuration.
Returns:
Dictionary with the current configuration values
"""
global _global_state
# Add active LLM provider and model info
config = {k: v for k, v in _global_state["config"].items() if not k.startswith("_")}
config.update({
"active_llm_provider": _global_state.get("active_llm_provider"),
"active_model_name": _global_state.get("active_model_name"),
"ollama_python_package_available": OLLAMA_PYTHON_PACKAGE_AVAILABLE, # For info
"ollama_available_models": _global_state.get("ollama_available_models", []),
"collect_results": _global_state.get("collect_results"),
"current_run_id": _global_state.get("current_run_id")
})
return config
@mcp.tool()
def update_config(config_updates: Dict[str, Any]) -> Dict[str, Any]:
"""
Update the MCTS configuration. Provider and model are updated via set_active_llm.
Args:
config_updates: Dictionary with MCTS configuration keys and values to update.
To change LLM provider or model, use `set_active_llm` tool.
Returns:
Dictionary with the updated configuration.
"""
global _global_state
logger.info(f"Updating MCTS config with: {config_updates}")
# Provider and model name changes should be handled by set_active_llm
if "active_llm_provider" in config_updates or "active_model_name" in config_updates:
logger.warning("Use 'set_active_llm' tool to change LLM provider or model name. These keys will be ignored in update_config.")
config_updates.pop("active_llm_provider", None)
config_updates.pop("active_model_name", None)
config_updates.pop("ollama_model", None) # old key
if "collect_results" in config_updates:
_global_state["collect_results"] = bool(config_updates.pop("collect_results"))
# Update regular MCTS config
cfg = _global_state["config"].copy()
cfg.update(config_updates) # Apply remaining valid config updates
_global_state["config"] = cfg
mcts = _global_state.get("mcts_instance")
if mcts:
mcts.config = cfg # Update config in existing MCTS instance
# Return current effective config using get_config()
return get_config()
@mcp.tool()
def get_mcts_status() -> Dict[str, Any]:
"""
Get the current status of the MCTS system.
Returns:
Dictionary with status information
"""
global _global_state
mcts = _global_state.get("mcts_instance")
if not mcts:
return {
"initialized": False,
"message": "MCTS not initialized. Call initialize_mcts first."
}
try:
# Get best node and extract information
best_node = mcts.find_best_final_node()
tags = best_node.descriptive_tags if best_node else []
return {
"initialized": True,
"chat_id": _global_state.get("current_chat_id"),
"iterations_completed": getattr(mcts, "iterations_completed", 0),
"simulations_completed": getattr(mcts, "simulations_completed", 0),
"best_score": getattr(mcts, "best_score", 0.0),
"best_content_summary": truncate_text(getattr(mcts, "best_solution", ""), 100)
if hasattr(mcts, "best_solution") else "N/A",
"tags": tags,
"tree_depth": mcts.memory.get("depth", 0) if hasattr(mcts, "memory") else 0,
"approach_types": getattr(mcts, "approach_types", []),
"active_llm_provider": _global_state.get("active_llm_provider"),
"active_model_name": _global_state.get("active_model_name"),
"collected_results": _global_state.get("collect_results"),
"run_id": _global_state.get("current_run_id")
}
except Exception as e:
logger.error(f"Error getting MCTS status: {e}")
return {
"initialized": True,
"error": f"Error getting MCTS status: {str(e)}",
"chat_id": _global_state.get("current_chat_id")
}
@mcp.tool()
def run_model_comparison(question: str, iterations: int = 2, simulations_per_iteration: int = 10) -> Dict[str, Any]:
"""
Run MCTS with the same question across multiple models for comparison.
Args:
question: The question to analyze with MCTS
iterations: Number of MCTS iterations per model (default: 2)
simulations_per_iteration: Simulations per iteration (default: 10)
Returns:
Dictionary with the run IDs for each model
"""
# Use OLLAMA_PYTHON_PACKAGE_AVAILABLE to check if the package needed for check_available_models is there
if not OLLAMA_PYTHON_PACKAGE_AVAILABLE: # Check if the ollama python lib is there
return {"error": "Ollama python package not available. Cannot run Ollama model comparison."}
if not COLLECTOR_AVAILABLE or results_collector is None:
return {"error": "Results collector is not available. Cannot track comparison results."}
# Refresh available models
models = check_available_models()
# Filter to only include our preferred small models for faster comparison using constants from ollama_utils
# Import SMALL_MODELS if needed here, or pass them to this function, or make get_recommended_models more flexible
# For now, assuming get_recommended_models in ollama_utils handles this logic
recommendations = get_recommended_models(models)
comparison_models = recommendations["small_models"] # Example: Use recommended small models
if not comparison_models:
# If no "small" models, maybe try medium or any available? For now, error out.
return {"error": f"No suitable Ollama models found from recommended small list for comparison. Available: {models}"}
# Set up comparison config
config = _global_state["config"].copy()
config.update({
"max_iterations": iterations,
"simulations_per_iteration": simulations_per_iteration
})
# Start comparison
run_ids = results_collector.compare_models(
question=question,
models=comparison_models,
config=config,
iterations=iterations,
simulations_per_iter=simulations_per_iteration
)
return {
"status": "started",
"question": question,
"iterations": iterations,
"simulations_per_iteration": simulations_per_iteration,
"models": comparison_models,
"run_ids": run_ids
}
```
--------------------------------------------------------------------------------
/src/mcts_mcp_server/mcts_core.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Core MCTS Implementation
=======================
This module implements the Monte Carlo Tree Search (MCTS) algorithm
for advanced analysis and reasoning.
"""
import asyncio
import copy # Added for deep copy
import logging
import math
import random
import re
# import json # No longer used directly
# import os # No longer used directly
# import sqlite3 # No longer used directly
# from datetime import datetime # No longer used directly
from collections import Counter, namedtuple
from typing import ( # Using built-in types for Dict, List, Set, Tuple; Protocol unused
Any,
)
import numpy as np
from .intent_handler import (
EVAL_ANSWER_PROMPT,
FINAL_SYNTHESIS_PROMPT,
INITIAL_PROMPT,
INTENT_CLASSIFIER_PROMPT,
TAG_GENERATION_PROMPT,
THOUGHTS_PROMPT,
UPDATE_PROMPT,
IntentHandler,
IntentResult,
)
from .llm_interface import LLMInterface # Moved to its own file
# Note: TfidfVectorizer, ENGLISH_STOP_WORDS, cosine_similarity are now managed within utils.py
# If mcts_core directly needs them elsewhere, specific imports might be needed,
# but for now, they are primarily used by functions moved to utils.py.
from .mcts_config import APPROACH_METADATA, APPROACH_TAXONOMY, DEFAULT_CONFIG
from .node import Node
from .utils import (
SKLEARN_AVAILABLE,
_summarize_text,
calculate_semantic_distance,
setup_logger,
truncate_text,
)
# Initialize main logger for this module using the utility setup function
logger = setup_logger(__name__)
# ==============================================================================
# MCTS Class
# ==============================================================================
# Define a simple structure to hold MCTS results
MCTSResult = namedtuple("MCTSResult", ["best_score", "best_solution_content", "mcts_instance"])
class MCTS:
"""Implements the Monte Carlo Tree Search algorithm for analysis."""
def __init__(self,
llm_interface: LLMInterface,
question: str,
initial_analysis_content: str,
config: dict[str, Any] | None = None,
initial_state: dict[str, Any] | None = None):
"""
Initializes the MCTS instance.
Args:
llm_interface: An object implementing the LLMInterface protocol.
question: The original user question or text to analyze.
initial_analysis_content: The initial analysis generated by the LLM.
config: MCTS configuration dictionary (uses DEFAULT_CONFIG if None).
initial_state: Optional dictionary containing state loaded from a previous run.
"""
self.llm = llm_interface
self.question = question
self.config = config if config is not None else copy.deepcopy(DEFAULT_CONFIG)
self.debug_logging = self.config.get("debug_logging", False)
# Update logger level based on config (logger is now module-level)
# The setup_logger in utils.py can be called again if needed,
# or we can adjust the existing logger's level directly.
# For now, the MCTS __init__ will set its own logger's level.
# Use a dedicated logger for each MCTS instance to avoid conflicts with global logger level.
self.logger = logging.getLogger(f"{__name__}.MCTS.{id(self)}")
if self.debug_logging:
self.logger.setLevel(logging.DEBUG)
else:
self.logger.setLevel(logging.INFO)
# Explicitly set loaded_initial_state for consistent attribute access
self.loaded_initial_state = initial_state if initial_state is not None else {}
self.question_summary = _summarize_text(self.question, max_words=50) # Use imported _summarize_text
# Initialize node_sequence to avoid duplicates if loading from previous state
if self.loaded_initial_state and "max_sequence" in self.loaded_initial_state:
self.node_sequence = int(self.loaded_initial_state["max_sequence"])
else:
self.node_sequence = 0
# Runtime state
self.iterations_completed = 0
self.random_state = random.SystemRandom() # Use a cryptographically secure random instance
self.explored_approaches: dict[str, list[str]] = {} # Track thoughts per approach type
self.explored_thoughts: set[str] = set() # Track unique thoughts generated
self.approach_types: list[str] = ["initial"] # Track unique approach types encountered
self.surprising_nodes: list[Node] = []
# Initialize memory, loading depth and branches from initial_state if available
loaded_depth = self.loaded_initial_state.get("depth", 0) if hasattr(self, "loaded_initial_state") and self.loaded_initial_state else 0
loaded_branches = self.loaded_initial_state.get("branches", 0) if hasattr(self, "loaded_initial_state") and self.loaded_initial_state else 0
self.memory: dict[str, Any] = {
"depth": loaded_depth,
"branches": loaded_branches,
"high_scoring_nodes": [],
}
# Store unfit markers loaded/identified
self.unfit_markers: list[dict[str, Any]] = []
if hasattr(self, "loaded_initial_state") and self.loaded_initial_state:
loaded_unfit = self.loaded_initial_state.get("unfit_markers", [])
if isinstance(loaded_unfit, list):
self.unfit_markers = loaded_unfit
else:
logger.warning("Loaded unfit_markers is not a list; ignoring.")
# --- Initialize Priors and Best Solution based on Loaded State ---
prior_alpha = max(1e-9, self.config["beta_prior_alpha"])
prior_beta = max(1e-9, self.config["beta_prior_beta"])
# Approach Priors (for Bayesian mode)
self.approach_alphas: dict[str, float] = {}
self.approach_betas: dict[str, float] = {}
initial_priors = self.loaded_initial_state.get("approach_priors", {}) if hasattr(self, "loaded_initial_state") and self.loaded_initial_state else {}
if (self.config["use_bayesian_evaluation"] and initial_priors and
isinstance(initial_priors.get("alpha"), dict) and
isinstance(initial_priors.get("beta"), dict)):
self.approach_alphas = {k: max(1e-9, v) for k, v in initial_priors["alpha"].items()}
self.approach_betas = {k: max(1e-9, v) for k, v in initial_priors["beta"].items()}
logger.info("Loaded approach priors from previous state.")
else:
# Initialize default priors for all known approaches + initial/variant
all_approach_keys = [*APPROACH_TAXONOMY.keys(), "initial", "variant"]
self.approach_alphas = dict.fromkeys(all_approach_keys, prior_alpha)
self.approach_betas = dict.fromkeys(all_approach_keys, prior_beta)
if self.config["use_bayesian_evaluation"]:
logger.info("Initialized default approach priors.")
# Approach Scores (for non-Bayesian mode) - Simple average tracking
self.approach_scores: dict[str, float] = {} # Average score per approach
# Best Solution Tracking
self.best_score: float = 0.0
self.best_solution: str = initial_analysis_content # Start with the initial analysis
if self.loaded_initial_state:
self.best_score = float(self.loaded_initial_state.get("best_score", 0.0))
# Store the previously best solution content for context or comparison purposes.
# This is not used as the starting point for the new search, but may be useful for reporting or analysis.
self.previous_best_solution_content = self.loaded_initial_state.get("best_solution_content")
logger.info(f"Initialized best score ({self.best_score}) tracker from previous state.")
# Load unfit markers
self.unfit_markers = self.loaded_initial_state.get("unfit_markers", [])
if self.unfit_markers:
logger.info(f"Loaded {len(self.unfit_markers)} unfit markers from previous state.")
# --- Initialize Root Node ---
self.root = Node(
content=initial_analysis_content,
sequence=self.get_next_sequence(),
parent=None,
max_children=self.config["max_children"],
use_bayesian_evaluation=self.config["use_bayesian_evaluation"],
beta_prior_alpha=prior_alpha, # Root starts with default priors
beta_prior_beta=prior_beta,
approach_type="initial",
approach_family="general",
)
# Initial simulation/backpropagation for the root node?
# Not doing this in the original code, root starts with 0 visits/priors.
logger.info(f"MCTS Initialized. Root Node Seq: {self.root.sequence}. Initial Best Score: {self.best_score:.2f}")
if self.debug_logging:
logger.debug(f"Initial Root Content: {truncate_text(self.root.content, 100)}")
def get_next_sequence(self) -> int:
"""
Gets the next sequential ID for a node.
Returns:
The next available sequence number for node identification
"""
self.node_sequence += 1
return self.node_sequence
def get_context_for_node(self, node: Node) -> dict[str, str]:
"""
Gathers comprehensive context for LLM prompts based on current MCTS state.
Args:
node: The node to generate context for
Returns:
Dictionary containing all context information with string values
Note:
Includes context from loaded state, current run, sibling nodes, and high-scoring examples
"""
cfg = self.config
best_answer_str = str(self.best_solution) if self.best_solution else "N/A"
# --- Base Context ---
context = {
"question_summary": self.question_summary,
"best_answer": truncate_text(best_answer_str, 300),
"best_score": f"{self.best_score:.1f}",
"current_answer": truncate_text(node.content, 300),
"current_sequence": str(node.sequence),
"current_approach": node.approach_type,
"current_tags": ", ".join(node.descriptive_tags) if node.descriptive_tags else "None",
"tree_depth": str(self.memory.get("depth", 0)),
"branches": str(self.memory.get("branches", 0)),
"approach_types": ", ".join(self.approach_types),
# Initialize context from loaded state with defaults
"previous_best_summary": "N/A",
"unfit_markers_summary": "None",
"learned_approach_summary": "Default priors",
"explored_approaches": "None yet.",
"high_scoring_examples": "None yet.",
"sibling_approaches": "", # Default empty, populated below if applicable
}
# --- Context from Loaded State ---
if self.loaded_initial_state:
context["previous_best_summary"] = self.loaded_initial_state.get("best_solution_summary", "N/A")
unfit = self.loaded_initial_state.get("unfit_markers", [])
if unfit:
markers_str = "; ".join([
f"'{m.get('summary', m.get('id', 'Unknown'))}' ({m.get('reason', 'Unknown')})"
for m in unfit[:5] # Show first 5
])
context["unfit_markers_summary"] = markers_str + ("..." if len(unfit) > 5 else "")
else:
context["unfit_markers_summary"] = "None recorded"
priors = self.loaded_initial_state.get("approach_priors")
if priors and "alpha" in priors and "beta" in priors:
means = {}
for app, alpha in priors["alpha"].items():
beta = priors["beta"].get(app, 1.0)
alpha, beta = max(1e-9, alpha), max(1e-9, beta)
if alpha + beta > 1e-9:
means[app] = (alpha / (alpha + beta)) * 10
sorted_means = sorted(means.items(), key=lambda item: item[1], reverse=True)
top_approaches = [f"{app} ({score:.1f})" for app, score in sorted_means[:3]]
context["learned_approach_summary"] = f"Favors: {', '.join(top_approaches)}" + ("..." if len(sorted_means) > 3 else "")
else:
context["learned_approach_summary"] = "Priors not loaded or incomplete"
# --- Context from Current MCTS Run ---
try: # Explored Thought Types (using current run's data)
if cfg["track_explored_approaches"] and self.explored_approaches:
exp_app_text = []
current_alphas = self.approach_alphas
current_betas = self.approach_betas
sorted_approach_keys = sorted(self.explored_approaches.keys())
for app in sorted_approach_keys:
thoughts = self.explored_approaches.get(app, [])
if thoughts:
count = len(thoughts)
score_text = ""
if cfg["use_bayesian_evaluation"]:
alpha = current_alphas.get(app, 1)
beta = current_betas.get(app, 1)
alpha, beta = max(1e-9, alpha), max(1e-9, beta)
if (alpha + beta) > 1e-9:
score_text = f"(β-Mean: {alpha / (alpha + beta):.2f}, N={count})"
else:
score_text = f"(N={count})"
else:
score = self.approach_scores.get(app, 0) # Use simple avg score
count_non_bayes = sum(1 for n in self._find_nodes_by_approach(app) if n.visits > 0) # More accurate count?
if count_non_bayes > 0:
score_text = f"(Avg: {score:.1f}, N={count_non_bayes})" # Use avg score if tracked
else:
score_text = f"(N={count})"
sample_count = min(2, len(thoughts))
sample = thoughts[-sample_count:]
exp_app_text.append(f"- {app} {score_text}: {'; '.join([f'{truncate_text(str(t), 50)}' for t in sample])}")
if exp_app_text:
context["explored_approaches"] = "\n".join(exp_app_text)
except Exception as e:
logger.error(f"Ctx err (approaches): {e}")
context["explored_approaches"] = "Error generating approach context."
try: # High Scoring Examples
if self.memory["high_scoring_nodes"]:
high_score_text = [
f"- Score {score:.1f} ({app}): {truncate_text(content, 70)}"
for score, content, app, thought in self.memory["high_scoring_nodes"]
]
context["high_scoring_examples"] = "\n".join(["Top Examples:", *high_score_text])
except Exception as e:
logger.error(f"Ctx err (high scores): {e}")
context["high_scoring_examples"] = "Error generating high score context."
try: # Sibling Context
if cfg["sibling_awareness"] and node.parent and len(node.parent.children) > 1:
siblings = [c for c in node.parent.children if c is not None and c != node and c.visits > 0] # Only visited siblings
if siblings:
sib_app_text = []
sorted_siblings = sorted(siblings, key=lambda s: s.sequence)
for s in sorted_siblings:
if s.thought: # Only show siblings that generated a thought
score = s.get_average_score()
tags_str = f"Tags: [{', '.join(s.descriptive_tags)}]" if s.descriptive_tags else ""
sib_app_text.append(f'"{truncate_text(str(s.thought), 50)}" -> (Score: {score:.1f} {tags_str})')
if sib_app_text:
context["sibling_approaches"] = "\n".join(["Siblings:"] + [f"- {sa}" for sa in sib_app_text])
except Exception as e:
logger.error(f"Ctx err (siblings): {e}")
context["sibling_approaches"] = "Error generating sibling context."
# Ensure all values are strings for final formatting
safe_context = {k: str(v) if v is not None else "" for k, v in context.items()}
return safe_context
def _calculate_uct(self, node: Node, parent_visits: int) -> float:
"""
Calculates the UCT (Upper Confidence Bound for Trees) score for node selection.
Args:
node: Node to calculate UCT score for
parent_visits: Number of visits to the parent node
Returns:
UCT score (higher values indicate better selection candidates)
Note:
Incorporates exploitation, exploration, surprise bonus, diversity bonus, and unfit penalties
"""
cfg = self.config
if node.visits == 0:
return float('inf') # Prioritize unvisited nodes
# 1. Exploitation Term (normalized 0-1)
exploitation = node.get_bayesian_mean() if cfg["use_bayesian_evaluation"] else (node.get_average_score() / 10.0)
# 2. Exploration Term
log_parent_visits = math.log(max(1, parent_visits))
exploration = cfg["exploration_weight"] * math.sqrt(log_parent_visits / node.visits)
# 3. Penalty for Unfit Nodes (using loaded/identified markers)
penalty = 0.0
is_unfit = False
if self.unfit_markers:
node_summary = node.thought or node.content # Use thought if available, else content
node_tags_set = set(node.descriptive_tags)
for marker in self.unfit_markers:
# Quick checks first
if marker.get("id") == node.id or marker.get("sequence") == node.sequence:
is_unfit = True
break
# Check content similarity using node_summary
marker_summary = marker.get('summary', '')
if marker_summary and node_summary and calculate_semantic_distance(str(node_summary), str(marker_summary)) < 0.2:
is_unfit = True
break
# Check tag overlap
marker_tags_set = set(marker.get('tags', []))
if node_tags_set and marker_tags_set and len(node_tags_set.intersection(marker_tags_set)) > 0:
is_unfit = True
break # Simple tag overlap check
# Apply penalty if unfit and *not* surprising (allow surprise to override)
if is_unfit and not node.is_surprising:
penalty = -100.0 # Strong penalty to avoid selecting unfit nodes
if self.debug_logging:
logger.debug(f"Applying UCT penalty to unfit Node {node.sequence}")
# 4. Surprise Bonus
surprise_bonus = 0.3 if node.is_surprising else 0.0 # Simple fixed bonus
# 5. Diversity Bonus (relative to siblings)
diversity_bonus = 0.0
if node.parent and len(node.parent.children) > 1 and cfg["score_diversity_bonus"] > 0:
my_score_norm = node.get_average_score() / 10.0
sibling_scores = [
(sib.get_average_score() / 10.0)
for sib in node.parent.children
if sib is not None and sib != node and sib.visits > 0
]
if sibling_scores:
sibling_avg = sum(sibling_scores) / len(sibling_scores)
diversity_bonus = cfg["score_diversity_bonus"] * abs(my_score_norm - sibling_avg)
# Combine terms
uct_value = exploitation + exploration + surprise_bonus + diversity_bonus + penalty
# Ensure finite return, default to low value if not
return uct_value if math.isfinite(uct_value) else -float('inf')
def _collect_non_leaf_nodes(self, node: Node, non_leaf_nodes: list[Node], max_depth: int, current_depth: int = 0):
"""
Helper method to find nodes that can still be expanded within a depth limit.
Args:
node: Current node to examine
non_leaf_nodes: List to append expandable nodes to
max_depth: Maximum depth to search
current_depth: Current recursion depth
"""
if current_depth > max_depth or node is None:
return
# Node is non-leaf if it HAS children AND is not fully expanded yet
if node.children and not node.fully_expanded():
non_leaf_nodes.append(node)
for child in node.children:
# Recursive call only if child exists
if child is not None:
self._collect_non_leaf_nodes(child, non_leaf_nodes, max_depth, current_depth + 1)
async def select(self) -> Node:
"""
Selects a node for expansion using UCT or Thompson Sampling.
Returns:
Selected leaf or expandable node for the next expansion
Note:
Implements branch enhancement for forced exploration and handles both UCT and Thompson sampling
"""
cfg = self.config
node = self.root
selection_path_ids = [node.id] # Track path by ID
# Optional: Branch Enhancement (Force exploration of less visited branches)
force_interval = cfg["force_exploration_interval"]
if (force_interval > 0 and self.simulations_completed > 0 and
self.simulations_completed % force_interval == 0 and self.memory["depth"] > 1):
candidate_nodes = []
# Collect expandable nodes up to half the current max depth
self._collect_non_leaf_nodes(self.root, candidate_nodes, max_depth=max(1, self.memory["depth"] // 2))
expandable_candidates = [n for n in candidate_nodes if not n.fully_expanded()]
if expandable_candidates:
forced_node = self.random_state.choice(expandable_candidates)
if self.debug_logging:
logger.debug(f"BRANCH ENHANCE: Forcing selection of Node {forced_node.sequence}")
# Need to return the actual node selected by force
return forced_node # Exit selection early with the forced node
# Standard Selection Loop
while node.children: # While the node has children listed
valid_children = [child for child in node.children if child is not None]
if not valid_children:
logger.warning(f"Node {node.sequence} has empty children list or only None entries. Stopping selection.")
break # Cannot proceed
parent_visits = node.visits
unvisited = [child for child in valid_children if child.visits == 0]
if unvisited:
selected_child = self.random_state.choice(unvisited)
node = selected_child # Move to the unvisited child
selection_path_ids.append(node.id)
break # Stop selection here, this node will be expanded/simulated
# If all children visited, use selection strategy
if cfg["use_thompson_sampling"] and cfg["use_bayesian_evaluation"]:
# Thompson Sampling
samples = [(child, child.thompson_sample()) for child in valid_children]
if not samples:
logger.warning(f"No valid Thompson samples for children of {node.sequence}. Selecting randomly.")
selected_child = self.random_state.choice(valid_children)
else:
selected_child, _ = max(samples, key=lambda x: x[1])
node = selected_child
else:
# UCT Selection
uct_values = []
for child in valid_children:
try:
uct = self._calculate_uct(child, parent_visits)
if math.isfinite(uct):
uct_values.append((child, uct))
else:
logger.warning(f"UCT for child {child.sequence} was non-finite. Skipping.")
except Exception as uct_err:
logger.error(f"UCT calculation error for node {child.sequence}: {uct_err}")
if not uct_values:
logger.warning(f"No valid UCT values for children of {node.sequence}. Selecting randomly.")
if not valid_children: # Should not happen if loop condition met, but safety check
logger.error(f"Selection error: Node {node.sequence} has no valid children. Cannot proceed.")
return node # Return current node as selection cannot advance
selected_child = self.random_state.choice(valid_children)
else:
uct_values.sort(key=lambda x: x[1], reverse=True) # Highest UCT wins
selected_child = uct_values[0][0]
node = selected_child
selection_path_ids.append(node.id) # Add selected node to path
# If the newly selected node is not fully expanded, stop selection (it's the target)
# Or if it has no children (it's a leaf node)
if not node.fully_expanded() or not node.children:
break
# Update max depth seen
current_depth = len(selection_path_ids) - 1
self.memory["depth"] = max(self.memory.get("depth", 0), current_depth)
if self.debug_logging:
path_seq = [(node.sequence if node else '?') for nid in selection_path_ids for node in [self._find_node_by_id(nid)]]
logger.debug(f"Selection path (Sequences): {' -> '.join(map(str, path_seq))}")
return node # Return the selected leaf or expandable node
def _classify_approach(self, thought: str) -> tuple[str, str]:
"""
Classifies a thought into an approach type and family using keyword matching.
Args:
thought: The thought text to classify
Returns:
Tuple of (approach_type, approach_family)
Note:
Uses APPROACH_TAXONOMY for keyword matching and APPROACH_METADATA for family assignment
"""
approach_type = "variant" # Default if no keywords match
approach_family = "general"
if not thought or not isinstance(thought, str):
return approach_type, approach_family
thought_lower = thought.lower()
approach_scores = {
app: sum(1 for kw in kws if kw in thought_lower)
for app, kws in APPROACH_TAXONOMY.items() if kws # Check if keywords exist
}
positive_scores = {app: score for app, score in approach_scores.items() if score > 0}
if positive_scores:
max_score = max(positive_scores.values())
# Handle ties by random choice among best
best_approaches = [app for app, score in positive_scores.items() if score == max_score]
approach_type = self.random_state.choice(best_approaches)
# Get family from metadata
approach_family = APPROACH_METADATA.get(approach_type, {}).get("family", "general")
if self.debug_logging:
logger.debug(f"Classified thought '{truncate_text(thought, 50)}' as: {approach_type} ({approach_family})")
return approach_type, approach_family
def _check_surprise(self, parent_node: Node, new_content: str, new_approach_type: str, new_approach_family: str) -> tuple[bool, str]:
"""
Checks if new node content/approach is surprising relative to parent.
Args:
parent_node: Parent node for comparison
new_content: New content to evaluate
new_approach_type: Approach type of new content
new_approach_family: Approach family of new content
Returns:
Tuple of (is_surprising, explanation)
Note:
Considers semantic distance, approach family shifts, and novelty factors
"""
cfg = self.config
surprise_factors = []
is_surprising = False
surprise_explanation = ""
# 1. Semantic Distance Check
if cfg["use_semantic_distance"]:
try:
parent_content = str(parent_node.content) if parent_node.content else ""
new_content_str = str(new_content) if new_content else ""
if parent_content and new_content_str:
dist = calculate_semantic_distance(parent_content, new_content_str, use_tfidf=True) # Can disable TFIDF here if too slow
if dist > cfg["surprise_threshold"]:
surprise_factors.append({
"type": "semantic", "value": dist,
"weight": cfg["surprise_semantic_weight"],
"desc": f"Semantic dist ({dist:.2f})"
})
except Exception as e:
logger.warning(f"Semantic distance check failed: {e}")
# 2. Shift in Thought Approach Family
parent_family = parent_node.approach_family
if parent_family != new_approach_family and new_approach_family != "general":
surprise_factors.append({
"type": "family_shift", "value": 1.0,
"weight": cfg["surprise_philosophical_shift_weight"],
"desc": f"Shift '{parent_family}'->'{new_approach_family}'"
})
# 3. Novelty of Thought Approach Family (using BFS on current tree)
try:
family_counts = Counter()
queue = [(self.root, 0)] if self.root else []
processed_bfs = set()
nodes_visited = 0
max_bfs_nodes = 100
max_bfs_depth = 5
while queue and nodes_visited < max_bfs_nodes:
curr_node, depth = queue.pop(0)
if curr_node is None or curr_node.id in processed_bfs or depth > max_bfs_depth:
continue
processed_bfs.add(curr_node.id)
nodes_visited += 1
family_counts[curr_node.approach_family] += 1
if depth + 1 <= max_bfs_depth:
queue.extend([(child, depth + 1) for child in curr_node.children if child is not None])
# If the new family has been seen <= 1 times (itself) and isn't 'general'
if family_counts.get(new_approach_family, 0) <= 1 and new_approach_family != "general":
surprise_factors.append({
"type": "novelty", "value": 0.8, # Slightly less value than shift/semantic maybe?
"weight": cfg["surprise_novelty_weight"],
"desc": f"Novel approach family ('{new_approach_family}')"
})
except Exception as e:
logger.warning(f"Novelty check BFS failed: {e}", exc_info=self.debug_logging)
# Calculate combined weighted score
if surprise_factors:
total_weighted_score = sum(f["value"] * f["weight"] for f in surprise_factors)
total_weight = sum(f["weight"] for f in surprise_factors)
combined_score = (total_weighted_score / total_weight) if total_weight > 1e-6 else 0.0
if combined_score >= cfg["surprise_overall_threshold"]:
is_surprising = True
factor_descs = [f"- {f['desc']} (Val:{f['value']:.2f}, W:{f['weight']:.1f})" for f in surprise_factors]
surprise_explanation = (f"Combined surprise ({combined_score:.2f} >= {cfg['surprise_overall_threshold']}):\n" + "\n".join(factor_descs))
if self.debug_logging:
logger.debug(f"Surprise DETECTED for node sequence {parent_node.sequence+1}: Score={combined_score:.2f}\n{surprise_explanation}")
return is_surprising, surprise_explanation
async def expand(self, node: Node) -> Node | None:
"""
Expands a node by generating a thought and creating a new child analysis.
Args:
node: Node to expand
Returns:
Newly created child node, or None if expansion failed
Raises:
Exception: If LLM calls fail or node expansion encounters errors
"""
cfg = self.config
if node.fully_expanded():
logger.warning(f"Attempted to expand fully expanded Node {node.sequence}. Returning None.")
return None
if not node.content:
logger.warning(f"Attempted to expand Node {node.sequence} with no content. Returning None.")
return None
try:
context = self.get_context_for_node(node)
# 1. Generate Thought
if self.debug_logging:
logger.debug(f"Generating thought for Node {node.sequence}")
thought = await self.llm.generate_thought(context, cfg)
if not isinstance(thought, str) or not thought.strip() or "Error:" in thought:
logger.error(f"Invalid thought generation for Node {node.sequence}: '{thought}'")
return None # Expansion failed
thought = thought.strip()
if self.debug_logging:
logger.debug(f"Node {node.sequence} -> Thought: '{truncate_text(thought, 80)}'")
# Check thought against unfit markers (simple check)
is_unfit_thought = False
if self.unfit_markers:
for marker in self.unfit_markers:
marker_summary = marker.get('summary')
if marker_summary and calculate_semantic_distance(thought, marker_summary) < 0.15: # Strict threshold
is_unfit_thought = True
logger.warning(f"Generated thought for Node {node.sequence} resembles unfit marker '{marker_summary}'. Skipping expansion.")
break
if is_unfit_thought:
return None # Skip expansion if thought seems unfit
# Classify approach based on thought
approach_type, approach_family = self._classify_approach(thought)
self.explored_thoughts.add(thought)
if approach_type not in self.approach_types:
self.approach_types.append(approach_type)
if approach_type not in self.explored_approaches:
self.explored_approaches[approach_type] = []
self.explored_approaches[approach_type].append(thought)
# 2. Update Analysis based on Thought
if self.debug_logging:
logger.debug(f"Updating analysis for Node {node.sequence} based on thought")
# Pass original content in context for update prompt
context_for_update = context.copy()
context_for_update['answer'] = node.content # Use 'answer' key as expected by UPDATE_PROMPT
context_for_update['improvements'] = thought # Use 'improvements' key
new_content = await self.llm.update_analysis(thought, context_for_update, cfg)
if not isinstance(new_content, str) or not new_content.strip() or "Error:" in new_content:
logger.error(f"Invalid new content generation for Node {node.sequence}: '{new_content}'")
return None # Expansion failed
new_content = new_content.strip()
if self.debug_logging:
logger.debug(f"Node {node.sequence} -> New Content: '{truncate_text(new_content, 80)}'")
# 3. Generate Tags for New Content
new_tags = await self.llm.generate_tags(new_content, cfg)
if self.debug_logging:
logger.debug(f"Generated Tags for new node: {new_tags}")
# 4. Check for Surprise
is_surprising, surprise_explanation = self._check_surprise(node, new_content, approach_type, approach_family)
# 5. Create Child Node
child = Node(
content=new_content,
parent=node,
sequence=self.get_next_sequence(),
thought=thought,
approach_type=approach_type,
approach_family=approach_family,
max_children=cfg["max_children"],
use_bayesian_evaluation=cfg["use_bayesian_evaluation"],
beta_prior_alpha=cfg["beta_prior_alpha"], # Child starts with default priors
beta_prior_beta=cfg["beta_prior_beta"]
)
child.descriptive_tags = new_tags
child.is_surprising = is_surprising
child.surprise_explanation = surprise_explanation
# Add child to parent
node.add_child(child)
if is_surprising:
self.surprising_nodes.append(child)
# Update branch count if this adds a new branch
if len(node.children) > 1:
self.memory["branches"] += 1
if self.debug_logging:
logger.debug(f"Successfully expanded Node {node.sequence} -> Child {child.sequence}")
return child
except Exception as e:
logger.error(f"Expand error on Node {node.sequence}: {e}", exc_info=self.debug_logging)
return None
async def simulate(self, node: Node) -> float | None:
"""
Simulates (evaluates) a node using the LLM to get a quality score.
Args:
node: Node to evaluate
Returns:
Score from 1-10, or None if simulation failed
Note:
Updates approach performance tracking and high-scoring node memory
"""
cfg = self.config
if not node.content:
logger.warning(f"Cannot simulate Node {node.sequence}: Content is empty. Returning default score 5.0")
return 5.0
try:
context = self.get_context_for_node(node)
# Ensure context has the key expected by the eval prompt
context['answer_to_evaluate'] = node.content
if self.debug_logging:
logger.debug(f"Evaluating Node {node.sequence}")
raw_score = await self.llm.evaluate_analysis(node.content, context, cfg)
# Validate score is int 1-10
if not isinstance(raw_score, int) or not (1 <= raw_score <= 10):
logger.error(f"Evaluation for Node {node.sequence} returned invalid score: {raw_score}. Defaulting to 5.")
raw_score = 5
score = float(raw_score)
node.raw_scores.append(raw_score) # Keep track of raw scores received
approach = node.approach_type if node.approach_type else "unknown"
# Update approach performance tracking
if cfg["use_bayesian_evaluation"]:
# Use raw score for pseudo counts (scale 1-10)
pseudo_successes = max(0, raw_score - 1) # 10 -> 9 successes
pseudo_failures = max(0, 10 - raw_score) # 3 -> 7 failures
# Ensure approach exists in prior dicts, initializing if necessary
current_alpha = self.approach_alphas.setdefault(approach, cfg["beta_prior_alpha"])
current_beta = self.approach_betas.setdefault(approach, cfg["beta_prior_beta"])
# Update priors safely
self.approach_alphas[approach] = max(1e-9, current_alpha + pseudo_successes)
self.approach_betas[approach] = max(1e-9, current_beta + pseudo_failures)
else:
# Non-Bayesian: Update simple average score (e.g., using EMA)
current_avg = self.approach_scores.get(approach, score) # Initialize with current score if first time
self.approach_scores[approach] = 0.7 * score + 0.3 * current_avg # EMA update
if self.debug_logging:
logger.debug(f"Node {node.sequence} evaluation result: {score:.1f}/10")
# Update high score memory (use score 1-10)
if score >= 7:
entry = (score, node.content, approach, node.thought)
self.memory["high_scoring_nodes"].append(entry)
# Sort and trim memory
self.memory["high_scoring_nodes"].sort(key=lambda x: x[0], reverse=True)
self.memory["high_scoring_nodes"] = self.memory["high_scoring_nodes"][:cfg["memory_cutoff"]]
return score
except Exception as e:
logger.error(f"Simulate error for Node {node.sequence}: {e}", exc_info=self.debug_logging)
return None # Indicate simulation failure
def backpropagate(self, node: Node, score: float) -> None:
"""
Backpropagates simulation score up the tree to update node statistics.
Args:
node: Starting node for backpropagation
score: Score to backpropagate (1-10 scale)
Note:
Updates visit counts and either Bayesian parameters or cumulative values
"""
cfg = self.config
if score is None or not math.isfinite(score):
logger.warning(f"Invalid score ({score}) received for backpropagation from Node {node.sequence}. Skipping.")
return
if self.debug_logging:
logger.debug(f"Backpropagating score {score:.2f} from Node {node.sequence}")
# Use 1-10 score for pseudo counts in Bayesian updates
pseudo_successes = max(0, score - 1) # Use 1-10 score for pseudo counts
pseudo_failures = max(0, 10 - score)
temp_node: Node | None = node
path_len = 0
while temp_node:
temp_node.visits += 1
if cfg["use_bayesian_evaluation"]:
if temp_node.alpha is not None and temp_node.beta is not None:
# Update using pseudo counts from 1-10 score
temp_node.alpha = max(1e-9, temp_node.alpha + pseudo_successes)
temp_node.beta = max(1e-9, temp_node.beta + pseudo_failures)
else:
logger.warning(f"Node {temp_node.sequence} missing alpha/beta during backprop.")
else: # Non-Bayesian: Add score to cumulative value
if temp_node.value is not None:
temp_node.value += score # Add the raw score (1-10)
else: # Initialize if missing (should only happen for root if not pre-simulated)
logger.warning(f"Node {temp_node.sequence} missing value during non-Bayesian backprop. Initializing.")
temp_node.value = score
temp_node = temp_node.parent
path_len += 1
if self.debug_logging:
logger.debug(f"Backpropagation complete for Node {node.sequence} (Path length: {path_len})")
async def run_search_iterations(self, num_iterations: int, simulations_per_iteration: int) -> None:
"""
Runs the main MCTS search loop with concurrent simulation batches.
Args:
num_iterations: Number of iterations to run
simulations_per_iteration: Simulations per iteration
Note:
Implements early stopping and concurrent batch processing for performance
"""
cfg = self.config
logger.info(f"Starting MCTS search: {num_iterations} iterations, {simulations_per_iteration} simulations/iter.")
# Performance optimization - run multiple simulations concurrently
max_concurrent = 3 # Set a reasonable limit for concurrency
for i in range(num_iterations):
self.iterations_completed = i + 1
logger.info(f"--- Starting Iteration {self.iterations_completed}/{num_iterations} ---")
# Process simulations in batches for better concurrency
for batch_start in range(0, simulations_per_iteration, max_concurrent):
batch_size = min(max_concurrent, simulations_per_iteration - batch_start)
batch_tasks = []
# Create tasks for the batch
for j in range(batch_start, batch_start + batch_size):
sim_num = j + 1
task = asyncio.create_task(self._run_single_simulation(sim_num, simulations_per_iteration))
batch_tasks.append(task)
# Wait for the batch to complete
await asyncio.gather(*batch_tasks)
# Check early stopping after each batch
if (cfg["early_stopping"] and
self.best_score >= cfg["early_stopping_threshold"] and
self.high_score_counter >= cfg["early_stopping_stability"]):
logger.info(f"EARLY STOPPING criteria met during Iteration {self.iterations_completed}.")
return # Exit early
# --- End of Simulations for Iteration i ---
logger.info(f"--- Finished Iteration {self.iterations_completed}. Current Best Score: {self.best_score:.2f} ---")
# Re-check early stopping condition after the iteration
if (cfg["early_stopping"] and
self.best_score >= cfg["early_stopping_threshold"] and
self.high_score_counter >= cfg["early_stopping_stability"]):
logger.info(f"EARLY STOPPING criteria met at end of Iteration {self.iterations_completed}.")
break # Exit outer iteration loop
logger.info("MCTS search finished.")
async def _run_single_simulation(self, current_sim_num: int, total_sims: int) -> None:
"""
Runs a single MCTS simulation cycle (select-expand-simulate-backpropagate).
Args:
current_sim_num: Current simulation number (for logging)
total_sims: Total simulations in this batch
Note:
Core MCTS algorithm implementation with error handling and best score tracking
"""
self.simulations_completed += 1
cfg = self.config
if self.debug_logging:
logger.debug(f"--- Sim {current_sim_num}/{total_sims} ---")
# 1. Select
leaf = await self.select()
if not leaf:
logger.error(f"Sim {current_sim_num}: Selection returned None. Skipping simulation.")
return
# 2. Expand (if not terminal and not fully expanded)
node_to_simulate = leaf
if not leaf.fully_expanded() and leaf.content: # Check content exists
if self.debug_logging:
logger.debug(f"Sim {current_sim_num}: Attempting expansion from Node {leaf.sequence}")
expanded_node = await self.expand(leaf)
if expanded_node:
node_to_simulate = expanded_node # Simulate the newly expanded node
if self.debug_logging:
logger.debug(f"Sim {current_sim_num}: Expanded {leaf.sequence} -> {node_to_simulate.sequence}")
else:
if self.debug_logging:
logger.warning(f"Sim {current_sim_num}: Expansion failed for {leaf.sequence}. Simulating original leaf.")
node_to_simulate = leaf # Simulate original leaf if expansion failed
elif self.debug_logging:
logger.debug(f"Sim {current_sim_num}: Node {leaf.sequence} is fully expanded or has no content. Simulating directly.")
# 3. Simulate
score = None
if node_to_simulate and node_to_simulate.content:
score = await self.simulate(node_to_simulate)
elif node_to_simulate:
logger.warning(f"Sim {current_sim_num}: Skipping simulation for {node_to_simulate.sequence} (no content).")
score = 5.0 # Assign default score
else: # Should not happen if selection worked
logger.error(f"Sim {current_sim_num}: node_to_simulate is None after select/expand. Skipping simulation.")
return # Skip backprop
# 4. Backpropagate
if score is not None:
self.backpropagate(node_to_simulate, score)
# Update overall best score/solution found so far
if score > self.best_score:
logger.info(f"Sim {current_sim_num}: ✨ New best! Score: {score:.1f} (Node {node_to_simulate.sequence})")
self.best_score = score
self.best_solution = str(node_to_simulate.content)
self.high_score_counter = 0 # Reset stability counter
elif score == self.best_score:
# If score matches best, don't reset counter
pass
else: # Score is lower than best
self.high_score_counter = 0 # Reset stability counter if score drops
# Check early stopping (threshold) - based on overall best score
if cfg["early_stopping"] and self.best_score >= cfg["early_stopping_threshold"]:
self.high_score_counter += 1 # Increment counter only if score >= threshold
if self.debug_logging:
logger.debug(f"Sim {current_sim_num}: Best score ({self.best_score:.1f}) >= threshold. Stability: {self.high_score_counter}/{cfg['early_stopping_stability']}")
else: # Simulation failed (score is None)
if node_to_simulate:
logger.warning(f"Sim {current_sim_num}: Simulation failed for Node {node_to_simulate.sequence}. No score obtained.")
self.high_score_counter = 0 # Reset stability counter if sim fails
def get_final_results(self) -> MCTSResult:
"""
Returns the best score and solution found during MCTS search.
Returns:
MCTSResult namedtuple containing best_score, best_solution_content, and mcts_instance
Note:
Cleans solution content of <think> tags and returns structured results
"""
# Clean the best solution content of <think> tags if present
cleaned_solution = self.best_solution
if cleaned_solution and isinstance(cleaned_solution, str):
# First try to remove the entire <think> block if it's a pure think block
clean_attempt = re.sub(r'<think>.*?</think>', '', cleaned_solution, flags=re.DOTALL)
# If that removes everything, keep the original but strip just the tags
if not clean_attempt.strip() and ("<think>" in cleaned_solution or "</think>" in cleaned_solution):
cleaned_solution = re.sub(r'</?think>', '', cleaned_solution)
else:
cleaned_solution = clean_attempt
# In a real app, you might want more detailed results (e.g., best node path)
return MCTSResult(
best_score=self.best_score,
best_solution_content=cleaned_solution.strip() if isinstance(cleaned_solution, str) else cleaned_solution,
mcts_instance=self # Return self for further analysis if needed
)
def find_best_final_node(self) -> Node | None:
"""
Finds the node object corresponding to the best solution content using BFS.
Returns:
Node with content matching best solution, or None if not found
Note:
Performs content cleaning and exact matching with score proximity as tiebreaker
"""
if not self.best_solution or not self.root:
return None
queue = [self.root]
visited_ids = {self.root.id}
best_match_node = None
min_score_diff = float('inf')
# Clean target solution content once
target_content = str(self.best_solution).strip()
target_content = re.sub(r"^```(json|markdown)?\s*", "", target_content, flags=re.IGNORECASE | re.MULTILINE)
target_content = re.sub(r"\s*```$", "", target_content, flags=re.MULTILINE).strip()
while queue:
current = queue.pop(0)
if current is None:
continue
# Clean node content for comparison
node_content = str(current.content).strip()
node_content = re.sub(r"^```(json|markdown)?\s*", "", node_content, flags=re.IGNORECASE | re.MULTILINE)
node_content = re.sub(r"\s*```$", "", node_content, flags=re.MULTILINE).strip()
# Check for exact content match (after cleaning)
if node_content == target_content:
score_diff = abs(current.get_average_score() - self.best_score)
if best_match_node is None or score_diff < min_score_diff:
best_match_node = current
min_score_diff = score_diff
# Add valid children to queue
for child in current.children:
if child and child.id not in visited_ids:
visited_ids.add(child.id)
queue.append(child)
if not best_match_node:
logger.warning("Could not find node object exactly matching best solution content.")
return best_match_node
def _find_node_by_id(self, node_id: str) -> Node | None:
"""
Finds a node by its unique ID using breadth-first search.
Args:
node_id: Unique identifier of the node to find
Returns:
Node with matching ID, or None if not found
"""
if not self.root:
return None
queue = [self.root]
visited = {self.root.id}
while queue:
current = queue.pop(0)
if current.id == node_id:
return current
for child in current.children:
if child and child.id not in visited:
visited.add(child.id)
queue.append(child)
return None
def _find_nodes_by_approach(self, approach_type: str) -> list[Node]:
"""
Finds all nodes with a specific approach type using breadth-first search.
Args:
approach_type: The approach type to search for
Returns:
List of nodes with matching approach type
"""
nodes = []
if not self.root:
return nodes
queue = [self.root]
visited = {self.root.id}
while queue:
current = queue.pop(0)
if current.approach_type == approach_type:
nodes.append(current)
for child in current.children:
if child and child.id not in visited:
visited.add(child.id)
queue.append(child)
return nodes
def export_tree_summary(self) -> dict[str, Any]:
"""
Exports a summary of the tree structure and key nodes.
Returns:
Dictionary containing tree structure in JSON format
"""
if not self.root:
return {"error": "No root node"}
return self.root.node_to_json()
def get_best_path_nodes(self) -> list[Node]:
"""
Traces the path from root to the best scoring node found.
Returns:
List of nodes from root to best node (in order)
"""
best_node = self.find_best_final_node()
if not best_node:
return []
path = []
current = best_node
while current:
path.append(current)
current = current.parent
return path[::-1] # Reverse to get root -> best order
# ==============================================================================
# Intent Handling
# ==============================================================================
# StateManager import moved to the top with other local imports
```