#
tokens: 44503/50000 4/56 files (page 3/6)
lines: on (toggle) GitHub
raw markdown copy reset
This is page 3 of 6. Use http://codebase.md/arthurcolle/openai-mcp?lines=true&page={x} to view the full context.

# Directory Structure

```
├── .gitignore
├── claude_code
│   ├── __init__.py
│   ├── __pycache__
│   │   ├── __init__.cpython-312.pyc
│   │   └── mcp_server.cpython-312.pyc
│   ├── claude.py
│   ├── commands
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   ├── __init__.cpython-312.pyc
│   │   │   └── serve.cpython-312.pyc
│   │   ├── client.py
│   │   ├── multi_agent_client.py
│   │   └── serve.py
│   ├── config
│   │   └── __init__.py
│   ├── examples
│   │   ├── agents_config.json
│   │   ├── claude_mcp_config.html
│   │   ├── claude_mcp_config.json
│   │   ├── echo_server.py
│   │   └── README.md
│   ├── lib
│   │   ├── __init__.py
│   │   ├── __pycache__
│   │   │   └── __init__.cpython-312.pyc
│   │   ├── context
│   │   │   └── __init__.py
│   │   ├── monitoring
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-312.pyc
│   │   │   │   └── server_metrics.cpython-312.pyc
│   │   │   ├── cost_tracker.py
│   │   │   └── server_metrics.py
│   │   ├── providers
│   │   │   ├── __init__.py
│   │   │   ├── base.py
│   │   │   └── openai.py
│   │   ├── rl
│   │   │   ├── __init__.py
│   │   │   ├── grpo.py
│   │   │   ├── mcts.py
│   │   │   └── tool_optimizer.py
│   │   ├── tools
│   │   │   ├── __init__.py
│   │   │   ├── __pycache__
│   │   │   │   ├── __init__.cpython-312.pyc
│   │   │   │   ├── base.cpython-312.pyc
│   │   │   │   ├── file_tools.cpython-312.pyc
│   │   │   │   └── manager.cpython-312.pyc
│   │   │   ├── ai_tools.py
│   │   │   ├── base.py
│   │   │   ├── code_tools.py
│   │   │   ├── file_tools.py
│   │   │   ├── manager.py
│   │   │   └── search_tools.py
│   │   └── ui
│   │       ├── __init__.py
│   │       └── tool_visualizer.py
│   ├── mcp_server.py
│   ├── README_MCP_CLIENT.md
│   ├── README_MULTI_AGENT.md
│   └── util
│       └── __init__.py
├── claude.py
├── cli.py
├── data
│   └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│   ├── agents_config.json
│   └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│   └── style.css
├── templates
│   └── index.html
└── web-client.html
```

# Files

--------------------------------------------------------------------------------
/claude_code/lib/rl/grpo.py:
--------------------------------------------------------------------------------

