This is page 3 of 6. Use http://codebase.md/arthurcolle/openai-mcp?lines=true&page={x} to view the full context.
# Directory Structure
```
├── .gitignore
├── claude_code
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-312.pyc
│ │ └── mcp_server.cpython-312.pyc
│ ├── claude.py
│ ├── commands
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-312.pyc
│ │ │ └── serve.cpython-312.pyc
│ │ ├── client.py
│ │ ├── multi_agent_client.py
│ │ └── serve.py
│ ├── config
│ │ └── __init__.py
│ ├── examples
│ │ ├── agents_config.json
│ │ ├── claude_mcp_config.html
│ │ ├── claude_mcp_config.json
│ │ ├── echo_server.py
│ │ └── README.md
│ ├── lib
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-312.pyc
│ │ ├── context
│ │ │ └── __init__.py
│ │ ├── monitoring
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ └── server_metrics.cpython-312.pyc
│ │ │ ├── cost_tracker.py
│ │ │ └── server_metrics.py
│ │ ├── providers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── openai.py
│ │ ├── rl
│ │ │ ├── __init__.py
│ │ │ ├── grpo.py
│ │ │ ├── mcts.py
│ │ │ └── tool_optimizer.py
│ │ ├── tools
│ │ │ ├── __init__.py
│ │ │ ├── __pycache__
│ │ │ │ ├── __init__.cpython-312.pyc
│ │ │ │ ├── base.cpython-312.pyc
│ │ │ │ ├── file_tools.cpython-312.pyc
│ │ │ │ └── manager.cpython-312.pyc
│ │ │ ├── ai_tools.py
│ │ │ ├── base.py
│ │ │ ├── code_tools.py
│ │ │ ├── file_tools.py
│ │ │ ├── manager.py
│ │ │ └── search_tools.py
│ │ └── ui
│ │ ├── __init__.py
│ │ └── tool_visualizer.py
│ ├── mcp_server.py
│ ├── README_MCP_CLIENT.md
│ ├── README_MULTI_AGENT.md
│ └── util
│ └── __init__.py
├── claude.py
├── cli.py
├── data
│ └── prompt_templates.json
├── deploy_modal_mcp.py
├── deploy.sh
├── examples
│ ├── agents_config.json
│ └── echo_server.py
├── install.sh
├── mcp_modal_adapter.py
├── mcp_server.py
├── modal_mcp_server.py
├── README_modal_mcp.md
├── README.md
├── requirements.txt
├── setup.py
├── static
│ └── style.css
├── templates
│ └── index.html
└── web-client.html
```
# Files
--------------------------------------------------------------------------------
/claude_code/lib/rl/grpo.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Group Relative Policy Optimization (GRPO) for multi-agent learning in Claude Code.
3 | This module provides a multi-agent GRPO implementation that learns from interactions.
4 | """
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | from torch.distributions import Categorical
12 | from typing import List, Dict, Tuple, Optional, Any, Union, Callable
13 | from dataclasses import dataclass
14 | from collections import deque
15 | import random
16 | import time
17 |
18 |
19 | @dataclass
20 | class Experience:
21 | """A single step of experience for reinforcement learning."""
22 | state: Any
23 | action: Any
24 | reward: float
25 | next_state: Any
26 | done: bool
27 | info: Optional[Dict[str, Any]] = None
28 |
29 |
30 | class ExperienceBuffer:
31 | """Buffer to store and sample experiences for training."""
32 |
33 | def __init__(self, capacity: int = 100000):
34 | """
35 | Initialize the experience buffer.
36 |
37 | Args:
38 | capacity: Maximum number of experiences to store
39 | """
40 | self.buffer = deque(maxlen=capacity)
41 |
42 | def add(self, experience: Experience) -> None:
43 | """Add an experience to the buffer."""
44 | self.buffer.append(experience)
45 |
46 | def sample(self, batch_size: int) -> List[Experience]:
47 | """Sample a batch of experiences from the buffer."""
48 | return random.sample(self.buffer, min(batch_size, len(self.buffer)))
49 |
50 | def __len__(self) -> int:
51 | """Get the current size of the buffer."""
52 | return len(self.buffer)
53 |
54 |
55 | class PolicyNetwork(nn.Module):
56 | """Neural network to represent a policy."""
57 |
58 | def __init__(self, input_dim: int, hidden_dims: List[int], output_dim: int):
59 | """
60 | Initialize the policy network.
61 |
62 | Args:
63 | input_dim: Dimension of the input state
64 | hidden_dims: List of hidden layer dimensions
65 | output_dim: Dimension of the action space
66 | """
67 | super(PolicyNetwork, self).__init__()
68 |
69 | # Create the input layer
70 | layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
71 |
72 | # Create hidden layers
73 | for i in range(len(hidden_dims) - 1):
74 | layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
75 | layers.append(nn.ReLU())
76 |
77 | # Create output layer
78 | layers.append(nn.Linear(hidden_dims[-1], output_dim))
79 |
80 | self.network = nn.Sequential(*layers)
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 | """Forward pass through the network."""
84 | return self.network(x)
85 |
86 |
87 | class ValueNetwork(nn.Module):
88 | """Neural network to represent a value function."""
89 |
90 | def __init__(self, input_dim: int, hidden_dims: List[int]):
91 | """
92 | Initialize the value network.
93 |
94 | Args:
95 | input_dim: Dimension of the input state
96 | hidden_dims: List of hidden layer dimensions
97 | """
98 | super(ValueNetwork, self).__init__()
99 |
100 | # Create the input layer
101 | layers = [nn.Linear(input_dim, hidden_dims[0]), nn.ReLU()]
102 |
103 | # Create hidden layers
104 | for i in range(len(hidden_dims) - 1):
105 | layers.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1]))
106 | layers.append(nn.ReLU())
107 |
108 | # Create output layer (scalar value)
109 | layers.append(nn.Linear(hidden_dims[-1], 1))
110 |
111 | self.network = nn.Sequential(*layers)
112 |
113 | def forward(self, x: torch.Tensor) -> torch.Tensor:
114 | """Forward pass through the network."""
115 | return self.network(x)
116 |
117 |
118 | class GRPO:
119 | """
120 | Group Relative Policy Optimization implementation for multi-agent learning.
121 | GRPO extends PPO by considering relative performance within a group of agents.
122 | """
123 |
124 | def __init__(
125 | self,
126 | state_dim: int,
127 | action_dim: int,
128 | hidden_dims: List[int] = [64, 64],
129 | lr_policy: float = 3e-4,
130 | lr_value: float = 1e-3,
131 | gamma: float = 0.99,
132 | gae_lambda: float = 0.95,
133 | clip_ratio: float = 0.2,
134 | target_kl: float = 0.01,
135 | value_coef: float = 0.5,
136 | entropy_coef: float = 0.01,
137 | max_grad_norm: float = 0.5,
138 | use_gae: bool = True,
139 | normalize_advantages: bool = True,
140 | relative_advantage_weight: float = 0.5,
141 | device: str = "cuda" if torch.cuda.is_available() else "cpu",
142 | ):
143 | """
144 | Initialize the GRPO agent.
145 |
146 | Args:
147 | state_dim: Dimension of the state space
148 | action_dim: Dimension of the action space
149 | hidden_dims: Dimensions of hidden layers in networks
150 | lr_policy: Learning rate for policy network
151 | lr_value: Learning rate for value network
152 | gamma: Discount factor
153 | gae_lambda: Lambda for GAE
154 | clip_ratio: PPO clipping parameter
155 | target_kl: Target KL divergence for early stopping
156 | value_coef: Value loss coefficient
157 | entropy_coef: Entropy bonus coefficient
158 | max_grad_norm: Maximum gradient norm for clipping
159 | use_gae: Whether to use GAE
160 | normalize_advantages: Whether to normalize advantages
161 | relative_advantage_weight: Weight for relative advantage component
162 | device: Device to run the model on
163 | """
164 | self.state_dim = state_dim
165 | self.action_dim = action_dim
166 | self.gamma = gamma
167 | self.gae_lambda = gae_lambda
168 | self.clip_ratio = clip_ratio
169 | self.target_kl = target_kl
170 | self.value_coef = value_coef
171 | self.entropy_coef = entropy_coef
172 | self.max_grad_norm = max_grad_norm
173 | self.use_gae = use_gae
174 | self.normalize_advantages = normalize_advantages
175 | self.relative_advantage_weight = relative_advantage_weight
176 | self.device = device
177 |
178 | # Initialize networks
179 | self.policy = PolicyNetwork(state_dim, hidden_dims, action_dim).to(device)
180 | self.value = ValueNetwork(state_dim, hidden_dims).to(device)
181 |
182 | # Initialize optimizers
183 | self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr_policy)
184 | self.value_optimizer = optim.Adam(self.value.parameters(), lr=lr_value)
185 |
186 | # Initialize experience buffer
187 | self.buffer = ExperienceBuffer()
188 |
189 | # Group-level buffers for relative advantage computation
190 | self.group_rewards = []
191 | self.agent_id = None # Will be set when joining a group
192 |
193 | def set_agent_id(self, agent_id: str) -> None:
194 | """Set the agent's ID within the group."""
195 | self.agent_id = agent_id
196 |
197 | def get_action(self, state: np.ndarray, deterministic: bool = False) -> Tuple[int, float]:
198 | """
199 | Get an action from the policy for the given state.
200 |
201 | Args:
202 | state: The current state
203 | deterministic: Whether to return the most likely action
204 |
205 | Returns:
206 | Tuple of (action, log probability)
207 | """
208 | # Convert state to tensor
209 | state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
210 |
211 | # Get action distributions
212 | with torch.no_grad():
213 | logits = self.policy(state_tensor)
214 | distribution = Categorical(logits=logits)
215 |
216 | if deterministic:
217 | action = torch.argmax(logits, dim=1).item()
218 | else:
219 | action = distribution.sample().item()
220 |
221 | log_prob = distribution.log_prob(torch.tensor(action)).item()
222 |
223 | return action, log_prob
224 |
225 | def get_value(self, state: np.ndarray) -> float:
226 | """
227 | Get the estimated value of a state.
228 |
229 | Args:
230 | state: The state to evaluate
231 |
232 | Returns:
233 | The estimated value
234 | """
235 | # Convert state to tensor
236 | state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
237 |
238 | # Get value estimate
239 | with torch.no_grad():
240 | value = self.value(state_tensor).item()
241 |
242 | return value
243 |
244 | def learn(
245 | self,
246 | batch_size: int = 64,
247 | epochs: int = 10,
248 | group_rewards: Optional[Dict[str, List[float]]] = None
249 | ) -> Dict[str, float]:
250 | """
251 | Update policy and value networks based on collected experience.
252 |
253 | Args:
254 | batch_size: Size of batches to use for updates
255 | epochs: Number of epochs to train for
256 | group_rewards: Rewards collected by all agents in the group
257 |
258 | Returns:
259 | Dictionary of training metrics
260 | """
261 | if len(self.buffer) < batch_size:
262 | return {"policy_loss": 0, "value_loss": 0, "kl": 0}
263 |
264 | # Prepare data for training
265 | states, actions, old_log_probs, returns, advantages = self._prepare_training_data(
266 | group_rewards)
267 |
268 | # Training metrics
269 | metrics = {
270 | "policy_loss": 0,
271 | "value_loss": 0,
272 | "entropy": 0,
273 | "kl": 0,
274 | }
275 |
276 | # Run training for multiple epochs
277 | for epoch in range(epochs):
278 | # Generate random indices for batching
279 | indices = np.random.permutation(len(states))
280 |
281 | # Process in batches
282 | for start_idx in range(0, len(states), batch_size):
283 | # Get batch indices
284 | batch_indices = indices[start_idx:start_idx + batch_size]
285 |
286 | # Extract batch data
287 | batch_states = states[batch_indices]
288 | batch_actions = actions[batch_indices]
289 | batch_old_log_probs = old_log_probs[batch_indices]
290 | batch_returns = returns[batch_indices]
291 | batch_advantages = advantages[batch_indices]
292 |
293 | # Update policy
294 | policy_loss, entropy, kl = self._update_policy(
295 | batch_states, batch_actions, batch_old_log_probs, batch_advantages)
296 |
297 | # Early stopping based on KL divergence
298 | if kl > 1.5 * self.target_kl:
299 | break
300 |
301 | # Update value function
302 | value_loss = self._update_value(batch_states, batch_returns)
303 |
304 | # Update metrics
305 | metrics["policy_loss"] += policy_loss
306 | metrics["value_loss"] += value_loss
307 | metrics["entropy"] += entropy
308 | metrics["kl"] += kl
309 |
310 | # Check for early stopping after each epoch
311 | if metrics["kl"] / (epoch + 1) > self.target_kl:
312 | break
313 |
314 | # Normalize metrics by number of updates
315 | num_updates = epochs * ((len(states) + batch_size - 1) // batch_size)
316 | for key in metrics:
317 | metrics[key] /= num_updates
318 |
319 | # Clear buffer after training
320 | self.buffer = ExperienceBuffer()
321 |
322 | return metrics
323 |
324 | def _prepare_training_data(
325 | self, group_rewards: Optional[Dict[str, List[float]]] = None
326 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
327 | """
328 | Prepare data for training from the experience buffer.
329 |
330 | Args:
331 | group_rewards: Rewards collected by all agents in the group
332 |
333 | Returns:
334 | Tuple of (states, actions, old_log_probs, returns, advantages)
335 | """
336 | # Collect experiences from buffer
337 | experiences = list(self.buffer.buffer)
338 |
339 | # Extract components
340 | states = torch.FloatTensor([exp.state for exp in experiences]).to(self.device)
341 | actions = torch.LongTensor([exp.action for exp in experiences]).to(self.device)
342 | rewards = torch.FloatTensor([exp.reward for exp in experiences]).to(self.device)
343 | next_states = torch.FloatTensor([exp.next_state for exp in experiences]).to(self.device)
344 | dones = torch.FloatTensor([float(exp.done) for exp in experiences]).to(self.device)
345 |
346 | # Compute values for all states and next states
347 | with torch.no_grad():
348 | values = self.value(states).squeeze()
349 | next_values = self.value(next_states).squeeze()
350 |
351 | # Compute advantages and returns
352 | if self.use_gae:
353 | # Generalized Advantage Estimation
354 | advantages = self._compute_gae(rewards, values, next_values, dones)
355 | else:
356 | # Regular advantages
357 | advantages = rewards + self.gamma * next_values * (1 - dones) - values
358 |
359 | # Compute returns (for value function)
360 | returns = advantages + values
361 |
362 | # If group rewards are provided, compute relative advantages
363 | if group_rewards is not None and self.agent_id in group_rewards:
364 | relative_advantages = self._compute_relative_advantages(
365 | advantages, group_rewards)
366 |
367 | # Combine regular and relative advantages
368 | advantages = (1 - self.relative_advantage_weight) * advantages + \
369 | self.relative_advantage_weight * relative_advantages
370 |
371 | # Normalize advantages if enabled
372 | if self.normalize_advantages:
373 | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
374 |
375 | # Get old log probabilities
376 | old_log_probs = torch.FloatTensor(
377 | [self._compute_log_prob(exp.state, exp.action) for exp in experiences]
378 | ).to(self.device)
379 |
380 | return states, actions, old_log_probs, returns, advantages
381 |
382 | def _compute_gae(
383 | self, rewards: torch.Tensor, values: torch.Tensor,
384 | next_values: torch.Tensor, dones: torch.Tensor
385 | ) -> torch.Tensor:
386 | """
387 | Compute advantages using Generalized Advantage Estimation.
388 |
389 | Args:
390 | rewards: Batch of rewards
391 | values: Batch of state values
392 | next_values: Batch of next state values
393 | dones: Batch of done flags
394 |
395 | Returns:
396 | Batch of advantage estimates
397 | """
398 | # Initialize advantages
399 | advantages = torch.zeros_like(rewards)
400 |
401 | # Initialize gae
402 | gae = 0
403 |
404 | # Compute advantages in reverse order
405 | for t in reversed(range(len(rewards))):
406 | # Compute TD error
407 | delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
408 |
409 | # Update gae
410 | gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
411 |
412 | # Store advantage
413 | advantages[t] = gae
414 |
415 | return advantages
416 |
417 | def _compute_relative_advantages(
418 | self, advantages: torch.Tensor, group_rewards: Dict[str, List[float]]
419 | ) -> torch.Tensor:
420 | """
421 | Compute relative advantages compared to other agents in the group.
422 |
423 | Args:
424 | advantages: This agent's advantages
425 | group_rewards: Rewards collected by all agents in the group
426 |
427 | Returns:
428 | Relative advantages
429 | """
430 | # Compute mean reward for each agent
431 | agent_mean_rewards = {
432 | agent_id: sum(rewards) / max(1, len(rewards))
433 | for agent_id, rewards in group_rewards.items()
434 | }
435 |
436 | # Compute mean reward across all agents
437 | group_mean_reward = sum(agent_mean_rewards.values()) / len(agent_mean_rewards)
438 |
439 | # Compute relative performance factor
440 | # Higher if this agent is doing better than the group average
441 | if self.agent_id in agent_mean_rewards:
442 | relative_factor = agent_mean_rewards[self.agent_id] / (group_mean_reward + 1e-8)
443 | else:
444 | relative_factor = 1.0
445 |
446 | # Apply the relative factor to the advantages
447 | relative_advantages = advantages * relative_factor
448 |
449 | return relative_advantages
450 |
451 | def _compute_log_prob(self, state: np.ndarray, action: int) -> float:
452 | """
453 | Compute the log probability of an action given a state.
454 |
455 | Args:
456 | state: The state
457 | action: The action
458 |
459 | Returns:
460 | The log probability
461 | """
462 | # Convert state to tensor
463 | state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
464 |
465 | # Get action distribution
466 | with torch.no_grad():
467 | logits = self.policy(state_tensor)
468 | distribution = Categorical(logits=logits)
469 | log_prob = distribution.log_prob(torch.tensor(action, device=self.device)).item()
470 |
471 | return log_prob
472 |
473 | def _update_policy(
474 | self,
475 | states: torch.Tensor,
476 | actions: torch.Tensor,
477 | old_log_probs: torch.Tensor,
478 | advantages: torch.Tensor
479 | ) -> Tuple[float, float, float]:
480 | """
481 | Update the policy network using PPO.
482 |
483 | Args:
484 | states: Batch of states
485 | actions: Batch of actions
486 | old_log_probs: Batch of old log probabilities
487 | advantages: Batch of advantages
488 |
489 | Returns:
490 | Tuple of (policy_loss, entropy, kl_divergence)
491 | """
492 | # Get action distributions
493 | logits = self.policy(states)
494 | distribution = Categorical(logits=logits)
495 |
496 | # Get new log probabilities
497 | new_log_probs = distribution.log_prob(actions)
498 |
499 | # Compute probability ratio
500 | ratio = torch.exp(new_log_probs - old_log_probs)
501 |
502 | # Compute surrogate objectives
503 | surrogate1 = ratio * advantages
504 | surrogate2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * advantages
505 |
506 | # Compute policy loss (negative because we're maximizing)
507 | policy_loss = -torch.min(surrogate1, surrogate2).mean()
508 |
509 | # Compute entropy bonus
510 | entropy = distribution.entropy().mean()
511 |
512 | # Add entropy bonus to loss
513 | loss = policy_loss - self.entropy_coef * entropy
514 |
515 | # Compute approximate KL divergence for monitoring
516 | with torch.no_grad():
517 | kl = (old_log_probs - new_log_probs).mean().item()
518 |
519 | # Update policy network
520 | self.policy_optimizer.zero_grad()
521 | loss.backward()
522 | nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
523 | self.policy_optimizer.step()
524 |
525 | return policy_loss.item(), entropy.item(), kl
526 |
527 | def _update_value(self, states: torch.Tensor, returns: torch.Tensor) -> float:
528 | """
529 | Update the value network.
530 |
531 | Args:
532 | states: Batch of states
533 | returns: Batch of returns
534 |
535 | Returns:
536 | Value loss
537 | """
538 | # Get value predictions
539 | values = self.value(states).squeeze()
540 |
541 | # Compute value loss
542 | value_loss = F.mse_loss(values, returns)
543 |
544 | # Update value network
545 | self.value_optimizer.zero_grad()
546 | value_loss.backward()
547 | nn.utils.clip_grad_norm_(self.value.parameters(), self.max_grad_norm)
548 | self.value_optimizer.step()
549 |
550 | return value_loss.item()
551 |
552 |
553 | class MultiAgentGroupRL:
554 | """
555 | Multi-agent reinforcement learning system using GRPO for Claude Code.
556 | This class manages multiple GRPO agents that learn in a coordinated way.
557 | """
558 |
559 | def __init__(
560 | self,
561 | agent_configs: List[Dict[str, Any]],
562 | feature_extractor: Callable[[Dict[str, Any]], np.ndarray],
563 | reward_function: Callable[[Dict[str, Any], str, Any], float],
564 | update_interval: int = 1000,
565 | training_epochs: int = 10,
566 | batch_size: int = 64,
567 | save_dir: str = "./models",
568 | device: str = "cuda" if torch.cuda.is_available() else "cpu",
569 | ):
570 | """
571 | Initialize the multi-agent RL system.
572 |
573 | Args:
574 | agent_configs: List of configurations for each agent
575 | feature_extractor: Function to extract state features
576 | reward_function: Function to compute rewards
577 | update_interval: How often to update agents (in steps)
578 | training_epochs: Number of epochs to train for each update
579 | batch_size: Batch size for training
580 | save_dir: Directory to save models
581 | device: Device to run on
582 | """
583 | self.feature_extractor = feature_extractor
584 | self.reward_function = reward_function
585 | self.update_interval = update_interval
586 | self.training_epochs = training_epochs
587 | self.batch_size = batch_size
588 | self.save_dir = save_dir
589 | self.device = device
590 |
591 | # Initialize agents
592 | self.agents = {}
593 | for config in agent_configs:
594 | agent_id = config["id"]
595 | state_dim = config["state_dim"]
596 | action_dim = config["action_dim"]
597 |
598 | # Create GRPO agent
599 | agent = GRPO(
600 | state_dim=state_dim,
601 | action_dim=action_dim,
602 | hidden_dims=config.get("hidden_dims", [64, 64]),
603 | device=device,
604 | **{k: v for k, v in config.items() if k not in ["id", "state_dim", "action_dim", "hidden_dims"]}
605 | )
606 |
607 | # Set agent ID
608 | agent.set_agent_id(agent_id)
609 |
610 | self.agents[agent_id] = agent
611 |
612 | # Track steps for periodic updates
613 | self.total_steps = 0
614 |
615 | # Store rewards for relative advantage computation
616 | self.agent_rewards = {agent_id: [] for agent_id in self.agents}
617 |
618 | def select_action(
619 | self, agent_id: str, observation: Dict[str, Any], deterministic: bool = False
620 | ) -> Tuple[Any, float]:
621 | """
622 | Select an action for the specified agent.
623 |
624 | Args:
625 | agent_id: ID of the agent
626 | observation: Current observation
627 | deterministic: Whether to select deterministically
628 |
629 | Returns:
630 | Tuple of (action, log probability)
631 | """
632 | if agent_id not in self.agents:
633 | raise ValueError(f"Unknown agent ID: {agent_id}")
634 |
635 | # Extract features
636 | state = self.feature_extractor(observation)
637 |
638 | # Get action from agent
639 | action, log_prob = self.agents[agent_id].get_action(state, deterministic)
640 |
641 | return action, log_prob
642 |
643 | def observe(
644 | self,
645 | agent_id: str,
646 | observation: Dict[str, Any],
647 | action: Any,
648 | reward: float,
649 | next_observation: Dict[str, Any],
650 | done: bool,
651 | info: Optional[Dict[str, Any]] = None
652 | ) -> None:
653 | """
654 | Record an observation for the specified agent.
655 |
656 | Args:
657 | agent_id: ID of the agent
658 | observation: Current observation
659 | action: Action taken
660 | reward: Reward received
661 | next_observation: Next observation
662 | done: Whether the episode is done
663 | info: Additional information
664 | """
665 | if agent_id not in self.agents:
666 | raise ValueError(f"Unknown agent ID: {agent_id}")
667 |
668 | # Extract features
669 | state = self.feature_extractor(observation)
670 | next_state = self.feature_extractor(next_observation)
671 |
672 | # Create experience
673 | exp = Experience(
674 | state=state,
675 | action=action,
676 | reward=reward,
677 | next_state=next_state,
678 | done=done,
679 | info=info
680 | )
681 |
682 | # Add experience to agent's buffer
683 | self.agents[agent_id].buffer.add(exp)
684 |
685 | # Store reward for relative advantage computation
686 | self.agent_rewards[agent_id].append(reward)
687 |
688 | # Increment step counter
689 | self.total_steps += 1
690 |
691 | # Perform updates if needed
692 | if self.total_steps % self.update_interval == 0:
693 | self.update_all_agents()
694 |
695 | def update_all_agents(self) -> Dict[str, Dict[str, float]]:
696 | """
697 | Update all agents' policies.
698 |
699 | Returns:
700 | Dictionary of training metrics for each agent
701 | """
702 | # Store metrics for each agent
703 | metrics = {}
704 |
705 | # Update each agent
706 | for agent_id, agent in self.agents.items():
707 | # Train the agent with group rewards
708 | agent_metrics = agent.learn(
709 | batch_size=self.batch_size,
710 | epochs=self.training_epochs,
711 | group_rewards=self.agent_rewards
712 | )
713 |
714 | metrics[agent_id] = agent_metrics
715 |
716 | # Reset reward tracking
717 | self.agent_rewards = {agent_id: [] for agent_id in self.agents}
718 |
719 | return metrics
720 |
721 | def save_agents(self, suffix: str = "") -> None:
722 | """
723 | Save all agents' models.
724 |
725 | Args:
726 | suffix: Optional suffix for saved files
727 | """
728 | import os
729 |
730 | # Create save directory if it doesn't exist
731 | os.makedirs(self.save_dir, exist_ok=True)
732 |
733 | # Save each agent
734 | for agent_id, agent in self.agents.items():
735 | # Create file path
736 | file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
737 |
738 | # Save model
739 | torch.save({
740 | "policy_state_dict": agent.policy.state_dict(),
741 | "value_state_dict": agent.value.state_dict(),
742 | "policy_optimizer_state_dict": agent.policy_optimizer.state_dict(),
743 | "value_optimizer_state_dict": agent.value_optimizer.state_dict(),
744 | }, file_path)
745 |
746 | def load_agents(self, suffix: str = "") -> None:
747 | """
748 | Load all agents' models.
749 |
750 | Args:
751 | suffix: Optional suffix for loaded files
752 | """
753 | import os
754 |
755 | # Load each agent
756 | for agent_id, agent in self.agents.items():
757 | # Create file path
758 | file_path = os.path.join(self.save_dir, f"{agent_id}{suffix}.pt")
759 |
760 | # Check if file exists
761 | if not os.path.exists(file_path):
762 | print(f"Warning: Model file not found for agent {agent_id}")
763 | continue
764 |
765 | # Load model
766 | checkpoint = torch.load(file_path, map_location=self.device)
767 |
768 | # Load state dicts
769 | agent.policy.load_state_dict(checkpoint["policy_state_dict"])
770 | agent.value.load_state_dict(checkpoint["value_state_dict"])
771 | agent.policy_optimizer.load_state_dict(checkpoint["policy_optimizer_state_dict"])
772 | agent.value_optimizer.load_state_dict(checkpoint["value_optimizer_state_dict"])
773 |
774 |
775 | class ToolSelectionGRPO:
776 | """
777 | Specialized GRPO implementation for tool selection in Claude Code.
778 | This class adapts the MultiAgentGroupRL for the specific context of tool selection.
779 | """
780 |
781 | def __init__(
782 | self,
783 | tool_registry: Any, # Should be a reference to the tool registry
784 | context_evaluator: Callable, # Function to evaluate quality of response given context
785 | state_dim: int = 768, # Embedding dimension for query
786 | num_agents: int = 3, # Number of agents in the group
787 | update_interval: int = 100,
788 | device: str = "cuda" if torch.cuda.is_available() else "cpu",
789 | ):
790 | """
791 | Initialize the GRPO tool selector.
792 |
793 | Args:
794 | tool_registry: Registry containing available tools
795 | context_evaluator: Function to evaluate response quality
796 | state_dim: Dimension of state features
797 | num_agents: Number of agents in the group
798 | update_interval: How often to update agents
799 | device: Device to run on
800 | """
801 | self.tool_registry = tool_registry
802 | self.context_evaluator = context_evaluator
803 |
804 | # Get all available tools
805 | self.tool_names = tool_registry.get_all_tool_names()
806 | self.action_dim = len(self.tool_names)
807 |
808 | # Define agent configurations
809 | agent_configs = [
810 | {
811 | "id": f"tool_agent_{i}",
812 | "state_dim": state_dim,
813 | "action_dim": self.action_dim,
814 | "hidden_dims": [256, 128],
815 | "relative_advantage_weight": 0.7 if i > 0 else 0.3, # Different weights
816 | "entropy_coef": 0.02 if i == 0 else 0.01, # Different exploration rates
817 | }
818 | for i in range(num_agents)
819 | ]
820 |
821 | # Initialize multi-agent RL system
822 | self.rl_system = MultiAgentGroupRL(
823 | agent_configs=agent_configs,
824 | feature_extractor=self._extract_features,
825 | reward_function=self._compute_reward,
826 | update_interval=update_interval,
827 | device=device,
828 | )
829 |
830 | # Track current episode
831 | self.current_episode = {agent_id: {} for agent_id in self.rl_system.agents}
832 |
833 | def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> str:
834 | """
835 | Select the best tool to use for a given user query and context.
836 |
837 | Args:
838 | user_query: The user's query
839 | context: The current conversation context
840 | visualizer: Optional visualizer to display the selection process
841 |
842 | Returns:
843 | The name of the best tool to use
844 | """
845 | # Create observation
846 | observation = {
847 | "query": user_query,
848 | "context": context,
849 | }
850 |
851 | # If visualizer is provided, start it
852 | if visualizer:
853 | visualizer.start()
854 | visualizer.add_execution(
855 | execution_id="tool_selection",
856 | tool_name="GRPO Tool Selection",
857 | parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
858 | )
859 |
860 | # Select agent to use (round-robin for now)
861 | agent_id = f"tool_agent_{self.rl_system.total_steps % len(self.rl_system.agents)}"
862 |
863 | # Update visualizer if provided
864 | if visualizer:
865 | visualizer.update_progress("tool_selection", 0.3)
866 |
867 | # Get action from agent
868 | action_idx, _ = self.rl_system.select_action(
869 | agent_id=agent_id,
870 | observation=observation,
871 | deterministic=False # Use exploratory actions during learning
872 | )
873 |
874 | # Update visualizer if provided
875 | if visualizer:
876 | visualizer.update_progress("tool_selection", 0.6)
877 |
878 | # Store initial information for the episode
879 | self.current_episode[agent_id] = {
880 | "observation": observation,
881 | "action_idx": action_idx,
882 | "initial_quality": self.context_evaluator(context),
883 | }
884 |
885 | # Map action index to tool name
886 | tool_name = self.tool_names[action_idx]
887 |
888 | # Complete visualization if provided
889 | if visualizer:
890 | # Create detailed metrics for visualization
891 | agent_data = {}
892 | for aid, agent in self.rl_system.agents.items():
893 | # Get all tool probabilities for this agent
894 | with torch.no_grad():
895 | state = self.rl_system._extract_features(observation)
896 | state_tensor = torch.FloatTensor(state).unsqueeze(0).to(agent.device)
897 | logits = agent.policy(state_tensor)
898 | probs = F.softmax(logits, dim=1).squeeze().cpu().numpy()
899 |
900 | # Add to metrics
901 | agent_data[aid] = {
902 | "selected": aid == agent_id,
903 | "tool_probabilities": {
904 | self.tool_names[i]: float(prob)
905 | for i, prob in enumerate(probs)
906 | }
907 | }
908 |
909 | # Complete the visualization
910 | visualizer.complete_execution(
911 | execution_id="tool_selection",
912 | result={
913 | "selected_tool": tool_name,
914 | "selected_agent": agent_id,
915 | "agent_data": agent_data
916 | },
917 | status="success"
918 | )
919 | visualizer.stop()
920 |
921 | return tool_name
922 |
923 | def observe_result(
924 | self, agent_id: str, result: Any, context: Dict[str, Any], done: bool = True
925 | ) -> None:
926 | """
927 | Observe the result of using a tool.
928 |
929 | Args:
930 | agent_id: The ID of the agent that selected the tool
931 | result: The result of using the tool
932 | context: The updated context after using the tool
933 | done: Whether the interaction is complete
934 | """
935 | if agent_id not in self.current_episode:
936 | return
937 |
938 | # Get episode information
939 | episode = self.current_episode[agent_id]
940 | observation = episode["observation"]
941 | action_idx = episode["action_idx"]
942 | initial_quality = episode["initial_quality"]
943 |
944 | # Create next observation
945 | next_observation = {
946 | "query": observation["query"],
947 | "context": context,
948 | "result": result,
949 | }
950 |
951 | # Compute reward
952 | reward = self._compute_reward(observation, action_idx, result, context, initial_quality)
953 |
954 | # Record observation
955 | self.rl_system.observe(
956 | agent_id=agent_id,
957 | observation=observation,
958 | action=action_idx,
959 | reward=reward,
960 | next_observation=next_observation,
961 | done=done,
962 | )
963 |
964 | # Clear episode if done
965 | if done:
966 | self.current_episode[agent_id] = {}
967 |
968 | def _extract_features(self, observation: Dict[str, Any]) -> np.ndarray:
969 | """Extract features from an observation."""
970 | # This would ideally use an embedding model
971 | # For now, return a random vector as a placeholder
972 | return np.random.randn(768)
973 |
974 | def _compute_reward(
975 | self,
976 | observation: Dict[str, Any],
977 | action_idx: int,
978 | result: Any,
979 | context: Dict[str, Any],
980 | initial_quality: float
981 | ) -> float:
982 | """Compute the reward for an action."""
983 | # Compute the quality improvement
984 | final_quality = self.context_evaluator(context)
985 | quality_improvement = final_quality - initial_quality
986 |
987 | # Base reward on quality improvement
988 | reward = max(0, quality_improvement * 10) # Scale for better learning
989 |
990 | return reward
991 |
992 | def update(self) -> Dict[str, Dict[str, float]]:
993 | """
994 | Trigger an update of all agents.
995 |
996 | Returns:
997 | Dictionary of training metrics
998 | """
999 | return self.rl_system.update_all_agents()
1000 |
1001 | def save(self, suffix: str = "") -> None:
1002 | """Save all agents."""
1003 | self.rl_system.save_agents(suffix)
1004 |
1005 | def load(self, suffix: str = "") -> None:
1006 | """Load all agents."""
1007 | self.rl_system.load_agents(suffix)
```
--------------------------------------------------------------------------------
/claude_code/lib/rl/mcts.py:
--------------------------------------------------------------------------------
```python
1 | """
2 | Monte Carlo Tree Search implementation for decision making in Claude Code.
3 | This module provides an advanced MCTS implementation that can be used to select
4 | optimal actions/tools based on simulated outcomes.
5 | """
6 |
7 | import math
8 | import numpy as np
9 | import random
10 | from typing import List, Dict, Any, Callable, Tuple, Optional, Union
11 | from dataclasses import dataclass
12 |
13 |
14 | @dataclass
15 | class MCTSNode:
16 | """Represents a node in the Monte Carlo search tree."""
17 | state: Any
18 | parent: Optional['MCTSNode'] = None
19 | action_taken: Any = None
20 | visits: int = 0
21 | value: float = 0.0
22 | children: Dict[Any, 'MCTSNode'] = None
23 |
24 | def __post_init__(self):
25 | if self.children is None:
26 | self.children = {}
27 |
28 | def is_fully_expanded(self, possible_actions: List[Any]) -> bool:
29 | """Check if all possible actions have been tried from this node."""
30 | return all(action in self.children for action in possible_actions)
31 |
32 | def is_terminal(self) -> bool:
33 | """Check if this node represents a terminal state."""
34 | # This should be customized based on your environment
35 | return False
36 |
37 | def best_child(self, exploration_weight: float = 1.0) -> 'MCTSNode':
38 | """Select the best child node according to UCB1 formula."""
39 | if not self.children:
40 | return None
41 |
42 | def ucb_score(child: MCTSNode) -> float:
43 | exploitation = child.value / child.visits if child.visits > 0 else 0
44 | exploration = math.sqrt(2 * math.log(self.visits) / child.visits) if child.visits > 0 else float('inf')
45 | return exploitation + exploration_weight * exploration
46 |
47 | return max(self.children.values(), key=ucb_score)
48 |
49 |
50 | class AdvancedMCTS:
51 | """
52 | Advanced Monte Carlo Tree Search implementation with various enhancements:
53 | - Progressive widening for large/continuous action spaces
54 | - RAVE (Rapid Action Value Estimation)
55 | - Parallel simulations
56 | - Dynamic exploration weight
57 | - Customizable simulation and backpropagation strategies
58 | """
59 |
60 | def __init__(
61 | self,
62 | state_evaluator: Callable[[Any], float],
63 | action_generator: Callable[[Any], List[Any]],
64 | simulator: Callable[[Any, Any], Any],
65 | max_iterations: int = 1000,
66 | exploration_weight: float = 1.0,
67 | time_limit: Optional[float] = None,
68 | progressive_widening: bool = False,
69 | pw_coef: float = 0.5,
70 | pw_power: float = 0.5,
71 | use_rave: bool = False,
72 | rave_equiv_param: float = 1000,
73 | ):
74 | """
75 | Initialize the MCTS algorithm.
76 |
77 | Args:
78 | state_evaluator: Function to evaluate the value of a state (terminal or not)
79 | action_generator: Function to generate possible actions from a state
80 | simulator: Function to simulate taking an action in a state, returning new state
81 | max_iterations: Maximum number of search iterations
82 | exploration_weight: Controls exploration vs exploitation balance
83 | time_limit: Optional time limit for search in seconds
84 | progressive_widening: Whether to use progressive widening for large action spaces
85 | pw_coef: Coefficient for progressive widening
86 | pw_power: Power for progressive widening
87 | use_rave: Whether to use RAVE (Rapid Action Value Estimation)
88 | rave_equiv_param: RAVE equivalence parameter
89 | """
90 | self.state_evaluator = state_evaluator
91 | self.action_generator = action_generator
92 | self.simulator = simulator
93 | self.max_iterations = max_iterations
94 | self.exploration_weight = exploration_weight
95 | self.time_limit = time_limit
96 |
97 | # Progressive widening parameters
98 | self.progressive_widening = progressive_widening
99 | self.pw_coef = pw_coef
100 | self.pw_power = pw_power
101 |
102 | # RAVE parameters
103 | self.use_rave = use_rave
104 | self.rave_equiv_param = rave_equiv_param
105 | self.rave_values = {} # (state, action) -> (value, visits)
106 |
107 | def search(self, initial_state: Any, visualizer=None) -> Any:
108 | """
109 | Perform MCTS search from the initial state and return the best action.
110 |
111 | Args:
112 | initial_state: The starting state for the search
113 | visualizer: Optional visualizer to show progress
114 |
115 | Returns:
116 | The best action found by the search
117 | """
118 | root = MCTSNode(state=initial_state)
119 |
120 | # Initialize visualizer if provided
121 | if visualizer:
122 | visualizer.set_search_parameters(root, self.max_iterations)
123 |
124 | # Run iterations of the MCTS algorithm
125 | for iteration in range(self.max_iterations):
126 | # Selection phase
127 | selected_node = self._select(root)
128 |
129 | # Expansion phase (if not terminal)
130 | expanded_node = None
131 | if not selected_node.is_terminal():
132 | expanded_node = self._expand(selected_node)
133 | else:
134 | expanded_node = selected_node
135 |
136 | # Simulation phase
137 | simulation_path = []
138 | if visualizer:
139 | # Track simulation path for visualization
140 | current = expanded_node
141 | current_state = current.state
142 | while current.parent:
143 | simulation_path.insert(0, (current.parent.state, current.action_taken))
144 | current = current.parent
145 |
146 | simulation_result = self._simulate(expanded_node)
147 |
148 | # Backpropagation phase
149 | self._backpropagate(expanded_node, simulation_result)
150 |
151 | # Update visualization
152 | if visualizer:
153 | # Find current best action
154 | best_action = None
155 | if root.children:
156 | best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
157 |
158 | # Update visualizer
159 | visualizer.update_iteration(
160 | iteration=iteration + 1,
161 | selected_node=selected_node,
162 | expanded_node=expanded_node,
163 | simulation_path=simulation_path,
164 | simulation_result=simulation_result,
165 | best_action=best_action
166 | )
167 |
168 | # Return the action that leads to the child with the highest value
169 | if not root.children:
170 | possible_actions = self.action_generator(root.state)
171 | if possible_actions:
172 | best_action = random.choice(possible_actions)
173 | if visualizer:
174 | visualizer.update_iteration(
175 | iteration=self.max_iterations,
176 | best_action=best_action
177 | )
178 | return best_action
179 | return None
180 |
181 | best_action = max(root.children.items(), key=lambda x: x[1].visits)[0]
182 | if visualizer:
183 | visualizer.update_iteration(
184 | iteration=self.max_iterations,
185 | best_action=best_action
186 | )
187 | return best_action
188 |
189 | def _select(self, node: MCTSNode) -> MCTSNode:
190 | """
191 | Select a node to expand using UCB1 and progressive widening if enabled.
192 |
193 | Args:
194 | node: The current node
195 |
196 | Returns:
197 | The selected node for expansion
198 | """
199 | while not node.is_terminal():
200 | possible_actions = self.action_generator(node.state)
201 |
202 | # Handle progressive widening if enabled
203 | if self.progressive_widening:
204 | max_children = max(1, int(self.pw_coef * (node.visits ** self.pw_power)))
205 | if len(node.children) < min(max_children, len(possible_actions)):
206 | return node
207 |
208 | # If not fully expanded, select this node for expansion
209 | if not node.is_fully_expanded(possible_actions):
210 | return node
211 |
212 | # Otherwise, select the best child according to UCB1
213 | node = node.best_child(self.exploration_weight)
214 | if node is None:
215 | break
216 |
217 | return node
218 |
219 | def _expand(self, node: MCTSNode) -> MCTSNode:
220 | """
221 | Expand the node by selecting an untried action and creating a new child node.
222 |
223 | Args:
224 | node: The node to expand
225 |
226 | Returns:
227 | The newly created child node
228 | """
229 | possible_actions = self.action_generator(node.state)
230 | untried_actions = [a for a in possible_actions if a not in node.children]
231 |
232 | if not untried_actions:
233 | return node
234 |
235 | action = random.choice(untried_actions)
236 | new_state = self.simulator(node.state, action)
237 | child_node = MCTSNode(
238 | state=new_state,
239 | parent=node,
240 | action_taken=action
241 | )
242 | node.children[action] = child_node
243 | return child_node
244 |
245 | def _simulate(self, node: MCTSNode, depth: int = 10) -> float:
246 | """
247 | Simulate a random playout from the given node until a terminal state or max depth.
248 |
249 | Args:
250 | node: The node to start simulation from
251 | depth: Maximum simulation depth
252 |
253 | Returns:
254 | The value of the simulated outcome
255 | """
256 | state = node.state
257 | current_depth = 0
258 |
259 | # Continue simulation until we reach a terminal state or max depth
260 | while current_depth < depth:
261 | if self._is_terminal_state(state):
262 | break
263 |
264 | possible_actions = self.action_generator(state)
265 | if not possible_actions:
266 | break
267 |
268 | action = random.choice(possible_actions)
269 | state = self.simulator(state, action)
270 | current_depth += 1
271 |
272 | return self.state_evaluator(state)
273 |
274 | def _is_terminal_state(self, state: Any) -> bool:
275 | """Determine if the state is terminal."""
276 | # This should be customized based on your environment
277 | return False
278 |
279 | def _backpropagate(self, node: MCTSNode, value: float) -> None:
280 | """
281 | Backpropagate the simulation result up the tree.
282 |
283 | Args:
284 | node: The leaf node where simulation started
285 | value: The value from the simulation
286 | """
287 | while node is not None:
288 | node.visits += 1
289 | node.value += value
290 |
291 | # Update RAVE values if enabled
292 | if self.use_rave and node.parent is not None:
293 | state_hash = self._hash_state(node.parent.state)
294 | action = node.action_taken
295 | if (state_hash, action) not in self.rave_values:
296 | self.rave_values[(state_hash, action)] = [0, 0] # [value, visits]
297 | rave_value, rave_visits = self.rave_values[(state_hash, action)]
298 | self.rave_values[(state_hash, action)] = [
299 | rave_value + value,
300 | rave_visits + 1
301 | ]
302 |
303 | node = node.parent
304 |
305 | def _hash_state(self, state: Any) -> int:
306 | """Create a hash of the state for RAVE table lookups."""
307 | # This should be customized based on your state representation
308 | if hasattr(state, "__hash__"):
309 | return hash(state)
310 | return hash(str(state))
311 |
312 |
313 | class MCTSToolSelector:
314 | """
315 | Specialized MCTS implementation for selecting optimal tools in Claude Code.
316 | This class adapts the AdvancedMCTS for the specific context of tool selection.
317 | """
318 |
319 | def __init__(
320 | self,
321 | tool_registry: Any, # Should be a reference to the tool registry
322 | context_evaluator: Callable, # Function to evaluate quality of response given context
323 | max_iterations: int = 200,
324 | exploration_weight: float = 1.0,
325 | use_learning: bool = True,
326 | tool_history_weight: float = 0.7,
327 | enable_plan_generation: bool = True,
328 | use_semantic_similarity: bool = True,
329 | adaptation_rate: float = 0.05
330 | ):
331 | """
332 | Initialize the MCTS tool selector with enhanced intelligence.
333 |
334 | Args:
335 | tool_registry: Registry containing available tools
336 | context_evaluator: Function to evaluate response quality
337 | max_iterations: Maximum search iterations
338 | exploration_weight: Controls exploration vs exploitation
339 | use_learning: Whether to use learning from past tool selections
340 | tool_history_weight: Weight given to historical tool performance
341 | enable_plan_generation: Generate complete tool sequences as plans
342 | use_semantic_similarity: Use semantic similarity for tool relevance
343 | adaptation_rate: Rate at which the system adapts to new patterns
344 | """
345 | self.tool_registry = tool_registry
346 | self.context_evaluator = context_evaluator
347 | self.use_learning = use_learning
348 | self.tool_history_weight = tool_history_weight
349 | self.enable_plan_generation = enable_plan_generation
350 | self.use_semantic_similarity = use_semantic_similarity
351 | self.adaptation_rate = adaptation_rate
352 |
353 | # Tool performance history by query type
354 | self.tool_history = {}
355 |
356 | # Tool sequence effectiveness records
357 | self.sequence_effectiveness = {}
358 |
359 | # Semantic fingerprints for tools and queries
360 | self.tool_fingerprints = {}
361 | self.query_clusters = {}
362 |
363 | # Cached simulation results for similar queries
364 | self.simulation_cache = {}
365 |
366 | # Initialize the MCTS algorithm
367 | self.mcts = AdvancedMCTS(
368 | state_evaluator=self._evaluate_state,
369 | action_generator=self._generate_actions,
370 | simulator=self._simulate_action,
371 | max_iterations=max_iterations,
372 | exploration_weight=exploration_weight,
373 | progressive_widening=True
374 | )
375 |
376 | # Initialize tool fingerprints
377 | self._initialize_tool_fingerprints()
378 |
379 | def _initialize_tool_fingerprints(self):
380 | """Initialize semantic fingerprints for all available tools."""
381 | if not self.use_semantic_similarity:
382 | return
383 |
384 | for tool_name in self.tool_registry.get_all_tool_names():
385 | tool = self.tool_registry.get_tool(tool_name)
386 | if tool and hasattr(tool, 'description'):
387 | # In a real implementation, this would compute an embedding
388 | # For now, we'll use a simple keyword extraction as a placeholder
389 | keywords = set(word.lower() for word in tool.description.split()
390 | if len(word) > 3)
391 | self.tool_fingerprints[tool_name] = {
392 | 'keywords': keywords,
393 | 'description': tool.description,
394 | 'usage_contexts': set()
395 | }
396 |
397 | def select_tool(self, user_query: str, context: Dict[str, Any], visualizer=None) -> Union[str, List[str]]:
398 | """
399 | Select the best tool to use for a given user query and context.
400 |
401 | Args:
402 | user_query: The user's query
403 | context: The current conversation context
404 | visualizer: Optional visualizer to show the selection process
405 |
406 | Returns:
407 | Either a single tool name or a sequence of tool names (if plan generation is enabled)
408 | """
409 | # Analyze query to determine its type/characteristics
410 | query_type = self._analyze_query(user_query)
411 |
412 | # Update semantic fingerprints with this query
413 | if self.use_semantic_similarity:
414 | self._update_query_clusters(user_query, query_type)
415 |
416 | initial_state = {
417 | 'query': user_query,
418 | 'query_type': query_type,
419 | 'context': context,
420 | 'actions_taken': [],
421 | 'response_quality': 0.0,
422 | 'steps_remaining': 3 if self.enable_plan_generation else 1,
423 | 'step_results': {}
424 | }
425 |
426 | # First check if we have a high-confidence cached result for similar queries
427 | cached_result = self._check_cache(user_query, query_type)
428 | if cached_result and random.random() > 0.1: # 10% random exploration
429 | if visualizer:
430 | visualizer.add_execution(
431 | execution_id="mcts_cache_hit",
432 | tool_name="MCTS Tool Selection (cached)",
433 | parameters={"query": user_query[:100] + "..." if len(user_query) > 100 else user_query}
434 | )
435 | visualizer.complete_execution(
436 | execution_id="mcts_cache_hit",
437 | result={"selected_tool": cached_result, "source": "cache"},
438 | status="success"
439 | )
440 | return cached_result
441 |
442 | # Run MCTS search
443 | best_action = self.mcts.search(initial_state, visualizer)
444 |
445 | # If plan generation is enabled, we might want to return a sequence
446 | if self.enable_plan_generation:
447 | # Extract the most promising action sequence from search
448 | plan = self._extract_plan_from_search()
449 | if plan and len(plan) > 1:
450 | # Store this plan in our cache
451 | self._cache_result(user_query, query_type, plan)
452 | return plan
453 |
454 | # Store single action in cache
455 | self._cache_result(user_query, query_type, best_action)
456 | return best_action
457 |
458 | def _analyze_query(self, query: str) -> str:
459 | """
460 | Analyze a query to determine its type and characteristics.
461 |
462 | Args:
463 | query: The user query
464 |
465 | Returns:
466 | A string identifying the query type
467 | """
468 | query_lower = query.lower()
469 |
470 | # Check for search-related queries
471 | if any(term in query_lower for term in ['find', 'search', 'where', 'look for']):
472 | return 'search'
473 |
474 | # Check for explanation queries
475 | if any(term in query_lower for term in ['explain', 'how', 'why', 'what is']):
476 | return 'explanation'
477 |
478 | # Check for file operation queries
479 | if any(term in query_lower for term in ['file', 'read', 'write', 'edit', 'create']):
480 | return 'file_operation'
481 |
482 | # Check for execution queries
483 | if any(term in query_lower for term in ['run', 'execute', 'start']):
484 | return 'execution'
485 |
486 | # Check for debugging queries
487 | if any(term in query_lower for term in ['debug', 'fix', 'error', 'problem']):
488 | return 'debugging'
489 |
490 | # Default to general
491 | return 'general'
492 |
493 | def _update_query_clusters(self, query: str, query_type: str):
494 | """
495 | Update query clusters with new query information.
496 |
497 | Args:
498 | query: The user query
499 | query_type: The type of query
500 | """
501 | # Extract query keywords
502 | keywords = set(word.lower() for word in query.split() if len(word) > 3)
503 |
504 | # Update query clusters
505 | if query_type not in self.query_clusters:
506 | self.query_clusters[query_type] = {
507 | 'keywords': set(),
508 | 'queries': []
509 | }
510 |
511 | # Add keywords to cluster
512 | self.query_clusters[query_type]['keywords'].update(keywords)
513 |
514 | # Add query to cluster (limit to last 50)
515 | self.query_clusters[query_type]['queries'].append(query)
516 | if len(self.query_clusters[query_type]['queries']) > 50:
517 | self.query_clusters[query_type]['queries'].pop(0)
518 |
519 | # Update tool fingerprints with these keywords
520 | for tool_name, fingerprint in self.tool_fingerprints.items():
521 | # If tool has been used successfully for this query type before
522 | if tool_name in self.tool_history.get(query_type, {}) and \
523 | self.tool_history[query_type][tool_name]['success_rate'] > 0.6:
524 | fingerprint['usage_contexts'].add(query_type)
525 |
526 | def _check_cache(self, query: str, query_type: str) -> Union[str, List[str], None]:
527 | """
528 | Check if we have a cached result for a similar query.
529 |
530 | Args:
531 | query: The user query
532 | query_type: The type of query
533 |
534 | Returns:
535 | A cached tool selection or None
536 | """
537 | if not self.use_learning or query_type not in self.tool_history:
538 | return None
539 |
540 | # Find the most successful tool for this query type
541 | type_history = self.tool_history[query_type]
542 | best_tools = sorted(
543 | [(tool, data['success_rate']) for tool, data in type_history.items()],
544 | key=lambda x: x[1],
545 | reverse=True
546 | )
547 |
548 | # Only use cache if we have a high confidence result
549 | if best_tools and best_tools[0][1] > 0.75:
550 | return best_tools[0][0]
551 |
552 | return None
553 |
554 | def _cache_result(self, query: str, query_type: str, action: Union[str, List[str]]):
555 | """
556 | Cache a result for future similar queries.
557 |
558 | Args:
559 | query: The user query
560 | query_type: The type of query
561 | action: The selected action or plan
562 | """
563 | # Store in simulation cache
564 | query_key = self._get_query_cache_key(query)
565 | self.simulation_cache[query_key] = {
566 | 'action': action,
567 | 'timestamp': self._get_timestamp(),
568 | 'query_type': query_type
569 | }
570 |
571 | # Limit cache size
572 | if len(self.simulation_cache) > 1000:
573 | # Remove oldest entries
574 | oldest_key = min(self.simulation_cache.keys(),
575 | key=lambda k: self.simulation_cache[k]['timestamp'])
576 | del self.simulation_cache[oldest_key]
577 |
578 | def _get_query_cache_key(self, query: str) -> str:
579 | """Generate a cache key for a query."""
580 | # In a real implementation, this might use a hash of query embeddings
581 | # For now, use a simple keyword approach
582 | keywords = ' '.join(sorted(set(word.lower() for word in query.split() if len(word) > 3)))
583 | return keywords[:100] # Limit key length
584 |
585 | def _get_timestamp(self):
586 | """Get current timestamp."""
587 | import time
588 | return time.time()
589 |
590 | def _evaluate_state(self, state: Dict[str, Any]) -> float:
591 | """
592 | Evaluate the quality of a state based on response quality and steps.
593 |
594 | Args:
595 | state: The current state
596 |
597 | Returns:
598 | A quality score
599 | """
600 | # Base score is the response quality
601 | score = state['response_quality']
602 |
603 | # If plan generation is enabled, we want to encourage complete plans
604 | if self.enable_plan_generation:
605 | steps_completed = len(state['actions_taken'])
606 | total_steps = steps_completed + state['steps_remaining']
607 |
608 | # Add bonus for completing more steps
609 | if total_steps > 0:
610 | step_completion_bonus = steps_completed / total_steps
611 | score += step_completion_bonus * 0.2 # 20% bonus for step completion
612 |
613 | return score
614 |
615 | def _generate_actions(self, state: Dict[str, Any]) -> List[str]:
616 | """
617 | Generate possible tool actions from the current state with intelligent filtering.
618 |
619 | Args:
620 | state: The current state
621 |
622 | Returns:
623 | List of possible actions
624 | """
625 | # Get query type
626 | query_type = state['query_type']
627 | query = state['query']
628 |
629 | # Get all available tools
630 | all_tools = set(self.tool_registry.get_all_tool_names())
631 |
632 | # Tools already used in this sequence
633 | used_tools = set(state['actions_taken'])
634 |
635 | # Remaining tools
636 | remaining_tools = all_tools - used_tools
637 |
638 | # If we're using learning, prioritize tools based on history
639 | if self.use_learning and query_type in self.tool_history:
640 | prioritized_tools = []
641 |
642 | # First, add tools that have been successful for this query type
643 | type_history = self.tool_history[query_type]
644 |
645 | # Check for successful tools
646 | for tool in remaining_tools:
647 | if tool in type_history and type_history[tool]['success_rate'] > 0.5:
648 | prioritized_tools.append(tool)
649 |
650 | # If we have at least some tools, return them
651 | if prioritized_tools and random.random() < self.tool_history_weight:
652 | return prioritized_tools
653 |
654 | # If using semantic similarity, filter by relevant tools
655 | if self.use_semantic_similarity:
656 | query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
657 |
658 | # Score tools by semantic similarity to query
659 | scored_tools = []
660 | for tool in remaining_tools:
661 | if tool in self.tool_fingerprints:
662 | fingerprint = self.tool_fingerprints[tool]
663 |
664 | # Calculate keyword overlap
665 | keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
666 |
667 | # Check if tool has been used for this query type
668 | context_match = 1.0 if query_type in fingerprint['usage_contexts'] else 0.0
669 |
670 | # Combined score
671 | score = keyword_overlap * 0.7 + context_match * 0.3
672 |
673 | scored_tools.append((tool, score))
674 |
675 | # Sort and filter tools
676 | scored_tools.sort(key=lambda x: x[1], reverse=True)
677 |
678 | # Take top half of tools if we have enough
679 | if len(scored_tools) > 2:
680 | return [t[0] for t in scored_tools[:max(2, len(scored_tools) // 2)]]
681 |
682 | # If we reach here, use all remaining tools
683 | return list(remaining_tools)
684 |
685 | def _simulate_action(self, state: Dict[str, Any], action: str) -> Dict[str, Any]:
686 | """
687 | Simulate taking an action (using a tool) in the given state with enhanced modeling.
688 |
689 | Args:
690 | state: The current state
691 | action: The tool action to simulate
692 |
693 | Returns:
694 | The new state after taking the action
695 | """
696 | # Create a new state with the action added
697 | new_state = state.copy()
698 | new_actions = state['actions_taken'].copy()
699 | new_actions.append(action)
700 | new_state['actions_taken'] = new_actions
701 |
702 | # Decrement steps remaining if using plan generation
703 | if self.enable_plan_generation and new_state['steps_remaining'] > 0:
704 | new_state['steps_remaining'] -= 1
705 |
706 | # Get query type and query
707 | query_type = state['query_type']
708 | query = state['query']
709 |
710 | # Simulate step result
711 | step_results = state['step_results'].copy()
712 | step_results[action] = self._simulate_tool_result(action, query)
713 | new_state['step_results'] = step_results
714 |
715 | # Estimate tool relevance based on learning or semantic similarity
716 | tool_relevance = self._estimate_tool_relevance(action, query, query_type)
717 |
718 | # Check for sequence effects (tools that work well together)
719 | sequence_bonus = 0.0
720 | if len(new_actions) > 1:
721 | prev_tool = new_actions[-2]
722 | sequence_key = f"{prev_tool}->{action}"
723 | if sequence_key in self.sequence_effectiveness:
724 | sequence_bonus = self.sequence_effectiveness[sequence_key] * 0.3 # 30% weight for sequence effects
725 |
726 | # Update quality based on relevance and sequence effects
727 | current_quality = state['response_quality']
728 | quality_improvement = tool_relevance + sequence_bonus
729 |
730 | # Add diminishing returns effect for additional tools
731 | if len(new_actions) > 1:
732 | diminishing_factor = 1.0 / len(new_actions)
733 | quality_improvement *= diminishing_factor
734 |
735 | new_quality = min(1.0, current_quality + quality_improvement)
736 | new_state['response_quality'] = new_quality
737 |
738 | return new_state
739 |
740 | def _simulate_tool_result(self, tool_name: str, query: str) -> Dict[str, Any]:
741 | """
742 | Simulate the result of using a tool for a query.
743 |
744 | Args:
745 | tool_name: The name of the tool
746 | query: The user query
747 |
748 | Returns:
749 | A simulated result
750 | """
751 | # In a real implementation, this would be a more sophisticated simulation
752 | return {
753 | "tool": tool_name,
754 | "success_probability": self._estimate_tool_relevance(tool_name, query),
755 | "simulated": True
756 | }
757 |
758 | def _estimate_tool_relevance(self, tool_name: str, query: str, query_type: str = None) -> float:
759 | """
760 | Estimate how relevant a tool is for a given query using history and semantics.
761 |
762 | Args:
763 | tool_name: The name of the tool
764 | query: The user query
765 | query_type: Optional query type
766 |
767 | Returns:
768 | A relevance score between 0.0 and 1.0
769 | """
770 | relevance_score = 0.0
771 |
772 | # If we have historical data for this query type
773 | if self.use_learning and query_type and query_type in self.tool_history and \
774 | tool_name in self.tool_history[query_type]:
775 |
776 | # Get historical success rate
777 | history_score = self.tool_history[query_type][tool_name]['success_rate']
778 | relevance_score += history_score * self.tool_history_weight
779 |
780 | # If we're using semantic similarity
781 | if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
782 | fingerprint = self.tool_fingerprints[tool_name]
783 |
784 | # Calculate keyword overlap
785 | query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
786 | keyword_overlap = len(query_keywords.intersection(fingerprint['keywords']))
787 |
788 | # Normalize by query keywords
789 | if query_keywords:
790 | semantic_score = keyword_overlap / len(query_keywords)
791 | relevance_score += semantic_score * (1.0 - self.tool_history_weight)
792 |
793 | # Ensure we have a minimum score for exploration
794 | if relevance_score < 0.1:
795 | relevance_score = 0.1 + (random.random() * 0.1) # Random boost between 0.1-0.2
796 |
797 | return relevance_score
798 |
799 | def _extract_plan_from_search(self) -> List[str]:
800 | """
801 | Extract a complete plan (tool sequence) from the search results.
802 |
803 | Returns:
804 | A list of tool names representing the plan
805 | """
806 | # In a real implementation, this would extract the highest value path
807 | # from the search tree. For now, return None to indicate no plan extraction.
808 | return None
809 |
810 | def update_tool_history(self, tool_name: str, query: str, success: bool,
811 | execution_time: float, result: Any = None):
812 | """
813 | Update the tool history with the results of using a tool.
814 |
815 | Args:
816 | tool_name: The name of the tool used
817 | query: The query the tool was used for
818 | success: Whether the tool was successful
819 | execution_time: The execution time in seconds
820 | result: Optional result of the tool execution
821 | """
822 | if not self.use_learning:
823 | return
824 |
825 | # Get query type
826 | query_type = self._analyze_query(query)
827 |
828 | # Initialize history entry if needed
829 | if query_type not in self.tool_history:
830 | self.tool_history[query_type] = {}
831 |
832 | if tool_name not in self.tool_history[query_type]:
833 | self.tool_history[query_type][tool_name] = {
834 | 'success_count': 0,
835 | 'failure_count': 0,
836 | 'total_time': 0.0,
837 | 'success_rate': 0.0,
838 | 'avg_time': 0.0,
839 | 'examples': []
840 | }
841 |
842 | # Update history
843 | history = self.tool_history[query_type][tool_name]
844 |
845 | # Update counts
846 | if success:
847 | history['success_count'] += 1
848 | else:
849 | history['failure_count'] += 1
850 |
851 | # Update time
852 | history['total_time'] += execution_time
853 |
854 | # Update success rate
855 | total = history['success_count'] + history['failure_count']
856 | history['success_rate'] = history['success_count'] / total if total > 0 else 0.0
857 |
858 | # Update average time
859 | history['avg_time'] = history['total_time'] / total if total > 0 else 0.0
860 |
861 | # Add example (limit to last 5)
862 | history['examples'].append({
863 | 'query': query,
864 | 'success': success,
865 | 'timestamp': self._get_timestamp()
866 | })
867 | if len(history['examples']) > 5:
868 | history['examples'].pop(0)
869 |
870 | # Update tool fingerprint
871 | if self.use_semantic_similarity and tool_name in self.tool_fingerprints:
872 | if success:
873 | # Add query type to usage contexts
874 | self.tool_fingerprints[tool_name]['usage_contexts'].add(query_type)
875 |
876 | # Add query keywords to tool fingerprint (with decay)
877 | query_keywords = set(word.lower() for word in query.split() if len(word) > 3)
878 | current_keywords = self.tool_fingerprints[tool_name]['keywords']
879 |
880 | # Add new keywords with adaptation rate
881 | for keyword in query_keywords:
882 | if keyword not in current_keywords:
883 | if random.random() < self.adaptation_rate:
884 | current_keywords.add(keyword)
885 |
886 | def update_sequence_effectiveness(self, tool_sequence: List[str], success: bool, quality_score: float):
887 | """
888 | Update the effectiveness record for a sequence of tools.
889 |
890 | Args:
891 | tool_sequence: The sequence of tools used
892 | success: Whether the sequence was successful
893 | quality_score: A quality score for the sequence
894 | """
895 | if not self.use_learning or len(tool_sequence) < 2:
896 | return
897 |
898 | # Update pairwise effectiveness
899 | for i in range(len(tool_sequence) - 1):
900 | first_tool = tool_sequence[i]
901 | second_tool = tool_sequence[i + 1]
902 | sequence_key = f"{first_tool}->{second_tool}"
903 |
904 | if sequence_key not in self.sequence_effectiveness:
905 | self.sequence_effectiveness[sequence_key] = 0.5 # Initial neutral score
906 |
907 | # Update score with decay
908 | current_score = self.sequence_effectiveness[sequence_key]
909 | if success:
910 | # Increase score with quality bonus
911 | new_score = current_score + self.adaptation_rate * quality_score
912 | else:
913 | # Decrease score
914 | new_score = current_score - self.adaptation_rate
915 |
916 | # Clamp between 0 and 1
917 | self.sequence_effectiveness[sequence_key] = max(0.0, min(1.0, new_score))
```
--------------------------------------------------------------------------------
/claude_code/lib/ui/tool_visualizer.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | # claude_code/lib/ui/tool_visualizer.py
3 | """Real-time tool execution visualization."""
4 |
5 | import logging
6 | import time
7 | import json
8 | from typing import Dict, List, Any, Optional
9 |
10 | from rich.console import Console
11 | from rich.panel import Panel
12 | from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn
13 | from rich.table import Table
14 | from rich.box import ROUNDED
15 | from rich.text import Text
16 | from rich.live import Live
17 | from rich.layout import Layout
18 | from rich.syntax import Syntax
19 |
20 | from ..tools.base import ToolResult
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | class ToolCallVisualizer:
26 | """Visualizes tool calls in real-time."""
27 |
28 | def __init__(self, console: Console):
29 | """Initialize the tool call visualizer.
30 |
31 | Args:
32 | console: Rich console instance
33 | """
34 | self.console = console
35 | self.active_calls: Dict[str, Dict[str, Any]] = {}
36 | self.completed_calls: List[Dict[str, Any]] = []
37 | self.layout = self._create_layout()
38 | self.live = Live(self.layout, console=console, refresh_per_second=4, auto_refresh=False)
39 | self.max_completed_calls = 5
40 |
41 | # Keep track of recent tool results for routines
42 | self.recent_tool_results: List[ToolResult] = []
43 | self.max_recent_results = 20 # Maximum number of recent results to track
44 |
45 | def _create_layout(self) -> Layout:
46 | """Create the layout for the tool call visualization.
47 |
48 | Returns:
49 | Layout object
50 | """
51 | layout = Layout()
52 | layout.split(
53 | Layout(name="active", size=3),
54 | Layout(name="completed", size=3)
55 | )
56 | return layout
57 |
58 | def _create_active_calls_panel(self) -> Panel:
59 | """Create a panel with active tool calls.
60 |
61 | Returns:
62 | Panel with active call information
63 | """
64 | if not self.active_calls:
65 | return Panel(
66 | "No active tool calls",
67 | title="[bold blue]Active Tool Calls[/bold blue]",
68 | border_style="blue",
69 | box=ROUNDED
70 | )
71 |
72 | # Create progress bars for each active call
73 | progress = Progress(
74 | TextColumn("[bold blue]{task.description}"),
75 | BarColumn(),
76 | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
77 | TimeElapsedColumn(),
78 | expand=True,
79 | console=self.console
80 | )
81 |
82 | # Add tasks for each active call
83 | for call_id, call_info in self.active_calls.items():
84 | if "task_id" not in call_info:
85 | # Create a new task for this call
86 | description = f"{call_info['tool_name']} ({call_id[:6]}...)"
87 | task_id = progress.add_task(description, total=100, completed=int(call_info["progress"] * 100))
88 | call_info["task_id"] = task_id
89 | else:
90 | # Update existing task
91 | progress.update(call_info["task_id"], completed=int(call_info["progress"] * 100))
92 |
93 | # Create a table with parameter information
94 | table = Table(show_header=True, header_style="bold cyan", box=ROUNDED, expand=True)
95 | table.add_column("Tool")
96 | table.add_column("Parameters")
97 |
98 | for call_id, call_info in self.active_calls.items():
99 | # Format parameters nicely
100 | params = call_info.get("parameters", {})
101 | if params:
102 | formatted_params = "\n".join([f"{k}: {self._format_value(v)}" for k, v in params.items()])
103 | else:
104 | formatted_params = "None"
105 |
106 | table.add_row(call_info["tool_name"], formatted_params)
107 |
108 | return Panel(
109 | progress,
110 | title="[bold blue]Active Tool Calls[/bold blue]",
111 | border_style="blue",
112 | box=ROUNDED
113 | )
114 |
115 | def _create_completed_calls_panel(self) -> Panel:
116 | """Create a panel with completed tool calls.
117 |
118 | Returns:
119 | Panel with completed call information
120 | """
121 | if not self.completed_calls:
122 | return Panel(
123 | "No completed tool calls",
124 | title="[bold green]Recent Tool Results[/bold green]",
125 | border_style="green",
126 | box=ROUNDED
127 | )
128 |
129 | # Create a table for results
130 | table = Table(show_header=True, header_style="bold green", box=ROUNDED, expand=True)
131 | table.add_column("Tool")
132 | table.add_column("Status")
133 | table.add_column("Time")
134 | table.add_column("Result Preview")
135 |
136 | # Show only the most recent completed calls
137 | for call_info in self.completed_calls[-self.max_completed_calls:]:
138 | tool_name = call_info["tool_name"]
139 | status = call_info["status"]
140 | execution_time = f"{call_info['execution_time']:.2f}s"
141 |
142 | # Format result preview
143 | result = call_info.get("result", "")
144 | if result:
145 | # Truncate and format result
146 | preview = self._format_result_preview(result, tool_name)
147 | else:
148 | preview = "No result"
149 |
150 | # Status with color
151 | status_text = Text(status)
152 | if status == "success":
153 | status_text.stylize("bold green")
154 | else:
155 | status_text.stylize("bold red")
156 |
157 | table.add_row(tool_name, status_text, execution_time, preview)
158 |
159 | return Panel(
160 | table,
161 | title="[bold green]Recent Tool Results[/bold green]",
162 | border_style="green",
163 | box=ROUNDED
164 | )
165 |
166 | def _format_value(self, value: Any) -> str:
167 | """Format a parameter value for display.
168 |
169 | Args:
170 | value: Parameter value
171 |
172 | Returns:
173 | Formatted string
174 | """
175 | if isinstance(value, (dict, list)):
176 | # Convert complex structures to JSON with indentation
177 | return json.dumps(value, indent=2)
178 | return str(value)
179 |
180 | def _format_result_preview(self, result: str, tool_name: str) -> str:
181 | """Format a result preview.
182 |
183 | Args:
184 | result: Result string
185 | tool_name: Name of the tool
186 |
187 | Returns:
188 | Formatted preview string
189 | """
190 | # Truncate result for preview
191 | if len(result) > 200:
192 | preview = result[:200] + "..."
193 | else:
194 | preview = result
195 |
196 | # Clean up newlines for display
197 | preview = preview.replace("\n", "\\n")
198 |
199 | return preview
200 |
201 | def start(self) -> None:
202 | """Start the visualization."""
203 | self.live.start()
204 | self.refresh()
205 |
206 | def stop(self) -> None:
207 | """Stop the visualization."""
208 | self.live.stop()
209 |
210 | def refresh(self) -> None:
211 | """Refresh the visualization."""
212 | # Update the layout with current information
213 | self.layout["active"].update(self._create_active_calls_panel())
214 | self.layout["completed"].update(self._create_completed_calls_panel())
215 |
216 | # Refresh the live display
217 | self.live.refresh()
218 |
219 | def add_tool_call(self, tool_call_id: str, tool_name: str, parameters: Dict[str, Any]) -> None:
220 | """Add a new tool call to visualize.
221 |
222 | Args:
223 | tool_call_id: ID of the tool call
224 | tool_name: Name of the tool
225 | parameters: Tool parameters
226 | """
227 | self.active_calls[tool_call_id] = {
228 | "tool_name": tool_name,
229 | "parameters": parameters,
230 | "start_time": time.time(),
231 | "progress": 0.0
232 | }
233 | self.refresh()
234 |
235 | def update_progress(self, tool_call_id: str, progress: float) -> None:
236 | """Update the progress of a tool call.
237 |
238 | Args:
239 | tool_call_id: ID of the tool call
240 | progress: Progress value (0-1)
241 | """
242 | if tool_call_id in self.active_calls:
243 | self.active_calls[tool_call_id]["progress"] = progress
244 | self.refresh()
245 |
246 | def complete_tool_call(self, tool_call_id: str, result: ToolResult) -> None:
247 | """Mark a tool call as complete.
248 |
249 | Args:
250 | tool_call_id: ID of the tool call
251 | result: Tool execution result
252 | """
253 | if tool_call_id in self.active_calls:
254 | call_info = self.active_calls[tool_call_id].copy()
255 |
256 | # Add result information
257 | call_info["result"] = result.result
258 | call_info["status"] = result.status
259 | call_info["execution_time"] = result.execution_time
260 | call_info["end_time"] = time.time()
261 |
262 | # Add to completed calls
263 | self.completed_calls.append(call_info)
264 |
265 | # Trim completed calls if needed
266 | if len(self.completed_calls) > self.max_completed_calls * 2:
267 | self.completed_calls = self.completed_calls[-self.max_completed_calls:]
268 |
269 | # Remove from active calls
270 | del self.active_calls[tool_call_id]
271 |
272 | # Store in recent tool results for routines
273 | if result.status == "success":
274 | self.recent_tool_results.append(result)
275 | # Keep only the most recent results
276 | if len(self.recent_tool_results) > self.max_recent_results:
277 | self.recent_tool_results.pop(0)
278 |
279 | self.refresh()
280 |
281 | def show_result_detail(self, result: ToolResult) -> None:
282 | """Display detailed result information.
283 |
284 | Args:
285 | result: Tool execution result
286 | """
287 | # Detect if result might be code
288 | content = result.result
289 | if content.startswith(("def ", "class ", "import ", "from ")) or "```" in content:
290 | # Try to extract code blocks
291 | if "```" in content:
292 | blocks = content.split("```")
293 | # Find a code block with a language specifier
294 | for i in range(1, len(blocks), 2):
295 | if i < len(blocks):
296 | lang = blocks[i].split("\n")[0].strip()
297 | code = "\n".join(blocks[i].split("\n")[1:])
298 | if lang and code:
299 | # Attempt to display as syntax-highlighted code
300 | try:
301 | syntax = Syntax(code, lang, theme="monokai", line_numbers=True)
302 | self.console.print(Panel(syntax, title=f"[bold]Result: {result.name}[/bold]"))
303 | return
304 | except Exception:
305 | pass
306 |
307 | # If we can't extract a code block, try to detect language
308 | for lang in ["python", "javascript", "bash", "json"]:
309 | try:
310 | syntax = Syntax(content, lang, theme="monokai", line_numbers=True)
311 | self.console.print(Panel(syntax, title=f"[bold]Result: {result.name}[/bold]"))
312 | return
313 | except Exception:
314 | pass
315 |
316 | # Just print as regular text if not code or if highlighting failed
317 | self.console.print(Panel(content, title=f"[bold]Result: {result.name}[/bold]"))
318 |
319 |
320 | class MCTSVisualizer:
321 | """Visualizes the Monte Carlo Tree Search process in real-time with enhanced intelligence."""
322 |
323 | def __init__(self, console: Console):
324 | """Initialize the MCTS visualizer.
325 |
326 | Args:
327 | console: Rich console instance
328 | """
329 | self.console = console
330 | self.root_node = None
331 | self.current_iteration = 0
332 | self.max_iterations = 0
333 | self.best_action = None
334 | self.active_simulation = None
335 | self.simulation_path = []
336 | self.layout = self._create_layout()
337 | self.live = Live(self.layout, console=console, refresh_per_second=10, auto_refresh=False)
338 |
339 | # Intelligence enhancement - track history
340 | self.action_history = {} # Track action performance over time
341 | self.visit_distribution = {} # Track how visits are distributed
342 | self.exploration_patterns = [] # Track exploration patterns
343 | self.quality_metrics = {"search_efficiency": 0.0, "exploration_balance": 0.0}
344 | self.auto_improvement_enabled = True
345 |
346 | def _create_layout(self) -> Layout:
347 | """Create the layout for MCTS visualization.
348 |
349 | Returns:
350 | Layout object
351 | """
352 | layout = Layout()
353 |
354 | # Create the main sections with more detailed visualization
355 | layout.split(
356 | Layout(name="header", size=3),
357 | Layout(name="main"),
358 | Layout(name="intelligence", size=7), # New section for intelligence metrics
359 | Layout(name="stats", size=5)
360 | )
361 |
362 | # Split the main section into tree, simulation and action insights
363 | layout["main"].split_row(
364 | Layout(name="tree", ratio=2),
365 | Layout(name="simulation", ratio=1),
366 | Layout(name="insights", ratio=1) # New section for action insights
367 | )
368 |
369 | return layout
370 |
371 | def set_search_parameters(self, root_node: Any, max_iterations: int, additional_params: Dict[str, Any] = None) -> None:
372 | """Set the search parameters with enhanced intelligence options.
373 |
374 | Args:
375 | root_node: The root node of the search tree
376 | max_iterations: Maximum number of iterations
377 | additional_params: Additional parameters for intelligent search
378 | """
379 | self.root_node = root_node
380 | self.max_iterations = max_iterations
381 | self.current_iteration = 0
382 |
383 | # Initialize intelligence tracking
384 | self.action_history = {}
385 | self.visit_distribution = {}
386 | self.exploration_patterns = []
387 |
388 | # Set additional intelligence parameters
389 | if additional_params:
390 | self.auto_improvement_enabled = additional_params.get('auto_improvement', True)
391 |
392 | # Apply any initial intelligence strategies
393 | if additional_params.get('initial_action_bias'):
394 | self.action_history = additional_params['initial_action_bias']
395 |
396 | self.refresh()
397 |
398 | def update_iteration(self, iteration: int, selected_node: Any = None,
399 | expanded_node: Any = None, simulation_path: List[Any] = None,
400 | simulation_result: float = None, best_action: Any = None,
401 | node_values: Dict[str, float] = None) -> None:
402 | """Update the current iteration status with enhanced tracking.
403 |
404 | Args:
405 | iteration: Current iteration number
406 | selected_node: Node selected in this iteration
407 | expanded_node: Node expanded in this iteration
408 | simulation_path: Path of the simulation
409 | simulation_result: Result of the simulation
410 | best_action: Current best action
411 | node_values: Values of important nodes in the search (for visualization)
412 | """
413 | self.current_iteration = iteration
414 | self.selected_node = selected_node
415 | self.expanded_node = expanded_node
416 | self.simulation_path = simulation_path or []
417 | self.simulation_result = simulation_result
418 |
419 | if best_action is not None:
420 | self.best_action = best_action
421 |
422 | # Intelligence tracking - update action history
423 | if self.simulation_path and simulation_result is not None:
424 | for _, action in self.simulation_path:
425 | if action is not None:
426 | action_str = str(action)
427 | if action_str not in self.action_history:
428 | self.action_history[action_str] = {
429 | "visits": 0,
430 | "total_value": 0.0,
431 | "iterations": []
432 | }
433 |
434 | self.action_history[action_str]["visits"] += 1
435 | self.action_history[action_str]["total_value"] += simulation_result
436 | self.action_history[action_str]["iterations"].append(iteration)
437 |
438 | # Update exploration pattern
439 | if selected_node:
440 | # Record exploration choice
441 | self.exploration_patterns.append({
442 | "iteration": iteration,
443 | "node_depth": self._get_node_depth(selected_node),
444 | "node_breadth": len(getattr(selected_node, "children", {})),
445 | "value_estimate": getattr(selected_node, "value", 0) / max(1, getattr(selected_node, "visits", 1))
446 | })
447 |
448 | # Update visit distribution
449 | if self.root_node and hasattr(self.root_node, "children"):
450 | self._update_visit_distribution()
451 |
452 | # Update quality metrics
453 | self._update_quality_metrics()
454 |
455 | self.refresh()
456 |
457 | def start(self) -> None:
458 | """Start the visualization."""
459 | self.live.start()
460 | self.refresh()
461 |
462 | def stop(self) -> None:
463 | """Stop the visualization."""
464 | self.live.stop()
465 |
466 | def refresh(self) -> None:
467 | """Refresh the visualization."""
468 | # Update header
469 | header_content = f"[bold blue]Enhanced Monte Carlo Tree Search - Iteration {self.current_iteration}/{self.max_iterations}[/bold blue]"
470 | if self.best_action:
471 | header_content += f" | Best Action: {self.best_action}"
472 |
473 | intelligence_status = "[green]Enabled[/green]" if self.auto_improvement_enabled else "[yellow]Disabled[/yellow]"
474 | header_content += f" | Intelligent Search: {intelligence_status}"
475 |
476 | self.layout["header"].update(Panel(header_content, border_style="blue"))
477 |
478 | # Update tree visualization
479 | self.layout["tree"].update(self._create_tree_panel())
480 |
481 | # Update simulation visualization
482 | self.layout["simulation"].update(self._create_simulation_panel())
483 |
484 | # Update action insights panel
485 | self.layout["insights"].update(self._create_insights_panel())
486 |
487 | # Update intelligence metrics
488 | self.layout["intelligence"].update(self._create_intelligence_panel())
489 |
490 | # Update stats
491 | self.layout["stats"].update(self._create_stats_panel())
492 |
493 | # Refresh the live display
494 | self.live.refresh()
495 |
496 | def _create_tree_panel(self) -> Panel:
497 | """Create a panel showing the current state of the search tree.
498 |
499 | Returns:
500 | Panel with tree visualization
501 | """
502 | if not self.root_node:
503 | return Panel("No search tree initialized", title="[bold]Search Tree[/bold]")
504 |
505 | # Create a table to show the tree structure
506 | from rich.tree import Tree
507 | from rich import box
508 |
509 | tree = Tree("🔍 Root Node", guide_style="bold blue")
510 |
511 | # Limit the depth and breadth for display
512 | max_depth = 3
513 | max_children = 5
514 |
515 | def add_node(node, tree_node, depth=0, path=None):
516 | if depth >= max_depth or not node or not hasattr(node, "children"):
517 | return
518 |
519 | if path is None:
520 | path = []
521 |
522 | # Add children nodes
523 | children = list(node.children.items())
524 | if not children:
525 | return
526 |
527 | # Sort children by a combination of visits and value
528 | def node_score(node_pair):
529 | child_node = node_pair[1]
530 | visits = getattr(child_node, "visits", 0)
531 | value = getattr(child_node, "value", 0)
532 |
533 | # Combine visits and value for scoring
534 | if visits > 0:
535 | # Use UCB-style formula for ranking
536 | exploitation = value / visits
537 | exploration = (2 * 0.5 * (math.log(node.visits) / visits)) if node.visits > 0 and visits > 0 else 0
538 | return exploitation + exploration
539 | return 0
540 |
541 | # Sort by this smarter formula
542 | children.sort(key=node_score, reverse=True)
543 | children = children[:max_children]
544 |
545 | for action, child in children:
546 | # Format node information
547 | visits = getattr(child, "visits", 0)
548 | value = getattr(child, "value", 0)
549 |
550 | # Highlight the node with more sophisticated coloring
551 | style = ""
552 | if child == self.selected_node:
553 | style = "bold yellow"
554 | elif child == self.expanded_node:
555 | style = "bold green"
556 | else:
557 | # Color based on value
558 | if visits > 0:
559 | avg_value = value / visits
560 | if avg_value > 0.7:
561 | style = "green"
562 | elif avg_value > 0.4:
563 | style = "blue"
564 | elif avg_value > 0.2:
565 | style = "yellow"
566 | else:
567 | style = "red"
568 |
569 | # Create the node label with enhanced information
570 | current_path = path + [action]
571 |
572 | if visits > 0:
573 | avg_value = value / visits
574 | confidence = min(1.0, math.sqrt(visits) / 5) * 100 # Simple confidence estimate
575 | label = f"[{style}]{action}: (Visits: {visits}, Value: {avg_value:.3f}, Conf: {confidence:.0f}%)[/{style}]"
576 | else:
577 | label = f"[{style}]{action}: (New)[/{style}]"
578 |
579 | # Add the child node to the tree
580 | child_tree = tree_node.add(label)
581 |
582 | # Recursively add its children
583 | add_node(child, child_tree, depth + 1, current_path)
584 |
585 | # Start building the tree from the root
586 | if hasattr(self.root_node, "children"):
587 | # Add math import for node scoring
588 | import math
589 | add_node(self.root_node, tree)
590 |
591 | return Panel(tree, title="[bold]Search Tree[/bold]", border_style="blue")
592 |
593 | def _create_simulation_panel(self) -> Panel:
594 | """Create a panel showing the current simulation with enhanced analytics.
595 |
596 | Returns:
597 | Panel with simulation visualization
598 | """
599 | if not self.simulation_path:
600 | return Panel("No active simulation", title="[bold]Current Simulation[/bold]")
601 |
602 | # Create a list of simulation steps
603 | from rich.table import Table
604 |
605 | table = Table(box=None, expand=True)
606 | table.add_column("Step")
607 | table.add_column("Action")
608 | table.add_column("Expected Value") # New column
609 |
610 | for i, (state, action) in enumerate(self.simulation_path):
611 | # Get expected value for this action
612 | action_str = str(action) if action is not None else "None"
613 | expected_value = "N/A"
614 |
615 | if action_str in self.action_history:
616 | history = self.action_history[action_str]
617 | if history["visits"] > 0:
618 | expected_value = f"{history['total_value'] / history['visits']:.3f}"
619 |
620 | table.add_row(f"Step {i+1}", f"{action}", expected_value)
621 |
622 | if self.simulation_result is not None:
623 | # Add path quality metric
624 | path_quality = "Low"
625 | if self.simulation_result > 0.7:
626 | path_quality = "[bold green]High[/bold green]"
627 | elif self.simulation_result > 0.4:
628 | path_quality = "[yellow]Medium[/yellow]"
629 | else:
630 | path_quality = "[red]Low[/red]"
631 |
632 | table.add_row("Result",
633 | f"[bold green]{self.simulation_result:.3f}[/bold green]",
634 | f"Path Quality: {path_quality}")
635 |
636 | return Panel(table, title="[bold]Current Simulation[/bold]", border_style="green")
637 |
638 | def _create_insights_panel(self) -> Panel:
639 | """Create a panel showing action insights from learned patterns.
640 |
641 | Returns:
642 | Panel with action insights
643 | """
644 | from rich.table import Table
645 |
646 | if not self.action_history:
647 | return Panel("No action insights available yet", title="[bold]Action Insights[/bold]")
648 |
649 | # Get top performing actions
650 | top_actions = []
651 | for action, data in self.action_history.items():
652 | if data["visits"] >= 3: # Only consider actions with enough samples
653 | avg_value = data["total_value"] / data["visits"]
654 | top_actions.append((action, avg_value, data["visits"]))
655 |
656 | # Sort by value and take top 5
657 | top_actions.sort(key=lambda x: x[1], reverse=True)
658 | top_actions = top_actions[:5]
659 |
660 | # Create insights table
661 | table = Table(box=None, expand=True)
662 | table.add_column("Action")
663 | table.add_column("Avg Value")
664 | table.add_column("Visits")
665 | table.add_column("Trend")
666 |
667 | for action, avg_value, visits in top_actions:
668 | # Generate trend indicator based on recent performance
669 | trend = "→"
670 | history = self.action_history[action]["iterations"]
671 | if len(history) >= 5:
672 | recent = set(history[-3:]) # Last 3 iterations
673 | if self.current_iteration - max(recent) <= 5:
674 | trend = "↑" # Recently used
675 | elif self.current_iteration - max(recent) >= 10:
676 | trend = "↓" # Not used recently
677 |
678 | # Color code based on value
679 | if avg_value > 0.7:
680 | value_str = f"[green]{avg_value:.3f}[/green]"
681 | elif avg_value > 0.4:
682 | value_str = f"[blue]{avg_value:.3f}[/blue]"
683 | else:
684 | value_str = f"[yellow]{avg_value:.3f}[/yellow]"
685 |
686 | table.add_row(str(action), value_str, str(visits), trend)
687 |
688 | return Panel(table, title="[bold]Action Insights[/bold]", border_style="cyan")
689 |
690 | def _create_intelligence_panel(self) -> Panel:
691 | """Create a panel showing intelligence metrics and learning patterns.
692 |
693 | Returns:
694 | Panel with intelligence visualization
695 | """
696 | from rich.table import Table
697 | from rich.columns import Columns
698 |
699 | # Create metrics table
700 | metrics_table = Table(box=None, expand=True)
701 | metrics_table.add_column("Metric")
702 | metrics_table.add_column("Value")
703 |
704 | # Add search quality metrics
705 | for metric, value in self.quality_metrics.items():
706 | formatted_name = metric.replace("_", " ").title()
707 | # Color based on value
708 | if value > 0.7:
709 | value_str = f"[green]{value:.2f}[/green]"
710 | elif value > 0.4:
711 | value_str = f"[blue]{value:.2f}[/blue]"
712 | else:
713 | value_str = f"[yellow]{value:.2f}[/yellow]"
714 |
715 | metrics_table.add_row(formatted_name, value_str)
716 |
717 | # Create exploration table
718 | exploration_table = Table(box=None, expand=True)
719 | exploration_table.add_column("Pattern")
720 | exploration_table.add_column("Value")
721 |
722 | # Add exploration patterns
723 | if self.exploration_patterns:
724 | # Average depth of exploration
725 | avg_depth = sum(p["node_depth"] for p in self.exploration_patterns) / len(self.exploration_patterns)
726 | exploration_table.add_row("Avg Exploration Depth", f"{avg_depth:.2f}")
727 |
728 | # Depth trend (increasing or decreasing)
729 | if len(self.exploration_patterns) >= 5:
730 | recent_avg = sum(p["node_depth"] for p in self.exploration_patterns[-5:]) / 5
731 | earlier_avg = sum(p["node_depth"] for p in self.exploration_patterns[:-5]) / max(1, len(self.exploration_patterns) - 5)
732 |
733 | if recent_avg > earlier_avg * 1.2:
734 | trend = "[green]Deepening[/green]"
735 | elif recent_avg < earlier_avg * 0.8:
736 | trend = "[yellow]Shallowing[/yellow]"
737 | else:
738 | trend = "[blue]Stable[/blue]"
739 |
740 | exploration_table.add_row("Depth Trend", trend)
741 |
742 | # Exploration-exploitation balance
743 | if len(self.exploration_patterns) >= 3:
744 | # Higher values = more exploitation of known good paths
745 | exploitation_ratio = sum(1 for p in self.exploration_patterns[-10:]
746 | if p["value_estimate"] > 0.5) / min(10, len(self.exploration_patterns))
747 |
748 | if exploitation_ratio > 0.7:
749 | balance = "[yellow]Heavy Exploitation[/yellow]"
750 | elif exploitation_ratio < 0.3:
751 | balance = "[yellow]Heavy Exploration[/yellow]"
752 | else:
753 | balance = "[green]Balanced[/green]"
754 |
755 | exploration_table.add_row("Search Balance", balance)
756 |
757 | # Combine tables into columns
758 | columns = Columns([metrics_table, exploration_table])
759 |
760 | return Panel(columns, title="[bold]Intelligence Metrics[/bold]", border_style="magenta")
761 |
762 | def _create_stats_panel(self) -> Panel:
763 | """Create a panel showing search statistics with enhanced metrics.
764 |
765 | Returns:
766 | Panel with statistics
767 | """
768 | if not self.root_node:
769 | return Panel("No statistics available", title="[bold]Search Statistics[/bold]")
770 |
771 | # Collect statistics
772 | total_nodes = 0
773 | max_depth = 0
774 | total_visits = getattr(self.root_node, "visits", 0)
775 | avg_branching = 0
776 |
777 | def count_nodes(node, depth=0):
778 | nonlocal total_nodes, max_depth, avg_branching
779 | if not node or not hasattr(node, "children"):
780 | return
781 |
782 | total_nodes += 1
783 | max_depth = max(max_depth, depth)
784 |
785 | # Count children for branching factor
786 | num_children = len(node.children)
787 | if num_children > 0:
788 | avg_branching += num_children
789 |
790 | for child in node.children.values():
791 | count_nodes(child, depth + 1)
792 |
793 | count_nodes(self.root_node)
794 |
795 | # Calculate average branching factor
796 | if total_nodes > 1: # Root node doesn't count for avg branching
797 | avg_branching /= (total_nodes - 1)
798 |
799 | # Create a table of statistics
800 | from rich.table import Table
801 |
802 | table = Table(box=None, expand=True)
803 | table.add_column("Metric")
804 | table.add_column("Value")
805 |
806 | table.add_row("Total Nodes", str(total_nodes))
807 | table.add_row("Max Depth", str(max_depth))
808 | table.add_row("Total Visits", str(total_visits))
809 | table.add_row("Avg Branching", f"{avg_branching:.2f}")
810 | table.add_row("Progress", f"{self.current_iteration / self.max_iterations:.1%}")
811 |
812 | # Efficiency estimate (higher is better)
813 | if total_visits > 0:
814 | visit_efficiency = total_nodes / total_visits
815 | efficiency_str = f"{visit_efficiency:.2f}"
816 | table.add_row("Search Efficiency", efficiency_str)
817 |
818 | return Panel(table, title="[bold]Search Statistics[/bold]", border_style="magenta")
819 |
820 | def _get_node_depth(self, node):
821 | """Calculate the depth of a node in the tree."""
822 | depth = 0
823 | current = node
824 | while getattr(current, "parent", None) is not None:
825 | depth += 1
826 | current = current.parent
827 | return depth
828 |
829 | def _update_visit_distribution(self):
830 | """Update the distribution of visits across the tree."""
831 | levels = {}
832 |
833 | def count_visits_by_level(node, depth=0):
834 | if not node or not hasattr(node, "children"):
835 | return
836 |
837 | # Initialize level if not present
838 | if depth not in levels:
839 | levels[depth] = {"visits": 0, "nodes": 0}
840 |
841 | # Update level stats
842 | levels[depth]["visits"] += getattr(node, "visits", 0)
843 | levels[depth]["nodes"] += 1
844 |
845 | # Process children
846 | for child in node.children.values():
847 | count_visits_by_level(child, depth + 1)
848 |
849 | # Start counting from root
850 | count_visits_by_level(self.root_node)
851 |
852 | # Update visit distribution
853 | self.visit_distribution = levels
854 |
855 | def _update_quality_metrics(self):
856 | """Update quality metrics for the search process."""
857 | # Search efficiency - ratio of valuable nodes to total nodes
858 | # Higher values indicate more efficient search
859 | if self.visit_distribution:
860 | useful_visits = sum(level["visits"] for depth, level in self.visit_distribution.items()
861 | if depth > 0) # Exclude root
862 | total_visits = sum(level["visits"] for level in self.visit_distribution.values())
863 |
864 | if total_visits > 0:
865 | self.quality_metrics["search_efficiency"] = useful_visits / total_visits
866 |
867 | # Exploration balance - how well the algorithm balances exploration vs exploitation
868 | if self.exploration_patterns:
869 | # Calculate variance in exploration depth
870 | depths = [p["node_depth"] for p in self.exploration_patterns[-20:]] # Last 20 iterations
871 | if depths:
872 | import statistics
873 | try:
874 | depth_variance = statistics.variance(depths) if len(depths) > 1 else 0
875 | # Normalize to 0-1 range (higher variance = more balanced exploration)
876 | normalized_variance = min(1.0, depth_variance / 5.0) # Assume variance > 5 is high
877 | self.quality_metrics["exploration_balance"] = normalized_variance
878 | except statistics.StatisticsError:
879 | pass
880 |
881 |
882 | class ParallelExecutionVisualizer:
883 | """Visualizes parallel execution of tool calls in real-time."""
884 |
885 | def __init__(self, console: Console):
886 | """Initialize the parallel execution visualizer.
887 |
888 | Args:
889 | console: Rich console instance
890 | """
891 | self.console = console
892 | self.active_executions = {}
893 | self.completed_executions = []
894 | self.layout = self._create_layout()
895 | self.live = Live(self.layout, console=console, refresh_per_second=10, auto_refresh=False)
896 |
897 | def _create_layout(self) -> Layout:
898 | """Create the layout for parallel execution visualization.
899 |
900 | Returns:
901 | Layout object
902 | """
903 | layout = Layout()
904 |
905 | # Create the main sections
906 | layout.split(
907 | Layout(name="header", size=3),
908 | Layout(name="executions"),
909 | Layout(name="metrics", size=5)
910 | )
911 |
912 | return layout
913 |
914 | def add_execution(self, execution_id: str, tool_name: str, parameters: Dict[str, Any]) -> None:
915 | """Add a new execution to visualize.
916 |
917 | Args:
918 | execution_id: Unique ID for the execution
919 | tool_name: Name of the tool being executed
920 | parameters: Parameters for the execution
921 | """
922 | self.active_executions[execution_id] = {
923 | "tool_name": tool_name,
924 | "parameters": parameters,
925 | "start_time": time.time(),
926 | "progress": 0.0,
927 | "status": "running"
928 | }
929 | self.refresh()
930 |
931 | def update_progress(self, execution_id: str, progress: float) -> None:
932 | """Update the progress of an execution.
933 |
934 | Args:
935 | execution_id: ID of the execution
936 | progress: Progress value (0-1)
937 | """
938 | if execution_id in self.active_executions:
939 | self.active_executions[execution_id]["progress"] = progress
940 | self.refresh()
941 |
942 | def complete_execution(self, execution_id: str, result: Any, status: str = "success") -> None:
943 | """Mark an execution as complete.
944 |
945 | Args:
946 | execution_id: ID of the execution
947 | result: Result of the execution
948 | status: Status of completion
949 | """
950 | if execution_id in self.active_executions:
951 | execution = self.active_executions[execution_id].copy()
952 | execution["end_time"] = time.time()
953 | execution["duration"] = execution["end_time"] - execution["start_time"]
954 | execution["result"] = result
955 | execution["status"] = status
956 |
957 | # Move to completed executions
958 | self.completed_executions.append(execution)
959 | del self.active_executions[execution_id]
960 |
961 | # Limit completed executions list
962 | if len(self.completed_executions) > 20:
963 | self.completed_executions = self.completed_executions[-20:]
964 |
965 | self.refresh()
966 |
967 | def start(self) -> None:
968 | """Start the visualization."""
969 | self.live.start()
970 | self.refresh()
971 |
972 | def stop(self) -> None:
973 | """Stop the visualization."""
974 | self.live.stop()
975 |
976 | def refresh(self) -> None:
977 | """Refresh the visualization."""
978 | # Update header
979 | header_content = f"[bold blue]Parallel Execution Monitor[/bold blue] | Active: {len(self.active_executions)} | Completed: {len(self.completed_executions)}"
980 | self.layout["header"].update(Panel(header_content, border_style="blue"))
981 |
982 | # Update executions visualization
983 | self.layout["executions"].update(self._create_executions_panel())
984 |
985 | # Update metrics
986 | self.layout["metrics"].update(self._create_metrics_panel())
987 |
988 | # Refresh the live display
989 | self.live.refresh()
990 |
991 | def _create_executions_panel(self) -> Panel:
992 | """Create a panel showing active and recent executions.
993 |
994 | Returns:
995 | Panel with executions visualization
996 | """
997 | from rich.table import Table
998 | from rich.progress import BarColumn, Progress, TextColumn
999 |
1000 | # Create progress bars for active executions
1001 | progress_group = Table.grid(expand=True)
1002 |
1003 | if self.active_executions:
1004 | # Create a progress group
1005 | progress = Progress(
1006 | TextColumn("[bold blue]{task.description}"),
1007 | BarColumn(bar_width=None),
1008 | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
1009 | TextColumn("| Elapsed: {task.elapsed:.2f}s"),
1010 | expand=True
1011 | )
1012 |
1013 | # Add tasks for each active execution
1014 | for exec_id, execution in self.active_executions.items():
1015 | tool_name = execution["tool_name"]
1016 | description = f"{tool_name} ({exec_id[:8]}...)"
1017 | task_id = progress.add_task(description, total=100, completed=int(execution["progress"] * 100))
1018 |
1019 | progress_group.add_row(progress)
1020 | else:
1021 | progress_group.add_row("[italic]No active executions[/italic]")
1022 |
1023 | # Create a table for completed executions
1024 | completed_table = Table(show_header=True, header_style="bold blue", expand=True)
1025 | completed_table.add_column("Tool")
1026 | completed_table.add_column("Duration")
1027 | completed_table.add_column("Status")
1028 | completed_table.add_column("Result Preview")
1029 |
1030 | if self.completed_executions:
1031 | # Most recent first
1032 | for execution in reversed(self.completed_executions[-10:]):
1033 | tool_name = execution["tool_name"]
1034 | duration = f"{execution['duration']:.2f}s"
1035 | status = execution["status"]
1036 |
1037 | # Format result preview
1038 | result = str(execution.get("result", ""))
1039 | preview = result[:50] + "..." if len(result) > 50 else result
1040 |
1041 | # Add status with color
1042 | status_text = f"[green]{status}[/green]" if status == "success" else f"[red]{status}[/red]"
1043 |
1044 | completed_table.add_row(tool_name, duration, status_text, preview)
1045 | else:
1046 | completed_table.add_row("[italic]No completed executions[/italic]", "", "", "")
1047 |
1048 | # Combine both into a layout
1049 | layout = Layout()
1050 | layout.split(
1051 | Layout(name="active", size=len(self.active_executions) * 2 + 3 if self.active_executions else 3),
1052 | Layout(name="completed")
1053 | )
1054 | layout["active"].update(Panel(progress_group, title="[bold]Active Executions[/bold]", border_style="blue"))
1055 | layout["completed"].update(Panel(completed_table, title="[bold]Recent Completions[/bold]", border_style="green"))
1056 |
1057 | return layout
1058 |
1059 | def _create_metrics_panel(self) -> Panel:
1060 | """Create a panel showing execution metrics.
1061 |
1062 | Returns:
1063 | Panel with metrics visualization
1064 | """
1065 | from rich.table import Table
1066 |
1067 | # Calculate metrics
1068 | total_executions = len(self.completed_executions)
1069 | successful = sum(1 for e in self.completed_executions if e["status"] == "success")
1070 | failed = total_executions - successful
1071 |
1072 | if total_executions > 0:
1073 | success_rate = successful / total_executions
1074 | avg_duration = sum(e["duration"] for e in self.completed_executions) / total_executions
1075 | else:
1076 | success_rate = 0
1077 | avg_duration = 0
1078 |
1079 | # Create metrics table
1080 | table = Table(box=None, expand=True)
1081 | table.add_column("Metric")
1082 | table.add_column("Value")
1083 |
1084 | table.add_row("Total Executions", str(total_executions))
1085 | table.add_row("Success Rate", f"{success_rate:.1%}")
1086 | table.add_row("Average Duration", f"{avg_duration:.2f}s")
1087 | table.add_row("Current Parallelism", str(len(self.active_executions)))
1088 |
1089 | return Panel(table, title="[bold]Execution Metrics[/bold]", border_style="magenta")
1090 |
1091 |
1092 | class MultiPanelLayout:
1093 | """Creates a multi-panel layout for the entire UI."""
1094 |
1095 | def __init__(self, console: Console):
1096 | """Initialize the multi-panel layout.
1097 |
1098 | Args:
1099 | console: Rich console instance
1100 | """
1101 | self.console = console
1102 | self.layout = self._create_layout()
1103 | self.live = Live(self.layout, console=console, refresh_per_second=4, auto_refresh=False)
1104 |
1105 | def _create_layout(self) -> Layout:
1106 | """Create the main application layout.
1107 |
1108 | Returns:
1109 | Layout object
1110 | """
1111 | layout = Layout()
1112 |
1113 | # Split into three main sections
1114 | layout.split(
1115 | Layout(name="conversation", ratio=3),
1116 | Layout(name="tools", ratio=2),
1117 | Layout(name="input", ratio=1)
1118 | )
1119 |
1120 | # Further split the tools section
1121 | layout["tools"].split_row(
1122 | Layout(name="active_tools"),
1123 | Layout(name="cost", size=30)
1124 | )
1125 |
1126 | return layout
1127 |
1128 | def start(self) -> None:
1129 | """Start the live display."""
1130 | self.live.start()
1131 |
1132 | def stop(self) -> None:
1133 | """Stop the live display."""
1134 | self.live.stop()
1135 |
1136 | def refresh(self) -> None:
1137 | """Refresh the display."""
1138 | self.live.refresh()
1139 |
1140 | def update_section(self, section: str, content: Any) -> None:
1141 | """Update a section of the layout.
1142 |
1143 | Args:
1144 | section: Section name
1145 | content: Content to display
1146 | """
1147 | if section in self.layout:
1148 | self.layout[section].update(content)
1149 | self.refresh()
```
--------------------------------------------------------------------------------
/mcp_server.py:
--------------------------------------------------------------------------------
```python
1 | #!/usr/bin/env python3
2 | """
3 | Model Context Protocol (MCP) Server Implementation
4 |
5 | This module implements the Model Context Protocol server capabilities,
6 | allowing the assistant to be used as an MCP-compatible context provider.
7 | """
8 |
9 | import os
10 | import json
11 | import time
12 | import uuid
13 | import sys
14 | import logging
15 | import asyncio
16 | import tiktoken
17 | import re
18 | from datetime import datetime
19 | from typing import Dict, List, Any, Optional, Union, AsyncGenerator
20 | from fastapi import FastAPI, HTTPException, Request, Response, Depends, BackgroundTasks, Query
21 | from fastapi.responses import JSONResponse, StreamingResponse
22 | from fastapi.middleware.cors import CORSMiddleware
23 | from fastapi.staticfiles import StaticFiles
24 | from fastapi.templating import Jinja2Templates
25 | from pydantic import BaseModel, Field
26 | import uvicorn
27 | import openai
28 | from openai import OpenAI
29 | import prometheus_client
30 | from prometheus_client import Counter, Histogram, Gauge
31 |
32 | # Configure logging
33 | logging.basicConfig(
34 | level=logging.INFO,
35 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
36 | )
37 | logger = logging.getLogger("mcp_server")
38 |
39 | # MCP Protocol Models
40 | class MCPHealthResponse(BaseModel):
41 | """Health check response for MCP protocol"""
42 | status: str = "healthy"
43 | version: str = "1.0.0"
44 | protocol_version: str = "0.1.0"
45 | provider: str = "OpenAI Code Assistant"
46 | models: List[str] = ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"]
47 | uptime: Optional[float] = None
48 | request_count: Optional[int] = None
49 | cache_hit_ratio: Optional[float] = None
50 |
51 | class MCPContextRequest(BaseModel):
52 | """Request for context generation from a prompt template"""
53 | prompt_id: str
54 | parameters: Dict[str, Any] = Field(default_factory=dict, description="Parameters to fill in the prompt template")
55 | model: Optional[str] = Field(None, description="Model to use for context generation")
56 | stream: bool = Field(False, description="Whether to stream the response")
57 | user: Optional[str] = Field(None, description="User identifier for tracking")
58 | conversation_id: Optional[str] = Field(None, description="Conversation identifier")
59 | message_id: Optional[str] = Field(None, description="Message identifier")
60 |
61 | class MCPContextResponse(BaseModel):
62 | """Response containing generated context"""
63 | context: str = Field(..., description="The generated context")
64 | context_id: str = Field(..., description="Unique identifier for this context")
65 | model: str = Field(..., description="Model used for generation")
66 | usage: Dict[str, int] = Field(default_factory=dict, description="Token usage statistics")
67 | metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
68 |
69 | class MCPErrorResponse(BaseModel):
70 | """Error response format"""
71 | error: str = Field(..., description="Error message")
72 | error_type: str = Field(..., description="Type of error")
73 | status_code: int = Field(..., description="HTTP status code")
74 | details: Optional[Dict[str, Any]] = Field(None, description="Additional error details")
75 |
76 | class MCPPromptTemplate(BaseModel):
77 | """Prompt template definition"""
78 | id: str = Field(..., description="Unique identifier for the template")
79 | template: str = Field(..., description="The prompt template with parameter placeholders")
80 | description: Optional[str] = Field(None, description="Description of the template")
81 | parameters: Dict[str, Dict[str, Any]] = Field(default_factory=dict, description="Parameter definitions")
82 | default_model: Optional[str] = Field(None, description="Default model to use with this template")
83 | metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
84 |
85 | class MCPPromptLibraryResponse(BaseModel):
86 | """Response containing a list of prompt templates"""
87 | prompts: List[MCPPromptTemplate] = Field(..., description="List of prompt templates")
88 | count: int = Field(..., description="Number of templates")
89 |
90 | # MCP Server Implementation
91 | # Prometheus metrics
92 | REQUEST_COUNT = Counter('mcp_requests_total', 'Total number of requests processed', ['endpoint', 'status'])
93 | REQUEST_LATENCY = Histogram('mcp_request_latency_seconds', 'Request latency in seconds', ['endpoint'])
94 | CACHE_HIT = Counter('mcp_cache_hits_total', 'Total number of cache hits')
95 | CACHE_MISS = Counter('mcp_cache_misses_total', 'Total number of cache misses')
96 | ACTIVE_CONNECTIONS = Gauge('mcp_active_connections', 'Number of active connections')
97 | TOKEN_USAGE = Counter('mcp_token_usage_total', 'Total number of tokens used', ['model', 'type'])
98 |
99 | # Cache implementation
100 | class CacheManager:
101 | """Manages caching for context responses"""
102 |
103 | def __init__(self, cache_type="memory", redis_url=None, ttl=3600):
104 | self.cache_type = cache_type
105 | self.redis_url = redis_url
106 | self.ttl = ttl
107 | self.memory_cache = {}
108 | self.redis_client = None
109 |
110 | if cache_type == "redis" and redis_url:
111 | try:
112 | import redis
113 | self.redis_client = redis.from_url(redis_url)
114 | logging.info(f"Redis cache initialized with URL: {redis_url}")
115 | except ImportError:
116 | logging.warning("Redis package not installed. Falling back to memory cache.")
117 | self.cache_type = "memory"
118 | except Exception as e:
119 | logging.error(f"Failed to connect to Redis: {str(e)}")
120 | self.cache_type = "memory"
121 |
122 | async def get(self, key):
123 | """Get item from cache"""
124 | if self.cache_type == "redis" and self.redis_client:
125 | try:
126 | value = self.redis_client.get(key)
127 | if value:
128 | CACHE_HIT.inc()
129 | return json.loads(value)
130 | CACHE_MISS.inc()
131 | return None
132 | except Exception as e:
133 | logging.error(f"Redis get error: {str(e)}")
134 | CACHE_MISS.inc()
135 | return None
136 | else:
137 | # Memory cache
138 | if key in self.memory_cache:
139 | if time.time() - self.memory_cache[key]["timestamp"] < self.ttl:
140 | CACHE_HIT.inc()
141 | return self.memory_cache[key]["data"]
142 | else:
143 | # Expired
144 | del self.memory_cache[key]
145 | CACHE_MISS.inc()
146 | return None
147 |
148 | async def set(self, key, value, ttl=None):
149 | """Set item in cache"""
150 | if ttl is None:
151 | ttl = self.ttl
152 |
153 | if self.cache_type == "redis" and self.redis_client:
154 | try:
155 | self.redis_client.setex(key, ttl, json.dumps(value))
156 | except Exception as e:
157 | logging.error(f"Redis set error: {str(e)}")
158 | else:
159 | # Memory cache
160 | self.memory_cache[key] = {
161 | "data": value,
162 | "timestamp": time.time()
163 | }
164 |
165 | async def delete(self, key):
166 | """Delete item from cache"""
167 | if self.cache_type == "redis" and self.redis_client:
168 | try:
169 | self.redis_client.delete(key)
170 | except Exception as e:
171 | logging.error(f"Redis delete error: {str(e)}")
172 | else:
173 | # Memory cache
174 | if key in self.memory_cache:
175 | del self.memory_cache[key]
176 |
177 | async def clear(self):
178 | """Clear all cache"""
179 | if self.cache_type == "redis" and self.redis_client:
180 | try:
181 | self.redis_client.flushdb()
182 | except Exception as e:
183 | logging.error(f"Redis flush error: {str(e)}")
184 | else:
185 | # Memory cache
186 | self.memory_cache = {}
187 |
188 | class MCPServer:
189 | """Model Context Protocol Server Implementation"""
190 |
191 | def __init__(self, cache_type="memory", redis_url=None):
192 | self.app = FastAPI(
193 | title="OpenAI Code Assistant MCP Server",
194 | description="Model Context Protocol server for OpenAI Code Assistant",
195 | version="1.0.0",
196 | docs_url="/docs",
197 | redoc_url="/redoc",
198 | openapi_url="/openapi.json",
199 | )
200 |
201 | # Initialize OpenAI client
202 | self.openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
203 |
204 | # Initialize cache
205 | self.cache = CacheManager(cache_type=cache_type, redis_url=redis_url)
206 |
207 | # Initialize tokenizer
208 | self.tokenizer = tiktoken.get_encoding("cl100k_base")
209 |
210 | # Setup routes and middleware
211 | self.setup_routes()
212 | self.setup_middleware()
213 |
214 | # Load templates and static files
215 | self.templates_dir = os.path.join(os.path.dirname(__file__), "templates")
216 | os.makedirs(self.templates_dir, exist_ok=True)
217 | self.static_dir = os.path.join(os.path.dirname(__file__), "static")
218 | os.makedirs(self.static_dir, exist_ok=True)
219 |
220 | # Create default template if it doesn't exist
221 | self._create_default_template()
222 |
223 | # Initialize templates
224 | self.templates = Jinja2Templates(directory=self.templates_dir)
225 |
226 | # Mount static files
227 | self.app.mount("/static", StaticFiles(directory=self.static_dir), name="static")
228 |
229 | # Load prompt templates
230 | self.prompt_templates = self._load_prompt_templates()
231 |
232 | # Initialize metrics
233 | self.request_count = 0
234 | self.start_time = time.time()
235 |
236 | def setup_middleware(self):
237 | """Configure middleware for the FastAPI app"""
238 | # Add CORS middleware
239 | self.app.add_middleware(
240 | CORSMiddleware,
241 | allow_origins=["*"],
242 | allow_credentials=True,
243 | allow_methods=["*"],
244 | allow_headers=["*"],
245 | )
246 |
247 | # Add request tracking middleware
248 | @self.app.middleware("http")
249 | async def track_requests(request: Request, call_next):
250 | # Increment active connections
251 | ACTIVE_CONNECTIONS.inc()
252 |
253 | # Track request start time
254 | start_time = time.time()
255 |
256 | # Process request
257 | try:
258 | response = await call_next(request)
259 |
260 | # Record metrics
261 | endpoint = request.url.path
262 | status = response.status_code
263 | REQUEST_COUNT.labels(endpoint=endpoint, status=status).inc()
264 | REQUEST_LATENCY.labels(endpoint=endpoint).observe(time.time() - start_time)
265 |
266 | # Increment total request count
267 | self.request_count += 1
268 |
269 | return response
270 | finally:
271 | # Decrement active connections
272 | ACTIVE_CONNECTIONS.dec()
273 |
274 | def _create_default_template(self):
275 | """Create default dashboard template if it doesn't exist"""
276 | index_path = os.path.join(self.templates_dir, "index.html")
277 | if not os.path.exists(index_path):
278 | with open(index_path, "w") as f:
279 | f.write("""
280 | <!DOCTYPE html>
281 | <html>
282 | <head>
283 | <title>OpenAI Code Assistant MCP Server</title>
284 | <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css">
285 | <style>
286 | body { padding: 20px; }
287 | .card { margin-bottom: 20px; }
288 | </style>
289 | </head>
290 | <body>
291 | <div class="container">
292 | <h1>OpenAI Code Assistant MCP Server</h1>
293 | <div class="row">
294 | <div class="col-md-6">
295 | <div class="card">
296 | <div class="card-header">Server Status</div>
297 | <div class="card-body">
298 | <p><strong>Status:</strong> {{ status }}</p>
299 | <p><strong>Uptime:</strong> {{ uptime }}</p>
300 | <p><strong>Requests Served:</strong> {{ request_count }}</p>
301 | <p><strong>Cache Hit Ratio:</strong> {{ cache_hit_ratio }}%</p>
302 | </div>
303 | </div>
304 | </div>
305 | <div class="col-md-6">
306 | <div class="card">
307 | <div class="card-header">Available Models</div>
308 | <div class="card-body">
309 | <ul>
310 | {% for model in models %}
311 | <li>{{ model }}</li>
312 | {% endfor %}
313 | </ul>
314 | </div>
315 | </div>
316 | </div>
317 | </div>
318 |
319 | <h2>Available Prompt Templates</h2>
320 | <div class="row">
321 | {% for template in templates %}
322 | <div class="col-md-6">
323 | <div class="card">
324 | <div class="card-header">{{ template.id }}</div>
325 | <div class="card-body">
326 | <p><strong>Description:</strong> {{ template.description }}</p>
327 | <p><strong>Parameters:</strong> {{ template.parameters|join(", ") }}</p>
328 | <p><strong>Default Model:</strong> {{ template.default_model }}</p>
329 | </div>
330 | </div>
331 | </div>
332 | {% endfor %}
333 | </div>
334 |
335 | <h2>API Documentation</h2>
336 | <p>
337 | <a href="/docs" class="btn btn-primary">Interactive API Docs</a>
338 | <a href="/redoc" class="btn btn-secondary">ReDoc API Docs</a>
339 | <a href="/metrics" class="btn btn-info">Prometheus Metrics</a>
340 | </p>
341 | </div>
342 |
343 | <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
344 | </body>
345 | </html>
346 | """)
347 |
348 | def setup_routes(self):
349 | """Configure API routes for MCP protocol"""
350 |
351 | # MCP Protocol Routes
352 | # Dashboard route
353 | @self.app.get("/", tags=["Dashboard"])
354 | async def dashboard(request: Request):
355 | """Dashboard showing server status and available templates"""
356 | # Calculate cache hit ratio
357 | cache_hits = prometheus_client.REGISTRY.get_sample_value('mcp_cache_hits_total') or 0
358 | cache_misses = prometheus_client.REGISTRY.get_sample_value('mcp_cache_misses_total') or 0
359 | total_cache_requests = cache_hits + cache_misses
360 | cache_hit_ratio = (cache_hits / total_cache_requests * 100) if total_cache_requests > 0 else 0
361 |
362 | # Format uptime
363 | uptime_seconds = time.time() - self.start_time
364 | days, remainder = divmod(uptime_seconds, 86400)
365 | hours, remainder = divmod(remainder, 3600)
366 | minutes, seconds = divmod(remainder, 60)
367 | uptime_str = f"{int(days)}d {int(hours)}h {int(minutes)}m {int(seconds)}s"
368 |
369 | # Get template information
370 | templates = []
371 | for template_id, template in self.prompt_templates.items():
372 | templates.append({
373 | "id": template_id,
374 | "description": template.get("description", ""),
375 | "parameters": list(template.get("parameters", {}).keys()),
376 | "default_model": template.get("default_model", "gpt-4o")
377 | })
378 |
379 | return self.templates.TemplateResponse("index.html", {
380 | "request": request,
381 | "status": "Healthy",
382 | "uptime": uptime_str,
383 | "request_count": self.request_count,
384 | "cache_hit_ratio": round(cache_hit_ratio, 2),
385 | "models": ["gpt-4o", "gpt-4-turbo", "gpt-3.5-turbo"],
386 | "templates": templates
387 | })
388 |
389 | # Prometheus metrics endpoint
390 | @self.app.get("/metrics", tags=["Monitoring"])
391 | async def metrics():
392 | """Expose Prometheus metrics"""
393 | return Response(prometheus_client.generate_latest(), media_type="text/plain")
394 |
395 | # Health check endpoints
396 | @self.app.get("/health", response_model=MCPHealthResponse, tags=["Health"])
397 | async def health():
398 | """Health check endpoint"""
399 | # Calculate cache hit ratio
400 | cache_hits = prometheus_client.REGISTRY.get_sample_value('mcp_cache_hits_total') or 0
401 | cache_misses = prometheus_client.REGISTRY.get_sample_value('mcp_cache_misses_total') or 0
402 | total_cache_requests = cache_hits + cache_misses
403 | cache_hit_ratio = (cache_hits / total_cache_requests) if total_cache_requests > 0 else 0
404 |
405 | return MCPHealthResponse(
406 | status="healthy",
407 | uptime=time.time() - self.start_time,
408 | request_count=self.request_count,
409 | cache_hit_ratio=cache_hit_ratio
410 | )
411 |
412 | @self.app.post("/context", response_model=MCPContextResponse, tags=["Context"])
413 | async def get_context(
414 | request: MCPContextRequest,
415 | background_tasks: BackgroundTasks,
416 | use_cache: bool = Query(True, description="Whether to use cached results if available")
417 | ):
418 | """
419 | Get context for a prompt template with parameters.
420 |
421 | This endpoint processes a prompt template with the provided parameters
422 | and returns the generated context. It can optionally use OpenAI models
423 | to enhance the context.
424 | """
425 | try:
426 | # Check if prompt template exists
427 | if request.prompt_id not in self.prompt_templates:
428 | raise HTTPException(
429 | status_code=404,
430 | detail=f"Prompt template '{request.prompt_id}' not found"
431 | )
432 |
433 | # Get prompt template
434 | template = self.prompt_templates[request.prompt_id]
435 |
436 | # Use default model if not specified
437 | model = request.model or template.get("default_model", "gpt-4o")
438 |
439 | # Generate context ID
440 | context_id = str(uuid.uuid4())
441 |
442 | # Generate cache key
443 | cache_key = f"{request.prompt_id}:{json.dumps(request.parameters, sort_keys=True)}:{model}"
444 |
445 | # Check cache if enabled
446 | if use_cache:
447 | cached_result = await self.cache.get(cache_key)
448 | if cached_result:
449 | # Update context ID for this request
450 | cached_result["context_id"] = context_id
451 | return MCPContextResponse(**cached_result)
452 |
453 | # Process template with parameters
454 | processed_template = self._process_template(template["template"], request.parameters)
455 |
456 | # Check if we should use OpenAI to enhance the context
457 | if template.get("use_openai", False):
458 | # Generate context using OpenAI
459 | context, usage = await self._generate_with_openai(
460 | processed_template,
461 | model,
462 | template.get("system_prompt")
463 | )
464 | else:
465 | # Use the processed template directly
466 | context = processed_template
467 |
468 | # Calculate token usage
469 | token_count = len(self.tokenizer.encode(context))
470 | usage = {
471 | "prompt_tokens": token_count,
472 | "completion_tokens": 0,
473 | "total_tokens": token_count
474 | }
475 |
476 | # Track token usage in Prometheus
477 | TOKEN_USAGE.labels(model=model, type="prompt").inc(usage["prompt_tokens"])
478 | TOKEN_USAGE.labels(model=model, type="completion").inc(usage["completion_tokens"])
479 |
480 | # Create response
481 | response = MCPContextResponse(
482 | context=context,
483 | context_id=context_id,
484 | model=model,
485 | usage=usage,
486 | metadata={
487 | "prompt_id": request.prompt_id,
488 | "timestamp": time.time(),
489 | "parameters": request.parameters
490 | }
491 | )
492 |
493 | # Store in cache
494 | await self.cache.set(cache_key, response.dict())
495 |
496 | return response
497 |
498 | except Exception as e:
499 | logger.error(f"Error processing context request: {str(e)}", exc_info=True)
500 | raise HTTPException(
501 | status_code=500,
502 | detail=f"Error processing context: {str(e)}"
503 | )
504 |
505 | @self.app.post("/context/stream", tags=["Context"])
506 | async def stream_context(request: MCPContextRequest):
507 | """
508 | Stream context generation.
509 |
510 | Similar to /context but streams the response as it's generated.
511 | """
512 | try:
513 | # Check if prompt template exists
514 | if request.prompt_id not in self.prompt_templates:
515 | raise HTTPException(
516 | status_code=404,
517 | detail=f"Prompt template '{request.prompt_id}' not found"
518 | )
519 |
520 | # Get prompt template
521 | template = self.prompt_templates[request.prompt_id]
522 |
523 | # Use default model if not specified
524 | model = request.model or template.get("default_model", "gpt-4o")
525 |
526 | # Generate context ID
527 | context_id = str(uuid.uuid4())
528 |
529 | # Process template with parameters
530 | processed_template = self._process_template(template["template"], request.parameters)
531 |
532 | # Stream the context generation
533 | return StreamingResponse(
534 | self._stream_context(processed_template, model, context_id, template.get("system_prompt")),
535 | media_type="text/event-stream"
536 | )
537 |
538 | except Exception as e:
539 | logger.error(f"Error streaming context: {str(e)}", exc_info=True)
540 | raise HTTPException(
541 | status_code=500,
542 | detail=f"Error streaming context: {str(e)}"
543 | )
544 |
545 | @self.app.get("/prompts", response_model=MCPPromptLibraryResponse, tags=["Prompts"])
546 | async def get_prompts():
547 | """
548 | Get available prompt templates.
549 |
550 | Returns a list of all prompt templates available in the system.
551 | """
552 | prompts = [
553 | MCPPromptTemplate(
554 | id=prompt_id,
555 | template=template["template"],
556 | description=template.get("description", ""),
557 | parameters=template.get("parameters", {}),
558 | default_model=template.get("default_model", "gpt-4o"),
559 | metadata=template.get("metadata", {})
560 | )
561 | for prompt_id, template in self.prompt_templates.items()
562 | ]
563 |
564 | return MCPPromptLibraryResponse(
565 | prompts=prompts,
566 | count=len(prompts)
567 | )
568 |
569 | @self.app.get("/prompts/{prompt_id}", response_model=MCPPromptTemplate, tags=["Prompts"])
570 | async def get_prompt(prompt_id: str):
571 | """
572 | Get a specific prompt template.
573 |
574 | Returns the details of a specific prompt template by ID.
575 | """
576 | if prompt_id not in self.prompt_templates:
577 | raise HTTPException(
578 | status_code=404,
579 | detail=f"Prompt template '{prompt_id}' not found"
580 | )
581 |
582 | template = self.prompt_templates[prompt_id]
583 | return MCPPromptTemplate(
584 | id=prompt_id,
585 | template=template["template"],
586 | description=template.get("description", ""),
587 | parameters=template.get("parameters", {}),
588 | default_model=template.get("default_model", "gpt-4o"),
589 | metadata=template.get("metadata", {})
590 | )
591 |
592 | @self.app.post("/prompts", response_model=MCPPromptTemplate, status_code=201, tags=["Prompts"])
593 | async def create_prompt(prompt: MCPPromptTemplate):
594 | """
595 | Create a new prompt template.
596 |
597 | Adds a new prompt template to the system.
598 | """
599 | if prompt.id in self.prompt_templates:
600 | raise HTTPException(
601 | status_code=409,
602 | detail=f"Prompt template '{prompt.id}' already exists"
603 | )
604 |
605 | self.prompt_templates[prompt.id] = {
606 | "template": prompt.template,
607 | "description": prompt.description,
608 | "parameters": prompt.parameters,
609 | "default_model": prompt.default_model,
610 | "metadata": prompt.metadata
611 | }
612 |
613 | # Save updated templates
614 | self._save_prompt_templates()
615 |
616 | return prompt
617 |
618 | @self.app.put("/prompts/{prompt_id}", response_model=MCPPromptTemplate, tags=["Prompts"])
619 | async def update_prompt(prompt_id: str, prompt: MCPPromptTemplate):
620 | """
621 | Update an existing prompt template.
622 |
623 | Updates the details of an existing prompt template.
624 | """
625 | if prompt_id != prompt.id:
626 | raise HTTPException(
627 | status_code=400,
628 | detail="Prompt ID in path must match prompt ID in body"
629 | )
630 |
631 | if prompt_id not in self.prompt_templates:
632 | raise HTTPException(
633 | status_code=404,
634 | detail=f"Prompt template '{prompt_id}' not found"
635 | )
636 |
637 | self.prompt_templates[prompt_id] = {
638 | "template": prompt.template,
639 | "description": prompt.description,
640 | "parameters": prompt.parameters,
641 | "default_model": prompt.default_model,
642 | "metadata": prompt.metadata
643 | }
644 |
645 | # Save updated templates
646 | self._save_prompt_templates()
647 |
648 | return prompt
649 |
650 | @self.app.delete("/prompts/{prompt_id}", tags=["Prompts"])
651 | async def delete_prompt(prompt_id: str):
652 | """
653 | Delete a prompt template.
654 |
655 | Removes a prompt template from the system.
656 | """
657 | if prompt_id not in self.prompt_templates:
658 | raise HTTPException(
659 | status_code=404,
660 | detail=f"Prompt template '{prompt_id}' not found"
661 | )
662 |
663 | del self.prompt_templates[prompt_id]
664 |
665 | # Save updated templates
666 | self._save_prompt_templates()
667 |
668 | return {"status": "deleted", "prompt_id": prompt_id}
669 |
670 | # Additional endpoints for a more complete MCP server
671 | @self.app.get("/models", tags=["Models"])
672 | async def get_models():
673 | """
674 | Get available models.
675 |
676 | Returns a list of models that can be used with this MCP server.
677 | """
678 | return {
679 | "models": [
680 | {
681 | "id": "gpt-4o",
682 | "name": "GPT-4o",
683 | "description": "OpenAI's most advanced model",
684 | "context_length": 128000,
685 | "is_default": True
686 | },
687 | {
688 | "id": "gpt-4-turbo",
689 | "name": "GPT-4 Turbo",
690 | "description": "Optimized version of GPT-4",
691 | "context_length": 128000,
692 | "is_default": False
693 | },
694 | {
695 | "id": "gpt-3.5-turbo",
696 | "name": "GPT-3.5 Turbo",
697 | "description": "Fast and efficient model",
698 | "context_length": 16385,
699 | "is_default": False
700 | }
701 | ],
702 | "count": 3
703 | }
704 |
705 | @self.app.get("/stats", tags=["System"])
706 | async def get_stats():
707 | """
708 | Get server statistics.
709 |
710 | Returns usage statistics and system information.
711 | """
712 | return {
713 | "uptime": time.time() - self.start_time,
714 | "prompt_templates_count": len(self.prompt_templates),
715 | "cache_size": len(self.context_cache),
716 | "requests_served": {
717 | "context": 0, # This would be tracked in a real implementation
718 | "prompts": 0,
719 | "total": 0
720 | },
721 | "system_info": {
722 | "python_version": sys.version,
723 | "platform": sys.platform
724 | }
725 | }
726 |
727 | @self.app.post("/context/stream", tags=["Context"])
728 | async def stream_context(request: MCPContextRequest):
729 | """
730 | Stream context generation.
731 |
732 | Similar to /context but streams the response as it's generated.
733 | """
734 | # In a real implementation, this would stream the response
735 | # For now, we'll just return a simple response
736 | return JSONResponse(
737 | content={"message": "Streaming not implemented in this version"},
738 | status_code=501
739 | )
740 |
741 | # Error handlers
742 | @self.app.exception_handler(HTTPException)
743 | async def http_exception_handler(request: Request, exc: HTTPException):
744 | """Handle HTTP exceptions in MCP format"""
745 | return JSONResponse(
746 | status_code=exc.status_code,
747 | content={
748 | "error": exc.detail,
749 | "error_type": "http_error",
750 | "status_code": exc.status_code,
751 | "details": exc.detail if isinstance(exc.detail, dict) else None
752 | }
753 | )
754 |
755 | @self.app.exception_handler(Exception)
756 | async def general_exception_handler(request: Request, exc: Exception):
757 | """Handle general exceptions in MCP format"""
758 | logger.error(f"Unhandled exception: {str(exc)}", exc_info=True)
759 | return JSONResponse(
760 | status_code=500,
761 | content={
762 | "error": str(exc),
763 | "error_type": "server_error",
764 | "status_code": 500,
765 | "details": None
766 | }
767 | )
768 |
769 | def _load_prompt_templates(self) -> Dict[str, Dict[str, Any]]:
770 | """Load prompt templates from file or initialize defaults"""
771 | templates_file = os.path.join(os.path.dirname(__file__), "data", "prompt_templates.json")
772 |
773 | # Create directory if it doesn't exist
774 | os.makedirs(os.path.dirname(templates_file), exist_ok=True)
775 |
776 | # Try to load existing templates
777 | if os.path.exists(templates_file):
778 | try:
779 | with open(templates_file, "r") as f:
780 | templates = json.load(f)
781 | logger.info(f"Loaded {len(templates)} prompt templates from {templates_file}")
782 | return templates
783 | except Exception as e:
784 | logger.error(f"Error loading prompt templates: {str(e)}")
785 |
786 | # Initialize with enhanced default templates
787 | default_templates = {
788 | "greeting": {
789 | "template": "Hello! The current time is {time}. How can I help you today?",
790 | "description": "A simple greeting template",
791 | "parameters": {
792 | "time": {
793 | "type": "string",
794 | "description": "The current time"
795 | }
796 | },
797 | "default_model": "gpt-4o",
798 | "metadata": {
799 | "category": "general"
800 | }
801 | },
802 | "code_review": {
803 | "template": "Please review the following code:\n\n```{language}\n{code}\n```\n\nFocus on: {focus_areas}",
804 | "description": "Template for code review requests",
805 | "parameters": {
806 | "language": {
807 | "type": "string",
808 | "description": "Programming language of the code"
809 | },
810 | "code": {
811 | "type": "string",
812 | "description": "The code to review"
813 | },
814 | "focus_areas": {
815 | "type": "string",
816 | "description": "Areas to focus on during review (e.g., 'performance, security')"
817 | }
818 | },
819 | "default_model": "gpt-4o",
820 | "use_openai": True,
821 | "system_prompt": "You are a code review expert. Analyze the provided code and provide constructive feedback focusing on the specified areas.",
822 | "metadata": {
823 | "category": "development"
824 | }
825 | },
826 | "system_prompt": {
827 | "template": "You are OpenAI Code Assistant, a CLI tool that helps users with software engineering tasks and general information.\nUse the available tools to assist the user with their requests.\n\n# Tone and style\nYou should be concise, direct, and to the point. When you run a non-trivial bash command, \nyou should explain what the command does and why you are running it.\nOutput text to communicate with the user; all text you output outside of tool use is displayed to the user.\nRemember that your output will be displayed on a command line interface.\n\n# Tool usage policy\n- When doing file search, remember to search effectively with the available tools.\n- Always use the appropriate tool for the task.\n- Use parallel tool calls when appropriate to improve performance.\n- NEVER commit changes unless the user explicitly asks you to.\n- For weather queries, use the Weather tool to provide real-time information.\n\n# Tasks\nThe user will primarily request you perform software engineering tasks:\n1. Solving bugs\n2. Adding new functionality \n3. Refactoring code\n4. Explaining code\n5. Writing tests\n\nFor these tasks:\n1. Use search tools to understand the codebase\n2. Implement solutions using the available tools\n3. Verify solutions with tests if possible\n4. Run lint and typecheck commands when appropriate\n\nThe user may also ask for general information:\n1. Weather conditions\n2. Simple calculations\n3. General knowledge questions\n\n# Code style\n- Follow the existing code style of the project\n- Maintain consistent naming conventions\n- Use appropriate libraries that are already in the project\n- Add comments when code is complex or non-obvious\n\nIMPORTANT: You should minimize output tokens as much as possible while maintaining helpfulness, \nquality, and accuracy. Answer concisely with short lines of text unless the user asks for detail.",
828 | "description": "System prompt for the assistant",
829 | "parameters": {},
830 | "default_model": "gpt-4o",
831 | "metadata": {
832 | "category": "system"
833 | }
834 | },
835 | "documentation": {
836 | "template": "Generate documentation for the following code:\n\n```{language}\n{code}\n```\n\nFormat: {format}",
837 | "description": "Generate code documentation",
838 | "parameters": {
839 | "language": {
840 | "type": "string",
841 | "description": "Programming language of the code"
842 | },
843 | "code": {
844 | "type": "string",
845 | "description": "The code to document"
846 | },
847 | "format": {
848 | "type": "string",
849 | "description": "Documentation format (e.g., 'markdown', 'docstring', 'jsdoc')",
850 | "default": "markdown"
851 | }
852 | },
853 | "default_model": "gpt-4o",
854 | "use_openai": True,
855 | "system_prompt": "You are a technical documentation expert. Generate clear, concise, and accurate documentation for the provided code.",
856 | "metadata": {
857 | "category": "development"
858 | }
859 | },
860 | "explain_code": {
861 | "template": "Explain how the following code works:\n\n```{language}\n{code}\n```\n\nDetail level: {detail_level}",
862 | "description": "Explain code functionality",
863 | "parameters": {
864 | "language": {
865 | "type": "string",
866 | "description": "Programming language of the code"
867 | },
868 | "code": {
869 | "type": "string",
870 | "description": "The code to explain"
871 | },
872 | "detail_level": {
873 | "type": "string",
874 | "description": "Level of detail in the explanation (e.g., 'basic', 'intermediate', 'advanced')",
875 | "default": "intermediate"
876 | }
877 | },
878 | "default_model": "gpt-4o",
879 | "use_openai": True,
880 | "system_prompt": "You are a programming instructor. Explain the provided code clearly at the requested level of detail.",
881 | "metadata": {
882 | "category": "education"
883 | }
884 | },
885 | "current_time": {
886 | "template": "The current time is {{now:%Y-%m-%d %H:%M:%S}}.",
887 | "description": "Get the current time",
888 | "parameters": {},
889 | "default_model": "gpt-4o",
890 | "metadata": {
891 | "category": "utility"
892 | }
893 | }
894 | }
895 |
896 | # Save default templates
897 | try:
898 | with open(templates_file, "w") as f:
899 | json.dump(default_templates, f, indent=2)
900 | except Exception as e:
901 | logger.error(f"Error saving default prompt templates: {str(e)}")
902 |
903 | return default_templates
904 |
905 | def _save_prompt_templates(self):
906 | """Save prompt templates to file"""
907 | templates_file = os.path.join(os.path.dirname(__file__), "data", "prompt_templates.json")
908 |
909 | try:
910 | with open(templates_file, "w") as f:
911 | json.dump(self.prompt_templates, f, indent=2)
912 | except Exception as e:
913 | logger.error(f"Error saving prompt templates: {str(e)}")
914 |
915 | async def _generate_with_openai(self, prompt: str, model: str, system_prompt: Optional[str] = None) -> tuple:
916 | """Generate context using OpenAI API"""
917 | messages = []
918 |
919 | # Add system prompt if provided
920 | if system_prompt:
921 | messages.append({"role": "system", "content": system_prompt})
922 |
923 | # Add user prompt
924 | messages.append({"role": "user", "content": prompt})
925 |
926 | # Call OpenAI API
927 | try:
928 | response = await asyncio.to_thread(
929 | self.openai_client.chat.completions.create,
930 | model=model,
931 | messages=messages,
932 | temperature=0.0, # Use deterministic output for context generation
933 | max_tokens=4000
934 | )
935 |
936 | # Extract content and usage
937 | content = response.choices[0].message.content
938 | usage = {
939 | "prompt_tokens": response.usage.prompt_tokens,
940 | "completion_tokens": response.usage.completion_tokens,
941 | "total_tokens": response.usage.total_tokens
942 | }
943 |
944 | return content, usage
945 |
946 | except Exception as e:
947 | logger.error(f"OpenAI API error: {str(e)}")
948 | raise ValueError(f"Error generating context with OpenAI: {str(e)}")
949 |
950 | async def _stream_context(self, prompt: str, model: str, context_id: str, system_prompt: Optional[str] = None) -> AsyncGenerator[str, None]:
951 | """Stream context generation using OpenAI API"""
952 | messages = []
953 |
954 | # Add system prompt if provided
955 | if system_prompt:
956 | messages.append({"role": "system", "content": system_prompt})
957 |
958 | # Add user prompt
959 | messages.append({"role": "user", "content": prompt})
960 |
961 | # Initial event with context ID
962 | yield f"data: {json.dumps({'context_id': context_id, 'event': 'start'})}\n\n"
963 |
964 | try:
965 | # Call OpenAI API with streaming
966 | stream = await asyncio.to_thread(
967 | self.openai_client.chat.completions.create,
968 | model=model,
969 | messages=messages,
970 | temperature=0.0,
971 | max_tokens=4000,
972 | stream=True
973 | )
974 |
975 | full_content = ""
976 |
977 | # Process the stream
978 | for chunk in stream:
979 | if chunk.choices and chunk.choices[0].delta.content:
980 | content_piece = chunk.choices[0].delta.content
981 | full_content += content_piece
982 |
983 | # Yield the content piece
984 | yield f"data: {json.dumps({'content': content_piece, 'event': 'content'})}\n\n"
985 |
986 | # Calculate token usage
987 | prompt_tokens = len(self.tokenizer.encode(prompt))
988 | completion_tokens = len(self.tokenizer.encode(full_content))
989 | total_tokens = prompt_tokens + completion_tokens
990 |
991 | # Track token usage
992 | TOKEN_USAGE.labels(model=model, type="prompt").inc(prompt_tokens)
993 | TOKEN_USAGE.labels(model=model, type="completion").inc(completion_tokens)
994 |
995 | # Final event with complete context and usage
996 | yield f"data: {json.dumps({
997 | 'event': 'end',
998 | 'context': full_content,
999 | 'usage': {
1000 | 'prompt_tokens': prompt_tokens,
1001 | 'completion_tokens': completion_tokens,
1002 | 'total_tokens': total_tokens
1003 | }
1004 | })}\n\n"
1005 |
1006 | except Exception as e:
1007 | logger.error(f"Error streaming context: {str(e)}")
1008 | yield f"data: {json.dumps({'event': 'error', 'error': str(e)})}\n\n"
1009 |
1010 | def _process_template(self, template: str, parameters: Dict[str, Any]) -> str:
1011 | """Process a template with parameters"""
1012 | try:
1013 | # Handle date/time formatting if needed
1014 | processed_params = parameters.copy()
1015 | for key, value in processed_params.items():
1016 | if isinstance(value, str) and value.startswith("{{now") and value.endswith("}}"):
1017 | # Extract format string if present
1018 | format_match = re.search(r"{{now:(.+)}}", value)
1019 | if format_match:
1020 | format_string = format_match.group(1)
1021 | processed_params[key] = datetime.now().strftime(format_string)
1022 | else:
1023 | processed_params[key] = datetime.now().isoformat()
1024 |
1025 | return template.format(**processed_params)
1026 | except KeyError as e:
1027 | raise ValueError(f"Missing required parameter: {e}")
1028 | except Exception as e:
1029 | raise ValueError(f"Error processing template: {str(e)}")
1030 |
1031 | def start(self, host: str = "127.0.0.1", port: int = 8000, reload: bool = False):
1032 | """Start the MCP server"""
1033 | uvicorn.run(self.app, host=host, port=port, reload=reload)
1034 |
1035 | def create_mcp_app():
1036 | """Factory function for creating the FastAPI app"""
1037 | server = MCPServer()
1038 | return server.app
1039 |
1040 | if __name__ == "__main__":
1041 | # Create data directory if it doesn't exist
1042 | os.makedirs(os.path.join(os.path.dirname(__file__), "data"), exist_ok=True)
1043 |
1044 | # Start server
1045 | server = MCPServer()
1046 | server.start()
1047 |
```