This is page 14 of 19. Use http://codebase.md/beehiveinnovations/gemini-mcp-server?page={x} to view the full context.
# Directory Structure
```
├── .claude
│ ├── commands
│ │ └── fix-github-issue.md
│ └── settings.json
├── .coveragerc
├── .dockerignore
├── .env.example
├── .gitattributes
├── .github
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE
│ │ ├── bug_report.yml
│ │ ├── config.yml
│ │ ├── documentation.yml
│ │ ├── feature_request.yml
│ │ └── tool_addition.yml
│ ├── pull_request_template.md
│ └── workflows
│ ├── docker-pr.yml
│ ├── docker-release.yml
│ ├── semantic-pr.yml
│ ├── semantic-release.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── AGENTS.md
├── CHANGELOG.md
├── claude_config_example.json
├── CLAUDE.md
├── clink
│ ├── __init__.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── codex.py
│ │ └── gemini.py
│ ├── constants.py
│ ├── models.py
│ ├── parsers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── claude.py
│ │ ├── codex.py
│ │ └── gemini.py
│ └── registry.py
├── code_quality_checks.ps1
├── code_quality_checks.sh
├── communication_simulator_test.py
├── conf
│ ├── __init__.py
│ ├── azure_models.json
│ ├── cli_clients
│ │ ├── claude.json
│ │ ├── codex.json
│ │ └── gemini.json
│ ├── custom_models.json
│ ├── dial_models.json
│ ├── gemini_models.json
│ ├── openai_models.json
│ ├── openrouter_models.json
│ └── xai_models.json
├── config.py
├── docker
│ ├── README.md
│ └── scripts
│ ├── build.ps1
│ ├── build.sh
│ ├── deploy.ps1
│ ├── deploy.sh
│ └── healthcheck.py
├── docker-compose.yml
├── Dockerfile
├── docs
│ ├── adding_providers.md
│ ├── adding_tools.md
│ ├── advanced-usage.md
│ ├── ai_banter.md
│ ├── ai-collaboration.md
│ ├── azure_openai.md
│ ├── configuration.md
│ ├── context-revival.md
│ ├── contributions.md
│ ├── custom_models.md
│ ├── docker-deployment.md
│ ├── gemini-setup.md
│ ├── getting-started.md
│ ├── index.md
│ ├── locale-configuration.md
│ ├── logging.md
│ ├── model_ranking.md
│ ├── testing.md
│ ├── tools
│ │ ├── analyze.md
│ │ ├── apilookup.md
│ │ ├── challenge.md
│ │ ├── chat.md
│ │ ├── clink.md
│ │ ├── codereview.md
│ │ ├── consensus.md
│ │ ├── debug.md
│ │ ├── docgen.md
│ │ ├── listmodels.md
│ │ ├── planner.md
│ │ ├── precommit.md
│ │ ├── refactor.md
│ │ ├── secaudit.md
│ │ ├── testgen.md
│ │ ├── thinkdeep.md
│ │ ├── tracer.md
│ │ └── version.md
│ ├── troubleshooting.md
│ ├── vcr-testing.md
│ └── wsl-setup.md
├── examples
│ ├── claude_config_macos.json
│ └── claude_config_wsl.json
├── LICENSE
├── providers
│ ├── __init__.py
│ ├── azure_openai.py
│ ├── base.py
│ ├── custom.py
│ ├── dial.py
│ ├── gemini.py
│ ├── openai_compatible.py
│ ├── openai.py
│ ├── openrouter.py
│ ├── registries
│ │ ├── __init__.py
│ │ ├── azure.py
│ │ ├── base.py
│ │ ├── custom.py
│ │ ├── dial.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ ├── openrouter.py
│ │ └── xai.py
│ ├── registry_provider_mixin.py
│ ├── registry.py
│ ├── shared
│ │ ├── __init__.py
│ │ ├── model_capabilities.py
│ │ ├── model_response.py
│ │ ├── provider_type.py
│ │ └── temperature.py
│ └── xai.py
├── pyproject.toml
├── pytest.ini
├── README.md
├── requirements-dev.txt
├── requirements.txt
├── run_integration_tests.ps1
├── run_integration_tests.sh
├── run-server.ps1
├── run-server.sh
├── scripts
│ └── sync_version.py
├── server.py
├── simulator_tests
│ ├── __init__.py
│ ├── base_test.py
│ ├── conversation_base_test.py
│ ├── log_utils.py
│ ├── test_analyze_validation.py
│ ├── test_basic_conversation.py
│ ├── test_chat_simple_validation.py
│ ├── test_codereview_validation.py
│ ├── test_consensus_conversation.py
│ ├── test_consensus_three_models.py
│ ├── test_consensus_workflow_accurate.py
│ ├── test_content_validation.py
│ ├── test_conversation_chain_validation.py
│ ├── test_cross_tool_comprehensive.py
│ ├── test_cross_tool_continuation.py
│ ├── test_debug_certain_confidence.py
│ ├── test_debug_validation.py
│ ├── test_line_number_validation.py
│ ├── test_logs_validation.py
│ ├── test_model_thinking_config.py
│ ├── test_o3_model_selection.py
│ ├── test_o3_pro_expensive.py
│ ├── test_ollama_custom_url.py
│ ├── test_openrouter_fallback.py
│ ├── test_openrouter_models.py
│ ├── test_per_tool_deduplication.py
│ ├── test_planner_continuation_history.py
│ ├── test_planner_validation_old.py
│ ├── test_planner_validation.py
│ ├── test_precommitworkflow_validation.py
│ ├── test_prompt_size_limit_bug.py
│ ├── test_refactor_validation.py
│ ├── test_secaudit_validation.py
│ ├── test_testgen_validation.py
│ ├── test_thinkdeep_validation.py
│ ├── test_token_allocation_validation.py
│ ├── test_vision_capability.py
│ └── test_xai_models.py
├── systemprompts
│ ├── __init__.py
│ ├── analyze_prompt.py
│ ├── chat_prompt.py
│ ├── clink
│ │ ├── codex_codereviewer.txt
│ │ ├── default_codereviewer.txt
│ │ ├── default_planner.txt
│ │ └── default.txt
│ ├── codereview_prompt.py
│ ├── consensus_prompt.py
│ ├── debug_prompt.py
│ ├── docgen_prompt.py
│ ├── generate_code_prompt.py
│ ├── planner_prompt.py
│ ├── precommit_prompt.py
│ ├── refactor_prompt.py
│ ├── secaudit_prompt.py
│ ├── testgen_prompt.py
│ ├── thinkdeep_prompt.py
│ └── tracer_prompt.py
├── tests
│ ├── __init__.py
│ ├── CASSETTE_MAINTENANCE.md
│ ├── conftest.py
│ ├── gemini_cassettes
│ │ ├── chat_codegen
│ │ │ └── gemini25_pro_calculator
│ │ │ └── mldev.json
│ │ ├── chat_cross
│ │ │ └── step1_gemini25_flash_number
│ │ │ └── mldev.json
│ │ └── consensus
│ │ └── step2_gemini25_flash_against
│ │ └── mldev.json
│ ├── http_transport_recorder.py
│ ├── mock_helpers.py
│ ├── openai_cassettes
│ │ ├── chat_cross_step2_gpt5_reminder.json
│ │ ├── chat_gpt5_continuation.json
│ │ ├── chat_gpt5_moon_distance.json
│ │ ├── consensus_step1_gpt5_for.json
│ │ └── o3_pro_basic_math.json
│ ├── pii_sanitizer.py
│ ├── sanitize_cassettes.py
│ ├── test_alias_target_restrictions.py
│ ├── test_auto_mode_comprehensive.py
│ ├── test_auto_mode_custom_provider_only.py
│ ├── test_auto_mode_model_listing.py
│ ├── test_auto_mode_provider_selection.py
│ ├── test_auto_mode.py
│ ├── test_auto_model_planner_fix.py
│ ├── test_azure_openai_provider.py
│ ├── test_buggy_behavior_prevention.py
│ ├── test_cassette_semantic_matching.py
│ ├── test_challenge.py
│ ├── test_chat_codegen_integration.py
│ ├── test_chat_cross_model_continuation.py
│ ├── test_chat_openai_integration.py
│ ├── test_chat_simple.py
│ ├── test_clink_claude_agent.py
│ ├── test_clink_claude_parser.py
│ ├── test_clink_codex_agent.py
│ ├── test_clink_gemini_agent.py
│ ├── test_clink_gemini_parser.py
│ ├── test_clink_integration.py
│ ├── test_clink_parsers.py
│ ├── test_clink_tool.py
│ ├── test_collaboration.py
│ ├── test_config.py
│ ├── test_consensus_integration.py
│ ├── test_consensus_schema.py
│ ├── test_consensus.py
│ ├── test_conversation_continuation_integration.py
│ ├── test_conversation_field_mapping.py
│ ├── test_conversation_file_features.py
│ ├── test_conversation_memory.py
│ ├── test_conversation_missing_files.py
│ ├── test_custom_openai_temperature_fix.py
│ ├── test_custom_provider.py
│ ├── test_debug.py
│ ├── test_deploy_scripts.py
│ ├── test_dial_provider.py
│ ├── test_directory_expansion_tracking.py
│ ├── test_disabled_tools.py
│ ├── test_docker_claude_desktop_integration.py
│ ├── test_docker_config_complete.py
│ ├── test_docker_healthcheck.py
│ ├── test_docker_implementation.py
│ ├── test_docker_mcp_validation.py
│ ├── test_docker_security.py
│ ├── test_docker_volume_persistence.py
│ ├── test_file_protection.py
│ ├── test_gemini_token_usage.py
│ ├── test_image_support_integration.py
│ ├── test_image_validation.py
│ ├── test_integration_utf8.py
│ ├── test_intelligent_fallback.py
│ ├── test_issue_245_simple.py
│ ├── test_large_prompt_handling.py
│ ├── test_line_numbers_integration.py
│ ├── test_listmodels_restrictions.py
│ ├── test_listmodels.py
│ ├── test_mcp_error_handling.py
│ ├── test_model_enumeration.py
│ ├── test_model_metadata_continuation.py
│ ├── test_model_resolution_bug.py
│ ├── test_model_restrictions.py
│ ├── test_o3_pro_output_text_fix.py
│ ├── test_o3_temperature_fix_simple.py
│ ├── test_openai_compatible_token_usage.py
│ ├── test_openai_provider.py
│ ├── test_openrouter_provider.py
│ ├── test_openrouter_registry.py
│ ├── test_parse_model_option.py
│ ├── test_per_tool_model_defaults.py
│ ├── test_pii_sanitizer.py
│ ├── test_pip_detection_fix.py
│ ├── test_planner.py
│ ├── test_precommit_workflow.py
│ ├── test_prompt_regression.py
│ ├── test_prompt_size_limit_bug_fix.py
│ ├── test_provider_retry_logic.py
│ ├── test_provider_routing_bugs.py
│ ├── test_provider_utf8.py
│ ├── test_providers.py
│ ├── test_rate_limit_patterns.py
│ ├── test_refactor.py
│ ├── test_secaudit.py
│ ├── test_server.py
│ ├── test_supported_models_aliases.py
│ ├── test_thinking_modes.py
│ ├── test_tools.py
│ ├── test_tracer.py
│ ├── test_utf8_localization.py
│ ├── test_utils.py
│ ├── test_uvx_resource_packaging.py
│ ├── test_uvx_support.py
│ ├── test_workflow_file_embedding.py
│ ├── test_workflow_metadata.py
│ ├── test_workflow_prompt_size_validation_simple.py
│ ├── test_workflow_utf8.py
│ ├── test_xai_provider.py
│ ├── transport_helpers.py
│ └── triangle.png
├── tools
│ ├── __init__.py
│ ├── analyze.py
│ ├── apilookup.py
│ ├── challenge.py
│ ├── chat.py
│ ├── clink.py
│ ├── codereview.py
│ ├── consensus.py
│ ├── debug.py
│ ├── docgen.py
│ ├── listmodels.py
│ ├── models.py
│ ├── planner.py
│ ├── precommit.py
│ ├── refactor.py
│ ├── secaudit.py
│ ├── shared
│ │ ├── __init__.py
│ │ ├── base_models.py
│ │ ├── base_tool.py
│ │ ├── exceptions.py
│ │ └── schema_builders.py
│ ├── simple
│ │ ├── __init__.py
│ │ └── base.py
│ ├── testgen.py
│ ├── thinkdeep.py
│ ├── tracer.py
│ ├── version.py
│ └── workflow
│ ├── __init__.py
│ ├── base.py
│ ├── schema_builders.py
│ └── workflow_mixin.py
├── utils
│ ├── __init__.py
│ ├── client_info.py
│ ├── conversation_memory.py
│ ├── env.py
│ ├── file_types.py
│ ├── file_utils.py
│ ├── image_utils.py
│ ├── model_context.py
│ ├── model_restrictions.py
│ ├── security_config.py
│ ├── storage_backend.py
│ └── token_utils.py
└── zen-mcp-server
```
# Files
--------------------------------------------------------------------------------
/tools/secaudit.py:
--------------------------------------------------------------------------------
```python
"""
SECAUDIT Workflow tool - Comprehensive security audit with systematic investigation
This tool provides a structured workflow for comprehensive security assessment and analysis.
It guides the CLI agent through systematic investigation steps with forced pauses between each step
to ensure thorough security examination, vulnerability identification, and compliance assessment
before proceeding. The tool supports complex security scenarios including OWASP Top 10 coverage,
compliance framework mapping, and technology-specific security patterns.
Key features:
- Step-by-step security audit workflow with progress tracking
- Context-aware file embedding (references during investigation, full content for analysis)
- Automatic security issue tracking with severity classification
- Expert analysis integration with external models
- Support for focused security audits (OWASP, compliance, technology-specific)
- Confidence-based workflow optimization
- Risk-based prioritization and remediation planning
"""
import logging
from typing import TYPE_CHECKING, Any, Literal, Optional
from pydantic import Field, model_validator
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from config import TEMPERATURE_ANALYTICAL
from systemprompts import SECAUDIT_PROMPT
from tools.shared.base_models import WorkflowRequest
from .workflow.base import WorkflowTool
logger = logging.getLogger(__name__)
# Tool-specific field descriptions for security audit workflow
SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS = {
"step": (
"Step 1: outline the audit strategy (OWASP Top 10, auth, validation, etc.). Later steps: report findings. MANDATORY: use `relevant_files` for code references and avoid large snippets."
),
"step_number": "Current security-audit step number (starts at 1).",
"total_steps": "Expected number of audit steps; adjust as new risks surface.",
"next_step_required": "True while additional threat analysis remains; set False once you are ready to hand off for validation.",
"findings": "Summarize vulnerabilities, auth issues, validation gaps, compliance notes, and positives; update prior findings as needed.",
"files_checked": "Absolute paths for every file inspected, including rejected candidates.",
"relevant_files": "Absolute paths for security-relevant files (auth modules, configs, sensitive code).",
"relevant_context": "Security-critical classes/methods (e.g. 'AuthService.login', 'encryption_helper').",
"issues_found": "Security issues with severity (critical/high/medium/low) and descriptions (vulns, auth flaws, injection, crypto, config).",
"confidence": "exploring/low/medium/high/very_high/almost_certain/certain. 'certain' blocks external validation—use only when fully complete.",
"images": "Optional absolute paths to diagrams or threat models that inform the audit.",
"security_scope": "Security context (web, mobile, API, cloud, etc.) including stack, user types, data sensitivity, and threat landscape.",
"threat_level": "Assess the threat level: low (internal/low-risk), medium (customer-facing/business data), high (regulated or sensitive), critical (financial/healthcare/PII).",
"compliance_requirements": "Applicable compliance frameworks or standards (SOC2, PCI DSS, HIPAA, GDPR, ISO 27001, NIST, etc.).",
"audit_focus": "Primary focus area: owasp, compliance, infrastructure, dependencies, or comprehensive.",
"severity_filter": "Minimum severity to include when reporting security issues.",
}
class SecauditRequest(WorkflowRequest):
"""Request model for security audit workflow investigation steps"""
# Required fields for each investigation step
step: str = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step"])
step_number: int = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
total_steps: int = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
next_step_required: bool = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
# Investigation tracking fields
findings: str = Field(..., description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
files_checked: list[str] = Field(
default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
)
relevant_files: list[str] = Field(
default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
)
relevant_context: list[str] = Field(
default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
)
issues_found: list[dict] = Field(
default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"]
)
confidence: Optional[str] = Field("low", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
# Optional images for visual context
images: Optional[list[str]] = Field(default=None, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["images"])
# Security audit-specific fields
security_scope: Optional[str] = Field(None, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["security_scope"])
threat_level: Optional[Literal["low", "medium", "high", "critical"]] = Field(
"medium", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["threat_level"]
)
compliance_requirements: Optional[list[str]] = Field(
default_factory=list, description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["compliance_requirements"]
)
audit_focus: Optional[Literal["owasp", "compliance", "infrastructure", "dependencies", "comprehensive"]] = Field(
"comprehensive", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["audit_focus"]
)
severity_filter: Optional[Literal["critical", "high", "medium", "low", "all"]] = Field(
"all", description=SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"]
)
@model_validator(mode="after")
def validate_security_audit_request(self):
"""Validate security audit request parameters"""
# Ensure security scope is provided for comprehensive audits
if self.step_number == 1 and not self.security_scope:
logger.warning("Security scope not provided for security audit - defaulting to general application")
# Validate compliance requirements format
if self.compliance_requirements:
valid_compliance = {"SOC2", "PCI DSS", "HIPAA", "GDPR", "ISO 27001", "NIST", "FedRAMP", "FISMA"}
for req in self.compliance_requirements:
if req not in valid_compliance:
logger.warning(f"Unknown compliance requirement: {req}")
return self
class SecauditTool(WorkflowTool):
"""
Comprehensive security audit workflow tool.
Provides systematic security assessment through multi-step investigation
covering OWASP Top 10, compliance requirements, and technology-specific
security patterns. Follows established WorkflowTool patterns while adding
security-specific capabilities.
"""
def __init__(self):
super().__init__()
self.initial_request = None
self.security_config = {}
def get_name(self) -> str:
"""Return the unique name of the tool."""
return "secaudit"
def get_description(self) -> str:
"""Return a description of the tool."""
return (
"Performs comprehensive security audit with systematic vulnerability assessment. "
"Use for OWASP Top 10 analysis, compliance evaluation, threat modeling, and security architecture review. "
"Guides through structured security investigation with expert validation."
)
def get_system_prompt(self) -> str:
"""Return the system prompt for expert security analysis."""
return SECAUDIT_PROMPT
def get_default_temperature(self) -> float:
"""Return the temperature for security audit analysis"""
return TEMPERATURE_ANALYTICAL
def get_model_category(self) -> "ToolModelCategory":
"""Return the model category for security audit"""
from tools.models import ToolModelCategory
return ToolModelCategory.EXTENDED_REASONING
def get_workflow_request_model(self) -> type:
"""Return the workflow request model class"""
return SecauditRequest
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
"""
Get security audit tool field definitions.
Returns comprehensive field definitions including security-specific
parameters while maintaining compatibility with existing workflow patterns.
"""
return SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS
def get_required_actions(
self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
) -> list[str]:
"""
Provide step-specific guidance for systematic security analysis.
Each step focuses on specific security domains to ensure comprehensive
coverage without missing critical security aspects.
"""
if step_number == 1:
return [
"Identify application type, technology stack, and security scope",
"Map attack surface, entry points, and data flows",
"Determine relevant security standards and compliance requirements",
"Establish threat landscape and risk context for the application",
]
elif step_number == 2:
return [
"Analyze authentication mechanisms and session management",
"Check authorization controls, access patterns, and privilege escalation risks",
"Assess multi-factor authentication, password policies, and account security",
"Review identity and access management implementations",
]
elif step_number == 3:
return [
"Examine input validation and sanitization mechanisms across all entry points",
"Check for injection vulnerabilities (SQL, XSS, Command, LDAP, NoSQL)",
"Review data encryption, sensitive data handling, and cryptographic implementations",
"Analyze API input validation, rate limiting, and request/response security",
]
elif step_number == 4:
return [
"Conduct OWASP Top 10 (2021) systematic review across all categories",
"Check each OWASP category methodically with specific findings and evidence",
"Cross-reference findings with application context and technology stack",
"Prioritize vulnerabilities based on exploitability and business impact",
]
elif step_number == 5:
return [
"Analyze third-party dependencies for known vulnerabilities and outdated versions",
"Review configuration security, default settings, and hardening measures",
"Check for hardcoded secrets, credentials, and sensitive information exposure",
"Assess logging, monitoring, incident response, and security observability",
]
elif step_number == 6:
return [
"Evaluate compliance requirements and identify gaps in controls",
"Assess business impact and risk levels of all identified findings",
"Create prioritized remediation roadmap with timeline and effort estimates",
"Document comprehensive security posture and recommendations",
]
else:
return [
"Continue systematic security investigation based on emerging findings",
"Deep-dive into specific security concerns identified in previous steps",
"Validate security hypotheses and confirm vulnerability assessments",
]
def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
"""
Determine when to call expert security analysis.
Expert analysis is triggered when the security audit has meaningful findings
unless the user requested to skip assistant model.
"""
# Check if user requested to skip assistant model
if request and not self.get_request_use_assistant_model(request):
return False
# Check if we have meaningful investigation data
return (
len(consolidated_findings.relevant_files) > 0
or len(consolidated_findings.findings) >= 2
or len(consolidated_findings.issues_found) > 0
)
def prepare_expert_analysis_context(self, consolidated_findings) -> str:
"""
Prepare comprehensive context for expert security model analysis.
Provides security-specific context including scope, threat level,
compliance requirements, and systematic findings for expert validation.
"""
context_parts = [
f"=== SECURITY AUDIT REQUEST ===\n{self.initial_request or 'Security audit workflow initiated'}\n=== END REQUEST ==="
]
# Add investigation summary
investigation_summary = self._build_security_audit_summary(consolidated_findings)
context_parts.append(
f"\n=== AGENT'S SECURITY INVESTIGATION ===\n{investigation_summary}\n=== END INVESTIGATION ==="
)
# Add security configuration context if available
if self.security_config:
config_text = "\n".join(f"- {key}: {value}" for key, value in self.security_config.items() if value)
context_parts.append(f"\n=== SECURITY CONFIGURATION ===\n{config_text}\n=== END CONFIGURATION ===")
# Add relevant files if available
if consolidated_findings.relevant_files:
files_text = "\n".join(f"- {file}" for file in consolidated_findings.relevant_files)
context_parts.append(f"\n=== RELEVANT FILES ===\n{files_text}\n=== END FILES ===")
# Add relevant security elements if available
if consolidated_findings.relevant_context:
methods_text = "\n".join(f"- {method}" for method in consolidated_findings.relevant_context)
context_parts.append(
f"\n=== SECURITY-CRITICAL CODE ELEMENTS ===\n{methods_text}\n=== END CODE ELEMENTS ==="
)
# Add security issues found if available
if consolidated_findings.issues_found:
issues_text = self._format_security_issues(consolidated_findings.issues_found)
context_parts.append(f"\n=== SECURITY ISSUES IDENTIFIED ===\n{issues_text}\n=== END ISSUES ===")
# Add assessment evolution if available
if consolidated_findings.hypotheses:
assessments_text = "\n".join(
f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}"
for h in consolidated_findings.hypotheses
)
context_parts.append(f"\n=== ASSESSMENT EVOLUTION ===\n{assessments_text}\n=== END ASSESSMENTS ===")
# Add images if available
if consolidated_findings.images:
images_text = "\n".join(f"- {img}" for img in consolidated_findings.images)
context_parts.append(
f"\n=== VISUAL SECURITY INFORMATION ===\n{images_text}\n=== END VISUAL INFORMATION ==="
)
return "\n".join(context_parts)
def _format_security_issues(self, issues_found: list[dict]) -> str:
"""
Format security issues for expert analysis.
Organizes security findings by severity for clear expert review.
"""
if not issues_found:
return "No security issues identified during systematic investigation."
# Group issues by severity
severity_groups = {"critical": [], "high": [], "medium": [], "low": []}
for issue in issues_found:
severity = issue.get("severity", "low").lower()
description = issue.get("description", "No description provided")
if severity in severity_groups:
severity_groups[severity].append(description)
else:
severity_groups["low"].append(f"[{severity.upper()}] {description}")
formatted_issues = []
for severity in ["critical", "high", "medium", "low"]:
if severity_groups[severity]:
formatted_issues.append(f"\n{severity.upper()} SEVERITY:")
for issue in severity_groups[severity]:
formatted_issues.append(f" • {issue}")
return "\n".join(formatted_issues) if formatted_issues else "No security issues identified."
def _build_security_audit_summary(self, consolidated_findings) -> str:
"""Prepare a comprehensive summary of the security audit investigation."""
summary_parts = [
"=== SYSTEMATIC SECURITY AUDIT INVESTIGATION SUMMARY ===",
f"Total steps: {len(consolidated_findings.findings)}",
f"Files examined: {len(consolidated_findings.files_checked)}",
f"Relevant files identified: {len(consolidated_findings.relevant_files)}",
f"Security-critical elements analyzed: {len(consolidated_findings.relevant_context)}",
f"Security issues identified: {len(consolidated_findings.issues_found)}",
"",
"=== INVESTIGATION PROGRESSION ===",
]
for finding in consolidated_findings.findings:
summary_parts.append(finding)
return "\n".join(summary_parts)
def get_input_schema(self) -> dict[str, Any]:
"""Generate input schema using WorkflowSchemaBuilder with security audit-specific overrides."""
from .workflow.schema_builders import WorkflowSchemaBuilder
# Security audit workflow-specific field overrides
secaudit_field_overrides = {
"step": {
"type": "string",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step"],
},
"step_number": {
"type": "integer",
"minimum": 1,
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["step_number"],
},
"total_steps": {
"type": "integer",
"minimum": 1,
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"],
},
"next_step_required": {
"type": "boolean",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"],
},
"findings": {
"type": "string",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["findings"],
},
"files_checked": {
"type": "array",
"items": {"type": "string"},
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"],
},
"relevant_files": {
"type": "array",
"items": {"type": "string"},
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"],
},
"confidence": {
"type": "string",
"enum": ["exploring", "low", "medium", "high", "very_high", "almost_certain", "certain"],
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["confidence"],
},
"issues_found": {
"type": "array",
"items": {"type": "object"},
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"],
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["images"],
},
# Security audit-specific fields (for step 1)
"security_scope": {
"type": "string",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["security_scope"],
},
"threat_level": {
"type": "string",
"enum": ["low", "medium", "high", "critical"],
"default": "medium",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["threat_level"],
},
"compliance_requirements": {
"type": "array",
"items": {"type": "string"},
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["compliance_requirements"],
},
"audit_focus": {
"type": "string",
"enum": ["owasp", "compliance", "infrastructure", "dependencies", "comprehensive"],
"default": "comprehensive",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["audit_focus"],
},
"severity_filter": {
"type": "string",
"enum": ["critical", "high", "medium", "low", "all"],
"default": "all",
"description": SECAUDIT_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"],
},
}
# Use WorkflowSchemaBuilder with security audit-specific tool fields
return WorkflowSchemaBuilder.build_schema(
tool_specific_fields=secaudit_field_overrides,
model_field_schema=self.get_model_field_schema(),
auto_mode=self.is_effective_auto_mode(),
tool_name=self.get_name(),
)
# Hook method overrides for security audit-specific behavior
def prepare_step_data(self, request) -> dict:
"""Map security audit-specific fields for internal processing."""
step_data = {
"step": request.step,
"step_number": request.step_number,
"findings": request.findings,
"files_checked": request.files_checked,
"relevant_files": request.relevant_files,
"relevant_context": request.relevant_context,
"issues_found": request.issues_found,
"confidence": request.confidence,
"hypothesis": request.findings, # Map findings to hypothesis for compatibility
"images": request.images or [],
}
# Store security-specific configuration on first step
if request.step_number == 1:
self.security_config = {
"security_scope": request.security_scope,
"threat_level": request.threat_level,
"compliance_requirements": request.compliance_requirements,
"audit_focus": request.audit_focus,
"severity_filter": request.severity_filter,
}
return step_data
def should_skip_expert_analysis(self, request, consolidated_findings) -> bool:
"""Security audit workflow skips expert analysis when the CLI agent has "certain" confidence."""
return request.confidence == "certain" and not request.next_step_required
def store_initial_issue(self, step_description: str):
"""Store initial request for expert analysis."""
self.initial_request = step_description
def should_include_files_in_expert_prompt(self) -> bool:
"""Include files in expert analysis for comprehensive security audit."""
return True
def should_embed_system_prompt(self) -> bool:
"""Embed system prompt in expert analysis for proper context."""
return True
def get_expert_thinking_mode(self) -> str:
"""Use high thinking mode for thorough security analysis."""
return "high"
def get_expert_analysis_instruction(self) -> str:
"""Get specific instruction for security audit expert analysis."""
return (
"Please provide comprehensive security analysis based on the investigation findings. "
"Focus on identifying any remaining vulnerabilities, validating the completeness of the analysis, "
"and providing final recommendations for security improvements, following the OWASP-based "
"format specified in the system prompt."
)
def get_completion_next_steps_message(self, expert_analysis_used: bool = False) -> str:
"""
Security audit-specific completion message.
"""
base_message = (
"SECURITY AUDIT IS COMPLETE. You MUST now summarize and present ALL security findings organized by "
"severity (Critical → High → Medium → Low), specific code locations with line numbers, and exact "
"remediation steps for each vulnerability. Clearly prioritize the top 3 security issues that need "
"immediate attention. Provide concrete, actionable guidance for each vulnerability—make it easy for "
"developers to understand exactly what needs to be fixed and how to implement the security improvements."
)
# Add expert analysis guidance only when expert analysis was actually used
if expert_analysis_used:
expert_guidance = self.get_expert_analysis_guidance()
if expert_guidance:
return f"{base_message}\n\n{expert_guidance}"
return base_message
def get_expert_analysis_guidance(self) -> str:
"""
Provide specific guidance for handling expert analysis in security audits.
"""
return (
"IMPORTANT: Analysis from an assistant model has been provided above. You MUST critically evaluate and validate "
"the expert security findings rather than accepting them blindly. Cross-reference the expert analysis with "
"your own investigation findings, verify that suggested security improvements are appropriate for this "
"application's context and threat model, and ensure recommendations align with the project's security requirements. "
"Present a synthesis that combines your systematic security review with validated expert insights, clearly "
"distinguishing between vulnerabilities you've independently confirmed and additional insights from expert analysis."
)
def get_step_guidance_message(self, request) -> str:
"""
Security audit-specific step guidance with detailed investigation instructions.
"""
step_guidance = self.get_security_audit_step_guidance(request.step_number, request.confidence, request)
return step_guidance["next_steps"]
def get_security_audit_step_guidance(self, step_number: int, confidence: str, request) -> dict[str, Any]:
"""
Provide step-specific guidance for security audit workflow.
"""
# Generate the next steps instruction based on required actions
required_actions = self.get_required_actions(step_number, confidence, request.findings, request.total_steps)
if step_number == 1:
next_steps = (
f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first examine "
f"the code files thoroughly using appropriate tools. CRITICAL AWARENESS: You need to understand "
f"the security landscape, identify potential vulnerabilities across OWASP Top 10 categories, "
f"and look for authentication flaws, injection points, cryptographic issues, and authorization bypasses. "
f"Use file reading tools, security analysis, and systematic examination to gather comprehensive information. "
f"Only call {self.get_name()} again AFTER completing your security investigation. When you call "
f"{self.get_name()} next time, use step_number: {step_number + 1} and report specific "
f"files examined, vulnerabilities found, and security assessments discovered."
)
elif confidence in ["exploring", "low"]:
next_steps = (
f"STOP! Do NOT call {self.get_name()} again yet. Based on your findings, you've identified areas that need "
f"deeper security analysis. MANDATORY ACTIONS before calling {self.get_name()} step {step_number + 1}:\n"
+ "\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\n\nOnly call {self.get_name()} again with step_number: {step_number + 1} AFTER "
+ "completing these security audit tasks."
)
elif confidence in ["medium", "high"]:
next_steps = (
f"WAIT! Your security audit needs final verification. DO NOT call {self.get_name()} immediately. REQUIRED ACTIONS:\n"
+ "\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\n\nREMEMBER: Ensure you have identified all significant vulnerabilities across all severity levels and "
f"verified the completeness of your security review. Document findings with specific file references and "
f"line numbers where applicable, then call {self.get_name()} with step_number: {step_number + 1}."
)
else:
next_steps = (
f"PAUSE SECURITY AUDIT. Before calling {self.get_name()} step {step_number + 1}, you MUST examine more code thoroughly. "
+ "Required: "
+ ", ".join(required_actions[:2])
+ ". "
+ f"Your next {self.get_name()} call (step_number: {step_number + 1}) must include "
f"NEW evidence from actual security analysis, not just theories. NO recursive {self.get_name()} calls "
f"without investigation work!"
)
return {"next_steps": next_steps}
def customize_workflow_response(self, response_data: dict, request) -> dict:
"""
Customize response to match security audit workflow format.
"""
# Store initial request on first step
if request.step_number == 1:
self.initial_request = request.step
# Store security configuration for expert analysis
if request.relevant_files:
self.security_config = {
"relevant_files": request.relevant_files,
"security_scope": request.security_scope,
"threat_level": request.threat_level,
"compliance_requirements": request.compliance_requirements,
"audit_focus": request.audit_focus,
"severity_filter": request.severity_filter,
}
# Convert generic status names to security audit-specific ones
tool_name = self.get_name()
status_mapping = {
f"{tool_name}_in_progress": "security_audit_in_progress",
f"pause_for_{tool_name}": "pause_for_security_audit",
f"{tool_name}_required": "security_audit_required",
f"{tool_name}_complete": "security_audit_complete",
}
if response_data["status"] in status_mapping:
response_data["status"] = status_mapping[response_data["status"]]
# Rename status field to match security audit workflow
if f"{tool_name}_status" in response_data:
response_data["security_audit_status"] = response_data.pop(f"{tool_name}_status")
# Add security audit-specific status fields
response_data["security_audit_status"]["vulnerabilities_by_severity"] = {}
for issue in self.consolidated_findings.issues_found:
severity = issue.get("severity", "unknown")
if severity not in response_data["security_audit_status"]["vulnerabilities_by_severity"]:
response_data["security_audit_status"]["vulnerabilities_by_severity"][severity] = 0
response_data["security_audit_status"]["vulnerabilities_by_severity"][severity] += 1
response_data["security_audit_status"]["audit_confidence"] = self.get_request_confidence(request)
# Map complete_secaudit to complete_security_audit
if f"complete_{tool_name}" in response_data:
response_data["complete_security_audit"] = response_data.pop(f"complete_{tool_name}")
# Map the completion flag to match security audit workflow
if f"{tool_name}_complete" in response_data:
response_data["security_audit_complete"] = response_data.pop(f"{tool_name}_complete")
return response_data
# Override inheritance hooks for security audit-specific behavior
def get_completion_status(self) -> str:
"""Security audit tools use audit-specific status."""
return "security_analysis_complete"
def get_completion_data_key(self) -> str:
"""Security audit uses 'complete_security_audit' key."""
return "complete_security_audit"
def get_final_analysis_from_request(self, request):
"""Security audit tools use 'findings' field."""
return request.findings
def get_confidence_level(self, request) -> str:
"""Security audit tools use 'certain' for high confidence."""
return "certain"
def get_completion_message(self) -> str:
"""Security audit-specific completion message."""
return (
"Security audit complete with CERTAIN confidence. You have identified all significant vulnerabilities "
"and provided comprehensive security analysis. MANDATORY: Present the user with the complete security audit results "
"categorized by severity, and IMMEDIATELY proceed with implementing the highest priority security fixes "
"or provide specific guidance for vulnerability remediation. Focus on actionable security recommendations."
)
def get_skip_reason(self) -> str:
"""Security audit-specific skip reason."""
return "Completed comprehensive security audit with full confidence locally"
def get_skip_expert_analysis_status(self) -> str:
"""Security audit-specific expert analysis skip status."""
return "skipped_due_to_certain_audit_confidence"
def prepare_work_summary(self) -> str:
"""Security audit-specific work summary."""
return self._build_security_audit_summary(self.consolidated_findings)
def get_request_model(self):
"""Return the request model for this tool"""
return SecauditRequest
async def prepare_prompt(self, request: SecauditRequest) -> str:
"""Not used - workflow tools use execute_workflow()."""
return "" # Workflow tools use execute_workflow() directly
```
--------------------------------------------------------------------------------
/simulator_tests/test_testgen_validation.py:
--------------------------------------------------------------------------------
```python
#!/usr/bin/env python3
"""
TestGen Tool Validation Test
Tests the testgen tool's capabilities using the workflow architecture.
This validates that the workflow-based implementation guides Claude through
systematic test generation analysis before creating comprehensive test suites.
"""
import json
from typing import Optional
from .conversation_base_test import ConversationBaseTest
class TestGenValidationTest(ConversationBaseTest):
"""Test testgen tool with workflow architecture"""
@property
def test_name(self) -> str:
return "testgen_validation"
@property
def test_description(self) -> str:
return "TestGen tool validation with step-by-step test planning"
def run_test(self) -> bool:
"""Test testgen tool capabilities"""
# Set up the test environment
self.setUp()
try:
self.logger.info("Test: TestGen tool validation")
# Create sample code files to test
self._create_test_code_files()
# Test 1: Single investigation session with multiple steps
if not self._test_single_test_generation_session():
return False
# Test 2: Test generation with pattern following
if not self._test_generation_with_pattern_following():
return False
# Test 3: Complete test generation with expert analysis
if not self._test_complete_generation_with_analysis():
return False
# Test 4: Certain confidence behavior
if not self._test_certain_confidence():
return False
# Test 5: Context-aware file embedding
if not self._test_context_aware_file_embedding():
return False
# Test 6: Multi-step test planning
if not self._test_multi_step_test_planning():
return False
self.logger.info(" ✅ All testgen validation tests passed")
return True
except Exception as e:
self.logger.error(f"TestGen validation test failed: {e}")
return False
def _create_test_code_files(self):
"""Create sample code files for test generation"""
# Create a calculator module with various functions
calculator_code = """#!/usr/bin/env python3
\"\"\"
Simple calculator module for demonstration
\"\"\"
def add(a, b):
\"\"\"Add two numbers\"\"\"
return a + b
def subtract(a, b):
\"\"\"Subtract b from a\"\"\"
return a - b
def multiply(a, b):
\"\"\"Multiply two numbers\"\"\"
return a * b
def divide(a, b):
\"\"\"Divide a by b\"\"\"
if b == 0:
raise ValueError("Cannot divide by zero")
return a / b
def calculate_percentage(value, percentage):
\"\"\"Calculate percentage of a value\"\"\"
if percentage < 0:
raise ValueError("Percentage cannot be negative")
if percentage > 100:
raise ValueError("Percentage cannot exceed 100")
return (value * percentage) / 100
def power(base, exponent):
\"\"\"Calculate base raised to exponent\"\"\"
if base == 0 and exponent < 0:
raise ValueError("Cannot raise 0 to negative power")
return base ** exponent
"""
# Create test file
self.calculator_file = self.create_additional_test_file("calculator.py", calculator_code)
self.logger.info(f" ✅ Created calculator module: {self.calculator_file}")
# Create a simple existing test file to use as pattern
existing_test = """#!/usr/bin/env python3
import pytest
from calculator import add, subtract
class TestCalculatorBasic:
\"\"\"Test basic calculator operations\"\"\"
def test_add_positive_numbers(self):
\"\"\"Test adding two positive numbers\"\"\"
assert add(2, 3) == 5
assert add(10, 20) == 30
def test_add_negative_numbers(self):
\"\"\"Test adding negative numbers\"\"\"
assert add(-5, -3) == -8
assert add(-10, 5) == -5
def test_subtract_positive(self):
\"\"\"Test subtracting positive numbers\"\"\"
assert subtract(10, 3) == 7
assert subtract(5, 5) == 0
"""
self.existing_test_file = self.create_additional_test_file("test_calculator_basic.py", existing_test)
self.logger.info(f" ✅ Created existing test file: {self.existing_test_file}")
def _test_single_test_generation_session(self) -> bool:
"""Test a complete test generation session with multiple steps"""
try:
self.logger.info(" 1.1: Testing single test generation session")
# Step 1: Start investigation
self.logger.info(" 1.1.1: Step 1 - Initial test planning")
response1, continuation_id = self.call_mcp_tool(
"testgen",
{
"step": "I need to generate comprehensive tests for the calculator module. Let me start by analyzing the code structure and understanding the functionality.",
"step_number": 1,
"total_steps": 4,
"next_step_required": True,
"findings": "Calculator module contains 6 functions: add, subtract, multiply, divide, calculate_percentage, and power. Each has specific error conditions that need testing.",
"files_checked": [self.calculator_file],
"relevant_files": [self.calculator_file],
"relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
},
)
if not response1 or not continuation_id:
self.logger.error("Failed to get initial test planning response")
return False
# Parse and validate JSON response
response1_data = self._parse_testgen_response(response1)
if not response1_data:
return False
# Validate step 1 response structure
if not self._validate_step_response(response1_data, 1, 4, True, "pause_for_test_analysis"):
return False
self.logger.info(f" ✅ Step 1 successful, continuation_id: {continuation_id}")
# Step 2: Analyze test requirements
self.logger.info(" 1.1.2: Step 2 - Test requirements analysis")
response2, _ = self.call_mcp_tool(
"testgen",
{
"step": "Now analyzing the test requirements for each function, identifying edge cases and boundary conditions.",
"step_number": 2,
"total_steps": 4,
"next_step_required": True,
"findings": "Identified key test scenarios: (1) divide - zero division error, (2) calculate_percentage - negative/over 100 validation, (3) power - zero to negative power error. Need tests for normal cases and edge cases.",
"files_checked": [self.calculator_file],
"relevant_files": [self.calculator_file],
"relevant_context": ["divide", "calculate_percentage", "power"],
"confidence": "medium",
"continuation_id": continuation_id,
},
)
if not response2:
self.logger.error("Failed to continue test planning to step 2")
return False
response2_data = self._parse_testgen_response(response2)
if not self._validate_step_response(response2_data, 2, 4, True, "pause_for_test_analysis"):
return False
# Check test generation status tracking
test_status = response2_data.get("test_generation_status", {})
if test_status.get("test_scenarios_identified", 0) < 3:
self.logger.error("Test scenarios not properly tracked")
return False
if test_status.get("analysis_confidence") != "medium":
self.logger.error("Confidence level not properly tracked")
return False
self.logger.info(" ✅ Step 2 successful with proper tracking")
# Store continuation_id for next test
self.test_continuation_id = continuation_id
return True
except Exception as e:
self.logger.error(f"Single test generation session test failed: {e}")
return False
def _test_generation_with_pattern_following(self) -> bool:
"""Test test generation following existing patterns"""
try:
self.logger.info(" 1.2: Testing test generation with pattern following")
# Start a new investigation with existing test patterns
self.logger.info(" 1.2.1: Start test generation with pattern reference")
response1, continuation_id = self.call_mcp_tool(
"testgen",
{
"step": "Generating tests for remaining calculator functions following existing test patterns",
"step_number": 1,
"total_steps": 3,
"next_step_required": True,
"findings": "Found existing test pattern using pytest with class-based organization and descriptive test names",
"files_checked": [self.calculator_file, self.existing_test_file],
"relevant_files": [self.calculator_file, self.existing_test_file],
"relevant_context": ["TestCalculatorBasic", "multiply", "divide", "calculate_percentage", "power"],
},
)
if not response1 or not continuation_id:
self.logger.error("Failed to start pattern following test")
return False
# Step 2: Analyze patterns
self.logger.info(" 1.2.2: Step 2 - Pattern analysis")
response2, _ = self.call_mcp_tool(
"testgen",
{
"step": "Analyzing the existing test patterns to maintain consistency",
"step_number": 2,
"total_steps": 3,
"next_step_required": True,
"findings": "Existing tests use: class-based organization (TestCalculatorBasic), descriptive method names (test_operation_scenario), multiple assertions per test, pytest framework",
"files_checked": [self.existing_test_file],
"relevant_files": [self.calculator_file, self.existing_test_file],
"confidence": "high",
"continuation_id": continuation_id,
},
)
if not response2:
self.logger.error("Failed to continue to step 2")
return False
self.logger.info(" ✅ Pattern analysis successful")
return True
except Exception as e:
self.logger.error(f"Pattern following test failed: {e}")
return False
def _test_complete_generation_with_analysis(self) -> bool:
"""Test complete test generation ending with expert analysis"""
try:
self.logger.info(" 1.3: Testing complete test generation with expert analysis")
# Use the continuation from first test or start fresh
continuation_id = getattr(self, "test_continuation_id", None)
if not continuation_id:
# Start fresh if no continuation available
self.logger.info(" 1.3.0: Starting fresh test generation")
response0, continuation_id = self.call_mcp_tool(
"testgen",
{
"step": "Analyzing calculator module for comprehensive test generation",
"step_number": 1,
"total_steps": 2,
"next_step_required": True,
"findings": "Identified 6 functions needing tests with various edge cases",
"files_checked": [self.calculator_file],
"relevant_files": [self.calculator_file],
"relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
},
)
if not response0 or not continuation_id:
self.logger.error("Failed to start fresh test generation")
return False
# Final step - trigger expert analysis
self.logger.info(" 1.3.1: Final step - complete test planning")
response_final, _ = self.call_mcp_tool(
"testgen",
{
"step": "Test planning complete. Identified all test scenarios including edge cases, error conditions, and boundary values for comprehensive coverage.",
"step_number": 2,
"total_steps": 2,
"next_step_required": False, # Final step - triggers expert analysis
"findings": "Complete test plan: normal operations, edge cases (zero, negative), error conditions (divide by zero, invalid percentage, zero to negative power), boundary values",
"files_checked": [self.calculator_file],
"relevant_files": [self.calculator_file],
"relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
"confidence": "high",
"continuation_id": continuation_id,
"model": "flash", # Use flash for expert analysis
},
)
if not response_final:
self.logger.error("Failed to complete test generation")
return False
response_final_data = self._parse_testgen_response(response_final)
if not response_final_data:
return False
# Validate final response structure
if response_final_data.get("status") != "calling_expert_analysis":
self.logger.error(
f"Expected status 'calling_expert_analysis', got '{response_final_data.get('status')}'"
)
return False
if not response_final_data.get("test_generation_complete"):
self.logger.error("Expected test_generation_complete=true for final step")
return False
# Check for expert analysis
if "expert_analysis" not in response_final_data:
self.logger.error("Missing expert_analysis in final response")
return False
expert_analysis = response_final_data.get("expert_analysis", {})
# Check for expected analysis content
analysis_text = json.dumps(expert_analysis, ensure_ascii=False).lower()
# Look for test generation indicators
test_indicators = ["test", "edge", "boundary", "error", "coverage", "pytest"]
found_indicators = sum(1 for indicator in test_indicators if indicator in analysis_text)
if found_indicators >= 4:
self.logger.info(" ✅ Expert analysis provided comprehensive test suggestions")
else:
self.logger.warning(
f" ⚠️ Expert analysis may not have fully addressed test generation (found {found_indicators}/6 indicators)"
)
# Check complete test generation summary
if "complete_test_generation" not in response_final_data:
self.logger.error("Missing complete_test_generation in final response")
return False
complete_generation = response_final_data["complete_test_generation"]
if not complete_generation.get("relevant_context"):
self.logger.error("Missing relevant context in complete test generation")
return False
self.logger.info(" ✅ Complete test generation with expert analysis successful")
return True
except Exception as e:
self.logger.error(f"Complete test generation test failed: {e}")
return False
def _test_certain_confidence(self) -> bool:
"""Test certain confidence behavior - should skip expert analysis"""
try:
self.logger.info(" 1.4: Testing certain confidence behavior")
# Test certain confidence - should skip expert analysis
self.logger.info(" 1.4.1: Certain confidence test generation")
response_certain, _ = self.call_mcp_tool(
"testgen",
{
"step": "I have fully analyzed the code and identified all test scenarios with 100% certainty. Test plan is complete.",
"step_number": 1,
"total_steps": 1,
"next_step_required": False, # Final step
"findings": "Complete test coverage plan: all functions covered with normal cases, edge cases, and error conditions. Ready for implementation.",
"files_checked": [self.calculator_file],
"relevant_files": [self.calculator_file],
"relevant_context": ["add", "subtract", "multiply", "divide", "calculate_percentage", "power"],
"confidence": "certain", # This should skip expert analysis
"model": "flash",
},
)
if not response_certain:
self.logger.error("Failed to test certain confidence")
return False
response_certain_data = self._parse_testgen_response(response_certain)
if not response_certain_data:
return False
# Validate certain confidence response - should skip expert analysis
if response_certain_data.get("status") != "test_generation_complete_ready_for_implementation":
self.logger.error(
f"Expected status 'test_generation_complete_ready_for_implementation', got '{response_certain_data.get('status')}'"
)
return False
if not response_certain_data.get("skip_expert_analysis"):
self.logger.error("Expected skip_expert_analysis=true for certain confidence")
return False
expert_analysis = response_certain_data.get("expert_analysis", {})
if expert_analysis.get("status") != "skipped_due_to_certain_test_confidence":
self.logger.error("Expert analysis should be skipped for certain confidence")
return False
self.logger.info(" ✅ Certain confidence behavior working correctly")
return True
except Exception as e:
self.logger.error(f"Certain confidence test failed: {e}")
return False
def call_mcp_tool(self, tool_name: str, params: dict) -> tuple[Optional[str], Optional[str]]:
"""Call an MCP tool in-process - override for testgen-specific response handling"""
# Use in-process implementation to maintain conversation memory
response_text, _ = self.call_mcp_tool_direct(tool_name, params)
if not response_text:
return None, None
# Extract continuation_id from testgen response specifically
continuation_id = self._extract_testgen_continuation_id(response_text)
return response_text, continuation_id
def _extract_testgen_continuation_id(self, response_text: str) -> Optional[str]:
"""Extract continuation_id from testgen response"""
try:
# Parse the response
response_data = json.loads(response_text)
return response_data.get("continuation_id")
except json.JSONDecodeError as e:
self.logger.debug(f"Failed to parse response for testgen continuation_id: {e}")
return None
def _parse_testgen_response(self, response_text: str) -> dict:
"""Parse testgen tool JSON response"""
try:
# Parse the response - it should be direct JSON
return json.loads(response_text)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse testgen response as JSON: {e}")
self.logger.error(f"Response text: {response_text[:500]}...")
return {}
def _validate_step_response(
self,
response_data: dict,
expected_step: int,
expected_total: int,
expected_next_required: bool,
expected_status: str,
) -> bool:
"""Validate a test generation step response structure"""
try:
# Check status
if response_data.get("status") != expected_status:
self.logger.error(f"Expected status '{expected_status}', got '{response_data.get('status')}'")
return False
# Check step number
if response_data.get("step_number") != expected_step:
self.logger.error(f"Expected step_number {expected_step}, got {response_data.get('step_number')}")
return False
# Check total steps
if response_data.get("total_steps") != expected_total:
self.logger.error(f"Expected total_steps {expected_total}, got {response_data.get('total_steps')}")
return False
# Check next_step_required
if response_data.get("next_step_required") != expected_next_required:
self.logger.error(
f"Expected next_step_required {expected_next_required}, got {response_data.get('next_step_required')}"
)
return False
# Check test_generation_status exists
if "test_generation_status" not in response_data:
self.logger.error("Missing test_generation_status in response")
return False
# Check next_steps guidance
if not response_data.get("next_steps"):
self.logger.error("Missing next_steps guidance in response")
return False
return True
except Exception as e:
self.logger.error(f"Error validating step response: {e}")
return False
def _test_context_aware_file_embedding(self) -> bool:
"""Test context-aware file embedding optimization"""
try:
self.logger.info(" 1.5: Testing context-aware file embedding")
# Create additional test files
utils_code = """#!/usr/bin/env python3
def validate_number(n):
\"\"\"Validate if input is a number\"\"\"
return isinstance(n, (int, float))
def format_result(result):
\"\"\"Format calculation result\"\"\"
if isinstance(result, float):
return round(result, 2)
return result
"""
math_helpers_code = """#!/usr/bin/env python3
import math
def factorial(n):
\"\"\"Calculate factorial of n\"\"\"
if n < 0:
raise ValueError("Factorial not defined for negative numbers")
return math.factorial(n)
def is_prime(n):
\"\"\"Check if number is prime\"\"\"
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True
"""
# Create test files
utils_file = self.create_additional_test_file("utils.py", utils_code)
math_file = self.create_additional_test_file("math_helpers.py", math_helpers_code)
# Test 1: New conversation, intermediate step - should only reference files
self.logger.info(" 1.5.1: New conversation intermediate step (should reference only)")
response1, continuation_id = self.call_mcp_tool(
"testgen",
{
"step": "Starting test generation for utility modules",
"step_number": 1,
"total_steps": 3,
"next_step_required": True, # Intermediate step
"findings": "Initial analysis of utility functions",
"files_checked": [utils_file, math_file],
"relevant_files": [utils_file], # This should be referenced, not embedded
"relevant_context": ["validate_number", "format_result"],
"confidence": "low",
"model": "flash",
},
)
if not response1 or not continuation_id:
self.logger.error("Failed to start context-aware file embedding test")
return False
response1_data = self._parse_testgen_response(response1)
if not response1_data:
return False
# Check file context - should be reference_only for intermediate step
file_context = response1_data.get("file_context", {})
if file_context.get("type") != "reference_only":
self.logger.error(f"Expected reference_only file context, got: {file_context.get('type')}")
return False
self.logger.info(" ✅ Intermediate step correctly uses reference_only file context")
# Test 2: Final step - should embed files for expert analysis
self.logger.info(" 1.5.2: Final step (should embed files)")
response2, _ = self.call_mcp_tool(
"testgen",
{
"step": "Test planning complete - all test scenarios identified",
"step_number": 2,
"total_steps": 2,
"next_step_required": False, # Final step - should embed files
"continuation_id": continuation_id,
"findings": "Complete test plan for all utility functions with edge cases",
"files_checked": [utils_file, math_file],
"relevant_files": [utils_file, math_file], # Should be fully embedded
"relevant_context": ["validate_number", "format_result", "factorial", "is_prime"],
"confidence": "high",
"model": "flash",
},
)
if not response2:
self.logger.error("Failed to complete to final step")
return False
response2_data = self._parse_testgen_response(response2)
if not response2_data:
return False
# Check file context - should be fully_embedded for final step
file_context2 = response2_data.get("file_context", {})
if file_context2.get("type") != "fully_embedded":
self.logger.error(
f"Expected fully_embedded file context for final step, got: {file_context2.get('type')}"
)
return False
# Verify expert analysis was called for final step
if response2_data.get("status") != "calling_expert_analysis":
self.logger.error("Final step should trigger expert analysis")
return False
self.logger.info(" ✅ Context-aware file embedding test completed successfully")
return True
except Exception as e:
self.logger.error(f"Context-aware file embedding test failed: {e}")
return False
def _test_multi_step_test_planning(self) -> bool:
"""Test multi-step test planning with complex code"""
try:
self.logger.info(" 1.6: Testing multi-step test planning")
# Create a complex class to test
complex_code = """#!/usr/bin/env python3
import asyncio
from typing import List, Dict, Optional
class DataProcessor:
\"\"\"Complex data processor with async operations\"\"\"
def __init__(self, batch_size: int = 100):
self.batch_size = batch_size
self.processed_count = 0
self.error_count = 0
self.cache: Dict[str, any] = {}
async def process_batch(self, items: List[dict]) -> List[dict]:
\"\"\"Process a batch of items asynchronously\"\"\"
if not items:
return []
if len(items) > self.batch_size:
raise ValueError(f"Batch size {len(items)} exceeds limit {self.batch_size}")
results = []
for item in items:
try:
result = await self._process_single_item(item)
results.append(result)
self.processed_count += 1
except Exception as e:
self.error_count += 1
results.append({"error": str(e), "item": item})
return results
async def _process_single_item(self, item: dict) -> dict:
\"\"\"Process a single item with caching\"\"\"
item_id = item.get('id')
if not item_id:
raise ValueError("Item must have an ID")
# Check cache
if item_id in self.cache:
return self.cache[item_id]
# Simulate async processing
await asyncio.sleep(0.01)
processed = {
'id': item_id,
'processed': True,
'value': item.get('value', 0) * 2
}
# Cache result
self.cache[item_id] = processed
return processed
def get_stats(self) -> Dict[str, int]:
\"\"\"Get processing statistics\"\"\"
return {
'processed': self.processed_count,
'errors': self.error_count,
'cache_size': len(self.cache),
'success_rate': self.processed_count / (self.processed_count + self.error_count) if (self.processed_count + self.error_count) > 0 else 0
}
"""
# Create test file
processor_file = self.create_additional_test_file("data_processor.py", complex_code)
# Step 1: Start investigation
self.logger.info(" 1.6.1: Step 1 - Start complex test planning")
response1, continuation_id = self.call_mcp_tool(
"testgen",
{
"step": "Analyzing complex DataProcessor class for comprehensive test generation",
"step_number": 1,
"total_steps": 4,
"next_step_required": True,
"findings": "DataProcessor is an async class with caching, error handling, and statistics. Need async test patterns.",
"files_checked": [processor_file],
"relevant_files": [processor_file],
"relevant_context": ["DataProcessor", "process_batch", "_process_single_item", "get_stats"],
"confidence": "low",
"model": "flash",
},
)
if not response1 or not continuation_id:
self.logger.error("Failed to start multi-step test planning")
return False
response1_data = self._parse_testgen_response(response1)
# Validate step 1
file_context1 = response1_data.get("file_context", {})
if file_context1.get("type") != "reference_only":
self.logger.error("Step 1 should use reference_only file context")
return False
self.logger.info(" ✅ Step 1: Started complex test planning")
# Step 2: Analyze async patterns
self.logger.info(" 1.6.2: Step 2 - Async pattern analysis")
response2, _ = self.call_mcp_tool(
"testgen",
{
"step": "Analyzing async patterns and edge cases for testing",
"step_number": 2,
"total_steps": 4,
"next_step_required": True,
"continuation_id": continuation_id,
"findings": "Key test areas: async batch processing, cache behavior, error handling, batch size limits, empty items, statistics calculation",
"files_checked": [processor_file],
"relevant_files": [processor_file],
"relevant_context": ["process_batch", "_process_single_item"],
"confidence": "medium",
"model": "flash",
},
)
if not response2:
self.logger.error("Failed to continue to step 2")
return False
self.logger.info(" ✅ Step 2: Async patterns analyzed")
# Step 3: Edge case identification
self.logger.info(" 1.6.3: Step 3 - Edge case identification")
response3, _ = self.call_mcp_tool(
"testgen",
{
"step": "Identifying all edge cases and boundary conditions",
"step_number": 3,
"total_steps": 4,
"next_step_required": True,
"continuation_id": continuation_id,
"findings": "Edge cases: empty batch, oversized batch, items without ID, cache hits/misses, concurrent processing, error accumulation",
"files_checked": [processor_file],
"relevant_files": [processor_file],
"confidence": "high",
"model": "flash",
},
)
if not response3:
self.logger.error("Failed to continue to step 3")
return False
self.logger.info(" ✅ Step 3: Edge cases identified")
# Step 4: Final test plan with expert analysis
self.logger.info(" 1.6.4: Step 4 - Complete test plan")
response4, _ = self.call_mcp_tool(
"testgen",
{
"step": "Test planning complete with comprehensive coverage strategy",
"step_number": 4,
"total_steps": 4,
"next_step_required": False, # Final step
"continuation_id": continuation_id,
"findings": "Complete async test suite plan: unit tests for each method, integration tests for batch processing, edge case coverage, performance tests",
"files_checked": [processor_file],
"relevant_files": [processor_file],
"confidence": "high",
"model": "flash",
},
)
if not response4:
self.logger.error("Failed to complete to final step")
return False
response4_data = self._parse_testgen_response(response4)
# Validate final step
if response4_data.get("status") != "calling_expert_analysis":
self.logger.error("Final step should trigger expert analysis")
return False
file_context4 = response4_data.get("file_context", {})
if file_context4.get("type") != "fully_embedded":
self.logger.error("Final step should use fully_embedded file context")
return False
self.logger.info(" ✅ Multi-step test planning completed successfully")
return True
except Exception as e:
self.logger.error(f"Multi-step test planning test failed: {e}")
return False
```
--------------------------------------------------------------------------------
/tests/test_model_restrictions.py:
--------------------------------------------------------------------------------
```python
"""Tests for model restriction functionality."""
import os
from unittest.mock import MagicMock, patch
import pytest
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
from providers.shared import ProviderType
from utils.model_restrictions import ModelRestrictionService
class TestModelRestrictionService:
"""Test cases for ModelRestrictionService."""
def test_no_restrictions_by_default(self):
"""Test that no restrictions exist when env vars are not set."""
with patch.dict(os.environ, {}, clear=True):
service = ModelRestrictionService()
# Should allow all models
assert service.is_allowed(ProviderType.OPENAI, "o3")
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash")
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
assert service.is_allowed(ProviderType.OPENROUTER, "openai/o3")
# Should have no restrictions
assert not service.has_restrictions(ProviderType.OPENAI)
assert not service.has_restrictions(ProviderType.GOOGLE)
assert not service.has_restrictions(ProviderType.OPENROUTER)
def test_load_single_model_restriction(self):
"""Test loading a single allowed model."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"}):
service = ModelRestrictionService()
# Should only allow o3-mini
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
assert not service.is_allowed(ProviderType.OPENAI, "o4-mini")
# Google and OpenRouter should have no restrictions
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4")
def test_load_multiple_models_restriction(self):
"""Test loading multiple allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so alias resolution for allow-lists is available
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
from providers.registry import ModelProviderRegistry
def fake_get_provider(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=fake_get_provider):
service = ModelRestrictionService()
# Check OpenAI models
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Check Google models
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro")
def test_case_insensitive_and_whitespace_handling(self):
"""Test that model names are case-insensitive and whitespace is trimmed."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": " O3-MINI , o4-Mini "}):
service = ModelRestrictionService()
# Should work with any case
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "O3-MINI")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
assert service.is_allowed(ProviderType.OPENAI, "O4-Mini")
def test_empty_string_allows_all(self):
"""Test that empty string allows all models (same as unset)."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "", "GOOGLE_ALLOWED_MODELS": "flash"}):
service = ModelRestrictionService()
# OpenAI should allow all models (empty string = no restrictions)
assert service.is_allowed(ProviderType.OPENAI, "o3")
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert service.is_allowed(ProviderType.OPENAI, "o4-mini")
# Google should only allow flash (and its resolved name)
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
assert not service.is_allowed(ProviderType.GOOGLE, "pro")
assert not service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
def test_filter_models(self):
"""Test filtering a list of models based on restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
service = ModelRestrictionService()
models = ["o3", "o3-mini", "o4-mini", "o3-pro"]
filtered = service.filter_models(ProviderType.OPENAI, models)
assert filtered == ["o3-mini", "o4-mini"]
def test_get_allowed_models(self):
"""Test getting the set of allowed models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mini"}):
service = ModelRestrictionService()
allowed = service.get_allowed_models(ProviderType.OPENAI)
assert allowed == {"o3-mini", "o4-mini"}
# No restrictions for Google
assert service.get_allowed_models(ProviderType.GOOGLE) is None
def test_shorthand_names_in_restrictions(self):
"""Test that shorthand names work in restrictions."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4mini,o3mini", "GOOGLE_ALLOWED_MODELS": "flash,pro"}):
# Instantiate providers so the registry can resolve aliases
OpenAIModelProvider(api_key="test-key")
GeminiModelProvider(api_key="test-key")
service = ModelRestrictionService()
# When providers check models, they pass both resolved and original names
# OpenAI: 'o4mini' shorthand allows o4-mini
assert service.is_allowed(ProviderType.OPENAI, "o4-mini", "o4mini") # How providers actually call it
assert service.is_allowed(ProviderType.OPENAI, "o4-mini") # Canonical should also be allowed
# OpenAI: o3-mini allowed directly
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Google should allow both models via shorthands
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-flash", "flash")
assert service.is_allowed(ProviderType.GOOGLE, "gemini-2.5-pro", "pro")
# Also test that full names work when specified in restrictions
assert service.is_allowed(ProviderType.OPENAI, "o3-mini", "o3mini") # Even with shorthand
def test_validation_against_known_models(self, caplog):
"""Test validation warnings for unknown models."""
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini,o4-mimi"}): # Note the typo: o4-mimi
service = ModelRestrictionService()
# Create mock provider with known models
mock_provider = MagicMock()
mock_provider.MODEL_CAPABILITIES = {
"o3": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
"o4-mini": {"context_window": 200000},
}
mock_provider.list_models.return_value = ["o3", "o3-mini", "o4-mini"]
provider_instances = {ProviderType.OPENAI: mock_provider}
service.validate_against_known_models(provider_instances)
# Should have logged a warning about the typo
assert "o4-mimi" in caplog.text
assert "not a recognized" in caplog.text
def test_openrouter_model_restrictions(self):
"""Test OpenRouter model restrictions functionality."""
with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet"}):
service = ModelRestrictionService()
# Should only allow specified OpenRouter models
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
assert service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-opus-4", "opus") # With original name
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
assert not service.is_allowed(ProviderType.OPENROUTER, "anthropic/claude-3-haiku")
assert not service.is_allowed(ProviderType.OPENROUTER, "mistral-large")
# Other providers should have no restrictions
assert service.is_allowed(ProviderType.OPENAI, "o3")
assert service.is_allowed(ProviderType.GOOGLE, "pro")
# Should have restrictions for OpenRouter
assert service.has_restrictions(ProviderType.OPENROUTER)
assert not service.has_restrictions(ProviderType.OPENAI)
assert not service.has_restrictions(ProviderType.GOOGLE)
def test_openrouter_filter_models(self):
"""Test filtering OpenRouter models based on restrictions."""
with patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,mistral"}):
service = ModelRestrictionService()
models = ["opus", "sonnet", "haiku", "mistral", "llama"]
filtered = service.filter_models(ProviderType.OPENROUTER, models)
assert filtered == ["opus", "mistral"]
def test_combined_provider_restrictions(self):
"""Test that restrictions work correctly when set for multiple providers."""
with patch.dict(
os.environ,
{
"OPENAI_ALLOWED_MODELS": "o3-mini",
"GOOGLE_ALLOWED_MODELS": "flash",
"OPENROUTER_ALLOWED_MODELS": "opus,sonnet",
},
):
service = ModelRestrictionService()
# OpenAI restrictions
assert service.is_allowed(ProviderType.OPENAI, "o3-mini")
assert not service.is_allowed(ProviderType.OPENAI, "o3")
# Google restrictions
assert service.is_allowed(ProviderType.GOOGLE, "flash")
assert not service.is_allowed(ProviderType.GOOGLE, "pro")
# OpenRouter restrictions
assert service.is_allowed(ProviderType.OPENROUTER, "opus")
assert service.is_allowed(ProviderType.OPENROUTER, "sonnet")
assert not service.is_allowed(ProviderType.OPENROUTER, "haiku")
# All providers should have restrictions
assert service.has_restrictions(ProviderType.OPENAI)
assert service.has_restrictions(ProviderType.GOOGLE)
assert service.has_restrictions(ProviderType.OPENROUTER)
class TestProviderIntegration:
"""Test integration with actual providers."""
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini"})
def test_openai_provider_respects_restrictions(self):
"""Test that OpenAI provider respects restrictions."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = OpenAIModelProvider(api_key="test-key")
# Should validate allowed model
assert provider.validate_model_name("o3-mini")
# Should not validate disallowed model
assert not provider.validate_model_name("o3")
# get_capabilities should raise for disallowed model
with pytest.raises(ValueError) as exc_info:
provider.get_capabilities("o3")
assert "not allowed by restriction policy" in str(exc_info.value)
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash,flash"})
def test_gemini_provider_respects_restrictions(self):
"""Test that Gemini provider respects restrictions."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
# Should validate allowed models (both shorthand and full name allowed)
assert provider.validate_model_name("flash")
assert provider.validate_model_name("gemini-2.5-flash")
# Should not validate disallowed model
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
# get_capabilities should raise for disallowed model
with pytest.raises(ValueError) as exc_info:
provider.get_capabilities("pro")
assert "not allowed by restriction policy" in str(exc_info.value)
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "flash"})
def test_gemini_parameter_order_regression_protection(self):
"""Test that prevents regression of parameter order bug in is_allowed calls.
This test specifically catches the bug where parameters were incorrectly
passed as (provider, user_input, resolved_name) instead of
(provider, resolved_name, user_input).
The bug was subtle because the is_allowed method uses OR logic, so it
worked in most cases by accident. This test creates a scenario where
the parameter order matters.
"""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
from providers.registry import ModelProviderRegistry
with patch.object(ModelProviderRegistry, "get_provider", return_value=provider):
# Test case: Only alias "flash" is allowed, not the full name
# If parameters are in wrong order, this test will catch it
# Should allow "flash" alias
assert provider.validate_model_name("flash")
# Should allow getting capabilities for "flash"
capabilities = provider.get_capabilities("flash")
assert capabilities.model_name == "gemini-2.5-flash"
# Canonical form should also be allowed now that alias is on the allowlist
assert provider.validate_model_name("gemini-2.5-flash")
# Unrelated models remain blocked
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
@patch.dict(os.environ, {"GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"})
def test_gemini_parameter_order_edge_case_full_name_only(self):
"""Test parameter order with only full name allowed, not alias.
This is the reverse scenario - only the full canonical name is allowed,
not the shorthand alias. This tests that the parameter order is correct
when resolving aliases.
"""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
provider = GeminiModelProvider(api_key="test-key")
# Should allow full name
assert provider.validate_model_name("gemini-2.5-flash")
# Should also allow alias that resolves to allowed full name
# This works because is_allowed checks both resolved_name and original_name
assert provider.validate_model_name("flash")
# Should not allow "pro" alias
assert not provider.validate_model_name("pro")
assert not provider.validate_model_name("gemini-2.5-pro")
class TestCustomProviderOpenRouterRestrictions:
"""Test custom provider integration with OpenRouter restrictions."""
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus,sonnet", "OPENROUTER_API_KEY": "test-key"})
def test_custom_provider_respects_openrouter_restrictions(self):
"""Test that custom provider correctly defers OpenRouter models to OpenRouter provider."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
assert not provider.validate_model_name("opus")
assert not provider.validate_model_name("sonnet")
assert not provider.validate_model_name("haiku")
# Should still validate custom models defined in conf/custom_models.json
assert provider.validate_model_name("local-llama")
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus", "OPENROUTER_API_KEY": "test-key"})
def test_custom_provider_openrouter_capabilities_restrictions(self):
"""Test that custom provider's get_capabilities correctly handles OpenRouter models."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# For OpenRouter models, CustomProvider should defer by raising
with pytest.raises(ValueError):
provider.get_capabilities("opus")
# Should raise for disallowed OpenRouter model (still defers)
with pytest.raises(ValueError):
provider.get_capabilities("haiku")
# Should still work for custom models
capabilities = provider.get_capabilities("local-llama")
assert capabilities.provider == ProviderType.CUSTOM
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "opus"}, clear=False)
def test_custom_provider_no_openrouter_key_ignores_restrictions(self):
"""Test that when OpenRouter key is not set, cloud models are rejected regardless of restrictions."""
# Make sure OPENROUTER_API_KEY is not set
if "OPENROUTER_API_KEY" in os.environ:
del os.environ["OPENROUTER_API_KEY"]
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# Should not validate OpenRouter models when key is not available
assert not provider.validate_model_name("opus") # Even though it's in allowed list
assert not provider.validate_model_name("haiku")
# Should still validate custom models
assert provider.validate_model_name("local-llama")
@patch.dict(os.environ, {"OPENROUTER_ALLOWED_MODELS": "", "OPENROUTER_API_KEY": "test-key"})
def test_custom_provider_empty_restrictions_allows_all_openrouter(self):
"""Test that custom provider correctly defers OpenRouter models regardless of restrictions."""
# Clear any cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.custom import CustomProvider
provider = CustomProvider(base_url="http://test.com/v1")
# CustomProvider should NOT validate OpenRouter models - they should be deferred to OpenRouter
assert not provider.validate_model_name("opus")
assert not provider.validate_model_name("sonnet")
assert not provider.validate_model_name("haiku")
class TestRegistryIntegration:
"""Test integration with ModelProviderRegistry."""
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
def test_registry_with_shorthand_restrictions(self):
"""Test that registry handles shorthand restrictions correctly."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
from providers.registry import ModelProviderRegistry
# Clear registry cache
ModelProviderRegistry.clear_cache()
# Get available models with restrictions
# This test documents current behavior - get_available_models doesn't handle aliases
ModelProviderRegistry.get_available_models(respect_restrictions=True)
# Currently, this will be empty because get_available_models doesn't
# recognize that "mini" allows "o4-mini"
# This is a known limitation that should be documented
@patch("providers.registry.ModelProviderRegistry.get_provider")
def test_get_available_models_respects_restrictions(self, mock_get_provider):
"""Test that registry filters models based on restrictions."""
from providers.registry import ModelProviderRegistry
# Mock providers
mock_openai = MagicMock()
mock_openai.MODEL_CAPABILITIES = {
"o3": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
}
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
def openai_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
if isinstance(config, str):
target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
continue
if include_aliases:
models.append(model_name)
else:
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
continue
models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models
mock_openai.list_models = MagicMock(side_effect=openai_list_models)
mock_gemini = MagicMock()
mock_gemini.MODEL_CAPABILITIES = {
"gemini-2.5-pro": {"context_window": 1048576},
"gemini-2.5-flash": {"context_window": 1048576},
}
mock_gemini.get_provider_type.return_value = ProviderType.GOOGLE
def gemini_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in mock_gemini.MODEL_CAPABILITIES.items():
if isinstance(config, str):
target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, target_model):
continue
if include_aliases:
models.append(model_name)
else:
if restriction_service and not restriction_service.is_allowed(ProviderType.GOOGLE, model_name):
continue
models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models
mock_gemini.list_models = MagicMock(side_effect=gemini_list_models)
def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI:
return mock_openai
elif provider_type == ProviderType.GOOGLE:
return mock_gemini
return None
mock_get_provider.side_effect = get_provider_side_effect
# Set up registry with providers
registry = ModelProviderRegistry()
registry._providers = {
ProviderType.OPENAI: type(mock_openai),
ProviderType.GOOGLE: type(mock_gemini),
}
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3-mini", "GOOGLE_ALLOWED_MODELS": "gemini-2.5-flash"}):
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
available = ModelProviderRegistry.get_available_models(respect_restrictions=True)
# Should only include allowed models
assert "o3-mini" in available
assert "o3" not in available
assert "gemini-2.5-flash" in available
assert "gemini-2.5-pro" not in available
class TestShorthandRestrictions:
"""Test that shorthand model names work correctly in restrictions."""
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "mini", "GOOGLE_ALLOWED_MODELS": "flash"})
def test_providers_validate_shorthands_correctly(self):
"""Test that providers correctly validate shorthand names."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Test OpenAI provider
openai_provider = OpenAIModelProvider(api_key="test-key")
gemini_provider = GeminiModelProvider(api_key="test-key")
from providers.registry import ModelProviderRegistry
def registry_side_effect(provider_type, force_new=False):
mapping = {
ProviderType.OPENAI: openai_provider,
ProviderType.GOOGLE: gemini_provider,
}
return mapping.get(provider_type)
with patch.object(ModelProviderRegistry, "get_provider", side_effect=registry_side_effect):
assert openai_provider.validate_model_name("mini") # Should work with shorthand
assert openai_provider.validate_model_name("gpt-5-mini") # Canonical resolved from shorthand
assert not openai_provider.validate_model_name("o4-mini") # Unrelated model still blocked
assert not openai_provider.validate_model_name("o3-mini")
# Test Gemini provider
assert gemini_provider.validate_model_name("flash") # Should work with shorthand
assert gemini_provider.validate_model_name("gemini-2.5-flash") # Canonical allowed
assert not gemini_provider.validate_model_name("pro") # Not allowed
@patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o3mini,mini,o4-mini"})
def test_multiple_shorthands_for_same_model(self):
"""Test that multiple shorthands work correctly."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
openai_provider = OpenAIModelProvider(api_key="test-key")
# Both shorthands should work
assert openai_provider.validate_model_name("mini") # mini -> o4-mini
assert openai_provider.validate_model_name("o3mini") # o3mini -> o3-mini
# Resolved names should be allowed when their shorthands are present
assert openai_provider.validate_model_name("o4-mini") # Explicitly allowed
assert openai_provider.validate_model_name("o3-mini") # Allowed via shorthand
# Other models should not work
assert not openai_provider.validate_model_name("o3")
assert not openai_provider.validate_model_name("o3-pro")
@patch.dict(
os.environ,
{"OPENAI_ALLOWED_MODELS": "mini,o4-mini", "GOOGLE_ALLOWED_MODELS": "flash,gemini-2.5-flash"},
)
def test_both_shorthand_and_full_name_allowed(self):
"""Test that we can allow both shorthand and full names."""
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# OpenAI - both mini and o4-mini are allowed
openai_provider = OpenAIModelProvider(api_key="test-key")
assert openai_provider.validate_model_name("mini")
assert openai_provider.validate_model_name("o4-mini")
# Gemini - both flash and full name are allowed
gemini_provider = GeminiModelProvider(api_key="test-key")
assert gemini_provider.validate_model_name("flash")
assert gemini_provider.validate_model_name("gemini-2.5-flash")
class TestAutoModeWithRestrictions:
"""Test auto mode behavior with restrictions."""
@patch("providers.registry.ModelProviderRegistry.get_provider")
def test_fallback_model_respects_restrictions(self, mock_get_provider):
"""Test that fallback model selection respects restrictions."""
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
# Mock providers
mock_openai = MagicMock()
mock_openai.MODEL_CAPABILITIES = {
"o3": {"context_window": 200000},
"o3-mini": {"context_window": 200000},
"o4-mini": {"context_window": 200000},
}
mock_openai.get_provider_type.return_value = ProviderType.OPENAI
def openai_list_models(
*,
respect_restrictions: bool = True,
include_aliases: bool = True,
lowercase: bool = False,
unique: bool = False,
):
from utils.model_restrictions import get_restriction_service
restriction_service = get_restriction_service() if respect_restrictions else None
models = []
for model_name, config in mock_openai.MODEL_CAPABILITIES.items():
if isinstance(config, str):
target_model = config
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, target_model):
continue
if include_aliases:
models.append(model_name)
else:
if restriction_service and not restriction_service.is_allowed(ProviderType.OPENAI, model_name):
continue
models.append(model_name)
if lowercase:
models = [m.lower() for m in models]
if unique:
seen = set()
ordered = []
for name in models:
if name in seen:
continue
seen.add(name)
ordered.append(name)
models = ordered
return models
mock_openai.list_models = MagicMock(side_effect=openai_list_models)
# Add get_preferred_model method to mock to match new implementation
def get_preferred_model(category, allowed_models):
# Simple preference logic for testing - just return first allowed model
return allowed_models[0] if allowed_models else None
mock_openai.get_preferred_model = get_preferred_model
def get_provider_side_effect(provider_type):
if provider_type == ProviderType.OPENAI:
return mock_openai
return None
mock_get_provider.side_effect = get_provider_side_effect
# Set up registry
registry = ModelProviderRegistry()
registry._providers = {ProviderType.OPENAI: type(mock_openai)}
with patch.dict(os.environ, {"OPENAI_ALLOWED_MODELS": "o4-mini"}):
# Clear cached restriction service
import utils.model_restrictions
utils.model_restrictions._restriction_service = None
# Should pick o4-mini instead of o3-mini for fast response
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
assert model == "o4-mini"
def test_fallback_with_shorthand_restrictions(self, monkeypatch):
"""Test fallback model selection with shorthand restrictions."""
# Use monkeypatch to set environment variables with automatic cleanup
monkeypatch.setenv("OPENAI_ALLOWED_MODELS", "mini")
monkeypatch.setenv("GEMINI_API_KEY", "")
monkeypatch.setenv("OPENAI_API_KEY", "test-key")
# Clear caches and reset registry
import utils.model_restrictions
from providers.registry import ModelProviderRegistry
from tools.models import ToolModelCategory
utils.model_restrictions._restriction_service = None
# Store original providers for restoration
registry = ModelProviderRegistry()
original_providers = registry._providers.copy()
original_initialized = registry._initialized_providers.copy()
try:
# Clear registry and register only OpenAI and Gemini providers
ModelProviderRegistry._instance = None
from providers.gemini import GeminiModelProvider
from providers.openai import OpenAIModelProvider
ModelProviderRegistry.register_provider(ProviderType.OPENAI, OpenAIModelProvider)
ModelProviderRegistry.register_provider(ProviderType.GOOGLE, GeminiModelProvider)
# Even with "mini" restriction, fallback should work if provider handles it correctly
# This tests the real-world scenario
model = ModelProviderRegistry.get_preferred_fallback_model(ToolModelCategory.FAST_RESPONSE)
# The fallback will depend on how get_available_models handles aliases
# When "mini" is allowed, it's returned as the allowed model
# "mini" is now an alias for gpt-5-mini, but the list shows "mini" itself
assert model in ["mini", "gpt-5-mini", "o4-mini", "gemini-2.5-flash"]
finally:
# Restore original registry state
registry = ModelProviderRegistry()
registry._providers.clear()
registry._initialized_providers.clear()
registry._providers.update(original_providers)
registry._initialized_providers.update(original_initialized)
```
--------------------------------------------------------------------------------
/providers/openai_compatible.py:
--------------------------------------------------------------------------------
```python
"""Base class for OpenAI-compatible API providers."""
import copy
import ipaddress
import logging
from typing import Optional
from urllib.parse import urlparse
from openai import OpenAI
from utils.env import get_env, suppress_env_vars
from utils.image_utils import validate_image
from .base import ModelProvider
from .shared import (
ModelCapabilities,
ModelResponse,
ProviderType,
)
class OpenAICompatibleProvider(ModelProvider):
"""Shared implementation for OpenAI API lookalikes.
The class owns HTTP client configuration (timeouts, proxy hardening,
custom headers) and normalises the OpenAI SDK responses into
:class:`~providers.shared.ModelResponse`. Concrete subclasses only need to
provide capability metadata and any provider-specific request tweaks.
"""
DEFAULT_HEADERS = {}
FRIENDLY_NAME = "OpenAI Compatible"
def __init__(self, api_key: str, base_url: str = None, **kwargs):
"""Initialize the provider with API key and optional base URL.
Args:
api_key: API key for authentication
base_url: Base URL for the API endpoint
**kwargs: Additional configuration options including timeout
"""
self._allowed_alias_cache: dict[str, str] = {}
super().__init__(api_key, **kwargs)
self._client = None
self.base_url = base_url
self.organization = kwargs.get("organization")
self.allowed_models = self._parse_allowed_models()
# Configure timeouts - especially important for custom/local endpoints
self.timeout_config = self._configure_timeouts(**kwargs)
# Validate base URL for security
if self.base_url:
self._validate_base_url()
# Warn if using external URL without authentication
if self.base_url and not self._is_localhost_url() and not api_key:
logging.warning(
f"Using external URL '{self.base_url}' without API key. "
"This may be insecure. Consider setting an API key for authentication."
)
def _ensure_model_allowed(
self,
capabilities: ModelCapabilities,
canonical_name: str,
requested_name: str,
) -> None:
"""Respect provider-specific allowlists before default restriction checks."""
super()._ensure_model_allowed(capabilities, canonical_name, requested_name)
if self.allowed_models is not None:
requested = requested_name.lower()
canonical = canonical_name.lower()
if requested not in self.allowed_models and canonical not in self.allowed_models:
allowed = False
for allowed_entry in list(self.allowed_models):
normalized_resolved = self._allowed_alias_cache.get(allowed_entry)
if normalized_resolved is None:
try:
resolved_name = self._resolve_model_name(allowed_entry)
except Exception:
continue
if not resolved_name:
continue
normalized_resolved = resolved_name.lower()
self._allowed_alias_cache[allowed_entry] = normalized_resolved
if normalized_resolved == canonical:
# Canonical match discovered via alias resolution – mark as allowed and
# memoise the canonical entry for future lookups.
allowed = True
self._allowed_alias_cache[canonical] = canonical
self.allowed_models.add(canonical)
break
if not allowed:
raise ValueError(
f"Model '{requested_name}' is not allowed by restriction policy. Allowed models: {sorted(self.allowed_models)}"
)
def _parse_allowed_models(self) -> Optional[set[str]]:
"""Parse allowed models from environment variable.
Returns:
Set of allowed model names (lowercase) or None if not configured
"""
# Get provider-specific allowed models
provider_type = self.get_provider_type().value.upper()
env_var = f"{provider_type}_ALLOWED_MODELS"
models_str = get_env(env_var, "") or ""
if models_str:
# Parse and normalize to lowercase for case-insensitive comparison
models = {m.strip().lower() for m in models_str.split(",") if m.strip()}
if models:
logging.info(f"Configured allowed models for {self.FRIENDLY_NAME}: {sorted(models)}")
self._allowed_alias_cache = {}
return models
# Log info if no allow-list configured for proxy providers
if self.get_provider_type() not in [ProviderType.GOOGLE, ProviderType.OPENAI]:
logging.info(
f"Model allow-list not configured for {self.FRIENDLY_NAME} - all models permitted. "
f"To restrict access, set {env_var} with comma-separated model names."
)
return None
def _configure_timeouts(self, **kwargs):
"""Configure timeout settings based on provider type and custom settings.
Custom URLs and local models often need longer timeouts due to:
- Network latency on local networks
- Extended thinking models taking longer to respond
- Local inference being slower than cloud APIs
Returns:
httpx.Timeout object with appropriate timeout settings
"""
import httpx
# Default timeouts - more generous for custom/local endpoints
default_connect = 30.0 # 30 seconds for connection (vs OpenAI's 5s)
default_read = 600.0 # 10 minutes for reading (same as OpenAI default)
default_write = 600.0 # 10 minutes for writing
default_pool = 600.0 # 10 minutes for pool
# For custom/local URLs, use even longer timeouts
if self.base_url and self._is_localhost_url():
default_connect = 60.0 # 1 minute for local connections
default_read = 1800.0 # 30 minutes for local models (extended thinking)
default_write = 1800.0 # 30 minutes for local models
default_pool = 1800.0 # 30 minutes for local models
logging.info(f"Using extended timeouts for local endpoint: {self.base_url}")
elif self.base_url:
default_connect = 45.0 # 45 seconds for custom remote endpoints
default_read = 900.0 # 15 minutes for custom remote endpoints
default_write = 900.0 # 15 minutes for custom remote endpoints
default_pool = 900.0 # 15 minutes for custom remote endpoints
logging.info(f"Using extended timeouts for custom endpoint: {self.base_url}")
# Allow override via kwargs or environment variables in future, for now...
connect_timeout = kwargs.get("connect_timeout")
if connect_timeout is None:
connect_timeout_raw = get_env("CUSTOM_CONNECT_TIMEOUT")
connect_timeout = float(connect_timeout_raw) if connect_timeout_raw is not None else float(default_connect)
read_timeout = kwargs.get("read_timeout")
if read_timeout is None:
read_timeout_raw = get_env("CUSTOM_READ_TIMEOUT")
read_timeout = float(read_timeout_raw) if read_timeout_raw is not None else float(default_read)
write_timeout = kwargs.get("write_timeout")
if write_timeout is None:
write_timeout_raw = get_env("CUSTOM_WRITE_TIMEOUT")
write_timeout = float(write_timeout_raw) if write_timeout_raw is not None else float(default_write)
pool_timeout = kwargs.get("pool_timeout")
if pool_timeout is None:
pool_timeout_raw = get_env("CUSTOM_POOL_TIMEOUT")
pool_timeout = float(pool_timeout_raw) if pool_timeout_raw is not None else float(default_pool)
timeout = httpx.Timeout(connect=connect_timeout, read=read_timeout, write=write_timeout, pool=pool_timeout)
logging.debug(
f"Configured timeouts - Connect: {connect_timeout}s, Read: {read_timeout}s, "
f"Write: {write_timeout}s, Pool: {pool_timeout}s"
)
return timeout
def _is_localhost_url(self) -> bool:
"""Check if the base URL points to localhost or local network.
Returns:
True if URL is localhost or local network, False otherwise
"""
if not self.base_url:
return False
try:
parsed = urlparse(self.base_url)
hostname = parsed.hostname
# Check for common localhost patterns
if hostname in ["localhost", "127.0.0.1", "::1"]:
return True
# Check for private network ranges (local network)
if hostname:
try:
ip = ipaddress.ip_address(hostname)
return ip.is_private or ip.is_loopback
except ValueError:
# Not an IP address, might be a hostname
pass
return False
except Exception:
return False
def _validate_base_url(self) -> None:
"""Validate base URL for security (SSRF protection).
Raises:
ValueError: If URL is invalid or potentially unsafe
"""
if not self.base_url:
return
try:
parsed = urlparse(self.base_url)
# Check URL scheme - only allow http/https
if parsed.scheme not in ("http", "https"):
raise ValueError(f"Invalid URL scheme: {parsed.scheme}. Only http/https allowed.")
# Check hostname exists
if not parsed.hostname:
raise ValueError("URL must include a hostname")
# Check port is valid (if specified)
port = parsed.port
if port is not None and (port < 1 or port > 65535):
raise ValueError(f"Invalid port number: {port}. Must be between 1 and 65535.")
except Exception as e:
if isinstance(e, ValueError):
raise
raise ValueError(f"Invalid base URL '{self.base_url}': {str(e)}")
@property
def client(self):
"""Lazy initialization of OpenAI client with security checks and timeout configuration."""
if self._client is None:
import httpx
proxy_env_vars = ["HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"]
with suppress_env_vars(*proxy_env_vars):
try:
# Create a custom httpx client that explicitly avoids proxy parameters
timeout_config = (
self.timeout_config
if hasattr(self, "timeout_config") and self.timeout_config
else httpx.Timeout(30.0)
)
# Create httpx client with minimal config to avoid proxy conflicts
# Note: proxies parameter was removed in httpx 0.28.0
# Check for test transport injection
if hasattr(self, "_test_transport"):
# Use custom transport for testing (HTTP recording/replay)
http_client = httpx.Client(
transport=self._test_transport,
timeout=timeout_config,
follow_redirects=True,
)
else:
# Normal production client
http_client = httpx.Client(
timeout=timeout_config,
follow_redirects=True,
)
# Keep client initialization minimal to avoid proxy parameter conflicts
client_kwargs = {
"api_key": self.api_key,
"http_client": http_client,
}
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.organization:
client_kwargs["organization"] = self.organization
# Add default headers if any
if self.DEFAULT_HEADERS:
client_kwargs["default_headers"] = self.DEFAULT_HEADERS.copy()
logging.debug(
"OpenAI client initialized with custom httpx client and timeout: %s",
timeout_config,
)
# Create OpenAI client with custom httpx client
self._client = OpenAI(**client_kwargs)
except Exception as e:
# If all else fails, try absolute minimal client without custom httpx
logging.warning(
"Failed to create client with custom httpx, falling back to minimal config: %s",
e,
)
try:
minimal_kwargs = {"api_key": self.api_key}
if self.base_url:
minimal_kwargs["base_url"] = self.base_url
self._client = OpenAI(**minimal_kwargs)
except Exception as fallback_error:
logging.error("Even minimal OpenAI client creation failed: %s", fallback_error)
raise
return self._client
def _sanitize_for_logging(self, params: dict) -> dict:
"""Sanitize sensitive data from parameters before logging.
Args:
params: Dictionary of API parameters
Returns:
dict: Sanitized copy of parameters safe for logging
"""
sanitized = copy.deepcopy(params)
# Sanitize messages content
if "input" in sanitized:
for msg in sanitized.get("input", []):
if isinstance(msg, dict) and "content" in msg:
for content_item in msg.get("content", []):
if isinstance(content_item, dict) and "text" in content_item:
# Truncate long text and add ellipsis
text = content_item["text"]
if len(text) > 100:
content_item["text"] = text[:100] + "... [truncated]"
# Remove any API keys that might be in headers/auth
sanitized.pop("api_key", None)
sanitized.pop("authorization", None)
return sanitized
def _safe_extract_output_text(self, response) -> str:
"""Safely extract output_text from o3-pro response with validation.
Args:
response: Response object from OpenAI SDK
Returns:
str: The output text content
Raises:
ValueError: If output_text is missing, None, or not a string
"""
logging.debug(f"Response object type: {type(response)}")
logging.debug(f"Response attributes: {dir(response)}")
if not hasattr(response, "output_text"):
raise ValueError(f"o3-pro response missing output_text field. Response type: {type(response).__name__}")
content = response.output_text
logging.debug(f"Extracted output_text: '{content}' (type: {type(content)})")
if content is None:
raise ValueError("o3-pro returned None for output_text")
if not isinstance(content, str):
raise ValueError(f"o3-pro output_text is not a string. Got type: {type(content).__name__}")
return content
def _generate_with_responses_endpoint(
self,
model_name: str,
messages: list,
temperature: float,
max_output_tokens: Optional[int] = None,
capabilities: Optional[ModelCapabilities] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the /v1/responses endpoint for reasoning models."""
# Convert messages to the correct format for responses endpoint
input_messages = []
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
if role == "system":
# For o3-pro, system messages should be handled carefully to avoid policy violations
# Instead of prefixing with "System:", we'll include the system content naturally
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
elif role == "user":
input_messages.append({"role": "user", "content": [{"type": "input_text", "text": content}]})
elif role == "assistant":
input_messages.append({"role": "assistant", "content": [{"type": "output_text", "text": content}]})
# Prepare completion parameters for responses endpoint
# Based on OpenAI documentation, use nested reasoning object for responses endpoint
effort = "medium"
if capabilities and capabilities.default_reasoning_effort:
effort = capabilities.default_reasoning_effort
completion_params = {
"model": model_name,
"input": input_messages,
"reasoning": {"effort": effort},
"store": True,
}
# Add max tokens if specified (using max_completion_tokens for responses endpoint)
if max_output_tokens:
completion_params["max_completion_tokens"] = max_output_tokens
# For responses endpoint, we only add parameters that are explicitly supported
# Remove unsupported chat completion parameters that may cause API errors
# Retry logic with progressive delays
max_retries = 4
retry_delays = [1, 3, 5, 8]
attempt_counter = {"value": 0}
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
import json
sanitized_params = self._sanitize_for_logging(completion_params)
logging.info(
f"o3-pro API request (sanitized): {json.dumps(sanitized_params, indent=2, ensure_ascii=False)}"
)
response = self.client.responses.create(**completion_params)
content = self._safe_extract_output_text(response)
usage = None
if hasattr(response, "usage"):
usage = self._extract_usage(response)
elif hasattr(response, "input_tokens") and hasattr(response, "output_tokens"):
input_tokens = getattr(response, "input_tokens", 0) or 0
output_tokens = getattr(response, "output_tokens", 0) or 0
usage = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
}
return ModelResponse(
content=content,
usage=usage,
model_name=model_name,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"model": getattr(response, "model", model_name),
"id": getattr(response, "id", ""),
"created": getattr(response, "created_at", 0),
"endpoint": "responses",
},
)
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix="responses endpoint",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = f"responses endpoint error after {attempts} attempt{'s' if attempts > 1 else ''}: {exc}"
logging.error(error_msg)
raise RuntimeError(error_msg) from exc
def generate_content(
self,
prompt: str,
model_name: str,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
max_output_tokens: Optional[int] = None,
images: Optional[list[str]] = None,
**kwargs,
) -> ModelResponse:
"""Generate content using the OpenAI-compatible API.
Args:
prompt: User prompt to send to the model
model_name: Canonical model name or its alias
system_prompt: Optional system prompt for model behavior
temperature: Sampling temperature
max_output_tokens: Maximum tokens to generate
images: Optional list of image paths or data URLs to include with the prompt (for vision models)
**kwargs: Additional provider-specific parameters
Returns:
ModelResponse with generated content and metadata
"""
# Validate model name against allow-list
if not self.validate_model_name(model_name):
raise ValueError(f"Model '{model_name}' not in allowed models list. Allowed models: {self.allowed_models}")
capabilities: Optional[ModelCapabilities]
try:
capabilities = self.get_capabilities(model_name)
except Exception as exc:
logging.debug(f"Falling back to generic capabilities for {model_name}: {exc}")
capabilities = None
# Get effective temperature for this model from capabilities when available
if capabilities:
effective_temperature = capabilities.get_effective_temperature(temperature)
if effective_temperature is not None and effective_temperature != temperature:
logging.debug(
f"Adjusting temperature from {temperature} to {effective_temperature} for model {model_name}"
)
else:
effective_temperature = temperature
# Only validate if temperature is not None (meaning the model supports it)
if effective_temperature is not None:
# Validate parameters with the effective temperature
self.validate_parameters(model_name, effective_temperature)
# Resolve to canonical model name
resolved_model = self._resolve_model_name(model_name)
# Prepare messages
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Prepare user message with text and potentially images
user_content = []
user_content.append({"type": "text", "text": prompt})
# Add images if provided and model supports vision
if images and capabilities and capabilities.supports_images:
for image_path in images:
try:
image_content = self._process_image(image_path)
if image_content:
user_content.append(image_content)
except Exception as e:
logging.warning(f"Failed to process image {image_path}: {e}")
# Continue with other images and text
continue
elif images and (not capabilities or not capabilities.supports_images):
logging.warning(f"Model {resolved_model} does not support images, ignoring {len(images)} image(s)")
# Add user message
if len(user_content) == 1:
# Only text content, use simple string format for compatibility
messages.append({"role": "user", "content": prompt})
else:
# Text + images, use content array format
messages.append({"role": "user", "content": user_content})
# Prepare completion parameters
# Always disable streaming for OpenRouter
# MCP doesn't use streaming, and this avoids issues with O3 model access
completion_params = {
"model": resolved_model,
"messages": messages,
"stream": False,
}
# Use the effective temperature we calculated earlier
supports_sampling = effective_temperature is not None
if supports_sampling:
completion_params["temperature"] = effective_temperature
# Add max tokens if specified and model supports it
# O3/O4 models that don't support temperature also don't support max_tokens
if max_output_tokens and supports_sampling:
completion_params["max_tokens"] = max_output_tokens
# Add any additional OpenAI-specific parameters
# Use capabilities to filter parameters for reasoning models
for key, value in kwargs.items():
if key in ["top_p", "frequency_penalty", "presence_penalty", "seed", "stop", "stream"]:
# Reasoning models (those that don't support temperature) also don't support these parameters
if not supports_sampling and key in ["top_p", "frequency_penalty", "presence_penalty", "stream"]:
continue # Skip unsupported parameters for reasoning models
completion_params[key] = value
# Check if this model needs the Responses API endpoint
# Prefer capability metadata; fall back to static map when capabilities unavailable
use_responses_api = False
if capabilities is not None:
use_responses_api = getattr(capabilities, "use_openai_response_api", False)
else:
static_capabilities = self.get_all_model_capabilities().get(resolved_model)
if static_capabilities is not None:
use_responses_api = getattr(static_capabilities, "use_openai_response_api", False)
if use_responses_api:
# These models require the /v1/responses endpoint for stateful context
# If it fails, we should not fall back to chat/completions
return self._generate_with_responses_endpoint(
model_name=resolved_model,
messages=messages,
temperature=temperature,
max_output_tokens=max_output_tokens,
capabilities=capabilities,
**kwargs,
)
# Retry logic with progressive delays
max_retries = 4 # Total of 4 attempts
retry_delays = [1, 3, 5, 8] # Progressive delays: 1s, 3s, 5s, 8s
attempt_counter = {"value": 0}
def _attempt() -> ModelResponse:
attempt_counter["value"] += 1
response = self.client.chat.completions.create(**completion_params)
content = response.choices[0].message.content
usage = self._extract_usage(response)
return ModelResponse(
content=content,
usage=usage,
model_name=resolved_model,
friendly_name=self.FRIENDLY_NAME,
provider=self.get_provider_type(),
metadata={
"finish_reason": response.choices[0].finish_reason,
"model": response.model,
"id": response.id,
"created": response.created,
},
)
try:
return self._run_with_retries(
operation=_attempt,
max_attempts=max_retries,
delays=retry_delays,
log_prefix=f"{self.FRIENDLY_NAME} API ({resolved_model})",
)
except Exception as exc:
attempts = max(attempt_counter["value"], 1)
error_msg = (
f"{self.FRIENDLY_NAME} API error for model {resolved_model} after {attempts} attempt"
f"{'s' if attempts > 1 else ''}: {exc}"
)
logging.error(error_msg)
raise RuntimeError(error_msg) from exc
def validate_parameters(self, model_name: str, temperature: float, **kwargs) -> None:
"""Validate model parameters.
For proxy providers, this may use generic capabilities.
Args:
model_name: Canonical model name or its alias
temperature: Temperature to validate
**kwargs: Additional parameters to validate
"""
try:
capabilities = self.get_capabilities(model_name)
# Check if we're using generic capabilities
if hasattr(capabilities, "_is_generic"):
logging.debug(
f"Using generic parameter validation for {model_name}. Actual model constraints may differ."
)
# Validate temperature using parent class method
super().validate_parameters(model_name, temperature, **kwargs)
except Exception as e:
# For proxy providers, we might not have accurate capabilities
# Log warning but don't fail
logging.warning(f"Parameter validation limited for {model_name}: {e}")
def _extract_usage(self, response) -> dict[str, int]:
"""Extract token usage from OpenAI response.
Args:
response: OpenAI API response object
Returns:
Dictionary with usage statistics
"""
usage = {}
if hasattr(response, "usage") and response.usage:
# Safely extract token counts with None handling
usage["input_tokens"] = getattr(response.usage, "prompt_tokens", 0) or 0
usage["output_tokens"] = getattr(response.usage, "completion_tokens", 0) or 0
usage["total_tokens"] = getattr(response.usage, "total_tokens", 0) or 0
return usage
def count_tokens(self, text: str, model_name: str) -> int:
"""Count tokens using OpenAI-compatible tokenizer tables when available."""
resolved_model = self._resolve_model_name(model_name)
try:
import tiktoken
try:
encoding = tiktoken.encoding_for_model(resolved_model)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(text))
except (ImportError, Exception) as exc:
logging.debug("tiktoken unavailable for %s: %s", resolved_model, exc)
return super().count_tokens(text, model_name)
def _is_error_retryable(self, error: Exception) -> bool:
"""Determine if an error should be retried based on structured error codes.
Uses OpenAI API error structure instead of text pattern matching for reliability.
Args:
error: Exception from OpenAI API call
Returns:
True if error should be retried, False otherwise
"""
error_str = str(error).lower()
# Check for 429 errors first - these need special handling
if "429" in error_str:
# Try to extract structured error information
error_type = None
error_code = None
# Parse structured error from OpenAI API response
# Format: "Error code: 429 - {'error': {'type': 'tokens', 'code': 'rate_limit_exceeded', ...}}"
try:
import ast
import json
import re
# Extract JSON part from error string using regex
# Look for pattern: {...} (from first { to last })
json_match = re.search(r"\{.*\}", str(error))
if json_match:
json_like_str = json_match.group(0)
# First try: parse as Python literal (handles single quotes safely)
try:
error_data = ast.literal_eval(json_like_str)
except (ValueError, SyntaxError):
# Fallback: try JSON parsing with simple quote replacement
# (for cases where it's already valid JSON or simple replacements work)
json_str = json_like_str.replace("'", '"')
error_data = json.loads(json_str)
if "error" in error_data:
error_info = error_data["error"]
error_type = error_info.get("type")
error_code = error_info.get("code")
except (json.JSONDecodeError, ValueError, SyntaxError, AttributeError):
# Fall back to checking hasattr for OpenAI SDK exception objects
if hasattr(error, "response") and hasattr(error.response, "json"):
try:
response_data = error.response.json()
if "error" in response_data:
error_info = response_data["error"]
error_type = error_info.get("type")
error_code = error_info.get("code")
except Exception:
pass
# Determine if 429 is retryable based on structured error codes
if error_type == "tokens":
# Token-related 429s are typically non-retryable (request too large)
logging.debug(f"Non-retryable 429: token-related error (type={error_type}, code={error_code})")
return False
elif error_code in ["invalid_request_error", "context_length_exceeded"]:
# These are permanent failures
logging.debug(f"Non-retryable 429: permanent failure (type={error_type}, code={error_code})")
return False
else:
# Other 429s (like requests per minute) are retryable
logging.debug(f"Retryable 429: rate limiting (type={error_type}, code={error_code})")
return True
# For non-429 errors, check if they're retryable
retryable_indicators = [
"timeout",
"connection",
"network",
"temporary",
"unavailable",
"retry",
"408", # Request timeout
"500", # Internal server error
"502", # Bad gateway
"503", # Service unavailable
"504", # Gateway timeout
"ssl", # SSL errors
"handshake", # Handshake failures
]
return any(indicator in error_str for indicator in retryable_indicators)
def _process_image(self, image_path: str) -> Optional[dict]:
"""Process an image for OpenAI-compatible API."""
try:
if image_path.startswith("data:"):
# Validate the data URL
validate_image(image_path)
# Handle data URL: data:image/png;base64,iVBORw0...
return {"type": "image_url", "image_url": {"url": image_path}}
else:
# Use base class validation
image_bytes, mime_type = validate_image(image_path)
# Read and encode the image
import base64
image_data = base64.b64encode(image_bytes).decode()
logging.debug(f"Processing image '{image_path}' as MIME type '{mime_type}'")
# Create data URL for OpenAI API
data_url = f"data:{mime_type};base64,{image_data}"
return {"type": "image_url", "image_url": {"url": data_url}}
except ValueError as e:
logging.warning(str(e))
return None
except Exception as e:
logging.error(f"Error processing image {image_path}: {e}")
return None
```
--------------------------------------------------------------------------------
/tools/tracer.py:
--------------------------------------------------------------------------------
```python
"""
Tracer Workflow tool - Step-by-step code tracing and dependency analysis
This tool provides a structured workflow for comprehensive code tracing and analysis.
It guides the CLI agent through systematic investigation steps with forced pauses between each step
to ensure thorough code examination, dependency mapping, and execution flow analysis before proceeding.
The tracer guides users through sequential code analysis with full context awareness and
the ability to revise and adapt as understanding deepens.
Key features:
- Sequential tracing with systematic investigation workflow
- Support for precision tracing (execution flow) and dependencies tracing (structural relationships)
- Self-contained completion with detailed output formatting instructions
- Context-aware analysis that builds understanding step by step
- No external expert analysis needed - provides comprehensive guidance internally
Perfect for: method/function execution flow analysis, dependency mapping, call chain tracing,
structural relationship analysis, architectural understanding, and code comprehension.
"""
import logging
from typing import TYPE_CHECKING, Any, Literal, Optional
from pydantic import Field, field_validator
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from config import TEMPERATURE_ANALYTICAL
from systemprompts import TRACER_PROMPT
from tools.shared.base_models import WorkflowRequest
from .workflow.base import WorkflowTool
logger = logging.getLogger(__name__)
# Tool-specific field descriptions for tracer workflow
TRACER_WORKFLOW_FIELD_DESCRIPTIONS = {
"step": (
"The plan for the current tracing step. Step 1: State the tracing strategy. Later steps: Report findings and adapt the plan. "
"CRITICAL: For 'precision' mode, focus on execution flow and call chains. For 'dependencies' mode, focus on structural relationships. "
"If trace_mode is 'ask' in step 1, you MUST prompt the user to choose a mode."
),
"step_number": (
"The index of the current step in the tracing sequence, beginning at 1. Each step should build upon or "
"revise the previous one."
),
"total_steps": (
"Your current estimate for how many steps will be needed to complete the tracing analysis. "
"Adjust as new findings emerge."
),
"next_step_required": (
"Set to true if you plan to continue the investigation with another step. False means you believe the "
"tracing analysis is complete and ready for final output formatting."
),
"findings": (
"Summary of discoveries from this step, including execution paths, dependency relationships, call chains, and structural patterns. "
"IMPORTANT: Document both direct (immediate calls) and indirect (transitive, side effects) relationships."
),
"files_checked": (
"List all files examined (absolute paths). Include even ruled-out files to track exploration path."
),
"relevant_files": (
"Subset of files_checked directly relevant to the tracing target (absolute paths). Include implementation files, "
"dependencies, or files demonstrating key relationships."
),
"relevant_context": (
"List methods/functions central to the tracing analysis, in 'ClassName.methodName' or 'functionName' format. "
"Prioritize those in the execution flow or dependency chain."
),
"confidence": (
"Your confidence in the tracing analysis. Use: 'exploring', 'low', 'medium', 'high', 'very_high', 'almost_certain', 'certain'. "
"CRITICAL: 'certain' implies the analysis is 100% complete locally and PREVENTS external model validation."
),
"trace_mode": "Type of tracing: 'ask' (default - prompts user to choose mode), 'precision' (execution flow) or 'dependencies' (structural relationships)",
"target_description": (
"Description of what to trace and WHY. Include context about what you're trying to understand or analyze."
),
"images": ("Optional paths to architecture diagrams or flow charts that help understand the tracing context."),
}
class TracerRequest(WorkflowRequest):
"""Request model for tracer workflow investigation steps"""
# Required fields for each investigation step
step: str = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["step"])
step_number: int = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
total_steps: int = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
next_step_required: bool = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
# Investigation tracking fields
findings: str = Field(..., description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
files_checked: list[str] = Field(
default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
)
relevant_files: list[str] = Field(
default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
)
relevant_context: list[str] = Field(
default_factory=list, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
)
confidence: Optional[str] = Field("exploring", description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["confidence"])
# Tracer-specific fields (used in step 1 to initialize)
trace_mode: Optional[Literal["precision", "dependencies", "ask"]] = Field(
"ask", description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["trace_mode"]
)
target_description: Optional[str] = Field(
None, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["target_description"]
)
images: Optional[list[str]] = Field(default=None, description=TRACER_WORKFLOW_FIELD_DESCRIPTIONS["images"])
# Exclude fields not relevant to tracing workflow
issues_found: list[dict] = Field(default_factory=list, exclude=True, description="Tracing doesn't track issues")
hypothesis: Optional[str] = Field(default=None, exclude=True, description="Tracing doesn't use hypothesis")
# Exclude other non-tracing fields
temperature: Optional[float] = Field(default=None, exclude=True)
thinking_mode: Optional[str] = Field(default=None, exclude=True)
use_assistant_model: Optional[bool] = Field(default=False, exclude=True, description="Tracing is self-contained")
@field_validator("step_number")
@classmethod
def validate_step_number(cls, v):
if v < 1:
raise ValueError("step_number must be at least 1")
return v
@field_validator("total_steps")
@classmethod
def validate_total_steps(cls, v):
if v < 1:
raise ValueError("total_steps must be at least 1")
return v
class TracerTool(WorkflowTool):
"""
Tracer workflow tool for step-by-step code tracing and dependency analysis.
This tool implements a structured tracing workflow that guides users through
methodical investigation steps, ensuring thorough code examination, dependency
mapping, and execution flow analysis before reaching conclusions. It supports
both precision tracing (execution flow) and dependencies tracing (structural relationships).
"""
def __init__(self):
super().__init__()
self.initial_request = None
self.trace_config = {}
def get_name(self) -> str:
return "tracer"
def get_description(self) -> str:
return (
"Performs systematic code tracing with modes for execution flow or dependency mapping. "
"Use for method execution analysis, call chain tracing, dependency mapping, and architectural understanding. "
"Supports precision mode (execution flow) and dependencies mode (structural relationships)."
)
def get_system_prompt(self) -> str:
return TRACER_PROMPT
def get_default_temperature(self) -> float:
return TEMPERATURE_ANALYTICAL
def get_model_category(self) -> "ToolModelCategory":
"""Tracer requires analytical reasoning for code analysis"""
from tools.models import ToolModelCategory
return ToolModelCategory.EXTENDED_REASONING
def requires_model(self) -> bool:
"""
Tracer tool doesn't require model resolution at the MCP boundary.
The tracer is a structured workflow tool that organizes tracing steps
and provides detailed output formatting guidance without calling external AI models.
Returns:
bool: False - tracer doesn't need AI model access
"""
return False
def get_workflow_request_model(self):
"""Return the tracer-specific request model."""
return TracerRequest
def get_tool_fields(self) -> dict[str, dict[str, Any]]:
"""Return tracing-specific field definitions beyond the standard workflow fields."""
return {
# Tracer-specific fields
"trace_mode": {
"type": "string",
"enum": ["precision", "dependencies", "ask"],
"description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["trace_mode"],
},
"target_description": {
"type": "string",
"description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["target_description"],
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": TRACER_WORKFLOW_FIELD_DESCRIPTIONS["images"],
},
}
def get_input_schema(self) -> dict[str, Any]:
"""Generate input schema using WorkflowSchemaBuilder with field exclusion."""
from .workflow.schema_builders import WorkflowSchemaBuilder
# Exclude investigation-specific fields that tracing doesn't need
excluded_workflow_fields = [
"issues_found", # Tracing doesn't track issues
"hypothesis", # Tracing doesn't use hypothesis
]
# Exclude common fields that tracing doesn't need
excluded_common_fields = [
"temperature", # Tracing doesn't need temperature control
"thinking_mode", # Tracing doesn't need thinking mode
"absolute_file_paths", # Tracing uses relevant_files instead
]
return WorkflowSchemaBuilder.build_schema(
tool_specific_fields=self.get_tool_fields(),
required_fields=["target_description", "trace_mode"], # Step 1 requires these
model_field_schema=self.get_model_field_schema(),
auto_mode=self.is_effective_auto_mode(),
tool_name=self.get_name(),
excluded_workflow_fields=excluded_workflow_fields,
excluded_common_fields=excluded_common_fields,
)
# ================================================================================
# Abstract Methods - Required Implementation from BaseWorkflowMixin
# ================================================================================
def get_required_actions(
self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
) -> list[str]:
"""Define required actions for each tracing phase."""
if step_number == 1:
# Check if we're in ask mode and need to prompt for mode selection
if self.get_trace_mode() == "ask":
return [
"MUST ask user to choose between precision or dependencies mode",
"Explain precision mode: traces execution flow, call chains, and usage patterns (best for methods/functions)",
"Explain dependencies mode: maps structural relationships and bidirectional dependencies (best for classes/modules)",
"Wait for user's mode selection before proceeding with investigation",
]
# Initial tracing investigation tasks (when mode is already selected)
return [
"Search for and locate the target method/function/class/module in the codebase",
"Read and understand the implementation of the target code",
"Identify the file location, complete signature, and basic structure",
"Begin mapping immediate relationships (what it calls, what calls it)",
"Understand the context and purpose of the target code",
]
elif confidence in ["exploring", "low"]:
# Need deeper investigation
return [
"Trace deeper into the execution flow or dependency relationships",
"Examine how the target code is used throughout the codebase",
"Map additional layers of dependencies or call chains",
"Look for conditional execution paths, error handling, and edge cases",
"Understand the broader architectural context and patterns",
]
elif confidence in ["medium", "high"]:
# Close to completion - need final verification
return [
"Verify completeness of the traced relationships and execution paths",
"Check for any missed dependencies, usage patterns, or execution branches",
"Confirm understanding of side effects, state changes, and external interactions",
"Validate that the tracing covers all significant code relationships",
"Prepare comprehensive findings for final output formatting",
]
else:
# General investigation needed
return [
"Continue systematic tracing of code relationships and execution paths",
"Gather more evidence using appropriate code analysis techniques",
"Test assumptions about code behavior and dependency relationships",
"Look for patterns that enhance understanding of the code structure",
"Focus on areas that haven't been thoroughly traced yet",
]
def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
"""Tracer is self-contained and doesn't need expert analysis."""
return False
def prepare_expert_analysis_context(self, consolidated_findings) -> str:
"""Tracer doesn't use expert analysis."""
return ""
def requires_expert_analysis(self) -> bool:
"""Tracer is self-contained like the planner tool."""
return False
# ================================================================================
# Workflow Customization - Match Planner Behavior
# ================================================================================
def prepare_step_data(self, request) -> dict:
"""
Prepare step data from request with tracer-specific fields.
"""
step_data = {
"step": request.step,
"step_number": request.step_number,
"findings": request.findings,
"files_checked": request.files_checked,
"relevant_files": request.relevant_files,
"relevant_context": request.relevant_context,
"issues_found": [], # Tracer doesn't track issues
"confidence": request.confidence or "exploring",
"hypothesis": None, # Tracer doesn't use hypothesis
"images": request.images or [],
# Tracer-specific fields
"trace_mode": request.trace_mode,
"target_description": request.target_description,
}
return step_data
def build_base_response(self, request, continuation_id: str = None) -> dict:
"""
Build the base response structure with tracer-specific fields.
"""
# Use work_history from workflow mixin for consistent step tracking
current_step_count = len(self.work_history) + 1
response_data = {
"status": f"{self.get_name()}_in_progress",
"step_number": request.step_number,
"total_steps": request.total_steps,
"next_step_required": request.next_step_required,
"step_content": request.step,
f"{self.get_name()}_status": {
"files_checked": len(self.consolidated_findings.files_checked),
"relevant_files": len(self.consolidated_findings.relevant_files),
"relevant_context": len(self.consolidated_findings.relevant_context),
"issues_found": len(self.consolidated_findings.issues_found),
"images_collected": len(self.consolidated_findings.images),
"current_confidence": self.get_request_confidence(request),
"step_history_length": current_step_count,
},
"metadata": {
"trace_mode": self.trace_config.get("trace_mode", "unknown"),
"target_description": self.trace_config.get("target_description", ""),
"step_history_length": current_step_count,
},
}
if continuation_id:
response_data["continuation_id"] = continuation_id
return response_data
def handle_work_continuation(self, response_data: dict, request) -> dict:
"""
Handle work continuation with tracer-specific guidance.
"""
response_data["status"] = f"pause_for_{self.get_name()}"
response_data[f"{self.get_name()}_required"] = True
# Get tracer-specific required actions
required_actions = self.get_required_actions(
request.step_number, request.confidence or "exploring", request.findings, request.total_steps
)
response_data["required_actions"] = required_actions
# Generate step-specific guidance
if request.step_number == 1:
# Check if we're in ask mode and need to prompt for mode selection
if self.get_trace_mode() == "ask":
response_data["next_steps"] = (
f"STOP! You MUST ask the user to choose a tracing mode before proceeding. "
f"Present these options clearly:\\n\\n"
f"**PRECISION MODE**: Traces execution flow, call chains, and usage patterns. "
f"Best for understanding how a specific method or function works, what it calls, "
f"and how data flows through the execution path.\\n\\n"
f"**DEPENDENCIES MODE**: Maps structural relationships and bidirectional dependencies. "
f"Best for understanding how a class or module relates to other components, "
f"what depends on it, and what it depends on.\\n\\n"
f"After the user selects a mode, call {self.get_name()} again with step_number: 1 "
f"but with the chosen trace_mode (either 'precision' or 'dependencies')."
)
else:
response_data["next_steps"] = (
f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first investigate "
f"the codebase to understand the target code. CRITICAL AWARENESS: You need to find and understand "
f"the target method/function/class/module, examine its implementation, and begin mapping its "
f"relationships. Use file reading tools, code search, and systematic examination to gather "
f"comprehensive information about the target. Only call {self.get_name()} again AFTER completing "
f"your investigation. When you call {self.get_name()} next time, use step_number: {request.step_number + 1} "
f"and report specific files examined, code structure discovered, and initial relationship findings."
)
elif request.confidence in ["exploring", "low"]:
next_step = request.step_number + 1
response_data["next_steps"] = (
f"STOP! Do NOT call {self.get_name()} again yet. Based on your findings, you've identified areas that need "
f"deeper tracing analysis. MANDATORY ACTIONS before calling {self.get_name()} step {next_step}:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\\n\\nOnly call {self.get_name()} again with step_number: {next_step} AFTER "
+ "completing these tracing investigations."
)
elif request.confidence in ["medium", "high"]:
next_step = request.step_number + 1
response_data["next_steps"] = (
f"WAIT! Your tracing analysis needs final verification. DO NOT call {self.get_name()} immediately. "
f"REQUIRED ACTIONS:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\\n\\nREMEMBER: Ensure you have traced all significant relationships and execution paths. "
f"Document findings with specific file references and method signatures, then call {self.get_name()} "
f"with step_number: {next_step}."
)
else:
# General investigation needed
next_step = request.step_number + 1
remaining_steps = request.total_steps - request.step_number
response_data["next_steps"] = (
f"Continue systematic tracing with step {next_step}. Approximately {remaining_steps} steps remaining. "
f"Focus on deepening your understanding of the code relationships and execution patterns."
)
return response_data
def customize_workflow_response(self, response_data: dict, request) -> dict:
"""
Customize response to match tracer tool format with output instructions.
"""
# Store trace configuration on first step
if request.step_number == 1:
self.initial_request = request.step
self.trace_config = {
"trace_mode": request.trace_mode,
"target_description": request.target_description,
}
# Update metadata with trace configuration
if "metadata" in response_data:
response_data["metadata"]["trace_mode"] = request.trace_mode or "unknown"
response_data["metadata"]["target_description"] = request.target_description or ""
# If in ask mode, mark this as mode selection phase
if request.trace_mode == "ask":
response_data["mode_selection_required"] = True
response_data["status"] = "mode_selection_required"
# Add tracer-specific output instructions for final steps
if not request.next_step_required:
response_data["tracing_complete"] = True
response_data["trace_summary"] = f"TRACING COMPLETE: {request.step}"
# Get mode-specific output instructions
trace_mode = self.trace_config.get("trace_mode", "precision")
rendering_instructions = self._get_rendering_instructions(trace_mode)
response_data["output"] = {
"instructions": (
"This is a structured tracing analysis response. Present the comprehensive tracing findings "
"using the specific rendering format for the trace mode. Follow the exact formatting guidelines "
"provided in rendering_instructions. Include all discovered relationships, execution paths, "
"and dependencies with precise file references and line numbers."
),
"format": f"{trace_mode}_trace_analysis",
"rendering_instructions": rendering_instructions,
"presentation_guidelines": {
"completed_trace": (
"Use the exact rendering format specified for the trace mode. Include comprehensive "
"diagrams, tables, and structured analysis. Reference specific file paths and line numbers. "
"Follow formatting rules precisely."
),
"step_content": "Present as main analysis with clear structure and actionable insights.",
"continuation": "Use continuation_id for related tracing sessions or follow-up analysis",
},
}
response_data["next_steps"] = (
f"Tracing analysis complete. Present the comprehensive {trace_mode} trace analysis to the user "
f"using the exact rendering format specified in the output instructions. Follow the formatting "
f"guidelines precisely, including diagrams, tables, and file references. After presenting the "
f"analysis, offer to help with related tracing tasks or use the continuation_id for follow-up analysis."
)
# Convert generic status names to tracer-specific ones
tool_name = self.get_name()
status_mapping = {
f"{tool_name}_in_progress": "tracing_in_progress",
f"pause_for_{tool_name}": "pause_for_tracing",
f"{tool_name}_required": "tracing_required",
f"{tool_name}_complete": "tracing_complete",
}
if response_data["status"] in status_mapping:
response_data["status"] = status_mapping[response_data["status"]]
return response_data
def _get_rendering_instructions(self, trace_mode: str) -> str:
"""
Get mode-specific rendering instructions for the CLI agent.
Args:
trace_mode: Either "precision" or "dependencies"
Returns:
str: Complete rendering instructions for the specified mode
"""
if trace_mode == "precision":
return self._get_precision_rendering_instructions()
else: # dependencies mode
return self._get_dependencies_rendering_instructions()
def _get_precision_rendering_instructions(self) -> str:
"""Get rendering instructions for precision trace mode."""
return """
## MANDATORY RENDERING INSTRUCTIONS FOR PRECISION TRACE
You MUST render the trace analysis using ONLY the Vertical Indented Flow Style:
### CALL FLOW DIAGRAM - Vertical Indented Style
**EXACT FORMAT TO FOLLOW:**
```
[ClassName::MethodName] (file: /complete/file/path.ext, line: ##)
↓
[AnotherClass::calledMethod] (file: /path/to/file.ext, line: ##)
↓
[ThirdClass::nestedMethod] (file: /path/file.ext, line: ##)
↓
[DeeperClass::innerCall] (file: /path/inner.ext, line: ##) ? if some_condition
↓
[ServiceClass::processData] (file: /services/service.ext, line: ##)
↓
[RepositoryClass::saveData] (file: /data/repo.ext, line: ##)
↓
[ClientClass::sendRequest] (file: /clients/client.ext, line: ##)
↓
[EmailService::sendEmail] (file: /email/service.ext, line: ##) ⚠️ ambiguous branch
→
[SMSService::sendSMS] (file: /sms/service.ext, line: ##) ⚠️ ambiguous branch
```
**CRITICAL FORMATTING RULES:**
1. **Method Names**: Use the actual naming convention of the project language you're analyzing. Automatically detect and adapt to the project's conventions (camelCase, snake_case, PascalCase, etc.) based on the codebase structure and file extensions.
2. **Vertical Flow Arrows**:
- Use `↓` for standard sequential calls (vertical flow)
- Use `→` for parallel/alternative calls (horizontal branch)
- NEVER use other arrow types
3. **Indentation Logic**:
- Start at column 0 for entry point
- Indent 2 spaces for each nesting level
- Maintain consistent indentation for same call depth
- Sibling calls at same level should have same indentation
4. **Conditional Calls**:
- Add `? if condition_description` after method for conditional execution
- Use actual condition names from code when possible
5. **Ambiguous Branches**:
- Mark with `⚠️ ambiguous branch` when execution path is uncertain
- Use `→` to show alternative paths at same indentation level
6. **File Path Format**:
- Use complete relative paths from project root
- Include actual file extensions from the project
- Show exact line numbers where method is defined
### ADDITIONAL ANALYSIS VIEWS
**1. BRANCHING & SIDE EFFECT TABLE**
| Location | Condition | Branches | Uncertain |
|----------|-----------|----------|-----------|
| CompleteFileName.ext:## | if actual_condition_from_code | method1(), method2(), else skip | No |
| AnotherFile.ext:## | if boolean_check | callMethod(), else return | No |
| ThirdFile.ext:## | if validation_passes | processData(), else throw | Yes |
**2. SIDE EFFECTS**
```
Side Effects:
- [database] Specific database operation description (CompleteFileName.ext:##)
- [network] Specific network call description (CompleteFileName.ext:##)
- [filesystem] Specific file operation description (CompleteFileName.ext:##)
- [state] State changes or property modifications (CompleteFileName.ext:##)
- [memory] Memory allocation or cache operations (CompleteFileName.ext:##)
```
**3. USAGE POINTS**
```
Usage Points:
1. FileName.ext:## - Context description of where/why it's called
2. AnotherFile.ext:## - Context description of usage scenario
3. ThirdFile.ext:## - Context description of calling pattern
4. FourthFile.ext:## - Context description of integration point
```
**4. ENTRY POINTS**
```
Entry Points:
- ClassName::methodName (context: where this flow typically starts)
- AnotherClass::entryMethod (context: alternative entry scenario)
- ThirdClass::triggerMethod (context: event-driven entry point)
```
**ABSOLUTE REQUIREMENTS:**
- Use ONLY the vertical indented style for the call flow diagram
- Present ALL FOUR additional analysis views (Branching Table, Side Effects, Usage Points, Entry Points)
- Adapt method naming to match the project's programming language conventions
- Use exact file paths and line numbers from the actual codebase
- DO NOT invent or guess method names or locations
- Follow indentation rules precisely for call hierarchy
- Mark uncertain execution paths clearly
- Provide contextual descriptions in Usage Points and Entry Points sections
- Include comprehensive side effects categorization (database, network, filesystem, state, memory)"""
def _get_dependencies_rendering_instructions(self) -> str:
"""Get rendering instructions for dependencies trace mode."""
return """
## MANDATORY RENDERING INSTRUCTIONS FOR DEPENDENCIES TRACE
You MUST render the trace analysis using ONLY the Bidirectional Arrow Flow Style:
### DEPENDENCY FLOW DIAGRAM - Bidirectional Arrow Style
**EXACT FORMAT TO FOLLOW:**
```
INCOMING DEPENDENCIES → [TARGET_CLASS/MODULE] → OUTGOING DEPENDENCIES
CallerClass::callerMethod ←────┐
AnotherCaller::anotherMethod ←─┤
ThirdCaller::thirdMethod ←─────┤
│
[TARGET_CLASS/MODULE]
│
├────→ FirstDependency::method
├────→ SecondDependency::method
└────→ ThirdDependency::method
TYPE RELATIONSHIPS:
InterfaceName ──implements──→ [TARGET_CLASS] ──extends──→ BaseClass
DTOClass ──uses──→ [TARGET_CLASS] ──uses──→ EntityClass
```
**CRITICAL FORMATTING RULES:**
1. **Target Placement**: Always place the target class/module in square brackets `[TARGET_NAME]` at the center
2. **Incoming Dependencies**: Show on the left side with `←` arrows pointing INTO the target
3. **Outgoing Dependencies**: Show on the right side with `→` arrows pointing OUT FROM the target
4. **Arrow Alignment**: Use consistent spacing and alignment for visual clarity
5. **Method Naming**: Use the project's actual naming conventions detected from the codebase
6. **File References**: Include complete file paths and line numbers
**VISUAL LAYOUT RULES:**
1. **Header Format**: Always start with the flow direction indicator
2. **Left Side (Incoming)**:
- List all callers with `←` arrows
- Use `┐`, `┤`, `┘` box drawing characters for clean connection lines
- Align arrows consistently
3. **Center (Target)**:
- Enclose target in square brackets
- Position centrally between incoming and outgoing
4. **Right Side (Outgoing)**:
- List all dependencies with `→` arrows
- Use `├`, `└` box drawing characters for branching
- Maintain consistent spacing
5. **Type Relationships Section**:
- Use `──relationship──→` format with double hyphens
- Show inheritance, implementation, and usage relationships
- Place below the main flow diagram
**DEPENDENCY TABLE:**
| Type | From/To | Method | File | Line |
|------|---------|--------|------|------|
| incoming_call | From: CallerClass | callerMethod | /complete/path/file.ext | ## |
| outgoing_call | To: TargetClass | targetMethod | /complete/path/file.ext | ## |
| implements | Self: ThisClass | — | /complete/path/file.ext | — |
| extends | Self: ThisClass | — | /complete/path/file.ext | — |
| uses_type | Self: ThisClass | — | /complete/path/file.ext | — |
**ABSOLUTE REQUIREMENTS:**
- Use ONLY the bidirectional arrow flow style shown above
- Automatically detect and use the project's naming conventions
- Use exact file paths and line numbers from the actual codebase
- DO NOT invent or guess method/class names
- Maintain visual alignment and consistent spacing
- Include type relationships section when applicable
- Show clear directional flow with proper arrows"""
# ================================================================================
# Hook Method Overrides for Tracer-Specific Behavior
# ================================================================================
def get_completion_status(self) -> str:
"""Tracer uses tracing-specific status."""
return "tracing_complete"
def get_completion_data_key(self) -> str:
"""Tracer uses 'complete_tracing' key."""
return "complete_tracing"
def get_completion_message(self) -> str:
"""Tracer-specific completion message."""
return (
"Tracing analysis complete. Present the comprehensive trace analysis to the user "
"using the specified rendering format and offer to help with related tracing tasks."
)
def get_skip_reason(self) -> str:
"""Tracer-specific skip reason."""
return "Tracer is self-contained and completes analysis without external assistance"
def get_skip_expert_analysis_status(self) -> str:
"""Tracer-specific expert analysis skip status."""
return "skipped_by_tool_design"
def store_initial_issue(self, step_description: str):
"""Store initial tracing description."""
self.initial_tracing_description = step_description
def get_initial_request(self, fallback_step: str) -> str:
"""Get initial tracing description."""
try:
return self.initial_tracing_description
except AttributeError:
return fallback_step
def get_request_confidence(self, request) -> str:
"""Get confidence from request for tracer workflow."""
try:
return request.confidence or "exploring"
except AttributeError:
return "exploring"
def get_trace_mode(self) -> str:
"""Get current trace mode. Override for custom trace mode handling."""
try:
return self.trace_config.get("trace_mode", "ask")
except AttributeError:
return "ask"
# Required abstract methods from BaseTool
def get_request_model(self):
"""Return the tracer-specific request model."""
return TracerRequest
async def prepare_prompt(self, request) -> str:
"""Not used - workflow tools use execute_workflow()."""
return "" # Workflow tools use execute_workflow() directly
```
--------------------------------------------------------------------------------
/tools/codereview.py:
--------------------------------------------------------------------------------
```python
"""
CodeReview Workflow tool - Systematic code review with step-by-step analysis
This tool provides a structured workflow for comprehensive code review and analysis.
It guides the CLI agent through systematic investigation steps with forced pauses between each step
to ensure thorough code examination, issue identification, and quality assessment before proceeding.
The tool supports complex review scenarios including security analysis, performance evaluation,
and architectural assessment.
Key features:
- Step-by-step code review workflow with progress tracking
- Context-aware file embedding (references during investigation, full content for analysis)
- Automatic issue tracking with severity classification
- Expert analysis integration with external models
- Support for focused reviews (security, performance, architecture)
- Confidence-based workflow optimization
"""
import logging
from typing import TYPE_CHECKING, Any, Literal, Optional
from pydantic import Field, model_validator
if TYPE_CHECKING:
from tools.models import ToolModelCategory
from config import TEMPERATURE_ANALYTICAL
from systemprompts import CODEREVIEW_PROMPT
from tools.shared.base_models import WorkflowRequest
from .workflow.base import WorkflowTool
logger = logging.getLogger(__name__)
# Tool-specific field descriptions for code review workflow
CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS = {
"step": (
"Review narrative. Step 1: outline the review strategy. Later steps: report findings. MUST cover quality, security, "
"performance, and architecture. Reference code via `relevant_files`; avoid dumping large snippets."
),
"step_number": "Current review step (starts at 1) – each step should build on the last.",
"total_steps": (
"Number of review steps planned. External validation: two steps (analysis + summary). Internal validation: one step. "
"Use the same limits when continuing an existing review via continuation_id."
),
"next_step_required": (
"True when another review step follows. External validation: step 1 → True, step 2 → False. Internal validation: set False immediately. "
"Apply the same rule on continuation flows."
),
"findings": "Capture findings (positive and negative) across quality, security, performance, and architecture; update each step.",
"files_checked": "Absolute paths of every file reviewed, including those ruled out.",
"relevant_files": "Step 1: list all files/dirs under review. Must be absolute full non-abbreviated paths. Final step: narrow to files tied to key findings.",
"relevant_context": "Functions or methods central to findings (e.g. 'Class.method' or 'function_name').",
"issues_found": "Issues with severity (critical/high/medium/low) and descriptions.",
"review_validation_type": "Set 'external' (default) for expert follow-up or 'internal' for local-only review.",
"images": "Optional diagram or screenshot paths that clarify review context.",
"review_type": "Review focus: full, security, performance, or quick.",
"focus_on": "Optional note on areas to emphasise (e.g. 'threading', 'auth flow').",
"standards": "Coding standards or style guides to enforce.",
"severity_filter": "Lowest severity to include when reporting issues (critical/high/medium/low/all).",
}
class CodeReviewRequest(WorkflowRequest):
"""Request model for code review workflow investigation steps"""
# Required fields for each investigation step
step: str = Field(..., description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["step"])
step_number: int = Field(..., description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["step_number"])
total_steps: int = Field(..., description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"])
next_step_required: bool = Field(..., description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"])
# Investigation tracking fields
findings: str = Field(..., description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["findings"])
files_checked: list[str] = Field(
default_factory=list, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"]
)
relevant_files: list[str] = Field(
default_factory=list, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"]
)
relevant_context: list[str] = Field(
default_factory=list, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["relevant_context"]
)
issues_found: list[dict] = Field(
default_factory=list, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"]
)
# Deprecated confidence field kept for backward compatibility only
confidence: Optional[str] = Field("low", exclude=True)
review_validation_type: Optional[Literal["external", "internal"]] = Field(
"external", description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS.get("review_validation_type", "")
)
# Optional images for visual context
images: Optional[list[str]] = Field(default=None, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["images"])
# Code review-specific fields (only used in step 1 to initialize)
review_type: Optional[Literal["full", "security", "performance", "quick"]] = Field(
"full", description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["review_type"]
)
focus_on: Optional[str] = Field(None, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["focus_on"])
standards: Optional[str] = Field(None, description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["standards"])
severity_filter: Optional[Literal["critical", "high", "medium", "low", "all"]] = Field(
"all", description=CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"]
)
# Override inherited fields to exclude them from schema (except model which needs to be available)
temperature: Optional[float] = Field(default=None, exclude=True)
thinking_mode: Optional[str] = Field(default=None, exclude=True)
@model_validator(mode="after")
def validate_step_one_requirements(self):
"""Ensure step 1 has required relevant_files field."""
if self.step_number == 1 and not self.relevant_files:
raise ValueError("Step 1 requires 'relevant_files' field to specify code files or directories to review")
return self
class CodeReviewTool(WorkflowTool):
"""
Code Review workflow tool for step-by-step code review and expert analysis.
This tool implements a structured code review workflow that guides users through
methodical investigation steps, ensuring thorough code examination, issue identification,
and quality assessment before reaching conclusions. It supports complex review scenarios
including security audits, performance analysis, architectural review, and maintainability assessment.
"""
def __init__(self):
super().__init__()
self.initial_request = None
self.review_config = {}
def get_name(self) -> str:
return "codereview"
def get_description(self) -> str:
return (
"Performs systematic, step-by-step code review with expert validation. "
"Use for comprehensive analysis covering quality, security, performance, and architecture. "
"Guides through structured investigation to ensure thoroughness."
)
def get_system_prompt(self) -> str:
return CODEREVIEW_PROMPT
def get_default_temperature(self) -> float:
return TEMPERATURE_ANALYTICAL
def get_model_category(self) -> "ToolModelCategory":
"""Code review requires thorough analysis and reasoning"""
from tools.models import ToolModelCategory
return ToolModelCategory.EXTENDED_REASONING
def get_workflow_request_model(self):
"""Return the code review workflow-specific request model."""
return CodeReviewRequest
def get_input_schema(self) -> dict[str, Any]:
"""Generate input schema using WorkflowSchemaBuilder with code review-specific overrides."""
from .workflow.schema_builders import WorkflowSchemaBuilder
# Code review workflow-specific field overrides
codereview_field_overrides = {
"step": {
"type": "string",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["step"],
},
"step_number": {
"type": "integer",
"minimum": 1,
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["step_number"],
},
"total_steps": {
"type": "integer",
"minimum": 1,
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["total_steps"],
},
"next_step_required": {
"type": "boolean",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["next_step_required"],
},
"findings": {
"type": "string",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["findings"],
},
"files_checked": {
"type": "array",
"items": {"type": "string"},
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["files_checked"],
},
"relevant_files": {
"type": "array",
"items": {"type": "string"},
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["relevant_files"],
},
"review_validation_type": {
"type": "string",
"enum": ["external", "internal"],
"default": "external",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS.get("review_validation_type", ""),
},
"issues_found": {
"type": "array",
"items": {"type": "object"},
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["issues_found"],
},
"images": {
"type": "array",
"items": {"type": "string"},
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["images"],
},
# Code review-specific fields (for step 1)
"review_type": {
"type": "string",
"enum": ["full", "security", "performance", "quick"],
"default": "full",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["review_type"],
},
"focus_on": {
"type": "string",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["focus_on"],
},
"standards": {
"type": "string",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["standards"],
},
"severity_filter": {
"type": "string",
"enum": ["critical", "high", "medium", "low", "all"],
"default": "all",
"description": CODEREVIEW_WORKFLOW_FIELD_DESCRIPTIONS["severity_filter"],
},
}
# Use WorkflowSchemaBuilder with code review-specific tool fields
return WorkflowSchemaBuilder.build_schema(
tool_specific_fields=codereview_field_overrides,
model_field_schema=self.get_model_field_schema(),
auto_mode=self.is_effective_auto_mode(),
tool_name=self.get_name(),
)
def get_required_actions(
self, step_number: int, confidence: str, findings: str, total_steps: int, request=None
) -> list[str]:
"""Define required actions for each investigation phase.
Now includes request parameter for continuation-aware decisions.
"""
# Check for continuation - fast track mode
if request:
continuation_id = self.get_request_continuation_id(request)
validation_type = self.get_review_validation_type(request)
if continuation_id and validation_type == "external":
if step_number == 1:
return [
"Quickly review the code files to understand context",
"Identify any critical issues that need immediate attention",
"Note main architectural patterns and design decisions",
"Prepare summary of key findings for expert validation",
]
else:
return ["Complete review and proceed to expert analysis"]
if step_number == 1:
# Initial code review investigation tasks
return [
"Read and understand the code files specified for review",
"Examine the overall structure, architecture, and design patterns used",
"Identify the main components, classes, and functions in the codebase",
"Understand the business logic and intended functionality",
"Look for obvious issues: bugs, security concerns, performance problems",
"Note any code smells, anti-patterns, or areas of concern",
]
elif step_number == 2:
# Deeper investigation for step 2
return [
"Examine specific code sections you've identified as concerning",
"Analyze security implications: input validation, authentication, authorization",
"Check for performance issues: algorithmic complexity, resource usage, inefficiencies",
"Look for architectural problems: tight coupling, missing abstractions, scalability issues",
"Identify code quality issues: readability, maintainability, error handling",
"Search for over-engineering, unnecessary complexity, or design patterns that could be simplified",
]
elif step_number >= 3:
# Final verification for later steps
return [
"Verify all identified issues have been properly documented with severity levels",
"Check for any missed critical security vulnerabilities or performance bottlenecks",
"Confirm that architectural concerns and code quality issues are comprehensively captured",
"Ensure positive aspects and well-implemented patterns are also noted",
"Validate that your assessment aligns with the review type and focus areas specified",
"Double-check that findings are actionable and provide clear guidance for improvements",
]
else:
# General investigation needed
return [
"Continue examining the codebase for additional patterns and potential issues",
"Gather more evidence using appropriate code analysis techniques",
"Test your assumptions about code behavior and design decisions",
"Look for patterns that confirm or refute your current assessment",
"Focus on areas that haven't been thoroughly examined yet",
]
def should_call_expert_analysis(self, consolidated_findings, request=None) -> bool:
"""
Decide when to call external model based on investigation completeness.
For continuations with external type, always proceed with expert analysis.
"""
# Check if user requested to skip assistant model
if request and not self.get_request_use_assistant_model(request):
return False
# For continuations with external type, always proceed with expert analysis
continuation_id = self.get_request_continuation_id(request)
validation_type = self.get_review_validation_type(request)
if continuation_id and validation_type == "external":
return True # Always perform expert analysis for external continuations
# Check if we have meaningful investigation data
return (
len(consolidated_findings.relevant_files) > 0
or len(consolidated_findings.findings) >= 2
or len(consolidated_findings.issues_found) > 0
)
def prepare_expert_analysis_context(self, consolidated_findings) -> str:
"""Prepare context for external model call for final code review validation."""
context_parts = [
f"=== CODE REVIEW REQUEST ===\\n{self.initial_request or 'Code review workflow initiated'}\\n=== END REQUEST ==="
]
# Add investigation summary
investigation_summary = self._build_code_review_summary(consolidated_findings)
context_parts.append(
f"\\n=== AGENT'S CODE REVIEW INVESTIGATION ===\\n{investigation_summary}\\n=== END INVESTIGATION ==="
)
# Add review configuration context if available
if self.review_config:
config_text = "\\n".join(f"- {key}: {value}" for key, value in self.review_config.items() if value)
context_parts.append(f"\\n=== REVIEW CONFIGURATION ===\\n{config_text}\\n=== END CONFIGURATION ===")
# Add relevant code elements if available
if consolidated_findings.relevant_context:
methods_text = "\\n".join(f"- {method}" for method in consolidated_findings.relevant_context)
context_parts.append(f"\\n=== RELEVANT CODE ELEMENTS ===\\n{methods_text}\\n=== END CODE ELEMENTS ===")
# Add issues found if available
if consolidated_findings.issues_found:
issues_text = "\\n".join(
f"[{issue.get('severity', 'unknown').upper()}] {issue.get('description', 'No description')}"
for issue in consolidated_findings.issues_found
)
context_parts.append(f"\\n=== ISSUES IDENTIFIED ===\\n{issues_text}\\n=== END ISSUES ===")
# Add assessment evolution if available
if consolidated_findings.hypotheses:
assessments_text = "\\n".join(
f"Step {h['step']} ({h['confidence']} confidence): {h['hypothesis']}"
for h in consolidated_findings.hypotheses
)
context_parts.append(f"\\n=== ASSESSMENT EVOLUTION ===\\n{assessments_text}\\n=== END ASSESSMENTS ===")
# Add images if available
if consolidated_findings.images:
images_text = "\\n".join(f"- {img}" for img in consolidated_findings.images)
context_parts.append(
f"\\n=== VISUAL REVIEW INFORMATION ===\\n{images_text}\\n=== END VISUAL INFORMATION ==="
)
return "\\n".join(context_parts)
def _build_code_review_summary(self, consolidated_findings) -> str:
"""Prepare a comprehensive summary of the code review investigation."""
summary_parts = [
"=== SYSTEMATIC CODE REVIEW INVESTIGATION SUMMARY ===",
f"Total steps: {len(consolidated_findings.findings)}",
f"Files examined: {len(consolidated_findings.files_checked)}",
f"Relevant files identified: {len(consolidated_findings.relevant_files)}",
f"Code elements analyzed: {len(consolidated_findings.relevant_context)}",
f"Issues identified: {len(consolidated_findings.issues_found)}",
"",
"=== INVESTIGATION PROGRESSION ===",
]
for finding in consolidated_findings.findings:
summary_parts.append(finding)
return "\\n".join(summary_parts)
def should_include_files_in_expert_prompt(self) -> bool:
"""Include files in expert analysis for comprehensive code review."""
return True
def should_embed_system_prompt(self) -> bool:
"""Embed system prompt in expert analysis for proper context."""
return True
def get_expert_thinking_mode(self) -> str:
"""Use high thinking mode for thorough code review analysis."""
return "high"
def get_expert_analysis_instruction(self) -> str:
"""Get specific instruction for code review expert analysis."""
return (
"Please provide comprehensive code review analysis based on the investigation findings. "
"Focus on identifying any remaining issues, validating the completeness of the analysis, "
"and providing final recommendations for code improvements, following the severity-based "
"format specified in the system prompt."
)
# Hook method overrides for code review-specific behavior
def prepare_step_data(self, request) -> dict:
"""
Map code review-specific fields for internal processing.
"""
step_data = {
"step": request.step,
"step_number": request.step_number,
"findings": request.findings,
"files_checked": request.files_checked,
"relevant_files": request.relevant_files,
"relevant_context": request.relevant_context,
"issues_found": request.issues_found,
"review_validation_type": self.get_review_validation_type(request),
"hypothesis": request.findings, # Map findings to hypothesis for compatibility
"images": request.images or [],
"confidence": "high", # Dummy value for workflow_mixin compatibility
}
return step_data
def should_skip_expert_analysis(self, request, consolidated_findings) -> bool:
"""
Code review workflow skips expert analysis only when review_validation_type is "internal".
Default is always to use expert analysis (external).
For continuations with external type, always perform expert analysis immediately.
"""
# If it's a continuation and review_validation_type is external, don't skip
continuation_id = self.get_request_continuation_id(request)
validation_type = self.get_review_validation_type(request)
if continuation_id and validation_type != "internal":
return False # Always do expert analysis for external continuations
# Only skip if explicitly set to internal AND review is complete
return validation_type == "internal" and not request.next_step_required
def store_initial_issue(self, step_description: str):
"""Store initial request for expert analysis."""
self.initial_request = step_description
# Override inheritance hooks for code review-specific behavior
def get_review_validation_type(self, request) -> str:
"""Get review validation type from request. Hook method for clean inheritance."""
try:
return request.review_validation_type or "external"
except AttributeError:
return "external" # Default to external validation
def get_completion_status(self) -> str:
"""Code review tools use review-specific status."""
return "code_review_complete_ready_for_implementation"
def get_completion_data_key(self) -> str:
"""Code review uses 'complete_code_review' key."""
return "complete_code_review"
def get_final_analysis_from_request(self, request):
"""Code review tools use 'findings' field."""
return request.findings
def get_confidence_level(self, request) -> str:
"""Code review tools use 'certain' for high confidence."""
return "certain"
def get_completion_message(self) -> str:
"""Code review-specific completion message."""
return (
"Code review complete. You have identified all significant issues "
"and provided comprehensive analysis. MANDATORY: Present the user with the complete review results "
"categorized by severity, and IMMEDIATELY proceed with implementing the highest priority fixes "
"or provide specific guidance for improvements. Focus on actionable recommendations."
)
def get_skip_reason(self) -> str:
"""Code review-specific skip reason."""
return "Completed comprehensive code review with internal analysis only (no external model validation)"
def get_skip_expert_analysis_status(self) -> str:
"""Code review-specific expert analysis skip status."""
return "skipped_due_to_internal_analysis_type"
def prepare_work_summary(self) -> str:
"""Code review-specific work summary."""
return self._build_code_review_summary(self.consolidated_findings)
def get_completion_next_steps_message(self, expert_analysis_used: bool = False) -> str:
"""
Code review-specific completion message.
"""
base_message = (
"CODE REVIEW IS COMPLETE. You MUST now summarize and present ALL review findings organized by "
"severity (Critical → High → Medium → Low), specific code locations with line numbers, and exact "
"recommendations for improvement. Clearly prioritize the top 3 issues that need immediate attention. "
"Provide concrete, actionable guidance for each issue—make it easy for a developer to understand "
"exactly what needs to be fixed and how to implement the improvements."
)
# Add expert analysis guidance only when expert analysis was actually used
if expert_analysis_used:
expert_guidance = self.get_expert_analysis_guidance()
if expert_guidance:
return f"{base_message}\n\n{expert_guidance}"
return base_message
def get_expert_analysis_guidance(self) -> str:
"""
Provide specific guidance for handling expert analysis in code reviews.
"""
return (
"IMPORTANT: Analysis from an assistant model has been provided above. You MUST critically evaluate and validate "
"the expert findings rather than accepting them blindly. Cross-reference the expert analysis with "
"your own investigation findings, verify that suggested improvements are appropriate for this "
"codebase's context and patterns, and ensure recommendations align with the project's standards. "
"Present a synthesis that combines your systematic review with validated expert insights, clearly "
"distinguishing between findings you've independently confirmed and additional insights from expert analysis."
)
def get_step_guidance_message(self, request) -> str:
"""
Code review-specific step guidance with detailed investigation instructions.
"""
step_guidance = self.get_code_review_step_guidance(request.step_number, request)
return step_guidance["next_steps"]
def get_code_review_step_guidance(self, step_number: int, request) -> dict[str, Any]:
"""
Provide step-specific guidance for code review workflow.
Uses get_required_actions to determine what needs to be done,
then formats those actions into appropriate guidance messages.
"""
# Get the required actions from the single source of truth
required_actions = self.get_required_actions(
step_number,
"medium", # Dummy value for backward compatibility
request.findings or "",
request.total_steps,
request, # Pass request for continuation-aware decisions
)
# Check if this is a continuation to provide context-aware guidance
continuation_id = self.get_request_continuation_id(request)
validation_type = self.get_review_validation_type(request)
is_external_continuation = continuation_id and validation_type == "external"
is_internal_continuation = continuation_id and validation_type == "internal"
# Step 1 handling
if step_number == 1:
if is_external_continuation:
# Fast-track for external continuations
return {
"next_steps": (
"You are on step 1 of MAXIMUM 2 steps for continuation. CRITICAL: Quickly review the code NOW. "
"MANDATORY ACTIONS:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ "\\n\\nSet next_step_required=True and step_number=2 for the next call to trigger expert analysis."
)
}
elif is_internal_continuation:
# Internal validation mode
next_steps = (
"Continuing previous conversation with internal validation only. The analysis will build "
"upon the prior findings without external model validation. REQUIRED ACTIONS:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
)
else:
# Normal flow for new reviews
next_steps = (
f"MANDATORY: DO NOT call the {self.get_name()} tool again immediately. You MUST first examine "
f"the code files thoroughly using appropriate tools. CRITICAL AWARENESS: You need to:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\\n\\nOnly call {self.get_name()} again AFTER completing your investigation. "
f"When you call {self.get_name()} next time, use step_number: {step_number + 1} "
f"and report specific files examined, issues found, and code quality assessments discovered."
)
elif step_number == 2:
# CRITICAL: Check if violating minimum step requirement
if (
request.total_steps >= 3
and request.step_number < request.total_steps
and not request.next_step_required
):
next_steps = (
f"ERROR: You set total_steps={request.total_steps} but next_step_required=False on step {request.step_number}. "
f"This violates the minimum step requirement. You MUST set next_step_required=True until you reach the final step. "
f"Call {self.get_name()} again with next_step_required=True and continue your investigation."
)
elif is_external_continuation or (not request.next_step_required and validation_type == "external"):
# Fast-track completion or about to complete for external validation
next_steps = (
"Proceeding immediately to expert analysis. "
f"MANDATORY: call {self.get_name()} tool immediately again, and set next_step_required=False to "
f"trigger external validation NOW."
)
else:
# Normal flow - deeper analysis needed
next_steps = (
f"STOP! Do NOT call {self.get_name()} again yet. You are on step 2 of {request.total_steps} minimum required steps. "
f"MANDATORY ACTIONS before calling {self.get_name()} step {step_number + 1}:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\\n\\nRemember: You MUST set next_step_required=True until step {request.total_steps}. "
+ f"Only call {self.get_name()} again with step_number: {step_number + 1} AFTER completing these code review tasks."
)
elif step_number >= 3:
if not request.next_step_required and validation_type == "external":
# About to complete - ready for expert analysis
next_steps = (
"Completing review and proceeding to expert analysis. "
"Ensure all findings are documented with specific file references and line numbers."
)
else:
# Later steps - final verification
next_steps = (
f"WAIT! Your code review needs final verification. DO NOT call {self.get_name()} immediately. REQUIRED ACTIONS:\\n"
+ "\\n".join(f"{i+1}. {action}" for i, action in enumerate(required_actions))
+ f"\\n\\nREMEMBER: Ensure you have identified all significant issues across all severity levels and "
f"verified the completeness of your review. Document findings with specific file references and "
f"line numbers where applicable, then call {self.get_name()} with step_number: {step_number + 1}."
)
else:
# Fallback for any other case - check minimum step violation first
if (
request.total_steps >= 3
and request.step_number < request.total_steps
and not request.next_step_required
):
next_steps = (
f"ERROR: You set total_steps={request.total_steps} but next_step_required=False on step {request.step_number}. "
f"This violates the minimum step requirement. You MUST set next_step_required=True until step {request.total_steps}."
)
elif not request.next_step_required and validation_type == "external":
next_steps = (
"Completing review. "
"Ensure all findings are documented with specific file references and severity levels."
)
else:
next_steps = (
f"PAUSE REVIEW. Before calling {self.get_name()} step {step_number + 1}, you MUST examine more code thoroughly. "
+ "Required: "
+ ", ".join(required_actions[:2])
+ ". "
+ f"Your next {self.get_name()} call (step_number: {step_number + 1}) must include "
f"NEW evidence from actual code analysis, not just theories. NO recursive {self.get_name()} calls "
f"without investigation work!"
)
return {"next_steps": next_steps}
def customize_workflow_response(self, response_data: dict, request) -> dict:
"""
Customize response to match code review workflow format.
"""
# Store initial request on first step
if request.step_number == 1:
self.initial_request = request.step
# Store review configuration for expert analysis
if request.relevant_files:
self.review_config = {
"relevant_files": request.relevant_files,
"review_type": request.review_type,
"focus_on": request.focus_on,
"standards": request.standards,
"severity_filter": request.severity_filter,
}
# Convert generic status names to code review-specific ones
tool_name = self.get_name()
status_mapping = {
f"{tool_name}_in_progress": "code_review_in_progress",
f"pause_for_{tool_name}": "pause_for_code_review",
f"{tool_name}_required": "code_review_required",
f"{tool_name}_complete": "code_review_complete",
}
if response_data["status"] in status_mapping:
response_data["status"] = status_mapping[response_data["status"]]
# Rename status field to match code review workflow
if f"{tool_name}_status" in response_data:
response_data["code_review_status"] = response_data.pop(f"{tool_name}_status")
# Add code review-specific status fields
response_data["code_review_status"]["issues_by_severity"] = {}
for issue in self.consolidated_findings.issues_found:
severity = issue.get("severity", "unknown")
if severity not in response_data["code_review_status"]["issues_by_severity"]:
response_data["code_review_status"]["issues_by_severity"][severity] = 0
response_data["code_review_status"]["issues_by_severity"][severity] += 1
response_data["code_review_status"]["review_validation_type"] = self.get_review_validation_type(request)
# Map complete_codereviewworkflow to complete_code_review
if f"complete_{tool_name}" in response_data:
response_data["complete_code_review"] = response_data.pop(f"complete_{tool_name}")
# Map the completion flag to match code review workflow
if f"{tool_name}_complete" in response_data:
response_data["code_review_complete"] = response_data.pop(f"{tool_name}_complete")
return response_data
# Required abstract methods from BaseTool
def get_request_model(self):
"""Return the code review workflow-specific request model."""
return CodeReviewRequest
async def prepare_prompt(self, request) -> str:
"""Not used - workflow tools use execute_workflow()."""
return "" # Workflow tools use execute_workflow() directly
```