```python
   1 | """
   2 | Group Relative Policy Optimization (GRPO) for multi-agent learning in Claude Code.
   3 | This module provides a multi-agent GRPO implementation that learns from interactions.
   4 | """
   5 | 
   6 | import numpy as np
   7 | import torch
   8 | import torch.nn as nn
   9 | import torch.nn.functional as F
  10 | import torch.optim as optim
  11 | from torch.distributions import Categorical
  12 | from typing import List, Dict, Tuple, Optional, Any, Union, Callable
  13 | from dataclasses import dataclass
  14 | from collections import deque
  15 | import random
  16 | import time
  17 | 
  18 | 
  19 | @dataclass
  20 | class Experience:
  21 |     """A single step of experience for reinforcement learning."""
  22 |     state: Any
  23 |     action: Any
  24 |     reward: float
  25 |     next_state: Any
  26 |     done: bool
  27 |     info: Optional[Dict[str, Any]] = None
  28 | 
  29 | 
  30 | class ExperienceBuffer:
  31 |     """Buffer to store and sample experiences for training."""
  32 |     
  33 |     def __init__(self, capacity: int = 100000):
  34 |         """
  35 |         Initialize the experience buffer.
  36 |         
  37 |         Args:
  38 |             capacity: Maximum number of experiences to store
  39 |         """
  40 |         self.buffer = deque(maxlen=capacity)
  41 |     
  42 |     def add(self, experience: Experience) -> None:
  43 |         """Add an experience to the buffer."""
  44 |         self.buffer.append(experience)
  45 |     
  46 |     def sample(self, batch_size: int) -> List[Experience]:
  47 |         """Sample a batch of experiences from the buffer."""
  48 |         return random.sample(self.buffer, min(batch_size, len(self.buffer)))
  49 |     
  50 |     def __len__(self) -> int:
  51 |         """Get the current size of the buffer."""
  52 |         return len(self.buffer)
  53 | 
  54 | 
  55 | class PolicyNetwork(nn.Module):
  56 |     """Neural network to represent a policy."""
  57 |     
  58 |     def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int):
  59 |         """
  60 |         Initialize the policy network.
  61 |         
  62 |         Args:
  63 |             input_dim: Dimension of the input state
  64 |             hidden_dims: List of hidden layer dimensions
  65 |             output_dim: Dimension of the action space
  66 |         """
  67 |         super(PolicyNetwork, self).__init__()
  68 |         
  69 |         # Create the input layer
  70 |         layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
  71 |         
  72 |         # Create hidden layers
  73 |         for i in range(len(hidden_dims) - 1):
  74 |             layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
  75 |             layers.append(nn.ReLU())
  76 |         
  77 |         # Create output layer
  78 |         layers.append(nn.Linear(hidden_dims[-1], output_dim))
  79 |         
  80 |         self.network = nn.Sequential(*layers)
  81 |     
  82 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
  83 |         """Forward pass through the network."""
  84 |         return self.network(x)
  85 | 
  86 | 
  87 | class ValueNetwork(nn.Module):
  88 |     """Neural network to represent a value function."""
  89 |     
  90 |     def __init__(self, input_dim: int, hidden_dims: List[int]):
  91 |         """
  92 |         Initialize the value network.
  93 |         
  94 |         Args:
  95 |             input_dim: Dimension of the input state
  96 |             hidden_dims: List of hidden layer dimensions
  97 |         """
  98 |         super(ValueNetwork, self).__init__()
  99 |         
 100 |         # Create the input layer
 101 |         layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
 102 |         
 103 |         # Create hidden layers
 104 |         for i in range(len(hidden_dims) - 1):
 105 |             layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
 106 |             layers.append(nn.ReLU())
 107 |         
 108 |         # Create output layer (scalar value)
 109 |         layers.append(nn.Linear(hidden_dims[-1], 1))
 110 |         
 111 |         self.network = nn.Sequential(*layers)
 112 |     
 113 |     def forward(self, x: torch.Tensor) -> torch.Tensor:
 114 |         """Forward pass through the network."""
 115 |         return self.network(x)
 116 | 
 117 | 
 118 | class GRPO:
 119 |     """
 120 |     Group Relative Policy Optimization implementation for multi-agent learning.
 121 |     GRPO extends PPO by considering relative performance within a group of agents.
 122 |     """
 123 |     
 124 |     def __init__(
 125 |         self,
 126 |         state_dim: int,
 127 |         action_dim: int,
 128 |         hidden_dims: List[int] = [64, 64],
 129 |         lr_policy: float = 3e-4,
 130 |         lr_value: float = 1e-3,
 131 |         gamma: float = 0.99,
 132 |         gae_lambda: float = 0.95,
 133 |         clip_ratio: float = 0.2,
 134 |         target_kl: float = 0.01,
 135 |         value_coef: float = 0.5,
 136 |         entropy_coef: float = 0.01,
 137 |         max_grad_norm: float = 0.5,
 138 |         use_gae: bool = True,
 139 |         normalize_advantages: bool = True,
 140 |         relative_advantage_weight: float = 0.5,
 141 |         device: str = "cuda" if torch.cuda.is_available() else "cpu",
 142 |     ):
 143 |         """
 144 |         Initialize the GRPO agent.
 145 |         
 146 |         Args:
 147 |             state_dim: Dimension of the state space
 148 |             action_dim: Dimension of the action space
 149 |             hidden_dims: Dimensions of hidden layers in networks
 150 |             lr_policy: Learning rate for policy network
 151 |             lr_value: Learning rate for value network
 152 |             gamma: Discount factor
 153 |             gae_lambda: Lambda for GAE
 154 |             clip_ratio: PPO clipping parameter
 155 |             target_kl: Target KL divergence for early stopping
 156 |             value_coef: Value loss coefficient
 157 |             entropy_coef: Entropy bonus coefficient
 158 |             max_grad_norm: Maximum gradient norm for clipping
 159 |             use_gae: Whether to use GAE
 160 |             normalize_advantages: Whether to normalize advantages
 161 |             relative_advantage_weight: Weight for relative advantage component
 162 |             device: Device to run the model on
 163 |         """
 164 |         self.state_dim = state_dim
 165 |         self.action_dim = action_dim
 166 |         self.gamma = gamma
 167 |         self.gae_lambda = gae_lambda
 168 |         self.clip_ratio = clip_ratio
 169 |         self.target_kl = target_kl
 170 |         self.value_coef = value_coef
 171 |         self.entropy_coef = entropy_coef
 172 |         self.max_grad_norm = max_grad_norm
 173 |         self.use_gae = use_gae
 174 |         self.normalize_advantages = normalize_advantages
 175 |         self.relative_advantage_weight = relative_advantage_weight
 176 |         self.device = device
 177 |         
 178 |         # Initialize networks
 179 |         self.policy = PolicyNetwork(state_dim, hidden_dims, action_dim).to(device)
 180 |         self.value = ValueNetwork(state_dim, hidden_dims).to(device)
 181 |         
 182 |         # Initialize optimizers
 183 |         self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr_policy)
 184 |         self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr_value)
 185 |         
 186 |         # Initialize experience buffer
 187 |         self.buffer = ExperienceBuffer()
 188 |         
 189 |         # Group-level buffers for relative advantage computation
 190 |         self.group_rewards = []
 191 |         self.agent_id = None  # Will be set when joining a group
 192 |     
 193 |     def set_agent_id(self, agent_id: str) -> None:
 194 |         """Set the agent's ID within the group."""
 195 |         self.agent_id = agent_id
 196 |     
 197 |     def get_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[int, float]:
 198 |         """
 199 |         Get an action from the policy for the given state.
 200 |         
 201 |         Args:
 202 |             state: The current state
 203 |             deterministic: Whether to return the most likely action
 204 |             
 205 |         Returns:
 206 |             Tuple of (action, log probability)
 207 |         """
 208 |         # Convert state to tensor
 209 |         state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 210 |         
 211 |         # Get action distributions
 212 |         with torch.no_grad():
 213 |             logits = self.policy(state_tensor)
 214 |             distribution = Categorical(logits=logits)
 215 |             
 216 |             if deterministic:
 217 |                 action = torch.argmax(logits, dim=1).item()
 218 |             else:
 219 |                 action = distribution.sample().item()
 220 |                 
 221 |             log_prob = distribution.log_prob(torch.tensor(action)).item()
 222 |         
 223 |         return action, log_prob
 224 |     
 225 |     def get_value(self, state: np.ndarray) -> float:
 226 |         """
 227 |         Get the estimated value of a state.
 228 |         
 229 |         Args:
 230 |             state: The state to evaluate
 231 |             
 232 |         Returns:
 233 |             The estimated value
 234 |         """
 235 |         # Convert state to tensor
 236 |         state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 237 |         
 238 |         # Get value estimate
 239 |         with torch.no_grad():
 240 |             value = self.value(state_tensor).item()
 241 |         
 242 |         return value
 243 |     
 244 |     def learn(
 245 |         self,
 246 |         batch_size: int = 64,
 247 |         epochs: int = 10,
 248 |         group_rewards: Optional[Dict[str, List[float]]] = None
 249 |     ) -> Dict[str, float]:
 250 |         """
 251 |         Update policy and value networks based on collected experience.
 252 |         
 253 |         Args:
 254 |             batch_size: Size of batches to use for updates
 255 |             epochs: Number of epochs to train for
 256 |             group_rewards: Rewards collected by all agents in the group
 257 |             
 258 |         Returns:
 259 |             Dictionary of training metrics
 260 |         """
 261 |         if len(self.buffer) < batch_size:
 262 |             return {"policy_loss": 0, "value_loss": 0, "kl": 0}
 263 |         
 264 |         # Prepare data for training
 265 |         states, actions, old_log_probs, returns, advantages = self._prepare_training_data(
 266 |             group_rewards)
 267 |         
 268 |         # Training metrics
 269 |         metrics = {
 270 |             "policy_loss": 0,
 271 |             "value_loss": 0,
 272 |             "entropy": 0,
 273 |             "kl": 0,
 274 |         }
 275 |         
 276 |         # Run training for multiple epochs
 277 |         for epoch in range(epochs):
 278 |             # Generate random indices for batching
 279 |             indices = np.random.permutation(len(states))
 280 |             
 281 |             # Process in batches
 282 |             for start_idx in range(0, len(states), batch_size):
 283 |                 # Get batch indices
 284 |                 batch_indices = indices[start_idx:start_idx + batch_size]
 285 |                 
 286 |                 # Extract batch data
 287 |                 batch_states = states[batch_indices]
 288 |                 batch_actions = actions[batch_indices]
 289 |                 batch_old_log_probs = old_log_probs[batch_indices]
 290 |                 batch_returns = returns[batch_indices]
 291 |                 batch_advantages = advantages[batch_indices]
 292 |                 
 293 |                 # Update policy
 294 |                 policy_loss, entropy, kl = self._update_policy(
 295 |                     batch_states, batch_actions, batch_old_log_probs, batch_advantages)
 296 |                 
 297 |                 # Early stopping based on KL divergence
 298 |                 if kl > 1.5 * self.target_kl:
 299 |                     break
 300 |                 
 301 |                 # Update value function
 302 |                 value_loss = self._update_value(batch_states, batch_returns)
 303 |                 
 304 |                 # Update metrics
 305 |                 metrics["policy_loss"] += policy_loss
 306 |                 metrics["value_loss"] += value_loss
 307 |                 metrics["entropy"] += entropy
 308 |                 metrics["kl"] += kl
 309 |             
 310 |             # Check for early stopping after each epoch
 311 |             if metrics["kl"] / (epoch + 1) > self.target_kl:
 312 |                 break
 313 |         
 314 |         # Normalize metrics by number of updates
 315 |         num_updates = epochs * ((len(states) + batch_size - 1) // batch_size)
 316 |         for key in metrics:
 317 |             metrics[key] /= num_updates
 318 |         
 319 |         # Clear buffer after training
 320 |         self.buffer = ExperienceBuffer()
 321 |         
 322 |         return metrics
 323 |     
 324 |     def _prepare_training_data(
 325 |         self, group_rewards: Optional[Dict[str, List[float]]] = None
 326 |     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 327 |         """
 328 |         Prepare data for training from the experience buffer.
 329 |         
 330 |         Args:
 331 |             group_rewards: Rewards collected by all agents in the group
 332 |             
 333 |         Returns:
 334 |             Tuple of (states, actions, old_log_probs, returns, advantages)
 335 |         """
 336 |         # Collect experiences from buffer
 337 |         experiences = list(self.buffer.buffer)
 338 |         
 339 |         # Extract components
 340 |         states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
 341 |         actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
 342 |         rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
 343 |         next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
 344 |         dones = torch.FloatTensor([float(exp.done) for exp in experiences]).to(self.device)
 345 |         
 346 |         # Compute values for all states and next states
 347 |         with torch.no_grad():
 348 |             values = self.value(states).squeeze()
 349 |             next_values = self.value(next_states).squeeze()
 350 |         
 351 |         # Compute advantages and returns
 352 |         if self.use_gae:
 353 |             # Generalized Advantage Estimation
 354 |             advantages = self._compute_gae(rewards, values, next_values, dones)
 355 |         else:
 356 |             # Regular advantages
 357 |             advantages = rewards + self.gamma * next_values * (1 - dones) - values
 358 |         
 359 |         # Compute returns (for value function)
 360 |         returns = advantages + values
 361 |         
 362 |         # If group rewards are provided, compute relative advantages
 363 |         if group_rewards is not None and self.agent_id in group_rewards:
 364 |             relative_advantages = self._compute_relative_advantages(
 365 |                 advantages, group_rewards)
 366 |             
 367 |             # Combine regular and relative advantages
 368 |             advantages = (1 - self.relative_advantage_weight) * advantages + \
 369 |                          self.relative_advantage_weight * relative_advantages
 370 |         
 371 |         # Normalize advantages if enabled
 372 |         if self.normalize_advantages:
 373 |             advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
 374 |         
 375 |         # Get old log probabilities
 376 |         old_log_probs = torch.FloatTensor(
 377 |             [self._compute_log_prob(exp.state, exp.action) for exp in experiences]
 378 |         ).to(self.device)
 379 |         
 380 |         return states, actions, old_log_probs, returns, advantages
 381 |     
 382 |     def _compute_gae(
 383 |         self, rewards: torch.Tensor, values: torch.Tensor, 
 384 |         next_values: torch.Tensor, dones: torch.Tensor
 385 |     ) -> torch.Tensor:
 386 |         """
 387 |         Compute advantages using Generalized Advantage Estimation.
 388 |         
 389 |         Args:
 390 |             rewards: Batch of rewards
 391 |             values: Batch of state values
 392 |             next_values: Batch of next state values
 393 |             dones: Batch of done flags
 394 |             
 395 |         Returns:
 396 |             Batch of advantage estimates
 397 |         """
 398 |         # Initialize advantages
 399 |         advantages = torch.zeros_like(rewards)
 400 |         
 401 |         # Initialize gae
 402 |         gae = 0
 403 |         
 404 |         # Compute advantages in reverse order
 405 |         for t in reversed(range(len(rewards))):
 406 |             # Compute TD error
 407 |             delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
 408 |             
 409 |             # Update gae
 410 |             gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
 411 |             
 412 |             # Store advantage
 413 |             advantages[t] = gae
 414 |         
 415 |         return advantages
 416 |     
 417 |     def _compute_relative_advantages(
 418 |         self, advantages: torch.Tensor, group_rewards: Dict[str, List[float]]
 419 |     ) -> torch.Tensor:
 420 |         """
 421 |         Compute relative advantages compared to other agents in the group.
 422 |         
 423 |         Args:
 424 |             advantages: This agent's advantages
 425 |             group_rewards: Rewards collected by all agents in the group
 426 |             
 427 |         Returns:
 428 |             Relative advantages
 429 |         """
 430 |         # Compute mean reward for each agent
 431 |         agent_mean_rewards = {
 432 |             agent_id: sum(rewards) / max(1, len(rewards))
 433 |             for agent_id, rewards in group_rewards.items()
 434 |         }
 435 |         
 436 |         # Compute mean reward across all agents
 437 |         group_mean_reward = sum(agent_mean_rewards.values()) / len(agent_mean_rewards)
 438 |         
 439 |         # Compute relative performance factor
 440 |         # Higher if this agent is doing better than the group average
 441 |         if self.agent_id in agent_mean_rewards:
 442 |             relative_factor = agent_mean_rewards[self.agent_id] / (group_mean_reward + 1e-8)
 443 |         else:
 444 |             relative_factor = 1.0
 445 |         
 446 |         # Apply the relative factor to the advantages
 447 |         relative_advantages = advantages * relative_factor
 448 |         
 449 |         return relative_advantages
 450 |     
 451 |     def _compute_log_prob(self, state: np.ndarray, action: int) -> float:
 452 |         """
 453 |         Compute the log probability of an action given a state.
 454 |         
 455 |         Args:
 456 |             state: The state
 457 |             action: The action
 458 |             
 459 |         Returns:
 460 |             The log probability
 461 |         """
 462 |         # Convert state to tensor
 463 |         state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
 464 |         
 465 |         # Get action distribution
 466 |         with torch.no_grad():
 467 |             logits = self.policy(state_tensor)
 468 |             distribution = Categorical(logits=logits)
 469 |             log_prob = distribution.log_prob(torch.tensor(action, device=self.device)).item()
 470 |         
 471 |         return log_prob
 472 |     
 473 |     def _update_policy(
 474 |         self, 
 475 |         states: torch.Tensor, 
 476 |         actions: torch.Tensor, 
 477 |         old_log_probs: torch.Tensor, 
 478 |         advantages: torch.Tensor
 479 |     ) -> Tuple[float, float, float]:
 480 |         """
 481 |         Update the policy network using PPO.
 482 |         
 483 |         Args:
 484 |             states: Batch of states
 485 |             actions: Batch of actions
 486 |             old_log_probs: Batch of old log probabilities
 487 |             advantages: Batch of advantages
 488 |             
 489 |         Returns:
 490 |             Tuple of (policy_loss, entropy, kl_divergence)
 491 |         """
 492 |         # Get action distributions
 493 |         logits = self.policy(states)
 494 |         distribution = Categorical(logits=logits)
 495 |         
 496 |         # Get new log probabilities
 497 |         new_log_probs = distribution.log_prob(actions)
 498 |         
 499 |         # Compute probability ratio
 500 |         ratio = torch.exp(new_log_probs - old_log_probs)
 501 |         
 502 |         # Compute surrogate objectives
 503 |         surrogate1 = ratio * advantages
 504 |         surrogate2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
 505 |         
 506 |         # Compute policy loss (negative because we're maximizing)
 507 |         policy_loss = -torch.min(surrogate1, surrogate2).mean()
 508 |         
 509 |         # Compute entropy bonus
 510 |         entropy = distribution.entropy().mean()
 511 |         
 512 |         # Add entropy bonus to loss
 513 |         loss = policy_loss - self.entropy_coef * entropy
 514 |         
 515 |         # Compute approximate KL divergence for monitoring
 516 |         with torch.no_grad():
 517 |             kl = (old_log_probs - new_log_probs).mean().item()
 518 |         
 519 |         # Update policy network
 520 |         self.policy_optimizer.zero_grad()
 521 |         loss.backward()
 522 |         nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
 523 |         self.policy_optimizer.step()
 524 |         
 525 |         return policy_loss.item(), entropy.item(), kl
 526 |     
 527 |     def _update_value(self, states: torch.Tensor, returns: torch.Tensor) -> float:
 528 |         """
 529 |         Update the value network.
 530 |         
 531 |         Args:
 532 |             states: Batch of states
 533 |             returns: Batch of returns
 534 |             
 535 |         Returns:
 536 |             Value loss
 537 |         """
 538 |         # Get value predictions
 539 |         values = self.value(states).squeeze()
 540 |         
 541 |         # Compute value loss
 542 |         value_loss = F.mse_loss(values, returns)
 543 |         
 544 |         # Update value network
 545 |         self.value_optimizer.zero_grad()
 546 |         value_loss.backward()
 547 |         nn.utils.clip_grad_norm_(self.value.parameters(), self.max_grad_norm)
 548 |         self.value_optimizer.step()
 549 |         
 550 |         return value_loss.item()
 551 | 
 552 | 
 553 | class MultiAgentGroupRL:
 554 |     """
 555 |     Multi-agent reinforcement learning system using GRPO for Claude Code.
 556 |     This class manages multiple GRPO agents that learn in a coordinated way.
 557 |     """
 558 |     
 559 |     def __init__(
 560 |         self,
 561 |         agent_configs: List[Dict[str, Any]],
 562 |         feature_extractor: Callable[[Dict[str, Any]], np.ndarray],
 563 |         reward_function: Callable[[Dict[str, Any], str, Any], float],
 564 |         update_interval: int = 1000,
 565 |         training_epochs: int = 10,
 566 |         batch_size: int = 64,
 567 |         save_dir: str = "./models",
 568 |         device: str = "cuda" if torch.cuda.is_available() else "cpu",
 569 |     ):
 570 |         """
 571 |         Initialize the multi-agent RL system.
 572 |         
 573 |         Args:
 574 |             agent_configs: List of configurations for each agent
 575 |             feature_extractor: Function to extract state features
 576 |             reward_function: Function to compute rewards
 577 |             update_interval: How often to update agents (in steps)
 578 |             training_epochs: Number of epochs to train for each update
 579 |             batch_size: Batch size for training
 580 |             save_dir: Directory to save models
 581 |             device: Device to run on
 582 |         """
 583 |         self.feature_extractor = feature_extractor
 584 |         self.reward_function = reward_function
 585 |         self.update_interval = update_interval
 586 |         self.training_epochs = training_epochs
 587 |         self.batch_size = batch_size
 588 |         self.save_dir = save_dir
 589 |         self.device = device
 590 |         
 591 |         # Initialize agents
 592 |         self.agents = {}
 593 |         for config in agent_configs:
 594 |             agent_id = config["id"]
 595 |             state_dim = config["state_dim"]
 596 |             action_dim = config["action_dim"]
 597 |             
 598 |             # Create GRPO agent
 599 |             agent = GRPO(
 600 |                 state_dim=state_dim,
 601 |                 action_dim=action_dim,
 602 |                 hidden_dims=config.get("hidden_dims", [64, 64]),
 603 |                 device=device,
 604 |                 **{k: v for k, v in config.items() if k not in ["id", "state_dim", "action_dim", "hidden_dims"]}
 605 |             )
 606 |             
 607 |             # Set agent ID
 608 |             agent.set_agent_id(agent_id)
 609 |             
 610 |             self.agents[agent_id] = agent
 611 |         
 612 |         # Track steps for periodic updates
 613 |         self.total_steps = 0
 614 |         
 615 |         # Store rewards for relative advantage computation
 616 |         self.agent_rewards = {agent_id: [] for agent_id in self.agents}
 617 |     
 618 |     def select_action(
 619 |         self, agent_id: str, observation: Dict[str, Any], deterministic: bool = False
 620 |     ) -> Tuple[Any, float]:
 621 |         """
 622 |         Select an action for the specified agent.
 623 |         
 624 |         Args:
 625 |             agent_id: ID of the agent
 626 |             observation: Current observation
 627 |             deterministic: Whether to select deterministically
 628 |             
 629 |         Returns:
 630 |             Tuple of (action, log probability)
 631 |         """
 632 |         if agent_id not in self.agents:
 633 |             raise ValueError(f"Unknown agent ID: {agent_id}")
 634 |         
 635 |         # Extract features
 636 |         state = self.feature_extractor(observation)
 637 |         
 638 |         # Get action from agent
 639 |         action, log_prob = self.agents[agent_id].get_action(state, deterministic)
 640 |         
 641 |         return action, log_prob
 642 |     
 643 |     def observe(
 644 |         self, 
 645 |         agent_id: str, 
 646 |         observation: Dict[str, Any],
 647 |         action: Any,
 648 |         reward: float,
 649 |         next_observation: Dict[str, Any],
 650 |         done: bool,
 651 |         info: Optional[Dict[str, Any]] = None
 652 |     ) -> None:
 653 |         """
 654 |         Record an observation for the specified agent.
 655 |         
 656 |         Args:
 657 |             agent_id: ID of the agent
 658 |             observation: Current observation
 659 |             action: Action taken
 660 |             reward: Reward received
 661 |             next_observation: Next observation
 662 |             done: Whether the episode is done
 663 |             info: Additional information
 664 |         """
 665 |         if agent_id not in self.agents:
 666 |             raise ValueError(f"Unknown agent ID: {agent_id}")
 667 |         
 668 |         # Extract features
 669 |         state = self.feature_extractor(observation)
 670 |         next_state = self.feature_extractor(next_observation)
 671 |         
 672 |         # Create experience
 673 |         exp = Experience(
 674 |             state=state,
 675 |             action=action,
 676 |             reward=reward,
 677 |             next_state=next_state,
 678 |             done=done,
 679 |             info=info
 680 |         )
 681 |         
 682 |         # Add experience to agent's buffer
 683 |         self.agents[agent_id].buffer.add(exp)
 684 |         
 685 |         # Store reward for relative advantage computation
 686 |         self.agent_rewards[agent_id].append(reward)
 687 |         
 688 |         # Increment step counter
 689 |         self.total_steps += 1
 690 |         
 691 |         # Perform updates if needed
 692 |         if self.total_steps % self.update_interval == 0:
 693 |             self.update_all_agents()
 694 |     
 695 |     def update_all_agents(self) -> Dict[str, Dict[str, float]]:
 696 |         """
 697 |         Update all agents' policies.
 698 |         
 699 |         Returns:
 700 |             Dictionary of training metrics for each agent
 701 |         """
 702 |         # Store metrics for each agent
 703 |         metrics = {}
 704 |         
 705 |         # Update each agent
 706 |         for agent_id, agent in self.agents.items():
 707 |             # Train the agent with group rewards
 708 |             agent_metrics = agent.learn(
 709 |                 batch_size=self.batch_size,
 710 |                 epochs=self.training_epochs,
 711 |                 group_rewards=self.agent_rewards
 712 |             )
 713 |             
 714 |             metrics[agent_id] = agent_metrics
 715 |         
 716 |         # Reset reward tracking
 717 |         self.agent_rewards = {agent_id: [] for agent_id in self.agents}
 718 |         
 719 |         return metrics
 720 |     
 721 |     def save_agents(self, suffix: str = "") -> None:
 722 |         """
 723 |         Save all agents' models.
 724 |         
 725 |         Args:
 726 |             suffix: Optional suffix for saved files
 727 |         """
 728 |         import os
 729 |         
 730 |         # Create save directory if it doesn't exist
 731 |         os.makedirs(self.save_dir, exist_ok=True)
 732 |         
 733 |         # Save each agent
 734 |         for agent_id, agent in self.agents.items():
 735 |             # Create file path
 736 |             file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
 737 |             
 738 |             # Save model
 739 |             torch.save({
 740 |                 "policy_state_dict": agent.policy.state_dict(),
 741 |                 "value_state_dict": agent.value.state_dict(),
 742 |                 "policy_optimizer_state_dict": agent.policy_optimizer.state_dict(),
 743 |                 "value_optimizer_state_dict": agent.value_optimizer.state_dict(),
 744 |             }, file_path)
 745 |     
 746 |     def load_agents(self, suffix: str = "") -> None:
 747 |         """
 748 |         Load all agents' models.
 749 |         
 750 |         Args:
 751 |             suffix: Optional suffix for loaded files
 752 |         """
 753 |         import os
 754 |         
 755 |         # Load each agent
 756 |         for agent_id, agent in self.agents.items():
 757 |             # Create file path
 758 |             file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
 759 |             
 760 |             # Check if file exists
 761 |             if not os.path.exists(file_path):
 762 |                 print(f"Warning: Model file not found for agent {agent_id}")
 763 |                 continue
 764 |             
 765 |             # Load model
 766 |             checkpoint = torch.load(file_path, map_location=self.device)
 767 |             
 768 |             # Load state dicts
 769 |             agent.policy.load_state_dict(checkpoint["policy_state_dict"])
 770 |             agent.value.load_state_dict(checkpoint["value_state_dict"])
 771 |             agent.policy_optimizer.load_state_dict(checkpoint["policy_optimizer_state_dict"])
 772 |             agent.value_optimizer.load_state_dict(checkpoint["value_optimizer_state_dict"])
 773 | 
 774 | 
 775 | class ToolSelectionGRPO:
 776 |     """
 777 |     Specialized GRPO implementation for tool selection in Claude Code.
 778 |     This class adapts the MultiAgentGroupRL for the specific context of tool selection.
 779 |     """
 780 |     
 781 |     def __init__(
 782 |         self,
 783 |         tool_registry: Any,  # Should be a reference to the tool registry
 784 |         context_evaluator: Callable,  # Function to evaluate quality of response given context
 785 |         state_dim: int = 768,  # Embedding dimension for query
 786 |         num_agents: int = 3,  # Number of agents in the group
 787 |         update_interval: int = 100,
 788 |         device: str = "cuda" if torch.cuda.is_available() else "cpu",
 789 |     ):
 790 |         """
 791 |         Initialize the GRPO tool selector.
 792 |         
 793 |         Args:
 794 |             tool_registry: Registry containing available tools
 795 |             context_evaluator: Function to evaluate response quality
 796 |             state_dim: Dimension of state features
 797 |             num_agents: Number of agents in the group
 798 |             update_interval: How often to update agents
 799 |             device: Device to run on
 800 |         """
 801 |         self.tool_registry = tool_registry
 802 |         self.context_evaluator = context_evaluator
 803 |         
 804 |         # Get all available tools
 805 |         self.tool_names = tool_registry.get_all_tool_names()
 806 |         self.action_dim = len(self.tool_names)
 807 |         
 808 |         # Define agent configurations
 809 |         agent_configs = [
 810 |             {
 811 |                 "id": f"tool_agent_{i}",
 812 |                 "state_dim": state_dim,
 813 |                 "action_dim": self.action_dim,
 814 |                 "hidden_dims": [256, 128],
 815 |                 "relative_advantage_weight": 0.7 if i > 0 else 0.3,  # Different weights
 816 |                 "entropy_coef": 0.02 if i == 0 else 0.01,  # Different exploration rates
 817 |             }
 818 |             for i in range(num_agents)
 819 |         ]
 820 |         
 821 |         # Initialize multi-agent RL system
 822 |         self.rl_system = MultiAgentGroupRL(
 823 |             agent_configs=agent_configs,
 824 |             feature_extractor=self._extract_features,
 825 |             reward_function=self._compute_reward,
 826 |             update_interval=update_interval,
 827 |             device=device,
 828 |         )
 829 |         
 830 |         # Track current episode
 831 |         self.current_episode = {agent_id: {} for agent_id in self.rl_system.agents}
 832 |     
 833 |     def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> str:
 834 |         """
 835 |         Select the best tool to use for a given user query and context.
 836 |         
 837 |         Args:
 838 |             user_query: The user's query
 839 |             context: The current conversation context
 840 |             visualizer: Optional visualizer to display the selection process
 841 |             
 842 |         Returns:
 843 |             The name of the best tool to use
 844 |         """
 845 |         # Create observation
 846 |         observation = {
 847 |             "query": user_query,
 848 |             "context": context,
 849 |         }
 850 |         
 851 |         # If visualizer is provided, start it
 852 |         if visualizer:
 853 |             visualizer.start()
 854 |             visualizer.add_execution(
 855 |                 execution_id="tool_selection",
 856 |                 tool_name="GRPO Tool Selection",
 857 |                 parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
 858 |             )
 859 |         
 860 |         # Select agent to use (round-robin for now)
 861 |         agent_id = f"tool_agent_{self.rl_system.total_steps % len(self.rl_system.agents)}"
 862 |         
 863 |         # Update visualizer if provided
 864 |         if visualizer:
 865 |             visualizer.update_progress("tool_selection", 0.3)
 866 |         
 867 |         # Get action from agent
 868 |         action_idx, _ = self.rl_system.select_action(
 869 |             agent_id=agent_id,
 870 |             observation=observation,
 871 |             deterministic=False  # Use exploratory actions during learning
 872 |         )
 873 |         
 874 |         # Update visualizer if provided
 875 |         if visualizer:
 876 |             visualizer.update_progress("tool_selection", 0.6)
 877 |         
 878 |         # Store initial information for the episode
 879 |         self.current_episode[agent_id] = {
 880 |             "observation": observation,
 881 |             "action_idx": action_idx,
 882 |             "initial_quality": self.context_evaluator(context),
 883 |         }
 884 |         
 885 |         # Map action index to tool name
 886 |         tool_name = self.tool_names[action_idx]
 887 |         
 888 |         # Complete visualization if provided
 889 |         if visualizer:
 890 |             # Create detailed metrics for visualization
 891 |             agent_data = {}
 892 |             for aid, agent in self.rl_system.agents.items():
 893 |                 # Get all tool probabilities for this agent
 894 |                 with torch.no_grad():
 895 |                     state = self.rl_system._extract_features(observation)
 896 |                     state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
 897 |                     logits = agent.policy(state_tensor)
 898 |                     probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
 899 |                 
 900 |                 # Add to metrics
 901 |                 agent_data[aid] = {
 902 |                     "selected": aid == agent_id,
 903 |                     "tool_probabilities": {
 904 |                         self.tool_names[i]: float(prob) 
 905 |                         for i, prob in enumerate(probs)
 906 |                     }
 907 |                 }
 908 |             
 909 |             # Complete the visualization
 910 |             visualizer.complete_execution(
 911 |                 execution_id="tool_selection",
 912 |                 result={
 913 |                     "selected_tool": tool_name,
 914 |                     "selected_agent": agent_id,
 915 |                     "agent_data": agent_data
 916 |                 },
 917 |                 status="success"
 918 |             )
 919 |             visualizer.stop()
 920 |         
 921 |         return tool_name
 922 |     
 923 |     def observe_result(
 924 |         self, agent_id: str, result: Any, context: Dict[str, Any], done: bool = True
 925 |     ) -> None:
 926 |         """
 927 |         Observe the result of using a tool.
 928 |         
 929 |         Args:
 930 |             agent_id: The ID of the agent that selected the tool
 931 |             result: The result of using the tool
 932 |             context: The updated context after using the tool
 933 |             done: Whether the interaction is complete
 934 |         """
 935 |         if agent_id not in self.current_episode:
 936 |             return
 937 |         
 938 |         # Get episode information
 939 |         episode = self.current_episode[agent_id]
 940 |         observation = episode["observation"]
 941 |         action_idx = episode["action_idx"]
 942 |         initial_quality = episode["initial_quality"]
 943 |         
 944 |         # Create next observation
 945 |         next_observation = {
 946 |             "query": observation["query"],
 947 |             "context": context,
 948 |             "result": result,
 949 |         }
 950 |         
 951 |         # Compute reward
 952 |         reward = self._compute_reward(observation, action_idx, result, context, initial_quality)
 953 |         
 954 |         # Record observation
 955 |         self.rl_system.observe(
 956 |             agent_id=agent_id,
 957 |             observation=observation,
 958 |             action=action_idx,
 959 |             reward=reward,
 960 |             next_observation=next_observation,
 961 |             done=done,
 962 |         )
 963 |         
 964 |         # Clear episode if done
 965 |         if done:
 966 |             self.current_episode[agent_id] = {}
 967 |     
 968 |     def _extract_features(self, observation: Dict[str, Any]) -> np.ndarray:
 969 |         """Extract features from an observation."""
 970 |         # This would ideally use an embedding model
 971 |         # For now, return a random vector as a placeholder
 972 |         return np.random.randn(768)
 973 |     
 974 |     def _compute_reward(
 975 |         self, 
 976 |         observation: Dict[str, Any], 
 977 |         action_idx: int, 
 978 |         result: Any,
 979 |         context: Dict[str, Any], 
 980 |         initial_quality: float
 981 |     ) -> float:
 982 |         """Compute the reward for an action."""
 983 |         # Compute the quality improvement
 984 |         final_quality = self.context_evaluator(context)
 985 |         quality_improvement = final_quality - initial_quality
 986 |         
 987 |         # Base reward on quality improvement
 988 |         reward = max(0, quality_improvement * 10)  # Scale for better learning
 989 |         
 990 |         return reward
 991 |     
 992 |     def update(self) -> Dict[str, Dict[str, float]]:
 993 |         """
 994 |         Trigger an update of all agents.
 995 |         
 996 |         Returns:
 997 |             Dictionary of training metrics
 998 |         """
 999 |         return self.rl_system.update_all_agents()
1000 |     
1001 |     def save(self, suffix: str = "") -> None:
1002 |         """Save all agents."""
1003 |         self.rl_system.save_agents(suffix)
1004 |     
1005 |     def load(self, suffix: str = "") -> None:
1006 |         """Load all agents."""
1007 |         self.rl_system.load_agents(suffix)
```

--------------------------------------------------------------------------------
/claude_code/lib/rl/mcts.py:
--------------------------------------------------------------------------------

```python
  1 | """
  2 | Monte Carlo Tree Search implementation for decision making in Claude Code.
  3 | This module provides an advanced MCTS implementation that can be used to select
  4 | optimal actions/tools based on simulated outcomes.
  5 | """
  6 | 
  7 | import math
  8 | import numpy as np
  9 | import random
 10 | from typing import List, Dict, Any, Callable, Tuple, Optional, Union
 11 | from dataclasses import dataclass
 12 | 
 13 | 
 14 | @dataclass
 15 | class MCTSNode:
 16 |     """Represents a node in the Monte Carlo search tree."""
 17 |     state: Any
 18 |     parent: Optional['MCTSNode'] = None
 19 |     action_taken: Any = None
 20 |     visits: int = 0
 21 |     value: float = 0.0
 22 |     children: Dict[Any, 'MCTSNode'] = None
 23 |     
 24 |     def __post_init__(self):
 25 |         if self.children is None:
 26 |             self.children = {}
 27 |     
 28 |     def is_fully_expanded(self, possible_actions: List[Any]) -> bool:
 29 |         """Check if all possible actions have been tried from this node."""
 30 |         return all(action in self.children for action in possible_actions)
 31 |     
 32 |     def is_terminal(self) -> bool:
 33 |         """Check if this node represents a terminal state."""
 34 |         # This should be customized based on your environment
 35 |         return False
 36 |     
 37 |     def best_child(self, exploration_weight: float = 1.0) -> 'MCTSNode':
 38 |         """Select the best child node according to UCB1 formula."""
 39 |         if not self.children:
 40 |             return None
 41 |             
 42 |         def ucb_score(child: MCTSNode) -> float:
 43 |             exploitation = child.value / child.visits if child.visits > 0 else 0
 44 |             exploration = math.sqrt(2 * math.log(self.visits) / child.visits) if child.visits > 0 else float('inf')
 45 |             return exploitation + exploration_weight * exploration
 46 |             
 47 |         return max(self.children.values(), key=ucb_score)
 48 | 
 49 | 
 50 | class AdvancedMCTS:
 51 |     """
 52 |     Advanced Monte Carlo Tree Search implementation with various enhancements:
 53 |     - Progressive widening for large/continuous action spaces
 54 |     - RAVE (Rapid Action Value Estimation)
 55 |     - Parallel simulations
 56 |     - Dynamic exploration weight
 57 |     - Customizable simulation and backpropagation strategies
 58 |     """
 59 |     
 60 |     def __init__(
 61 |         self, 
 62 |         state_evaluator: Callable[[Any], float],
 63 |         action_generator: Callable[[Any], List[Any]],
 64 |         simulator: Callable[[Any, Any], Any],
 65 |         max_iterations: int = 1000,
 66 |         exploration_weight: float = 1.0,
 67 |         time_limit: Optional[float] = None,
 68 |         progressive_widening: bool = False,
 69 |         pw_coef: float = 0.5,
 70 |         pw_power: float = 0.5,
 71 |         use_rave: bool = False,
 72 |         rave_equiv_param: float = 1000,
 73 |     ):
 74 |         """
 75 |         Initialize the MCTS algorithm.
 76 |         
 77 |         Args:
 78 |             state_evaluator: Function to evaluate the value of a state (terminal or not)
 79 |             action_generator: Function to generate possible actions from a state
 80 |             simulator: Function to simulate taking an action in a state, returning new state
 81 |             max_iterations: Maximum number of search iterations
 82 |             exploration_weight: Controls exploration vs exploitation balance
 83 |             time_limit: Optional time limit for search in seconds
 84 |             progressive_widening: Whether to use progressive widening for large action spaces
 85 |             pw_coef: Coefficient for progressive widening
 86 |             pw_power: Power for progressive widening
 87 |             use_rave: Whether to use RAVE (Rapid Action Value Estimation)
 88 |             rave_equiv_param: RAVE equivalence parameter
 89 |         """
 90 |         self.state_evaluator = state_evaluator
 91 |         self.action_generator = action_generator 
 92 |         self.simulator = simulator
 93 |         self.max_iterations = max_iterations
 94 |         self.exploration_weight = exploration_weight
 95 |         self.time_limit = time_limit
 96 |         
 97 |         # Progressive widening parameters
 98 |         self.progressive_widening = progressive_widening
 99 |         self.pw_coef = pw_coef
100 |         self.pw_power = pw_power
101 |         
102 |         # RAVE parameters
103 |         self.use_rave = use_rave
104 |         self.rave_equiv_param = rave_equiv_param
105 |         self.rave_values = {}  # (state, action) -> (value, visits)
106 |     
107 |     def search(self, initial_state: Any, visualizer=None) -> Any:
108 |         """
109 |         Perform MCTS search from the initial state and return the best action.
110 |         
111 |         Args:
112 |             initial_state: The starting state for the search
113 |             visualizer: Optional visualizer to show progress
114 |             
115 |         Returns:
116 |             The best action found by the search
117 |         """
118 |         root = MCTSNode(state=initial_state)
119 |         
120 |         # Initialize visualizer if provided
121 |         if visualizer:
122 |             visualizer.set_search_parameters(root, self.max_iterations)
123 |         
124 |         # Run iterations of the MCTS algorithm
125 |         for iteration in range(self.max_iterations):
126 |             # Selection phase
127 |             selected_node = self._select(root)
128 |             
129 |             # Expansion phase (if not terminal)
130 |             expanded_node = None
131 |             if not selected_node.is_terminal():
132 |                 expanded_node = self._expand(selected_node)
133 |             else:
134 |                 expanded_node = selected_node
135 |             
136 |             # Simulation phase
137 |             simulation_path = []
138 |             if visualizer:
139 |                 # Track simulation path for visualization
140 |                 current = expanded_node
141 |                 current_state = current.state
142 |                 while current.parent:
143 |                     simulation_path.insert(0, (current.parent.state, current.action_taken))
144 |                     current = current.parent
145 |             
146 |             simulation_result = self._simulate(expanded_node)
147 |             
148 |             # Backpropagation phase
149 |             self._backpropagate(expanded_node, simulation_result)
150 |             
151 |             # Update visualization
152 |             if visualizer:
153 |                 # Find current best action
154 |                 best_action = None
155 |                 if root.children:
156 |                     best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
157 |                 
158 |                 # Update visualizer
159 |                 visualizer.update_iteration(
160 |                     iteration=iteration + 1,
161 |                     selected_node=selected_node,
162 |                     expanded_node=expanded_node,
163 |                     simulation_path=simulation_path,
164 |                     simulation_result=simulation_result,
165 |                     best_action=best_action
166 |                 )
167 |             
168 |         # Return the action that leads to the child with the highest value
169 |         if not root.children:
170 |             possible_actions = self.action_generator(root.state)
171 |             if possible_actions:
172 |                 best_action = random.choice(possible_actions)
173 |                 if visualizer:
174 |                     visualizer.update_iteration(
175 |                         iteration=self.max_iterations,
176 |                         best_action=best_action
177 |                     )
178 |                 return best_action
179 |             return None
180 |         
181 |         best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
182 |         if visualizer:
183 |             visualizer.update_iteration(
184 |                 iteration=self.max_iterations,
185 |                 best_action=best_action
186 |             )
187 |         return best_action
188 |     
189 |     def _select(self, node: MCTSNode) -> MCTSNode:
190 |         """
191 |         Select a node to expand using UCB1 and progressive widening if enabled.
192 |         
193 |         Args:
194 |             node: The current node
195 |             
196 |         Returns:
197 |             The selected node for expansion
198 |         """
199 |         while not node.is_terminal():
200 |             possible_actions = self.action_generator(node.state)
201 |             
202 |             # Handle progressive widening if enabled
203 |             if self.progressive_widening:
204 |                 max_children = max(1, int(self.pw_coef * (node.visits ** self.pw_power)))
205 |                 if len(node.children) < min(max_children, len(possible_actions)):
206 |                     return node
207 |             
208 |             # If not fully expanded, select this node for expansion
209 |             if not node.is_fully_expanded(possible_actions):
210 |                 return node
211 |                 
212 |             # Otherwise, select the best child according to UCB1
213 |             node = node.best_child(self.exploration_weight)
214 |             if node is None:
215 |                 break
216 |                 
217 |         return node
218 |     
219 |     def _expand(self, node: MCTSNode) -> MCTSNode:
220 |         """
221 |         Expand the node by selecting an untried action and creating a new child node.
222 |         
223 |         Args:
224 |             node: The node to expand
225 |             
226 |         Returns:
227 |             The newly created child node
228 |         """
229 |         possible_actions = self.action_generator(node.state)
230 |         untried_actions = [a for a in possible_actions if a not in node.children]
231 |         
232 |         if not untried_actions:
233 |             return node
234 |             
235 |         action = random.choice(untried_actions)
236 |         new_state = self.simulator(node.state, action)
237 |         child_node = MCTSNode(
238 |             state=new_state,
239 |             parent=node,
240 |             action_taken=action
241 |         )
242 |         node.children[action] = child_node
243 |         return child_node
244 |     
245 |     def _simulate(self, node: MCTSNode, depth: int = 10) -> float:
246 |         """
247 |         Simulate a random playout from the given node until a terminal state or max depth.
248 |         
249 |         Args:
250 |             node: The node to start simulation from
251 |             depth: Maximum simulation depth
252 |             
253 |         Returns:
254 |             The value of the simulated outcome
255 |         """
256 |         state = node.state
257 |         current_depth = 0
258 |         
259 |         # Continue simulation until we reach a terminal state or max depth
260 |         while current_depth < depth:
261 |             if self._is_terminal_state(state):
262 |                 break
263 |                 
264 |             possible_actions = self.action_generator(state)
265 |             if not possible_actions:
266 |                 break
267 |                 
268 |             action = random.choice(possible_actions)
269 |             state = self.simulator(state, action)
270 |             current_depth += 1
271 |             
272 |         return self.state_evaluator(state)
273 |     
274 |     def _is_terminal_state(self, state: Any) -> bool:
275 |         """Determine if the state is terminal."""
276 |         # This should be customized based on your environment
277 |         return False
278 |     
279 |     def _backpropagate(self, node: MCTSNode, value: float) -> None:
280 |         """
281 |         Backpropagate the simulation result up the tree.
282 |         
283 |         Args:
284 |             node: The leaf node where simulation started
285 |             value: The value from the simulation
286 |         """
287 |         while node is not None:
288 |             node.visits += 1
289 |             node.value += value
290 |             
291 |             # Update RAVE values if enabled
292 |             if self.use_rave and node.parent is not None:
293 |                 state_hash = self._hash_state(node.parent.state)
294 |                 action = node.action_taken
295 |                 if (state_hash, action) not in self.rave_values:
296 |                     self.rave_values[(state_hash, action)] = [0, 0]  # [value, visits]
297 |                 rave_value, rave_visits = self.rave_values[(state_hash, action)]
298 |                 self.rave_values[(state_hash, action)] = [
299 |                     rave_value + value,
300 |                     rave_visits + 1
301 |                 ]
302 |                 
303 |             node = node.parent
304 |     
305 |     def _hash_state(self, state: Any) -> int:
306 |         """Create a hash of the state for RAVE table lookups."""
307 |         # This should be customized based on your state representation
308 |         if hasattr(state, "__hash__"):
309 |             return hash(state)
310 |         return hash(str(state))
311 | 
312 | 
313 | class MCTSToolSelector:
314 |     """
315 |     Specialized MCTS implementation for selecting optimal tools in Claude Code.
316 |     This class adapts the AdvancedMCTS for the specific context of tool selection.
317 |     """
318 |     
319 |     def __init__(
320 |         self,
321 |         tool_registry: Any,  # Should be a reference to the tool registry
322 |         context_evaluator: Callable,  # Function to evaluate quality of response given context
323 |         max_iterations: int = 200,
324 |         exploration_weight: float = 1.0,
325 |         use_learning: bool = True,
326 |         tool_history_weight: float = 0.7,
327 |         enable_plan_generation: bool = True,
328 |         use_semantic_similarity: bool = True,
329 |         adaptation_rate: float = 0.05
330 |     ):
331 |         """
332 |         Initialize the MCTS tool selector with enhanced intelligence.
333 |         
334 |         Args:
335 |             tool_registry: Registry containing available tools
336 |             context_evaluator: Function to evaluate response quality
337 |             max_iterations: Maximum search iterations
338 |             exploration_weight: Controls exploration vs exploitation
339 |             use_learning: Whether to use learning from past tool selections
340 |             tool_history_weight: Weight given to historical tool performance
341 |             enable_plan_generation: Generate complete tool sequences as plans
342 |             use_semantic_similarity: Use semantic similarity for tool relevance
343 |             adaptation_rate: Rate at which the system adapts to new patterns
344 |         """
345 |         self.tool_registry = tool_registry
346 |         self.context_evaluator = context_evaluator
347 |         self.use_learning = use_learning
348 |         self.tool_history_weight = tool_history_weight
349 |         self.enable_plan_generation = enable_plan_generation
350 |         self.use_semantic_similarity = use_semantic_similarity
351 |         self.adaptation_rate = adaptation_rate
352 |         
353 |         # Tool performance history by query type
354 |         self.tool_history = {}
355 |         
356 |         # Tool sequence effectiveness records
357 |         self.sequence_effectiveness = {}
358 |         
359 |         # Semantic fingerprints for tools and queries
360 |         self.tool_fingerprints = {}
361 |         self.query_clusters = {}
362 |         
363 |         # Cached simulation results for similar queries
364 |         self.simulation_cache = {}
365 |         
366 |         # Initialize the MCTS algorithm
367 |         self.mcts = AdvancedMCTS(
368 |             state_evaluator=self._evaluate_state,
369 |             action_generator=self._generate_actions,
370 |             simulator=self._simulate_action,
371 |             max_iterations=max_iterations,
372 |             exploration_weight=exploration_weight,
373 |             progressive_widening=True
374 |         )
375 |         
376 |         # Initialize tool fingerprints
377 |         self._initialize_tool_fingerprints()
378 |     
379 |     def _initialize_tool_fingerprints(self):
380 |         """Initialize semantic fingerprints for all available tools."""
381 |         if not self.use_semantic_similarity:
382 |             return
383 |             
384 |         for tool_name in self.tool_registry.get_all_tool_names():
385 |             tool = self.tool_registry.get_tool(tool_name)
386 |             if tool and hasattr(tool, 'description'):
387 |                 # In a real implementation, this would compute an embedding
388 |                 # For now, we'll use a simple keyword extraction as a placeholder
389 |                 keywords = set(word.lower() for word in tool.description.split() 
390 |                              if len(word) > 3)
391 |                 self.tool_fingerprints[tool_name] = {
392 |                     'keywords': keywords,
393 |                     'description': tool.description,
394 |                     'usage_contexts': set()
395 |                 }
396 |     
397 |     def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> Union[str, List[str]]:
398 |         """
399 |         Select the best tool to use for a given user query and context.
400 |         
401 |         Args:
402 |             user_query: The user's query
403 |             context: The current conversation context
404 |             visualizer: Optional visualizer to show the selection process
405 |             
406 |         Returns:
407 |             Either a single tool name or a sequence of tool names (if plan generation is enabled)
408 |         """
409 |         # Analyze query to determine its type/characteristics
410 |         query_type = self._analyze_query(user_query)
411 |         
412 |         # Update semantic fingerprints with this query
413 |         if self.use_semantic_similarity:
414 |             self._update_query_clusters(user_query, query_type)
415 |         
416 |         initial_state = {
417 |             'query': user_query,
418 |             'query_type': query_type,
419 |             'context': context,
420 |             'actions_taken': [],
421 |             'response_quality': 0.0,
422 |             'steps_remaining': 3 if self.enable_plan_generation else 1,
423 |             'step_results': {}
424 |         }
425 |         
426 |         # First check if we have a high-confidence cached result for similar queries
427 |         cached_result = self._check_cache(user_query, query_type)
428 |         if cached_result and random.random() > 0.1:  # 10% random exploration
429 |             if visualizer:
430 |                 visualizer.add_execution(
431 |                     execution_id="mcts_cache_hit",
432 |                     tool_name="MCTS Tool Selection (cached)",
433 |                     parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
434 |                 )
435 |                 visualizer.complete_execution(
436 |                     execution_id="mcts_cache_hit",
437 |                     result={"selected_tool": cached_result, "source": "cache"},
438 |                     status="success"
439 |                 )
440 |             return cached_result
441 |         
442 |         # Run MCTS search
443 |         best_action = self.mcts.search(initial_state, visualizer)
444 |         
445 |         # If plan generation is enabled, we might want to return a sequence
446 |         if self.enable_plan_generation:
447 |             # Extract the most promising action sequence from search
448 |             plan = self._extract_plan_from_search()
449 |             if plan and len(plan) > 1:
450 |                 # Store this plan in our cache
451 |                 self._cache_result(user_query, query_type, plan)
452 |                 return plan
453 |         
454 |         # Store single action in cache
455 |         self._cache_result(user_query, query_type, best_action)
456 |         return best_action
457 |     
458 |     def _analyze_query(self, query: str) -> str:
459 |         """
460 |         Analyze a query to determine its type and characteristics.
461 |         
462 |         Args:
463 |             query: The user query
464 |             
465 |         Returns:
466 |             A string identifying the query type
467 |         """
468 |         query_lower = query.lower()
469 |         
470 |         # Check for search-related queries
471 |         if any(term in query_lower for term in ['find', 'search', 'where', 'look for']):
472 |             return 'search'
473 |             
474 |         # Check for explanation queries
475 |         if any(term in query_lower for term in ['explain', 'how', 'why', 'what is']):
476 |             return 'explanation'
477 |             
478 |         # Check for file operation queries
479 |         if any(term in query_lower for term in ['file', 'read', 'write', 'edit', 'create']):
480 |             return 'file_operation'
481 |             
482 |         # Check for execution queries
483 |         if any(term in query_lower for term in ['run', 'execute', 'start']):
484 |             return 'execution'
485 |             
486 |         # Check for debugging queries
487 |         if any(term in query_lower for term in ['debug', 'fix', 'error', 'problem']):
488 |             return 'debugging'
489 |             
490 |         # Default to general
491 |         return 'general'
492 |     
493 |     def _update_query_clusters(self, query: str, query_type: str):
494 |         """
495 |         Update query clusters with new query information.
496 |         
497 |         Args:
498 |             query: The user query
499 |             query_type: The type of query
500 |         """
501 |         # Extract query keywords
502 |         keywords = set(word.lower() for word in query.split() if len(word) > 3)
503 |         
504 |         # Update query clusters
505 |         if query_type not in self.query_clusters:
506 |             self.query_clusters[query_type] = {
507 |                 'keywords': set(),
508 |                 'queries': []
509 |             }
510 |             
511 |         # Add keywords to cluster
512 |         self.query_clusters[query_type]['keywords'].update(keywords)
513 |         
514 |         # Add query to cluster (limit to last 50)
515 |         self.query_clusters[query_type]['queries'].append(query)
516 |         if len(self.query_clusters[query_type]['queries']) > 50:
517 |             self.query_clusters[query_type]['queries'].pop(0)
518 |             
519 |         # Update tool fingerprints with these keywords
520 |         for tool_name, fingerprint in self.tool_fingerprints.items():
521 |             # If tool has been used successfully for this query type before
522 |             if tool_name in self.tool_history.get(query_type, {}) and \
523 |                self.tool_history[query_type][tool_name]['success_rate'] > 0.6:
524 |                 fingerprint['usage_contexts'].add(query_type)
525 |     
526 |     def _check_cache(self, query: str, query_type: str) -> Union[str, List[str], None]:
527 |         """
528 |         Check if we have a cached result for a similar query.
529 |         
530 |         Args:
531 |             query: The user query
532 |             query_type: The type of query
533 |             
534 |         Returns:
535 |             A cached tool selection or None
536 |         """
537 |         if not self.use_learning or query_type not in self.tool_history:
538 |             return None
539 |             
540 |         # Find the most successful tool for this query type
541 |         type_history = self.tool_history[query_type]
542 |         best_tools = sorted(
543 |             [(tool, data['success_rate']) for tool, data in type_history.items()],
544 |             key=lambda x: x[1], 
545 |             reverse=True
546 |         )
547 |         
548 |         # Only use cache if we have a high confidence result
549 |         if best_tools and best_tools[0][1] > 0.75:
550 |             return best_tools[0][0]
551 |             
552 |         return None
553 |     
554 |     def _cache_result(self, query: str, query_type: str, action: Union[str, List[str]]):
555 |         """
556 |         Cache a result for future similar queries.
557 |         
558 |         Args:
559 |             query: The user query
560 |             query_type: The type of query
561 |             action: The selected action or plan
562 |         """
563 |         # Store in simulation cache
564 |         query_key = self._get_query_cache_key(query)
565 |         self.simulation_cache[query_key] = {
566 |             'action': action,
567 |             'timestamp': self._get_timestamp(),
568 |             'query_type': query_type
569 |         }
570 |         
571 |         # Limit cache size
572 |         if len(self.simulation_cache) > 1000:
573 |             # Remove oldest entries
574 |             oldest_key = min(self.simulation_cache.keys(), 
575 |                            key=lambda k: self.simulation_cache[k]['timestamp'])
576 |             del self.simulation_cache[oldest_key]
577 |     
578 |     def _get_query_cache_key(self, query: str) -> str:
579 |         """Generate a cache key for a query."""
580 |         # In a real implementation, this might use a hash of query embeddings
581 |         # For now, use a simple keyword approach
582 |         keywords = ' '.join(sorted(set(word.lower() for word in query.split() if len(word) > 3)))
583 |         return keywords[:100]  # Limit key length
584 |     
585 |     def _get_timestamp(self):
586 |         """Get current timestamp."""
587 |         import time
588 |         return time.time()
589 |     
590 |     def _evaluate_state(self, state: Dict[str, Any]) -> float:
591 |         """
592 |         Evaluate the quality of a state based on response quality and steps.
593 |         
594 |         Args:
595 |             state: The current state
596 |             
597 |         Returns:
598 |             A quality score
599 |         """
600 |         # Base score is the response quality
601 |         score = state['response_quality']
602 |         
603 |         # If plan generation is enabled, we want to encourage complete plans
604 |         if self.enable_plan_generation:
605 |             steps_completed = len(state['actions_taken'])
606 |             total_steps = steps_completed + state['steps_remaining']
607 |             
608 |             # Add bonus for completing more steps
609 |             if total_steps > 0:
610 |                 step_completion_bonus = steps_completed / total_steps
611 |                 score += step_completion_bonus * 0.2  # 20% bonus for step completion
612 |         
613 |         return score
614 |     
615 |     def _generate_actions(self, state: Dict[str, Any]) -> List[str]:
616 |         """
617 |         Generate possible tool actions from the current state with intelligent filtering.
618 |         
619 |         Args:
620 |             state: The current state
621 |             
622 |         Returns:
623 |             List of possible actions
624 |         """
625 |         # Get query type
626 |         query_type = state['query_type']
627 |         query = state['query']
628 |         
629 |         # Get all available tools
630 |         all_tools = set(self.tool_registry.get_all_tool_names())
631 |         
632 |         # Tools already used in this sequence
633 |         used_tools = set(state['actions_taken'])
634 |         
635 |         # Remaining tools
636 |         remaining_tools = all_tools - used_tools
637 |         
638 |         # If we're using learning, prioritize tools based on history
639 |         if self.use_learning and query_type in self.tool_history:
640 |             prioritized_tools = []
641 |             
642 |             # First, add tools that have been successful for this query type
643 |             type_history = self.tool_history[query_type]
644 |             
645 |             # Check for successful tools
646 |             for tool in remaining_tools:
647 |                 if tool in type_history and type_history[tool]['success_rate'] > 0.5:
648 |                     prioritized_tools.append(tool)
649 |                     
650 |             # If we have at least some tools, return them
651 |             if prioritized_tools and random.random() < self.tool_history_weight:
652 |                 return prioritized_tools
653 |         
654 |         # If using semantic similarity, filter by relevant tools
655 |         if self.use_semantic_similarity:
656 |             query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
657 |             
658 |             # Score tools by semantic similarity to query
659 |             scored_tools = []
660 |             for tool in remaining_tools:
661 |                 if tool in self.tool_fingerprints:
662 |                     fingerprint = self.tool_fingerprints[tool]
663 |                     
664 |                     # Calculate keyword overlap
665 |                     keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
666 |                     
667 |                     # Check if tool has been used for this query type
668 |                     context_match = 1.0 if query_type in fingerprint['usage_contexts'] else 0.0
669 |                     
670 |                     # Combined score
671 |                     score = keyword_overlap * 0.7 + context_match * 0.3
672 |                     
673 |                     scored_tools.append((tool, score))
674 |             
675 |             # Sort and filter tools
676 |             scored_tools.sort(key=lambda x: x[1], reverse=True)
677 |             
678 |             # Take top half of tools if we have enough
679 |             if len(scored_tools) > 2:
680 |                 return [t[0] for t in scored_tools[:max(2, len(scored_tools) // 2)]]
681 |         
682 |         # If we reach here, use all remaining tools
683 |         return list(remaining_tools)
684 |     
685 |     def _simulate_action(self, state: Dict[str, Any], action: str) -> Dict[str, Any]:
686 |         """
687 |         Simulate taking an action (using a tool) in the given state with enhanced modeling.
688 |         
689 |         Args:
690 |             state: The current state
691 |             action: The tool action to simulate
692 |             
693 |         Returns:
694 |             The new state after taking the action
695 |         """
696 |         # Create a new state with the action added
697 |         new_state = state.copy()
698 |         new_actions = state['actions_taken'].copy()
699 |         new_actions.append(action)
700 |         new_state['actions_taken'] = new_actions
701 |         
702 |         # Decrement steps remaining if using plan generation
703 |         if self.enable_plan_generation and new_state['steps_remaining'] > 0:
704 |             new_state['steps_remaining'] -= 1
705 |         
706 |         # Get query type and query
707 |         query_type = state['query_type']
708 |         query = state['query']
709 |         
710 |         # Simulate step result
711 |         step_results = state['step_results'].copy()
712 |         step_results[action] = self._simulate_tool_result(action, query)
713 |         new_state['step_results'] = step_results
714 |         
715 |         # Estimate tool relevance based on learning or semantic similarity
716 |         tool_relevance = self._estimate_tool_relevance(action, query, query_type)
717 |         
718 |         # Check for sequence effects (tools that work well together)
719 |         sequence_bonus = 0.0
720 |         if len(new_actions) > 1:
721 |             prev_tool = new_actions[-2]
722 |             sequence_key = f"{prev_tool}->{action}"
723 |             if sequence_key in self.sequence_effectiveness:
724 |                 sequence_bonus = self.sequence_effectiveness[sequence_key] * 0.3  # 30% weight for sequence effects
725 |         
726 |         # Update quality based on relevance and sequence effects
727 |         current_quality = state['response_quality']
728 |         quality_improvement = tool_relevance + sequence_bonus
729 |         
730 |         # Add diminishing returns effect for additional tools
731 |         if len(new_actions) > 1:
732 |             diminishing_factor = 1.0 / len(new_actions)
733 |             quality_improvement *= diminishing_factor
734 |         
735 |         new_quality = min(1.0, current_quality + quality_improvement)
736 |         new_state['response_quality'] = new_quality
737 |         
738 |         return new_state
739 |     
740 |     def _simulate_tool_result(self, tool_name: str, query: str) -> Dict[str, Any]:
741 |         """
742 |         Simulate the result of using a tool for a query.
743 |         
744 |         Args:
745 |             tool_name: The name of the tool
746 |             query: The user query
747 |             
748 |         Returns:
749 |             A simulated result
750 |         """
751 |         # In a real implementation, this would be a more sophisticated simulation
752 |         return {
753 |             "tool": tool_name,
754 |             "success_probability": self._estimate_tool_relevance(tool_name, query),
755 |             "simulated": True
756 |         }
757 |     
758 |     def _estimate_tool_relevance(self, tool_name: str, query: str, query_type: str = None) -> float:
759 |         """
760 |         Estimate how relevant a tool is for a given query using history and semantics.
761 |         
762 |         Args:
763 |             tool_name: The name of the tool
764 |             query: The user query
765 |             query_type: Optional query type
766 |             
767 |         Returns:
768 |             A relevance score between 0.0 and 1.0
769 |         """
770 |         relevance_score = 0.0
771 |         
772 |         # If we have historical data for this query type
773 |         if self.use_learning and query_type and query_type in self.tool_history and \
774 |            tool_name in self.tool_history[query_type]:
775 |             
776 |             # Get historical success rate
777 |             history_score = self.tool_history[query_type][tool_name]['success_rate']
778 |             relevance_score += history_score * self.tool_history_weight
779 |         
780 |         # If we're using semantic similarity
781 |         if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
782 |             fingerprint = self.tool_fingerprints[tool_name]
783 |             
784 |             # Calculate keyword overlap
785 |             query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
786 |             keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
787 |             
788 |             # Normalize by query keywords
789 |             if query_keywords:
790 |                 semantic_score = keyword_overlap / len(query_keywords)
791 |                 relevance_score += semantic_score * (1.0 - self.tool_history_weight)
792 |         
793 |         # Ensure we have a minimum score for exploration
794 |         if relevance_score < 0.1:
795 |             relevance_score = 0.1 + (random.random() * 0.1)  # Random boost between 0.1-0.2
796 |         
797 |         return relevance_score
798 |     
799 |     def _extract_plan_from_search(self) -> List[str]:
800 |         """
801 |         Extract a complete plan (tool sequence) from the search results.
802 |         
803 |         Returns:
804 |             A list of tool names representing the plan
805 |         """
806 |         # In a real implementation, this would extract the highest value path 
807 |         # from the search tree. For now, return None to indicate no plan extraction.
808 |         return None
809 |     
810 |     def update_tool_history(self, tool_name: str, query: str, success: bool, 
811 |                           execution_time: float, result: Any = None):
812 |         """
813 |         Update the tool history with the results of using a tool.
814 |         
815 |         Args:
816 |             tool_name: The name of the tool used
817 |             query: The query the tool was used for
818 |             success: Whether the tool was successful
819 |             execution_time: The execution time in seconds
820 |             result: Optional result of the tool execution
821 |         """
822 |         if not self.use_learning:
823 |             return
824 |             
825 |         # Get query type
826 |         query_type = self._analyze_query(query)
827 |         
828 |         # Initialize history entry if needed
829 |         if query_type not in self.tool_history:
830 |             self.tool_history[query_type] = {}
831 |             
832 |         if tool_name not in self.tool_history[query_type]:
833 |             self.tool_history[query_type][tool_name] = {
834 |                 'success_count': 0,
835 |                 'failure_count': 0,
836 |                 'total_time': 0.0,
837 |                 'success_rate': 0.0,
838 |                 'avg_time': 0.0,
839 |                 'examples': []
840 |             }
841 |         
842 |         # Update history
843 |         history = self.tool_history[query_type][tool_name]
844 |         
845 |         # Update counts
846 |         if success:
847 |             history['success_count'] += 1
848 |         else:
849 |             history['failure_count'] += 1
850 |             
851 |         # Update time
852 |         history['total_time'] += execution_time
853 |         
854 |         # Update success rate
855 |         total = history['success_count'] + history['failure_count']
856 |         history['success_rate'] = history['success_count'] / total if total > 0 else 0.0
857 |         
858 |         # Update average time
859 |         history['avg_time'] = history['total_time'] / total if total > 0 else 0.0
860 |         
861 |         # Add example (limit to last 5)
862 |         history['examples'].append({
863 |             'query': query,
864 |             'success': success,
865 |             'timestamp': self._get_timestamp()
866 |         })
867 |         if len(history['examples']) > 5:
868 |             history['examples'].pop(0)
869 |             
870 |         # Update tool fingerprint
871 |         if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
872 |             if success:
873 |                 # Add query type to usage contexts
874 |                 self.tool_fingerprints[tool_name]['usage_contexts'].add(query_type)
875 |                 
876 |                 # Add query keywords to tool fingerprint (with decay)
877 |                 query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
878 |                 current_keywords = self.tool_fingerprints[tool_name]['keywords']
879 |                 
880 |                 # Add new keywords with adaptation rate
881 |                 for keyword in query_keywords:
882 |                     if keyword not in current_keywords:
883 |                         if random.random() < self.adaptation_rate:
884 |                             current_keywords.add(keyword)
885 |     
886 |     def update_sequence_effectiveness(self, tool_sequence: List[str], success: bool, quality_score: float):
887 |         """
888 |         Update the effectiveness record for a sequence of tools.
889 |         
890 |         Args:
891 |             tool_sequence: The sequence of tools used
892 |             success: Whether the sequence was successful
893 |             quality_score: A quality score for the sequence
894 |         """
895 |         if not self.use_learning or len(tool_sequence) < 2:
896 |             return
897 |             
898 |         # Update pairwise effectiveness
899 |         for i in range(len(tool_sequence) - 1):
900 |             first_tool = tool_sequence[i]
901 |             second_tool = tool_sequence[i + 1]
902 |             sequence_key = f"{first_tool}->{second_tool}"
903 |             
904 |             if sequence_key not in self.sequence_effectiveness:
905 |                 self.sequence_effectiveness[sequence_key] = 0.5  # Initial neutral score
906 |                 
907 |             # Update score with decay
908 |             current_score = self.sequence_effectiveness[sequence_key]
909 |             if success:
910 |                 # Increase score with quality bonus
911 |                 new_score = current_score + self.adaptation_rate * quality_score
912 |             else:
913 |                 # Decrease score
914 |                 new_score = current_score - self.adaptation_rate
915 |                 
916 |             # Clamp between 0 and 1
917 |             self.sequence_effectiveness[sequence_key] = max(0.0, min(1.0, new_score))
```

--------------------------------------------------------------------------------
/claude_code/lib/ui/tool_visualizer.py:
--------------------------------------------------------------------------------

```python
   1 | #!/usr/bin/env python3
   2 | # claude_code/lib/ui/tool_visualizer.py
   3 | """Real-time tool execution visualization."""
   4 | 
   5 | import logging
   6 | import time
   7 | import json
   8 | from typing import Dict, List, Any, Optional
   9 | 
  10 | from rich.console import Console
  11 | from rich.panel import Panel
  12 | from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn
  13 | from rich.table import Table
  14 | from rich.box import ROUNDED
  15 | from rich.text import Text
  16 | from rich.live import Live
  17 | from rich.layout import Layout
  18 | from rich.syntax import Syntax
  19 | 
  20 | from ..tools.base import ToolResult
  21 | 
  22 | logger = logging.getLogger(__name__)
  23 | 
  24 | 
  25 | class ToolCallVisualizer:
  26 |     """Visualizes tool calls in real-time."""
  27 |     
  28 |     def __init__(self, console: Console):
  29 |         """Initialize the tool call visualizer.
  30 |         
  31 |         Args:
  32 |             console: Rich console instance
  33 |         """
  34 |         self.console = console
  35 |         self.active_calls: Dict[str, Dict[str, Any]] = {}
  36 |         self.completed_calls: List[Dict[str, Any]] = []
  37 |         self.layout = self._create_layout()
  38 |         self.live = Live(self.layout, console=console, refresh_per_second=4, auto_refresh=False)
  39 |         self.max_completed_calls = 5
  40 |         
  41 |         # Keep track of recent tool results for routines
  42 |         self.recent_tool_results: List[ToolResult] = []
  43 |         self.max_recent_results = 20  # Maximum number of recent results to track
  44 |         
  45 |     def _create_layout(self) -> Layout:
  46 |         """Create the layout for the tool call visualization.
  47 |         
  48 |         Returns:
  49 |             Layout object
  50 |         """
  51 |         layout = Layout()
  52 |         layout.split(
  53 |             Layout(name="active", size=3),
  54 |             Layout(name="completed", size=3)
  55 |         )
  56 |         return layout
  57 |     
  58 |     def _create_active_calls_panel(self) -> Panel:
  59 |         """Create a panel with active tool calls.
  60 |         
  61 |         Returns:
  62 |             Panel with active call information
  63 |         """
  64 |         if not self.active_calls:
  65 |             return Panel(
  66 |                 "No active tool calls",
  67 |                 title="[bold blue]Active Tool Calls[/bold blue]",
  68 |                 border_style="blue",
  69 |                 box=ROUNDED
  70 |             )
  71 |         
  72 |         # Create progress bars for each active call
  73 |         progress = Progress(
  74 |             TextColumn("[bold blue]{task.description}"),
  75 |             BarColumn(),
  76 |             TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
  77 |             TimeElapsedColumn(),
  78 |             expand=True,
  79 |             console=self.console
  80 |         )
  81 |         
  82 |         # Add tasks for each active call
  83 |         for call_id, call_info in self.active_calls.items():
  84 |             if "task_id" not in call_info:
  85 |                 # Create a new task for this call
  86 |                 description = f"{call_info['tool_name']} ({call_id[:6]}...)"
  87 |                 task_id = progress.add_task(description, total=100, completed=int(call_info["progress"] * 100))
  88 |                 call_info["task_id"] = task_id
  89 |             else:
  90 |                 # Update existing task
  91 |                 progress.update(call_info["task_id"], completed=int(call_info["progress"] * 100))
  92 |         
  93 |         # Create a table with parameter information
  94 |         table = Table(show_header=True, header_style="bold cyan", box=ROUNDED, expand=True)
  95 |         table.add_column("Tool")
  96 |         table.add_column("Parameters")
  97 |         
  98 |         for call_id, call_info in self.active_calls.items():
  99 |             # Format parameters nicely
 100 |             params = call_info.get("parameters", {})
 101 |             if params:
 102 |                 formatted_params = "\n".join([f"{k}: {self._format_value(v)}" for k, v in params.items()])
 103 |             else:
 104 |                 formatted_params = "None"
 105 |             
 106 |             table.add_row(call_info["tool_name"], formatted_params)
 107 |         
 108 |         return Panel(
 109 |             progress,
 110 |             title="[bold blue]Active Tool Calls[/bold blue]",
 111 |             border_style="blue",
 112 |             box=ROUNDED
 113 |         )
 114 |     
 115 |     def _create_completed_calls_panel(self) -> Panel:
 116 |         """Create a panel with completed tool calls.
 117 |         
 118 |         Returns:
 119 |             Panel with completed call information
 120 |         """
 121 |         if not self.completed_calls:
 122 |             return Panel(
 123 |                 "No completed tool calls",
 124 |                 title="[bold green]Recent Tool Results[/bold green]",
 125 |                 border_style="green",
 126 |                 box=ROUNDED
 127 |             )
 128 |         
 129 |         # Create a table for results
 130 |         table = Table(show_header=True, header_style="bold green", box=ROUNDED, expand=True)
 131 |         table.add_column("Tool")
 132 |         table.add_column("Status")
 133 |         table.add_column("Time")
 134 |         table.add_column("Result Preview")
 135 |         
 136 |         # Show only the most recent completed calls
 137 |         for call_info in self.completed_calls[-self.max_completed_calls:]:
 138 |             tool_name = call_info["tool_name"]
 139 |             status = call_info["status"]
 140 |             execution_time = f"{call_info['execution_time']:.2f}s"
 141 |             
 142 |             # Format result preview
 143 |             result = call_info.get("result", "")
 144 |             if result:
 145 |                 # Truncate and format result
 146 |                 preview = self._format_result_preview(result, tool_name)
 147 |             else:
 148 |                 preview = "No result"
 149 |             
 150 |             # Status with color
 151 |             status_text = Text(status)
 152 |             if status == "success":
 153 |                 status_text.stylize("bold green")
 154 |             else:
 155 |                 status_text.stylize("bold red")
 156 |             
 157 |             table.add_row(tool_name, status_text, execution_time, preview)
 158 |         
 159 |         return Panel(
 160 |             table,
 161 |             title="[bold green]Recent Tool Results[/bold green]",
 162 |             border_style="green",
 163 |             box=ROUNDED
 164 |         )
 165 |     
 166 |     def _format_value(self, value: Any) -> str:
 167 |         """Format a parameter value for display.
 168 |         
 169 |         Args:
 170 |             value: Parameter value
 171 |             
 172 |         Returns:
 173 |             Formatted string
 174 |         """
 175 |         if isinstance(value, (dict, list)):
 176 |             # Convert complex structures to JSON with indentation
 177 |             return json.dumps(value, indent=2)
 178 |         return str(value)
 179 |     
 180 |     def _format_result_preview(self, result: str, tool_name: str) -> str:
 181 |         """Format a result preview.
 182 |         
 183 |         Args:
 184 |             result: Result string
 185 |             tool_name: Name of the tool
 186 |             
 187 |         Returns:
 188 |             Formatted preview string
 189 |         """
 190 |         # Truncate result for preview
 191 |         if len(result) > 200:
 192 |             preview = result[:200] + "..."
 193 |         else:
 194 |             preview = result
 195 |         
 196 |         # Clean up newlines for display
 197 |         preview = preview.replace("\n", "\\n")
 198 |         
 199 |         return preview
 200 |     
 201 |     def start(self) -> None:
 202 |         """Start the visualization."""
 203 |         self.live.start()
 204 |         self.refresh()
 205 |     
 206 |     def stop(self) -> None:
 207 |         """Stop the visualization."""
 208 |         self.live.stop()
 209 |     
 210 |     def refresh(self) -> None:
 211 |         """Refresh the visualization."""
 212 |         # Update the layout with current information
 213 |         self.layout["active"].update(self._create_active_calls_panel())
 214 |         self.layout["completed"].update(self._create_completed_calls_panel())
 215 |         
 216 |         # Refresh the live display
 217 |         self.live.refresh()
 218 |     
 219 |     def add_tool_call(self, tool_call_id: str, tool_name: str, parameters: Dict[str, Any]) -> None:
 220 |         """Add a new tool call to visualize.
 221 |         
 222 |         Args:
 223 |             tool_call_id: ID of the tool call
 224 |             tool_name: Name of the tool
 225 |             parameters: Tool parameters
 226 |         """
 227 |         self.active_calls[tool_call_id] = {
 228 |             "tool_name": tool_name,
 229 |             "parameters": parameters,
 230 |             "start_time": time.time(),
 231 |             "progress": 0.0
 232 |         }
 233 |         self.refresh()
 234 |     
 235 |     def update_progress(self, tool_call_id: str, progress: float) -> None:
 236 |         """Update the progress of a tool call.
 237 |         
 238 |         Args:
 239 |             tool_call_id: ID of the tool call
 240 |             progress: Progress value (0-1)
 241 |         """
 242 |         if tool_call_id in self.active_calls:
 243 |             self.active_calls[tool_call_id]["progress"] = progress
 244 |             self.refresh()
 245 |     
 246 |     def complete_tool_call(self, tool_call_id: str, result: ToolResult) -> None:
 247 |         """Mark a tool call as complete.
 248 |         
 249 |         Args:
 250 |             tool_call_id: ID of the tool call
 251 |             result: Tool execution result
 252 |         """
 253 |         if tool_call_id in self.active_calls:
 254 |             call_info = self.active_calls[tool_call_id].copy()
 255 |             
 256 |             # Add result information
 257 |             call_info["result"] = result.result
 258 |             call_info["status"] = result.status
 259 |             call_info["execution_time"] = result.execution_time
 260 |             call_info["end_time"] = time.time()
 261 |             
 262 |             # Add to completed calls
 263 |             self.completed_calls.append(call_info)
 264 |             
 265 |             # Trim completed calls if needed
 266 |             if len(self.completed_calls) > self.max_completed_calls * 2:
 267 |                 self.completed_calls = self.completed_calls[-self.max_completed_calls:]
 268 |             
 269 |             # Remove from active calls
 270 |             del self.active_calls[tool_call_id]
 271 |             
 272 |             # Store in recent tool results for routines
 273 |             if result.status == "success":
 274 |                 self.recent_tool_results.append(result)
 275 |                 # Keep only the most recent results
 276 |                 if len(self.recent_tool_results) > self.max_recent_results:
 277 |                     self.recent_tool_results.pop(0)
 278 |             
 279 |             self.refresh()
 280 |     
 281 |     def show_result_detail(self, result: ToolResult) -> None:
 282 |         """Display detailed result information.
 283 |         
 284 |         Args:
 285 |             result: Tool execution result
 286 |         """
 287 |         # Detect if result might be code
 288 |         content = result.result
 289 |         if content.startswith(("def ", "class ", "import ", "from ")) or "```" in content:
 290 |             # Try to extract code blocks
 291 |             if "```" in content:
 292 |                 blocks = content.split("```")
 293 |                 # Find a code block with a language specifier
 294 |                 for i in range(1, len(blocks), 2):
 295 |                     if i < len(blocks):
 296 |                         lang = blocks[i].split("\n")[0].strip()
 297 |                         code = "\n".join(blocks[i].split("\n")[1:])
 298 |                         if lang and code:
 299 |                             # Attempt to display as syntax-highlighted code
 300 |                             try:
 301 |                                 syntax = Syntax(code, lang, theme="monokai", line_numbers=True)
 302 |                                 self.console.print(Panel(syntax, title=f"[bold]Result: {result.name}[/bold]"))
 303 |                                 return
 304 |                             except Exception:
 305 |                                 pass
 306 |             
 307 |             # If we can't extract a code block, try to detect language
 308 |             for lang in ["python", "javascript", "bash", "json"]:
 309 |                 try:
 310 |                     syntax = Syntax(content, lang, theme="monokai", line_numbers=True)
 311 |                     self.console.print(Panel(syntax, title=f"[bold]Result: {result.name}[/bold]"))
 312 |                     return
 313 |                 except Exception:
 314 |                     pass
 315 |         
 316 |         # Just print as regular text if not code or if highlighting failed
 317 |         self.console.print(Panel(content, title=f"[bold]Result: {result.name}[/bold]"))
 318 | 
 319 | 
 320 | class MCTSVisualizer:
 321 |     """Visualizes the Monte Carlo Tree Search process in real-time with enhanced intelligence."""
 322 |     
 323 |     def __init__(self, console: Console):
 324 |         """Initialize the MCTS visualizer.
 325 |         
 326 |         Args:
 327 |             console: Rich console instance
 328 |         """
 329 |         self.console = console
 330 |         self.root_node = None
 331 |         self.current_iteration = 0
 332 |         self.max_iterations = 0
 333 |         self.best_action = None
 334 |         self.active_simulation = None
 335 |         self.simulation_path = []
 336 |         self.layout = self._create_layout()
 337 |         self.live = Live(self.layout, console=console, refresh_per_second=10, auto_refresh=False)
 338 |         
 339 |         # Intelligence enhancement - track history
 340 |         self.action_history = {}  # Track action performance over time
 341 |         self.visit_distribution = {}  # Track how visits are distributed
 342 |         self.exploration_patterns = []  # Track exploration patterns
 343 |         self.quality_metrics = {"search_efficiency": 0.0, "exploration_balance": 0.0}
 344 |         self.auto_improvement_enabled = True
 345 |         
 346 |     def _create_layout(self) -> Layout:
 347 |         """Create the layout for MCTS visualization.
 348 |         
 349 |         Returns:
 350 |             Layout object
 351 |         """
 352 |         layout = Layout()
 353 |         
 354 |         # Create the main sections with more detailed visualization
 355 |         layout.split(
 356 |             Layout(name="header", size=3),
 357 |             Layout(name="main"),
 358 |             Layout(name="intelligence", size=7),  # New section for intelligence metrics
 359 |             Layout(name="stats", size=5)
 360 |         )
 361 |         
 362 |         # Split the main section into tree, simulation and action insights
 363 |         layout["main"].split_row(
 364 |             Layout(name="tree", ratio=2),
 365 |             Layout(name="simulation", ratio=1),
 366 |             Layout(name="insights", ratio=1)  # New section for action insights
 367 |         )
 368 |         
 369 |         return layout
 370 |         
 371 |     def set_search_parameters(self, root_node: Any, max_iterations: int, additional_params: Dict[str, Any] = None) -> None:
 372 |         """Set the search parameters with enhanced intelligence options.
 373 |         
 374 |         Args:
 375 |             root_node: The root node of the search tree
 376 |             max_iterations: Maximum number of iterations
 377 |             additional_params: Additional parameters for intelligent search
 378 |         """
 379 |         self.root_node = root_node
 380 |         self.max_iterations = max_iterations
 381 |         self.current_iteration = 0
 382 |         
 383 |         # Initialize intelligence tracking
 384 |         self.action_history = {}
 385 |         self.visit_distribution = {}
 386 |         self.exploration_patterns = []
 387 |         
 388 |         # Set additional intelligence parameters
 389 |         if additional_params:
 390 |             self.auto_improvement_enabled = additional_params.get('auto_improvement', True)
 391 |             
 392 |             # Apply any initial intelligence strategies
 393 |             if additional_params.get('initial_action_bias'):
 394 |                 self.action_history = additional_params['initial_action_bias']
 395 |                 
 396 |         self.refresh()
 397 |         
 398 |     def update_iteration(self, iteration: int, selected_node: Any = None, 
 399 |                         expanded_node: Any = None, simulation_path: List[Any] = None,
 400 |                         simulation_result: float = None, best_action: Any = None,
 401 |                         node_values: Dict[str, float] = None) -> None:
 402 |         """Update the current iteration status with enhanced tracking.
 403 |         
 404 |         Args:
 405 |             iteration: Current iteration number
 406 |             selected_node: Node selected in this iteration
 407 |             expanded_node: Node expanded in this iteration
 408 |             simulation_path: Path of the simulation
 409 |             simulation_result: Result of the simulation
 410 |             best_action: Current best action
 411 |             node_values: Values of important nodes in the search (for visualization)
 412 |         """
 413 |         self.current_iteration = iteration
 414 |         self.selected_node = selected_node
 415 |         self.expanded_node = expanded_node
 416 |         self.simulation_path = simulation_path or []
 417 |         self.simulation_result = simulation_result
 418 |         
 419 |         if best_action is not None:
 420 |             self.best_action = best_action
 421 |             
 422 |         # Intelligence tracking - update action history
 423 |         if self.simulation_path and simulation_result is not None:
 424 |             for _, action in self.simulation_path:
 425 |                 if action is not None:
 426 |                     action_str = str(action)
 427 |                     if action_str not in self.action_history:
 428 |                         self.action_history[action_str] = {
 429 |                             "visits": 0,
 430 |                             "total_value": 0.0,
 431 |                             "iterations": []
 432 |                         }
 433 |                     
 434 |                     self.action_history[action_str]["visits"] += 1
 435 |                     self.action_history[action_str]["total_value"] += simulation_result
 436 |                     self.action_history[action_str]["iterations"].append(iteration)
 437 |         
 438 |         # Update exploration pattern
 439 |         if selected_node:
 440 |             # Record exploration choice
 441 |             self.exploration_patterns.append({
 442 |                 "iteration": iteration,
 443 |                 "node_depth": self._get_node_depth(selected_node),
 444 |                 "node_breadth": len(getattr(selected_node, "children", {})),
 445 |                 "value_estimate": getattr(selected_node, "value", 0) / max(1, getattr(selected_node, "visits", 1))
 446 |             })
 447 |             
 448 |         # Update visit distribution
 449 |         if self.root_node and hasattr(self.root_node, "children"):
 450 |             self._update_visit_distribution()
 451 |             
 452 |         # Update quality metrics
 453 |         self._update_quality_metrics()
 454 |             
 455 |         self.refresh()
 456 |         
 457 |     def start(self) -> None:
 458 |         """Start the visualization."""
 459 |         self.live.start()
 460 |         self.refresh()
 461 |         
 462 |     def stop(self) -> None:
 463 |         """Stop the visualization."""
 464 |         self.live.stop()
 465 |         
 466 |     def refresh(self) -> None:
 467 |         """Refresh the visualization."""
 468 |         # Update header
 469 |         header_content = f"[bold blue]Enhanced Monte Carlo Tree Search - Iteration {self.current_iteration}/{self.max_iterations}[/bold blue]"
 470 |         if self.best_action:
 471 |             header_content += f" | Best Action: {self.best_action}"
 472 |             
 473 |         intelligence_status = "[green]Enabled[/green]" if self.auto_improvement_enabled else "[yellow]Disabled[/yellow]"
 474 |         header_content += f" | Intelligent Search: {intelligence_status}"
 475 |             
 476 |         self.layout["header"].update(Panel(header_content, border_style="blue"))
 477 |         
 478 |         # Update tree visualization
 479 |         self.layout["tree"].update(self._create_tree_panel())
 480 |         
 481 |         # Update simulation visualization
 482 |         self.layout["simulation"].update(self._create_simulation_panel())
 483 |         
 484 |         # Update action insights panel
 485 |         self.layout["insights"].update(self._create_insights_panel())
 486 |         
 487 |         # Update intelligence metrics
 488 |         self.layout["intelligence"].update(self._create_intelligence_panel())
 489 |         
 490 |         # Update stats
 491 |         self.layout["stats"].update(self._create_stats_panel())
 492 |         
 493 |         # Refresh the live display
 494 |         self.live.refresh()
 495 |         
 496 |     def _create_tree_panel(self) -> Panel:
 497 |         """Create a panel showing the current state of the search tree.
 498 |         
 499 |         Returns:
 500 |             Panel with tree visualization
 501 |         """
 502 |         if not self.root_node:
 503 |             return Panel("No search tree initialized", title="[bold]Search Tree[/bold]")
 504 |             
 505 |         # Create a table to show the tree structure
 506 |         from rich.tree import Tree
 507 |         from rich import box
 508 |         
 509 |         tree = Tree("🔍 Root Node", guide_style="bold blue")
 510 |         
 511 |         # Limit the depth and breadth for display
 512 |         max_depth = 3
 513 |         max_children = 5
 514 |         
 515 |         def add_node(node, tree_node, depth=0, path=None):
 516 |             if depth >= max_depth or not node or not hasattr(node, "children"):
 517 |                 return
 518 |                 
 519 |             if path is None:
 520 |                 path = []
 521 |                 
 522 |             # Add children nodes
 523 |             children = list(node.children.items())
 524 |             if not children:
 525 |                 return
 526 |                 
 527 |             # Sort children by a combination of visits and value
 528 |             def node_score(node_pair):
 529 |                 child_node = node_pair[1]
 530 |                 visits = getattr(child_node, "visits", 0)
 531 |                 value = getattr(child_node, "value", 0)
 532 |                 
 533 |                 # Combine visits and value for scoring
 534 |                 if visits > 0:
 535 |                     # Use UCB-style formula for ranking
 536 |                     exploitation = value / visits
 537 |                     exploration = (2 * 0.5 * (math.log(node.visits) / visits)) if node.visits > 0 and visits > 0 else 0
 538 |                     return exploitation + exploration
 539 |                 return 0
 540 |             
 541 |             # Sort by this smarter formula
 542 |             children.sort(key=node_score, reverse=True)
 543 |             children = children[:max_children]
 544 |             
 545 |             for action, child in children:
 546 |                 # Format node information
 547 |                 visits = getattr(child, "visits", 0)
 548 |                 value = getattr(child, "value", 0)
 549 |                 
 550 |                 # Highlight the node with more sophisticated coloring
 551 |                 style = ""
 552 |                 if child == self.selected_node:
 553 |                     style = "bold yellow"
 554 |                 elif child == self.expanded_node:
 555 |                     style = "bold green"
 556 |                 else:
 557 |                     # Color based on value
 558 |                     if visits > 0:
 559 |                         avg_value = value / visits
 560 |                         if avg_value > 0.7:
 561 |                             style = "green"
 562 |                         elif avg_value > 0.4:
 563 |                             style = "blue"
 564 |                         elif avg_value > 0.2:
 565 |                             style = "yellow"
 566 |                         else:
 567 |                             style = "red"
 568 |                 
 569 |                 # Create the node label with enhanced information
 570 |                 current_path = path + [action]
 571 |                 
 572 |                 if visits > 0:
 573 |                     avg_value = value / visits
 574 |                     confidence = min(1.0, math.sqrt(visits) / 5) * 100  # Simple confidence estimate
 575 |                     label = f"[{style}]{action}: (Visits: {visits}, Value: {avg_value:.3f}, Conf: {confidence:.0f}%)[/{style}]"
 576 |                 else:
 577 |                     label = f"[{style}]{action}: (New)[/{style}]"
 578 |                 
 579 |                 # Add the child node to the tree
 580 |                 child_tree = tree_node.add(label)
 581 |                 
 582 |                 # Recursively add its children
 583 |                 add_node(child, child_tree, depth + 1, current_path)
 584 |         
 585 |         # Start building the tree from the root
 586 |         if hasattr(self.root_node, "children"):
 587 |             # Add math import for node scoring
 588 |             import math
 589 |             add_node(self.root_node, tree)
 590 |             
 591 |         return Panel(tree, title="[bold]Search Tree[/bold]", border_style="blue")
 592 |     
 593 |     def _create_simulation_panel(self) -> Panel:
 594 |         """Create a panel showing the current simulation with enhanced analytics.
 595 |         
 596 |         Returns:
 597 |             Panel with simulation visualization
 598 |         """
 599 |         if not self.simulation_path:
 600 |             return Panel("No active simulation", title="[bold]Current Simulation[/bold]")
 601 |             
 602 |         # Create a list of simulation steps
 603 |         from rich.table import Table
 604 |         
 605 |         table = Table(box=None, expand=True)
 606 |         table.add_column("Step")
 607 |         table.add_column("Action")
 608 |         table.add_column("Expected Value")  # New column
 609 |         
 610 |         for i, (state, action) in enumerate(self.simulation_path):
 611 |             # Get expected value for this action
 612 |             action_str = str(action) if action is not None else "None"
 613 |             expected_value = "N/A"
 614 |             
 615 |             if action_str in self.action_history:
 616 |                 history = self.action_history[action_str]
 617 |                 if history["visits"] > 0:
 618 |                     expected_value = f"{history['total_value'] / history['visits']:.3f}"
 619 |             
 620 |             table.add_row(f"Step {i+1}", f"{action}", expected_value)
 621 |             
 622 |         if self.simulation_result is not None:
 623 |             # Add path quality metric
 624 |             path_quality = "Low"
 625 |             if self.simulation_result > 0.7:
 626 |                 path_quality = "[bold green]High[/bold green]"
 627 |             elif self.simulation_result > 0.4:
 628 |                 path_quality = "[yellow]Medium[/yellow]"
 629 |             else:
 630 |                 path_quality = "[red]Low[/red]"
 631 |                 
 632 |             table.add_row("Result", 
 633 |                         f"[bold green]{self.simulation_result:.3f}[/bold green]", 
 634 |                         f"Path Quality: {path_quality}")
 635 |             
 636 |         return Panel(table, title="[bold]Current Simulation[/bold]", border_style="green")
 637 |     
 638 |     def _create_insights_panel(self) -> Panel:
 639 |         """Create a panel showing action insights from learned patterns.
 640 |         
 641 |         Returns:
 642 |             Panel with action insights
 643 |         """
 644 |         from rich.table import Table
 645 |         
 646 |         if not self.action_history:
 647 |             return Panel("No action insights available yet", title="[bold]Action Insights[/bold]")
 648 |             
 649 |         # Get top performing actions
 650 |         top_actions = []
 651 |         for action, data in self.action_history.items():
 652 |             if data["visits"] >= 3:  # Only consider actions with enough samples
 653 |                 avg_value = data["total_value"] / data["visits"]
 654 |                 top_actions.append((action, avg_value, data["visits"]))
 655 |                 
 656 |         # Sort by value and take top 5
 657 |         top_actions.sort(key=lambda x: x[1], reverse=True)
 658 |         top_actions = top_actions[:5]
 659 |         
 660 |         # Create insights table
 661 |         table = Table(box=None, expand=True)
 662 |         table.add_column("Action")
 663 |         table.add_column("Avg Value")
 664 |         table.add_column("Visits")
 665 |         table.add_column("Trend")
 666 |         
 667 |         for action, avg_value, visits in top_actions:
 668 |             # Generate trend indicator based on recent performance
 669 |             trend = "→"
 670 |             history = self.action_history[action]["iterations"]
 671 |             if len(history) >= 5:
 672 |                 recent = set(history[-3:])  # Last 3 iterations
 673 |                 if self.current_iteration - max(recent) <= 5:
 674 |                     trend = "↑"  # Recently used
 675 |                 elif self.current_iteration - max(recent) >= 10:
 676 |                     trend = "↓"  # Not used recently
 677 |             
 678 |             # Color code based on value
 679 |             if avg_value > 0.7:
 680 |                 value_str = f"[green]{avg_value:.3f}[/green]"
 681 |             elif avg_value > 0.4:
 682 |                 value_str = f"[blue]{avg_value:.3f}[/blue]"
 683 |             else:
 684 |                 value_str = f"[yellow]{avg_value:.3f}[/yellow]"
 685 |                 
 686 |             table.add_row(str(action), value_str, str(visits), trend)
 687 |             
 688 |         return Panel(table, title="[bold]Action Insights[/bold]", border_style="cyan")
 689 |     
 690 |     def _create_intelligence_panel(self) -> Panel:
 691 |         """Create a panel showing intelligence metrics and learning patterns.
 692 |         
 693 |         Returns:
 694 |             Panel with intelligence visualization
 695 |         """
 696 |         from rich.table import Table
 697 |         from rich.columns import Columns
 698 |         
 699 |         # Create metrics table
 700 |         metrics_table = Table(box=None, expand=True)
 701 |         metrics_table.add_column("Metric")
 702 |         metrics_table.add_column("Value")
 703 |         
 704 |         # Add search quality metrics
 705 |         for metric, value in self.quality_metrics.items():
 706 |             formatted_name = metric.replace("_", " ").title()
 707 |             # Color based on value
 708 |             if value > 0.7:
 709 |                 value_str = f"[green]{value:.2f}[/green]"
 710 |             elif value > 0.4:
 711 |                 value_str = f"[blue]{value:.2f}[/blue]"
 712 |             else:
 713 |                 value_str = f"[yellow]{value:.2f}[/yellow]"
 714 |                 
 715 |             metrics_table.add_row(formatted_name, value_str)
 716 |             
 717 |         # Create exploration table
 718 |         exploration_table = Table(box=None, expand=True)
 719 |         exploration_table.add_column("Pattern")
 720 |         exploration_table.add_column("Value")
 721 |         
 722 |         # Add exploration patterns
 723 |         if self.exploration_patterns:
 724 |             # Average depth of exploration
 725 |             avg_depth = sum(p["node_depth"] for p in self.exploration_patterns) / len(self.exploration_patterns)
 726 |             exploration_table.add_row("Avg Exploration Depth", f"{avg_depth:.2f}")
 727 |             
 728 |             # Depth trend (increasing or decreasing)
 729 |             if len(self.exploration_patterns) >= 5:
 730 |                 recent_avg = sum(p["node_depth"] for p in self.exploration_patterns[-5:]) / 5
 731 |                 earlier_avg = sum(p["node_depth"] for p in self.exploration_patterns[:-5]) / max(1, len(self.exploration_patterns) - 5)
 732 |                 
 733 |                 if recent_avg > earlier_avg * 1.2:
 734 |                     trend = "[green]Deepening[/green]"
 735 |                 elif recent_avg < earlier_avg * 0.8:
 736 |                     trend = "[yellow]Shallowing[/yellow]"
 737 |                 else:
 738 |                     trend = "[blue]Stable[/blue]"
 739 |                     
 740 |                 exploration_table.add_row("Depth Trend", trend)
 741 |                 
 742 |             # Exploration-exploitation balance
 743 |             if len(self.exploration_patterns) >= 3:
 744 |                 # Higher values = more exploitation of known good paths
 745 |                 exploitation_ratio = sum(1 for p in self.exploration_patterns[-10:] 
 746 |                                      if p["value_estimate"] > 0.5) / min(10, len(self.exploration_patterns))
 747 |                 
 748 |                 if exploitation_ratio > 0.7:
 749 |                     balance = "[yellow]Heavy Exploitation[/yellow]"
 750 |                 elif exploitation_ratio < 0.3:
 751 |                     balance = "[yellow]Heavy Exploration[/yellow]"
 752 |                 else:
 753 |                     balance = "[green]Balanced[/green]"
 754 |                     
 755 |                 exploration_table.add_row("Search Balance", balance)
 756 |                 
 757 |         # Combine tables into columns
 758 |         columns = Columns([metrics_table, exploration_table])
 759 |         
 760 |         return Panel(columns, title="[bold]Intelligence Metrics[/bold]", border_style="magenta")
 761 |     
 762 |     def _create_stats_panel(self) -> Panel:
 763 |         """Create a panel showing search statistics with enhanced metrics.
 764 |         
 765 |         Returns:
 766 |             Panel with statistics
 767 |         """
 768 |         if not self.root_node:
 769 |             return Panel("No statistics available", title="[bold]Search Statistics[/bold]")
 770 |             
 771 |         # Collect statistics
 772 |         total_nodes = 0
 773 |         max_depth = 0
 774 |         total_visits = getattr(self.root_node, "visits", 0)
 775 |         avg_branching = 0
 776 |         
 777 |         def count_nodes(node, depth=0):
 778 |             nonlocal total_nodes, max_depth, avg_branching
 779 |             if not node or not hasattr(node, "children"):
 780 |                 return
 781 |                 
 782 |             total_nodes += 1
 783 |             max_depth = max(max_depth, depth)
 784 |             
 785 |             # Count children for branching factor
 786 |             num_children = len(node.children)
 787 |             if num_children > 0:
 788 |                 avg_branching += num_children
 789 |                 
 790 |             for child in node.children.values():
 791 |                 count_nodes(child, depth + 1)
 792 |                 
 793 |         count_nodes(self.root_node)
 794 |         
 795 |         # Calculate average branching factor
 796 |         if total_nodes > 1:  # Root node doesn't count for avg branching
 797 |             avg_branching /= (total_nodes - 1) 
 798 |         
 799 |         # Create a table of statistics
 800 |         from rich.table import Table
 801 |         
 802 |         table = Table(box=None, expand=True)
 803 |         table.add_column("Metric")
 804 |         table.add_column("Value")
 805 |         
 806 |         table.add_row("Total Nodes", str(total_nodes))
 807 |         table.add_row("Max Depth", str(max_depth))
 808 |         table.add_row("Total Visits", str(total_visits))
 809 |         table.add_row("Avg Branching", f"{avg_branching:.2f}")
 810 |         table.add_row("Progress", f"{self.current_iteration / self.max_iterations:.1%}")
 811 |         
 812 |         # Efficiency estimate (higher is better)
 813 |         if total_visits > 0:
 814 |             visit_efficiency = total_nodes / total_visits
 815 |             efficiency_str = f"{visit_efficiency:.2f}"
 816 |             table.add_row("Search Efficiency", efficiency_str)
 817 |         
 818 |         return Panel(table, title="[bold]Search Statistics[/bold]", border_style="magenta")
 819 |     
 820 |     def _get_node_depth(self, node):
 821 |         """Calculate the depth of a node in the tree."""
 822 |         depth = 0
 823 |         current = node
 824 |         while getattr(current, "parent", None) is not None:
 825 |             depth += 1
 826 |             current = current.parent
 827 |         return depth
 828 |     
 829 |     def _update_visit_distribution(self):
 830 |         """Update the distribution of visits across the tree."""
 831 |         levels = {}
 832 |         
 833 |         def count_visits_by_level(node, depth=0):
 834 |             if not node or not hasattr(node, "children"):
 835 |                 return
 836 |                 
 837 |             # Initialize level if not present
 838 |             if depth not in levels:
 839 |                 levels[depth] = {"visits": 0, "nodes": 0}
 840 |                 
 841 |             # Update level stats
 842 |             levels[depth]["visits"] += getattr(node, "visits", 0)
 843 |             levels[depth]["nodes"] += 1
 844 |             
 845 |             # Process children
 846 |             for child in node.children.values():
 847 |                 count_visits_by_level(child, depth + 1)
 848 |                 
 849 |         # Start counting from root
 850 |         count_visits_by_level(self.root_node)
 851 |         
 852 |         # Update visit distribution
 853 |         self.visit_distribution = levels
 854 |     
 855 |     def _update_quality_metrics(self):
 856 |         """Update quality metrics for the search process."""
 857 |         # Search efficiency - ratio of valuable nodes to total nodes
 858 |         # Higher values indicate more efficient search
 859 |         if self.visit_distribution:
 860 |             useful_visits = sum(level["visits"] for depth, level in self.visit_distribution.items() 
 861 |                                if depth > 0)  # Exclude root
 862 |             total_visits = sum(level["visits"] for level in self.visit_distribution.values())
 863 |             
 864 |             if total_visits > 0:
 865 |                 self.quality_metrics["search_efficiency"] = useful_visits / total_visits
 866 |             
 867 |         # Exploration balance - how well the algorithm balances exploration vs exploitation
 868 |         if self.exploration_patterns:
 869 |             # Calculate variance in exploration depth
 870 |             depths = [p["node_depth"] for p in self.exploration_patterns[-20:]]  # Last 20 iterations
 871 |             if depths:
 872 |                 import statistics
 873 |                 try:
 874 |                     depth_variance = statistics.variance(depths) if len(depths) > 1 else 0
 875 |                     # Normalize to 0-1 range (higher variance = more balanced exploration)
 876 |                     normalized_variance = min(1.0, depth_variance / 5.0)  # Assume variance > 5 is high
 877 |                     self.quality_metrics["exploration_balance"] = normalized_variance
 878 |                 except statistics.StatisticsError:
 879 |                     pass
 880 | 
 881 | 
 882 | class ParallelExecutionVisualizer:
 883 |     """Visualizes parallel execution of tool calls in real-time."""
 884 |     
 885 |     def __init__(self, console: Console):
 886 |         """Initialize the parallel execution visualizer.
 887 |         
 888 |         Args:
 889 |             console: Rich console instance
 890 |         """
 891 |         self.console = console
 892 |         self.active_executions = {}
 893 |         self.completed_executions = []
 894 |         self.layout = self._create_layout()
 895 |         self.live = Live(self.layout, console=console, refresh_per_second=10, auto_refresh=False)
 896 |         
 897 |     def _create_layout(self) -> Layout:
 898 |         """Create the layout for parallel execution visualization.
 899 |         
 900 |         Returns:
 901 |             Layout object
 902 |         """
 903 |         layout = Layout()
 904 |         
 905 |         # Create the main sections
 906 |         layout.split(
 907 |             Layout(name="header", size=3),
 908 |             Layout(name="executions"),
 909 |             Layout(name="metrics", size=5)
 910 |         )
 911 |         
 912 |         return layout
 913 |         
 914 |     def add_execution(self, execution_id: str, tool_name: str, parameters: Dict[str, Any]) -> None:
 915 |         """Add a new execution to visualize.
 916 |         
 917 |         Args:
 918 |             execution_id: Unique ID for the execution
 919 |             tool_name: Name of the tool being executed
 920 |             parameters: Parameters for the execution
 921 |         """
 922 |         self.active_executions[execution_id] = {
 923 |             "tool_name": tool_name,
 924 |             "parameters": parameters,
 925 |             "start_time": time.time(),
 926 |             "progress": 0.0,
 927 |             "status": "running"
 928 |         }
 929 |         self.refresh()
 930 |         
 931 |     def update_progress(self, execution_id: str, progress: float) -> None:
 932 |         """Update the progress of an execution.
 933 |         
 934 |         Args:
 935 |             execution_id: ID of the execution
 936 |             progress: Progress value (0-1)
 937 |         """
 938 |         if execution_id in self.active_executions:
 939 |             self.active_executions[execution_id]["progress"] = progress
 940 |             self.refresh()
 941 |             
 942 |     def complete_execution(self, execution_id: str, result: Any, status: str = "success") -> None:
 943 |         """Mark an execution as complete.
 944 |         
 945 |         Args:
 946 |             execution_id: ID of the execution
 947 |             result: Result of the execution
 948 |             status: Status of completion
 949 |         """
 950 |         if execution_id in self.active_executions:
 951 |             execution = self.active_executions[execution_id].copy()
 952 |             execution["end_time"] = time.time()
 953 |             execution["duration"] = execution["end_time"] - execution["start_time"]
 954 |             execution["result"] = result
 955 |             execution["status"] = status
 956 |             
 957 |             # Move to completed executions
 958 |             self.completed_executions.append(execution)
 959 |             del self.active_executions[execution_id]
 960 |             
 961 |             # Limit completed executions list
 962 |             if len(self.completed_executions) > 20:
 963 |                 self.completed_executions = self.completed_executions[-20:]
 964 |                 
 965 |             self.refresh()
 966 |             
 967 |     def start(self) -> None:
 968 |         """Start the visualization."""
 969 |         self.live.start()
 970 |         self.refresh()
 971 |         
 972 |     def stop(self) -> None:
 973 |         """Stop the visualization."""
 974 |         self.live.stop()
 975 |         
 976 |     def refresh(self) -> None:
 977 |         """Refresh the visualization."""
 978 |         # Update header
 979 |         header_content = f"[bold blue]Parallel Execution Monitor[/bold blue] | Active: {len(self.active_executions)} | Completed: {len(self.completed_executions)}"
 980 |         self.layout["header"].update(Panel(header_content, border_style="blue"))
 981 |         
 982 |         # Update executions visualization
 983 |         self.layout["executions"].update(self._create_executions_panel())
 984 |         
 985 |         # Update metrics
 986 |         self.layout["metrics"].update(self._create_metrics_panel())
 987 |         
 988 |         # Refresh the live display
 989 |         self.live.refresh()
 990 |         
 991 |     def _create_executions_panel(self) -> Panel:
 992 |         """Create a panel showing active and recent executions.
 993 |         
 994 |         Returns:
 995 |             Panel with executions visualization
 996 |         """
 997 |         from rich.table import Table
 998 |         from rich.progress import BarColumn, Progress, TextColumn
 999 |         
1000 |         # Create progress bars for active executions
1001 |         progress_group = Table.grid(expand=True)
1002 |         
1003 |         if self.active_executions:
1004 |             # Create a progress group
1005 |             progress = Progress(
1006 |                 TextColumn("[bold blue]{task.description}"),
1007 |                 BarColumn(bar_width=None),
1008 |                 TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
1009 |                 TextColumn("| Elapsed: {task.elapsed:.2f}s"),
1010 |                 expand=True
1011 |             )
1012 |             
1013 |             # Add tasks for each active execution
1014 |             for exec_id, execution in self.active_executions.items():
1015 |                 tool_name = execution["tool_name"]
1016 |                 description = f"{tool_name} ({exec_id[:8]}...)"
1017 |                 task_id = progress.add_task(description, total=100, completed=int(execution["progress"] * 100))
1018 |                 
1019 |             progress_group.add_row(progress)
1020 |         else:
1021 |             progress_group.add_row("[italic]No active executions[/italic]")
1022 |             
1023 |         # Create a table for completed executions
1024 |         completed_table = Table(show_header=True, header_style="bold blue", expand=True)
1025 |         completed_table.add_column("Tool")
1026 |         completed_table.add_column("Duration")
1027 |         completed_table.add_column("Status")
1028 |         completed_table.add_column("Result Preview")
1029 |         
1030 |         if self.completed_executions:
1031 |             # Most recent first
1032 |             for execution in reversed(self.completed_executions[-10:]):
1033 |                 tool_name = execution["tool_name"]
1034 |                 duration = f"{execution['duration']:.2f}s"
1035 |                 status = execution["status"]
1036 |                 
1037 |                 # Format result preview
1038 |                 result = str(execution.get("result", ""))
1039 |                 preview = result[:50] + "..." if len(result) > 50 else result
1040 |                 
1041 |                 # Add status with color
1042 |                 status_text = f"[green]{status}[/green]" if status == "success" else f"[red]{status}[/red]"
1043 |                 
1044 |                 completed_table.add_row(tool_name, duration, status_text, preview)
1045 |         else:
1046 |             completed_table.add_row("[italic]No completed executions[/italic]", "", "", "")
1047 |             
1048 |         # Combine both into a layout
1049 |         layout = Layout()
1050 |         layout.split(
1051 |             Layout(name="active", size=len(self.active_executions) * 2 + 3 if self.active_executions else 3),
1052 |             Layout(name="completed")
1053 |         )
1054 |         layout["active"].update(Panel(progress_group, title="[bold]Active Executions[/bold]", border_style="blue"))
1055 |         layout["completed"].update(Panel(completed_table, title="[bold]Recent Completions[/bold]", border_style="green"))
1056 |         
1057 |         return layout
1058 |     
1059 |     def _create_metrics_panel(self) -> Panel:
1060 |         """Create a panel showing execution metrics.
1061 |         
1062 |         Returns:
1063 |             Panel with metrics visualization
1064 |         """
1065 |         from rich.table import Table
1066 |         
1067 |         # Calculate metrics
1068 |         total_executions = len(self.completed_executions)
1069 |         successful = sum(1 for e in self.completed_executions if e["status"] == "success")
1070 |         failed = total_executions - successful
1071 |         
1072 |         if total_executions > 0:
1073 |             success_rate = successful / total_executions
1074 |             avg_duration = sum(e["duration"] for e in self.completed_executions) / total_executions
1075 |         else:
1076 |             success_rate = 0
1077 |             avg_duration = 0
1078 |             
1079 |         # Create metrics table
1080 |         table = Table(box=None, expand=True)
1081 |         table.add_column("Metric")
1082 |         table.add_column("Value")
1083 |         
1084 |         table.add_row("Total Executions", str(total_executions))
1085 |         table.add_row("Success Rate", f"{success_rate:.1%}")
1086 |         table.add_row("Average Duration", f"{avg_duration:.2f}s")
1087 |         table.add_row("Current Parallelism", str(len(self.active_executions)))
1088 |         
1089 |         return Panel(table, title="[bold]Execution Metrics[/bold]", border_style="magenta")
1090 | 
1091 | 
1092 | class MultiPanelLayout:
1093 |     """Creates a multi-panel layout for the entire UI."""
1094 |     
1095 |     def __init__(self, console: Console):
1096 |         """Initialize the multi-panel layout.
1097 |         
1098 |         Args:
1099 |             console: Rich console instance
1100 |         """
1101 |         self.console = console
1102 |         self.layout = self._create_layout()
1103 |         self.live = Live(self.layout, console=console, refresh_per_second=4, auto_refresh=False)
1104 |         
1105 |     def _create_layout(self) -> Layout:
1106 |         """Create the main application layout.
1107 |         
1108 |         Returns:
1109 |             Layout object
1110 |         """
1111 |         layout = Layout()
1112 |         
1113 |         # Split into three main sections
1114 |         layout.split(
1115 |             Layout(name="conversation", ratio=3),
1116 |             Layout(name="tools", ratio=2),
1117 |             Layout(name="input", ratio=1)
1118 |         )
1119 |         
1120 |         # Further split the tools section
1121 |         layout["tools"].split_row(
1122 |             Layout(name="active_tools"),
1123 |             Layout(name="cost", size=30)
1124 |         )
1125 |         
1126 |         return layout
1127 |     
1128 |     def start(self) -> None:
1129 |         """Start the live display."""
1130 |         self.live.start()
1131 |     
1132 |     def stop(self) -> None:
1133 |         """Stop the live display."""
1134 |         self.live.stop()
1135 |     
1136 |     def refresh(self) -> None:
1137 |         """Refresh the display."""
1138 |         self.live.refresh()
1139 |     
1140 |     def update_section(self, section: str, content: Any) -> None:
1141 |         """Update a section of the layout.
1142 |         
1143 |         Args:
1144 |             section: Section name
1145 |             content: Content to display
1146 |         """
1147 |         if section in self.layout:
1148 |             self.layout[section].update(content)
1149 |             self.refresh()
```

--------------------------------------------------------------------------------
/mcp_server.py:
--------------------------------------------------------------------------------

```python
   1 | #!/usr/bin/env python3
   2 | """
   3 | Model Context Protocol (MCP) Server Implementation
   4 | 
   5 | This module implements the Model Context Protocol server capabilities,
   6 | allowing the assistant to be used as an MCP-compatible context provider.
   7 | """
   8 | 
   9 | import os
  10 | import json
  11 | import time
  12 | import uuid
  13 | import sys
  14 | import logging
  15 | import asyncio
  16 | import tiktoken
  17 | import re
  18 | from datetime import datetime
  19 | from typing import Dict, List, Any, Optional, Union, AsyncGenerator
  20 | from fastapi import FastAPI, HTTPException, Request, Response, Depends, BackgroundTasks, Query
  21 | from fastapi.responses import JSONResponse, StreamingResponse
  22 | from fastapi.middleware.cors import CORSMiddleware
  23 | from fastapi.staticfiles import StaticFiles
  24 | from fastapi.templating import Jinja2Templates
  25 | from pydantic import BaseModel, Field
  26 | import uvicorn
  27 | import openai
  28 | from openai import OpenAI
  29 | import prometheus_client
  30 | from prometheus_client import Counter, Histogram, Gauge
  31 | 
  32 | # Configure logging
  33 | logging.basicConfig(
  34 |     level=logging.INFO,
  35 |     format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
  36 | )
  37 | logger = logging.getLogger("mcp_server")
  38 | 
  39 | # MCP Protocol Models
  40 | class MCPHealthResponse(BaseModel):
  41 |     """Health check response for MCP protocol"""
  42 |     status: str = "healthy"
  43 |     version: str = "1.0.0"
  44 |     protocol_version: str = "0.1.0"
  45 |     provider: str = "OpenAI Code Assistant"
  46 |     models: List[str] = ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"]
  47 |     uptime: Optional[float] = None
  48 |     request_count: Optional[int] = None
  49 |     cache_hit_ratio: Optional[float] = None
  50 | 
  51 | class MCPContextRequest(BaseModel):
  52 |     """Request for context generation from a prompt template"""
  53 |     prompt_id: str
  54 |     parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to fill in the prompt template")
  55 |     model: Optional[str] = Field(None, description="Model to use for context generation")
  56 |     stream: bool = Field(False, description="Whether to stream the response")
  57 |     user: Optional[str] = Field(None, description="User identifier for tracking")
  58 |     conversation_id: Optional[str] = Field(None, description="Conversation identifier")
  59 |     message_id: Optional[str] = Field(None, description="Message identifier")
  60 | 
  61 | class MCPContextResponse(BaseModel):
  62 |     """Response containing generated context"""
  63 |     context: str = Field(..., description="The generated context")
  64 |     context_id: str = Field(..., description="Unique identifier for this context")
  65 |     model: str = Field(..., description="Model used for generation")
  66 |     usage: Dict[str, int] = Field(default_factory=dict, description="Token usage statistics")
  67 |     metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
  68 | 
  69 | class MCPErrorResponse(BaseModel):
  70 |     """Error response format"""
  71 |     error: str = Field(..., description="Error message")
  72 |     error_type: str = Field(..., description="Type of error")
  73 |     status_code: int = Field(..., description="HTTP status code")
  74 |     details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
  75 | 
  76 | class MCPPromptTemplate(BaseModel):
  77 |     """Prompt template definition"""
  78 |     id: str = Field(..., description="Unique identifier for the template")
  79 |     template: str = Field(..., description="The prompt template with parameter placeholders")
  80 |     description: Optional[str] = Field(None, description="Description of the template")
  81 |     parameters: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Parameter definitions")
  82 |     default_model: Optional[str] = Field(None, description="Default model to use with this template")
  83 |     metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
  84 | 
  85 | class MCPPromptLibraryResponse(BaseModel):
  86 |     """Response containing a list of prompt templates"""
  87 |     prompts: List[MCPPromptTemplate] = Field(..., description="List of prompt templates")
  88 |     count: int = Field(..., description="Number of templates")
  89 | 
  90 | # MCP Server Implementation
  91 | # Prometheus metrics
  92 | REQUEST_COUNT = Counter('mcp_requests_total', 'Total number of requests processed', ['endpoint', 'status'])
  93 | REQUEST_LATENCY = Histogram('mcp_request_latency_seconds', 'Request latency in seconds', ['endpoint'])
  94 | CACHE_HIT = Counter('mcp_cache_hits_total', 'Total number of cache hits')
  95 | CACHE_MISS = Counter('mcp_cache_misses_total', 'Total number of cache misses')
  96 | ACTIVE_CONNECTIONS = Gauge('mcp_active_connections', 'Number of active connections')
  97 | TOKEN_USAGE = Counter('mcp_token_usage_total', 'Total number of tokens used', ['model', 'type'])
  98 | 
  99 | # Cache implementation
 100 | class CacheManager:
 101 |     """Manages caching for context responses"""
 102 |     
 103 |     def __init__(self, cache_type="memory", redis_url=None, ttl=3600):
 104 |         self.cache_type = cache_type
 105 |         self.redis_url = redis_url
 106 |         self.ttl = ttl
 107 |         self.memory_cache = {}
 108 |         self.redis_client = None
 109 |         
 110 |         if cache_type == "redis" and redis_url:
 111 |             try:
 112 |                 import redis
 113 |                 self.redis_client = redis.from_url(redis_url)
 114 |                 logging.info(f"Redis cache initialized with URL: {redis_url}")
 115 |             except ImportError:
 116 |                 logging.warning("Redis package not installed. Falling back to memory cache.")
 117 |                 self.cache_type = "memory"
 118 |             except Exception as e:
 119 |                 logging.error(f"Failed to connect to Redis: {str(e)}")
 120 |                 self.cache_type = "memory"
 121 |     
 122 |     async def get(self, key):
 123 |         """Get item from cache"""
 124 |         if self.cache_type == "redis" and self.redis_client:
 125 |             try:
 126 |                 value = self.redis_client.get(key)
 127 |                 if value:
 128 |                     CACHE_HIT.inc()
 129 |                     return json.loads(value)
 130 |                 CACHE_MISS.inc()
 131 |                 return None
 132 |             except Exception as e:
 133 |                 logging.error(f"Redis get error: {str(e)}")
 134 |                 CACHE_MISS.inc()
 135 |                 return None
 136 |         else:
 137 |             # Memory cache
 138 |             if key in self.memory_cache:
 139 |                 if time.time() - self.memory_cache[key]["timestamp"] < self.ttl:
 140 |                     CACHE_HIT.inc()
 141 |                     return self.memory_cache[key]["data"]
 142 |                 else:
 143 |                     # Expired
 144 |                     del self.memory_cache[key]
 145 |             CACHE_MISS.inc()
 146 |             return None
 147 |     
 148 |     async def set(self, key, value, ttl=None):
 149 |         """Set item in cache"""
 150 |         if ttl is None:
 151 |             ttl = self.ttl
 152 |             
 153 |         if self.cache_type == "redis" and self.redis_client:
 154 |             try:
 155 |                 self.redis_client.setex(key, ttl, json.dumps(value))
 156 |             except Exception as e:
 157 |                 logging.error(f"Redis set error: {str(e)}")
 158 |         else:
 159 |             # Memory cache
 160 |             self.memory_cache[key] = {
 161 |                 "data": value,
 162 |                 "timestamp": time.time()
 163 |             }
 164 |     
 165 |     async def delete(self, key):
 166 |         """Delete item from cache"""
 167 |         if self.cache_type == "redis" and self.redis_client:
 168 |             try:
 169 |                 self.redis_client.delete(key)
 170 |             except Exception as e:
 171 |                 logging.error(f"Redis delete error: {str(e)}")
 172 |         else:
 173 |             # Memory cache
 174 |             if key in self.memory_cache:
 175 |                 del self.memory_cache[key]
 176 |     
 177 |     async def clear(self):
 178 |         """Clear all cache"""
 179 |         if self.cache_type == "redis" and self.redis_client:
 180 |             try:
 181 |                 self.redis_client.flushdb()
 182 |             except Exception as e:
 183 |                 logging.error(f"Redis flush error: {str(e)}")
 184 |         else:
 185 |             # Memory cache
 186 |             self.memory_cache = {}
 187 | 
 188 | class MCPServer:
 189 |     """Model Context Protocol Server Implementation"""
 190 |     
 191 |     def __init__(self, cache_type="memory", redis_url=None):
 192 |         self.app = FastAPI(
 193 |             title="OpenAI Code Assistant MCP Server",
 194 |             description="Model Context Protocol server for OpenAI Code Assistant",
 195 |             version="1.0.0",
 196 |             docs_url="/docs",
 197 |             redoc_url="/redoc",
 198 |             openapi_url="/openapi.json",
 199 |         )
 200 |         
 201 |         # Initialize OpenAI client
 202 |         self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
 203 |         
 204 |         # Initialize cache
 205 |         self.cache = CacheManager(cache_type=cache_type, redis_url=redis_url)
 206 |         
 207 |         # Initialize tokenizer
 208 |         self.tokenizer = tiktoken.get_encoding("cl100k_base")
 209 |         
 210 |         # Setup routes and middleware
 211 |         self.setup_routes()
 212 |         self.setup_middleware()
 213 |         
 214 |         # Load templates and static files
 215 |         self.templates_dir = os.path.join(os.path.dirname(__file__), "templates")
 216 |         os.makedirs(self.templates_dir, exist_ok=True)
 217 |         self.static_dir = os.path.join(os.path.dirname(__file__), "static")
 218 |         os.makedirs(self.static_dir, exist_ok=True)
 219 |         
 220 |         # Create default template if it doesn't exist
 221 |         self._create_default_template()
 222 |         
 223 |         # Initialize templates
 224 |         self.templates = Jinja2Templates(directory=self.templates_dir)
 225 |         
 226 |         # Mount static files
 227 |         self.app.mount("/static", StaticFiles(directory=self.static_dir), name="static")
 228 |         
 229 |         # Load prompt templates
 230 |         self.prompt_templates = self._load_prompt_templates()
 231 |         
 232 |         # Initialize metrics
 233 |         self.request_count = 0
 234 |         self.start_time = time.time()
 235 |         
 236 |     def setup_middleware(self):
 237 |         """Configure middleware for the FastAPI app"""
 238 |         # Add CORS middleware
 239 |         self.app.add_middleware(
 240 |             CORSMiddleware,
 241 |             allow_origins=["*"],
 242 |             allow_credentials=True,
 243 |             allow_methods=["*"],
 244 |             allow_headers=["*"],
 245 |         )
 246 |         
 247 |         # Add request tracking middleware
 248 |         @self.app.middleware("http")
 249 |         async def track_requests(request: Request, call_next):
 250 |             # Increment active connections
 251 |             ACTIVE_CONNECTIONS.inc()
 252 |             
 253 |             # Track request start time
 254 |             start_time = time.time()
 255 |             
 256 |             # Process request
 257 |             try:
 258 |                 response = await call_next(request)
 259 |                 
 260 |                 # Record metrics
 261 |                 endpoint = request.url.path
 262 |                 status = response.status_code
 263 |                 REQUEST_COUNT.labels(endpoint=endpoint, status=status).inc()
 264 |                 REQUEST_LATENCY.labels(endpoint=endpoint).observe(time.time() - start_time)
 265 |                 
 266 |                 # Increment total request count
 267 |                 self.request_count += 1
 268 |                 
 269 |                 return response
 270 |             finally:
 271 |                 # Decrement active connections
 272 |                 ACTIVE_CONNECTIONS.dec()
 273 |     
 274 |     def _create_default_template(self):
 275 |         """Create default dashboard template if it doesn't exist"""
 276 |         index_path = os.path.join(self.templates_dir, "index.html")
 277 |         if not os.path.exists(index_path):
 278 |             with open(index_path, "w") as f:
 279 |                 f.write("""
 280 | <!DOCTYPE html>
 281 | <html>
 282 | <head>
 283 |     <title>OpenAI Code Assistant MCP Server</title>
 284 |     <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css">
 285 |     <style>
 286 |         body { padding: 20px; }
 287 |         .card { margin-bottom: 20px; }
 288 |     </style>
 289 | </head>
 290 | <body>
 291 |     <div class="container">
 292 |         <h1>OpenAI Code Assistant MCP Server</h1>
 293 |         <div class="row">
 294 |             <div class="col-md-6">
 295 |                 <div class="card">
 296 |                     <div class="card-header">Server Status</div>
 297 |                     <div class="card-body">
 298 |                         <p><strong>Status:</strong> {{ status }}</p>
 299 |                         <p><strong>Uptime:</strong> {{ uptime }}</p>
 300 |                         <p><strong>Requests Served:</strong> {{ request_count }}</p>
 301 |                         <p><strong>Cache Hit Ratio:</strong> {{ cache_hit_ratio }}%</p>
 302 |                     </div>
 303 |                 </div>
 304 |             </div>
 305 |             <div class="col-md-6">
 306 |                 <div class="card">
 307 |                     <div class="card-header">Available Models</div>
 308 |                     <div class="card-body">
 309 |                         <ul>
 310 |                             {% for model in models %}
 311 |                             <li>{{ model }}</li>
 312 |                             {% endfor %}
 313 |                         </ul>
 314 |                     </div>
 315 |                 </div>
 316 |             </div>
 317 |         </div>
 318 |         
 319 |         <h2>Available Prompt Templates</h2>
 320 |         <div class="row">
 321 |             {% for template in templates %}
 322 |             <div class="col-md-6">
 323 |                 <div class="card">
 324 |                     <div class="card-header">{{ template.id }}</div>
 325 |                     <div class="card-body">
 326 |                         <p><strong>Description:</strong> {{ template.description }}</p>
 327 |                         <p><strong>Parameters:</strong> {{ template.parameters|join(", ") }}</p>
 328 |                         <p><strong>Default Model:</strong> {{ template.default_model }}</p>
 329 |                     </div>
 330 |                 </div>
 331 |             </div>
 332 |             {% endfor %}
 333 |         </div>
 334 |         
 335 |         <h2>API Documentation</h2>
 336 |         <p>
 337 |             <a href="/docs" class="btn btn-primary">Interactive API Docs</a>
 338 |             <a href="/redoc" class="btn btn-secondary">ReDoc API Docs</a>
 339 |             <a href="/metrics" class="btn btn-info">Prometheus Metrics</a>
 340 |         </p>
 341 |     </div>
 342 |     
 343 |     <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
 344 | </body>
 345 | </html>
 346 |                 """)
 347 |     
 348 |     def setup_routes(self):
 349 |         """Configure API routes for MCP protocol"""
 350 |         
 351 |         # MCP Protocol Routes
 352 |         # Dashboard route
 353 |         @self.app.get("/", tags=["Dashboard"])
 354 |         async def dashboard(request: Request):
 355 |             """Dashboard showing server status and available templates"""
 356 |             # Calculate cache hit ratio
 357 |             cache_hits = prometheus_client.REGISTRY.get_sample_value('mcp_cache_hits_total') or 0
 358 |             cache_misses = prometheus_client.REGISTRY.get_sample_value('mcp_cache_misses_total') or 0
 359 |             total_cache_requests = cache_hits + cache_misses
 360 |             cache_hit_ratio = (cache_hits / total_cache_requests * 100) if total_cache_requests > 0 else 0
 361 |             
 362 |             # Format uptime
 363 |             uptime_seconds = time.time() - self.start_time
 364 |             days, remainder = divmod(uptime_seconds, 86400)
 365 |             hours, remainder = divmod(remainder, 3600)
 366 |             minutes, seconds = divmod(remainder, 60)
 367 |             uptime_str = f"{int(days)}d {int(hours)}h {int(minutes)}m {int(seconds)}s"
 368 |             
 369 |             # Get template information
 370 |             templates = []
 371 |             for template_id, template in self.prompt_templates.items():
 372 |                 templates.append({
 373 |                     "id": template_id,
 374 |                     "description": template.get("description", ""),
 375 |                     "parameters": list(template.get("parameters", {}).keys()),
 376 |                     "default_model": template.get("default_model", "gpt-4o")
 377 |                 })
 378 |             
 379 |             return self.templates.TemplateResponse("index.html", {
 380 |                 "request": request,
 381 |                 "status": "Healthy",
 382 |                 "uptime": uptime_str,
 383 |                 "request_count": self.request_count,
 384 |                 "cache_hit_ratio": round(cache_hit_ratio, 2),
 385 |                 "models": ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"],
 386 |                 "templates": templates
 387 |             })
 388 |         
 389 |         # Prometheus metrics endpoint
 390 |         @self.app.get("/metrics", tags=["Monitoring"])
 391 |         async def metrics():
 392 |             """Expose Prometheus metrics"""
 393 |             return Response(prometheus_client.generate_latest(), media_type="text/plain")
 394 |         
 395 |         # Health check endpoints
 396 |         @self.app.get("/health", response_model=MCPHealthResponse, tags=["Health"])
 397 |         async def health():
 398 |             """Health check endpoint"""
 399 |             # Calculate cache hit ratio
 400 |             cache_hits = prometheus_client.REGISTRY.get_sample_value('mcp_cache_hits_total') or 0
 401 |             cache_misses = prometheus_client.REGISTRY.get_sample_value('mcp_cache_misses_total') or 0
 402 |             total_cache_requests = cache_hits + cache_misses
 403 |             cache_hit_ratio = (cache_hits / total_cache_requests) if total_cache_requests > 0 else 0
 404 |             
 405 |             return MCPHealthResponse(
 406 |                 status="healthy",
 407 |                 uptime=time.time() - self.start_time,
 408 |                 request_count=self.request_count,
 409 |                 cache_hit_ratio=cache_hit_ratio
 410 |             )
 411 |         
 412 |         @self.app.post("/context", response_model=MCPContextResponse, tags=["Context"])
 413 |         async def get_context(
 414 |             request: MCPContextRequest, 
 415 |             background_tasks: BackgroundTasks,
 416 |             use_cache: bool = Query(True, description="Whether to use cached results if available")
 417 |         ):
 418 |             """
 419 |             Get context for a prompt template with parameters.
 420 |             
 421 |             This endpoint processes a prompt template with the provided parameters
 422 |             and returns the generated context. It can optionally use OpenAI models
 423 |             to enhance the context.
 424 |             """
 425 |             try:
 426 |                 # Check if prompt template exists
 427 |                 if request.prompt_id not in self.prompt_templates:
 428 |                     raise HTTPException(
 429 |                         status_code=404,
 430 |                         detail=f"Prompt template '{request.prompt_id}' not found"
 431 |                     )
 432 |                 
 433 |                 # Get prompt template
 434 |                 template = self.prompt_templates[request.prompt_id]
 435 |                 
 436 |                 # Use default model if not specified
 437 |                 model = request.model or template.get("default_model", "gpt-4o")
 438 |                 
 439 |                 # Generate context ID
 440 |                 context_id = str(uuid.uuid4())
 441 |                 
 442 |                 # Generate cache key
 443 |                 cache_key = f"{request.prompt_id}:{json.dumps(request.parameters, sort_keys=True)}:{model}"
 444 |                 
 445 |                 # Check cache if enabled
 446 |                 if use_cache:
 447 |                     cached_result = await self.cache.get(cache_key)
 448 |                     if cached_result:
 449 |                         # Update context ID for this request
 450 |                         cached_result["context_id"] = context_id
 451 |                         return MCPContextResponse(**cached_result)
 452 |                 
 453 |                 # Process template with parameters
 454 |                 processed_template = self._process_template(template["template"], request.parameters)
 455 |                 
 456 |                 # Check if we should use OpenAI to enhance the context
 457 |                 if template.get("use_openai", False):
 458 |                     # Generate context using OpenAI
 459 |                     context, usage = await self._generate_with_openai(
 460 |                         processed_template, 
 461 |                         model, 
 462 |                         template.get("system_prompt")
 463 |                     )
 464 |                 else:
 465 |                     # Use the processed template directly
 466 |                     context = processed_template
 467 |                     
 468 |                     # Calculate token usage
 469 |                     token_count = len(self.tokenizer.encode(context))
 470 |                     usage = {
 471 |                         "prompt_tokens": token_count,
 472 |                         "completion_tokens": 0,
 473 |                         "total_tokens": token_count
 474 |                     }
 475 |                 
 476 |                 # Track token usage in Prometheus
 477 |                 TOKEN_USAGE.labels(model=model, type="prompt").inc(usage["prompt_tokens"])
 478 |                 TOKEN_USAGE.labels(model=model, type="completion").inc(usage["completion_tokens"])
 479 |                 
 480 |                 # Create response
 481 |                 response = MCPContextResponse(
 482 |                     context=context,
 483 |                     context_id=context_id,
 484 |                     model=model,
 485 |                     usage=usage,
 486 |                     metadata={
 487 |                         "prompt_id": request.prompt_id,
 488 |                         "timestamp": time.time(),
 489 |                         "parameters": request.parameters
 490 |                     }
 491 |                 )
 492 |                 
 493 |                 # Store in cache
 494 |                 await self.cache.set(cache_key, response.dict())
 495 |                 
 496 |                 return response
 497 |                 
 498 |             except Exception as e:
 499 |                 logger.error(f"Error processing context request: {str(e)}", exc_info=True)
 500 |                 raise HTTPException(
 501 |                     status_code=500,
 502 |                     detail=f"Error processing context: {str(e)}"
 503 |                 )
 504 |         
 505 |         @self.app.post("/context/stream", tags=["Context"])
 506 |         async def stream_context(request: MCPContextRequest):
 507 |             """
 508 |             Stream context generation.
 509 |             
 510 |             Similar to /context but streams the response as it's generated.
 511 |             """
 512 |             try:
 513 |                 # Check if prompt template exists
 514 |                 if request.prompt_id not in self.prompt_templates:
 515 |                     raise HTTPException(
 516 |                         status_code=404,
 517 |                         detail=f"Prompt template '{request.prompt_id}' not found"
 518 |                     )
 519 |                 
 520 |                 # Get prompt template
 521 |                 template = self.prompt_templates[request.prompt_id]
 522 |                 
 523 |                 # Use default model if not specified
 524 |                 model = request.model or template.get("default_model", "gpt-4o")
 525 |                 
 526 |                 # Generate context ID
 527 |                 context_id = str(uuid.uuid4())
 528 |                 
 529 |                 # Process template with parameters
 530 |                 processed_template = self._process_template(template["template"], request.parameters)
 531 |                 
 532 |                 # Stream the context generation
 533 |                 return StreamingResponse(
 534 |                     self._stream_context(processed_template, model, context_id, template.get("system_prompt")),
 535 |                     media_type="text/event-stream"
 536 |                 )
 537 |                 
 538 |             except Exception as e:
 539 |                 logger.error(f"Error streaming context: {str(e)}", exc_info=True)
 540 |                 raise HTTPException(
 541 |                     status_code=500,
 542 |                     detail=f"Error streaming context: {str(e)}"
 543 |                 )
 544 |         
 545 |         @self.app.get("/prompts", response_model=MCPPromptLibraryResponse, tags=["Prompts"])
 546 |         async def get_prompts():
 547 |             """
 548 |             Get available prompt templates.
 549 |             
 550 |             Returns a list of all prompt templates available in the system.
 551 |             """
 552 |             prompts = [
 553 |                 MCPPromptTemplate(
 554 |                     id=prompt_id,
 555 |                     template=template["template"],
 556 |                     description=template.get("description", ""),
 557 |                     parameters=template.get("parameters", {}),
 558 |                     default_model=template.get("default_model", "gpt-4o"),
 559 |                     metadata=template.get("metadata", {})
 560 |                 )
 561 |                 for prompt_id, template in self.prompt_templates.items()
 562 |             ]
 563 |             
 564 |             return MCPPromptLibraryResponse(
 565 |                 prompts=prompts,
 566 |                 count=len(prompts)
 567 |             )
 568 |         
 569 |         @self.app.get("/prompts/{prompt_id}", response_model=MCPPromptTemplate, tags=["Prompts"])
 570 |         async def get_prompt(prompt_id: str):
 571 |             """
 572 |             Get a specific prompt template.
 573 |             
 574 |             Returns the details of a specific prompt template by ID.
 575 |             """
 576 |             if prompt_id not in self.prompt_templates:
 577 |                 raise HTTPException(
 578 |                     status_code=404,
 579 |                     detail=f"Prompt template '{prompt_id}' not found"
 580 |                 )
 581 |             
 582 |             template = self.prompt_templates[prompt_id]
 583 |             return MCPPromptTemplate(
 584 |                 id=prompt_id,
 585 |                 template=template["template"],
 586 |                 description=template.get("description", ""),
 587 |                 parameters=template.get("parameters", {}),
 588 |                 default_model=template.get("default_model", "gpt-4o"),
 589 |                 metadata=template.get("metadata", {})
 590 |             )
 591 |         
 592 |         @self.app.post("/prompts", response_model=MCPPromptTemplate, status_code=201, tags=["Prompts"])
 593 |         async def create_prompt(prompt: MCPPromptTemplate):
 594 |             """
 595 |             Create a new prompt template.
 596 |             
 597 |             Adds a new prompt template to the system.
 598 |             """
 599 |             if prompt.id in self.prompt_templates:
 600 |                 raise HTTPException(
 601 |                     status_code=409,
 602 |                     detail=f"Prompt template '{prompt.id}' already exists"
 603 |                 )
 604 |             
 605 |             self.prompt_templates[prompt.id] = {
 606 |                 "template": prompt.template,
 607 |                 "description": prompt.description,
 608 |                 "parameters": prompt.parameters,
 609 |                 "default_model": prompt.default_model,
 610 |                 "metadata": prompt.metadata
 611 |             }
 612 |             
 613 |             # Save updated templates
 614 |             self._save_prompt_templates()
 615 |             
 616 |             return prompt
 617 |         
 618 |         @self.app.put("/prompts/{prompt_id}", response_model=MCPPromptTemplate, tags=["Prompts"])
 619 |         async def update_prompt(prompt_id: str, prompt: MCPPromptTemplate):
 620 |             """
 621 |             Update an existing prompt template.
 622 |             
 623 |             Updates the details of an existing prompt template.
 624 |             """
 625 |             if prompt_id != prompt.id:
 626 |                 raise HTTPException(
 627 |                     status_code=400,
 628 |                     detail="Prompt ID in path must match prompt ID in body"
 629 |                 )
 630 |             
 631 |             if prompt_id not in self.prompt_templates:
 632 |                 raise HTTPException(
 633 |                     status_code=404,
 634 |                     detail=f"Prompt template '{prompt_id}' not found"
 635 |                 )
 636 |             
 637 |             self.prompt_templates[prompt_id] = {
 638 |                 "template": prompt.template,
 639 |                 "description": prompt.description,
 640 |                 "parameters": prompt.parameters,
 641 |                 "default_model": prompt.default_model,
 642 |                 "metadata": prompt.metadata
 643 |             }
 644 |             
 645 |             # Save updated templates
 646 |             self._save_prompt_templates()
 647 |             
 648 |             return prompt
 649 |         
 650 |         @self.app.delete("/prompts/{prompt_id}", tags=["Prompts"])
 651 |         async def delete_prompt(prompt_id: str):
 652 |             """
 653 |             Delete a prompt template.
 654 |             
 655 |             Removes a prompt template from the system.
 656 |             """
 657 |             if prompt_id not in self.prompt_templates:
 658 |                 raise HTTPException(
 659 |                     status_code=404,
 660 |                     detail=f"Prompt template '{prompt_id}' not found"
 661 |                 )
 662 |             
 663 |             del self.prompt_templates[prompt_id]
 664 |             
 665 |             # Save updated templates
 666 |             self._save_prompt_templates()
 667 |             
 668 |             return {"status": "deleted", "prompt_id": prompt_id}
 669 |         
 670 |         # Additional endpoints for a more complete MCP server
 671 |         @self.app.get("/models", tags=["Models"])
 672 |         async def get_models():
 673 |             """
 674 |             Get available models.
 675 |             
 676 |             Returns a list of models that can be used with this MCP server.
 677 |             """
 678 |             return {
 679 |                 "models": [
 680 |                     {
 681 |                         "id": "gpt-4o",
 682 |                         "name": "GPT-4o",
 683 |                         "description": "OpenAI's most advanced model",
 684 |                         "context_length": 128000,
 685 |                         "is_default": True
 686 |                     },
 687 |                     {
 688 |                         "id": "gpt-4-turbo",
 689 |                         "name": "GPT-4 Turbo",
 690 |                         "description": "Optimized version of GPT-4",
 691 |                         "context_length": 128000,
 692 |                         "is_default": False
 693 |                     },
 694 |                     {
 695 |                         "id": "gpt-3.5-turbo",
 696 |                         "name": "GPT-3.5 Turbo",
 697 |                         "description": "Fast and efficient model",
 698 |                         "context_length": 16385,
 699 |                         "is_default": False
 700 |                     }
 701 |                 ],
 702 |                 "count": 3
 703 |             }
 704 |         
 705 |         @self.app.get("/stats", tags=["System"])
 706 |         async def get_stats():
 707 |             """
 708 |             Get server statistics.
 709 |             
 710 |             Returns usage statistics and system information.
 711 |             """
 712 |             return {
 713 |                 "uptime": time.time() - self.start_time,
 714 |                 "prompt_templates_count": len(self.prompt_templates),
 715 |                 "cache_size": len(self.context_cache),
 716 |                 "requests_served": {
 717 |                     "context": 0,  # This would be tracked in a real implementation
 718 |                     "prompts": 0,
 719 |                     "total": 0
 720 |                 },
 721 |                 "system_info": {
 722 |                     "python_version": sys.version,
 723 |                     "platform": sys.platform
 724 |                 }
 725 |             }
 726 |             
 727 |         @self.app.post("/context/stream", tags=["Context"])
 728 |         async def stream_context(request: MCPContextRequest):
 729 |             """
 730 |             Stream context generation.
 731 |             
 732 |             Similar to /context but streams the response as it's generated.
 733 |             """
 734 |             # In a real implementation, this would stream the response
 735 |             # For now, we'll just return a simple response
 736 |             return JSONResponse(
 737 |                 content={"message": "Streaming not implemented in this version"},
 738 |                 status_code=501
 739 |             )
 740 |             
 741 |         # Error handlers
 742 |         @self.app.exception_handler(HTTPException)
 743 |         async def http_exception_handler(request: Request, exc: HTTPException):
 744 |             """Handle HTTP exceptions in MCP format"""
 745 |             return JSONResponse(
 746 |                 status_code=exc.status_code,
 747 |                 content={
 748 |                     "error": exc.detail,
 749 |                     "error_type": "http_error",
 750 |                     "status_code": exc.status_code,
 751 |                     "details": exc.detail if isinstance(exc.detail, dict) else None
 752 |                 }
 753 |             )
 754 |         
 755 |         @self.app.exception_handler(Exception)
 756 |         async def general_exception_handler(request: Request, exc: Exception):
 757 |             """Handle general exceptions in MCP format"""
 758 |             logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
 759 |             return JSONResponse(
 760 |                 status_code=500,
 761 |                 content={
 762 |                     "error": str(exc),
 763 |                     "error_type": "server_error",
 764 |                     "status_code": 500,
 765 |                     "details": None
 766 |                 }
 767 |             )
 768 |     
 769 |     def _load_prompt_templates(self) -> Dict[str, Dict[str, Any]]:
 770 |         """Load prompt templates from file or initialize defaults"""
 771 |         templates_file = os.path.join(os.path.dirname(__file__), "data", "prompt_templates.json")
 772 |         
 773 |         # Create directory if it doesn't exist
 774 |         os.makedirs(os.path.dirname(templates_file), exist_ok=True)
 775 |         
 776 |         # Try to load existing templates
 777 |         if os.path.exists(templates_file):
 778 |             try:
 779 |                 with open(templates_file, "r") as f:
 780 |                     templates = json.load(f)
 781 |                     logger.info(f"Loaded {len(templates)} prompt templates from {templates_file}")
 782 |                     return templates
 783 |             except Exception as e:
 784 |                 logger.error(f"Error loading prompt templates: {str(e)}")
 785 |         
 786 |         # Initialize with enhanced default templates
 787 |         default_templates = {
 788 |             "greeting": {
 789 |                 "template": "Hello! The current time is {time}. How can I help you today?",
 790 |                 "description": "A simple greeting template",
 791 |                 "parameters": {
 792 |                     "time": {
 793 |                         "type": "string",
 794 |                         "description": "The current time"
 795 |                     }
 796 |                 },
 797 |                 "default_model": "gpt-4o",
 798 |                 "metadata": {
 799 |                     "category": "general"
 800 |                 }
 801 |             },
 802 |             "code_review": {
 803 |                 "template": "Please review the following code:\n\n```{language}\n{code}\n```\n\nFocus on: {focus_areas}",
 804 |                 "description": "Template for code review requests",
 805 |                 "parameters": {
 806 |                     "language": {
 807 |                         "type": "string",
 808 |                         "description": "Programming language of the code"
 809 |                     },
 810 |                     "code": {
 811 |                         "type": "string",
 812 |                         "description": "The code to review"
 813 |                     },
 814 |                     "focus_areas": {
 815 |                         "type": "string",
 816 |                         "description": "Areas to focus on during review (e.g., 'performance, security')"
 817 |                     }
 818 |                 },
 819 |                 "default_model": "gpt-4o",
 820 |                 "use_openai": True,
 821 |                 "system_prompt": "You are a code review expert. Analyze the provided code and provide constructive feedback focusing on the specified areas.",
 822 |                 "metadata": {
 823 |                     "category": "development"
 824 |                 }
 825 |             },
 826 |             "system_prompt": {
 827 |                 "template": "You are OpenAI Code Assistant, a CLI tool that helps users with software engineering tasks and general information.\nUse the available tools to assist the user with their requests.\n\n# Tone and style\nYou should be concise, direct, and to the point. When you run a non-trivial bash command, \nyou should explain what the command does and why you are running it.\nOutput text to communicate with the user; all text you output outside of tool use is displayed to the user.\nRemember that your output will be displayed on a command line interface.\n\n# Tool usage policy\n- When doing file search, remember to search effectively with the available tools.\n- Always use the appropriate tool for the task.\n- Use parallel tool calls when appropriate to improve performance.\n- NEVER commit changes unless the user explicitly asks you to.\n- For weather queries, use the Weather tool to provide real-time information.\n\n# Tasks\nThe user will primarily request you perform software engineering tasks:\n1. Solving bugs\n2. Adding new functionality \n3. Refactoring code\n4. Explaining code\n5. Writing tests\n\nFor these tasks:\n1. Use search tools to understand the codebase\n2. Implement solutions using the available tools\n3. Verify solutions with tests if possible\n4. Run lint and typecheck commands when appropriate\n\nThe user may also ask for general information:\n1. Weather conditions\n2. Simple calculations\n3. General knowledge questions\n\n# Code style\n- Follow the existing code style of the project\n- Maintain consistent naming conventions\n- Use appropriate libraries that are already in the project\n- Add comments when code is complex or non-obvious\n\nIMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, \nquality, and accuracy. Answer concisely with short lines of text unless the user asks for detail.",
 828 |                 "description": "System prompt for the assistant",
 829 |                 "parameters": {},
 830 |                 "default_model": "gpt-4o",
 831 |                 "metadata": {
 832 |                     "category": "system"
 833 |                 }
 834 |             },
 835 |             "documentation": {
 836 |                 "template": "Generate documentation for the following code:\n\n```{language}\n{code}\n```\n\nFormat: {format}",
 837 |                 "description": "Generate code documentation",
 838 |                 "parameters": {
 839 |                     "language": {
 840 |                         "type": "string",
 841 |                         "description": "Programming language of the code"
 842 |                     },
 843 |                     "code": {
 844 |                         "type": "string",
 845 |                         "description": "The code to document"
 846 |                     },
 847 |                     "format": {
 848 |                         "type": "string",
 849 |                         "description": "Documentation format (e.g., 'markdown', 'docstring', 'jsdoc')",
 850 |                         "default": "markdown"
 851 |                     }
 852 |                 },
 853 |                 "default_model": "gpt-4o",
 854 |                 "use_openai": True,
 855 |                 "system_prompt": "You are a technical documentation expert. Generate clear, concise, and accurate documentation for the provided code.",
 856 |                 "metadata": {
 857 |                     "category": "development"
 858 |                 }
 859 |             },
 860 |             "explain_code": {
 861 |                 "template": "Explain how the following code works:\n\n```{language}\n{code}\n```\n\nDetail level: {detail_level}",
 862 |                 "description": "Explain code functionality",
 863 |                 "parameters": {
 864 |                     "language": {
 865 |                         "type": "string",
 866 |                         "description": "Programming language of the code"
 867 |                     },
 868 |                     "code": {
 869 |                         "type": "string",
 870 |                         "description": "The code to explain"
 871 |                     },
 872 |                     "detail_level": {
 873 |                         "type": "string",
 874 |                         "description": "Level of detail in the explanation (e.g., 'basic', 'intermediate', 'advanced')",
 875 |                         "default": "intermediate"
 876 |                     }
 877 |                 },
 878 |                 "default_model": "gpt-4o",
 879 |                 "use_openai": True,
 880 |                 "system_prompt": "You are a programming instructor. Explain the provided code clearly at the requested level of detail.",
 881 |                 "metadata": {
 882 |                     "category": "education"
 883 |                 }
 884 |             },
 885 |             "current_time": {
 886 |                 "template": "The current time is {{now:%Y-%m-%d %H:%M:%S}}.",
 887 |                 "description": "Get the current time",
 888 |                 "parameters": {},
 889 |                 "default_model": "gpt-4o",
 890 |                 "metadata": {
 891 |                     "category": "utility"
 892 |                 }
 893 |             }
 894 |         }
 895 |         
 896 |         # Save default templates
 897 |         try:
 898 |             with open(templates_file, "w") as f:
 899 |                 json.dump(default_templates, f, indent=2)
 900 |         except Exception as e:
 901 |             logger.error(f"Error saving default prompt templates: {str(e)}")
 902 |         
 903 |         return default_templates
 904 |     
 905 |     def _save_prompt_templates(self):
 906 |         """Save prompt templates to file"""
 907 |         templates_file = os.path.join(os.path.dirname(__file__), "data", "prompt_templates.json")
 908 |         
 909 |         try:
 910 |             with open(templates_file, "w") as f:
 911 |                 json.dump(self.prompt_templates, f, indent=2)
 912 |         except Exception as e:
 913 |             logger.error(f"Error saving prompt templates: {str(e)}")
 914 |     
 915 |     async def _generate_with_openai(self, prompt: str, model: str, system_prompt: Optional[str] = None) -> tuple:
 916 |         """Generate context using OpenAI API"""
 917 |         messages = []
 918 |         
 919 |         # Add system prompt if provided
 920 |         if system_prompt:
 921 |             messages.append({"role": "system", "content": system_prompt})
 922 |         
 923 |         # Add user prompt
 924 |         messages.append({"role": "user", "content": prompt})
 925 |         
 926 |         # Call OpenAI API
 927 |         try:
 928 |             response = await asyncio.to_thread(
 929 |                 self.openai_client.chat.completions.create,
 930 |                 model=model,
 931 |                 messages=messages,
 932 |                 temperature=0.0,  # Use deterministic output for context generation
 933 |                 max_tokens=4000
 934 |             )
 935 |             
 936 |             # Extract content and usage
 937 |             content = response.choices[0].message.content
 938 |             usage = {
 939 |                 "prompt_tokens": response.usage.prompt_tokens,
 940 |                 "completion_tokens": response.usage.completion_tokens,
 941 |                 "total_tokens": response.usage.total_tokens
 942 |             }
 943 |             
 944 |             return content, usage
 945 |             
 946 |         except Exception as e:
 947 |             logger.error(f"OpenAI API error: {str(e)}")
 948 |             raise ValueError(f"Error generating context with OpenAI: {str(e)}")
 949 |     
 950 |     async def _stream_context(self, prompt: str, model: str, context_id: str, system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
 951 |         """Stream context generation using OpenAI API"""
 952 |         messages = []
 953 |         
 954 |         # Add system prompt if provided
 955 |         if system_prompt:
 956 |             messages.append({"role": "system", "content": system_prompt})
 957 |         
 958 |         # Add user prompt
 959 |         messages.append({"role": "user", "content": prompt})
 960 |         
 961 |         # Initial event with context ID
 962 |         yield f"data: {json.dumps({'context_id': context_id, 'event': 'start'})}\n\n"
 963 |         
 964 |         try:
 965 |             # Call OpenAI API with streaming
 966 |             stream = await asyncio.to_thread(
 967 |                 self.openai_client.chat.completions.create,
 968 |                 model=model,
 969 |                 messages=messages,
 970 |                 temperature=0.0,
 971 |                 max_tokens=4000,
 972 |                 stream=True
 973 |             )
 974 |             
 975 |             full_content = ""
 976 |             
 977 |             # Process the stream
 978 |             for chunk in stream:
 979 |                 if chunk.choices and chunk.choices[0].delta.content:
 980 |                     content_piece = chunk.choices[0].delta.content
 981 |                     full_content += content_piece
 982 |                     
 983 |                     # Yield the content piece
 984 |                     yield f"data: {json.dumps({'content': content_piece, 'event': 'content'})}\n\n"
 985 |             
 986 |             # Calculate token usage
 987 |             prompt_tokens = len(self.tokenizer.encode(prompt))
 988 |             completion_tokens = len(self.tokenizer.encode(full_content))
 989 |             total_tokens = prompt_tokens + completion_tokens
 990 |             
 991 |             # Track token usage
 992 |             TOKEN_USAGE.labels(model=model, type="prompt").inc(prompt_tokens)
 993 |             TOKEN_USAGE.labels(model=model, type="completion").inc(completion_tokens)
 994 |             
 995 |             # Final event with complete context and usage
 996 |             yield f"data: {json.dumps({
 997 |                 'event': 'end',
 998 |                 'context': full_content,
 999 |                 'usage': {
1000 |                     'prompt_tokens': prompt_tokens,
1001 |                     'completion_tokens': completion_tokens,
1002 |                     'total_tokens': total_tokens
1003 |                 }
1004 |             })}\n\n"
1005 |             
1006 |         except Exception as e:
1007 |             logger.error(f"Error streaming context: {str(e)}")
1008 |             yield f"data: {json.dumps({'event': 'error', 'error': str(e)})}\n\n"
1009 |     
1010 |     def _process_template(self, template: str, parameters: Dict[str, Any]) -> str:
1011 |         """Process a template with parameters"""
1012 |         try:
1013 |             # Handle date/time formatting if needed
1014 |             processed_params = parameters.copy()
1015 |             for key, value in processed_params.items():
1016 |                 if isinstance(value, str) and value.startswith("{{now") and value.endswith("}}"):
1017 |                     # Extract format string if present
1018 |                     format_match = re.search(r"{{now:(.+)}}", value)
1019 |                     if format_match:
1020 |                         format_string = format_match.group(1)
1021 |                         processed_params[key] = datetime.now().strftime(format_string)
1022 |                     else:
1023 |                         processed_params[key] = datetime.now().isoformat()
1024 |             
1025 |             return template.format(**processed_params)
1026 |         except KeyError as e:
1027 |             raise ValueError(f"Missing required parameter: {e}")
1028 |         except Exception as e:
1029 |             raise ValueError(f"Error processing template: {str(e)}")
1030 |     
1031 |     def start(self, host: str = "127.0.0.1", port: int = 8000, reload: bool = False):
1032 |         """Start the MCP server"""
1033 |         uvicorn.run(self.app, host=host, port=port, reload=reload)
1034 | 
1035 | def create_mcp_app():
1036 |     """Factory function for creating the FastAPI app"""
1037 |     server = MCPServer()
1038 |     return server.app
1039 | 
1040 | if __name__ == "__main__":
1041 |     # Create data directory if it doesn't exist
1042 |     os.makedirs(os.path.join(os.path.dirname(__file__), "data"), exist_ok=True)
1043 |     
1044 |     # Start server
1045 |     server = MCPServer()
1046 |     server.start()
1047 | 
```
Page 3/6FirstPrevNextLast